Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 47 additions & 10 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class GetPipelineRunResponse(PipelineRunResponse):
class ListPipelineJobsResponse:
pipeline_runs: list[PipelineRunResponse]
next_page_token: str | None = None
sql: str | None = None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brainstorming ideas:



class PipelineRunsApiService_Sql:
Expand Down Expand Up @@ -196,6 +197,36 @@ def terminate(
execution_node.extra_data["desired_state"] = "TERMINATED"
session.commit()

@staticmethod
def _compile_sql_string(
stmt: sql.Select,
dialect: sql.engine.Dialect,
) -> str:
"""Compile a SQLAlchemy statement to a SQL string for debugging.

Uses ``literal_binds=True`` to inline bound parameters as literal
values, producing a self-contained query string::

SELECT ... WHERE key = 'environment' AND created_at < '2024-01-15' LIMIT 10

If a column type lacks a ``literal_processor`` (raises CompileError or
NotImplementedError), falls back to placeholder syntax with a params
comment::

SELECT ... WHERE key = :key_1 AND created_at < :created_at_1 LIMIT :param_1
-- params: {'key_1': 'environment', 'created_at_1': '2024-01-15', 'param_1': 10}
"""
try:
compiled = stmt.compile(
dialect=dialect,
compile_kwargs={"literal_binds": True},
)
return str(compiled)
except (sql.exc.CompileError, NotImplementedError):
compiled = stmt.compile(dialect=dialect)
params_suffix = f"\n-- params: {compiled.params}" if compiled.params else ""
return str(compiled) + params_suffix

# Note: This method must be last to not shadow the "list" type
def list(
self,
Expand All @@ -207,6 +238,7 @@ def list(
current_user: str | None = None,
include_pipeline_names: bool = False,
include_execution_stats: bool = False,
include_sql: bool = False,
) -> ListPipelineJobsResponse:
where_clauses = filter_query_sql.build_list_filters(
filter_value=filter,
Expand All @@ -215,18 +247,22 @@ def list(
current_user=current_user,
)

pipeline_runs = list(
session.scalars(
sql.select(bts.PipelineRun)
.where(*where_clauses)
.order_by(
bts.PipelineRun.created_at.desc(),
bts.PipelineRun.id.desc(),
)
.limit(self._DEFAULT_PAGE_SIZE)
).all()
stmt = (
sql.select(bts.PipelineRun)
.where(*where_clauses)
.order_by(
bts.PipelineRun.created_at.desc(),
bts.PipelineRun.id.desc(),
)
.limit(self._DEFAULT_PAGE_SIZE)
)

sql_string = None
if include_sql:
sql_string = self._compile_sql_string(stmt, session.bind.dialect)

pipeline_runs = list(session.scalars(stmt).all())

next_page_token = filter_query_sql.maybe_next_page_token(
rows=pipeline_runs, page_size=self._DEFAULT_PAGE_SIZE
)
Expand All @@ -242,6 +278,7 @@ def list(
for pipeline_run in pipeline_runs
],
next_page_token=next_page_token,
sql=sql_string,
)

def _create_pipeline_run_response(
Expand Down
81 changes: 81 additions & 0 deletions tests/test_api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,87 @@ def test_list_filter_created_by_me(self, session_factory, service):
assert len(result.pipeline_runs) == 1
assert result.pipeline_runs[0].created_by == "alice@example.com"

def test_list_include_sql_default_none(self, session_factory, service):
_create_run(session_factory, service, root_task=_make_task_spec())

with session_factory() as session:
result = service.list(session=session)
assert result.sql is None

def test_list_include_sql_true(self, session_factory, service):
_create_run(session_factory, service, root_task=_make_task_spec())

with session_factory() as session:
result = service.list(session=session, include_sql=True)
expected = (
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
" pipeline_run.annotations, pipeline_run.created_by,"
" pipeline_run.created_at, pipeline_run.updated_at,"
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
"FROM pipeline_run"
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
" LIMIT 10 OFFSET 0"
)
assert result.sql == expected

def test_list_include_sql_with_filter_query(self, session_factory, service):
run = _create_run(session_factory, service, root_task=_make_task_spec())
with session_factory() as session:
service.set_annotation(session=session, id=run.id, key="team", value="ml")

fq = json.dumps({"and": [{"key_exists": {"key": "team"}}]})
with session_factory() as session:
result = service.list(session=session, filter_query=fq, include_sql=True)
expected = (
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
" pipeline_run.annotations, pipeline_run.created_by,"
" pipeline_run.created_at, pipeline_run.updated_at,"
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
"FROM pipeline_run \n"
"WHERE EXISTS (SELECT pipeline_run_annotation.pipeline_run_id \n"
"FROM pipeline_run_annotation \n"
"WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id"
" AND pipeline_run_annotation.\"key\" = 'team')"
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
" LIMIT 10 OFFSET 0"
)
assert result.sql == expected

def test_list_include_sql_with_cursor(self, session_factory, service):
for i in range(12):
_create_run(
session_factory,
service,
root_task=_make_task_spec(f"pipeline-{i}"),
)

with session_factory() as session:
page1 = service.list(session=session)
assert page1.next_page_token is not None

with session_factory() as session:
page2 = service.list(
session=session,
page_token=page1.next_page_token,
include_sql=True,
)

cursor_dt_iso, cursor_id = page1.next_page_token.split("~")
cursor_dt = datetime.datetime.fromisoformat(cursor_dt_iso)
sql_dt = cursor_dt.strftime("%Y-%m-%d %H:%M:%S.%f")
expected = (
"SELECT pipeline_run.id, pipeline_run.root_execution_id,"
" pipeline_run.annotations, pipeline_run.created_by,"
" pipeline_run.created_at, pipeline_run.updated_at,"
" pipeline_run.parent_pipeline_id, pipeline_run.extra_data \n"
"FROM pipeline_run \n"
f"WHERE (pipeline_run.created_at, pipeline_run.id)"
f" < ('{sql_dt}', '{cursor_id}')"
" ORDER BY pipeline_run.created_at DESC, pipeline_run.id DESC\n"
" LIMIT 10 OFFSET 0"
)
assert page2.sql == expected


class TestCreatePipelineRunResponse:
def test_base_response(self, session_factory, service):
Expand Down