Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ repos:
- id: codespell
additional_dependencies:
- tomli
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0
hooks:
- id: mypy
args: ["--config-file=pyproject.toml"]
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,8 @@ testpaths = "tests"

[tool.codespell]
skip = "uv.lock"

[tool.mypy]
files = ["ablate", "tests"]
explicit_package_bases = true
ignore_missing_imports = true
4 changes: 2 additions & 2 deletions tests/core/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_temporal_data() -> None:

def test_invalid_metrics_type() -> None:
with pytest.raises(ValidationError):
Run(id="bad", params={}, metrics="not a dict")
Run(id="bad", params={}, metrics="not a dict") # type: ignore[arg-type]


def test_grouped_run() -> None:
Expand All @@ -41,7 +41,7 @@ def test_grouped_run() -> None:

def test_invalid_runs_type() -> None:
with pytest.raises(ValidationError):
GroupedRun(key="group_key", value="group_value", runs=["not a run"])
GroupedRun(key="group_key", value="group_value", runs=["not a run"]) # type: ignore[list-item]


def test_run_roundtrip_serialization() -> None:
Expand Down
24 changes: 12 additions & 12 deletions tests/exporters/test_markdown_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ class DummyBlock:


class DummyTextBlock(AbstractTextBlock):
def build(self, runs: List[Run]) -> str:
def build(self, runs: List[Run]) -> str: # type: ignore[override]
return "not supported"


class DummyFigureBlock(AbstractFigureBlock):
def build(self, runs: List[Run]) -> str:
def build(self, runs: List[Run]) -> str: # type: ignore[override]
return "not a dataframe"


Expand All @@ -55,7 +55,7 @@ def runs() -> list[Run]:
def test_export_text_blocks(tmp_path: Path, runs: List[Run]) -> None:
report = Report(runs).add(H1("Heading 1"), Text("Some paragraph text."))
out_path = tmp_path / "report.md"
Markdown(output_path=out_path).export(report)
Markdown(output_path=str(out_path)).export(report)

content = out_path.read_text()
assert "# Heading 1" in content
Expand All @@ -71,7 +71,7 @@ def test_export_table_block(tmp_path: Path, runs: List[Run]) -> None:
)
report = Report(runs).add(table)
out_path = tmp_path / "report.md"
Markdown(output_path=out_path).export(report)
Markdown(output_path=str(out_path)).export(report)

content = out_path.read_text().replace("\r\n", "\n")
assert re.search(r"\|\s*Model\s*\|\s*Accuracy\s*\|", content)
Expand All @@ -86,7 +86,7 @@ def test_export_figure_block(tmp_path: Path, runs: List[Run]) -> None:
plot = MetricPlot(Metric("accuracy", direction="max"), identifier=Param("model"))
report = Report(runs).add(plot)
out_path = tmp_path / "report.md"
exporter = Markdown(output_path=out_path)
exporter = Markdown(output_path=str(out_path))
exporter.export(report)

content = out_path.read_text()
Expand All @@ -104,29 +104,29 @@ def test_export_figure_block_empty(tmp_path: Path) -> None:
plot = MetricPlot(Metric("accuracy", direction="max"), identifier=Param("model"))
report = Report([empty_run]).add(plot)
out_path = tmp_path / "report.md"
Markdown(output_path=out_path).export(report)
Markdown(output_path=str(out_path)).export(report)

content = out_path.read_text()
assert "*No data available for accuracy*" in content


def test_unknown_block_raises(tmp_path: Path, runs: List[Run]) -> None:
report = Report(runs)
report += DummyBlock()
report += DummyBlock() # type: ignore[arg-type]
with pytest.raises(ValueError, match="Unknown block type"):
Markdown(output_path=tmp_path / "out.md").export(report)
Markdown(output_path=str(tmp_path / "out.md")).export(report)


def test_unsupported_figure_block_raises(tmp_path: Path, runs: List[Run]) -> None:
report = Report(runs).add(DummyFigureBlock())
exporter = Markdown(output_path=tmp_path / "out.md")
exporter = Markdown(output_path=str(tmp_path / "out.md"))
with pytest.raises(NotImplementedError, match="Unsupported figure block"):
exporter.export(report)


def test_unsupported_text_block_raises(tmp_path: Path, runs: List[Run]) -> None:
report = Report(runs).add(DummyTextBlock("oops"))
exporter = Markdown(output_path=tmp_path / "out.md")
exporter = Markdown(output_path=str(tmp_path / "out.md"))
with pytest.raises(NotImplementedError, match="Unsupported text block"):
exporter.export(report)

Expand All @@ -138,7 +138,7 @@ def test_block_level_runs_override_global(tmp_path: Path, runs: List[Run]) -> No
Table([Param("model"), Metric("accuracy", "max")], runs=scoped_runs),
)
out_path = tmp_path / "report.md"
Markdown(output_path=out_path).export(report)
Markdown(output_path=str(out_path)).export(report)
content = out_path.read_text()
assert "resnet" in content
assert content.count("resnet") == 1
Expand All @@ -147,6 +147,6 @@ def test_block_level_runs_override_global(tmp_path: Path, runs: List[Run]) -> No
def test_export_heading_variants(tmp_path: Path, runs: List[Run]) -> None:
report = Report(runs).add(H2("Section Title"))
out_path = tmp_path / "headings.md"
Markdown(output_path=out_path).export(report)
Markdown(output_path=str(out_path)).export(report)
content = out_path.read_text()
assert "## Section Title" in content
2 changes: 1 addition & 1 deletion tests/queries/test_grouped_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_aggregate_all_strategies(grouped: GroupedQuery) -> None:
assert len(grouped.aggregate("mean", over=m).all()) == 2

with pytest.raises(ValueError, match="Unsupported aggregation method"):
grouped.aggregate("unsupported", over=m)
grouped.aggregate("unsupported", over=m) # type: ignore[arg-type]


def test_aggregate_best_worst_missing_over(grouped: GroupedQuery) -> None:
Expand Down
6 changes: 3 additions & 3 deletions tests/queries/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_metric_selector(example_run: Run) -> None:

def test_invalid_metric_direction() -> None:
with pytest.raises(ValueError, match="Invalid direction"):
Metric("accuracy", direction="invalid")
Metric("accuracy", direction="invalid") # type: ignore[arg-type]


@pytest.mark.parametrize(
Expand All @@ -71,7 +71,7 @@ def test_invalid_metric_direction() -> None:
def test_temporal_metric_selector(
example_run: Run, reduction: str, expected: float
) -> None:
selector = TemporalMetric("accuracy", direction="max", reduction=reduction)
selector = TemporalMetric("accuracy", direction="max", reduction=reduction) # type: ignore[arg-type]
assert selector(example_run) == expected


Expand All @@ -82,4 +82,4 @@ def test_temporal_metric_missing_returns_nan(example_run: Run) -> None:

def test_temporal_metric_invalid_reduction() -> None:
with pytest.raises(ValueError, match="Invalid reduction method"):
TemporalMetric("accuracy", direction="max", reduction="median")
TemporalMetric("accuracy", direction="max", reduction="median") # type: ignore[arg-type]