Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ablate/sources/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .abstract_source import AbstractSource
from .mlflow_source import MLflow


__all__ = ["AbstractSource", "MLflow"]
14 changes: 14 additions & 0 deletions ablate/sources/abstract_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from typing import List

from ablate.core.types import Run


class AbstractSource(ABC):
@abstractmethod
def load(self) -> List[Run]:
"""Load the data from the source.

Returns:
A list of runs with their parameters, metrics, and optionally temporal data.
"""
57 changes: 57 additions & 0 deletions ablate/sources/mlflow_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from pathlib import Path
from typing import List
from urllib.parse import urlparse

from ablate.core.types import Run

from .abstract_source import AbstractSource


class MLflow(AbstractSource):
def __init__(self, experiment_names: List[str], tracking_uri: str | None) -> None:
"""MLflow source for loading runs from a MLflow server.

Args:
experiment_names: A list of experiment names to load runs from.
tracking_uri: The URI or local path to the MLflow tracking server.
If None, use the default tracking URI set in the MLflow configuration.
Defaults to None.

Raises:
ImportError: If the `mlflow` package is not installed.
"""
try:
from mlflow.tracking import MlflowClient
except ImportError as e:
raise ImportError(
"MLflow source requires `mlflow`. "
"Please install with `pip install ablate[mlflow]`."
) from e

self.tracking_uri = tracking_uri
self.experiment_names = experiment_names
if not tracking_uri:
self.client = MlflowClient()
return
if urlparse(tracking_uri).scheme in {"http", "https", "file"}:
uri = tracking_uri
else:
uri = Path(tracking_uri).resolve().as_uri()
self.client = MlflowClient(uri)

def load(self) -> List[Run]:
runs = self.client.search_runs(
[
self.client.get_experiment_by_name(n).experiment_id
for n in self.experiment_names
]
)
records: List[Run] = []
for run in runs:
p, m, t = run.data.params, run.data.metrics, {}
p.update(run.data.tags)
for name in m:
history = self.client.get_metric_history(run.info.run_id, name)
t[name] = [(h.step, h.value) for h in history]
records.append(Run(id=run.info.run_id, params=p, metrics=m, temporal=t))
return records
62 changes: 62 additions & 0 deletions tests/sources/test_mlflow_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from types import SimpleNamespace
from unittest.mock import MagicMock, patch

import pytest

from ablate.sources import MLflow


@pytest.mark.filterwarnings("ignore::pydantic.PydanticDeprecatedSince20")
@pytest.mark.parametrize(
("tracking_uri", "expected_uri_startswith"),
[
(None, None),
("/some/local/path", "file://"),
("http://mlflow.mycompany.com", "http://"),
("https://mlflow.mycompany.com", "https://"),
("file:///already/uri", "file://"),
],
)
@patch("mlflow.tracking.MlflowClient")
def test_mlflow_uri_resolution(
client: MagicMock,
tracking_uri: str,
expected_uri_startswith: str,
) -> None:
mock_client = client.return_value
mock_run = SimpleNamespace(
info=SimpleNamespace(run_id="run-1"),
data=SimpleNamespace(
params={"lr": "0.01"},
metrics={"accuracy": 0.9},
tags={"mlflow.runName": "example"},
),
)
mock_client.get_experiment_by_name.return_value = SimpleNamespace(
experiment_id="123"
)
mock_client.search_runs.return_value = [mock_run]
mock_client.get_metric_history.return_value = [SimpleNamespace(step=1, value=0.9)]

source = MLflow(tracking_uri=tracking_uri, experiment_names=["default"])
runs = source.load()

if expected_uri_startswith is None:
client.assert_called_once_with()
else:
uri_arg = client.call_args.args[0]
assert uri_arg.startswith(expected_uri_startswith)

r = runs[0]
assert r.id == "run-1"
assert r.params == {"lr": "0.01", "mlflow.runName": "example"}
assert r.metrics == {"accuracy": 0.9}
assert r.temporal == {"accuracy": [(1, 0.9)]}


def test_import_error_if_mlflow_not_installed() -> None:
with (
patch.dict("sys.modules", {"mlflow.tracking": None}),
pytest.raises(ImportError, match="MLflow source requires `mlflow`"),
):
MLflow(tracking_uri="/fake", experiment_names=["exp"])