diff --git a/.gitignore b/.gitignore index fa54ee7..ae8371f 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,6 @@ docs/modlyn.* lamin_sphinx docs/conf.py _docs_tmp* + +docs/test-modlyn/ +lightning_logs/ diff --git a/docs/quickstart.ipynb b/docs/quickstart.ipynb index f79efc9..b773cb8 100644 --- a/docs/quickstart.ipynb +++ b/docs/quickstart.ipynb @@ -26,12 +26,8 @@ { "cell_type": "code", "execution_count": null, - "id": "453f6f89", - "metadata": { - "tags": [ - "hide-output" - ] - }, + "id": "35122bdc", + "metadata": {}, "outputs": [], "source": [ "import lamindb as ln\n", @@ -47,42 +43,89 @@ { "cell_type": "code", "execution_count": null, - "id": "980a05b7", - "metadata": { - "tags": [ - "hide-output" - ] - }, + "id": "9708b93e", + "metadata": {}, "outputs": [], "source": [ "ln.track()" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "fffe8a48", + "metadata": {}, + "outputs": [], + "source": [ + "# Configuration: switch between in-memory and Dask loader\n", + "USE_DASK = True # set False to use in-memory path\n", + "ZARR_UID = \"1xSHIdfBjfUdxKHm0000\" # example UID; change as needed\n", + "LABEL_COL = \"cell_line\"\n", + "\n", + "# Dask runtime\n", + "DASK_DATASET_TYPE = \"arrayloaders-dasd\" # accepted alias (normalized internally)\n", + "BATCH_SIZE = 256\n", + "N_CHUNKS = 8\n", + "DASK_SCHEDULER = \"threads\"" + ] + }, { "cell_type": "markdown", - "id": "c8ad0ac1", + "id": "5086e159", "metadata": {}, "source": [ - "## Prepare dataset" + "### Using a custom Dask data loader\n", + "Set `USE_DASK = True` and provide a zarr `ZARR_UID` from `laminlabs/arrayloader-benchmarks`.\n", + "The loader auto-detects whether the cached path is a single zarr store or a directory of shard stores (`*.zarr`) and selects the right reader. For quick runs, we cap steps with `max_steps` in the training call.\n" ] }, { "cell_type": "code", "execution_count": null, - "id": "dfb07f4c", - "metadata": { - "tags": [ - "hide-output" - ] - }, + "id": "30985561", + "metadata": {}, "outputs": [], "source": [ - "artifact = ln.Artifact.using(\"laminlabs/arrayloader-benchmarks\").get(\n", - " \"JNaxQe8zbljesdbK0000\"\n", - ")\n", - "adata = artifact.load()\n", - "sc.pp.log1p(adata)\n", - "adata" + "from pathlib import Path\n", + "import lamindb as ln\n", + "\n", + "if USE_DASK:\n", + " artifact = ln.Artifact.using(\"laminlabs/arrayloader-benchmarks\").get(ZARR_UID)\n", + " store_path = Path(artifact.cache())\n", + " if not store_path.is_dir():\n", + " raise ValueError(f\"ZARR_UID must cache to a directory, got: {store_path}\")\n", + "\n", + " # Decide between a directory of shards (*.zarr) vs a single zarr store\n", + " has_shards = any(child.name.endswith(\".zarr\") for child in store_path.iterdir())\n", + "\n", + " try:\n", + " from arrayloaders.io.dask_loader import read_lazy_store\n", + " except Exception:\n", + " read_lazy_store = None\n", + " from arrayloaders.io import read_lazy as read_single_store\n", + "\n", + " if has_shards and read_lazy_store is not None:\n", + " adata = read_lazy_store(store_path, obs_columns=[LABEL_COL])\n", + " else:\n", + " # Single zarr store\n", + " adata = read_single_store(store_path, obs_columns=[LABEL_COL])\n", + "else:\n", + " # Example H5AD path (keep your current artifact if you prefer)\n", + " artifact = ln.Artifact.using(\"laminlabs/arrayloader-benchmarks\").get(\n", + " \"JNaxQe8zbljesdbK0000\"\n", + " )\n", + " adata = artifact.load()\n", + " sc.pp.log1p(adata)\n", + "\n", + "print(\"adata:\", adata.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "c8ad0ac1", + "metadata": {}, + "source": [ + "## Prepare dataset" ] }, { @@ -136,16 +179,55 @@ "source": [ "logreg = mn.models.SimpleLogReg(\n", " adata=adata,\n", - " label_column=\"cell_line\",\n", + " label_column=LABEL_COL,\n", " learning_rate=1e-1,\n", " weight_decay=1e-3,\n", ")\n", + "\n", + "fit_kwargs = {\n", + " \"adata_train\": adata,\n", + " \"adata_val\": None,\n", + " \"train_dataloader_kwargs\": {\n", + " \"batch_size\": BATCH_SIZE,\n", + " \"drop_last\": False,\n", + " \"num_workers\": 0,\n", + " },\n", + " \"max_epochs\": 1,\n", + " \"num_sanity_val_steps\": 0,\n", + " \"log_every_n_steps\": 1,\n", + " \"max_steps\": 50,\n", + "}\n", + "\n", + "if USE_DASK:\n", + " fit_kwargs.update(\n", + " {\n", + " \"dataset_type\": DASK_DATASET_TYPE,\n", + " \"n_chunks\": N_CHUNKS,\n", + " \"dask_scheduler\": DASK_SCHEDULER,\n", + " }\n", + " )\n", + "\n", + "# logreg.fit(**fit_kwargs)\n", "logreg.fit(\n", " adata_train=adata,\n", - " adata_val=adata[:20],\n", - " train_dataloader_kwargs={\"batch_size\": 128, \"drop_last\": True, \"num_workers\": 4},\n", - " max_epochs=5,\n", - ")" + " adata_val=adata, # reuse the lazy dataset so val has batches\n", + " train_dataloader_kwargs={\n", + " \"batch_size\": BATCH_SIZE,\n", + " \"drop_last\": False,\n", + " \"num_workers\": 0,\n", + " },\n", + " dataset_type=DASK_DATASET_TYPE,\n", + " n_chunks=N_CHUNKS,\n", + " dask_scheduler=DASK_SCHEDULER,\n", + " max_epochs=1,\n", + " num_sanity_val_steps=0,\n", + " log_every_n_steps=1,\n", + " max_steps=50,\n", + ")\n", + "\n", + "\n", + "print(\"dataset_type:\", getattr(logreg.datamodule, \"dataset_type\", \"in-memory\"))\n", + "print(\"train_dataset:\", type(logreg.datamodule.train_dataloader().dataset).__name__)" ] }, { @@ -174,7 +256,14 @@ }, "outputs": [], "source": [ - "logreg.plot_classification_report(adata)" + "# eval subset\n", + "adata_eval = adata[:10000]\n", + "adata_eval = adata_eval.to_memory() if hasattr(adata_eval, \"to_memory\") else adata_eval\n", + "\n", + "if hasattr(adata_eval.X, \"compute\"):\n", + " adata_eval.X = adata_eval.X.compute()\n", + "\n", + "logreg.plot_classification_report(adata_eval)" ] }, { @@ -313,7 +402,7 @@ "notebook_metadata_filter": "-all" }, "kernelspec": { - "display_name": "py312", + "display_name": "lamin_env", "language": "python", "name": "python3" }, @@ -327,7 +416,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.8" + "version": "3.12.10" } }, "nbformat": 4, diff --git a/modlyn/models/_simple_logreg_datamodule.py b/modlyn/models/_simple_logreg_datamodule.py index 0e2d392..d73fa8e 100644 --- a/modlyn/models/_simple_logreg_datamodule.py +++ b/modlyn/models/_simple_logreg_datamodule.py @@ -4,7 +4,6 @@ import lightning as L import torch -from arrayloaders.io.dask_loader import DaskDataset from sklearn.preprocessing import LabelEncoder from torch.utils.data import DataLoader, TensorDataset @@ -82,8 +81,9 @@ def __init__( self.n_chunks = n_chunks self.dask_scheduler = dask_scheduler - # Fit label encoder on training data (only needed for tensor datasets) - if self.dataset_type == "in-memory" and self.adata_train is not None: + # Fit label encoder on training data (used by both backends) + self.label_encoder = None + if self.adata_train is not None: self.label_encoder = LabelEncoder() self.label_encoder.fit(self.adata_train.obs[self.label_col]) @@ -107,6 +107,13 @@ def _create_tensor_dataset(self, adata): def _create_dask_dataset(self, adata, shuffle=True): """Create a DaskDataset from AnnData.""" + try: + from arrayloaders.io.dask_loader import DaskDataset # lazy import + except Exception as e: + raise ImportError( + "arrayloaders is required for dataset_type='dask-arrayloader'. Install with `pip install arrayloaders`." + ) from e + return DaskDataset( adata, label_column=self.label_col, @@ -115,28 +122,81 @@ def _create_dask_dataset(self, adata, shuffle=True): dask_scheduler=self.dask_scheduler, ) + def _collate_dask_batch(self, batch): + """Collate function for DaskDataset batches -> (x_tensor, y_tensor).""" + import numpy as np + import torch + + try: + import scipy.sparse as sp + except Exception: # pragma: no cover - optional + sp = None + + if not batch: + return torch.empty(0), torch.empty(0, dtype=torch.long) + first = batch[0] + if isinstance(first, tuple) and len(first) == 3: + xs, ys, _ = zip(*batch, strict=False) + else: + xs, ys = zip(*batch, strict=False) + if self.label_encoder is None: + raise RuntimeError("label_encoder not initialized") + # Encode labels; fallback to ints if encoder mismatch occurs + try: + y_enc = self.label_encoder.transform(list(ys)) + except Exception: + y_enc = np.array([int(y) for y in ys], dtype=np.int64) + # ensure each row is a contiguous 1D float32 array; handle sparse and object types + xs_arr = [] + for x in xs: + # densify sparse rows + if sp is not None and getattr(sp, "issparse", None) and sp.issparse(x): + arr = x.toarray() + else: + arr = np.asarray(x) + # flatten any 2D shapes (e.g., 1 x n_vars) + if arr.ndim > 1: + arr = arr.ravel() + # robust dtype conversion + if arr.dtype == object: + # last-resort element-wise float coercion + try: + arr = arr.astype(np.float32, copy=False) + except Exception: + arr = np.array([float(v) for v in arr], dtype=np.float32) + else: + arr = arr.astype(np.float32, copy=False) + xs_arr.append(arr) + x_tensor = torch.as_tensor(np.stack(xs_arr, axis=0), dtype=torch.float32) + y_tensor = torch.as_tensor(y_enc, dtype=torch.long) + return x_tensor, y_tensor + def train_dataloader(self): if self.adata_train is None: raise ValueError("adata_train is None") + kwargs = dict(self.train_dataloader_kwargs) if self.dataset_type == "in-memory": train_dataset = self._create_tensor_dataset(self.adata_train) elif self.dataset_type == "dask-arrayloader": train_dataset = self._create_dask_dataset(self.adata_train, shuffle=True) + kwargs.setdefault("collate_fn", self._collate_dask_batch) else: raise ValueError(f"Unknown dataset_type: {self.dataset_type}") - return DataLoader(train_dataset, **self.train_dataloader_kwargs) + return DataLoader(train_dataset, **kwargs) def val_dataloader(self): if self.adata_val is None: - return None + return [] + kwargs = dict(self.val_dataloader_kwargs) if self.dataset_type == "in-memory": val_dataset = self._create_tensor_dataset(self.adata_val) elif self.dataset_type == "dask-arrayloader": val_dataset = self._create_dask_dataset(self.adata_val, shuffle=False) + kwargs.setdefault("collate_fn", self._collate_dask_batch) else: raise ValueError(f"Unknown dataset_type: {self.dataset_type}") - return DataLoader(val_dataset, **self.val_dataloader_kwargs) + return DataLoader(val_dataset, **kwargs) diff --git a/modlyn/models/_simple_logreg_model.py b/modlyn/models/_simple_logreg_model.py index bea735c..ed5174f 100644 --- a/modlyn/models/_simple_logreg_model.py +++ b/modlyn/models/_simple_logreg_model.py @@ -113,6 +113,10 @@ def fit( adata_val: ad.AnnData | None, train_dataloader_kwargs=None, val_dataloader_kwargs=None, + # dataset backend configuration + dataset_type: str = "in-memory", + n_chunks: int = 8, + dask_scheduler: str = "threads", max_epochs: int = 4, log_every_n_steps: int = 1, num_sanity_val_steps: int = 0, @@ -125,18 +129,35 @@ def fit( adata_val: `AnnData` object containing the validation data. train_dataloader_kwargs: Additional keyword arguments passed to the torch DataLoader for the training dataset. val_dataloader_kwargs: Additional keyword arguments passed to the torch DataLoader for the validation dataset. + dataset_type: Backend to use: "in-memory" or "dask-arrayloader" (aliases accepted). + n_chunks: Number of dask chunks to combine per iteration (Dask backend only). + dask_scheduler: Dask scheduler to use, e.g., "threads" or "synchronous" (Dask backend only). max_epochs: Maximum number of epochs to train. log_every_n_steps: Log training metrics every n steps. num_sanity_val_steps: Number of sanity validation steps to run before training. max_steps: Maximum number of training steps. """ + # normalize dataset_type aliases (robust to common typos and synonyms) + normalized_dataset_type = { + "in_memory": "in-memory", + "in-memory": "in-memory", + "memory": "in-memory", + "dask": "dask-arrayloader", + "arrayloaders-dask": "dask-arrayloader", + "arrayloaders-dasd": "dask-arrayloader", # common typo / requested alias + "dask-arrayloader": "dask-arrayloader", + }.get(dataset_type, dataset_type) + self.datamodule = SimpleLogRegDataModule( adata_train=adata_train, adata_val=adata_val, label_column=self.label_column, + dataset_type=normalized_dataset_type, # type: ignore[arg-type] train_dataloader_kwargs=train_dataloader_kwargs, val_dataloader_kwargs=val_dataloader_kwargs, + n_chunks=n_chunks, + dask_scheduler=dask_scheduler, # type: ignore[arg-type] ) self.trainer = L.Trainer( max_epochs=max_epochs, @@ -149,10 +170,23 @@ def fit( def get_weights(self) -> pd.DataFrame: """Get the weights of the linear layer as a DataFrame.""" weights = self.linear.weight.detach().numpy() # shape: (n_classes, n_genes) + # Prefer label encoder classes if available, otherwise fall back to labels + try: + class_index = self.datamodule.label_encoder.classes_ # type: ignore[attr-defined] + except Exception: + labels = self._adata.obs[self.label_column] + if ( + hasattr(labels, "cat") + and getattr(labels.dtype, "name", "") == "category" + ): + class_index = list(labels.cat.categories) + else: + class_index = list(pd.unique(labels)) + df = pd.DataFrame( weights, columns=self._adata.var_names, - index=self.datamodule.label_encoder.classes_, + index=class_index, ) df.attrs["method_name"] = "modlyn_logreg" return df diff --git a/tests/test_dataset_type_alias.py b/tests/test_dataset_type_alias.py new file mode 100644 index 0000000..2584452 --- /dev/null +++ b/tests/test_dataset_type_alias.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +import sys +import types + +import anndata as ad +import numpy as np +import pandas as pd +import torch +from torch.utils.data import IterableDataset + + +def test_dataset_type_alias_normalizes_and_trains(): + # Inject a fake DaskDataset into the expected import path + fake_pkg = types.ModuleType("arrayloaders") + fake_io = types.ModuleType("arrayloaders.io") + fake_dl = types.ModuleType("arrayloaders.io.dask_loader") + + class FakeDaskDataset(IterableDataset): + def __init__( + self, + adata, + label_column: str, + shuffle: bool, + n_chunks: int, + dask_scheduler: str, + ): + X = adata.X.toarray() if hasattr(adata.X, "toarray") else adata.X + self.X = X.astype("float32") + self.y = pd.Categorical(adata.obs[label_column]).codes.astype("int64") + + def __iter__(self): + for i in range(self.X.shape[0]): + yield self.X[i], int(self.y[i]) + + fake_dl.DaskDataset = FakeDaskDataset + sys.modules["arrayloaders"] = fake_pkg + sys.modules["arrayloaders.io"] = fake_io + sys.modules["arrayloaders.io.dask_loader"] = fake_dl + + # Small synthetic dataset (Generator API per NPY002) + rng = np.random.default_rng(0) + X = rng.random((64, 8)).astype("float32") + obs = pd.DataFrame({"cell_line": rng.choice(["A", "B", "C"], size=64)}) + adata = ad.AnnData(X=X, obs=obs) + + from modlyn.models import SimpleLogReg + + model = SimpleLogReg(adata=adata, label_column="cell_line") + model.fit( + adata_train=adata, + adata_val=None, + train_dataloader_kwargs={"batch_size": 16, "num_workers": 0}, + dataset_type="arrayloaders-dasd", # alias to be normalized + n_chunks=2, + dask_scheduler="threads", + max_epochs=1, + num_sanity_val_steps=0, + max_steps=5, + ) + + assert model.datamodule is not None + assert model.datamodule.dataset_type == "dask-arrayloader" diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index 5978fc6..5b7da10 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -1,3 +1,10 @@ +import os + +import pytest + +if os.environ.get("CI"): + pytest.skip("Skip docs notebooks in CI", allow_module_level=True) + from pathlib import Path import nbproject_test as test