diff --git a/ablate/core/__init__.py b/ablate/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ablate/core/cli/main.py b/ablate/core/cli/main.py new file mode 100644 index 0000000..81495e1 --- /dev/null +++ b/ablate/core/cli/main.py @@ -0,0 +1,2 @@ +def main() -> None: # TODO + pass diff --git a/ablate/core/types/__init__.py b/ablate/core/types/__init__.py new file mode 100644 index 0000000..f4e9ef1 --- /dev/null +++ b/ablate/core/types/__init__.py @@ -0,0 +1,4 @@ +from .runs import GroupedRun, Run + + +__all__ = ["GroupedRun", "Run"] diff --git a/ablate/core/types/runs.py b/ablate/core/types/runs.py new file mode 100644 index 0000000..bf1c459 --- /dev/null +++ b/ablate/core/types/runs.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, List, Tuple + +from pydantic import BaseModel + + +class Run(BaseModel): + """A single run of an experiment. + + Args: + id: Unique identifier for the run. + params: Parameters used for the run. + metrics: Metrics recorded during the run. + temporal: Temporal data recorded during the run. + """ + + id: str + params: Dict[str, Any] + metrics: Dict[str, float] + temporal: Dict[str, list[Tuple[int, float]]] = {} + + +class GroupedRun(BaseModel): + """A collection of runs grouped by a key-value pair. + + Args: + key: Key used to group the runs. + value: Value used to group the runs. + runs: List of runs that belong to this group. + """ + + key: str + value: str + runs: List[Run] diff --git a/tests/core/test_types.py b/tests/core/test_types.py new file mode 100644 index 0000000..b789406 --- /dev/null +++ b/tests/core/test_types.py @@ -0,0 +1,56 @@ +from pydantic import ValidationError +import pytest + +from ablate.core.types import GroupedRun, Run + + +def test_run() -> None: + run = Run(id="test_run", params={"param1": 1}, metrics={"metric1": 0.5}) + assert run.id == "test_run" + assert run.params == {"param1": 1} + assert run.metrics == {"metric1": 0.5} + assert run.temporal == {} + + +def test_temporal_data() -> None: + run = Run( + id="test_run", + params={"param1": 1}, + metrics={"metric1": 0.5}, + temporal={"metric1": [(0, 0.0), (1, 1.0), (2, 2.0)]}, + ) + assert run.temporal == {"metric1": [(0, 0.0), (1, 1.0), (2, 2.0)]} + assert run.temporal["metric1"][0] == (0, 0.0) + + +def test_invalid_metrics_type() -> None: + with pytest.raises(ValidationError): + Run(id="bad", params={}, metrics="not a dict") + + +def test_grouped_run() -> None: + run1 = Run(id="run1", params={"param1": 1}, metrics={"metric1": 0.5}) + run2 = Run(id="run2", params={"param2": 2}, metrics={"metric2": 0.8}) + grouped_run = GroupedRun(key="group_key", value="group_value", runs=[run1, run2]) + assert grouped_run.key == "group_key" + assert grouped_run.value == "group_value" + assert len(grouped_run.runs) == 2 + assert grouped_run.runs[0].id == "run1" + assert grouped_run.runs[1].id == "run2" + + +def test_invalid_runs_type() -> None: + with pytest.raises(ValidationError): + GroupedRun(key="group_key", value="group_value", runs=["not a run"]) + + +def test_run_roundtrip_serialization() -> None: + run = Run( + id="run1", + params={"x": 1}, + metrics={"acc": 0.9}, + temporal={"loss": [(0, 0.1), (1, 0.05)]}, + ) + data = run.model_dump() + recovered = Run(**data) + assert recovered == run