Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions ablate/queries/grouped_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def bottomk(self, metric: AbstractMetric, k: int) -> Query:
def aggregate(
self,
method: Literal["first", "last", "best", "worst", "mean"],
over: AbstractMetric,
over: AbstractMetric | None = None,
) -> Query:
"""Aggregate each group of runs using a specified method.

Expand All @@ -172,13 +172,21 @@ def aggregate(
Args:
method: Aggregation strategy to apply per group.
over: The metric used for comparison when using "best" or "worst" methods.
Has no effect for "first", "last", or "mean" methods.
Defaults to None.

Raises:
ValueError: If an unsupported aggregation method is provided.
ValueError: If an unsupported aggregation method is provided or if the
"best" or "worst" method is used without a specified metric.


Returns:
A new query with the aggregated runs from each group.
"""
if method in {"best", "worst"} and over is None:
raise ValueError(
f"Method '{method}' requires a metric to be specified for comparison."
)
from .query import Query

match method:
Expand Down
2 changes: 1 addition & 1 deletion ablate/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import Self


if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
from ablate.blocks import AbstractBlock
from ablate.core.types import Run

Expand Down
42 changes: 24 additions & 18 deletions tests/queries/test_grouped_query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import pytest

from ablate.core.types import GroupedRun, Run
Expand All @@ -6,7 +8,8 @@
from ablate.queries.selectors import Metric, Param


def make_runs() -> list[Run]:
@pytest.fixture
def runs() -> List[Run]:
return [
Run(id="a", params={"model": "resnet", "seed": 1}, metrics={"accuracy": 0.7}),
Run(id="b", params={"model": "resnet", "seed": 2}, metrics={"accuracy": 0.8}),
Expand All @@ -15,20 +18,18 @@ def make_runs() -> list[Run]:
]


def make_grouped() -> GroupedQuery:
return Query(make_runs()).groupby(Param("model"))
@pytest.fixture
def grouped(runs: List[Run]) -> GroupedQuery:
return Query(runs).groupby(Param("model"))


def test_filter_keeps_matching_groups() -> None:
grouped = make_grouped()
def test_filter_keeps_matching_groups(grouped: GroupedQuery) -> None:
filtered = grouped.filter(lambda g: g.value == "resnet")
assert len(filtered) == 1
assert all(run.params["model"] == "resnet" for run in filtered.all())


def test_map_modifies_each_group() -> None:
grouped = make_grouped()

def test_map_modifies_each_group(grouped: GroupedQuery) -> None:
def fn(group: GroupedRun) -> GroupedRun:
group.runs = [Run(**{**r.model_dump(), "id": r.id + "_x"}) for r in group.runs]
return group
Expand All @@ -37,23 +38,23 @@ def fn(group: GroupedRun) -> GroupedRun:
assert all(r.id.endswith("_x") for r in mapped.all())


def test_sort_sorts_within_each_group() -> None:
grouped = make_grouped().sort(Metric("accuracy", direction="max"), ascending=True)
def test_sort_sorts_within_each_group(grouped: GroupedQuery) -> None:
grouped = grouped.sort(Metric("accuracy", direction="max"), ascending=True)
for group in grouped._grouped:
accs = [r.metrics["accuracy"] for r in group.runs]
assert accs == sorted(accs)


def test_head_tail_topk_bottomk_all_return_expected_shape() -> None:
grouped = make_grouped()
def test_head_tail_topk_bottomk_all_return_expected_shape(
grouped: GroupedQuery,
) -> None:
assert len(grouped.head(1).all()) == 2
assert len(grouped.tail(1).all()) == 2
assert len(grouped.topk(Metric("accuracy", direction="max"), 1).all()) == 2
assert len(grouped.bottomk(Metric("accuracy", direction="max"), 1).all()) == 2


def test_aggregate_all_strategies() -> None:
grouped = make_grouped()
def test_aggregate_all_strategies(grouped: GroupedQuery) -> None:
m = Metric("accuracy", direction="max")
assert len(grouped.aggregate("first", over=m).all()) == 2
assert len(grouped.aggregate("last", over=m).all()) == 2
Expand All @@ -65,6 +66,13 @@ def test_aggregate_all_strategies() -> None:
grouped.aggregate("unsupported", over=m)


def test_aggregate_best_worst_missing_over(grouped: GroupedQuery) -> None:
with pytest.raises(ValueError, match="Method 'best' requires a metric"):
grouped.aggregate("best")
with pytest.raises(ValueError, match="Method 'worst' requires a metric"):
grouped.aggregate("worst")


