Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ablate/exporters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .abstract_exporter import AbstractExporter
from .markdown_exporter import Markdown
from .notebook_exporter import Notebook


__all__ = ["AbstractExporter", "Markdown"]
__all__ = ["AbstractExporter", "Markdown", "Notebook"]
74 changes: 74 additions & 0 deletions ablate/exporters/notebook_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import sys
from typing import List

from matplotlib import pyplot as plt

from ablate.blocks import (
AbstractFigureBlock,
AbstractTableBlock,
AbstractTextBlock,
MetricPlot,
Text,
)
from ablate.blocks.text_blocks import _Heading
from ablate.core.types import Run
from ablate.exporters.abstract_exporter import AbstractExporter
from ablate.exporters.utils import HEADING_LEVELS
from ablate.report import Report

from .utils import create_metric_plot


def running_in_notebook() -> bool:
return any("jupyter" in arg or "ipykernel" in arg for arg in sys.argv)


class Notebook(AbstractExporter):
def __init__(self) -> None:
super().__init__()
try:
from IPython.display import display # noqa: F401
except ImportError as e:
raise ImportError(
"Notebook exporter requires `jupyter`. "
"Please install with `pip install ablate[jupyter]`."
) from e

def export(self, report: Report) -> None:
if not running_in_notebook():
raise RuntimeError(
"Notebook exporter can only be used inside a Jupyter notebook."
)
self.render_blocks(report)

def render_text(self, block: AbstractTextBlock, runs: List[Run]) -> None:
from IPython.display import Markdown, display

if isinstance(block, Text):
display(Markdown(block.build(runs)))
elif isinstance(block, _Heading):
level = HEADING_LEVELS[type(block)]
display(Markdown(f"{'#' * level} {block.build(runs)}"))
else:
raise NotImplementedError(f"Unsupported text block: '{type(block)}'")

def render_table(self, block: AbstractTableBlock, runs: List[Run]) -> None:
from IPython.display import display

display(block.build(runs))

def render_figure(self, block: AbstractFigureBlock, runs: List[Run]) -> None:
from IPython.display import Markdown, display

if not isinstance(block, MetricPlot):
raise NotImplementedError(f"Unsupported figure block: '{type(block)}'.")

df = block.build(runs)
if df.empty:
m = f"*No data available for {', '.join(m.label for m in block.metrics)}*"
display(Markdown(m))
return

fig = create_metric_plot(df, block.identifier.label)
plt.close(fig)
display(fig)
24 changes: 14 additions & 10 deletions ablate/exporters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,8 @@ def hash_dataframe(df: pd.DataFrame) -> str:
return hashlib.md5(df.to_csv(index=False).encode("utf-8")).hexdigest()[:12]


def render_metric_plot(
block: MetricPlot,
runs: List[Run],
output_dir: Path,
) -> str | None:
def create_metric_plot(df: pd.DataFrame, label: str) -> plt.Figure:
apply_default_plot_style()
df = block.build(runs)
if df.empty:
return None

