From cfb08b4e36fb10749337ccb37ca7a3a62d3fee8d Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sun, 11 May 2025 12:13:37 +0200 Subject: [PATCH 1/5] add abstract source --- ablate/sources/abstract_source.py | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 ablate/sources/abstract_source.py diff --git a/ablate/sources/abstract_source.py b/ablate/sources/abstract_source.py new file mode 100644 index 0000000..4e2e221 --- /dev/null +++ b/ablate/sources/abstract_source.py @@ -0,0 +1,9 @@ +from abc import ABC, abstractmethod +from typing import List + +from ablate.core.types import Run + + +class AbstractSource(ABC): + @abstractmethod + def load(self) -> List[Run]: ... From c16bf973d5411a556248c418f56ac9b41d78ef66 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sun, 11 May 2025 12:13:46 +0200 Subject: [PATCH 2/5] add mlflow source --- ablate/sources/__init__.py | 5 +++++ ablate/sources/mlflow_source.py | 38 +++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+) create mode 100644 ablate/sources/__init__.py create mode 100644 ablate/sources/mlflow_source.py diff --git a/ablate/sources/__init__.py b/ablate/sources/__init__.py new file mode 100644 index 0000000..8a9496e --- /dev/null +++ b/ablate/sources/__init__.py @@ -0,0 +1,5 @@ +from .abstract_source import AbstractSource +from .mlflow_source import MLflow + + +__all__ = ["AbstractSource", "MLflow"] diff --git a/ablate/sources/mlflow_source.py b/ablate/sources/mlflow_source.py new file mode 100644 index 0000000..872f225 --- /dev/null +++ b/ablate/sources/mlflow_source.py @@ -0,0 +1,38 @@ +from pathlib import Path +from typing import List + +from ablate.core.types import Run + +from .abstract_source import AbstractSource + + +class MLflow(AbstractSource): + def __init__(self, tracking_uri: str, experiment_names: List[str]) -> None: + try: + from mlflow.tracking import MlflowClient + except ImportError as e: + raise ImportError( + "MLflow source requires `mlflow`. " + "Please install with `pip install ablate[mlflow]`." + ) from e + + self.tracking_uri = tracking_uri + self.experiment_names = experiment_names + self.client = MlflowClient(Path(tracking_uri).resolve().as_uri()) + + def load(self) -> List[Run]: + runs = self.client.search_runs( + [ + self.client.get_experiment_by_name(n).experiment_id + for n in self.experiment_names + ] + ) + records: List[Run] = [] + for run in runs: + p, m, t = run.data.params, run.data.metrics, {} + p.update(run.data.tags) + for name in m: + history = self.client.get_metric_history(run.info.run_id, name) + t[name] = [(h.step, h.value) for h in history] + records.append(Run(id=run.info.run_id, params=p, metrics=m, temporal=t)) + return records From 27924bbf04ffa8cc0871db31595d491910ea39f5 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sun, 11 May 2025 12:14:04 +0200 Subject: [PATCH 3/5] add mlflow source tests --- tests/sources/test_mlflow_source.py | 49 +++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tests/sources/test_mlflow_source.py diff --git a/tests/sources/test_mlflow_source.py b/tests/sources/test_mlflow_source.py new file mode 100644 index 0000000..855c1e2 --- /dev/null +++ b/tests/sources/test_mlflow_source.py @@ -0,0 +1,49 @@ +from types import SimpleNamespace +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import pytest + +from ablate.sources import MLflow + + +if TYPE_CHECKING: + from ablate.core.types import Run + + +@pytest.mark.filterwarnings("ignore::pydantic.PydanticDeprecatedSince20") +@patch("mlflow.tracking.MlflowClient") +def test_load_converts_mlflow_runs_to_runs(client: MagicMock) -> None: + mock_client = client.return_value + mock_run = SimpleNamespace( + info=SimpleNamespace(run_id="run-1"), + data=SimpleNamespace( + params={"lr": "0.01"}, + metrics={"accuracy": 0.9}, + tags={"mlflow.runName": "example"}, + ), + ) + + mock_client.get_experiment_by_name.return_value = SimpleNamespace( + experiment_id="123" + ) + mock_client.search_runs.return_value = [mock_run] + mock_client.get_metric_history.return_value = [SimpleNamespace(step=1, value=0.9)] + + source = MLflow(tracking_uri="/fake/path", experiment_names=["default"]) + runs = source.load() + r: Run = runs[0] + + assert len(runs) == 1 + assert r.id == "run-1" + assert r.params == {"lr": "0.01", "mlflow.runName": "example"} + assert r.metrics == {"accuracy": 0.9} + assert r.temporal == {"accuracy": [(1, 0.9)]} + + +def test_import_error_if_mlflow_not_installed() -> None: + with ( + patch.dict("sys.modules", {"mlflow.tracking": None}), + pytest.raises(ImportError, match="MLflow source requires `mlflow`"), + ): + MLflow(tracking_uri="/fake", experiment_names=["exp"]) From 78e49d20e5e79b8f9c5ae37ab7ff36accc7f427e Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sun, 11 May 2025 12:53:15 +0200 Subject: [PATCH 4/5] add docstrings and handle mlflow uris --- ablate/sources/abstract_source.py | 7 ++++++- ablate/sources/mlflow_source.py | 23 +++++++++++++++++++++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/ablate/sources/abstract_source.py b/ablate/sources/abstract_source.py index 4e2e221..c1be956 100644 --- a/ablate/sources/abstract_source.py +++ b/ablate/sources/abstract_source.py @@ -6,4 +6,9 @@ class AbstractSource(ABC): @abstractmethod - def load(self) -> List[Run]: ... + def load(self) -> List[Run]: + """Load the data from the source. + + Returns: + A list of runs with their parameters, metrics, and optionally temporal data. + """ diff --git a/ablate/sources/mlflow_source.py b/ablate/sources/mlflow_source.py index 872f225..41c3362 100644 --- a/ablate/sources/mlflow_source.py +++ b/ablate/sources/mlflow_source.py @@ -1,5 +1,6 @@ from pathlib import Path from typing import List +from urllib.parse import urlparse from ablate.core.types import Run @@ -7,7 +8,18 @@ class MLflow(AbstractSource): - def __init__(self, tracking_uri: str, experiment_names: List[str]) -> None: + def __init__(self, experiment_names: List[str], tracking_uri: str | None) -> None: + """MLflow source for loading runs from a MLflow server. + + Args: + experiment_names: A list of experiment names to load runs from. + tracking_uri: The URI or local path to the MLflow tracking server. + If None, use the default tracking URI set in the MLflow configuration. + Defaults to None. + + Raises: + ImportError: If the `mlflow` package is not installed. + """ try: from mlflow.tracking import MlflowClient except ImportError as e: @@ -18,7 +30,14 @@ def __init__(self, tracking_uri: str, experiment_names: List[str]) -> None: self.tracking_uri = tracking_uri self.experiment_names = experiment_names - self.client = MlflowClient(Path(tracking_uri).resolve().as_uri()) + if not tracking_uri: + self.client = MlflowClient() + return + if urlparse(tracking_uri).scheme in {"http", "https", "file"}: + uri = tracking_uri + else: + uri = Path(tracking_uri).resolve().as_uri() + self.client = MlflowClient(uri) def load(self) -> List[Run]: runs = self.client.search_runs( From 90d8889f62872b1b3f00307d9f1b55a94e279056 Mon Sep 17 00:00:00 2001 From: Simon Rampp Date: Sun, 11 May 2025 12:53:20 +0200 Subject: [PATCH 5/5] update tests --- tests/sources/test_mlflow_source.py | 33 ++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/tests/sources/test_mlflow_source.py b/tests/sources/test_mlflow_source.py index 855c1e2..5821999 100644 --- a/tests/sources/test_mlflow_source.py +++ b/tests/sources/test_mlflow_source.py @@ -1,5 +1,4 @@ from types import SimpleNamespace -from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import pytest @@ -7,13 +6,23 @@ from ablate.sources import MLflow -if TYPE_CHECKING: - from ablate.core.types import Run - - @pytest.mark.filterwarnings("ignore::pydantic.PydanticDeprecatedSince20") +@pytest.mark.parametrize( + ("tracking_uri", "expected_uri_startswith"), + [ + (None, None), + ("/some/local/path", "file://"), + ("http://mlflow.mycompany.com", "http://"), + ("https://mlflow.mycompany.com", "https://"), + ("file:///already/uri", "file://"), + ], +) @patch("mlflow.tracking.MlflowClient") -def test_load_converts_mlflow_runs_to_runs(client: MagicMock) -> None: +def test_mlflow_uri_resolution( + client: MagicMock, + tracking_uri: str, + expected_uri_startswith: str, +) -> None: mock_client = client.return_value mock_run = SimpleNamespace( info=SimpleNamespace(run_id="run-1"), @@ -23,18 +32,22 @@ def test_load_converts_mlflow_runs_to_runs(client: MagicMock) -> None: tags={"mlflow.runName": "example"}, ), ) - mock_client.get_experiment_by_name.return_value = SimpleNamespace( experiment_id="123" ) mock_client.search_runs.return_value = [mock_run] mock_client.get_metric_history.return_value = [SimpleNamespace(step=1, value=0.9)] - source = MLflow(tracking_uri="/fake/path", experiment_names=["default"]) + source = MLflow(tracking_uri=tracking_uri, experiment_names=["default"]) runs = source.load() - r: Run = runs[0] - assert len(runs) == 1 + if expected_uri_startswith is None: + client.assert_called_once_with() + else: + uri_arg = client.call_args.args[0] + assert uri_arg.startswith(expected_uri_startswith) + + r = runs[0] assert r.id == "run-1" assert r.params == {"lr": "0.01", "mlflow.runName": "example"} assert r.metrics == {"accuracy": 0.9}