diff --git a/ablate/blocks/figure_blocks.py b/ablate/blocks/figure_blocks.py index c109c2f..6cc7c17 100644 --- a/ablate/blocks/figure_blocks.py +++ b/ablate/blocks/figure_blocks.py @@ -22,21 +22,21 @@ def build(self, runs: List[Run]) -> pd.DataFrame: ... class MetricPlot(AbstractFigureBlock): def __init__( self, - metric: AbstractMetric | List[AbstractMetric], + metrics: AbstractMetric | List[AbstractMetric], identifier: Param | None = None, runs: List[Run] | None = None, ) -> None: """Block for plotting metrics over time. Args: - metric: Metric or list of metrics to be plotted over time. + metrics: Metric or list of metrics to be plotted over time. identifier: Optional identifier for the runs. If None, the run ID is used. Defaults to None. runs: Optional list of runs to be used for the block instead of the default runs from the report. Defaults to None. """ super().__init__(runs) - self.metrics = metric if isinstance(metric, list) else [metric] + self.metrics = metrics if isinstance(metrics, list) else [metrics] self.identifier = identifier or Id() def build(self, runs: List[Run]) -> pd.DataFrame: diff --git a/tests/blocks/test_blocks.py b/tests/blocks/test_blocks.py index f15185f..5b54204 100644 --- a/tests/blocks/test_blocks.py +++ b/tests/blocks/test_blocks.py @@ -38,9 +38,7 @@ def test_table_block() -> None: def test_metric_plot_single() -> None: - plot = MetricPlot( - metric=Metric("accuracy", direction="max"), identifier=Param("seed") - ) + plot = MetricPlot(Metric("accuracy", direction="max"), identifier=Param("seed")) df = plot.build(make_runs()) assert isinstance(df, pd.DataFrame) assert set(df.columns) >= {"step", "value", "metric", "run", "run_id"} @@ -48,9 +46,7 @@ def test_metric_plot_single() -> None: def test_metric_plot_multi() -> None: - plot = MetricPlot( - metric=[Metric("accuracy", direction="max")], identifier=Param("seed") - ) + plot = MetricPlot([Metric("accuracy", direction="max")], identifier=Param("seed")) df = plot.build(make_runs()) assert isinstance(df, pd.DataFrame) assert all(k in df.columns for k in ["step", "value", "metric", "run"])