diff --git a/ablate/sources/mlflow_source.py b/ablate/sources/mlflow_source.py index 41c3362..edd6c59 100644 --- a/ablate/sources/mlflow_source.py +++ b/ablate/sources/mlflow_source.py @@ -40,12 +40,12 @@ def __init__(self, experiment_names: List[str], tracking_uri: str | None) -> Non self.client = MlflowClient(uri) def load(self) -> List[Run]: - runs = self.client.search_runs( - [ - self.client.get_experiment_by_name(n).experiment_id - for n in self.experiment_names - ] - ) + ids = [self.client.get_experiment_by_name(n) for n in self.experiment_names] + if not all(ids): + raise ValueError( + f"One or more experiment names not found: {self.experiment_names}" + ) + runs = self.client.search_runs([e.experiment_id for e in ids if e]) records: List[Run] = [] for run in runs: p, m, t = run.data.params, run.data.metrics, {} diff --git a/tests/sources/test_mlflow_source.py b/tests/sources/test_mlflow_source.py index 5821999..5b8901e 100644 --- a/tests/sources/test_mlflow_source.py +++ b/tests/sources/test_mlflow_source.py @@ -60,3 +60,14 @@ def test_import_error_if_mlflow_not_installed() -> None: pytest.raises(ImportError, match="MLflow source requires `mlflow`"), ): MLflow(tracking_uri="/fake", experiment_names=["exp"]) + + +@patch("mlflow.tracking.MlflowClient") +def test_mlflow_raises_on_invalid_experiment_name(client: MagicMock) -> None: + mock_client = client.return_value + mock_client.get_experiment_by_name.return_value = None + + source = MLflow(tracking_uri="/fake/path", experiment_names=["nonexistent"]) + + with pytest.raises(ValueError, match="One or more experiment names not found"): + source.load()