Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions src/tinker/types/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,45 @@ class ParsedCheckpointTinkerPath(BaseModel):

@classmethod
def from_tinker_path(cls, tinker_path: str) -> "ParsedCheckpointTinkerPath":
"""Parse a tinker path to an instance of ParsedCheckpointTinkerPath"""
"""Parse a tinker path to an instance of ParsedCheckpointTinkerPath.

Supports two formats:
- Standard: tinker://run-id/weights/0001
- With suffix: tinker://run-id:suffix/weights/0001
(e.g., tinker://run-id:train:0/weights/0001)
"""
if not tinker_path.startswith("tinker://"):
raise ValueError(f"Invalid tinker path: {tinker_path}")
parts = tinker_path[9:].split("/")
if len(parts) != 3:
raise ValueError(f"Invalid tinker path: {tinker_path}")
if parts[1] not in ["weights", "sampler_weights"]:
raise ValueError(f"Invalid tinker path: {tinker_path}")
checkpoint_type = "training" if parts[1] == "weights" else "sampler"

# Remove the tinker:// prefix
path_parts = tinker_path[9:]

# Split into segments
# Format: run_id_with_type/checkpoint_type/checkpoint_id
segments = path_parts.split("/")
if len(segments) != 3:
raise ValueError(
f"Invalid tinker path: {tinker_path}. "
f"Expected: tinker://run-id/weights/0001 or tinker://run-id:train:0/weights/0001"
)

run_id_with_type = segments[0]
checkpoint_type_segment = segments[1]
checkpoint_id = segments[2]

# Validate checkpoint type
if checkpoint_type_segment not in ["weights", "sampler_weights"]:
raise ValueError(
f"Invalid checkpoint type: {checkpoint_type_segment}. "
f"Expected: weights or sampler_weights"
)


checkpoint_type = "training" if checkpoint_type_segment == "weights" else "sampler"

return cls(
tinker_path=tinker_path,
training_run_id=parts[0],
training_run_id=run_id_with_type,
checkpoint_type=checkpoint_type,
checkpoint_id="/".join(parts[1:]),
checkpoint_id="/".join([checkpoint_type_segment, checkpoint_id]),
)
62 changes: 62 additions & 0 deletions tests/test_checkpoint_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,65 @@ def test_after_without_run_id_error(self) -> None:
)
assert result.exit_code != 0
assert "--run-id" in self._get_error_message(result)


class TestParsedCheckpointTinkerPath:
"""Tests for ParsedCheckpointTinkerPath.from_tinker_path()."""

def test_standard_format(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath

result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://run-id/weights/0001"
)
assert result.training_run_id == "run-id"
assert result.checkpoint_type == "training"
assert result.checkpoint_id == "weights/0001"

def test_sampler_format(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath

result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://run-id/sampler_weights/0001"
)
assert result.training_run_id == "run-id"
assert result.checkpoint_type == "sampler"
assert result.checkpoint_id == "sampler_weights/0001"

def test_with_train_suffix(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath

result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://5f2d7413-3980-502a-b012-9b7e122b3305:train:0/sampler_weights/final"
)
assert result.training_run_id == "5f2d7413-3980-502a-b012-9b7e122b3305:train:0"
assert result.checkpoint_type == "sampler"
assert result.checkpoint_id == "sampler_weights/final"

def test_with_sampler_suffix(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath

result = ParsedCheckpointTinkerPath.from_tinker_path(
"tinker://run-id:sampler/weights/0001"
)
assert result.training_run_id == "run-id:sampler"
assert result.checkpoint_type == "training"
assert result.checkpoint_id == "weights/0001"

def test_invalid_missing_prefix(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath

with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("run-id/weights/0001")

def test_invalid_wrong_checkpoint_type(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath

with pytest.raises(ValueError, match="Invalid checkpoint type"):
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/invalid/0001")

def test_invalid_not_enough_parts(self) -> None:
from tinker.types.checkpoint import ParsedCheckpointTinkerPath

with pytest.raises(ValueError, match="Invalid tinker path"):
ParsedCheckpointTinkerPath.from_tinker_path("tinker://run-id/weights")