diff --git a/ablate/sources/__init__.py b/ablate/sources/__init__.py new file mode 100644 index 0000000..8a9496e --- /dev/null +++ b/ablate/sources/__init__.py @@ -0,0 +1,5 @@ +from .abstract_source import AbstractSource +from .mlflow_source import MLflow + + +__all__ = ["AbstractSource", "MLflow"] diff --git a/ablate/sources/abstract_source.py b/ablate/sources/abstract_source.py new file mode 100644 index 0000000..c1be956 --- /dev/null +++ b/ablate/sources/abstract_source.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod +from typing import List + +from ablate.core.types import Run + + +class AbstractSource(ABC): + @abstractmethod + def load(self) -> List[Run]: + """Load the data from the source. + + Returns: + A list of runs with their parameters, metrics, and optionally temporal data. + """ diff --git a/ablate/sources/mlflow_source.py b/ablate/sources/mlflow_source.py new file mode 100644 index 0000000..41c3362 --- /dev/null +++ b/ablate/sources/mlflow_source.py @@ -0,0 +1,57 @@ +from pathlib import Path +from typing import List +from urllib.parse import urlparse + +from ablate.core.types import Run + +from .abstract_source import AbstractSource + + +class MLflow(AbstractSource): + def __init__(self, experiment_names: List[str], tracking_uri: str | None) -> None: + """MLflow source for loading runs from a MLflow server. + + Args: + experiment_names: A list of experiment names to load runs from. + tracking_uri: The URI or local path to the MLflow tracking server. + If None, use the default tracking URI set in the MLflow configuration. + Defaults to None. + + Raises: + ImportError: If the `mlflow` package is not installed. + """ + try: + from mlflow.tracking import MlflowClient + except ImportError as e: + raise ImportError( + "MLflow source requires `mlflow`. " + "Please install with `pip install ablate[mlflow]`." + ) from e + + self.tracking_uri = tracking_uri + self.experiment_names = experiment_names + if not tracking_uri: + self.client = MlflowClient() + return + if urlparse(tracking_uri).scheme in {"http", "https", "file"}: + uri = tracking_uri + else: + uri = Path(tracking_uri).resolve().as_uri() + self.client = MlflowClient(uri) + + def load(self) -> List[Run]: + runs = self.client.search_runs( + [ + self.client.get_experiment_by_name(n).experiment_id + for n in self.experiment_names + ] + ) + records: List[Run] = [] + for run in runs: + p, m, t = run.data.params, run.data.metrics, {} + p.update(run.data.tags) + for name in m: + history = self.client.get_metric_history(run.info.run_id, name) + t[name] = [(h.step, h.value) for h in history] + records.append(Run(id=run.info.run_id, params=p, metrics=m, temporal=t)) + return records diff --git a/tests/sources/test_mlflow_source.py b/tests/sources/test_mlflow_source.py new file mode 100644 index 0000000..5821999 --- /dev/null +++ b/tests/sources/test_mlflow_source.py @@ -0,0 +1,62 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from ablate.sources import MLflow + + +@pytest.mark.filterwarnings("ignore::pydantic.PydanticDeprecatedSince20") +@pytest.mark.parametrize( + ("tracking_uri", "expected_uri_startswith"), + [ + (None, None), + ("/some/local/path", "file://"), + ("http://mlflow.mycompany.com", "http://"), + ("https://mlflow.mycompany.com", "https://"), + ("file:///already/uri", "file://"), + ], +) +@patch("mlflow.tracking.MlflowClient") +def test_mlflow_uri_resolution( + client: MagicMock, + tracking_uri: str, + expected_uri_startswith: str, +) -> None: + mock_client = client.return_value + mock_run = SimpleNamespace( + info=SimpleNamespace(run_id="run-1"), + data=SimpleNamespace( + params={"lr": "0.01"}, + metrics={"accuracy": 0.9}, + tags={"mlflow.runName": "example"}, + ), + ) + mock_client.get_experiment_by_name.return_value = SimpleNamespace( + experiment_id="123" + ) + mock_client.search_runs.return_value = [mock_run] + mock_client.get_metric_history.return_value = [SimpleNamespace(step=1, value=0.9)] + + source = MLflow(tracking_uri=tracking_uri, experiment_names=["default"]) + runs = source.load() + + if expected_uri_startswith is None: + client.assert_called_once_with() + else: + uri_arg = client.call_args.args[0] + assert uri_arg.startswith(expected_uri_startswith) + + r = runs[0] + assert r.id == "run-1" + assert r.params == {"lr": "0.01", "mlflow.runName": "example"} + assert r.metrics == {"accuracy": 0.9} + assert r.temporal == {"accuracy": [(1, 0.9)]} + + +def test_import_error_if_mlflow_not_installed() -> None: + with ( + patch.dict("sys.modules", {"mlflow.tracking": None}), + pytest.raises(ImportError, match="MLflow source requires `mlflow`"), + ): + MLflow(tracking_uri="/fake", experiment_names=["exp"])