def test_aggregate_mean_collapses_metadata_and_temporal() -> None:
run1 = Run(
id="a",
Expand All @@ -87,13 +95,11 @@ def test_aggregate_mean_collapses_metadata_and_temporal() -> None:
assert agg.temporal["acc"] == [(1, 0.4), (2, 0.8)]


def test_to_query_and_all_return_same_runs() -> None:
grouped = make_grouped()
def test_to_query_and_all_return_same_runs(grouped: GroupedQuery) -> None:
assert grouped._to_query().all() == grouped.all()


def test_copy_and_deepcopy() -> None:
grouped = make_grouped()
def test_copy_and_deepcopy(grouped: GroupedQuery) -> None:
shallow = grouped.copy()
deep = grouped.deepcopy()

Expand Down
53 changes: 29 additions & 24 deletions tests/queries/test_query.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,95 @@
from typing import List

import pytest

from ablate.core.types import Run
from ablate.queries.query import Query
from ablate.queries.selectors import Metric, Param


def make_runs() -> list[Run]:
@pytest.fixture
def runs() -> List[Run]:
return [
Run(id="a", params={"model": "resnet", "seed": 1}, metrics={"accuracy": 0.7}),
Run(id="b", params={"model": "resnet", "seed": 2}, metrics={"accuracy": 0.8}),
Run(id="c", params={"model": "vit", "seed": 1}, metrics={"accuracy": 0.9}),
]


def test_filter() -> None:
runs = Query(make_runs()).filter(Param("model") == "resnet").all()
def test_filter(runs: List[Run]) -> None:
runs = Query(runs).filter(Param("model") == "resnet").all()
assert len(runs) == 2
assert all(r.params["model"] == "resnet" for r in runs)


def test_sort() -> None:
runs = Query(make_runs()).sort(Metric("accuracy", direction="min")).all()
def test_sort(runs: List[Run]) -> None:
runs = Query(runs).sort(Metric("accuracy", direction="min")).all()
assert [r.id for r in runs] == ["c", "b", "a"]


def test_head_tail() -> None:
q = Query(make_runs())
def test_head_tail(runs: List[Run]) -> None:
q = Query(runs)
assert q.head(1).all()[0].id == "a"
assert q.tail(1).all()[0].id == "c"


def test_topk_bottomk() -> None:
q = Query(make_runs())
def test_topk_bottomk(runs: List[Run]) -> None:
q = Query(runs)
top = q.topk(Metric("accuracy", direction="max"), k=2).all()
bot = q.bottomk(Metric("accuracy", direction="max"), k=1).all()
assert [r.id for r in top] == ["c", "b"]
assert [r.id for r in bot] == ["a"]


def test_map() -> None:
q = Query(make_runs())
def test_map(runs: List[Run]) -> None:
q = Query(runs)

def upper_case_id(run: Run) -> Run:
run.id = run.id.upper()
return run

updated = q.map(upper_case_id).all()
assert updated[0].id == "A"
assert make_runs()[0].id == "a"
assert runs[0].id == "a"


def test_groupby_single_key() -> None:
gq = Query(make_runs()).groupby(Param("model"))
def test_groupby_single_key(runs: List[Run]) -> None:
gq = Query(runs).groupby(Param("model"))
assert len(gq) == 2
assert {g.value for g in gq._grouped} == {"resnet", "vit"}


def test_groupby_multiple_keys() -> None:
gq = Query(make_runs()).groupby([Param("model"), Param("seed")])
def test_groupby_multiple_keys(runs: List[Run]) -> None:
gq = Query(runs).groupby([Param("model"), Param("seed")])
keys = {(g.key, g.value) for g in gq._grouped}
assert len(keys) == 3


def test_groupdiff() -> None:
grouped = Query(make_runs()).groupdiff(Param("seed"))._grouped
def test_groupdiff(runs: List[Run]) -> None:
grouped = Query(runs).groupdiff(Param("seed"))._grouped
expected_group_sizes = {"resnet": 2, "vit": 1}
for group in grouped:
model = group.runs[0].params["model"]
assert len(group.runs) == expected_group_sizes[model]
assert all(len(g.value) == 8 for g in grouped)


def test_query_copy_shallow() -> None:
original = Query(make_runs())
def test_query_copy_shallow(runs: List[Run]) -> None:
original = Query(runs)
copied = original.copy()

assert copied._runs == original._runs
assert copied._runs is not original._runs
assert copied._runs[0] is original._runs[0]


def test_query_deepcopy() -> None:
original = Query(make_runs())
def test_query_deepcopy(runs: List[Run]) -> None:
original = Query(runs)
deepcopied = original.deepcopy()
assert deepcopied._runs == original._runs
assert deepcopied._runs is not original._runs
assert deepcopied._runs[0] is not original._runs[0]


def test_query_len() -> None:
assert len(Query(make_runs())) == 3
def test_query_len(runs: List[Run]) -> None:
assert len(Query(runs)) == 3