diff --git a/ablate/queries/grouped_query.py b/ablate/queries/grouped_query.py index 16317ca..a49d954 100644 --- a/ablate/queries/grouped_query.py +++ b/ablate/queries/grouped_query.py @@ -158,7 +158,7 @@ def bottomk(self, metric: AbstractMetric, k: int) -> Query: def aggregate( self, method: Literal["first", "last", "best", "worst", "mean"], - over: AbstractMetric, + over: AbstractMetric | None = None, ) -> Query: """Aggregate each group of runs using a specified method. @@ -172,13 +172,21 @@ def aggregate( Args: method: Aggregation strategy to apply per group. over: The metric used for comparison when using "best" or "worst" methods. + Has no effect for "first", "last", or "mean" methods. + Defaults to None. Raises: - ValueError: If an unsupported aggregation method is provided. + ValueError: If an unsupported aggregation method is provided or if the + "best" or "worst" method is used without a specified metric. + Returns: A new query with the aggregated runs from each group. """ + if method in {"best", "worst"} and over is None: + raise ValueError( + f"Method '{method}' requires a metric to be specified for comparison." + ) from .query import Query match method: diff --git a/ablate/report.py b/ablate/report.py index 2fa8a0a..5dbfde4 100644 --- a/ablate/report.py +++ b/ablate/report.py @@ -5,7 +5,7 @@ from typing_extensions import Self -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from ablate.blocks import AbstractBlock from ablate.core.types import Run diff --git a/tests/queries/test_grouped_query.py b/tests/queries/test_grouped_query.py index 80bde50..5093415 100644 --- a/tests/queries/test_grouped_query.py +++ b/tests/queries/test_grouped_query.py @@ -1,3 +1,5 @@ +from typing import List + import pytest from ablate.core.types import GroupedRun, Run @@ -6,7 +8,8 @@ from ablate.queries.selectors import Metric, Param -def make_runs() -> list[Run]: +@pytest.fixture +def runs() -> List[Run]: return [ Run(id="a", params={"model": "resnet", "seed": 1}, metrics={"accuracy": 0.7}), Run(id="b", params={"model": "resnet", "seed": 2}, metrics={"accuracy": 0.8}), @@ -15,20 +18,18 @@ def make_runs() -> list[Run]: ] -def make_grouped() -> GroupedQuery: - return Query(make_runs()).groupby(Param("model")) +@pytest.fixture +def grouped(runs: List[Run]) -> GroupedQuery: + return Query(runs).groupby(Param("model")) -def test_filter_keeps_matching_groups() -> None: - grouped = make_grouped() +def test_filter_keeps_matching_groups(grouped: GroupedQuery) -> None: filtered = grouped.filter(lambda g: g.value == "resnet") assert len(filtered) == 1 assert all(run.params["model"] == "resnet" for run in filtered.all()) -def test_map_modifies_each_group() -> None: - grouped = make_grouped() - +def test_map_modifies_each_group(grouped: GroupedQuery) -> None: def fn(group: GroupedRun) -> GroupedRun: group.runs = [Run(**{**r.model_dump(), "id": r.id + "_x"}) for r in group.runs] return group @@ -37,23 +38,23 @@ def fn(group: GroupedRun) -> GroupedRun: assert all(r.id.endswith("_x") for r in mapped.all()) -def test_sort_sorts_within_each_group() -> None: - grouped = make_grouped().sort(Metric("accuracy", direction="max"), ascending=True) +def test_sort_sorts_within_each_group(grouped: GroupedQuery) -> None: + grouped = grouped.sort(Metric("accuracy", direction="max"), ascending=True) for group in grouped._grouped: accs = [r.metrics["accuracy"] for r in group.runs] assert accs == sorted(accs) -def test_head_tail_topk_bottomk_all_return_expected_shape() -> None: - grouped = make_grouped() +def test_head_tail_topk_bottomk_all_return_expected_shape( + grouped: GroupedQuery, +) -> None: assert len(grouped.head(1).all()) == 2 assert len(grouped.tail(1).all()) == 2 assert len(grouped.topk(Metric("accuracy", direction="max"), 1).all()) == 2 assert len(grouped.bottomk(Metric("accuracy", direction="max"), 1).all()) == 2 -def test_aggregate_all_strategies() -> None: - grouped = make_grouped() +def test_aggregate_all_strategies(grouped: GroupedQuery) -> None: m = Metric("accuracy", direction="max") assert len(grouped.aggregate("first", over=m).all()) == 2 assert len(grouped.aggregate("last", over=m).all()) == 2 @@ -65,6 +66,13 @@ def test_aggregate_all_strategies() -> None: grouped.aggregate("unsupported", over=m) +def test_aggregate_best_worst_missing_over(grouped: GroupedQuery) -> None: + with pytest.raises(ValueError, match="Method 'best' requires a metric"): + grouped.aggregate("best") + with pytest.raises(ValueError, match="Method 'worst' requires a metric"): + grouped.aggregate("worst") + + def test_aggregate_mean_collapses_metadata_and_temporal() -> None: run1 = Run( id="a", @@ -87,13 +95,11 @@ def test_aggregate_mean_collapses_metadata_and_temporal() -> None: assert agg.temporal["acc"] == [(1, 0.4), (2, 0.8)] -def test_to_query_and_all_return_same_runs() -> None: - grouped = make_grouped() +def test_to_query_and_all_return_same_runs(grouped: GroupedQuery) -> None: assert grouped._to_query().all() == grouped.all() -def test_copy_and_deepcopy() -> None: - grouped = make_grouped() +def test_copy_and_deepcopy(grouped: GroupedQuery) -> None: shallow = grouped.copy() deep = grouped.deepcopy() diff --git a/tests/queries/test_query.py b/tests/queries/test_query.py index 7f19fb7..9ba3597 100644 --- a/tests/queries/test_query.py +++ b/tests/queries/test_query.py @@ -1,9 +1,14 @@ +from typing import List + +import pytest + from ablate.core.types import Run from ablate.queries.query import Query from ablate.queries.selectors import Metric, Param -def make_runs() -> list[Run]: +@pytest.fixture +def runs() -> List[Run]: return [ Run(id="a", params={"model": "resnet", "seed": 1}, metrics={"accuracy": 0.7}), Run(id="b", params={"model": "resnet", "seed": 2}, metrics={"accuracy": 0.8}), @@ -11,33 +16,33 @@ def make_runs() -> list[Run]: ] -def test_filter() -> None: - runs = Query(make_runs()).filter(Param("model") == "resnet").all() +def test_filter(runs: List[Run]) -> None: + runs = Query(runs).filter(Param("model") == "resnet").all() assert len(runs) == 2 assert all(r.params["model"] == "resnet" for r in runs) -def test_sort() -> None: - runs = Query(make_runs()).sort(Metric("accuracy", direction="min")).all() +def test_sort(runs: List[Run]) -> None: + runs = Query(runs).sort(Metric("accuracy", direction="min")).all() assert [r.id for r in runs] == ["c", "b", "a"] -def test_head_tail() -> None: - q = Query(make_runs()) +def test_head_tail(runs: List[Run]) -> None: + q = Query(runs) assert q.head(1).all()[0].id == "a" assert q.tail(1).all()[0].id == "c" -def test_topk_bottomk() -> None: - q = Query(make_runs()) +def test_topk_bottomk(runs: List[Run]) -> None: + q = Query(runs) top = q.topk(Metric("accuracy", direction="max"), k=2).all() bot = q.bottomk(Metric("accuracy", direction="max"), k=1).all() assert [r.id for r in top] == ["c", "b"] assert [r.id for r in bot] == ["a"] -def test_map() -> None: - q = Query(make_runs()) +def test_map(runs: List[Run]) -> None: + q = Query(runs) def upper_case_id(run: Run) -> Run: run.id = run.id.upper() @@ -45,23 +50,23 @@ def upper_case_id(run: Run) -> Run: updated = q.map(upper_case_id).all() assert updated[0].id == "A" - assert make_runs()[0].id == "a" + assert runs[0].id == "a" -def test_groupby_single_key() -> None: - gq = Query(make_runs()).groupby(Param("model")) +def test_groupby_single_key(runs: List[Run]) -> None: + gq = Query(runs).groupby(Param("model")) assert len(gq) == 2 assert {g.value for g in gq._grouped} == {"resnet", "vit"} -def test_groupby_multiple_keys() -> None: - gq = Query(make_runs()).groupby([Param("model"), Param("seed")]) +def test_groupby_multiple_keys(runs: List[Run]) -> None: + gq = Query(runs).groupby([Param("model"), Param("seed")]) keys = {(g.key, g.value) for g in gq._grouped} assert len(keys) == 3 -def test_groupdiff() -> None: - grouped = Query(make_runs()).groupdiff(Param("seed"))._grouped +def test_groupdiff(runs: List[Run]) -> None: + grouped = Query(runs).groupdiff(Param("seed"))._grouped expected_group_sizes = {"resnet": 2, "vit": 1} for group in grouped: model = group.runs[0].params["model"] @@ -69,8 +74,8 @@ def test_groupdiff() -> None: assert all(len(g.value) == 8 for g in grouped) -def test_query_copy_shallow() -> None: - original = Query(make_runs()) +def test_query_copy_shallow(runs: List[Run]) -> None: + original = Query(runs) copied = original.copy() assert copied._runs == original._runs @@ -78,13 +83,13 @@ def test_query_copy_shallow() -> None: assert copied._runs[0] is original._runs[0] -def test_query_deepcopy() -> None: - original = Query(make_runs()) +def test_query_deepcopy(runs: List[Run]) -> None: + original = Query(runs) deepcopied = original.deepcopy() assert deepcopied._runs == original._runs assert deepcopied._runs is not original._runs assert deepcopied._runs[0] is not original._runs[0] -def test_query_len() -> None: - assert len(Query(make_runs())) == 3 +def test_query_len(runs: List[Run]) -> None: + assert len(Query(runs)) == 3