Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tinker/lib/internal_client_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ async def _async_cleanup(self):

@staticmethod
def _is_retryable_status_code(status_code: int) -> bool:
return status_code in (408, 409, 429) or (500 <= status_code < 600)
return status_code in (408, 429) or (500 <= status_code < 600)

@staticmethod
def _is_retryable_exception(exception: Exception) -> bool:
Expand Down
75 changes: 73 additions & 2 deletions src/tinker/lib/public_interfaces/training_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Generator, List, Literal, Tuple

from tinker import types
from tinker._exceptions import ConflictError
from tinker.lib.client_connection_pool_type import ClientConnectionPoolType
from tinker.lib.public_interfaces.api_future import APIFuture, AwaitableConcurrentFuture
from tinker.lib.telemetry import Telemetry, capture_exceptions
Expand Down Expand Up @@ -55,6 +56,30 @@
}


def _matching_checkpoint_ids(
checkpoint_name: str, checkpoint_type: Literal["training", "sampler"]
) -> set[str]:
prefix = "weights" if checkpoint_type == "training" else "sampler_weights"
return {
checkpoint_name,
f"{prefix}/{checkpoint_name}",
}


def _find_matching_checkpoint_path(
checkpoints: list[types.Checkpoint],
checkpoint_name: str,
checkpoint_type: Literal["training", "sampler"],
) -> str | None:
matching_ids = _matching_checkpoint_ids(checkpoint_name, checkpoint_type)
for checkpoint in checkpoints:
if checkpoint.checkpoint_type != checkpoint_type:
continue
if checkpoint.checkpoint_id in matching_ids:
return checkpoint.tinker_path
return None


class TrainingClient(TelemetryProvider):
"""Client for training ML models with forward/backward passes and optimization.

Expand Down Expand Up @@ -609,7 +634,20 @@ async def _send_request():
)

async with self._take_turn(request_id):
future = await self.holder.execute_with_retries(_send_request)
try:
future = await self.holder.execute_with_retries(_send_request)
except ConflictError:
recovered_path = await self._recover_checkpoint_path_from_conflict(
checkpoint_name=name,
checkpoint_type="training",
)
if recovered_path is None:
raise
logger.warning(
"Recovered from save_state 409 conflict by reusing checkpoint path for '%s'",
name,
)
return types.SaveWeightsResponse(path=recovered_path)
return await _APIFuture(
types.SaveWeightsResponse,
self.holder,
Expand Down Expand Up @@ -656,6 +694,24 @@ async def _send_request():
queue_state_observer=self._queue_state_logger,
)

async def _recover_checkpoint_path_from_conflict(
self,
checkpoint_name: str,
checkpoint_type: Literal["training", "sampler"],
) -> str | None:
"""Resolve an existing checkpoint path after a save-name conflict."""

async def _send_request():
with self.holder.aclient(ClientConnectionPoolType.TRAIN) as client:
return await client.weights.list(model_id=self._guaranteed_model_id())

checkpoints_response = await self.holder.execute_with_retries(_send_request)
return _find_matching_checkpoint_path(
checkpoints_response.checkpoints,
checkpoint_name=checkpoint_name,
checkpoint_type=checkpoint_type,
)

@capture_exceptions(fatal=True)
def load_state(self, path: str) -> APIFuture[types.LoadWeightsResponse]:
"""Load model weights from a saved checkpoint.
Expand Down Expand Up @@ -745,7 +801,22 @@ async def _send_request():
)

async with self._take_turn(request_id):
future = await self.holder.execute_with_retries(_send_request)
try:
future = await self.holder.execute_with_retries(_send_request)
except ConflictError:
if name is None:
raise
recovered_path = await self._recover_checkpoint_path_from_conflict(
checkpoint_name=name,
checkpoint_type="sampler",
)
if recovered_path is None:
raise
logger.warning(
"Recovered from save_weights_for_sampler 409 conflict by reusing checkpoint path for '%s'",
name,
)
return types.SaveWeightsForSamplerResponseInternal(path=recovered_path)
return await _APIFuture(
types.SaveWeightsForSamplerResponseInternal,
self.holder,
Expand Down
99 changes: 99 additions & 0 deletions tests/test_service_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@

import json
import os
from datetime import UTC, datetime

import httpx
import pytest
from respx import MockRouter

import tinker
from tinker import types
from tinker.lib.internal_client_holder import InternalClientHolder
from tinker.lib.public_interfaces.training_client import (
TrainingClient,
_find_matching_checkpoint_path,
_matching_checkpoint_ids,
)

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")

Expand Down Expand Up @@ -202,3 +209,95 @@ def test_create_training_client_from_state_sync_uses_public_endpoint(

# Verify it uses the public endpoint (info_lite), not the full training run endpoint
assert info_lite_route.called


def test_matching_checkpoint_ids_supports_bare_and_prefixed_names() -> None:
assert _matching_checkpoint_ids("000010", "training") == {"000010", "weights/000010"}
assert _matching_checkpoint_ids("000010", "sampler") == {
"000010",
"sampler_weights/000010",
}


def test_find_matching_checkpoint_path_filters_by_type_and_name() -> None:
checkpoints = [
types.Checkpoint(
checkpoint_id="weights/000010",
checkpoint_type="training",
time=datetime.now(UTC),
tinker_path="tinker://run-1/weights/000010",
),
types.Checkpoint(
checkpoint_id="sampler_weights/000010",
checkpoint_type="sampler",
time=datetime.now(UTC),
tinker_path="tinker://run-1/sampler_weights/000010",
),
]

assert (
_find_matching_checkpoint_path(checkpoints, "000010", "training")
== "tinker://run-1/weights/000010"
)
assert (
_find_matching_checkpoint_path(checkpoints, "000010", "sampler")
== "tinker://run-1/sampler_weights/000010"
)


def test_retryable_status_codes_do_not_include_409() -> None:
assert InternalClientHolder._is_retryable_status_code(408)
assert not InternalClientHolder._is_retryable_status_code(409)
assert InternalClientHolder._is_retryable_status_code(429)


class _DummyHolder:
def __init__(self, checkpoints: list[types.Checkpoint]):
self._response = types.CheckpointsListResponse(checkpoints=checkpoints)

async def execute_with_retries(self, _func):
return self._response


@pytest.mark.asyncio
async def test_recover_checkpoint_path_from_conflict_returns_matching_path() -> None:
client = TrainingClient.__new__(TrainingClient)
client.holder = _DummyHolder(
checkpoints=[
types.Checkpoint(
checkpoint_id="weights/000010",
checkpoint_type="training",
time=datetime.now(UTC),
tinker_path="tinker://run-1/weights/000010",
)
]
)
client._guaranteed_model_id = lambda: "run-1" # type: ignore[method-assign]

recovered = await client._recover_checkpoint_path_from_conflict(
checkpoint_name="000010",
checkpoint_type="training",
)
assert recovered == "tinker://run-1/weights/000010"


@pytest.mark.asyncio
async def test_recover_checkpoint_path_from_conflict_returns_none_when_missing() -> None:
client = TrainingClient.__new__(TrainingClient)
client.holder = _DummyHolder(
checkpoints=[
types.Checkpoint(
checkpoint_id="weights/000011",
checkpoint_type="training",
time=datetime.now(UTC),
tinker_path="tinker://run-1/weights/000011",
)
]
)
client._guaranteed_model_id = lambda: "run-1" # type: ignore[method-assign]

recovered = await client._recover_checkpoint_path_from_conflict(
checkpoint_name="000010",
checkpoint_type="training",
)
assert recovered is None