diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 903d438..4a90b01 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10"] # add later: "3.11", "3.12", "3.13" steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/ablate/__init__.py b/ablate/__init__.py index 3dc1f76..6ba6c11 100644 --- a/ablate/__init__.py +++ b/ablate/__init__.py @@ -1 +1,6 @@ +from . import queries, sources + + +__all__ = ["queries", "sources"] + __version__ = "0.1.0" diff --git a/ablate/queries/__init__.py b/ablate/queries/__init__.py new file mode 100644 index 0000000..766e727 --- /dev/null +++ b/ablate/queries/__init__.py @@ -0,0 +1,24 @@ +from .grouped_query import GroupedQuery +from .query import Query +from .selectors import ( + AbstractMetric, + AbstractParam, + AbstractSelector, + Id, + Metric, + Param, + TemporalMetric, +) + + +__all__ = [ + "AbstractMetric", + "AbstractParam", + "AbstractSelector", + "GroupedQuery", + "Id", + "Metric", + "Param", + "Query", + "TemporalMetric", +] diff --git a/ablate/queries/grouped_query.py b/ablate/queries/grouped_query.py new file mode 100644 index 0000000..16317ca --- /dev/null +++ b/ablate/queries/grouped_query.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +from collections import defaultdict +from copy import deepcopy +from typing import TYPE_CHECKING, Callable, Dict, List, Literal + +from ablate.core.types import GroupedRun, Run + + +if TYPE_CHECKING: # pragma: no cover + from .query import Query # noqa: TC004 + from .selectors import AbstractMetric + + +class GroupedQuery: + def __init__(self, groups: List[GroupedRun]) -> None: + """Query interface for manipulating grouped runs in a functional way. + + All methods operate on a shallow copy of the runs in the query, so the original + runs are not modified and assumed to be immutable. + + Args: + groups: A list of grouped runs to be queried. + """ + self._grouped = groups + + def filter(self, fn: Callable[[GroupedRun], bool]) -> GroupedQuery: + """Filter the grouped runs in the grouped query based on a predicate function. + + Args: + fn: Predicate function that takes in a grouped run and returns a boolean + value. + + Returns: + A new grouped query with the grouped runs that satisfy the predicate + function. + """ + return GroupedQuery([g for g in self._grouped[:] if fn(g)]) + + def map(self, fn: Callable[[GroupedRun], GroupedRun]) -> GroupedQuery: + """Apply a function to each grouped run in the grouped query. + + This function is intended to be used for modifying the grouped runs in the + grouped query. The function should return a new grouped run object as the + original grouped run is not modified. + + Args: + fn: Function that takes in a grouped run and returns a new grouped run + object. + + Returns: + A new grouped query with the modified grouped runs. + """ + return GroupedQuery([fn(deepcopy(g)) for g in self._grouped]) + + def sort(self, key: AbstractMetric, ascending: bool = False) -> GroupedQuery: + """Sort the runs inside each grouped run in the grouped query based on a metric. + + Args: + key: Metric to sort the grouped runs by. + ascending: Whether to sort in ascending order. + Defaults to False (descending order). + + Returns: + A new grouped query with the grouped runs sorted by the specified metric. + """ + return GroupedQuery( + [ + GroupedRun( + key=g.key, + value=g.value, + runs=sorted(g.runs, key=key, reverse=not ascending), + ) + for g in self._grouped + ] + ) + + def head(self, n: int) -> Query: + """Get the first n runs inside each grouped run. + + Args: + n: Number of runs to return per group. + + Returns: + A new query with the first n runs from each grouped run. + """ + return GroupedQuery( + [ + GroupedRun(key=g.key, value=g.value, runs=g.runs[:n]) + for g in self._grouped + ] + )._to_query() + + def tail(self, n: int) -> Query: + """Get the last n runs inside each grouped run. + + Args: + n: Number of runs to return per group. + + Returns: + A new query with the last n runs from each grouped run. + """ + return GroupedQuery( + [ + GroupedRun(key=g.key, value=g.value, runs=g.runs[-n:]) + for g in self._grouped + ] + )._to_query() + + def topk(self, metric: AbstractMetric, k: int) -> Query: + """Get the top k runs inside each grouped run based on a metric. + + Args: + metric: Metric to sort the runs by. + k: Number of top runs to return per group. + + Returns: + A new query with the top k runs from each grouped run based on the + specified metric. + """ + return GroupedQuery( + [ + GroupedRun( + key=g.key, + value=g.value, + runs=sorted(g.runs, key=metric, reverse=metric.direction == "min")[ + :k + ], + ) + for g in self._grouped + ] + )._to_query() + + def bottomk(self, metric: AbstractMetric, k: int) -> Query: + """Get the bottom k runs inside each grouped run based on a metric. + + Args: + metric: Metric to sort the runs by. + k: Number of bottom runs to return per group. + + Returns: + A new query with the bottom k runs from each grouped run based on the + specified metric. + """ + return GroupedQuery( + [ + GroupedRun( + key=g.key, + value=g.value, + runs=sorted(g.runs, key=metric, reverse=metric.direction == "max")[ + :k + ], + ) + for g in self._grouped + ] + )._to_query() + + def aggregate( + self, + method: Literal["first", "last", "best", "worst", "mean"], + over: AbstractMetric, + ) -> Query: + """Aggregate each group of runs using a specified method. + + Supported methods include: + - "first": Selects the first run from each group. + - "last": Selects the last run from each group. + - "best": Selects the run with the best value based on the given metric. + - "worst": Selects the run with the worst value based on the given metric. + - "mean": Computes the mean run across all runs in each group, including + averaged metrics and temporal data, and collapsed metadata. + Args: + method: Aggregation strategy to apply per group. + over: The metric used for comparison when using "best" or "worst" methods. + + Raises: + ValueError: If an unsupported aggregation method is provided. + + Returns: + A new query with the aggregated runs from each group. + """ + from .query import Query + + match method: + case "first": + return self.head(1) + case "last": + return self.tail(1) + case "best": + return self.topk(over, 1) + case "worst": + return self.bottomk(over, 1) + case "mean": + return Query([self._mean_run(g) for g in self._grouped]) + case _: + raise ValueError( + f"Unsupported aggregation method: '{method}'. Must be " + "'first', 'last', 'best', 'worst', or 'mean'." + ) + + @staticmethod + def _mean_run(group: GroupedRun) -> Run: + def _mean(values: List[float]) -> float: + return sum(values) / len(values) if values else float("nan") + + def _mean_temporal(runs: List[Run]) -> Dict[str, List[tuple[int, float]]]: + all_keys = set().union(*(r.temporal.keys() for r in runs)) + step_accumulator: Dict[str, Dict[int, List[float]]] = {} + + for key in all_keys: + step_values = defaultdict(list) + for run in runs: + for step, val in run.temporal.get(key, []): + step_values[step].append(val) + step_accumulator[key] = step_values + + return { + key: sorted( + (step, sum(vals) / len(vals)) for step, vals in step_values.items() + ) + for key, step_values in step_accumulator.items() + } + + def _common_metadata(attr: str) -> Dict[str, str]: + key_sets = [getattr(r, attr).keys() for r in group.runs] + common_keys = set.intersection(*map(set, key_sets)) + result = {} + for k in common_keys: + values = {str(getattr(r, attr)[k]) for r in group.runs} + result[k] = next(iter(values)) if len(values) == 1 else "#" + return result + + all_metrics = [r.metrics for r in group.runs] + all_keys = set().union(*all_metrics) + mean_metrics = { + k: _mean([m[k] for m in all_metrics if k in m]) for k in all_keys + } + + return Run( + id=f"grouped:{group.key}:{group.value}", + params=_common_metadata("params"), + metrics=mean_metrics, + temporal=_mean_temporal(group.runs), + ) + + def _to_query(self) -> Query: + from .query import Query + + return Query([run for group in self._grouped for run in group.runs]) + + def all(self) -> List[Run]: + """Collect all runs in the grouped query by flattening the grouped runs. + + Returns: + A list of all runs in the grouped query. + """ + return deepcopy(self._to_query()._runs) + + def copy(self) -> GroupedQuery: + """Obtain a shallow copy of the grouped query. + + Returns: + A new grouped query with the same grouped runs as the original grouped + query. + """ + return GroupedQuery(self._grouped[:]) + + def deepcopy(self) -> GroupedQuery: + """Obtain a deep copy of the grouped query. + + Returns: + A new grouped query with deep copies of the grouped runs in the original + grouped query. + """ + return GroupedQuery(deepcopy(self._grouped)) + + def __len__(self) -> int: + """Get the number of grouped runs in the grouped query. + + Returns: + The number of grouped runs in the grouped query. + """ + return len(self._grouped) diff --git a/ablate/queries/query.py b/ablate/queries/query.py new file mode 100644 index 0000000..87306c5 --- /dev/null +++ b/ablate/queries/query.py @@ -0,0 +1,218 @@ +from __future__ import annotations + +from collections import defaultdict +from copy import deepcopy +import hashlib +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Set, Tuple, Union + +from ablate.core.types import GroupedRun, Run + +from .grouped_query import GroupedQuery + + +if TYPE_CHECKING: # pragma: no cover + from .selectors import AbstractMetric, AbstractParam + + +class Query: + def __init__(self, runs: List[Run]) -> None: + """Query interface for manipulating runs in a functional way. + + All methods operate on a shallow copy of the runs in the query, so the original + runs are not modified and assumed to be immutable. + + Args: + runs: List of runs to be queried. + """ + self._runs = runs + + def filter(self, fn: Callable[[Run], bool]) -> Query: + """Filter the runs in the query based on a predicate function. + + Args: + fn: Predicate function that takes in a run and returns a boolean value. + + Returns: + A new query with the runs that satisfy the predicate function. + """ + return Query([r for r in self._runs[:] if fn(r)]) + + def map(self, fn: Callable[[Run], Run]) -> Query: + """Apply a function to each run in the query. + + This function is intended to be used for modifying the runs in the query. The + function should return a new run object as the original run is not modified. + + Args: + fn: Function that takes in a run and returns a new run object. + + Returns: + A new query with the modified runs. + """ + return Query([fn(r) for r in deepcopy(self._runs)]) + + def sort(self, key: AbstractMetric, ascending: bool = False) -> Query: + """Sort the runs in the query based on a metric. + + Args: + key: Metric to sort the runs by. + ascending: Whether to sort in ascending order. + Defaults to False (descending order). + + Returns: + A new query with the runs sorted by the specified metric. + """ + return Query(sorted(self._runs[:], key=key, reverse=not ascending)) + + def groupby( + self, + selectors: Union[AbstractParam, List[AbstractParam]], + ) -> GroupedQuery: + """Group the runs in the query by one or more selectors. + + Args: + selectors: Selector or list of selectors to group the runs by. + + Returns: + A grouped query containing the grouped runs. + """ + if not isinstance(selectors, list): + selectors = [selectors] + + def key_fn(run: Run) -> Tuple[Any, ...]: + return tuple(selector(run) for selector in selectors) + + groups = defaultdict(list) + for run in self._runs: + groups[key_fn(run)].append(run) + + grouped = [ + GroupedRun( + key="+".join(selector.name for selector in selectors), + value="|".join(map(str, k)), + runs=v, + ) + for k, v in groups.items() + ] + return GroupedQuery(grouped) + + def groupdiff( + self, + selectors: Union[AbstractParam, List[AbstractParam]], + ) -> GroupedQuery: + """Group the runs in the query by one or more selectors, excluding the keys. + This is similar to `groupby` but it excludes the specified keys from the + grouping key. + + Args: + selectors: Selector or list of selectors to exclude from the grouping key. + + Returns: + A grouped query containing the grouped runs. + """ + if not isinstance(selectors, list): + selectors = [selectors] + + def exclude_keys( + d: Dict[str, Any], + keys: Set[str], + ) -> Tuple[Tuple[str, Any], ...]: + return tuple(sorted((k, v) for k, v in d.items() if k not in keys)) + + exclude_names = {s.name for s in selectors} + groups = defaultdict(list) + for run in self._runs: + key = exclude_keys(run.params, exclude_names) + groups[key].append(run) + + def _hash_key(kv: Tuple[Tuple[str, Any], ...]) -> str: + raw = ",".join(f"{k}={v}" for k, v in sorted(kv)) + return hashlib.md5(raw.encode()).hexdigest()[:8] + + grouped = [ + GroupedRun( + key="-".join(s.name for s in selectors), + value=_hash_key(k), + runs=runs, + ) + for k, runs in groups.items() + ] + return GroupedQuery(grouped) + + def head(self, n: int) -> Query: + """Get the first n runs in the query. + + Args: + n: Number of runs to return. + + Returns: + A new query with the first n runs. + """ + return Query(self._runs[:n]) + + def tail(self, n: int) -> Query: + """Get the last n runs in the query. + + Args: + n: Number of runs to return. + + Returns: + A new query with the last n runs. + """ + return Query(self._runs[-n:]) + + def topk(self, metric: AbstractMetric, k: int) -> Query: + """Get the top k runs in the query based on a metric. + + Args: + metric: Metric to sort the runs by. + k: Number of top runs to return. + + Returns: + A new query with the top k runs based on the specified metric. + """ + return self.sort(metric, ascending=metric.direction == "min").head(k) + + def bottomk(self, metric: AbstractMetric, k: int) -> Query: + """Get the bottom k runs in the query based on a metric. + + Args: + metric: Metric to sort the runs by. + k: Number of bottom runs to return. + + Returns: + A new query with the bottom k runs based on the specified metric. + """ + return self.sort(metric, ascending=metric.direction == "max").head(k) + + def all(self) -> List[Run]: + """Collect all runs in the query. + + Returns: + A list of all runs in the query. + """ + return deepcopy(self._runs) + + def copy(self) -> Query: + """Obtain a shallow copy of the query. + + Returns: + A new query with the same runs as the original query. + """ + return Query(self._runs[:]) + + def deepcopy(self) -> Query: + """Obtain a deep copy of the query. + + Returns: + A new query with deep copies of the runs in the original query. + """ + return Query(deepcopy(self._runs)) + + def __len__(self) -> int: + """Get the number of runs in the query. + + Returns: + The number of runs in the query. + """ + return len(self._runs) diff --git a/ablate/queries/selectors.py b/ablate/queries/selectors.py new file mode 100644 index 0000000..f44708b --- /dev/null +++ b/ablate/queries/selectors.py @@ -0,0 +1,130 @@ +from abc import ABC, abstractmethod +from operator import eq, ge, gt, le, lt, ne +from typing import Any, Callable, Literal + +from ablate.core.types import Run + + +class AbstractSelector(ABC): + def __init__(self, name: str) -> None: + """Abstract class for selecting runs based on a specific attribute. + + Args: + name: Name of the attribute to select on. + """ + self.name = name + + @abstractmethod + def __call__(self, run: Run) -> Any: ... + + def _cmp(self, op: Callable[[Any, Any], bool], other: Any) -> Callable[[Run], bool]: + return lambda run: op(self.__call__(run), other) + + def __eq__(self, other: object) -> Callable[[Run], bool]: # type: ignore[override] + return self._cmp(eq, other) + + def __ne__(self, other: object) -> Callable[[Run], bool]: # type: ignore[override] + return self._cmp(ne, other) + + def __lt__(self, other: Any) -> Callable[[Run], bool]: + return self._cmp(lt, other) + + def __le__(self, other: Any) -> Callable[[Run], bool]: + return self._cmp(le, other) + + def __gt__(self, other: Any) -> Callable[[Run], bool]: + return self._cmp(gt, other) + + def __ge__(self, other: Any) -> Callable[[Run], bool]: + return self._cmp(ge, other) + + +class AbstractParam(AbstractSelector, ABC): ... + + +class Id(AbstractParam): + def __init__(self) -> None: + """Selector for the ID of the run.""" + super().__init__("id") + + def __call__(self, run: Run) -> str: + return run.id + + +class Param(AbstractParam): + """Selector for a specific parameter of the run.""" + + def __call__(self, run: Run) -> int | float | str | None: + return run.params.get(self.name) + + +class AbstractMetric(AbstractSelector, ABC): + def __init__( + self, + name: str, + direction: Literal["min", "max"], + ) -> None: + super().__init__(name) + if direction not in ("min", "max"): + raise ValueError( + f"Invalid direction: '{direction}'. Must be 'min' or 'max'." + ) + self.direction = direction + + +class Metric(AbstractMetric): + """Selector for a specific metric of the run. + + Args: + name: Name of the metric to select on. + direction: Direction of the metric. "min" for minimization, "max" for + maximization. + """ + + def __call__(self, run: Run) -> float: + val = run.metrics.get(self.name) + if val is None: + return float("-inf") if self.direction == "max" else float("inf") + return val + + +class TemporalMetric(AbstractMetric): + def __init__( + self, + name: str, + direction: Literal["min", "max"], + reduction: Literal["min", "max", "first", "last"] | None = None, + ) -> None: + """Selector for a specific temporal metric of the run. + + Args: + name: Name of the temporal metric to select on. + direction: Direction of the metric. "min" for minimization, "max" for + maximization. + reduction: Reduction method to apply to the temporal metric. "min" for + minimum, "max" for maximum, "first" for the first value, and "last" + for the last value. If None, the direction is used as the reduction. + Defaults to None. + """ + super().__init__(name, direction) + if reduction is not None and reduction not in ("min", "max", "first", "last"): + raise ValueError( + f"Invalid reduction method: '{reduction}'. Must be 'min', 'max', " + "'first', or 'last'." + ) + self.reduction = reduction or direction + + def __call__(self, run: Run) -> float: + values = run.temporal.get(self.name, []) + if not values: + return float("nan") + + match self.reduction: + case "min": + return min(v for _, v in values) + case "max": + return max(v for _, v in values) + case "first": + return values[0][1] + case "last": + return values[-1][1] diff --git a/tests/queries/test_grouped_query.py b/tests/queries/test_grouped_query.py new file mode 100644 index 0000000..80bde50 --- /dev/null +++ b/tests/queries/test_grouped_query.py @@ -0,0 +1,107 @@ +import pytest + +from ablate.core.types import GroupedRun, Run +from ablate.queries.grouped_query import GroupedQuery +from ablate.queries.query import Query +from ablate.queries.selectors import Metric, Param + + +def make_runs() -> list[Run]: + return [ + Run(id="a", params={"model": "resnet", "seed": 1}, metrics={"accuracy": 0.7}), + Run(id="b", params={"model": "resnet", "seed": 2}, metrics={"accuracy": 0.8}), + Run(id="c", params={"model": "vit", "seed": 1}, metrics={"accuracy": 0.6}), + Run(id="d", params={"model": "vit", "seed": 2}, metrics={"accuracy": 0.9}), + ] + + +def make_grouped() -> GroupedQuery: + return Query(make_runs()).groupby(Param("model")) + + +def test_filter_keeps_matching_groups() -> None: + grouped = make_grouped() + filtered = grouped.filter(lambda g: g.value == "resnet") + assert len(filtered) == 1 + assert all(run.params["model"] == "resnet" for run in filtered.all()) + + +def test_map_modifies_each_group() -> None: + grouped = make_grouped() + + def fn(group: GroupedRun) -> GroupedRun: + group.runs = [Run(**{**r.model_dump(), "id": r.id + "_x"}) for r in group.runs] + return group + + mapped = grouped.map(fn) + assert all(r.id.endswith("_x") for r in mapped.all()) + + +def test_sort_sorts_within_each_group() -> None: + grouped = make_grouped().sort(Metric("accuracy", direction="max"), ascending=True) + for group in grouped._grouped: + accs = [r.metrics["accuracy"] for r in group.runs] + assert accs == sorted(accs) + + +def test_head_tail_topk_bottomk_all_return_expected_shape() -> None: + grouped = make_grouped() + assert len(grouped.head(1).all()) == 2 + assert len(grouped.tail(1).all()) == 2 + assert len(grouped.topk(Metric("accuracy", direction="max"), 1).all()) == 2 + assert len(grouped.bottomk(Metric("accuracy", direction="max"), 1).all()) == 2 + + +def test_aggregate_all_strategies() -> None: + grouped = make_grouped() + m = Metric("accuracy", direction="max") + assert len(grouped.aggregate("first", over=m).all()) == 2 + assert len(grouped.aggregate("last", over=m).all()) == 2 + assert len(grouped.aggregate("best", over=m).all()) == 2 + assert len(grouped.aggregate("worst", over=m).all()) == 2 + assert len(grouped.aggregate("mean", over=m).all()) == 2 + + with pytest.raises(ValueError, match="Unsupported aggregation method"): + grouped.aggregate("unsupported", over=m) + + +def test_aggregate_mean_collapses_metadata_and_temporal() -> None: + run1 = Run( + id="a", + params={"model": "resnet", "seed": 1}, + metrics={"acc": 0.8}, + temporal={"acc": [(1, 0.2), (2, 0.6)]}, + ) + run2 = Run( + id="b", + params={"model": "resnet", "seed": 2}, + metrics={"acc": 0.4}, + temporal={"acc": [(1, 0.6), (2, 1.0)]}, + ) + grouped = GroupedQuery([GroupedRun(key="model", value="resnet", runs=[run1, run2])]) + agg = grouped.aggregate("mean", over=Metric("acc", direction="max")).all()[0] + + assert agg.params["model"] == "resnet" + assert agg.params["seed"] == "#" + assert agg.metrics["acc"] == pytest.approx(0.6) + assert agg.temporal["acc"] == [(1, 0.4), (2, 0.8)] + + +def test_to_query_and_all_return_same_runs() -> None: + grouped = make_grouped() + assert grouped._to_query().all() == grouped.all() + + +def test_copy_and_deepcopy() -> None: + grouped = make_grouped() + shallow = grouped.copy() + deep = grouped.deepcopy() + + assert shallow._grouped == grouped._grouped + assert shallow._grouped is not grouped._grouped + + assert deep._grouped == grouped._grouped + assert deep._grouped is not grouped._grouped + assert all( + dr is not gr for dr, gr in zip(deep._grouped, grouped._grouped, strict=False) + ) diff --git a/tests/queries/test_query.py b/tests/queries/test_query.py new file mode 100644 index 0000000..7f19fb7 --- /dev/null +++ b/tests/queries/test_query.py @@ -0,0 +1,90 @@ +from ablate.core.types import Run +from ablate.queries.query import Query +from ablate.queries.selectors import Metric, Param + + +def make_runs() -> list[Run]: + return [ + Run(id="a", params={"model": "resnet", "seed": 1}, metrics={"accuracy": 0.7}), + Run(id="b", params={"model": "resnet", "seed": 2}, metrics={"accuracy": 0.8}), + Run(id="c", params={"model": "vit", "seed": 1}, metrics={"accuracy": 0.9}), + ] + + +def test_filter() -> None: + runs = Query(make_runs()).filter(Param("model") == "resnet").all() + assert len(runs) == 2 + assert all(r.params["model"] == "resnet" for r in runs) + + +def test_sort() -> None: + runs = Query(make_runs()).sort(Metric("accuracy", direction="min")).all() + assert [r.id for r in runs] == ["c", "b", "a"] + + +def test_head_tail() -> None: + q = Query(make_runs()) + assert q.head(1).all()[0].id == "a" + assert q.tail(1).all()[0].id == "c" + + +def test_topk_bottomk() -> None: + q = Query(make_runs()) + top = q.topk(Metric("accuracy", direction="max"), k=2).all() + bot = q.bottomk(Metric("accuracy", direction="max"), k=1).all() + assert [r.id for r in top] == ["c", "b"] + assert [r.id for r in bot] == ["a"] + + +def test_map() -> None: + q = Query(make_runs()) + + def upper_case_id(run: Run) -> Run: + run.id = run.id.upper() + return run + + updated = q.map(upper_case_id).all() + assert updated[0].id == "A" + assert make_runs()[0].id == "a" + + +def test_groupby_single_key() -> None: + gq = Query(make_runs()).groupby(Param("model")) + assert len(gq) == 2 + assert {g.value for g in gq._grouped} == {"resnet", "vit"} + + +def test_groupby_multiple_keys() -> None: + gq = Query(make_runs()).groupby([Param("model"), Param("seed")]) + keys = {(g.key, g.value) for g in gq._grouped} + assert len(keys) == 3 + + +def test_groupdiff() -> None: + grouped = Query(make_runs()).groupdiff(Param("seed"))._grouped + expected_group_sizes = {"resnet": 2, "vit": 1} + for group in grouped: + model = group.runs[0].params["model"] + assert len(group.runs) == expected_group_sizes[model] + assert all(len(g.value) == 8 for g in grouped) + + +def test_query_copy_shallow() -> None: + original = Query(make_runs()) + copied = original.copy() + + assert copied._runs == original._runs + assert copied._runs is not original._runs + assert copied._runs[0] is original._runs[0] + + +def test_query_deepcopy() -> None: + original = Query(make_runs()) + deepcopied = original.deepcopy() + assert deepcopied._runs == original._runs + assert deepcopied._runs is not original._runs + assert deepcopied._runs[0] is not original._runs[0] + + +def test_query_len() -> None: + assert len(Query(make_runs())) == 3 diff --git a/tests/queries/test_selectors.py b/tests/queries/test_selectors.py new file mode 100644 index 0000000..142aecc --- /dev/null +++ b/tests/queries/test_selectors.py @@ -0,0 +1,85 @@ +import pytest + +from ablate.core.types import Run +from ablate.queries.selectors import Id, Metric, Param, TemporalMetric + + +@pytest.fixture +def example_run() -> Run: + return Run( + id="run-42", + params={"model": "resnet", "lr": 0.001}, + metrics={"accuracy": 0.91, "loss": 0.1}, + temporal={"accuracy": [(0, 0.5), (1, 0.8), (2, 0.9)]}, + ) + + +def test_id_selector(example_run: Run) -> None: + selector = Id() + assert selector(example_run) == "run-42" + assert selector(example_run) != "run-x" + + +def test_param_selector(example_run: Run) -> None: + selector = Param("lr") + assert selector(example_run) == 0.001 + assert (selector > 0.0001)(example_run) + assert not (selector < 0.0001)(example_run) + + missing_selector = Param("missing") + assert missing_selector(example_run) is None + + +def test_param_comparisons(example_run: Run) -> None: + selector = Param("lr") + assert (selector == 0.001)(example_run) + assert not (selector == 0.01)(example_run) + + assert (selector != 0.01)(example_run) + assert not (selector != 0.001)(example_run) + + assert (selector <= 0.001)(example_run) + assert (selector >= 0.001)(example_run) + assert not (selector <= 0.0001)(example_run) + assert not (selector >= 0.01)(example_run) + + +def test_metric_selector(example_run: Run) -> None: + selector = Metric("accuracy", direction="max") + assert selector(example_run) == 0.91 + assert (selector > 0.5)(example_run) + assert not (selector < 0.5)(example_run) + + missing = Metric("missing", direction="max") + assert missing(example_run) == float("-inf") + + +def test_invalid_metric_direction() -> None: + with pytest.raises(ValueError, match="Invalid direction"): + Metric("accuracy", direction="invalid") + + +@pytest.mark.parametrize( + ("reduction", "expected"), + [ + ("min", 0.5), + ("max", 0.9), + ("first", 0.5), + ("last", 0.9), + ], +) +def test_temporal_metric_selector( + example_run: Run, reduction: str, expected: float +) -> None: + selector = TemporalMetric("accuracy", direction="max", reduction=reduction) + assert selector(example_run) == expected + + +def test_temporal_metric_missing_returns_nan(example_run: Run) -> None: + selector = TemporalMetric("not_logged", direction="max", reduction="max") + assert selector(example_run) != selector(example_run) + + +def test_temporal_metric_invalid_reduction() -> None: + with pytest.raises(ValueError, match="Invalid reduction method"): + TemporalMetric("accuracy", direction="max", reduction="median")