diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index 3604503..9544e00 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -99,6 +99,7 @@ def __init__( gcs_bucket: str, pathways_service: str, expected_tpu_instances: Mapping[Any, Any], + proxy_job_name: str | None, ): """Initializes the TPU manager.""" self.cluster = cluster @@ -111,7 +112,7 @@ def __init__( random.choices(string.ascii_lowercase + string.digits, k=5) ) user = os.environ.get("USER", "user") - self._proxy_job_name = f"isc-proxy-{user}-{suffix}" + self._proxy_job_name = proxy_job_name or f"isc-proxy-{user}-{suffix}" self._port_forward_process = None self._proxy_port = None @@ -194,6 +195,7 @@ def connect( gcs_bucket: str, pathways_service: str, expected_tpu_instances: Mapping[str, int], + proxy_job_name: str | None = None, ) -> Iterator["_ISCPathways"]: """Connects to a Pathways server if the cluster exists. If not, creates it. @@ -205,12 +207,16 @@ def connect( pathways_service: The service name and port of the Pathways head pod. expected_tpu_instances: A dictionary mapping TPU machine types to the number of instances. For example: {"tpuv6e:2x2": 2} + proxy_job_name: The name to use for the deployed proxy. If not provided, a + random name will be generated. Yields: The Pathways manager. """ + _logger.info("Validating Pathways service and TPU instances...") validators.validate_pathways_service(pathways_service) validators.validate_tpu_instances(expected_tpu_instances) + _logger.info("Validation complete.") gke_utils.fetch_cluster_credentials( cluster_name=cluster, project_id=project, location=region ) @@ -222,5 +228,6 @@ def connect( gcs_bucket=gcs_bucket, pathways_service=pathways_service, expected_tpu_instances=expected_tpu_instances, + proxy_job_name=proxy_job_name, ) as t: yield t