From 8da663a3a0d9a26d221d0d7f5c82fefa05b41530 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sun, 11 May 2025 10:55:35 +0200 Subject: [PATCH 1/3] add placeholder cli entry point --- ablate/core/__init__.py | 0 ablate/core/cli/main.py | 2 ++ 2 files changed, 2 insertions(+) create mode 100644 ablate/core/__init__.py create mode 100644 ablate/core/cli/main.py diff --git a/ablate/core/__init__.py b/ablate/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ablate/core/cli/main.py b/ablate/core/cli/main.py new file mode 100644 index 0000000..81495e1 --- /dev/null +++ b/ablate/core/cli/main.py @@ -0,0 +1,2 @@ +def main() -> None: # TODO + pass From 8dc1d8fb1c2f19bd6ea6717bbc280960ddefaf4b Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sun, 11 May 2025 11:00:03 +0200 Subject: [PATCH 2/3] add basic run types --- ablate/core/types/__init__.py | 4 ++++ ablate/core/types/runs.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 ablate/core/types/__init__.py create mode 100644 ablate/core/types/runs.py diff --git a/ablate/core/types/__init__.py b/ablate/core/types/__init__.py new file mode 100644 index 0000000..f4e9ef1 --- /dev/null +++ b/ablate/core/types/__init__.py @@ -0,0 +1,4 @@ +from .runs import GroupedRun, Run + + +__all__ = ["GroupedRun", "Run"] diff --git a/ablate/core/types/runs.py b/ablate/core/types/runs.py new file mode 100644 index 0000000..bf1c459 --- /dev/null +++ b/ablate/core/types/runs.py @@ -0,0 +1,33 @@ +from typing import Any, Dict, List, Tuple + +from pydantic import BaseModel + + +class Run(BaseModel): + """A single run of an experiment. + + Args: + id: Unique identifier for the run. + params: Parameters used for the run. + metrics: Metrics recorded during the run. + temporal: Temporal data recorded during the run. + """ + + id: str + params: Dict[str, Any] + metrics: Dict[str, float] + temporal: Dict[str, list[Tuple[int, float]]] = {} + + +class GroupedRun(BaseModel): + """A collection of runs grouped by a key-value pair. + + Args: + key: Key used to group the runs. + value: Value used to group the runs. + runs: List of runs that belong to this group. + """ + + key: str + value: str + runs: List[Run] From d54e0d328449428d5adffafe13015ab3bfe157fd Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sun, 11 May 2025 11:14:37 +0200 Subject: [PATCH 3/3] add tests to types --- tests/core/test_types.py | 56 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 tests/core/test_types.py diff --git a/tests/core/test_types.py b/tests/core/test_types.py new file mode 100644 index 0000000..b789406 --- /dev/null +++ b/tests/core/test_types.py @@ -0,0 +1,56 @@ +from pydantic import ValidationError +import pytest + +from ablate.core.types import GroupedRun, Run + + +def test_run() -> None: + run = Run(id="test_run", params={"param1": 1}, metrics={"metric1": 0.5}) + assert run.id == "test_run" + assert run.params == {"param1": 1} + assert run.metrics == {"metric1": 0.5} + assert run.temporal == {} + + +def test_temporal_data() -> None: + run = Run( + id="test_run", + params={"param1": 1}, + metrics={"metric1": 0.5}, + temporal={"metric1": [(0, 0.0), (1, 1.0), (2, 2.0)]}, + ) + assert run.temporal == {"metric1": [(0, 0.0), (1, 1.0), (2, 2.0)]} + assert run.temporal["metric1"][0] == (0, 0.0) + + +def test_invalid_metrics_type() -> None: + with pytest.raises(ValidationError): + Run(id="bad", params={}, metrics="not a dict") + + +def test_grouped_run() -> None: + run1 = Run(id="run1", params={"param1": 1}, metrics={"metric1": 0.5}) + run2 = Run(id="run2", params={"param2": 2}, metrics={"metric2": 0.8}) + grouped_run = GroupedRun(key="group_key", value="group_value", runs=[run1, run2]) + assert grouped_run.key == "group_key" + assert grouped_run.value == "group_value" + assert len(grouped_run.runs) == 2 + assert grouped_run.runs[0].id == "run1" + assert grouped_run.runs[1].id == "run2" + + +def test_invalid_runs_type() -> None: + with pytest.raises(ValidationError): + GroupedRun(key="group_key", value="group_value", runs=["not a run"]) + + +def test_run_roundtrip_serialization() -> None: + run = Run( + id="run1", + params={"x": 1}, + metrics={"acc": 0.9}, + temporal={"loss": [(0, 0.1), (1, 0.05)]}, + ) + data = run.model_dump() + recovered = Run(**data) + assert recovered == run