From b154665b859c3137b7e06b6a3b4f6e370d95f541 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Mon, 12 May 2025 19:11:39 +0200 Subject: [PATCH 1/2] add optional csv export to markdown exporter --- ablate/exporters/markdown_exporter.py | 21 +++++++++++++++++++-- ablate/exporters/utils.py | 8 ++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/ablate/exporters/markdown_exporter.py b/ablate/exporters/markdown_exporter.py index 4db50e9..67b43e8 100644 --- a/ablate/exporters/markdown_exporter.py +++ b/ablate/exporters/markdown_exporter.py @@ -13,7 +13,7 @@ from ablate.exporters.abstract_exporter import AbstractExporter from ablate.report import Report -from .utils import HEADING_LEVELS, render_metric_plot +from .utils import HEADING_LEVELS, hash_dataframe, render_metric_plot class Markdown(AbstractExporter): @@ -21,6 +21,7 @@ def __init__( self, output_path: str = "report.md", assets_dir: str | None = None, + export_csv: bool = False, ) -> None: """Export the report as a markdown file. @@ -29,12 +30,15 @@ def __init__( assets_dir: The directory to store the assets (figures, etc.). If None, defaults to the parent directory of the output file with a ".ablate" subdirectory. Defaults to None. + export_csv: Whether to export tables and plots as CSV files. + Defaults to False. """ self.output_path = Path(output_path) self.assets_dir = ( Path(assets_dir) if assets_dir else self.output_path.parent / ".ablate" ) self.assets_dir.mkdir(exist_ok=True) + self.export_csv = export_csv def export(self, report: Report) -> None: content = self.render_blocks(report) @@ -51,12 +55,25 @@ def render_text(self, block: AbstractTextBlock, runs: List[Run]) -> str: raise NotImplementedError(f"Unsupported text block: '{type(block)}'.") def render_table(self, block: AbstractTableBlock, runs: List[Run]) -> str: - return block.build(runs).to_markdown(index=False) + df = block.build(runs) + if self.export_csv: + df.to_csv( + self.assets_dir / f"{type(block).__name__}_{hash_dataframe(df)}.csv", + index=False, + ) + return df.to_markdown(index=False) def render_figure(self, block: AbstractFigureBlock, runs: List[Run]) -> str: if not isinstance(block, MetricPlot): raise NotImplementedError(f"Unsupported figure block: '{type(block)}'.") + if self.export_csv: + df = block.build(runs) + df.to_csv( + self.assets_dir / f"{type(block).__name__}_{hash_dataframe(df)}.csv", + index=False, + ) + filename = render_metric_plot(block, runs, self.assets_dir) if filename is None: return ( diff --git a/ablate/exporters/utils.py b/ablate/exporters/utils.py index 08cd170..0afd863 100644 --- a/ablate/exporters/utils.py +++ b/ablate/exporters/utils.py @@ -3,6 +3,7 @@ from typing import List import matplotlib.pyplot as plt +import pandas as pd import seaborn as sns from ablate.blocks import H1, H2, H3, H4, H5, H6, MetricPlot @@ -19,6 +20,10 @@ def apply_default_plot_style() -> None: plt.rcParams["figure.dpi"] = 300 +def hash_dataframe(df: pd.DataFrame) -> str: + return hashlib.md5(df.to_csv(index=False).encode("utf-8")).hexdigest()[:12] + + def render_metric_plot( block: MetricPlot, runs: List[Run], @@ -43,8 +48,7 @@ def render_metric_plot( ax.legend(title=block.identifier.label, loc="best", frameon=False) plt.tight_layout() - h = hashlib.md5(df.to_csv(index=False).encode("utf-8")).hexdigest()[:12] - filename = f"{type(block).__name__}_{h}.png" + filename = f"{type(block).__name__}_{hash_dataframe(df)}.png" fig.savefig(output_dir / filename) plt.close(fig) return filename From 0d9e9063636ac0ec32b94ca60c7a6ea0128d4d70 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Mon, 12 May 2025 19:11:45 +0200 Subject: [PATCH 2/2] add tests --- tests/exporters/test_markdown_exporter.py | 37 +++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/exporters/test_markdown_exporter.py b/tests/exporters/test_markdown_exporter.py index 1f731a2..324b5f7 100644 --- a/tests/exporters/test_markdown_exporter.py +++ b/tests/exporters/test_markdown_exporter.py @@ -150,3 +150,40 @@ def test_export_heading_variants(tmp_path: Path, runs: List[Run]) -> None: Markdown(output_path=str(out_path)).export(report) content = out_path.read_text() assert "## Section Title" in content + + +def test_export_table_block_with_csv(tmp_path: Path, runs: List[Run]) -> None: + table = Table( + columns=[Param("model"), Metric("accuracy", direction="max")], + ) + report = Report(runs).add(table) + out_path = tmp_path / "report.md" + exporter = Markdown(output_path=str(out_path), export_csv=True) + exporter.export(report) + + content = out_path.read_text() + assert "resnet" in content + asset_dir = tmp_path / ".ablate" + csv_files = list(asset_dir.glob("Table_*.csv")) + assert len(csv_files) == 1 + csv_content = csv_files[0].read_text() + assert "model,accuracy" in csv_content + assert "resnet,0.8" in csv_content + + +def test_export_figure_block_with_csv(tmp_path: Path, runs: List[Run]) -> None: + plot = MetricPlot(Metric("accuracy", direction="max"), identifier=Param("model")) + report = Report(runs).add(plot) + out_path = tmp_path / "report.md" + exporter = Markdown(output_path=str(out_path), export_csv=True) + exporter.export(report) + + asset_dir = tmp_path / ".ablate" + png_files = list(asset_dir.glob("MetricPlot_*.png")) + assert len(png_files) == 1 + csv_files = list(asset_dir.glob("MetricPlot_*.csv")) + assert len(csv_files) == 1 + csv_content = csv_files[0].read_text() + assert "step,value,metric,run,run_id" in csv_content + assert "0,0.5,accuracy,resnet,run1" in csv_content + assert "1,0.9,accuracy,resnet,run2" in csv_content