Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 141 additions & 93 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
import json
import logging
import typing
from typing import Any, Optional
from typing import Any, Final, Optional

import sqlalchemy as sql
from sqlalchemy import orm

from . import backend_types_sql as bts
from . import component_structures as structures
from . import errors

if typing.TYPE_CHECKING:
from cloud_pipelines.orchestration.storage_providers import (
Expand All @@ -26,12 +33,6 @@ def _get_current_time() -> datetime.datetime:
return datetime.datetime.now(tz=datetime.timezone.utc)


from . import component_structures as structures
from . import backend_types_sql as bts
from . import errors
from .errors import ItemNotFoundError


# ==== PipelineJobService
@dataclasses.dataclass(kw_only=True)
class PipelineRunResponse:
Expand Down Expand Up @@ -65,12 +66,11 @@ class ListPipelineJobsResponse:
next_page_token: str | None = None


import sqlalchemy as sql
from sqlalchemy import orm


class PipelineRunsApiService_Sql:
PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name"
_PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name"
_PAGE_TOKEN_OFFSET_KEY: Final[str] = "offset"
_PAGE_TOKEN_FILTER_KEY: Final[str] = "filter"
_DEFAULT_PAGE_SIZE: Final[int] = 10

def create(
self,
Expand Down Expand Up @@ -104,7 +104,7 @@ def create(
annotations=annotations,
created_by=created_by,
extra_data={
self.PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name,
self._PIPELINE_NAME_EXTRA_DATA_KEY: pipeline_name,
},
)
session.add(pipeline_run)
Expand All @@ -116,7 +116,7 @@ def create(
def get(self, session: orm.Session, id: bts.IdType) -> PipelineRunResponse:
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
raise ItemNotFoundError(f"Pipeline run {id} not found.")
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
return PipelineRunResponse.from_db(pipeline_run)

def terminate(
Expand All @@ -128,7 +128,7 @@ def terminate(
):
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
raise ItemNotFoundError(f"Pipeline run {id} not found.")
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
if not skip_user_check and (terminated_by != pipeline_run.created_by):
raise errors.PermissionError(
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be terminated by {terminated_by}"
Expand Down Expand Up @@ -166,98 +166,86 @@ def list(
*,
session: orm.Session,
page_token: str | None = None,
# page_size: int = 10,
filter: str | None = None,
current_user: str | None = None,
include_pipeline_names: bool = False,
include_execution_stats: bool = False,
) -> ListPipelineJobsResponse:
page_token_dict = _decode_page_token(page_token)
OFFSET_KEY = "offset"
offset = page_token_dict.get(OFFSET_KEY, 0)
page_size = 10

FILTER_KEY = "filter"
if page_token:
filter = page_token_dict.get(FILTER_KEY, None)
where_clauses = []
parsed_filter = _parse_filter(filter) if filter else {}
for key, value in parsed_filter.items():
if key == "_text":
raise NotImplementedError("Text search is not implemented yet.")
elif key == "created_by":
if value == "me":
if current_user is None:
# raise ApiServiceError(
# f"The `created_by:me` filter requires `current_user`."
# )
current_user = ""
value = current_user
# TODO: Maybe make this a bit more robust.
# We need to change the filter since it goes into the next_page_token.
filter = filter.replace(
"created_by:me", f"created_by:{current_user}"
)
if value:
where_clauses.append(bts.PipelineRun.created_by == value)
else:
where_clauses.append(bts.PipelineRun.created_by == None)
else:
raise NotImplementedError(f"Unsupported filter {filter}.")
filter_value, offset = _resolve_filter_value(
filter=filter,
page_token=page_token,
)
where_clauses, next_page_filter_value = _build_filter_where_clauses(
filter_value=filter_value,
current_user=current_user,
)

pipeline_runs = list(
session.scalars(
sql.select(bts.PipelineRun)
.where(*where_clauses)
.order_by(bts.PipelineRun.created_at.desc())
.offset(offset)
.limit(page_size)
.limit(self._DEFAULT_PAGE_SIZE)
).all()
)
next_page_offset = offset + page_size
next_page_token_dict = {OFFSET_KEY: next_page_offset, FILTER_KEY: filter}
next_page_offset = offset + self._DEFAULT_PAGE_SIZE
next_page_token_dict = {
self._PAGE_TOKEN_OFFSET_KEY: next_page_offset,
self._PAGE_TOKEN_FILTER_KEY: next_page_filter_value,
}
next_page_token = _encode_page_token(next_page_token_dict)
if len(pipeline_runs) < page_size:
if len(pipeline_runs) < self._DEFAULT_PAGE_SIZE:
next_page_token = None

def create_pipeline_run_response(
pipeline_run: bts.PipelineRun,
) -> PipelineRunResponse:
response = PipelineRunResponse.from_db(pipeline_run)
if include_pipeline_names:
pipeline_name = None
extra_data = pipeline_run.extra_data or {}
if self.PIPELINE_NAME_EXTRA_DATA_KEY in extra_data:
pipeline_name = extra_data[self.PIPELINE_NAME_EXTRA_DATA_KEY]
else:
execution_node = session.get(
bts.ExecutionNode, pipeline_run.root_execution_id
)
if execution_node:
task_spec = structures.TaskSpec.from_json_dict(
execution_node.task_spec
)
component_spec = task_spec.component_ref.spec
if component_spec:
pipeline_name = component_spec.name
response.pipeline_name = pipeline_name
if include_execution_stats:
execution_status_stats = self._calculate_execution_status_stats(
session=session, root_execution_id=pipeline_run.root_execution_id
)
response.execution_status_stats = {
status.value: count
for status, count in execution_status_stats.items()
}
return response

return ListPipelineJobsResponse(
pipeline_runs=[
create_pipeline_run_response(pipeline_run)
self._create_pipeline_run_response(
session=session,
pipeline_run=pipeline_run,
include_pipeline_names=include_pipeline_names,
include_execution_stats=include_execution_stats,
)
for pipeline_run in pipeline_runs
],
next_page_token=next_page_token,
)

def _create_pipeline_run_response(
self,
*,
session: orm.Session,
pipeline_run: bts.PipelineRun,
include_pipeline_names: bool,
include_execution_stats: bool,
) -> PipelineRunResponse:
response = PipelineRunResponse.from_db(pipeline_run)
if include_pipeline_names:
pipeline_name = None
extra_data = pipeline_run.extra_data or {}
if self._PIPELINE_NAME_EXTRA_DATA_KEY in extra_data:
pipeline_name = extra_data[self._PIPELINE_NAME_EXTRA_DATA_KEY]
else:
execution_node = session.get(
bts.ExecutionNode, pipeline_run.root_execution_id
)
if execution_node:
task_spec = structures.TaskSpec.from_json_dict(
execution_node.task_spec
)
component_spec = task_spec.component_ref.spec
if component_spec:
pipeline_name = component_spec.name
response.pipeline_name = pipeline_name
if include_execution_stats:
execution_status_stats = self._calculate_execution_status_stats(
session=session, root_execution_id=pipeline_run.root_execution_id
)
response.execution_status_stats = {
status.value: count for status, count in execution_status_stats.items()
}
return response

def _calculate_execution_status_stats(
self, session: orm.Session, root_execution_id: bts.IdType
) -> dict[bts.ContainerExecutionStatus, int]:
Expand Down Expand Up @@ -316,7 +304,7 @@ def set_annotation(
):
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
raise ItemNotFoundError(f"Pipeline run {id} not found.")
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
if not skip_user_check and (user_name != pipeline_run.created_by):
raise errors.PermissionError(
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}"
Expand All @@ -338,7 +326,7 @@ def delete_annotation(
):
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
raise ItemNotFoundError(f"Pipeline run {id} not found.")
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
if not skip_user_check and (user_name != pipeline_run.created_by):
raise errors.PermissionError(
f"The pipeline run {id} was started by {pipeline_run.created_by} and cannot be changed by {user_name}"
Expand All @@ -349,6 +337,64 @@ def delete_annotation(
session.commit()


def _resolve_filter_value(
*,
filter: str | None,
page_token: str | None,
) -> tuple[str | None, int]:
"""Decode page_token and return the effective (filter_value, offset).

If a page_token is present, its stored filter takes precedence over the
raw filter parameter (the token carries the resolved filter forward across pages).
"""
page_token_dict = _decode_page_token(page_token)
offset = page_token_dict.get(
PipelineRunsApiService_Sql._PAGE_TOKEN_OFFSET_KEY,
0,
)
if page_token:
filter = page_token_dict.get(
PipelineRunsApiService_Sql._PAGE_TOKEN_FILTER_KEY,
None,
)
return filter, offset


def _build_filter_where_clauses(
*,
filter_value: str | None,
current_user: str | None,
) -> tuple[list[sql.ColumnElement], str | None]:
"""Parse a filter string into SQLAlchemy WHERE clauses.

Returns (where_clauses, next_page_filter_value). The second value is the
filter string with shorthand values resolved (e.g. "created_by:me" becomes
"created_by:alice@example.com") so it can be embedded in the next page token.
"""
where_clauses: list[sql.ColumnElement] = []
parsed_filter = _parse_filter(filter_value) if filter_value else {}
for key, value in parsed_filter.items():
if key == "_text":
raise NotImplementedError("Text search is not implemented yet.")
elif key == "created_by":
if value == "me":
if current_user is None:
current_user = ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Shouldn't we raise an error here if an unauthenticated user is trying to filter by me? Right now, we end doing WHERE created_by IS NULL if there is no current_user which does not feel correct

value = current_user
# TODO: Maybe make this a bit more robust.
# We need to change the filter since it goes into the next_page_token.
filter_value = filter_value.replace(
"created_by:me", f"created_by:{current_user}"
)
if value:
where_clauses.append(bts.PipelineRun.created_by == value)
else:
where_clauses.append(bts.PipelineRun.created_by == None)
else:
raise NotImplementedError(f"Unsupported filter {filter_value}.")
return where_clauses, filter_value


def _decode_page_token(page_token: str) -> dict[str, Any]:
return json.loads(base64.b64decode(page_token)) if page_token else {}

Expand Down Expand Up @@ -524,7 +570,7 @@ class ExecutionNodesApiService_Sql:
def get(self, session: orm.Session, id: bts.IdType) -> GetExecutionInfoResponse:
execution_node = session.get(bts.ExecutionNode, id)
if execution_node is None:
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")

parent_pipeline_run_id = session.scalar(
sql.select(bts.PipelineRun.id).where(
Expand Down Expand Up @@ -676,7 +722,7 @@ def get_container_execution_state(
) -> GetContainerExecutionStateResponse:
execution = session.get(bts.ExecutionNode, id)
if not execution:
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
container_execution = execution.container_execution
if not container_execution:
raise RuntimeError(
Expand All @@ -696,7 +742,7 @@ def get_artifacts(
if not session.scalar(
sql.select(sql.exists().where(bts.ExecutionNode.id == id))
):
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")

input_artifact_links = session.scalars(
sql.select(bts.InputArtifactLink)
Expand Down Expand Up @@ -742,7 +788,7 @@ def get_container_execution_log(
) -> GetContainerExecutionLogResponse:
execution = session.get(bts.ExecutionNode, id)
if not execution:
raise ItemNotFoundError(f"Execution with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Execution with {id=} does not exist.")
container_execution = execution.container_execution
execution_extra_data = execution.extra_data or {}
system_error_exception_full = execution_extra_data.get(
Expand Down Expand Up @@ -829,7 +875,9 @@ def stream_container_execution_log(
) -> typing.Iterator[str]:
execution = session.get(bts.ExecutionNode, execution_id)
if not execution:
raise ItemNotFoundError(f"Execution with {execution_id=} does not exist.")
raise errors.ItemNotFoundError(
f"Execution with {execution_id=} does not exist."
)
container_execution = execution.container_execution
if not container_execution:
raise ApiServiceError(
Expand Down Expand Up @@ -970,7 +1018,7 @@ class ArtifactNodesApiService_Sql:
def get(self, session: orm.Session, id: bts.IdType) -> GetArtifactInfoResponse:
artifact_node = session.get(bts.ArtifactNode, id)
if artifact_node is None:
raise ItemNotFoundError(f"Artifact with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Artifact with {id=} does not exist.")
artifact_data = artifact_node.artifact_data
result = GetArtifactInfoResponse(id=artifact_node.id)
if artifact_data:
Expand All @@ -986,7 +1034,7 @@ def get_signed_artifact_url(
.where(bts.ArtifactNode.id == id)
)
if not artifact_data:
raise ItemNotFoundError(f"Artifact node with {id=} does not exist.")
raise errors.ItemNotFoundError(f"Artifact node with {id=} does not exist.")
if not artifact_data.uri:
raise ValueError(f"Artifact node with {id=} does not have artifact URI.")
if artifact_data.is_dir:
Expand Down
Loading