diff --git a/src/tinker/lib/internal_client_holder.py b/src/tinker/lib/internal_client_holder.py index 80407c3..86be483 100644 --- a/src/tinker/lib/internal_client_holder.py +++ b/src/tinker/lib/internal_client_holder.py @@ -400,7 +400,7 @@ async def _async_cleanup(self): @staticmethod def _is_retryable_status_code(status_code: int) -> bool: - return status_code in (408, 409, 429) or (500 <= status_code < 600) + return status_code in (408, 429) or (500 <= status_code < 600) @staticmethod def _is_retryable_exception(exception: Exception) -> bool: diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index b047e2d..40cd7a6 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Tuple from tinker import types +from tinker._exceptions import ConflictError from tinker.lib.client_connection_pool_type import ClientConnectionPoolType from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture from tinker.lib.telemetry import Telemetry, capture_exceptions @@ -55,6 +56,30 @@ } +def _matching_checkpoint_ids( + checkpoint_name: str, checkpoint_type: Literal["training", "sampler"] +) -> set[str]: + prefix = "weights" if checkpoint_type == "training" else "sampler_weights" + return { + checkpoint_name, + f"{prefix}/{checkpoint_name}", + } + + +def _find_matching_checkpoint_path( + checkpoints: list[types.Checkpoint], + checkpoint_name: str, + checkpoint_type: Literal["training", "sampler"], +) -> str | None: + matching_ids = _matching_checkpoint_ids(checkpoint_name, checkpoint_type) + for checkpoint in checkpoints: + if checkpoint.checkpoint_type != checkpoint_type: + continue + if checkpoint.checkpoint_id in matching_ids: + return checkpoint.tinker_path + return None + + class TrainingClient(TelemetryProvider): """Client for training ML models with forward/backward passes and optimization. @@ -609,7 +634,20 @@ async def _send_request(): ) async with self._take_turn(request_id): - future = await self.holder.execute_with_retries(_send_request) + try: + future = await self.holder.execute_with_retries(_send_request) + except ConflictError: + recovered_path = await self._recover_checkpoint_path_from_conflict( + checkpoint_name=name, + checkpoint_type="training", + ) + if recovered_path is None: + raise + logger.warning( + "Recovered from save_state 409 conflict by reusing checkpoint path for '%s'", + name, + ) + return types.SaveWeightsResponse(path=recovered_path) return await _APIFuture( types.SaveWeightsResponse, self.holder, @@ -656,6 +694,24 @@ async def _send_request(): queue_state_observer=self._queue_state_logger, ) + async def _recover_checkpoint_path_from_conflict( + self, + checkpoint_name: str, + checkpoint_type: Literal["training", "sampler"], + ) -> str | None: + """Resolve an existing checkpoint path after a save-name conflict.""" + + async def _send_request(): + with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client: + return await client.weights.list(model_id=self._guaranteed_model_id()) + + checkpoints_response = await self.holder.execute_with_retries(_send_request) + return _find_matching_checkpoint_path( + checkpoints_response.checkpoints, + checkpoint_name=checkpoint_name, + checkpoint_type=checkpoint_type, + ) + @capture_exceptions(fatal=True) def load_state(self, path: str) -> APIFuture[types.LoadWeightsResponse]: """Load model weights from a saved checkpoint. @@ -745,7 +801,22 @@ async def _send_request(): ) async with self._take_turn(request_id): - future = await self.holder.execute_with_retries(_send_request) + try: + future = await self.holder.execute_with_retries(_send_request) + except ConflictError: + if name is None: + raise + recovered_path = await self._recover_checkpoint_path_from_conflict( + checkpoint_name=name, + checkpoint_type="sampler", + ) + if recovered_path is None: + raise + logger.warning( + "Recovered from save_weights_for_sampler 409 conflict by reusing checkpoint path for '%s'", + name, + ) + return types.SaveWeightsForSamplerResponseInternal(path=recovered_path) return await _APIFuture( types.SaveWeightsForSamplerResponseInternal, self.holder, diff --git a/tests/test_service_client.py b/tests/test_service_client.py index 3557b2b..5ba8792 100644 --- a/tests/test_service_client.py +++ b/tests/test_service_client.py @@ -4,6 +4,7 @@ import json import os +from datetime import UTC, datetime import httpx import pytest @@ -11,6 +12,12 @@ import tinker from tinker import types +from tinker.lib.internal_client_holder import InternalClientHolder +from tinker.lib.public_interfaces.training_client import ( + TrainingClient, + _find_matching_checkpoint_path, + _matching_checkpoint_ids, +) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -202,3 +209,95 @@ def test_create_training_client_from_state_sync_uses_public_endpoint( # Verify it uses the public endpoint (info_lite), not the full training run endpoint assert info_lite_route.called + + +def test_matching_checkpoint_ids_supports_bare_and_prefixed_names() -> None: + assert _matching_checkpoint_ids("000010", "training") == {"000010", "weights/000010"} + assert _matching_checkpoint_ids("000010", "sampler") == { + "000010", + "sampler_weights/000010", + } + + +def test_find_matching_checkpoint_path_filters_by_type_and_name() -> None: + checkpoints = [ + types.Checkpoint( + checkpoint_id="weights/000010", + checkpoint_type="training", + time=datetime.now(UTC), + tinker_path="tinker://run-1/weights/000010", + ), + types.Checkpoint( + checkpoint_id="sampler_weights/000010", + checkpoint_type="sampler", + time=datetime.now(UTC), + tinker_path="tinker://run-1/sampler_weights/000010", + ), + ] + + assert ( + _find_matching_checkpoint_path(checkpoints, "000010", "training") + == "tinker://run-1/weights/000010" + ) + assert ( + _find_matching_checkpoint_path(checkpoints, "000010", "sampler") + == "tinker://run-1/sampler_weights/000010" + ) + + +def test_retryable_status_codes_do_not_include_409() -> None: + assert InternalClientHolder._is_retryable_status_code(408) + assert not InternalClientHolder._is_retryable_status_code(409) + assert InternalClientHolder._is_retryable_status_code(429) + + +class _DummyHolder: + def __init__(self, checkpoints: list[types.Checkpoint]): + self._response = types.CheckpointsListResponse(checkpoints=checkpoints) + + async def execute_with_retries(self, _func): + return self._response + + +@pytest.mark.asyncio +async def test_recover_checkpoint_path_from_conflict_returns_matching_path() -> None: + client = TrainingClient.__new__(TrainingClient) + client.holder = _DummyHolder( + checkpoints=[ + types.Checkpoint( + checkpoint_id="weights/000010", + checkpoint_type="training", + time=datetime.now(UTC), + tinker_path="tinker://run-1/weights/000010", + ) + ] + ) + client._guaranteed_model_id = lambda: "run-1" # type: ignore[method-assign] + + recovered = await client._recover_checkpoint_path_from_conflict( + checkpoint_name="000010", + checkpoint_type="training", + ) + assert recovered == "tinker://run-1/weights/000010" + + +@pytest.mark.asyncio +async def test_recover_checkpoint_path_from_conflict_returns_none_when_missing() -> None: + client = TrainingClient.__new__(TrainingClient) + client.holder = _DummyHolder( + checkpoints=[ + types.Checkpoint( + checkpoint_id="weights/000011", + checkpoint_type="training", + time=datetime.now(UTC), + tinker_path="tinker://run-1/weights/000011", + ) + ] + ) + client._guaranteed_model_id = lambda: "run-1" # type: ignore[method-assign] + + recovered = await client._recover_checkpoint_path_from_conflict( + checkpoint_name="000010", + checkpoint_type="training", + ) + assert recovered is None