diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index 232658e..c2cb93d 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -85,6 +85,7 @@ class GetPipelineRunResponse(PipelineRunResponse): class ListPipelineJobsResponse: pipeline_runs: list[PipelineRunResponse] next_page_token: str | None = None + sql: str | None = None class PipelineRunsApiService_Sql: @@ -196,6 +197,36 @@ def terminate( execution_node.extra_data["desired_state"] = "TERMINATED" session.commit() + @staticmethod + def _compile_sql_string( + stmt: sql.Select, + dialect: sql.engine.Dialect, + ) -> str: + """Compile a SQLAlchemy statement to a SQL string for debugging. + + Uses ``literal_binds=True`` to inline bound parameters as literal + values, producing a self-contained query string:: + + SELECT ... WHERE key = 'environment' AND created_at < '2024-01-15' LIMIT 10 + + If a column type lacks a ``literal_processor`` (raises CompileError or + NotImplementedError), falls back to placeholder syntax with a params + comment:: + + SELECT ... WHERE key = :key_1 AND created_at < :created_at_1 LIMIT :param_1 + -- params: {'key_1': 'environment', 'created_at_1': '2024-01-15', 'param_1': 10} + """ + try: + compiled = stmt.compile( + dialect=dialect, + compile_kwargs={"literal_binds": True}, + ) + return str(compiled) + except (sql.exc.CompileError, NotImplementedError): + compiled = stmt.compile(dialect=dialect) + params_suffix = f"\n-- params: {compiled.params}" if compiled.params else "" + return str(compiled) + params_suffix + # Note: This method must be last to not shadow the "list" type def list( self, @@ -207,6 +238,7 @@ def list( current_user: str | None = None, include_pipeline_names: bool = False, include_execution_stats: bool = False, + include_sql: bool = False, ) -> ListPipelineJobsResponse: where_clauses = filter_query_sql.build_list_filters( filter_value=filter, @@ -215,18 +247,22 @@ def list( current_user=current_user, ) - pipeline_runs = list( - session.scalars( - sql.select(bts.PipelineRun) - .where(*where_clauses) - .order_by( - bts.PipelineRun.created_at.desc(), - bts.PipelineRun.id.desc(), - ) - .limit(self._DEFAULT_PAGE_SIZE) - ).all() + stmt = ( + sql.select(bts.PipelineRun) + .where(*where_clauses) + .order_by( + bts.PipelineRun.created_at.desc(), + bts.PipelineRun.id.desc(), + ) + .limit(self._DEFAULT_PAGE_SIZE) ) + sql_string = None + if include_sql: + sql_string = self._compile_sql_string(stmt, session.bind.dialect) + + pipeline_runs = list(session.scalars(stmt).all()) + next_page_token = filter_query_sql.maybe_next_page_token( rows=pipeline_runs, page_size=self._DEFAULT_PAGE_SIZE ) @@ -242,6 +278,7 @@ def list( for pipeline_run in pipeline_runs ], next_page_token=next_page_token, + sql=sql_string, ) def _create_pipeline_run_response( diff --git a/tests/test_api_server_sql.py b/tests/test_api_server_sql.py index 9586886..90f2b2c 100644 --- a/tests/test_api_server_sql.py +++ b/tests/test_api_server_sql.py @@ -295,6 +295,87 @@ def test_list_filter_created_by_me(self, session_factory, service): assert len(result.pipeline_runs) == 1 assert result.pipeline_runs[0].created_by == "alice@example.com" + def test_list_include_sql_default_none(self, session_factory, service): + _create_run(session_factory, service, root_task=_make_task_spec()) + + with session_factory() as session: + result = service.list(session=session) + assert result.sql is None + + def test_list_include_sql_true(self, session_factory, service): + _create_run(session_factory, service, root_task=_make_task_spec()) + + with session_factory() as session: + result = service.list(session=session, include_sql=True) + expected = ( + "SELECT pipeline_run.id, pipeline_run.root_execution_id," + " pipeline_run.annotations, pipeline_run.created_by," + " pipeline_run.created_at, pipeline_run.updated_at," + " pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n" + "FROM pipeline_run" + " ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n" + " LIMIT 10 OFFSET 0" + ) + assert result.sql == expected + + def test_list_include_sql_with_filter_query(self, session_factory, service): + run = _create_run(session_factory, service, root_task=_make_task_spec()) + with session_factory() as session: + service.set_annotation(session=session, id=run.id, key="team", value="ml") + + fq = json.dumps({"and": [{"key_exists": {"key": "team"}}]}) + with session_factory() as session: + result = service.list(session=session, filter_query=fq, include_sql=True) + expected = ( + "SELECT pipeline_run.id, pipeline_run.root_execution_id," + " pipeline_run.annotations, pipeline_run.created_by," + " pipeline_run.created_at, pipeline_run.updated_at," + " pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n" + "FROM pipeline_run \n" + "WHERE EXISTS (SELECT pipeline_run_annotation.pipeline_run_id \n" + "FROM pipeline_run_annotation \n" + "WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id" + " AND pipeline_run_annotation.\"key\" = 'team')" + " ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n" + " LIMIT 10 OFFSET 0" + ) + assert result.sql == expected + + def test_list_include_sql_with_cursor(self, session_factory, service): + for i in range(12): + _create_run( + session_factory, + service, + root_task=_make_task_spec(f"pipeline-{i}"), + ) + + with session_factory() as session: + page1 = service.list(session=session) + assert page1.next_page_token is not None + + with session_factory() as session: + page2 = service.list( + session=session, + page_token=page1.next_page_token, + include_sql=True, + ) + + cursor_dt_iso, cursor_id = page1.next_page_token.split("~") + cursor_dt = datetime.datetime.fromisoformat(cursor_dt_iso) + sql_dt = cursor_dt.strftime("%Y-%m-%d %H:%M:%S.%f") + expected = ( + "SELECT pipeline_run.id, pipeline_run.root_execution_id," + " pipeline_run.annotations, pipeline_run.created_by," + " pipeline_run.created_at, pipeline_run.updated_at," + " pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n" + "FROM pipeline_run \n" + f"WHERE (pipeline_run.created_at, pipeline_run.id)" + f" < ('{sql_dt}', '{cursor_id}')" + " ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n" + " LIMIT 10 OFFSET 0" + ) + assert page2.sql == expected + class TestCreatePipelineRunResponse: def test_base_response(self, session_factory, service):