Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ablate/blocks/figure_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,21 @@ def build(self, runs: List[Run]) -> pd.DataFrame: ...
class MetricPlot(AbstractFigureBlock):
def __init__(
self,
metric: AbstractMetric | List[AbstractMetric],
metrics: AbstractMetric | List[AbstractMetric],
identifier: Param | None = None,
runs: List[Run] | None = None,
) -> None:
"""Block for plotting metrics over time.

Args:
metric: Metric or list of metrics to be plotted over time.
metrics: Metric or list of metrics to be plotted over time.
identifier: Optional identifier for the runs. If None, the run ID is used.
Defaults to None.
runs: Optional list of runs to be used for the block instead of the default
runs from the report. Defaults to None.
"""
super().__init__(runs)
self.metrics = metric if isinstance(metric, list) else [metric]
self.metrics = metrics if isinstance(metrics, list) else [metrics]
self.identifier = identifier or Id()

def build(self, runs: List[Run]) -> pd.DataFrame:
Expand Down
8 changes: 2 additions & 6 deletions tests/blocks/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,15 @@ def test_table_block() -> None:


def test_metric_plot_single() -> None:
plot = MetricPlot(
metric=Metric("accuracy", direction="max"), identifier=Param("seed")
)
plot = MetricPlot(Metric("accuracy", direction="max"), identifier=Param("seed"))
df = plot.build(make_runs())
assert isinstance(df, pd.DataFrame)
assert set(df.columns) >= {"step", "value", "metric", "run", "run_id"}
assert df["metric"].unique().tolist() == ["accuracy"]


def test_metric_plot_multi() -> None:
plot = MetricPlot(
metric=[Metric("accuracy", direction="max")], identifier=Param("seed")
)
plot = MetricPlot([Metric("accuracy", direction="max")], identifier=Param("seed"))
df = plot.build(make_runs())
assert isinstance(df, pd.DataFrame)
assert all(k in df.columns for k in ["step", "value", "metric", "run"])