diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index e5d92de..e53e2a3 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -68,6 +68,14 @@ class ListPipelineJobsResponse: class PipelineRunsApiService_Sql: _PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name" _DEFAULT_PAGE_SIZE: Final[int] = 10 + _SYSTEM_KEY_RESERVED_MSG = ( + "Annotation keys starting with " + f"{filter_query_sql.SYSTEM_KEY_PREFIX!r} are reserved for system use." + ) + + def _fail_if_changing_system_annotation(self, *, key: str) -> None: + if key.startswith(filter_query_sql.SYSTEM_KEY_PREFIX): + raise errors.ApiValidationError(self._SYSTEM_KEY_RESERVED_MSG) def create( self, @@ -105,6 +113,19 @@ def create( }, ) session.add(pipeline_run) + # Mirror created_by into the annotations table so it's searchable + # via filter_query like any other annotation. + if created_by is not None: + # Flush to populate pipeline_run.id (server-generated) before inserting the annotation FK. + # TODO: Use ORM relationship instead of explicit flush + manual FK assignment. + session.flush() + session.add( + bts.PipelineRunAnnotation( + pipeline_run_id=pipeline_run.id, + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + value=created_by, + ) + ) session.commit() session.refresh(pipeline_run) @@ -295,6 +316,7 @@ def set_annotation( user_name: str | None = None, skip_user_check: bool = False, ): + self._fail_if_changing_system_annotation(key=key) pipeline_run = session.get(bts.PipelineRun, id) if not pipeline_run: raise errors.ItemNotFoundError(f"Pipeline run {id} not found.") @@ -317,6 +339,7 @@ def delete_annotation( user_name: str | None = None, skip_user_check: bool = False, ): + self._fail_if_changing_system_annotation(key=key) pipeline_run = session.get(bts.PipelineRun, id) if not pipeline_run: raise errors.ItemNotFoundError(f"Pipeline run {id} not found.") diff --git a/cloud_pipelines_backend/database_ops.py b/cloud_pipelines_backend/database_ops.py index c39d1fc..92a36ea 100644 --- a/cloud_pipelines_backend/database_ops.py +++ b/cloud_pipelines_backend/database_ops.py @@ -1,6 +1,8 @@ import sqlalchemy +from sqlalchemy import orm from . import backend_types_sql as bts +from . import filter_query_sql def create_db_engine_and_migrate_db( @@ -83,3 +85,60 @@ def migrate_db(db_engine: sqlalchemy.Engine): if index.name == bts.PipelineRunAnnotation._IX_ANNOTATION_RUN_ID_KEY_VALUE: index.create(db_engine, checkfirst=True) break + + _backfill_pipeline_run_created_by_annotations(db_engine=db_engine) + + +def _is_pipeline_run_annotation_key_already_backfilled( + *, + session: orm.Session, + key: str, +) -> bool: + """Return True if at least one annotation with the given key exists.""" + return session.query( + sqlalchemy.exists( + sqlalchemy.select(sqlalchemy.literal(1)) + .select_from(bts.PipelineRunAnnotation) + .where( + bts.PipelineRunAnnotation.key == key, + ) + ) + ).scalar() + + +def _backfill_pipeline_run_created_by_annotations( + *, + db_engine: sqlalchemy.Engine, +) -> None: + """Copy pipeline_run.created_by into pipeline_run_annotation so + annotation-based search works for created_by. + + The check and insert run in a single session/transaction to avoid + TOCTOU races between concurrent startup processes. + + Skips entirely if any created_by annotation key already exists (i.e. the + write-path is populating them, so the backfill has already run or is + no longer needed). + """ + with orm.Session(db_engine) as session: + if _is_pipeline_run_annotation_key_already_backfilled( + session=session, + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + ): + return + + stmt = sqlalchemy.insert(bts.PipelineRunAnnotation).from_select( + ["pipeline_run_id", "key", "value"], + sqlalchemy.select( + bts.PipelineRun.id, + sqlalchemy.literal( + filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY + ), + bts.PipelineRun.created_by, + ).where( + bts.PipelineRun.created_by.isnot(None), + bts.PipelineRun.created_by != "", + ), + ) + session.execute(stmt) + session.commit() diff --git a/cloud_pipelines_backend/filter_query_models.py b/cloud_pipelines_backend/filter_query_models.py index 821afd2..b966323 100644 --- a/cloud_pipelines_backend/filter_query_models.py +++ b/cloud_pipelines_backend/filter_query_models.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from typing import Annotated import pydantic @@ -58,21 +59,45 @@ def _at_least_one_time_bound(self) -> TimeRange: # --- Predicate wrapper models (one field each) --- -class KeyExistsPredicate(_BaseModel): +class KeyPredicateBase(_BaseModel): + """Base for predicates that target an annotation key.""" + + @property + @abc.abstractmethod + def key(self) -> str: ... + + +class KeyExistsPredicate(KeyPredicateBase): key_exists: KeyExists + @property + def key(self) -> str: + return self.key_exists.key + -class ValueContainsPredicate(_BaseModel): +class ValueContainsPredicate(KeyPredicateBase): value_contains: ValueContains + @property + def key(self) -> str: + return self.value_contains.key -class ValueInPredicate(_BaseModel): + +class ValueInPredicate(KeyPredicateBase): value_in: ValueIn + @property + def key(self) -> str: + return self.value_in.key + -class ValueEqualsPredicate(_BaseModel): +class ValueEqualsPredicate(KeyPredicateBase): value_equals: ValueEquals + @property + def key(self) -> str: + return self.value_equals.key + class TimeRangePredicate(_BaseModel): time_range: TimeRange diff --git a/cloud_pipelines_backend/filter_query_sql.py b/cloud_pipelines_backend/filter_query_sql.py index b4c6564..f29d1ee 100644 --- a/cloud_pipelines_backend/filter_query_sql.py +++ b/cloud_pipelines_backend/filter_query_sql.py @@ -1,5 +1,6 @@ import base64 import json +import enum from typing import Any, Final import sqlalchemy as sql @@ -8,6 +9,21 @@ from . import errors from . import filter_query_models +SYSTEM_KEY_PREFIX: Final[str] = "system/" + + +class PipelineRunAnnotationSystemKey(enum.StrEnum): + CREATED_BY = f"{SYSTEM_KEY_PREFIX}pipeline_run.created_by" + + +SYSTEM_KEY_SUPPORTED_PREDICATES: dict[PipelineRunAnnotationSystemKey, set[type]] = { + PipelineRunAnnotationSystemKey.CREATED_BY: { + filter_query_models.KeyExistsPredicate, + filter_query_models.ValueEqualsPredicate, + filter_query_models.ValueInPredicate, + }, +} + # --------------------------------------------------------------------------- # Page-token helpers # --------------------------------------------------------------------------- @@ -44,6 +60,78 @@ def _resolve_filter_value( return filter, filter_query, offset +# --------------------------------------------------------------------------- +# PipelineRunAnnotationSystemKey validation and resolution +# --------------------------------------------------------------------------- + + +def _check_predicate_allowed(*, predicate: filter_query_models.Predicate) -> None: + """Raise if a system key is used with an unsupported predicate type.""" + if not isinstance(predicate, filter_query_models.KeyPredicateBase): + return + key = predicate.key + + try: + system_key = PipelineRunAnnotationSystemKey(key) + except ValueError: + return + + supported = SYSTEM_KEY_SUPPORTED_PREDICATES.get(system_key, set()) + if type(predicate) not in supported: + raise errors.ApiValidationError( + f"Predicate {type(predicate).__name__} is not supported " + f"for system key {system_key!r}. " + f"Supported: {[t.__name__ for t in supported]}" + ) + + +def _resolve_system_key_value( + *, + key: str, + value: str, + current_user: str | None, +) -> str: + """Resolve special placeholder values for system keys.""" + if key == PipelineRunAnnotationSystemKey.CREATED_BY and value == "me": + return current_user if current_user is not None else "" + return value + + +def _maybe_resolve_system_values( + *, + predicate: filter_query_models.ValueEqualsPredicate, + current_user: str | None, +) -> filter_query_models.ValueEqualsPredicate: + """Resolve special values in a ValueEqualsPredicate.""" + key = predicate.value_equals.key + value = predicate.value_equals.value + resolved = _resolve_system_key_value( + key=key, + value=value, + current_user=current_user, + ) + if resolved != value: + return filter_query_models.ValueEqualsPredicate( + value_equals=filter_query_models.ValueEquals(key=key, value=resolved) + ) + return predicate + + +def _validate_and_resolve_predicate( + *, + predicate: filter_query_models.Predicate, + current_user: str | None, +) -> filter_query_models.Predicate: + """Validate system key support, then resolve special values.""" + _check_predicate_allowed(predicate=predicate) + if isinstance(predicate, filter_query_models.ValueEqualsPredicate): + return _maybe_resolve_system_values( + predicate=predicate, + current_user=current_user, + ) + return predicate + + # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- @@ -79,7 +167,12 @@ def build_list_filters( if filter_query_value: parsed = filter_query_models.FilterQuery.model_validate_json(filter_query_value) - where_clauses.append(filter_query_to_where_clause(filter_query=parsed)) + where_clauses.append( + filter_query_to_where_clause( + filter_query=parsed, + current_user=current_user, + ) + ) next_page_token = _encode_page_token( page_token_dict={ @@ -95,10 +188,13 @@ def build_list_filters( def filter_query_to_where_clause( *, filter_query: filter_query_models.FilterQuery, + current_user: str | None = None, ) -> sql.ColumnElement: predicates = filter_query.and_ or filter_query.or_ is_and = filter_query.and_ is not None - clauses = [_predicate_to_clause(predicate=p) for p in predicates] + clauses = [ + _predicate_to_clause(predicate=p, current_user=current_user) for p in predicates + ] return sql.and_(*clauses) if is_and else sql.or_(*clauses) @@ -163,17 +259,35 @@ def _build_filter_where_clauses( def _predicate_to_clause( *, - predicate, + predicate: filter_query_models.Predicate, + current_user: str | None = None, ) -> sql.ColumnElement: + predicate = _validate_and_resolve_predicate( + predicate=predicate, + current_user=current_user, + ) + match predicate: case filter_query_models.AndPredicate(): return sql.and_( - *[_predicate_to_clause(predicate=p) for p in predicate.and_] + *[ + _predicate_to_clause(predicate=p, current_user=current_user) + for p in predicate.and_ + ] ) case filter_query_models.OrPredicate(): - return sql.or_(*[_predicate_to_clause(predicate=p) for p in predicate.or_]) + return sql.or_( + *[ + _predicate_to_clause(predicate=p, current_user=current_user) + for p in predicate.or_ + ] + ) case filter_query_models.NotPredicate(): - return sql.not_(_predicate_to_clause(predicate=predicate.not_)) + return sql.not_( + _predicate_to_clause( + predicate=predicate.not_, current_user=current_user + ) + ) case filter_query_models.KeyExistsPredicate(): return _key_exists_to_clause(predicate=predicate) case filter_query_models.ValueEqualsPredicate(): diff --git a/tests/test_api_server_sql.py b/tests/test_api_server_sql.py index a46607e..b240cd5 100644 --- a/tests/test_api_server_sql.py +++ b/tests/test_api_server_sql.py @@ -360,6 +360,29 @@ def test_create_without_created_by(self, session_factory, service): result = _create_run(session_factory, service, root_task=_make_task_spec()) assert result.created_by is None + def test_create_writes_created_by_annotation(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="alice@example.com", + ) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] + == "alice@example.com" + ) + + def test_create_without_created_by_no_annotation(self, session_factory, service): + run = _create_run(session_factory, service, root_task=_make_task_spec()) + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert ( + filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY + not in annotations + ) + class TestPipelineRunAnnotationCrud: def test_set_annotation(self, session_factory, service): @@ -379,7 +402,7 @@ def test_set_annotation(self, session_factory, service): ) with session_factory() as session: annotations = service.list_annotations(session=session, id=run.id) - assert annotations == {"team": "ml-ops"} + assert annotations["team"] == "ml-ops" def test_set_annotation_overwrites(self, session_factory, service): run = _create_run( @@ -406,7 +429,7 @@ def test_set_annotation_overwrites(self, session_factory, service): ) with session_factory() as session: annotations = service.list_annotations(session=session, id=run.id) - assert annotations == {"team": "new-value"} + assert annotations["team"] == "new-value" def test_delete_annotation(self, session_factory, service): run = _create_run( @@ -432,7 +455,7 @@ def test_delete_annotation(self, session_factory, service): ) with session_factory() as session: annotations = service.list_annotations(session=session, id=run.id) - assert annotations == {} + assert "team" not in annotations def test_list_annotations_empty(self, session_factory, service): run = _create_run(session_factory, service, root_task=_make_task_spec()) @@ -440,6 +463,43 @@ def test_list_annotations_empty(self, session_factory, service): annotations = service.list_annotations(session=session, id=run.id) assert annotations == {} + def test_set_annotation_rejects_system_key(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with session_factory() as session: + with pytest.raises( + errors.ApiValidationError, match="reserved for system use" + ): + service.set_annotation( + session=session, + id=run.id, + key="system/pipeline_run.created_by", + value="hacker", + user_name="user1", + ) + + def test_delete_annotation_rejects_system_key(self, session_factory, service): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="user1", + ) + with session_factory() as session: + with pytest.raises( + errors.ApiValidationError, match="reserved for system use" + ): + service.delete_annotation( + session=session, + id=run.id, + key="system/pipeline_run.created_by", + user_name="user1", + ) + class TestFilterQueryApiWiring: def test_filter_query_validates_invalid_json(self, session_factory, service): @@ -833,3 +893,94 @@ def test_pagination_preserves_filter_query(self, session_factory, service): ) assert len(page2.pipeline_runs) == 2 assert page2.next_page_token is None + + def test_filter_query_created_by(self, session_factory, service): + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="alice", + ) + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="bob", + ) + + fq = json.dumps( + { + "and": [ + { + "value_equals": { + "key": filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + "value": "alice", + } + } + ] + } + ) + with session_factory() as session: + result = service.list( + session=session, + filter_query=fq, + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].created_by == "alice" + + def test_filter_query_created_by_me(self, session_factory, service): + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="alice", + ) + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="bob", + ) + + fq = json.dumps( + { + "and": [ + { + "value_equals": { + "key": filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + "value": "me", + } + } + ] + } + ) + with session_factory() as session: + result = service.list( + session=session, + filter_query=fq, + current_user="alice", + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].created_by == "alice" + + def test_filter_query_created_by_unsupported_predicate( + self, session_factory, service + ): + fq = json.dumps( + { + "and": [ + { + "value_contains": { + "key": filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + "value_substring": "al", + } + } + ] + } + ) + with session_factory() as session: + with pytest.raises(errors.ApiValidationError, match="not supported"): + service.list( + session=session, + filter_query=fq, + ) diff --git a/tests/test_database_ops.py b/tests/test_database_ops.py new file mode 100644 index 0000000..6183519 --- /dev/null +++ b/tests/test_database_ops.py @@ -0,0 +1,445 @@ +from typing import Any + +import pytest +import sqlalchemy +from sqlalchemy import orm + +from cloud_pipelines_backend import api_server_sql +from cloud_pipelines_backend import backend_types_sql as bts +from cloud_pipelines_backend import component_structures as structures +from cloud_pipelines_backend import database_ops +from cloud_pipelines_backend import filter_query_sql + + +def _make_task_spec( + *, + pipeline_name: str = "test-pipeline", +) -> structures.TaskSpec: + return structures.TaskSpec( + component_ref=structures.ComponentReference( + spec=structures.ComponentSpec( + name=pipeline_name, + implementation=structures.ContainerImplementation( + container=structures.ContainerSpec(image="test-image") + ), + ) + ) + ) + + +@pytest.fixture() +def session_factory() -> orm.sessionmaker: + db_engine = database_ops.create_db_engine(database_uri="sqlite://") + bts._TableBase.metadata.create_all(db_engine) + return orm.sessionmaker(db_engine) + + +def _create_run( + session_factory: orm.sessionmaker, + service: api_server_sql.PipelineRunsApiService_Sql, + **kwargs, +) -> api_server_sql.PipelineRunResponse: + """Create a pipeline run using a fresh session (mirrors production per-request sessions).""" + with session_factory() as session: + return service.create(session, **kwargs) + + +def _get_index_names( + *, + engine: sqlalchemy.Engine, + table_name: str, +) -> set[str]: + inspector = sqlalchemy.inspect(engine) + return {idx["name"] for idx in inspector.get_indexes(table_name)} + + +class TestMigrateDb: + """Verify migrate_db creates indexes on a pre-existing DB missing them.""" + + @pytest.fixture() + def bare_engine(self) -> sqlalchemy.Engine: + """Create tables then drop all indexes so migrate_db can recreate them.""" + db_engine = database_ops.create_db_engine(database_uri="sqlite://") + bts._TableBase.metadata.create_all(db_engine) + + with db_engine.connect() as conn: + for table_name in ("execution_node", "pipeline_run_annotation"): + for idx_name in _get_index_names( + engine=db_engine, table_name=table_name + ): + conn.execute(sqlalchemy.text(f"DROP INDEX IF EXISTS {idx_name}")) + conn.commit() + return db_engine + + def test_creates_execution_node_cache_key_index( + self, + bare_engine: sqlalchemy.Engine, + ) -> None: + assert bts.ExecutionNode._IX_EXECUTION_NODE_CACHE_KEY not in _get_index_names( + engine=bare_engine, + table_name="execution_node", + ) + database_ops.migrate_db(db_engine=bare_engine) + assert bts.ExecutionNode._IX_EXECUTION_NODE_CACHE_KEY in _get_index_names( + engine=bare_engine, + table_name="execution_node", + ) + + def test_creates_annotation_run_id_key_value_index( + self, + bare_engine: sqlalchemy.Engine, + ) -> None: + assert ( + bts.PipelineRunAnnotation._IX_ANNOTATION_RUN_ID_KEY_VALUE + not in _get_index_names( + engine=bare_engine, table_name="pipeline_run_annotation" + ) + ) + database_ops.migrate_db(db_engine=bare_engine) + assert ( + bts.PipelineRunAnnotation._IX_ANNOTATION_RUN_ID_KEY_VALUE + in _get_index_names( + engine=bare_engine, table_name="pipeline_run_annotation" + ) + ) + + def test_idempotent( + self, + bare_engine: sqlalchemy.Engine, + ) -> None: + database_ops.migrate_db(db_engine=bare_engine) + indexes_after_first = { + "execution_node": _get_index_names( + engine=bare_engine, table_name="execution_node" + ), + "pipeline_run_annotation": _get_index_names( + engine=bare_engine, + table_name="pipeline_run_annotation", + ), + } + database_ops.migrate_db(db_engine=bare_engine) + indexes_after_second = { + "execution_node": _get_index_names( + engine=bare_engine, table_name="execution_node" + ), + "pipeline_run_annotation": _get_index_names( + engine=bare_engine, + table_name="pipeline_run_annotation", + ), + } + assert indexes_after_first == indexes_after_second + + +class TestIsAnnotationKeyAlreadyBackfilled: + def test_false_on_empty_db( + self, + session_factory: orm.sessionmaker, + ) -> None: + with session_factory() as session: + assert ( + database_ops._is_pipeline_run_annotation_key_already_backfilled( + session=session, + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + ) + is False + ) + + def test_false_with_unrelated_annotation( + self, + session_factory: orm.sessionmaker, + ) -> None: + service = api_server_sql.PipelineRunsApiService_Sql() + run = _create_run(session_factory, service, root_task=_make_task_spec()) + with session_factory() as session: + service.set_annotation(session=session, id=run.id, key="team", value="ml") + with session_factory() as session: + assert ( + database_ops._is_pipeline_run_annotation_key_already_backfilled( + session=session, + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + ) + is False + ) + + def test_true_when_key_exists( + self, + session_factory: orm.sessionmaker, + ) -> None: + service = api_server_sql.PipelineRunsApiService_Sql() + _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="alice", + ) + with session_factory() as session: + assert ( + database_ops._is_pipeline_run_annotation_key_already_backfilled( + session=session, + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + ) + is True + ) + + def test_matches_exact_key( + self, + session_factory: orm.sessionmaker, + ) -> None: + """Only returns True for the exact key queried, not other keys.""" + service = api_server_sql.PipelineRunsApiService_Sql() + run = _create_run(session_factory, service, root_task=_make_task_spec()) + with session_factory() as session: + service.set_annotation(session=session, id=run.id, key="team", value="ml") + with session_factory() as session: + assert ( + database_ops._is_pipeline_run_annotation_key_already_backfilled( + session=session, + key="team", + ) + is True + ) + assert ( + database_ops._is_pipeline_run_annotation_key_already_backfilled( + session=session, + key="other_key", + ) + is False + ) + + def test_true_after_backfill( + self, + session_factory: orm.sessionmaker, + ) -> None: + """Create a run, delete its write-path annotation, then backfill.""" + service = api_server_sql.PipelineRunsApiService_Sql() + engine = session_factory.kw["bind"] + key = filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY + + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="bob", + ) + with session_factory() as session: + session.query(bts.PipelineRunAnnotation).filter_by( + pipeline_run_id=run.id, + key=key, + ).delete() + session.commit() + + with session_factory() as session: + assert ( + database_ops._is_pipeline_run_annotation_key_already_backfilled( + session=session, + key=key, + ) + is False + ) + database_ops._backfill_pipeline_run_created_by_annotations(db_engine=engine) + with session_factory() as session: + assert ( + database_ops._is_pipeline_run_annotation_key_already_backfilled( + session=session, + key=key, + ) + is True + ) + + +class TestCreatedByBackfill: + def test_backfill_populates_annotation_value( + self, + session_factory: orm.sessionmaker, + ) -> None: + """The INSERT path produces the correct annotation value.""" + service = api_server_sql.PipelineRunsApiService_Sql() + engine = session_factory.kw["bind"] + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="alice", + ) + key = filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY + + # Remove write-path annotation so the backfill INSERT actually runs + with session_factory() as session: + session.query(bts.PipelineRunAnnotation).filter_by( + pipeline_run_id=run.id, + key=key, + ).delete() + session.commit() + + with session_factory() as session: + assert key not in service.list_annotations(session=session, id=run.id) + + database_ops._backfill_pipeline_run_created_by_annotations(db_engine=engine) + + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert annotations[key] == "alice" + + def test_backfill_skips_empty_created_by( + self, + session_factory: orm.sessionmaker, + ) -> None: + """Runs with created_by='' are not backfilled (isnot(None) passes but empty string has no value).""" + service = api_server_sql.PipelineRunsApiService_Sql() + engine = session_factory.kw["bind"] + key = filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY + + # Create a run then set created_by to empty string directly in DB + run = _create_run(session_factory, service, root_task=_make_task_spec()) + with session_factory() as session: + db_run = session.get(bts.PipelineRun, run.id) + db_run.created_by = "" + session.commit() + + with session_factory() as session: + assert key not in service.list_annotations(session=session, id=run.id) + + database_ops._backfill_pipeline_run_created_by_annotations(db_engine=engine) + + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert key not in annotations + + def test_backfill_idempotent( + self, + session_factory: orm.sessionmaker, + ) -> None: + service = api_server_sql.PipelineRunsApiService_Sql() + engine = session_factory.kw["bind"] + + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="alice", + ) + + database_ops._backfill_pipeline_run_created_by_annotations(db_engine=engine) + database_ops._backfill_pipeline_run_created_by_annotations(db_engine=engine) + + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert ( + annotations[filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY] + == "alice" + ) + + def test_backfill_skips_null_created_by( + self, + session_factory: orm.sessionmaker, + ) -> None: + service = api_server_sql.PipelineRunsApiService_Sql() + engine = session_factory.kw["bind"] + + run = _create_run(session_factory, service, root_task=_make_task_spec()) + assert run.created_by is None + + database_ops._backfill_pipeline_run_created_by_annotations(db_engine=engine) + + with session_factory() as session: + annotations = service.list_annotations(session=session, id=run.id) + assert ( + filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY + not in annotations + ) + + def test_backfill_mixed_runs_and_repeated_backfills( + self, + session_factory: orm.sessionmaker, + ) -> None: + """Simulates a realistic sequence: create runs, backfill, create more runs, backfill again. + Verifies all annotations are correct and no duplicates are created.""" + service = api_server_sql.PipelineRunsApiService_Sql() + engine = session_factory.kw["bind"] + key = filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY + + run_alice = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="alice", + ) + run_no_user = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + ) + + database_ops._backfill_pipeline_run_created_by_annotations(db_engine=engine) + + run_bob = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="bob", + ) + run_alice2 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="alice", + ) + + database_ops._backfill_pipeline_run_created_by_annotations(db_engine=engine) + database_ops._backfill_pipeline_run_created_by_annotations(db_engine=engine) + + with session_factory() as session: + assert ( + service.list_annotations(session=session, id=run_alice.id)[key] + == "alice" + ) + assert key not in service.list_annotations( + session=session, id=run_no_user.id + ) + assert ( + service.list_annotations(session=session, id=run_bob.id)[key] == "bob" + ) + assert ( + service.list_annotations(session=session, id=run_alice2.id)[key] + == "alice" + ) + + def test_backfill_uses_single_session( + self, + session_factory: orm.sessionmaker, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Check and insert happen in the same session (single transaction).""" + service = api_server_sql.PipelineRunsApiService_Sql() + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="alice", + ) + key = filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY + with session_factory() as session: + session.query(bts.PipelineRunAnnotation).filter_by( + pipeline_run_id=run.id, + key=key, + ).delete() + session.commit() + + engine = session_factory.kw["bind"] + session_count = 0 + _original_session_init = orm.Session.__init__ + + def _counting_init( + self: Any, + *args: Any, + **kwargs: Any, + ) -> None: + nonlocal session_count + session_count += 1 + _original_session_init(self, *args, **kwargs) + + monkeypatch.setattr(orm.Session, "__init__", _counting_init) + + database_ops._backfill_pipeline_run_created_by_annotations(db_engine=engine) + + assert session_count == 1 diff --git a/tests/test_filter_query_sql.py b/tests/test_filter_query_sql.py index 2e0bad8..cd1d763 100644 --- a/tests/test_filter_query_sql.py +++ b/tests/test_filter_query_sql.py @@ -389,3 +389,141 @@ def test_created_by_me_resolved_in_next_token(self): assert len(clauses) == 1 decoded = filter_query_sql._decode_page_token(page_token=next_token) assert decoded["filter"] == "created_by:bob@example.com" + + +class TestPipelineRunAnnotationSystemKeyValidation: + def test_check_predicate_allowed_skips_logical_operator(self): + pred = filter_query_models.AndPredicate( + **{ + "and": [ + filter_query_models.KeyExistsPredicate( + key_exists=filter_query_models.KeyExists(key="x") + ) + ] + } + ) + assert filter_query_sql._check_predicate_allowed(predicate=pred) is None + + def test_check_predicate_allowed_supported(self): + pred = filter_query_models.ValueEqualsPredicate( + value_equals=filter_query_models.ValueEquals( + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + value="alice", + ) + ) + assert filter_query_sql._check_predicate_allowed(predicate=pred) is None + + def test_check_predicate_allowed_unsupported(self): + pred = filter_query_models.ValueContainsPredicate( + value_contains=filter_query_models.ValueContains( + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + value_substring="al", + ) + ) + with pytest.raises(errors.ApiValidationError, match="not supported"): + filter_query_sql._check_predicate_allowed(predicate=pred) + + def test_check_predicate_allowed_non_system_key(self): + pred = filter_query_models.ValueContainsPredicate( + value_contains=filter_query_models.ValueContains( + key="team", value_substring="ml" + ) + ) + assert filter_query_sql._check_predicate_allowed(predicate=pred) is None + + def test_check_predicate_allowed_key_exists_supported(self): + pred = filter_query_models.KeyExistsPredicate( + key_exists=filter_query_models.KeyExists( + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY + ) + ) + assert filter_query_sql._check_predicate_allowed(predicate=pred) is None + + def test_check_predicate_allowed_value_in_supported(self): + pred = filter_query_models.ValueInPredicate( + value_in=filter_query_models.ValueIn( + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + values=["alice"], + ) + ) + assert filter_query_sql._check_predicate_allowed(predicate=pred) is None + + def test_check_predicate_allowed_skips_time_range(self): + pred = filter_query_models.TimeRangePredicate( + time_range=filter_query_models.TimeRange( + key="created_at", start_time="2024-01-01T00:00:00Z" + ) + ) + assert filter_query_sql._check_predicate_allowed(predicate=pred) is None + + def test_check_predicate_allowed_skips_or_predicate(self): + pred = filter_query_models.OrPredicate( + **{ + "or": [ + filter_query_models.KeyExistsPredicate( + key_exists=filter_query_models.KeyExists(key="x") + ) + ] + } + ) + assert filter_query_sql._check_predicate_allowed(predicate=pred) is None + + def test_check_predicate_allowed_skips_not_predicate(self): + pred = filter_query_models.NotPredicate( + **{ + "not": filter_query_models.KeyExistsPredicate( + key_exists=filter_query_models.KeyExists(key="x") + ) + }, + ) + assert filter_query_sql._check_predicate_allowed(predicate=pred) is None + + def test_resolve_system_key_value_me(self): + result = filter_query_sql._resolve_system_key_value( + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + value="me", + current_user="alice@example.com", + ) + assert result == "alice@example.com" + + def test_resolve_system_key_value_me_no_user(self): + result = filter_query_sql._resolve_system_key_value( + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + value="me", + current_user=None, + ) + assert result == "" + + def test_resolve_system_key_value_passthrough(self): + result = filter_query_sql._resolve_system_key_value( + key="team", + value="me", + current_user="alice", + ) + assert result == "me" + + def test_maybe_resolve_system_values(self): + pred = filter_query_models.ValueEqualsPredicate( + value_equals=filter_query_models.ValueEquals( + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + value="me", + ) + ) + resolved = filter_query_sql._maybe_resolve_system_values( + predicate=pred, + current_user="bob@example.com", + ) + assert resolved.value_equals.value == "bob@example.com" + + def test_validate_and_resolve_predicate(self): + pred = filter_query_models.ValueEqualsPredicate( + value_equals=filter_query_models.ValueEquals( + key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY, + value="me", + ) + ) + resolved = filter_query_sql._validate_and_resolve_predicate( + predicate=pred, + current_user="charlie@example.com", + ) + assert resolved.value_equals.value == "charlie@example.com"