diff --git a/ablate/exporters/markdown_exporter.py b/ablate/exporters/markdown_exporter.py index 1cbf12a..4db50e9 100644 --- a/ablate/exporters/markdown_exporter.py +++ b/ablate/exporters/markdown_exporter.py @@ -57,11 +57,7 @@ def render_figure(self, block: AbstractFigureBlock, runs: List[Run]) -> str: if not isinstance(block, MetricPlot): raise NotImplementedError(f"Unsupported figure block: '{type(block)}'.") - filename = render_metric_plot( - block.build(runs), - self.assets_dir, - type(block).__name__, - ) + filename = render_metric_plot(block, runs, self.assets_dir) if filename is None: return ( f"*No data available for {', '.join(m.label for m in block.metrics)}*" diff --git a/ablate/exporters/utils.py b/ablate/exporters/utils.py index e710e68..08cd170 100644 --- a/ablate/exporters/utils.py +++ b/ablate/exporters/utils.py @@ -1,11 +1,12 @@ import hashlib from pathlib import Path +from typing import List import matplotlib.pyplot as plt -import pandas as pd import seaborn as sns -from ablate.blocks import H1, H2, H3, H4, H5, H6 +from ablate.blocks import H1, H2, H3, H4, H5, H6, MetricPlot +from ablate.core.types.runs import Run HEADING_LEVELS = {H1: 1, H2: 2, H3: 3, H4: 4, H5: 5, H6: 6} @@ -19,11 +20,12 @@ def apply_default_plot_style() -> None: def render_metric_plot( - df: pd.DataFrame, + block: MetricPlot, + runs: List[Run], output_dir: Path, - name_prefix: str, ) -> str | None: apply_default_plot_style() + df = block.build(runs) if df.empty: return None @@ -38,11 +40,11 @@ def render_metric_plot( ) ax.set_xlabel("Step") ax.set_ylabel("Value") - ax.legend(title="Run", loc="best", frameon=False) + ax.legend(title=block.identifier.label, loc="best", frameon=False) plt.tight_layout() h = hashlib.md5(df.to_csv(index=False).encode("utf-8")).hexdigest()[:12] - filename = f"{name_prefix}_{h}.png" + filename = f"{type(block).__name__}_{h}.png" fig.savefig(output_dir / filename) plt.close(fig) return filename diff --git a/ablate/queries/grouped_query.py b/ablate/queries/grouped_query.py index a49d954..056b9bc 100644 --- a/ablate/queries/grouped_query.py +++ b/ablate/queries/grouped_query.py @@ -2,14 +2,14 @@ from collections import defaultdict from copy import deepcopy -from typing import TYPE_CHECKING, Callable, Dict, List, Literal +from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Union from ablate.core.types import GroupedRun, Run if TYPE_CHECKING: # pragma: no cover from .query import Query # noqa: TC004 - from .selectors import AbstractMetric + from .selectors import AbstractMetric, AbstractParam class GroupedQuery: @@ -75,6 +75,35 @@ def sort(self, key: AbstractMetric, ascending: bool = False) -> GroupedQuery: ] ) + def project( + self, selectors: Union[AbstractParam, List[AbstractParam]] + ) -> GroupedQuery: + """Project the parameter space of the grouped runs in the grouped query to a + subset of parameters only including the specified selectors. + + This function is intended to be used for reducing the dimensionality of the + parameter space and therefore operates on a deep copy of the grouped runs in the + grouped query. + + Args: + selectors: Selector or list of selectors to project the grouped runs by. + + Returns: + A new grouped query with the projected grouped runs. + """ + if not isinstance(selectors, list): + selectors = [selectors] + + names = {s.name for s in selectors} + projected: List[GroupedRun] = [] + + for group in deepcopy(self._grouped): + for run in group.runs: + run.params = {k: v for k, v in run.params.items() if k in names} + projected.append(group) + + return GroupedQuery(projected) + def head(self, n: int) -> Query: """Get the first n runs inside each grouped run. @@ -195,8 +224,10 @@ def aggregate( case "last": return self.tail(1) case "best": + assert over is not None return self.topk(over, 1) case "worst": + assert over is not None return self.bottomk(over, 1) case "mean": return Query([self._mean_run(g) for g in self._grouped]) diff --git a/ablate/queries/query.py b/ablate/queries/query.py index 87306c5..af6c45f 100644 --- a/ablate/queries/query.py +++ b/ablate/queries/query.py @@ -64,6 +64,31 @@ def sort(self, key: AbstractMetric, ascending: bool = False) -> Query: """ return Query(sorted(self._runs[:], key=key, reverse=not ascending)) + def project(self, selectors: Union[AbstractParam, List[AbstractParam]]) -> Query: + """Project the parameter space of the runs in the query to a subset of + parameters only including the specified selectors. + + This function is intended to be used for reducing the dimensionality of the + parameter space and therefore operates on a deep copy of the runs in the query. + + Args: + selectors: Selector or list of selectors to project the runs by. + + Returns: + A new query with the projected runs. + """ + if not isinstance(selectors, list): + selectors = [selectors] + + names = {s.name for s in selectors} + projected: List[Run] = [] + + for run in deepcopy(self._runs): + run.params = {k: v for k, v in run.params.items() if k in names} + projected.append(run) + + return Query(projected) + def groupby( self, selectors: Union[AbstractParam, List[AbstractParam]], diff --git a/tests/queries/test_grouped_query.py b/tests/queries/test_grouped_query.py index 5093415..3f31fe5 100644 --- a/tests/queries/test_grouped_query.py +++ b/tests/queries/test_grouped_query.py @@ -111,3 +111,11 @@ def test_copy_and_deepcopy(grouped: GroupedQuery) -> None: assert all( dr is not gr for dr, gr in zip(deep._grouped, grouped._grouped, strict=False) ) + + +def test_grouped_query_project_reduces_param_space(grouped: GroupedQuery) -> None: + grouped = grouped.project(Param("model")) + for group in grouped._grouped: + for run in group.runs: + assert set(run.params.keys()) == {"model"} + assert set(run.metrics.keys()) == {"accuracy"} diff --git a/tests/queries/test_query.py b/tests/queries/test_query.py index 9ba3597..7ed459f 100644 --- a/tests/queries/test_query.py +++ b/tests/queries/test_query.py @@ -93,3 +93,10 @@ def test_query_deepcopy(runs: List[Run]) -> None: def test_query_len(runs: List[Run]) -> None: assert len(Query(runs)) == 3 + + +def test_project_reduces_parameter_space(runs: List[Run]) -> None: + q = Query(runs).project(Param("model")) + for run in q.all(): + assert set(run.params.keys()) == {"model"} + assert set(run.metrics.keys()) == {"accuracy"}