Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions ablate/exporters/markdown_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,7 @@ def render_figure(self, block: AbstractFigureBlock, runs: List[Run]) -> str:
if not isinstance(block, MetricPlot):
raise NotImplementedError(f"Unsupported figure block: '{type(block)}'.")

filename = render_metric_plot(
block.build(runs),
self.assets_dir,
type(block).__name__,
)
filename = render_metric_plot(block, runs, self.assets_dir)
if filename is None:
return (
f"*No data available for {', '.join(m.label for m in block.metrics)}*"
Expand Down
14 changes: 8 additions & 6 deletions ablate/exporters/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import hashlib
from pathlib import Path
from typing import List

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from ablate.blocks import H1, H2, H3, H4, H5, H6
from ablate.blocks import H1, H2, H3, H4, H5, H6, MetricPlot
from ablate.core.types.runs import Run


HEADING_LEVELS = {H1: 1, H2: 2, H3: 3, H4: 4, H5: 5, H6: 6}
Expand All @@ -19,11 +20,12 @@ def apply_default_plot_style() -> None:


def render_metric_plot(
df: pd.DataFrame,
block: MetricPlot,
runs: List[Run],
output_dir: Path,
name_prefix: str,
) -> str | None:
apply_default_plot_style()
df = block.build(runs)
if df.empty:
return None

Expand All @@ -38,11 +40,11 @@ def render_metric_plot(
)
ax.set_xlabel("Step")
ax.set_ylabel("Value")
ax.legend(title="Run", loc="best", frameon=False)
ax.legend(title=block.identifier.label, loc="best", frameon=False)
plt.tight_layout()

h = hashlib.md5(df.to_csv(index=False).encode("utf-8")).hexdigest()[:12]
filename = f"{name_prefix}_{h}.png"
filename = f"{type(block).__name__}_{h}.png"
fig.savefig(output_dir / filename)
plt.close(fig)
return filename
35 changes: 33 additions & 2 deletions ablate/queries/grouped_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from collections import defaultdict
from copy import deepcopy
from typing import TYPE_CHECKING, Callable, Dict, List, Literal
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Union

from ablate.core.types import GroupedRun, Run


if TYPE_CHECKING: # pragma: no cover
from .query import Query # noqa: TC004
from .selectors import AbstractMetric
from .selectors import AbstractMetric, AbstractParam


class GroupedQuery:
Expand Down Expand Up @@ -75,6 +75,35 @@ def sort(self, key: AbstractMetric, ascending: bool = False) -> GroupedQuery:
]
)

def project(
self, selectors: Union[AbstractParam, List[AbstractParam]]
) -> GroupedQuery:
"""Project the parameter space of the grouped runs in the grouped query to a
subset of parameters only including the specified selectors.

This function is intended to be used for reducing the dimensionality of the
parameter space and therefore operates on a deep copy of the grouped runs in the
grouped query.

Args:
selectors: Selector or list of selectors to project the grouped runs by.

Returns:
A new grouped query with the projected grouped runs.
"""
if not isinstance(selectors, list):
selectors = [selectors]

names = {s.name for s in selectors}
projected: List[GroupedRun] = []

for group in deepcopy(self._grouped):
for run in group.runs:
run.params = {k: v for k, v in run.params.items() if k in names}
projected.append(group)

return GroupedQuery(projected)

def head(self, n: int) -> Query:
"""Get the first n runs inside each grouped run.

Expand Down Expand Up @@ -195,8 +224,10 @@ def aggregate(
case "last":
return self.tail(1)
case "best":
assert over is not None
return self.topk(over, 1)
case "worst":
assert over is not None
return self.bottomk(over, 1)
case "mean":
return Query([self._mean_run(g) for g in self._grouped])
Expand Down
25 changes: 25 additions & 0 deletions ablate/queries/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,31 @@ def sort(self, key: AbstractMetric, ascending: bool = False) -> Query:
"""
return Query(sorted(self._runs[:], key=key, reverse=not ascending))

def project(self, selectors: Union[AbstractParam, List[AbstractParam]]) -> Query:
"""Project the parameter space of the runs in the query to a subset of
parameters only including the specified selectors.

This function is intended to be used for reducing the dimensionality of the
parameter space and therefore operates on a deep copy of the runs in the query.

Args:
selectors: Selector or list of selectors to project the runs by.

Returns:
A new query with the projected runs.
"""
if not isinstance(selectors, list):
selectors = [selectors]

names = {s.name for s in selectors}
projected: List[Run] = []

for run in deepcopy(self._runs):
run.params = {k: v for k, v in run.params.items() if k in names}
projected.append(run)

return Query(projected)

def groupby(
self,
selectors: Union[AbstractParam, List[AbstractParam]],
Expand Down
8 changes: 8 additions & 0 deletions tests/queries/test_grouped_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,11 @@ def test_copy_and_deepcopy(grouped: GroupedQuery) -> None:
assert all(
dr is not gr for dr, gr in zip(deep._grouped, grouped._grouped, strict=False)
)


def test_grouped_query_project_reduces_param_space(grouped: GroupedQuery) -> None:
grouped = grouped.project(Param("model"))
for group in grouped._grouped:
for run in group.runs:
assert set(run.params.keys()) == {"model"}
assert set(run.metrics.keys()) == {"accuracy"}
7 changes: 7 additions & 0 deletions tests/queries/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,10 @@ def test_query_deepcopy(runs: List[Run]) -> None:

def test_query_len(runs: List[Run]) -> None:
assert len(Query(runs)) == 3


def test_project_reduces_parameter_space(runs: List[Run]) -> None:
q = Query(runs).project(Param("model"))
for run in q.all():
assert set(run.params.keys()) == {"model"}
assert set(run.metrics.keys()) == {"accuracy"}