fig, ax = plt.subplots()
sns.lineplot(
data=df,
Expand All @@ -45,9 +37,21 @@ def render_metric_plot(
)
ax.set_xlabel("Step")
ax.set_ylabel("Value")
ax.legend(title=block.identifier.label, loc="best", frameon=False)
ax.legend(title=label, loc="best", frameon=False)
plt.tight_layout()
return fig


def render_metric_plot(
block: MetricPlot,
runs: List[Run],
output_dir: Path,
) -> str | None:
df = block.build(runs)
if df.empty:
return None

fig = create_metric_plot(df, block.identifier.label)
filename = f"{type(block).__name__}_{hash_dataframe(df)}.png"
fig.savefig(output_dir / filename)
plt.close(fig)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ ablate = "ablate.core.cli.main:main"

[project.optional-dependencies]
mlflow = ["mlflow>=2.22.0"]
jupyter = ["jupyter>=1.1.1"]

[tool.ruff]
line-length = 88
Expand Down
Empty file added tests/exporters/__init__.py
Empty file.
22 changes: 3 additions & 19 deletions tests/exporters/test_markdown_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,19 @@
import matplotlib.pyplot as plt
import pytest

from ablate.blocks import (
H1,
H2,
AbstractFigureBlock,
AbstractTextBlock,
MetricPlot,
Table,
Text,
)
from ablate.blocks import H1, H2, MetricPlot, Table, Text
from ablate.core.types import Run
from ablate.exporters import Markdown
from ablate.queries import Metric, Param
from ablate.report import Report

from .utils import DummyFigureBlock, DummyTextBlock


class DummyBlock:
pass


class DummyTextBlock(AbstractTextBlock):
def build(self, runs: List[Run]) -> str: # type: ignore[override]
return "not supported"


class DummyFigureBlock(AbstractFigureBlock):
def build(self, runs: List[Run]) -> str: # type: ignore[override]
return "not a dataframe"


@pytest.fixture
def runs() -> list[Run]:
return [
Expand Down
120 changes: 120 additions & 0 deletions tests/exporters/test_notebook_exporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from typing import List
from unittest.mock import patch

import matplotlib.pyplot as plt
import pandas as pd
import pytest

from ablate import Report
from ablate.blocks import H2, MetricPlot, Table, Text
from ablate.core.types import Run
from ablate.exporters.notebook_exporter import Notebook, running_in_notebook
from ablate.queries import Metric, Param

from .utils import DummyFigureBlock, DummyTextBlock


@pytest.fixture
def runs() -> List[Run]:
return [
Run(
id="r1",
params={"model": "resnet"},
metrics={"acc": 0.8},
temporal={"acc": [(0, 0.5), (1, 0.8)]},
),
Run(
id="r2",
params={"model": "resnet"},
metrics={"acc": 0.9},
temporal={"acc": [(0, 0.6), (1, 0.9)]},
),
]


def test_render_text_block(runs: List[Run]) -> None:
block = Text("Hello *world*")
with patch("IPython.display.display") as mock_display:
Notebook().render_text(block, runs)
assert mock_display.call_count == 1
assert "Hello" in mock_display.call_args[0][0].data


def test_render_heading_block(runs: List[Run]) -> None:
block = H2("Section")
with patch("IPython.display.display") as mock_display:
Notebook().render_text(block, runs)
assert mock_display.call_count == 1
assert mock_display.call_args[0][0].data.startswith("##")


def test_render_table_block(runs: List[Run]) -> None:
table = Table(columns=[Param("model"), Metric("acc", direction="max")])
with patch("IPython.display.display") as mock_display:
Notebook().render_table(table, runs)
df = mock_display.call_args[0][0]
assert isinstance(df, pd.DataFrame)
assert "model" in df.columns
assert len(df) == 2


def test_render_metric_plot(runs: List[Run]) -> None:
plot = MetricPlot(Metric("acc", direction="max"), identifier=Param("model"))
with patch("IPython.display.display") as mock_display:
Notebook().render_figure(plot, runs)
fig = mock_display.call_args[0][0]
assert isinstance(fig, plt.Figure)


def test_render_empty_plot() -> None:
empty_run = Run(id="x", params={}, metrics={}, temporal={})
plot = MetricPlot(Metric("acc", direction="max"), identifier=Param("model"))
with patch("IPython.display.display") as mock_display:
Notebook().render_figure(plot, [empty_run])
text = mock_display.call_args[0][0]
assert "*No data available for acc*" in text.data


def test_running_in_notebook_true(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("sys.argv", ["ipykernel_launcher"])
assert running_in_notebook() is True


def test_running_in_notebook_false(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("sys.argv", ["some_script.py"])
assert running_in_notebook() is False


def test_notebook_import_error() -> None:
with (
patch.dict("sys.modules", {"IPython.display": None}),
pytest.raises(ImportError, match="requires `jupyter`"),
):
Notebook()


def test_export_runs(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("sys.argv", ["ipykernel_launcher"])
report = Report([])
with patch.object(Notebook, "render_blocks", return_value=[]) as mock_render:
Notebook().export(report)
mock_render.assert_called_once_with(report)


def test_export_outside_notebook(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr("sys.argv", ["script.py"])

with pytest.raises(RuntimeError, match="only be used inside a Jupyter notebook"):
Notebook().export(Report([]))


def test_render_text_not_implemented(runs: List[Run]) -> None:
dummy = DummyTextBlock("oops")
with pytest.raises(NotImplementedError):
Notebook().render_text(dummy, runs)


def test_render_figure_not_implemented(runs: List[Run]) -> None:
dummy = DummyFigureBlock()
with pytest.raises(NotImplementedError):
Notebook().render_figure(dummy, runs)
14 changes: 14 additions & 0 deletions tests/exporters/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import List

from ablate.blocks import AbstractFigureBlock, AbstractTextBlock
from ablate.core.types import Run


class DummyTextBlock(AbstractTextBlock):
def build(self, runs: List[Run]) -> str: # type: ignore[override]
return "not supported"


class DummyFigureBlock(AbstractFigureBlock):
def build(self, runs: List[Run]) -> str: # type: ignore[override]
return "not a dataframe"
Loading