Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added ablate/core/__init__.py
Empty file.
2 changes: 2 additions & 0 deletions ablate/core/cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
def main() -> None: # TODO
pass
4 changes: 4 additions & 0 deletions ablate/core/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .runs import GroupedRun, Run


__all__ = ["GroupedRun", "Run"]
33 changes: 33 additions & 0 deletions ablate/core/types/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Any, Dict, List, Tuple

from pydantic import BaseModel


class Run(BaseModel):
"""A single run of an experiment.

Args:
id: Unique identifier for the run.
params: Parameters used for the run.
metrics: Metrics recorded during the run.
temporal: Temporal data recorded during the run.
"""

id: str
params: Dict[str, Any]
metrics: Dict[str, float]
temporal: Dict[str, list[Tuple[int, float]]] = {}


class GroupedRun(BaseModel):
"""A collection of runs grouped by a key-value pair.

Args:
key: Key used to group the runs.
value: Value used to group the runs.
runs: List of runs that belong to this group.
"""

key: str
value: str
runs: List[Run]
56 changes: 56 additions & 0 deletions tests/core/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from pydantic import ValidationError
import pytest

from ablate.core.types import GroupedRun, Run


def test_run() -> None:
run = Run(id="test_run", params={"param1": 1}, metrics={"metric1": 0.5})
assert run.id == "test_run"
assert run.params == {"param1": 1}
assert run.metrics == {"metric1": 0.5}
assert run.temporal == {}


def test_temporal_data() -> None:
run = Run(
id="test_run",
params={"param1": 1},
metrics={"metric1": 0.5},
temporal={"metric1": [(0, 0.0), (1, 1.0), (2, 2.0)]},
)
assert run.temporal == {"metric1": [(0, 0.0), (1, 1.0), (2, 2.0)]}
assert run.temporal["metric1"][0] == (0, 0.0)


def test_invalid_metrics_type() -> None:
with pytest.raises(ValidationError):
Run(id="bad", params={}, metrics="not a dict")


def test_grouped_run() -> None:
run1 = Run(id="run1", params={"param1": 1}, metrics={"metric1": 0.5})
run2 = Run(id="run2", params={"param2": 2}, metrics={"metric2": 0.8})
grouped_run = GroupedRun(key="group_key", value="group_value", runs=[run1, run2])
assert grouped_run.key == "group_key"
assert grouped_run.value == "group_value"
assert len(grouped_run.runs) == 2
assert grouped_run.runs[0].id == "run1"
assert grouped_run.runs[1].id == "run2"


def test_invalid_runs_type() -> None:
with pytest.raises(ValidationError):
GroupedRun(key="group_key", value="group_value", runs=["not a run"])


def test_run_roundtrip_serialization() -> None:
run = Run(
id="run1",
params={"x": 1},
metrics={"acc": 0.9},
temporal={"loss": [(0, 0.1), (1, 0.05)]},
)
data = run.model_dump()
recovered = Run(**data)
assert recovered == run