Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions ablate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from . import blocks, queries, sources
from . import blocks, exporters, queries, sources
from .report import Report


__all__ = ["blocks", "queries", "sources"]
__all__ = ["blocks", "exporters", "queries", "Report", "sources"]

__version__ = "0.1.0"
3 changes: 2 additions & 1 deletion ablate/blocks/abstract_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(self, runs: List[Run] | None = None) -> None:

@abstractmethod
def build(self, runs: List[Run]) -> Any:
"""Build the intermediate representation of the block, ready for rendering.
"""Build the intermediate representation of the block, ready for rendering
and export.

Args:
runs: List of runs to be used for the block.
Expand Down
2 changes: 1 addition & 1 deletion ablate/blocks/figure_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AbstractFigureBlock(AbstractBlock, ABC):
def build(self, runs: List[Run]) -> pd.DataFrame: ...


class MetricPlot(AbstractBlock):
class MetricPlot(AbstractFigureBlock):
def __init__(
self,
metric: AbstractMetric | List[AbstractMetric],
Expand Down
15 changes: 9 additions & 6 deletions ablate/blocks/text_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,25 @@ def build(self, runs: List[Run]) -> str:
return self.text.strip()


class H1(AbstractTextBlock): ...
class _Heading(AbstractTextBlock): ...


class H2(AbstractTextBlock): ...
class H1(_Heading): ...


class H3(AbstractTextBlock): ...
class H2(_Heading): ...


class H4(AbstractTextBlock): ...
class H3(_Heading): ...


class H5(AbstractTextBlock): ...
class H4(_Heading): ...


class H6(AbstractTextBlock): ...
class H5(_Heading): ...


class H6(_Heading): ...


class Text(AbstractTextBlock): ...
5 changes: 5 additions & 0 deletions ablate/exporters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .abstract_exporter import AbstractExporter
from .markdown_exporter import Markdown


__all__ = ["AbstractExporter", "Markdown"]
97 changes: 97 additions & 0 deletions ablate/exporters/abstract_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, List, cast

from ablate.blocks import (
AbstractBlock,
AbstractFigureBlock,
AbstractTableBlock,
AbstractTextBlock,
)
from ablate.core.types import Run
from ablate.report import Report


class AbstractExporter(ABC):
@abstractmethod
def export(self, report: Report) -> None:
"""Export the report.

Should call the `render_blocks` to generate the content of the report.

Args:
report: The report to be exported.
"""

def render_blocks(self, report: Report) -> List[Any]:
"""Render a blocks of the report.

Args:
report: The report to be rendered.

Raises:
ValueError: If the block type is not supported.

Returns:
List of rendered blocks.
"""
render_map = {
AbstractTextBlock: self.render_text,
AbstractTableBlock: self.render_table,
AbstractFigureBlock: self.render_figure,
}
content = []
for block in report.blocks:
for block_type, render_fn in render_map.items():
if isinstance(block, block_type):
fn = cast("Callable[[AbstractBlock, List[Run]], Any]", render_fn)
content.append(self._apply_render_fn(block, fn, report.runs))
break
else:
raise ValueError(f"Unknown block type: '{type(block)}'.")
return content

@staticmethod
def _apply_render_fn(
block: AbstractBlock,
fn: Callable[[AbstractBlock, List[Run]], Any],
runs: List[Run],
) -> Any:
if block.runs:
return fn(block, block.runs)
return fn(block, runs)

@abstractmethod
def render_text(self, block: AbstractTextBlock, runs: List[Run]) -> Any:
"""Render a text block.

Args:
block: The text block to be rendered.
runs: The list of runs to be used for the block.

Returns:
The rendered text block.
"""

@abstractmethod
def render_table(self, block: AbstractTableBlock, runs: List[Run]) -> Any:
"""Render a table block.

Args:
block: The table block to be rendered.
runs: The list of runs to be used for the block.

Returns:
The rendered table block.
"""

@abstractmethod
def render_figure(self, block: AbstractFigureBlock, runs: List[Run]) -> Any:
"""Render a figure block.

Args:
block: The figure block to be rendered.
runs: The list of runs to be used for the block.

Returns:
The rendered figure block.
"""
69 changes: 69 additions & 0 deletions ablate/exporters/markdown_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from pathlib import Path
from typing import List

from ablate.blocks import (
AbstractFigureBlock,
AbstractTableBlock,
AbstractTextBlock,
MetricPlot,
Text,
)
from ablate.blocks.text_blocks import _Heading
from ablate.core.types import Run
from ablate.exporters.abstract_exporter import AbstractExporter
from ablate.report import Report

from .utils import HEADING_LEVELS, render_metric_plot


class Markdown(AbstractExporter):
def __init__(
self,
output_path: str = "report.md",
assets_dir: str | None = None,
) -> None:
"""Export the report as a markdown file.

Args:
output_path: The path to the output markdown file. Defaults to "report.md".
assets_dir: The directory to store the assets (figures, etc.). If None,
defaults to the parent directory of the output file with a ".ablate"
subdirectory. Defaults to None.
"""
self.output_path = Path(output_path)
self.assets_dir = (
Path(assets_dir) if assets_dir else self.output_path.parent / ".ablate"
)
self.assets_dir.mkdir(exist_ok=True)

def export(self, report: Report) -> None:
content = self.render_blocks(report)
with self.output_path.open("w", encoding="utf-8") as f:
for block_output in content:
f.write(block_output)
f.write("\n\n")

def render_text(self, block: AbstractTextBlock, runs: List[Run]) -> str:
if isinstance(block, Text):
return block.build(runs)
if isinstance(block, _Heading):
return f"{'#' * HEADING_LEVELS[type(block)]} {block.build(runs)}"
raise NotImplementedError(f"Unsupported text block: '{type(block)}'.")

def render_table(self, block: AbstractTableBlock, runs: List[Run]) -> str:
return block.build(runs).to_markdown(index=False)

def render_figure(self, block: AbstractFigureBlock, runs: List[Run]) -> str:
if not isinstance(block, MetricPlot):
raise NotImplementedError(f"Unsupported figure block: '{type(block)}'.")

filename = render_metric_plot(
block.build(runs),
self.assets_dir,
type(block).__name__,
)
if filename is None:
return (
f"*No data available for {', '.join(m.label for m in block.metrics)}*"
)
return f"![{filename}](.ablate/{filename})"
48 changes: 48 additions & 0 deletions ablate/exporters/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import hashlib
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from ablate.blocks import H1, H2, H3, H4, H5, H6


HEADING_LEVELS = {H1: 1, H2: 2, H3: 3, H4: 4, H5: 5, H6: 6}


def apply_default_plot_style() -> None:
sns.set_style("whitegrid")
sns.set_context("paper", font_scale=0.8)
sns.set_palette("muted")
plt.rcParams["figure.dpi"] = 300


def render_metric_plot(
df: pd.DataFrame,
output_dir: Path,
name_prefix: str,
) -> str | None:
apply_default_plot_style()
if df.empty:
return None

fig, ax = plt.subplots()
sns.lineplot(
data=df,
x="step",
y="value",
hue="run",
style="metric" if df["metric"].nunique() > 1 else None,
ax=ax,
)
ax.set_xlabel("Step")
ax.set_ylabel("Value")
ax.legend(title="Run", loc="best", frameon=False)
plt.tight_layout()

h = hashlib.md5(df.to_csv(index=False).encode("utf-8")).hexdigest()[:12]
filename = f"{name_prefix}_{h}.png"
fig.savefig(output_dir / filename)
plt.close(fig)
return filename
40 changes: 40 additions & 0 deletions ablate/report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List

from typing_extensions import Self


if TYPE_CHECKING:
from ablate.blocks import AbstractBlock
from ablate.core.types import Run


class Report:
def __init__(self, runs: List[Run]) -> None:
"""Report mapping a list of runs to a list of blocks.

Args:
runs: List of runs to be associated with the report.
"""
self.runs = runs
self.blocks: List[AbstractBlock] = []

def add(self, *blocks: AbstractBlock) -> Self:
"""Add one or more blocks to the report.

Returns:
The updated report with the added blocks.
"""
for block in blocks:
self.blocks.append(block)
return self

def __iadd__(self, block: AbstractBlock) -> Self:
self.blocks.append(block)
return self

def __add__(self, block: AbstractBlock) -> Report:
r = Report(self.runs)
r.blocks = self.blocks + [block]
return r
Loading