Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions ablate/exporters/markdown_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
from ablate.exporters.abstract_exporter import AbstractExporter
from ablate.report import Report

from .utils import HEADING_LEVELS, render_metric_plot
from .utils import HEADING_LEVELS, hash_dataframe, render_metric_plot


class Markdown(AbstractExporter):
def __init__(
self,
output_path: str = "report.md",
assets_dir: str | None = None,
export_csv: bool = False,
) -> None:
"""Export the report as a markdown file.

Expand All @@ -29,12 +30,15 @@ def __init__(
assets_dir: The directory to store the assets (figures, etc.). If None,
defaults to the parent directory of the output file with a ".ablate"
subdirectory. Defaults to None.
export_csv: Whether to export tables and plots as CSV files.
Defaults to False.
"""
self.output_path = Path(output_path)
self.assets_dir = (
Path(assets_dir) if assets_dir else self.output_path.parent / ".ablate"
)
self.assets_dir.mkdir(exist_ok=True)
self.export_csv = export_csv

def export(self, report: Report) -> None:
content = self.render_blocks(report)
Expand All @@ -51,12 +55,25 @@ def render_text(self, block: AbstractTextBlock, runs: List[Run]) -> str:
raise NotImplementedError(f"Unsupported text block: '{type(block)}'.")

def render_table(self, block: AbstractTableBlock, runs: List[Run]) -> str:
return block.build(runs).to_markdown(index=False)
df = block.build(runs)
if self.export_csv:
df.to_csv(
self.assets_dir / f"{type(block).__name__}_{hash_dataframe(df)}.csv",
index=False,
)
return df.to_markdown(index=False)

def render_figure(self, block: AbstractFigureBlock, runs: List[Run]) -> str:
if not isinstance(block, MetricPlot):
raise NotImplementedError(f"Unsupported figure block: '{type(block)}'.")

if self.export_csv:
df = block.build(runs)
df.to_csv(
self.assets_dir / f"{type(block).__name__}_{hash_dataframe(df)}.csv",
index=False,
)

filename = render_metric_plot(block, runs, self.assets_dir)
if filename is None:
return (
Expand Down
8 changes: 6 additions & 2 deletions ablate/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from ablate.blocks import H1, H2, H3, H4, H5, H6, MetricPlot
Expand All @@ -19,6 +20,10 @@ def apply_default_plot_style() -> None:
plt.rcParams["figure.dpi"] = 300


def hash_dataframe(df: pd.DataFrame) -> str:
return hashlib.md5(df.to_csv(index=False).encode("utf-8")).hexdigest()[:12]


def render_metric_plot(
block: MetricPlot,
runs: List[Run],
Expand All @@ -43,8 +48,7 @@ def render_metric_plot(
ax.legend(title=block.identifier.label, loc="best", frameon=False)
plt.tight_layout()

h = hashlib.md5(df.to_csv(index=False).encode("utf-8")).hexdigest()[:12]
filename = f"{type(block).__name__}_{h}.png"
filename = f"{type(block).__name__}_{hash_dataframe(df)}.png"
fig.savefig(output_dir / filename)
plt.close(fig)
return filename
37 changes: 37 additions & 0 deletions tests/exporters/test_markdown_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,40 @@ def test_export_heading_variants(tmp_path: Path, runs: List[Run]) -> None:
Markdown(output_path=str(out_path)).export(report)
content = out_path.read_text()
assert "## Section Title" in content


def test_export_table_block_with_csv(tmp_path: Path, runs: List[Run]) -> None:
table = Table(
columns=[Param("model"), Metric("accuracy", direction="max")],
)
report = Report(runs).add(table)
out_path = tmp_path / "report.md"
exporter = Markdown(output_path=str(out_path), export_csv=True)
exporter.export(report)

content = out_path.read_text()
assert "resnet" in content
asset_dir = tmp_path / ".ablate"
csv_files = list(asset_dir.glob("Table_*.csv"))
assert len(csv_files) == 1
csv_content = csv_files[0].read_text()
assert "model,accuracy" in csv_content
assert "resnet,0.8" in csv_content


def test_export_figure_block_with_csv(tmp_path: Path, runs: List[Run]) -> None:
plot = MetricPlot(Metric("accuracy", direction="max"), identifier=Param("model"))
report = Report(runs).add(plot)
out_path = tmp_path / "report.md"
exporter = Markdown(output_path=str(out_path), export_csv=True)
exporter.export(report)

asset_dir = tmp_path / ".ablate"
png_files = list(asset_dir.glob("MetricPlot_*.png"))
assert len(png_files) == 1
csv_files = list(asset_dir.glob("MetricPlot_*.csv"))
assert len(csv_files) == 1
csv_content = csv_files[0].read_text()
assert "step,value,metric,run,run_id" in csv_content
assert "0,0.5,accuracy,resnet,run1" in csv_content
assert "1,0.9,accuracy,resnet,run2" in csv_content