diff --git a/cloud_pipelines_backend/api_router.py b/cloud_pipelines_backend/api_router.py index d40c817..57340ee 100644 --- a/cloud_pipelines_backend/api_router.py +++ b/cloud_pipelines_backend/api_router.py @@ -169,7 +169,7 @@ def inject_user_name(func: typing.Callable, parameter_name: str = "user_name"): SessionDep = typing.Annotated[orm.Session, fastapi.Depends(get_session)] - def inject_session_dependency(func: typing.Callable) -> typing.Callable: + def inject_session_dependency(func: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.Any]: return replace_annotations( func, original_annotation=orm.Session, new_annotation=SessionDep ) @@ -521,13 +521,18 @@ def admin_set_execution_node_status( ): with session.begin(): execution_node = session.get(backend_types_sql.ExecutionNode, id) + if execution_node is None: + raise fastapi.HTTPException(status_code=404, detail=f"ExecutionNode {id} not found") execution_node.container_execution_status = status @router.get("/api/admin/sql_engine_connection_pool_status") async def get_sql_engine_connection_pool_status( session: typing.Annotated[orm.Session, fastapi.Depends(get_session)], ) -> str: - return session.get_bind().pool.status() + bind = session.get_bind() + if not isinstance(bind, sqlalchemy.Engine): + raise TypeError("Expected Engine, got Connection") + return bind.pool.status() # # Needs to be called after all routes have been added to the router # app.include_router(router) @@ -564,14 +569,14 @@ async def get_sql_engine_connection_pool_status( # https://fastapi.tiangolo.com/tutorial/dependencies # We'll replace the original function annotations. # It works!!! -def replace_annotations(func, original_annotation, new_annotation): +def replace_annotations(func: typing.Callable[..., typing.Any], original_annotation: type, new_annotation: type) -> typing.Callable[..., typing.Any]: for name in list(func.__annotations__): if func.__annotations__[name] == original_annotation: func.__annotations__[name] = new_annotation return func -def add_parameter_annotation_metadata(func, parameter_name: str, annotation_metadata): +def add_parameter_annotation_metadata(func: typing.Callable[..., typing.Any], parameter_name: str, annotation_metadata: typing.Any) -> typing.Callable[..., typing.Any]: func.__annotations__[parameter_name] = typing.Annotated[ func.__annotations__[parameter_name], annotation_metadata ] diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index 6877c36..15148a4 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -87,7 +87,10 @@ def create( # TODO: Load and validate all components # TODO: Fetch missing components and populate component specs - pipeline_name = root_task.component_ref.spec.name + component_spec = root_task.component_ref.spec + if component_spec is None: + raise ApiServiceError("component_ref.spec is required") + pipeline_name = component_spec.name with session.begin(): root_execution_node = _recursively_create_all_executions_and_artifacts_root( @@ -172,7 +175,7 @@ def list( include_pipeline_names: bool = False, include_execution_stats: bool = False, ) -> ListPipelineJobsResponse: - page_token_dict = _decode_page_token(page_token) + page_token_dict = _decode_page_token(page_token) if page_token else {} OFFSET_KEY = "offset" offset = page_token_dict.get(OFFSET_KEY, 0) page_size = 10 @@ -195,7 +198,7 @@ def list( value = current_user # TODO: Maybe make this a bit more robust. # We need to change the filter since it goes into the next_page_token. - filter = filter.replace( + filter = (filter or "").replace( "created_by:me", f"created_by:{current_user}" ) if value: @@ -215,7 +218,7 @@ def list( ) next_page_offset = offset + page_size next_page_token_dict = {OFFSET_KEY: next_page_offset, FILTER_KEY: filter} - next_page_token = _encode_page_token(next_page_token_dict) + next_page_token: str | None = _encode_page_token(next_page_token_dict) if len(pipeline_runs) < page_size: next_page_token = None @@ -281,8 +284,9 @@ def _calculate_execution_status_stats( ) ) execution_status_stat_rows = session.execute(query).tuples().all() - execution_status_stats = dict(execution_status_stat_rows) - + execution_status_stats: dict[bts.ContainerExecutionStatus, int] = { + status: count for status, count in execution_status_stat_rows if status is not None + } return execution_status_stats def list_annotations( @@ -665,7 +669,7 @@ def get_graph_execution_state( status_stats = child_execution_status_stats.setdefault( child_execution_id, {} ) - status_stats[status.value] = count + status_stats[status.value if status is not None else "UNKNOWN"] = count return GetGraphExecutionStateResponse( child_execution_status_stats=child_execution_status_stats, ) @@ -876,7 +880,7 @@ def _read_container_execution_log_from_uri( if storage_provider: # TODO: Switch to storage_provider.parse_uri_get_accessor uri_accessor = storage_provider.make_uri(log_uri) - log_text = uri_accessor.get_reader().download_as_text() + log_text: str = uri_accessor.get_reader().download_as_text() return log_text if "://" not in log_uri: @@ -889,14 +893,14 @@ def _read_container_execution_log_from_uri( gcs_client = storage.Client() blob = storage.Blob.from_string(log_uri, client=gcs_client) - log_text = blob.download_as_text() + log_text = str(blob.download_as_text()) return log_text elif log_uri.startswith("hf://"): from cloud_pipelines_backend.storage_providers import huggingface_repo_storage storage_provider = huggingface_repo_storage.HuggingFaceRepoStorageProvider() uri_accessor = storage_provider.parse_uri_get_accessor(uri_string=log_uri) - log_text = uri_accessor.get_reader().download_as_text() + log_text = str(uri_accessor.get_reader().download_as_text()) return log_text else: raise NotImplementedError( @@ -1378,7 +1382,7 @@ def _recursively_create_all_executions_and_artifacts( ] = {} for input_spec in child_component_spec.inputs or []: input_argument = (child_task_spec.arguments or {}).get(input_spec.name) - input_artifact_node: _ArtifactNodeOrDynamicDataType | None = None + child_input_artifact_node: _ArtifactNodeOrDynamicDataType | None = None if input_argument is None and not input_spec.optional: # Not failing on unconnected required input if there is a default value if input_spec.default is None: @@ -1390,21 +1394,21 @@ def _recursively_create_all_executions_and_artifacts( if input_argument is None: pass elif isinstance(input_argument, structures.GraphInputArgument): - input_artifact_node = input_artifact_nodes.get( + child_input_artifact_node = input_artifact_nodes.get( input_argument.graph_input.input_name ) - if input_artifact_node is None: + if child_input_artifact_node is None: # Warning: unconnected upstream # TODO: Support using upstream graph input's default value when needed for required input (non-trivial feature). # Feature: "Unconnected upstream with optional default" pass elif isinstance(input_argument, structures.TaskOutputArgument): task_output_source = input_argument.task_output - input_artifact_node = task_output_artifact_nodes[ + child_input_artifact_node = task_output_artifact_nodes[ task_output_source.task_id ][task_output_source.output_name] elif isinstance(input_argument, str): - input_artifact_node = ( + child_input_artifact_node = ( _construct_constant_artifact_node_and_add_to_session( session=session, value=input_argument, @@ -1412,7 +1416,7 @@ def _recursively_create_all_executions_and_artifacts( ) ) # Not adding constant inputs to the DB. We'll add them to `ExecutionNode.constant_arguments` - # input_artifact_node = ( + # child_input_artifact_node = ( # _construct_constant_artifact_node( # value=input_argument, # artifact_type=input_spec.type, @@ -1420,14 +1424,14 @@ def _recursively_create_all_executions_and_artifacts( # ) elif isinstance(input_argument, structures.DynamicDataArgument): # We'll deal with dynamic data (e.g. secrets) when launching the container. - input_artifact_node = input_argument + child_input_artifact_node = input_argument else: raise ApiServiceError( f"Unexpected task argument: {input_spec.name}={input_argument}. {child_task_spec=}" ) - if input_artifact_node: + if child_input_artifact_node: child_task_input_artifact_nodes[input_spec.name] = ( - input_artifact_node + child_input_artifact_node ) # Creating child task nodes and their output artifacts @@ -1488,7 +1492,7 @@ def _toposort_tasks( ) # Topologically sorting tasks to detect cycles - task_dependents = {k: {} for k in task_dependencies} + task_dependents: dict[str, dict[str, bool]] = {k: {} for k in task_dependencies} for task_id, dependencies in task_dependencies.items(): for dependency in dependencies: task_dependents[dependency][task_id] = True diff --git a/cloud_pipelines_backend/backend_types_sql.py b/cloud_pipelines_backend/backend_types_sql.py index f642383..abe86a0 100644 --- a/cloud_pipelines_backend/backend_types_sql.py +++ b/cloud_pipelines_backend/backend_types_sql.py @@ -64,7 +64,7 @@ def generate_unique_id() -> str: # # Needed to put a union type into DB # class SqlIOTypeStruct(_BaseModel): -# type: structures.TypeSpecType +# type_: structures.TypeSpecType # renamed to avoid mypy type-comment parsing # No. We'll represent TypeSpecType as name:str + properties:dict # Supported cases: # * type: "name" diff --git a/cloud_pipelines_backend/component_library_api_server.py b/cloud_pipelines_backend/component_library_api_server.py index 94560a9..f7e632f 100644 --- a/cloud_pipelines_backend/component_library_api_server.py +++ b/cloud_pipelines_backend/component_library_api_server.py @@ -158,7 +158,7 @@ class PublishedComponentResponse: # text: str | None = None @staticmethod - def from_db(published_component: PublishedComponentRow): + def from_db(published_component: PublishedComponentRow) -> "PublishedComponentResponse": return PublishedComponentResponse( digest=published_component.digest, published_by=published_component.published_by, diff --git a/cloud_pipelines_backend/component_structures.py b/cloud_pipelines_backend/component_structures.py index 1a6f287..609f69f 100644 --- a/cloud_pipelines_backend/component_structures.py +++ b/cloud_pipelines_backend/component_structures.py @@ -312,7 +312,7 @@ class TaskOutputReference(_BaseModel): output_name: str task_id: str # Used for linking to the upstream task in serialized component file. - # type: Optional[TypeSpecType] = None # Can be used to override the reference data type + # type_: Optional[TypeSpecType] = None # Can be used to override the reference data type @dataclasses.dataclass diff --git a/cloud_pipelines_backend/database_ops.py b/cloud_pipelines_backend/database_ops.py index eb47855..4c52cac 100644 --- a/cloud_pipelines_backend/database_ops.py +++ b/cloud_pipelines_backend/database_ops.py @@ -1,3 +1,5 @@ +from typing import Any + import sqlalchemy from . import backend_types_sql as bts @@ -29,7 +31,7 @@ def create_db_engine( # Using PyMySQL instead of missing MySQLdb database_uri = database_uri.replace("mysql://", "mysql+pymysql://") - create_engine_kwargs = {} + create_engine_kwargs: dict[str, Any] = {} if database_uri == "sqlite://": create_engine_kwargs["poolclass"] = sqlalchemy.pool.StaticPool @@ -74,6 +76,8 @@ def migrate_db(db_engine: sqlalchemy.Engine): # bts.ExecutionNode.__table__.indexes.remove(index1) # Or we need to avoid calling the Index constructor. - for index in bts.ExecutionNode.__table__.indexes: + table = bts.ExecutionNode.__table__ + assert isinstance(table, sqlalchemy.Table) + for index in table.indexes: if index.name == "ix_execution_node_container_execution_cache_key": index.create(db_engine, checkfirst=True) diff --git a/cloud_pipelines_backend/instrumentation/api_tracing.py b/cloud_pipelines_backend/instrumentation/api_tracing.py index 14ab3b2..2301270 100644 --- a/cloud_pipelines_backend/instrumentation/api_tracing.py +++ b/cloud_pipelines_backend/instrumentation/api_tracing.py @@ -8,7 +8,7 @@ import logging import secrets -from starlette.middleware.base import BaseHTTPMiddleware +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request from starlette.responses import Response @@ -42,7 +42,7 @@ class RequestContextMiddleware(BaseHTTPMiddleware): This ensures all logs during the request processing include the same request_id. """ - async def dispatch(self, request: Request, call_next) -> Response: + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: """Process each request with a new request_id. Args: diff --git a/cloud_pipelines_backend/instrumentation/otel_tracing.py b/cloud_pipelines_backend/instrumentation/otel_tracing.py index c2a2191..4f86f94 100644 --- a/cloud_pipelines_backend/instrumentation/otel_tracing.py +++ b/cloud_pipelines_backend/instrumentation/otel_tracing.py @@ -58,6 +58,7 @@ def setup_api_tracing(app: fastapi.FastAPI) -> None: _validate_otel_config(otel_endpoint, otel_protocol) + otel_exporter: otel_grpc_trace_exporter.OTLPSpanExporter | otel_http_trace_exporter.OTLPSpanExporter if otel_protocol == _OTEL_PROTOCOL_GRPC: otel_exporter = otel_grpc_trace_exporter.OTLPSpanExporter( endpoint=otel_endpoint diff --git a/cloud_pipelines_backend/launchers/container_component_utils.py b/cloud_pipelines_backend/launchers/container_component_utils.py index 09436f1..fc56331 100644 --- a/cloud_pipelines_backend/launchers/container_component_utils.py +++ b/cloud_pipelines_backend/launchers/container_component_utils.py @@ -34,11 +34,11 @@ def resolve_container_command_line( container_spec = component_spec.implementation.container # Need to preserve the order to make the kubernetes output names deterministic - output_paths = dict() + output_paths: dict[str, str] = {} input_paths = dict() inputs_consumed_by_value = {} - def expand_command_part(arg) -> str | list[str] | None: + def expand_command_part(arg: object) -> str | list[str] | None: if arg is None: return None if isinstance(arg, (str, int, float, bool)): @@ -89,7 +89,7 @@ def expand_command_part(arg) -> str | list[str] | None: arg = arg.if_structure condition_result = expand_command_part(arg.condition) condition_result_bool = ( - condition_result and condition_result.lower() == "true" + isinstance(condition_result, str) and condition_result.lower() == "true" ) result_node = arg.then_value if condition_result_bool else arg.else_value if result_node is None: @@ -107,7 +107,7 @@ def expand_command_part(arg) -> str | list[str] | None: else: raise TypeError(f"Unrecognized argument type: {arg}") - def expand_argument_list(argument_list): + def expand_argument_list(argument_list: list | None) -> list[str]: expanded_list = [] if argument_list is not None: for part in argument_list: diff --git a/cloud_pipelines_backend/launchers/huggingface_launchers.py b/cloud_pipelines_backend/launchers/huggingface_launchers.py index 5a0adcd..963f543 100644 --- a/cloud_pipelines_backend/launchers/huggingface_launchers.py +++ b/cloud_pipelines_backend/launchers/huggingface_launchers.py @@ -147,7 +147,7 @@ def get_input_path(input_name: str) -> str: ).as_posix() # hf_uri = huggingface_repo_storage.HuggingFaceRepoUri.parse(uri) download_input_uris[uri] = container_path - return container_path + return str(container_path) def get_output_path(output_name: str) -> str: uri = output_uris[output_name] @@ -387,7 +387,8 @@ def has_failed(self) -> bool: @property def started_at(self) -> datetime.datetime | None: # HF Jobs do not provide started_at, so using created_at - return self._job.created_at + created_at: datetime.datetime | None = self._job.created_at + return created_at @property def ended_at(self) -> datetime.datetime | None: @@ -404,13 +405,14 @@ def launcher_error_message(self) -> str | None: _logger.info( f"launcher_error_message: {self._id=}: {self._job.status.message=}" ) - return self._job.status.message + message: str | None = self._job.status.message + return message return None def get_log(self) -> str: if self.has_ended: try: - return ( + log_text: str = ( huggingface_repo_storage.HuggingFaceRepoStorageProvider( client=self._get_api_client() ) @@ -418,6 +420,7 @@ def get_log(self) -> str: .get_reader() .download_as_text() ) + return log_text except Exception as ex: _logger.warning( f"get_log: {self._id=}: Error getting log from URI: {self._log_uri}", @@ -434,13 +437,11 @@ def upload_log(self): pass def stream_log_lines(self) -> typing.Iterator[str]: - return ( - self._get_api_client() - .fetch_job_logs( + return iter( + self._get_api_client().fetch_job_logs( job_id=self._id, namespace=self._namespace, ) - .__iter__() ) def terminate(self): diff --git a/cloud_pipelines_backend/launchers/kubernetes_launchers.py b/cloud_pipelines_backend/launchers/kubernetes_launchers.py index 953a1c8..1d8675e 100644 --- a/cloud_pipelines_backend/launchers/kubernetes_launchers.py +++ b/cloud_pipelines_backend/launchers/kubernetes_launchers.py @@ -141,7 +141,7 @@ def __init__( pod_postprocessor: PodPostProcessor | None = None, _storage_provider: storage_provider_interfaces.StorageProvider, _create_volume_and_volume_mount: typing.Callable[ - [str, str, str, bool], + ..., tuple[k8s_client_lib.V1Volume, k8s_client_lib.V1VolumeMount], ], ): @@ -275,7 +275,7 @@ def get_input_path(input_name: str) -> str: ) volume_map[volume.name] = volume volume_mounts.append(volume_mount) - return container_path + return str(container_path) def get_output_path(output_name: str) -> str: uri = output_uris[output_name] @@ -294,7 +294,7 @@ def get_output_path(output_name: str) -> str: ) volume_map[volume.name] = volume volume_mounts.append(volume_mount) - return container_path + return str(container_path) # Resolving the command line. # Also indirectly populates volumes and volume_mounts. @@ -391,7 +391,7 @@ def __init__( pod_postprocessor: PodPostProcessor | None = None, _storage_provider: storage_provider_interfaces.StorageProvider, _create_volume_and_volume_mount: typing.Callable[ - [str, str, str, bool], + ..., tuple[k8s_client_lib.V1Volume, k8s_client_lib.V1VolumeMount], ], ): @@ -696,7 +696,8 @@ def exit_code(self) -> int | None: main_container_terminated_state = self._get_main_container_terminated_state() if main_container_terminated_state is None: return None - return main_container_terminated_state.exit_code + exit_code: int = main_container_terminated_state.exit_code + return exit_code @property def has_ended(self) -> bool: @@ -720,12 +721,14 @@ def started_at(self) -> datetime.datetime | None: main_container_state.terminated ) if terminated_state is not None: - return terminated_state.started_at + started_at: datetime.datetime | None = terminated_state.started_at + return started_at running_state: k8s_client_lib.V1ContainerStateRunning = ( main_container_state.running ) if running_state is not None: - return running_state.started_at + started_at_running: datetime.datetime | None = running_state.started_at + return started_at_running return None @property @@ -733,7 +736,8 @@ def ended_at(self) -> datetime.datetime | None: terminated_state = self._get_main_container_terminated_state() if terminated_state is None: return None - return terminated_state.finished_at + finished_at: datetime.datetime | None = terminated_state.finished_at + return finished_at @property def launcher_error_message(self) -> str | None: @@ -796,7 +800,7 @@ def get_refreshed(self) -> LaunchedKubernetesContainer: def get_log(self) -> str: launcher = self._get_launcher() core_api_client = k8s_client_lib.CoreV1Api(api_client=launcher._api_client) - return core_api_client.read_namespaced_pod_log( + log: str = core_api_client.read_namespaced_pod_log( name=self._pod_name, namespace=self._namespace, container=_MAIN_CONTAINER_NAME, @@ -806,6 +810,7 @@ def get_log(self) -> str: # stream="All", _request_timeout=launcher._request_timeout, ) + return log def upload_log(self): launcher = self._get_launcher() @@ -864,7 +869,7 @@ def __init__( pod_postprocessor: PodPostProcessor | None = None, _storage_provider: storage_provider_interfaces.StorageProvider, _create_volume_and_volume_mount: typing.Callable[ - [str, str, str, bool], + ..., tuple[k8s_client_lib.V1Volume, k8s_client_lib.V1VolumeMount], ], ): @@ -1116,7 +1121,7 @@ def exit_code(self) -> int | None: if non_zero_exit_codes: # Returning 1st non-zero exit code. # There should only be one. - return non_zero_exit_codes[0] + return int(non_zero_exit_codes[0]) # We should not reach here. Failed jobs should have at least one non-zero exit code. # But just in case, we return None instead of 0 to indicate that exit code is unknown. return None @@ -1138,7 +1143,8 @@ def started_at(self) -> datetime.datetime | None: job_status = self._debug_job.status if not job_status: return None - return job_status.start_time + start_time: datetime.datetime | None = job_status.start_time + return start_time @property def ended_at(self) -> datetime.datetime | None: @@ -1152,7 +1158,8 @@ def ended_at(self) -> datetime.datetime | None: ] if not ended_condition_times: return None - return ended_condition_times[0] + transition_time: datetime.datetime | None = ended_condition_times[0] + return transition_time @property def launcher_error_message(self) -> str | None: @@ -1243,7 +1250,7 @@ def get_refreshed(self) -> LaunchedKubernetesJob: pod_map[index_str] = pod else: for pod in pod_list_response.items: - index_str: str = pod.metadata.name + index_str = str(pod.metadata.name) pod_map[index_str] = pod new_launched_container = copy.copy(self) new_launched_container._debug_job = job @@ -1253,7 +1260,7 @@ def get_refreshed(self) -> LaunchedKubernetesJob: def _get_log_by_pod_key(self, pod_name: str) -> str: launcher = self._get_launcher() core_api_client = k8s_client_lib.CoreV1Api(api_client=launcher._api_client) - log = core_api_client.read_namespaced_pod_log( + log: str = core_api_client.read_namespaced_pod_log( name=pod_name, namespace=self._namespace, container=_MAIN_CONTAINER_NAME, @@ -1401,12 +1408,14 @@ def windows_path_to_docker_path(path: str) -> str: def _kubernetes_serialize(obj) -> dict[str, Any]: shallow_client = k8s_client_lib.ApiClient.__new__(k8s_client_lib.ApiClient) - return shallow_client.sanitize_for_serialization(obj) + result: dict[str, Any] = shallow_client.sanitize_for_serialization(obj) + return result def _kubernetes_deserialize(obj_dict: dict[str, Any], cls: type[_T]) -> _T: shallow_client = k8s_client_lib.ApiClient.__new__(k8s_client_lib.ApiClient) - return shallow_client._ApiClient__deserialize(obj_dict, cls) + result: _T = shallow_client._ApiClient__deserialize(obj_dict, cls) + return result def _update_dict_recursively(d1: dict, d2: dict): diff --git a/cloud_pipelines_backend/launchers/local_docker_launchers.py b/cloud_pipelines_backend/launchers/local_docker_launchers.py index b8cf70e..27aa9c7 100644 --- a/cloud_pipelines_backend/launchers/local_docker_launchers.py +++ b/cloud_pipelines_backend/launchers/local_docker_launchers.py @@ -85,6 +85,8 @@ def launch_container_task( log_uri: str, annotations: dict[str, Any] | None = None, ) -> "LaunchedDockerContainer": + if not isinstance(component_spec.implementation, structures.ContainerImplementation): + raise TypeError("Expected ContainerImplementation") container_spec = component_spec.implementation.container # TODO: Validate the output URIs. Don't forget about (`C:\*` and `C:/*` paths) @@ -168,7 +170,7 @@ def get_input_path(input_name: str) -> str: read_only=True, ) ) - return container_path + return str(container_path) def get_output_path(output_name: str) -> str: uri = output_uris[output_name] @@ -184,7 +186,7 @@ def get_output_path(output_name: str) -> str: read_only=False, ) ) - return container_path + return str(container_path) # Resolving the command line. # Also indirectly populates volumes and volume_mounts. @@ -281,11 +283,11 @@ def status(self) -> interfaces.ContainerStatus: def exit_code(self) -> int | None: if not self.has_ended: return None - return self._container.attrs["State"]["ExitCode"] + return int(self._container.attrs["State"]["ExitCode"]) @property def has_ended(self) -> bool: - return self._container.attrs["State"]["Status"] == "exited" + return str(self._container.attrs["State"]["Status"]) == "exited" @property def has_succeeded(self) -> bool: @@ -316,9 +318,8 @@ def launcher_error_message(self) -> str | None: return None def get_log(self) -> str: - return self._container.logs(stdout=True, stderr=True, timestamps=True).decode( - "utf-8", errors="replace" - ) + log_bytes: bytes = self._container.logs(stdout=True, stderr=True, timestamps=True) + return log_bytes.decode("utf-8", errors="replace") def upload_log(self): log = self.get_log() diff --git a/cloud_pipelines_backend/orchestrator_sql.py b/cloud_pipelines_backend/orchestrator_sql.py index b98c832..b262aab 100644 --- a/cloud_pipelines_backend/orchestrator_sql.py +++ b/cloud_pipelines_backend/orchestrator_sql.py @@ -45,7 +45,7 @@ def __init__( logs_root_uri: str, default_task_annotations: dict[str, Any] | None = None, sleep_seconds_between_queue_sweeps: float = 1.0, - output_data_purge_duration: datetime.timedelta = None, + output_data_purge_duration: datetime.timedelta | None = None, ): self._session_factory = session_factory self._launcher = launcher @@ -296,7 +296,8 @@ def internal_process_one_queued_execution( # TODO: Move the time filtering to the DB side. # TODO: Filter by `ended_at`, not `created_at`. # Note: Need `.astimezone` to avoid `TypeError: can't compare offset-naive and offset-aware datetimes` - if execution_candidate.container_execution.created_at.astimezone( + if execution_candidate.container_execution.created_at is not None + and execution_candidate.container_execution.created_at.astimezone( datetime.timezone.utc ) > data_purge_threshold_time @@ -513,7 +514,7 @@ def generate_execution_log_uri( # Handling annotations - full_annotations = {} + full_annotations: dict[str, Any] = {} _update_dict_recursive( full_annotations, copy.deepcopy(self._default_task_annotations or {}) ) @@ -525,7 +526,7 @@ def generate_execution_log_uri( ) full_annotations[common_annotations.PIPELINE_RUN_CREATED_BY_ANNOTATION_KEY] = ( - pipeline_run.created_by + pipeline_run.created_by or "" ) full_annotations[common_annotations.PIPELINE_RUN_ID_ANNOTATION_KEY] = ( pipeline_run.id @@ -758,14 +759,16 @@ def _maybe_preload_value( ) return None try: - text = data.decode("utf-8") + text: str = data.decode("utf-8") return text except Exception: pass + return None + output_artifact_data_map = container_execution.output_artifact_data_map or {} output_artifact_uris: dict[str, str] = { output_name: output_artifact_info_dict["uri"] - for output_name, output_artifact_info_dict in container_execution.output_artifact_data_map.items() + for output_name, output_artifact_info_dict in output_artifact_data_map.items() } # We need to first check that the output data exists. Otherwise `uri_reader.get_info()`` throws `IndexError` @@ -1062,7 +1065,8 @@ def _maybe_get_small_artifact_value( _logger.exception("Error during preloading small artifact values.") return None try: - text = data.decode("utf-8") + text: str = data.decode("utf-8") return text except Exception: pass + return None diff --git a/cloud_pipelines_backend/storage_providers/huggingface_repo_storage.py b/cloud_pipelines_backend/storage_providers/huggingface_repo_storage.py index 1b7bb73..146bc3d 100644 --- a/cloud_pipelines_backend/storage_providers/huggingface_repo_storage.py +++ b/cloud_pipelines_backend/storage_providers/huggingface_repo_storage.py @@ -88,7 +88,7 @@ def parse_uri_get_accessor(self, uri_string: str) -> interfaces.UriAccessor: def upload(self, source_path: str, destination_uri: HuggingFaceRepoUri): _LOGGER.debug(f"Uploading from {source_path} to {destination_uri}") - if pathlib.Path(source_path).is_dir: + if pathlib.Path(source_path).is_dir(): self._client.upload_folder( repo_type=destination_uri.repo_type, repo_id=destination_uri.repo_id, @@ -118,7 +118,7 @@ def download(self, source_uri: HuggingFaceRepoUri, destination_path: str): ) import shutil - if cache_data_path.is_dir: + if cache_data_path.is_dir(): shutil.copytree(cache_data_path, destination_path, dirs_exist_ok=True) else: pathlib.Path(destination_path).parent.mkdir(parents=True, exist_ok=True) diff --git a/pyproject.toml b/pyproject.toml index 06075a8..9464b1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dev = [ "pytest>=8.4.2", "pytest-asyncio>=0.25.2", "ruff>=0.11", + "types-pyyaml>=6.0.12.20250915", ] huggingface = [ "huggingface-hub[oauth]>=0.35.3", @@ -56,9 +57,12 @@ known-first-party = ["cloud_pipelines", "cloud_pipelines_backend"] [tool.mypy] python_version = "3.10" +namespace_packages = true +explicit_package_bases = true warn_return_any = true warn_unused_configs = true ignore_missing_imports = true +exclude = ["build/"] [tool.setuptools.packages.find] include = ["cloud_pipelines_backend*"] diff --git a/start_local.py b/start_local.py index 52a3278..a0c1d0e 100644 --- a/start_local.py +++ b/start_local.py @@ -49,7 +49,7 @@ # endregion # region: Orchestrator configuration -default_task_annotations = {} +default_task_annotations: dict[str, str] = {} sleep_seconds_between_queue_sweeps: float = 1.0 # endregion diff --git a/uv.lock b/uv.lock index 5419396..7181220 100644 --- a/uv.lock +++ b/uv.lock @@ -276,6 +276,7 @@ dev = [ { name = "pytest" }, { name = "pytest-asyncio" }, { name = "ruff" }, + { name = "types-pyyaml" }, ] huggingface = [ { name = "huggingface-hub", extra = ["oauth"] }, @@ -301,6 +302,7 @@ dev = [ { name = "pytest", specifier = ">=8.4.2" }, { name = "pytest-asyncio", specifier = ">=0.25.2" }, { name = "ruff", specifier = ">=0.11" }, + { name = "types-pyyaml", specifier = ">=6.0.12.20250915" }, ] huggingface = [{ name = "huggingface-hub", extras = ["oauth"], specifier = ">=0.35.3" }] @@ -1797,6 +1799,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/42/3efaf858001d2c2913de7f354563e3a3a2f0decae3efe98427125a8f441e/typer-0.16.0-py3-none-any.whl", hash = "sha256:1f79bed11d4d02d4310e3c1b7ba594183bcedb0ac73b27a9e5f28f6fb5b98855", size = 46317, upload-time = "2025-05-26T14:30:30.523Z" }, ] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20250915" +source = { registry = "https://pypi.python.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7e/69/3c51b36d04da19b92f9e815be12753125bd8bc247ba0470a982e6979e71c/types_pyyaml-6.0.12.20250915.tar.gz", hash = "sha256:0f8b54a528c303f0e6f7165687dd33fafa81c807fcac23f632b63aa624ced1d3", size = 17522 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/e0/1eed384f02555dde685fff1a1ac805c1c7dcb6dd019c916fe659b1c1f9ec/types_pyyaml-6.0.12.20250915-py3-none-any.whl", hash = "sha256:e7d4d9e064e89a3b3cae120b4990cd370874d2bf12fa5f46c97018dd5d3c9ab6", size = 20338 }, +] + [[package]] name = "typing-extensions" version = "4.14.0"