From cf9dd5d2b14e073063d10ee2aa9e44f25067f443 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Mon, 12 May 2025 18:29:13 +0200 Subject: [PATCH 1/2] fix incorrect types and add ignores where needed in tests --- tests/core/test_types.py | 4 ++-- tests/exporters/test_markdown_exporter.py | 24 +++++++++++------------ tests/queries/test_grouped_query.py | 2 +- tests/queries/test_selectors.py | 6 +++--- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/core/test_types.py b/tests/core/test_types.py index b789406..7c1da86 100644 --- a/tests/core/test_types.py +++ b/tests/core/test_types.py @@ -25,7 +25,7 @@ def test_temporal_data() -> None: def test_invalid_metrics_type() -> None: with pytest.raises(ValidationError): - Run(id="bad", params={}, metrics="not a dict") + Run(id="bad", params={}, metrics="not a dict") # type: ignore[arg-type] def test_grouped_run() -> None: @@ -41,7 +41,7 @@ def test_grouped_run() -> None: def test_invalid_runs_type() -> None: with pytest.raises(ValidationError): - GroupedRun(key="group_key", value="group_value", runs=["not a run"]) + GroupedRun(key="group_key", value="group_value", runs=["not a run"]) # type: ignore[list-item] def test_run_roundtrip_serialization() -> None: diff --git a/tests/exporters/test_markdown_exporter.py b/tests/exporters/test_markdown_exporter.py index 6a2d57a..1f731a2 100644 --- a/tests/exporters/test_markdown_exporter.py +++ b/tests/exporters/test_markdown_exporter.py @@ -25,12 +25,12 @@ class DummyBlock: class DummyTextBlock(AbstractTextBlock): - def build(self, runs: List[Run]) -> str: + def build(self, runs: List[Run]) -> str: # type: ignore[override] return "not supported" class DummyFigureBlock(AbstractFigureBlock): - def build(self, runs: List[Run]) -> str: + def build(self, runs: List[Run]) -> str: # type: ignore[override] return "not a dataframe" @@ -55,7 +55,7 @@ def runs() -> list[Run]: def test_export_text_blocks(tmp_path: Path, runs: List[Run]) -> None: report = Report(runs).add(H1("Heading 1"), Text("Some paragraph text.")) out_path = tmp_path / "report.md" - Markdown(output_path=out_path).export(report) + Markdown(output_path=str(out_path)).export(report) content = out_path.read_text() assert "# Heading 1" in content @@ -71,7 +71,7 @@ def test_export_table_block(tmp_path: Path, runs: List[Run]) -> None: ) report = Report(runs).add(table) out_path = tmp_path / "report.md" - Markdown(output_path=out_path).export(report) + Markdown(output_path=str(out_path)).export(report) content = out_path.read_text().replace("\r\n", "\n") assert re.search(r"\|\s*Model\s*\|\s*Accuracy\s*\|", content) @@ -86,7 +86,7 @@ def test_export_figure_block(tmp_path: Path, runs: List[Run]) -> None: plot = MetricPlot(Metric("accuracy", direction="max"), identifier=Param("model")) report = Report(runs).add(plot) out_path = tmp_path / "report.md" - exporter = Markdown(output_path=out_path) + exporter = Markdown(output_path=str(out_path)) exporter.export(report) content = out_path.read_text() @@ -104,7 +104,7 @@ def test_export_figure_block_empty(tmp_path: Path) -> None: plot = MetricPlot(Metric("accuracy", direction="max"), identifier=Param("model")) report = Report([empty_run]).add(plot) out_path = tmp_path / "report.md" - Markdown(output_path=out_path).export(report) + Markdown(output_path=str(out_path)).export(report) content = out_path.read_text() assert "*No data available for accuracy*" in content @@ -112,21 +112,21 @@ def test_export_figure_block_empty(tmp_path: Path) -> None: def test_unknown_block_raises(tmp_path: Path, runs: List[Run]) -> None: report = Report(runs) - report += DummyBlock() + report += DummyBlock() # type: ignore[arg-type] with pytest.raises(ValueError, match="Unknown block type"): - Markdown(output_path=tmp_path / "out.md").export(report) + Markdown(output_path=str(tmp_path / "out.md")).export(report) def test_unsupported_figure_block_raises(tmp_path: Path, runs: List[Run]) -> None: report = Report(runs).add(DummyFigureBlock()) - exporter = Markdown(output_path=tmp_path / "out.md") + exporter = Markdown(output_path=str(tmp_path / "out.md")) with pytest.raises(NotImplementedError, match="Unsupported figure block"): exporter.export(report) def test_unsupported_text_block_raises(tmp_path: Path, runs: List[Run]) -> None: report = Report(runs).add(DummyTextBlock("oops")) - exporter = Markdown(output_path=tmp_path / "out.md") + exporter = Markdown(output_path=str(tmp_path / "out.md")) with pytest.raises(NotImplementedError, match="Unsupported text block"): exporter.export(report) @@ -138,7 +138,7 @@ def test_block_level_runs_override_global(tmp_path: Path, runs: List[Run]) -> No Table([Param("model"), Metric("accuracy", "max")], runs=scoped_runs), ) out_path = tmp_path / "report.md" - Markdown(output_path=out_path).export(report) + Markdown(output_path=str(out_path)).export(report) content = out_path.read_text() assert "resnet" in content assert content.count("resnet") == 1 @@ -147,6 +147,6 @@ def test_block_level_runs_override_global(tmp_path: Path, runs: List[Run]) -> No def test_export_heading_variants(tmp_path: Path, runs: List[Run]) -> None: report = Report(runs).add(H2("Section Title")) out_path = tmp_path / "headings.md" - Markdown(output_path=out_path).export(report) + Markdown(output_path=str(out_path)).export(report) content = out_path.read_text() assert "## Section Title" in content diff --git a/tests/queries/test_grouped_query.py b/tests/queries/test_grouped_query.py index 3f31fe5..4fe5622 100644 --- a/tests/queries/test_grouped_query.py +++ b/tests/queries/test_grouped_query.py @@ -63,7 +63,7 @@ def test_aggregate_all_strategies(grouped: GroupedQuery) -> None: assert len(grouped.aggregate("mean", over=m).all()) == 2 with pytest.raises(ValueError, match="Unsupported aggregation method"): - grouped.aggregate("unsupported", over=m) + grouped.aggregate("unsupported", over=m) # type: ignore[arg-type] def test_aggregate_best_worst_missing_over(grouped: GroupedQuery) -> None: diff --git a/tests/queries/test_selectors.py b/tests/queries/test_selectors.py index 142aecc..60db694 100644 --- a/tests/queries/test_selectors.py +++ b/tests/queries/test_selectors.py @@ -56,7 +56,7 @@ def test_metric_selector(example_run: Run) -> None: def test_invalid_metric_direction() -> None: with pytest.raises(ValueError, match="Invalid direction"): - Metric("accuracy", direction="invalid") + Metric("accuracy", direction="invalid") # type: ignore[arg-type] @pytest.mark.parametrize( @@ -71,7 +71,7 @@ def test_invalid_metric_direction() -> None: def test_temporal_metric_selector( example_run: Run, reduction: str, expected: float ) -> None: - selector = TemporalMetric("accuracy", direction="max", reduction=reduction) + selector = TemporalMetric("accuracy", direction="max", reduction=reduction) # type: ignore[arg-type] assert selector(example_run) == expected @@ -82,4 +82,4 @@ def test_temporal_metric_missing_returns_nan(example_run: Run) -> None: def test_temporal_metric_invalid_reduction() -> None: with pytest.raises(ValueError, match="Invalid reduction method"): - TemporalMetric("accuracy", direction="max", reduction="median") + TemporalMetric("accuracy", direction="max", reduction="median") # type: ignore[arg-type] From 5c05c0913653e30c9fe28c746a58ef09eaeafc9d Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Mon, 12 May 2025 18:29:47 +0200 Subject: [PATCH 2/2] add mypy to pre-commit --- .pre-commit-config.yaml | 5 +++++ pyproject.toml | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e76b64d..104ef44 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,3 +11,8 @@ repos: - id: codespell additional_dependencies: - tomli + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.15.0 + hooks: + - id: mypy + args: ["--config-file=pyproject.toml"] diff --git a/pyproject.toml b/pyproject.toml index 6a23002..e908080 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,3 +86,8 @@ testpaths = "tests" [tool.codespell] skip = "uv.lock" + +[tool.mypy] +files = ["ablate", "tests"] +explicit_package_bases = true +ignore_missing_imports = true