diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index ed9c7bc..e5d92de 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -1,7 +1,5 @@ -import base64 import dataclasses import datetime -import json import logging import typing from typing import Any, Final, Optional @@ -12,7 +10,7 @@ from . import backend_types_sql as bts from . import component_structures as structures from . import errors -from . import filter_query_models +from . import filter_query_sql if typing.TYPE_CHECKING: from cloud_pipelines.orchestration.storage_providers import ( @@ -69,8 +67,6 @@ class ListPipelineJobsResponse: class PipelineRunsApiService_Sql: _PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name" - _PAGE_TOKEN_OFFSET_KEY: Final[str] = "offset" - _PAGE_TOKEN_FILTER_KEY: Final[str] = "filter" _DEFAULT_PAGE_SIZE: Final[int] = 10 def create( @@ -173,22 +169,12 @@ def list( include_pipeline_names: bool = False, include_execution_stats: bool = False, ) -> ListPipelineJobsResponse: - if filter and filter_query: - raise errors.ApiValidationError( - "Cannot use both 'filter' and 'filter_query'. Use one or the other." - ) - - if filter_query: - filter_query_models.FilterQuery.model_validate_json(filter_query) - raise NotImplementedError("filter_query is not yet implemented.") - - filter_value, offset = _resolve_filter_value( - filter=filter, - page_token=page_token, - ) - where_clauses, next_page_filter_value = _build_filter_where_clauses( - filter_value=filter_value, + where_clauses, offset, next_token = filter_query_sql.build_list_filters( + filter_value=filter, + filter_query_value=filter_query, + page_token_value=page_token, current_user=current_user, + page_size=self._DEFAULT_PAGE_SIZE, ) pipeline_runs = list( @@ -200,14 +186,10 @@ def list( .limit(self._DEFAULT_PAGE_SIZE) ).all() ) - next_page_offset = offset + self._DEFAULT_PAGE_SIZE - next_page_token_dict = { - self._PAGE_TOKEN_OFFSET_KEY: next_page_offset, - self._PAGE_TOKEN_FILTER_KEY: next_page_filter_value, - } - next_page_token = _encode_page_token(next_page_token_dict) - if len(pipeline_runs) < self._DEFAULT_PAGE_SIZE: - next_page_token = None + + next_page_token = ( + next_token if len(pipeline_runs) >= self._DEFAULT_PAGE_SIZE else None + ) return ListPipelineJobsResponse( pipeline_runs=[ @@ -348,88 +330,6 @@ def delete_annotation( session.commit() -def _resolve_filter_value( - *, - filter: str | None, - page_token: str | None, -) -> tuple[str | None, int]: - """Decode page_token and return the effective (filter_value, offset). - - If a page_token is present, its stored filter takes precedence over the - raw filter parameter (the token carries the resolved filter forward across pages). - """ - page_token_dict = _decode_page_token(page_token) - offset = page_token_dict.get( - PipelineRunsApiService_Sql._PAGE_TOKEN_OFFSET_KEY, - 0, - ) - if page_token: - filter = page_token_dict.get( - PipelineRunsApiService_Sql._PAGE_TOKEN_FILTER_KEY, - None, - ) - return filter, offset - - -def _build_filter_where_clauses( - *, - filter_value: str | None, - current_user: str | None, -) -> tuple[list[sql.ColumnElement], str | None]: - """Parse a filter string into SQLAlchemy WHERE clauses. - - Returns (where_clauses, next_page_filter_value). The second value is the - filter string with shorthand values resolved (e.g. "created_by:me" becomes - "created_by:alice@example.com") so it can be embedded in the next page token. - """ - where_clauses: list[sql.ColumnElement] = [] - parsed_filter = _parse_filter(filter_value) if filter_value else {} - for key, value in parsed_filter.items(): - if key == "_text": - raise NotImplementedError("Text search is not implemented yet.") - elif key == "created_by": - if value == "me": - if current_user is None: - current_user = "" - value = current_user - # TODO: Maybe make this a bit more robust. - # We need to change the filter since it goes into the next_page_token. - filter_value = filter_value.replace( - "created_by:me", f"created_by:{current_user}" - ) - if value: - where_clauses.append(bts.PipelineRun.created_by == value) - else: - where_clauses.append(bts.PipelineRun.created_by == None) - else: - raise NotImplementedError(f"Unsupported filter {filter_value}.") - return where_clauses, filter_value - - -def _decode_page_token(page_token: str) -> dict[str, Any]: - return json.loads(base64.b64decode(page_token)) if page_token else {} - - -def _encode_page_token(page_token_dict: dict[str, Any]) -> str: - return (base64.b64encode(json.dumps(page_token_dict).encode("utf8"))).decode( - "utf-8" - ) - - -def _parse_filter(filter: str) -> dict[str, str]: - # TODO: Improve - parts = filter.strip().split() - parsed_filter = {} - for part in parts: - key, sep, value = part.partition(":") - if sep: - parsed_filter[key] = value - else: - parsed_filter.setdefault("_text", "") - parsed_filter["_text"] += part - return parsed_filter - - # ========== ExecutionNodeApiService_Sql diff --git a/cloud_pipelines_backend/backend_types_sql.py b/cloud_pipelines_backend/backend_types_sql.py index 04a8bf5..586a993 100644 --- a/cloud_pipelines_backend/backend_types_sql.py +++ b/cloud_pipelines_backend/backend_types_sql.py @@ -2,7 +2,7 @@ import datetime import enum import typing -from typing import Any +from typing import Any, Final import sqlalchemy as sql from sqlalchemy import orm @@ -294,6 +294,9 @@ class ExecutionToAncestorExecutionLink(_TableBase): # So we need to jump through extra hoops to make the relationship many-to-many again. class ExecutionNode(_TableBase): __tablename__ = "execution_node" + _IX_EXECUTION_NODE_CACHE_KEY: Final[str] = ( + "ix_execution_node_container_execution_cache_key" + ) id: orm.Mapped[IdType] = orm.mapped_column( primary_key=True, init=False, insert_default=generate_unique_id ) @@ -491,6 +494,9 @@ class ContainerExecution(_TableBase): class PipelineRunAnnotation(_TableBase): __tablename__ = "pipeline_run_annotation" + _IX_ANNOTATION_RUN_ID_KEY_VALUE: Final[str] = ( + "ix_pipeline_run_annotation_pipeline_run_id_key_value" + ) pipeline_run_id: orm.Mapped[IdType] = orm.mapped_column( sql.ForeignKey(PipelineRun.id), primary_key=True, @@ -500,6 +506,15 @@ class PipelineRunAnnotation(_TableBase): key: orm.Mapped[str] = orm.mapped_column(default=None, primary_key=True) value: orm.Mapped[str | None] = orm.mapped_column(default=None) + __table_args__ = ( + sql.Index( + _IX_ANNOTATION_RUN_ID_KEY_VALUE, + "pipeline_run_id", + "key", + "value", + ), + ) + class Secret(_TableBase): __tablename__ = "secret" diff --git a/cloud_pipelines_backend/component_structures.py b/cloud_pipelines_backend/component_structures.py index 1bc4550..73fe271 100644 --- a/cloud_pipelines_backend/component_structures.py +++ b/cloud_pipelines_backend/component_structures.py @@ -54,7 +54,6 @@ import pydantic.alias_generators from pydantic.dataclasses import dataclass as pydantic_dataclasses - # PrimitiveTypes = Union[str, int, float, bool] PrimitiveTypes = str diff --git a/cloud_pipelines_backend/database_ops.py b/cloud_pipelines_backend/database_ops.py index 3d94ed1..c39d1fc 100644 --- a/cloud_pipelines_backend/database_ops.py +++ b/cloud_pipelines_backend/database_ops.py @@ -75,5 +75,11 @@ def migrate_db(db_engine: sqlalchemy.Engine): # Or we need to avoid calling the Index constructor. for index in bts.ExecutionNode.__table__.indexes: - if index.name == "ix_execution_node_container_execution_cache_key": + if index.name == bts.ExecutionNode._IX_EXECUTION_NODE_CACHE_KEY: index.create(db_engine, checkfirst=True) + break + + for index in bts.PipelineRunAnnotation.__table__.indexes: + if index.name == bts.PipelineRunAnnotation._IX_ANNOTATION_RUN_ID_KEY_VALUE: + index.create(db_engine, checkfirst=True) + break diff --git a/cloud_pipelines_backend/filter_query_sql.py b/cloud_pipelines_backend/filter_query_sql.py new file mode 100644 index 0000000..b4c6564 --- /dev/null +++ b/cloud_pipelines_backend/filter_query_sql.py @@ -0,0 +1,246 @@ +import base64 +import json +from typing import Any, Final + +import sqlalchemy as sql + +from . import backend_types_sql as bts +from . import errors +from . import filter_query_models + +# --------------------------------------------------------------------------- +# Page-token helpers +# --------------------------------------------------------------------------- + +_PAGE_TOKEN_OFFSET_KEY: Final[str] = "offset" +_PAGE_TOKEN_FILTER_KEY: Final[str] = "filter" +_PAGE_TOKEN_FILTER_QUERY_KEY: Final[str] = "filter_query" + + +def _encode_page_token(*, page_token_dict: dict[str, Any]) -> str: + return base64.b64encode(json.dumps(page_token_dict).encode("utf-8")).decode("utf-8") + + +def _decode_page_token(*, page_token: str | None) -> dict[str, Any]: + return json.loads(base64.b64decode(page_token)) if page_token else {} + + +def _resolve_filter_value( + *, + filter: str | None, + filter_query: str | None, + page_token: str | None, +) -> tuple[str | None, str | None, int]: + """Decode page_token and return the effective (filter_value, filter_query_value, offset). + + If a page_token is present, its stored values take precedence over the + raw parameters (the token carries resolved values forward across pages). + """ + page_token_dict = _decode_page_token(page_token=page_token) + offset = page_token_dict.get(_PAGE_TOKEN_OFFSET_KEY, 0) + if page_token: + filter = page_token_dict.get(_PAGE_TOKEN_FILTER_KEY) + filter_query = page_token_dict.get(_PAGE_TOKEN_FILTER_QUERY_KEY) + return filter, filter_query, offset + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def build_list_filters( + *, + filter_value: str | None, + filter_query_value: str | None, + page_token_value: str | None, + current_user: str | None, + page_size: int, +) -> tuple[list[sql.ColumnElement], int, str]: + """Resolve pagination token, legacy filter, and filter_query into WHERE clauses. + + Returns (where_clauses, offset, next_page_token_encoded). + """ + if filter_value and filter_query_value: + raise errors.ApiValidationError( + "Cannot use both 'filter' and 'filter_query'. Use one or the other." + ) + + filter_value, filter_query_value, offset = _resolve_filter_value( + filter=filter_value, + filter_query=filter_query_value, + page_token=page_token_value, + ) + + where_clauses, resolved_filter = _build_filter_where_clauses( + filter_value=filter_value, + current_user=current_user, + ) + + if filter_query_value: + parsed = filter_query_models.FilterQuery.model_validate_json(filter_query_value) + where_clauses.append(filter_query_to_where_clause(filter_query=parsed)) + + next_page_token = _encode_page_token( + page_token_dict={ + _PAGE_TOKEN_OFFSET_KEY: offset + page_size, + _PAGE_TOKEN_FILTER_KEY: resolved_filter, + _PAGE_TOKEN_FILTER_QUERY_KEY: filter_query_value, + } + ) + + return where_clauses, offset, next_page_token + + +def filter_query_to_where_clause( + *, + filter_query: filter_query_models.FilterQuery, +) -> sql.ColumnElement: + predicates = filter_query.and_ or filter_query.or_ + is_and = filter_query.and_ is not None + clauses = [_predicate_to_clause(predicate=p) for p in predicates] + return sql.and_(*clauses) if is_and else sql.or_(*clauses) + + +# --------------------------------------------------------------------------- +# Legacy filter parsing +# --------------------------------------------------------------------------- + + +def _parse_filter(filter: str) -> dict[str, str]: + # TODO: Improve + parts = filter.strip().split() + parsed_filter = {} + for part in parts: + key, sep, value = part.partition(":") + if sep: + parsed_filter[key] = value + else: + parsed_filter.setdefault("_text", "") + parsed_filter["_text"] += part + return parsed_filter + + +def _build_filter_where_clauses( + *, + filter_value: str | None, + current_user: str | None, +) -> tuple[list[sql.ColumnElement], str | None]: + """Parse a filter string into SQLAlchemy WHERE clauses. + + Returns (where_clauses, next_page_filter_value). The second value is the + filter string with shorthand values resolved (e.g. "created_by:me" becomes + "created_by:alice@example.com") so it can be embedded in the next page token. + """ + where_clauses: list[sql.ColumnElement] = [] + parsed_filter = _parse_filter(filter_value) if filter_value else {} + for key, value in parsed_filter.items(): + if key == "_text": + raise NotImplementedError("Text search is not implemented yet.") + elif key == "created_by": + if value == "me": + if current_user is None: + current_user = "" + value = current_user + # TODO: Maybe make this a bit more robust. + # We need to change the filter since it goes into the next_page_token. + filter_value = filter_value.replace( + "created_by:me", f"created_by:{current_user}" + ) + if value: + where_clauses.append(bts.PipelineRun.created_by == value) + else: + where_clauses.append(bts.PipelineRun.created_by == None) + else: + raise NotImplementedError(f"Unsupported filter {filter_value}.") + return where_clauses, filter_value + + +# --------------------------------------------------------------------------- +# Predicate dispatch +# --------------------------------------------------------------------------- + + +def _predicate_to_clause( + *, + predicate, +) -> sql.ColumnElement: + match predicate: + case filter_query_models.AndPredicate(): + return sql.and_( + *[_predicate_to_clause(predicate=p) for p in predicate.and_] + ) + case filter_query_models.OrPredicate(): + return sql.or_(*[_predicate_to_clause(predicate=p) for p in predicate.or_]) + case filter_query_models.NotPredicate(): + return sql.not_(_predicate_to_clause(predicate=predicate.not_)) + case filter_query_models.KeyExistsPredicate(): + return _key_exists_to_clause(predicate=predicate) + case filter_query_models.ValueEqualsPredicate(): + return _value_equals_to_clause(predicate=predicate) + case filter_query_models.ValueContainsPredicate(): + return _value_contains_to_clause(predicate=predicate) + case filter_query_models.ValueInPredicate(): + return _value_in_to_clause(predicate=predicate) + case _: + # TODO: TimeRangePredicate -- not supported currently, will be supported in the future. + raise NotImplementedError( + f"Predicate type {type(predicate).__name__} is not yet implemented." + ) + + +# --------------------------------------------------------------------------- +# Generic annotation-based leaf predicates +# --------------------------------------------------------------------------- + + +def _annotation_exists(*, conditions: list[sql.ColumnElement]) -> sql.ColumnElement: + return sql.exists( + sql.select(bts.PipelineRunAnnotation.pipeline_run_id).where( + bts.PipelineRunAnnotation.pipeline_run_id == bts.PipelineRun.id, + *conditions, + ) + ) + + +def _key_exists_to_clause( + *, predicate: filter_query_models.KeyExistsPredicate +) -> sql.ColumnElement: + return _annotation_exists( + conditions=[bts.PipelineRunAnnotation.key == predicate.key_exists.key], + ) + + +def _value_equals_to_clause( + *, predicate: filter_query_models.ValueEqualsPredicate +) -> sql.ColumnElement: + return _annotation_exists( + conditions=[ + bts.PipelineRunAnnotation.key == predicate.value_equals.key, + bts.PipelineRunAnnotation.value == predicate.value_equals.value, + ], + ) + + +def _value_contains_to_clause( + *, predicate: filter_query_models.ValueContainsPredicate +) -> sql.ColumnElement: + return _annotation_exists( + conditions=[ + bts.PipelineRunAnnotation.key == predicate.value_contains.key, + bts.PipelineRunAnnotation.value.contains( + predicate.value_contains.value_substring + ), + ], + ) + + +def _value_in_to_clause( + *, predicate: filter_query_models.ValueInPredicate +) -> sql.ColumnElement: + return _annotation_exists( + conditions=[ + bts.PipelineRunAnnotation.key == predicate.value_in.key, + bts.PipelineRunAnnotation.value.in_(predicate.value_in.values), + ], + ) diff --git a/tests/test_api_server_sql.py b/tests/test_api_server_sql.py index 3a4f662..a46607e 100644 --- a/tests/test_api_server_sql.py +++ b/tests/test_api_server_sql.py @@ -1,3 +1,5 @@ +import json + import pytest from sqlalchemy import orm @@ -6,6 +8,7 @@ from cloud_pipelines_backend import component_structures as structures from cloud_pipelines_backend import database_ops from cloud_pipelines_backend import errors +from cloud_pipelines_backend import filter_query_sql class TestExecutionStatusSummary: @@ -438,134 +441,395 @@ def test_list_annotations_empty(self, session_factory, service): assert annotations == {} -class TestResolveFilterValue: - def test_no_page_token_no_filter(self): - filter_value, offset = api_server_sql._resolve_filter_value( - filter=None, page_token=None - ) - assert filter_value is None - assert offset == 0 +class TestFilterQueryApiWiring: + def test_filter_query_validates_invalid_json(self, session_factory, service): + from pydantic import ValidationError + + invalid_json = '{"bad_key": "not_valid"}' + with session_factory() as session: + with pytest.raises(ValidationError): + service.list( + session=session, + filter_query=invalid_json, + ) + + def test_mutual_exclusivity_rejected(self, session_factory, service): + with session_factory() as session: + with pytest.raises(errors.ApiValidationError, match="Cannot use both"): + service.list( + session=session, + filter="created_by:alice", + filter_query='{"and": [{"key_exists": {"key": "team"}}]}', + ) - def test_no_page_token_with_filter(self): - filter_value, offset = api_server_sql._resolve_filter_value( - filter="created_by:alice", - page_token=None, - ) - assert filter_value == "created_by:alice" - assert offset == 0 - def test_page_token_overrides_filter(self): - token = api_server_sql._encode_page_token( - {"offset": 20, "filter": "created_by:bob"} +class TestFilterQueryIntegration: + def _set_annotation(self, *, session_factory, service, run_id, key, value): + with session_factory() as session: + service.set_annotation( + session=session, + id=run_id, + key=key, + value=value, + user_name="test-user", + ) + + def test_annotation_key_exists(self, session_factory, service): + run1 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + run2 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", ) - filter_value, offset = api_server_sql._resolve_filter_value( - filter="created_by:alice", - page_token=token, + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run1.id, + key="team", + value="ml-ops", ) - assert filter_value == "created_by:bob" - assert offset == 20 - def test_page_token_without_filter_key(self): - token = api_server_sql._encode_page_token({"offset": 10}) - filter_value, offset = api_server_sql._resolve_filter_value( - filter="created_by:alice", - page_token=token, + fq = json.dumps({"and": [{"key_exists": {"key": "team"}}]}) + with session_factory() as session: + result = service.list( + session=session, + filter_query=fq, + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].id == run1.id + + def test_annotation_value_equals(self, session_factory, service): + run1 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + run2 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", ) - assert filter_value is None - assert offset == 10 + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run1.id, + key="team", + value="ml-ops", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run2.id, + key="team", + value="data-eng", + ) + + fq = json.dumps({"and": [{"value_equals": {"key": "team", "value": "ml-ops"}}]}) + with session_factory() as session: + result = service.list( + session=session, + filter_query=fq, + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].id == run1.id - def test_page_token_without_offset_key(self): - token = api_server_sql._encode_page_token({"filter": "created_by:bob"}) - filter_value, offset = api_server_sql._resolve_filter_value( - filter=None, - page_token=token, + def test_annotation_value_contains(self, session_factory, service): + run1 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + run2 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run1.id, + key="team", + value="ml-ops", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run2.id, + key="team", + value="data-eng", ) - assert filter_value == "created_by:bob" - assert offset == 0 + fq = json.dumps( + {"and": [{"value_contains": {"key": "team", "value_substring": "ml"}}]} + ) + with session_factory() as session: + result = service.list( + session=session, + filter_query=fq, + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].id == run1.id -class TestBuildFilterWhereClauses: - def test_no_filter(self): - clauses, next_filter = api_server_sql._build_filter_where_clauses( - filter_value=None, - current_user=None, + def test_annotation_value_in(self, session_factory, service): + run1 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + run2 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + run3 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run1.id, + key="env", + value="prod", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run2.id, + key="env", + value="staging", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run3.id, + key="env", + value="dev", ) - assert clauses == [] - assert next_filter is None - def test_created_by_literal(self): - clauses, next_filter = api_server_sql._build_filter_where_clauses( - filter_value="created_by:alice", - current_user=None, + fq = json.dumps( + {"and": [{"value_in": {"key": "env", "values": ["prod", "staging"]}}]} ) - assert len(clauses) == 1 - assert next_filter == "created_by:alice" + with session_factory() as session: + result = service.list( + session=session, + filter_query=fq, + ) + assert len(result.pipeline_runs) == 2 + result_ids = {r.id for r in result.pipeline_runs} + assert run1.id in result_ids + assert run2.id in result_ids - def test_created_by_me_resolves(self): - clauses, next_filter = api_server_sql._build_filter_where_clauses( - filter_value="created_by:me", - current_user="alice@example.com", + def test_and_multiple_annotations(self, session_factory, service): + run1 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + run2 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run1.id, + key="team", + value="ml-ops", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run1.id, + key="env", + value="prod", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run2.id, + key="team", + value="ml-ops", ) - assert len(clauses) == 1 - assert next_filter == "created_by:alice@example.com" - def test_created_by_me_no_current_user(self): - clauses, next_filter = api_server_sql._build_filter_where_clauses( - filter_value="created_by:me", - current_user=None, + fq = json.dumps( + { + "and": [ + {"key_exists": {"key": "team"}}, + {"value_equals": {"key": "env", "value": "prod"}}, + ] + } ) - assert len(clauses) == 1 - assert next_filter == "created_by:" + with session_factory() as session: + result = service.list( + session=session, + filter_query=fq, + ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].id == run1.id - def test_created_by_empty_value(self): - clauses, next_filter = api_server_sql._build_filter_where_clauses( - filter_value="created_by:", - current_user=None, + def test_or_annotations(self, session_factory, service): + run1 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + run2 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + run3 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run1.id, + key="team", + value="ml-ops", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run2.id, + key="project", + value="search", ) - assert len(clauses) == 1 - assert next_filter == "created_by:" - def test_unsupported_key_raises(self): - with pytest.raises(NotImplementedError, match="Unsupported filter"): - api_server_sql._build_filter_where_clauses( - filter_value="unknown_key:value", - current_user=None, + fq = json.dumps( + { + "or": [ + {"key_exists": {"key": "team"}}, + {"key_exists": {"key": "project"}}, + ] + } + ) + with session_factory() as session: + result = service.list( + session=session, + filter_query=fq, ) + assert len(result.pipeline_runs) == 2 + result_ids = {r.id for r in result.pipeline_runs} + assert run1.id in result_ids + assert run2.id in result_ids + + def test_not_annotation(self, session_factory, service): + run1 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + run2 = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run1.id, + key="deprecated", + value="true", + ) - def test_text_search_raises(self): - with pytest.raises(NotImplementedError, match="Text search"): - api_server_sql._build_filter_where_clauses( - filter_value="some_text_without_colon", - current_user=None, + fq = json.dumps({"and": [{"not": {"key_exists": {"key": "deprecated"}}}]}) + with session_factory() as session: + result = service.list( + session=session, + filter_query=fq, ) + assert len(result.pipeline_runs) == 1 + assert result.pipeline_runs[0].id == run2.id + def test_no_match(self, session_factory, service): + _create_run(session_factory, service, root_task=_make_task_spec()) -class TestFilterQueryApiWiring: - def test_filter_query_returns_not_implemented(self, session_factory, service): - valid_json = '{"and": [{"key_exists": {"key": "team"}}]}' + fq = json.dumps({"and": [{"key_exists": {"key": "nonexistent"}}]}) + with session_factory() as session: + result = service.list( + session=session, + filter_query=fq, + ) + assert len(result.pipeline_runs) == 0 + + def test_time_range_raises_not_implemented(self, session_factory, service): + fq = json.dumps( + { + "and": [ + { + "time_range": { + "key": "system/pipeline_run.date.created_at", + "start_time": "2024-01-01T00:00:00Z", + } + } + ] + } + ) with session_factory() as session: - with pytest.raises(NotImplementedError, match="not yet implemented"): + with pytest.raises(NotImplementedError, match="TimeRangePredicate"): service.list( session=session, - filter_query=valid_json, + filter_query=fq, ) - def test_filter_query_validates_before_501(self, session_factory, service): - from pydantic import ValidationError + def test_pagination_preserves_filter_query(self, session_factory, service): + for _ in range(12): + run = _create_run( + session_factory, + service, + root_task=_make_task_spec(), + created_by="test-user", + ) + self._set_annotation( + session_factory=session_factory, + service=service, + run_id=run.id, + key="team", + value="ml-ops", + ) - invalid_json = '{"bad_key": "not_valid"}' + fq = json.dumps({"and": [{"key_exists": {"key": "team"}}]}) with session_factory() as session: - with pytest.raises(ValidationError): - service.list( - session=session, - filter_query=invalid_json, - ) + page1 = service.list( + session=session, + filter_query=fq, + ) + assert len(page1.pipeline_runs) == 10 + assert page1.next_page_token is not None + + decoded = filter_query_sql._decode_page_token(page_token=page1.next_page_token) + assert decoded["filter_query"] == fq - def test_mutual_exclusivity_rejected(self, session_factory, service): with session_factory() as session: - with pytest.raises(errors.ApiValidationError, match="Cannot use both"): - service.list( - session=session, - filter="created_by:alice", - filter_query='{"and": [{"key_exists": {"key": "team"}}]}', - ) + page2 = service.list( + session=session, + page_token=page1.next_page_token, + ) + assert len(page2.pipeline_runs) == 2 + assert page2.next_page_token is None diff --git a/tests/test_filter_query_sql.py b/tests/test_filter_query_sql.py new file mode 100644 index 0000000..2e0bad8 --- /dev/null +++ b/tests/test_filter_query_sql.py @@ -0,0 +1,391 @@ +import pytest +import sqlalchemy as sql +from sqlalchemy.dialects import sqlite as sqlite_dialect + +from cloud_pipelines_backend import backend_types_sql as bts +from cloud_pipelines_backend import errors +from cloud_pipelines_backend import filter_query_models +from cloud_pipelines_backend import filter_query_sql + + +def _compile(clause: sql.ColumnElement) -> str: + """Compile a SQLAlchemy clause to a normalized string for assertion.""" + raw = str( + clause.compile( + dialect=sqlite_dialect.dialect(), + compile_kwargs={"literal_binds": True}, + ) + ) + return " ".join(raw.split()) + + +class TestLeafPredicates: + def test_key_exists(self): + fq = filter_query_models.FilterQuery.model_validate( + {"and": [{"key_exists": {"key": "team"}}]}, + ) + clause = filter_query_sql.filter_query_to_where_clause(filter_query=fq) + assert _compile(clause) == ( + "EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'team')" + ) + + def test_value_equals(self): + fq = filter_query_models.FilterQuery.model_validate( + {"and": [{"value_equals": {"key": "team", "value": "ml-ops"}}]}, + ) + clause = filter_query_sql.filter_query_to_where_clause(filter_query=fq) + assert _compile(clause) == ( + "EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'team'" + " AND pipeline_run_annotation.value = 'ml-ops')" + ) + + def test_value_contains(self): + fq = filter_query_models.FilterQuery.model_validate( + {"and": [{"value_contains": {"key": "team", "value_substring": "ml"}}]}, + ) + clause = filter_query_sql.filter_query_to_where_clause(filter_query=fq) + assert _compile(clause) == ( + "EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'team'" + " AND (pipeline_run_annotation.value LIKE '%' || 'ml' || '%'))" + ) + + def test_value_in(self): + fq = filter_query_models.FilterQuery.model_validate( + {"and": [{"value_in": {"key": "env", "values": ["prod", "staging"]}}]}, + ) + clause = filter_query_sql.filter_query_to_where_clause(filter_query=fq) + assert _compile(clause) == ( + "EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'env'" + " AND pipeline_run_annotation.value IN ('prod', 'staging'))" + ) + + +class TestLogicalPredicates: + def test_and_predicates(self): + fq = filter_query_models.FilterQuery.model_validate( + { + "and": [ + {"key_exists": {"key": "team"}}, + {"value_equals": {"key": "env", "value": "prod"}}, + ] + }, + ) + clause = filter_query_sql.filter_query_to_where_clause(filter_query=fq) + assert _compile(clause) == ( + "(EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'team'))" + " AND" + " (EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'env'" + " AND pipeline_run_annotation.value = 'prod'))" + ) + + def test_or_predicates(self): + fq = filter_query_models.FilterQuery.model_validate( + { + "or": [ + {"key_exists": {"key": "team"}}, + {"key_exists": {"key": "project"}}, + ] + }, + ) + clause = filter_query_sql.filter_query_to_where_clause(filter_query=fq) + assert _compile(clause) == ( + "(EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'team'))" + " OR" + " (EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'project'))" + ) + + def test_not_predicate(self): + fq = filter_query_models.FilterQuery.model_validate( + {"and": [{"not": {"key_exists": {"key": "deprecated"}}}]}, + ) + clause = filter_query_sql.filter_query_to_where_clause(filter_query=fq) + assert _compile(clause) == ( + "NOT (EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'deprecated'))" + ) + + def test_nested_and_or(self): + fq = filter_query_models.FilterQuery.model_validate( + { + "and": [ + { + "or": [ + {"key_exists": {"key": "team"}}, + {"key_exists": {"key": "project"}}, + ] + }, + {"value_equals": {"key": "env", "value": "prod"}}, + ] + }, + ) + clause = filter_query_sql.filter_query_to_where_clause(filter_query=fq) + assert _compile(clause) == ( + "((EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'team'))" + " OR" + " (EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'project')))" + " AND" + " (EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'env'" + " AND pipeline_run_annotation.value = 'prod'))" + ) + + +class TestUnsupportedPredicate: + def test_time_range_raises_not_implemented(self): + predicate = filter_query_models.TimeRangePredicate( + time_range=filter_query_models.TimeRange( + key="system/pipeline_run.date.created_at", + start_time="2024-01-01T00:00:00Z", + ) + ) + with pytest.raises(NotImplementedError, match="TimeRangePredicate"): + filter_query_sql._predicate_to_clause(predicate=predicate) + + +class TestPageToken: + def test_decode_none(self): + token = filter_query_sql._decode_page_token(page_token=None) + assert token == {} + + def test_encode_decode_roundtrip(self): + encoded = filter_query_sql._encode_page_token( + page_token_dict={ + "offset": 20, + "filter": "created_by:alice", + "filter_query": '{"and": [{"key_exists": {"key": "team"}}]}', + } + ) + decoded = filter_query_sql._decode_page_token(page_token=encoded) + assert decoded["offset"] == 20 + assert decoded["filter"] == "created_by:alice" + assert decoded["filter_query"] == '{"and": [{"key_exists": {"key": "team"}}]}' + + def test_decode_with_filter_query(self): + fq_json = '{"or": [{"value_equals": {"key": "env", "value": "prod"}}]}' + encoded = filter_query_sql._encode_page_token( + page_token_dict={ + "offset": 10, + "filter_query": fq_json, + } + ) + decoded = filter_query_sql._decode_page_token(page_token=encoded) + assert decoded["filter_query"] == fq_json + assert decoded.get("filter") is None + assert decoded["offset"] == 10 + + def test_decode_empty_string(self): + token = filter_query_sql._decode_page_token(page_token="") + assert token == {} + + +class TestBuildFilterWhereClauses: + def test_no_filter(self): + clauses, next_filter = filter_query_sql._build_filter_where_clauses( + filter_value=None, + current_user=None, + ) + assert clauses == [] + assert next_filter is None + + def test_created_by_literal(self): + clauses, next_filter = filter_query_sql._build_filter_where_clauses( + filter_value="created_by:alice", + current_user=None, + ) + assert len(clauses) == 1 + assert next_filter == "created_by:alice" + + def test_created_by_me_resolves(self): + clauses, next_filter = filter_query_sql._build_filter_where_clauses( + filter_value="created_by:me", + current_user="alice@example.com", + ) + assert len(clauses) == 1 + assert next_filter == "created_by:alice@example.com" + + def test_created_by_me_no_current_user(self): + clauses, next_filter = filter_query_sql._build_filter_where_clauses( + filter_value="created_by:me", + current_user=None, + ) + assert len(clauses) == 1 + assert next_filter == "created_by:" + + def test_created_by_empty_value(self): + clauses, next_filter = filter_query_sql._build_filter_where_clauses( + filter_value="created_by:", + current_user=None, + ) + assert len(clauses) == 1 + assert next_filter == "created_by:" + + def test_unsupported_key_raises(self): + with pytest.raises(NotImplementedError, match="Unsupported filter"): + filter_query_sql._build_filter_where_clauses( + filter_value="unknown_key:value", + current_user=None, + ) + + def test_text_search_raises(self): + with pytest.raises(NotImplementedError, match="Text search"): + filter_query_sql._build_filter_where_clauses( + filter_value="some_text_without_colon", + current_user=None, + ) + + +class TestBuildListFilters: + def test_no_filters(self): + clauses, offset, next_token = filter_query_sql.build_list_filters( + filter_value=None, + filter_query_value=None, + page_token_value=None, + current_user=None, + page_size=10, + ) + assert clauses == [] + assert offset == 0 + decoded = filter_query_sql._decode_page_token(page_token=next_token) + assert decoded["offset"] == 10 + assert decoded.get("filter") is None + assert decoded.get("filter_query") is None + + def test_mutual_exclusivity_raises(self): + with pytest.raises(errors.ApiValidationError, match="Cannot use both"): + filter_query_sql.build_list_filters( + filter_value="created_by:alice", + filter_query_value='{"and": [{"key_exists": {"key": "team"}}]}', + page_token_value=None, + current_user=None, + page_size=10, + ) + + def test_legacy_filter_produces_clauses(self): + clauses, offset, next_token = filter_query_sql.build_list_filters( + filter_value="created_by:alice", + filter_query_value=None, + page_token_value=None, + current_user=None, + page_size=10, + ) + assert len(clauses) == 1 + assert offset == 0 + decoded = filter_query_sql._decode_page_token(page_token=next_token) + assert decoded["filter"] == "created_by:alice" + + def test_filter_query_produces_clauses(self): + fq = '{"and": [{"key_exists": {"key": "team"}}]}' + clauses, offset, next_token = filter_query_sql.build_list_filters( + filter_value=None, + filter_query_value=fq, + page_token_value=None, + current_user=None, + page_size=10, + ) + assert len(clauses) == 1 + assert _compile(clauses[0]) == ( + "EXISTS (SELECT pipeline_run_annotation.pipeline_run_id" + " FROM pipeline_run_annotation, pipeline_run" + " WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'team')" + ) + decoded = filter_query_sql._decode_page_token(page_token=next_token) + assert decoded["filter_query"] == fq + + def test_page_token_restores_offset_and_filters(self): + encoded_token = filter_query_sql._encode_page_token( + page_token_dict={ + "offset": 20, + "filter": "created_by:alice", + } + ) + clauses, offset, next_token = filter_query_sql.build_list_filters( + filter_value=None, + filter_query_value=None, + page_token_value=encoded_token, + current_user=None, + page_size=10, + ) + assert offset == 20 + assert len(clauses) == 1 + decoded = filter_query_sql._decode_page_token(page_token=next_token) + assert decoded["offset"] == 30 + assert decoded["filter"] == "created_by:alice" + + def test_page_token_restores_filter_query(self): + fq = '{"and": [{"key_exists": {"key": "env"}}]}' + encoded_token = filter_query_sql._encode_page_token( + page_token_dict={ + "offset": 10, + "filter_query": fq, + } + ) + clauses, offset, next_token = filter_query_sql.build_list_filters( + filter_value=None, + filter_query_value=None, + page_token_value=encoded_token, + current_user=None, + page_size=5, + ) + assert offset == 10 + assert len(clauses) == 1 + decoded = filter_query_sql._decode_page_token(page_token=next_token) + assert decoded["offset"] == 15 + assert decoded["filter_query"] == fq + + def test_page_size_reflected_in_next_token(self): + _, _, next_token = filter_query_sql.build_list_filters( + filter_value=None, + filter_query_value=None, + page_token_value=None, + current_user=None, + page_size=25, + ) + decoded = filter_query_sql._decode_page_token(page_token=next_token) + assert decoded["offset"] == 25 + + def test_created_by_me_resolved_in_next_token(self): + clauses, offset, next_token = filter_query_sql.build_list_filters( + filter_value="created_by:me", + filter_query_value=None, + page_token_value=None, + current_user="bob@example.com", + page_size=10, + ) + assert len(clauses) == 1 + decoded = filter_query_sql._decode_page_token(page_token=next_token) + assert decoded["filter"] == "created_by:bob@example.com"