Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions cloud_pipelines_backend/api_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def inject_user_name(func: typing.Callable, parameter_name: str = "user_name"):

SessionDep = typing.Annotated[orm.Session, fastapi.Depends(get_session)]

def inject_session_dependency(func: typing.Callable) -> typing.Callable:
def inject_session_dependency(func: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.Any]:
return replace_annotations(
func, original_annotation=orm.Session, new_annotation=SessionDep
)
Expand Down Expand Up @@ -521,13 +521,18 @@ def admin_set_execution_node_status(
):
with session.begin():
execution_node = session.get(backend_types_sql.ExecutionNode, id)
if execution_node is None:
raise fastapi.HTTPException(status_code=404, detail=f"ExecutionNode {id} not found")
execution_node.container_execution_status = status

@router.get("/api/admin/sql_engine_connection_pool_status")
async def get_sql_engine_connection_pool_status(
session: typing.Annotated[orm.Session, fastapi.Depends(get_session)],
) -> str:
return session.get_bind().pool.status()
bind = session.get_bind()
if not isinstance(bind, sqlalchemy.Engine):
raise TypeError("Expected Engine, got Connection")
return bind.pool.status()

# # Needs to be called after all routes have been added to the router
# app.include_router(router)
Expand Down Expand Up @@ -564,14 +569,14 @@ async def get_sql_engine_connection_pool_status(
# https://fastapi.tiangolo.com/tutorial/dependencies
# We'll replace the original function annotations.
# It works!!!
def replace_annotations(func, original_annotation, new_annotation):
def replace_annotations(func: typing.Callable[..., typing.Any], original_annotation: type, new_annotation: type) -> typing.Callable[..., typing.Any]:
for name in list(func.__annotations__):
if func.__annotations__[name] == original_annotation:
func.__annotations__[name] = new_annotation
return func


def add_parameter_annotation_metadata(func, parameter_name: str, annotation_metadata):
def add_parameter_annotation_metadata(func: typing.Callable[..., typing.Any], parameter_name: str, annotation_metadata: typing.Any) -> typing.Callable[..., typing.Any]:
func.__annotations__[parameter_name] = typing.Annotated[
func.__annotations__[parameter_name], annotation_metadata
]
Expand Down
44 changes: 24 additions & 20 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ def create(
# TODO: Load and validate all components
# TODO: Fetch missing components and populate component specs

pipeline_name = root_task.component_ref.spec.name
component_spec = root_task.component_ref.spec
if component_spec is None:
raise ApiServiceError("component_ref.spec is required")
pipeline_name = component_spec.name

with session.begin():
root_execution_node = _recursively_create_all_executions_and_artifacts_root(
Expand Down Expand Up @@ -172,7 +175,7 @@ def list(
include_pipeline_names: bool = False,
include_execution_stats: bool = False,
) -> ListPipelineJobsResponse:
page_token_dict = _decode_page_token(page_token)
page_token_dict = _decode_page_token(page_token) if page_token else {}
OFFSET_KEY = "offset"
offset = page_token_dict.get(OFFSET_KEY, 0)
page_size = 10
Expand All @@ -195,7 +198,7 @@ def list(
value = current_user
# TODO: Maybe make this a bit more robust.
# We need to change the filter since it goes into the next_page_token.
filter = filter.replace(
filter = (filter or "").replace(
"created_by:me", f"created_by:{current_user}"
)
if value:
Expand All @@ -215,7 +218,7 @@ def list(
)
next_page_offset = offset + page_size
next_page_token_dict = {OFFSET_KEY: next_page_offset, FILTER_KEY: filter}
next_page_token = _encode_page_token(next_page_token_dict)
next_page_token: str | None = _encode_page_token(next_page_token_dict)
if len(pipeline_runs) < page_size:
next_page_token = None

Expand Down Expand Up @@ -281,8 +284,9 @@ def _calculate_execution_status_stats(
)
)
execution_status_stat_rows = session.execute(query).tuples().all()
execution_status_stats = dict(execution_status_stat_rows)

execution_status_stats: dict[bts.ContainerExecutionStatus, int] = {
status: count for status, count in execution_status_stat_rows if status is not None
}
return execution_status_stats

def list_annotations(
Expand Down Expand Up @@ -665,7 +669,7 @@ def get_graph_execution_state(
status_stats = child_execution_status_stats.setdefault(
child_execution_id, {}
)
status_stats[status.value] = count
status_stats[status.value if status is not None else "UNKNOWN"] = count
return GetGraphExecutionStateResponse(
child_execution_status_stats=child_execution_status_stats,
)
Expand Down Expand Up @@ -876,7 +880,7 @@ def _read_container_execution_log_from_uri(
if storage_provider:
# TODO: Switch to storage_provider.parse_uri_get_accessor
uri_accessor = storage_provider.make_uri(log_uri)
log_text = uri_accessor.get_reader().download_as_text()
log_text: str = uri_accessor.get_reader().download_as_text()
return log_text

if "://" not in log_uri:
Expand All @@ -889,14 +893,14 @@ def _read_container_execution_log_from_uri(

gcs_client = storage.Client()
blob = storage.Blob.from_string(log_uri, client=gcs_client)
log_text = blob.download_as_text()
log_text = str(blob.download_as_text())
return log_text
elif log_uri.startswith("hf://"):
from cloud_pipelines_backend.storage_providers import huggingface_repo_storage

storage_provider = huggingface_repo_storage.HuggingFaceRepoStorageProvider()
uri_accessor = storage_provider.parse_uri_get_accessor(uri_string=log_uri)
log_text = uri_accessor.get_reader().download_as_text()
log_text = str(uri_accessor.get_reader().download_as_text())
return log_text
else:
raise NotImplementedError(
Expand Down Expand Up @@ -1378,7 +1382,7 @@ def _recursively_create_all_executions_and_artifacts(
] = {}
for input_spec in child_component_spec.inputs or []:
input_argument = (child_task_spec.arguments or {}).get(input_spec.name)
input_artifact_node: _ArtifactNodeOrDynamicDataType | None = None
child_input_artifact_node: _ArtifactNodeOrDynamicDataType | None = None
if input_argument is None and not input_spec.optional:
# Not failing on unconnected required input if there is a default value
if input_spec.default is None:
Expand All @@ -1390,44 +1394,44 @@ def _recursively_create_all_executions_and_artifacts(
if input_argument is None:
pass
elif isinstance(input_argument, structures.GraphInputArgument):
input_artifact_node = input_artifact_nodes.get(
child_input_artifact_node = input_artifact_nodes.get(
input_argument.graph_input.input_name
)
if input_artifact_node is None:
if child_input_artifact_node is None:
# Warning: unconnected upstream
# TODO: Support using upstream graph input's default value when needed for required input (non-trivial feature).
# Feature: "Unconnected upstream with optional default"
pass
elif isinstance(input_argument, structures.TaskOutputArgument):
task_output_source = input_argument.task_output
input_artifact_node = task_output_artifact_nodes[
child_input_artifact_node = task_output_artifact_nodes[
task_output_source.task_id
][task_output_source.output_name]
elif isinstance(input_argument, str):
input_artifact_node = (
child_input_artifact_node = (
_construct_constant_artifact_node_and_add_to_session(
session=session,
value=input_argument,
artifact_type=input_spec.type,
)
)
# Not adding constant inputs to the DB. We'll add them to `ExecutionNode.constant_arguments`
# input_artifact_node = (
# child_input_artifact_node = (
# _construct_constant_artifact_node(
# value=input_argument,
# artifact_type=input_spec.type,
# )
# )
elif isinstance(input_argument, structures.DynamicDataArgument):
# We'll deal with dynamic data (e.g. secrets) when launching the container.
input_artifact_node = input_argument
child_input_artifact_node = input_argument
else:
raise ApiServiceError(
f"Unexpected task argument: {input_spec.name}={input_argument}. {child_task_spec=}"
)
if input_artifact_node:
if child_input_artifact_node:
child_task_input_artifact_nodes[input_spec.name] = (
input_artifact_node
child_input_artifact_node
)

# Creating child task nodes and their output artifacts
Expand Down Expand Up @@ -1488,7 +1492,7 @@ def _toposort_tasks(
)

# Topologically sorting tasks to detect cycles
task_dependents = {k: {} for k in task_dependencies}
task_dependents: dict[str, dict[str, bool]] = {k: {} for k in task_dependencies}
for task_id, dependencies in task_dependencies.items():
for dependency in dependencies:
task_dependents[dependency][task_id] = True
Expand Down
2 changes: 1 addition & 1 deletion cloud_pipelines_backend/backend_types_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def generate_unique_id() -> str:

# # Needed to put a union type into DB
# class SqlIOTypeStruct(_BaseModel):
# type: structures.TypeSpecType
# type_: structures.TypeSpecType # renamed to avoid mypy type-comment parsing
# No. We'll represent TypeSpecType as name:str + properties:dict
# Supported cases:
# * type: "name"
Expand Down
2 changes: 1 addition & 1 deletion cloud_pipelines_backend/component_library_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class PublishedComponentResponse:
# text: str | None = None

@staticmethod
def from_db(published_component: PublishedComponentRow):
def from_db(published_component: PublishedComponentRow) -> "PublishedComponentResponse":
return PublishedComponentResponse(
digest=published_component.digest,
published_by=published_component.published_by,
Expand Down
2 changes: 1 addition & 1 deletion cloud_pipelines_backend/component_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ class TaskOutputReference(_BaseModel):

output_name: str
task_id: str # Used for linking to the upstream task in serialized component file.
# type: Optional[TypeSpecType] = None # Can be used to override the reference data type
# type_: Optional[TypeSpecType] = None # Can be used to override the reference data type


@dataclasses.dataclass
Expand Down
8 changes: 6 additions & 2 deletions cloud_pipelines_backend/database_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import sqlalchemy

from . import backend_types_sql as bts
Expand Down Expand Up @@ -29,7 +31,7 @@ def create_db_engine(
# Using PyMySQL instead of missing MySQLdb
database_uri = database_uri.replace("mysql://", "mysql+pymysql://")

create_engine_kwargs = {}
create_engine_kwargs: dict[str, Any] = {}
if database_uri == "sqlite://":
create_engine_kwargs["poolclass"] = sqlalchemy.pool.StaticPool

Expand Down Expand Up @@ -74,6 +76,8 @@ def migrate_db(db_engine: sqlalchemy.Engine):
# bts.ExecutionNode.__table__.indexes.remove(index1)
# Or we need to avoid calling the Index constructor.

for index in bts.ExecutionNode.__table__.indexes:
table = bts.ExecutionNode.__table__
assert isinstance(table, sqlalchemy.Table)
for index in table.indexes:
if index.name == "ix_execution_node_container_execution_cache_key":
index.create(db_engine, checkfirst=True)
4 changes: 2 additions & 2 deletions cloud_pipelines_backend/instrumentation/api_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import logging
import secrets

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JFYI: We try to add modules, not individual attributes.

from starlette.requests import Request
from starlette.responses import Response

Expand Down Expand Up @@ -42,7 +42,7 @@ class RequestContextMiddleware(BaseHTTPMiddleware):
This ensures all logs during the request processing include the same request_id.
"""

async def dispatch(self, request: Request, call_next) -> Response:
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
"""Process each request with a new request_id.

Args:
Expand Down
1 change: 1 addition & 0 deletions cloud_pipelines_backend/instrumentation/otel_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def setup_api_tracing(app: fastapi.FastAPI) -> None:

_validate_otel_config(otel_endpoint, otel_protocol)

otel_exporter: otel_grpc_trace_exporter.OTLPSpanExporter | otel_http_trace_exporter.OTLPSpanExporter
if otel_protocol == _OTEL_PROTOCOL_GRPC:
otel_exporter = otel_grpc_trace_exporter.OTLPSpanExporter(
endpoint=otel_endpoint
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def resolve_container_command_line(
container_spec = component_spec.implementation.container

# Need to preserve the order to make the kubernetes output names deterministic
output_paths = dict()
output_paths: dict[str, str] = {}
input_paths = dict()
inputs_consumed_by_value = {}

def expand_command_part(arg) -> str | list[str] | None:
def expand_command_part(arg: object) -> str | list[str] | None:
if arg is None:
return None
if isinstance(arg, (str, int, float, bool)):
Expand Down Expand Up @@ -89,7 +89,7 @@ def expand_command_part(arg) -> str | list[str] | None:
arg = arg.if_structure
condition_result = expand_command_part(arg.condition)
condition_result_bool = (
condition_result and condition_result.lower() == "true"
isinstance(condition_result, str) and condition_result.lower() == "true"
)
result_node = arg.then_value if condition_result_bool else arg.else_value
if result_node is None:
Expand All @@ -107,7 +107,7 @@ def expand_command_part(arg) -> str | list[str] | None:
else:
raise TypeError(f"Unrecognized argument type: {arg}")

def expand_argument_list(argument_list):
def expand_argument_list(argument_list: list | None) -> list[str]:
expanded_list = []
if argument_list is not None:
for part in argument_list:
Expand Down
17 changes: 9 additions & 8 deletions cloud_pipelines_backend/launchers/huggingface_launchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def get_input_path(input_name: str) -> str:
).as_posix()
# hf_uri = huggingface_repo_storage.HuggingFaceRepoUri.parse(uri)
download_input_uris[uri] = container_path
return container_path
return str(container_path)
Copy link
Contributor

@Ark-kun Ark-kun Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

container_path returned by .as_posix() is str


def get_output_path(output_name: str) -> str:
uri = output_uris[output_name]
Expand Down Expand Up @@ -387,7 +387,8 @@ def has_failed(self) -> bool:
@property
def started_at(self) -> datetime.datetime | None:
# HF Jobs do not provide started_at, so using created_at
return self._job.created_at
created_at: datetime.datetime | None = self._job.created_at
return created_at

@property
def ended_at(self) -> datetime.datetime | None:
Expand All @@ -404,20 +405,22 @@ def launcher_error_message(self) -> str | None:
_logger.info(
f"launcher_error_message: {self._id=}: {self._job.status.message=}"
)
return self._job.status.message
message: str | None = self._job.status.message
return message
return None

def get_log(self) -> str:
if self.has_ended:
try:
return (
log_text: str = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.download_as_text() already returns str, doesn't it?

huggingface_repo_storage.HuggingFaceRepoStorageProvider(
client=self._get_api_client()
)
.make_uri(self._log_uri)
.get_reader()
.download_as_text()
)
return log_text
except Exception as ex:
_logger.warning(
f"get_log: {self._id=}: Error getting log from URI: {self._log_uri}",
Expand All @@ -434,13 +437,11 @@ def upload_log(self):
pass

def stream_log_lines(self) -> typing.Iterator[str]:
return (
self._get_api_client()
.fetch_job_logs(
return iter(
self._get_api_client().fetch_job_logs(
job_id=self._id,
namespace=self._namespace,
)
.__iter__()
)

def terminate(self):
Expand Down
Loading