From bbcfa4f7e73ef3cde017c6011deed2c9cb50f332 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Fri, 24 Apr 2026 15:47:46 -0700 Subject: [PATCH 1/7] support hf jsonl file format Signed-off-by: Shengliang Xu --- modelopt/torch/utils/dataset_utils.py | 35 ++- tests/unit/torch/utils/test_dataset_utils.py | 238 ++++++++++++++++++- 2 files changed, 263 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 01cb3abe88f..b0a8a444a79 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -230,6 +230,9 @@ def get_dataset_samples( or a path to a ``.jsonl`` file. For local directory paths, the predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key). + For ``.jsonl`` paths, the file is loaded via HuggingFace's ``json`` builder and + routed through the same auto-preprocess path as unregistered HF datasets, so + chat / prompt / text columns are handled consistently with live HF datasets. num_samples: Number of samples to load from the dataset. apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset). For unregistered datasets with a @@ -244,18 +247,23 @@ def get_dataset_samples( Returns: Samples: The list of samples. """ - # Local JSONL file path support (each line is a JSON object with a `text` field). - if dataset_name.endswith(".jsonl"): - return get_jsonl_text_samples(dataset_name, num_samples, key="text") - from datasets import load_dataset + # Local JSONL: load via HF's ``json`` builder and route through the same + # auto-preprocess path as unregistered HF datasets so chat / prompt / text + # columns are handled consistently with a downloaded HF dataset. Never + # matches ``SUPPORTED_DATASET_CONFIG``. + is_jsonl = dataset_name.endswith(".jsonl") and os.path.isfile(dataset_name) + local_dataset_path = None if os.path.exists(dataset_name): # Local path local_dataset_path = dataset_name - dataset_name = os.path.basename(os.path.normpath(local_dataset_path)) + if not is_jsonl: + # Directory paths may match a registered key via their basename + # (e.g. /hf-local/abisee/cnn_dailymail -> cnn_dailymail). + dataset_name = os.path.basename(os.path.normpath(local_dataset_path)) - is_registered = dataset_name in SUPPORTED_DATASET_CONFIG + is_registered = not is_jsonl and dataset_name in SUPPORTED_DATASET_CONFIG if is_registered: dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name] @@ -291,7 +299,14 @@ def _preprocess(sample: dict) -> str: f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. " "Auto-detecting format from column names." ) - config = {"path": local_dataset_path or dataset_name} + if is_jsonl: + config = {"path": "json", "data_files": local_dataset_path} + else: + config = {"path": local_dataset_path or dataset_name} + # HF's file-based builders (incl. ``json``) label a string/list ``data_files`` + # as the ``train`` split unconditionally — the filename on disk is ignored. + # Named splits require a dict ``data_files={"train": ..., "test": ...}``, + # which we don't expose here. splits = _normalize_splits(split) if split is not None else ["train"] def _preprocess(sample: dict) -> str: @@ -344,8 +359,10 @@ def get_dataset_dataloader( """Get a dataloader with the dataset name and tokenizer of the target model. Args: - dataset_name: Name of the dataset to load, or a path to a ``.jsonl`` file. - If a ``.jsonl`` file is provided, each line must be a JSON object with a ``text`` field. + dataset_name: Name of the dataset to load, a path to a ``.jsonl`` file, or a list + mixing the two. Each entry is loaded via :func:`get_dataset_samples` and the + resulting samples are concatenated before tokenization. ``num_samples`` may be + an ``int`` (applied to a single source) or a list aligned with ``dataset_name``. tokenizer: Instance of HuggingFace tokenizer. batch_size: Batch size of the returned dataloader. num_samples: Number of samples from the dataset. diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index 9a89d53672e..bd4fceddfb5 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -18,7 +18,11 @@ import pytest import torch -from modelopt.torch.utils.dataset_utils import _process_batch, get_dataset_samples +from modelopt.torch.utils.dataset_utils import ( + _process_batch, + get_dataset_dataloader, + get_dataset_samples, +) def setup_test_data(): @@ -167,3 +171,235 @@ def test_get_dataset_samples_with_unsupported_minipile_dataset(tmp_path, test_lo assert isinstance(samples, list) assert len(samples) == 5 assert all(isinstance(s, str) and len(s) > 0 for s in samples) + + +# --------------------------------------------------------------------------- +# Local JSONL loading — must flow through the same auto-preprocess path as a +# downloaded HF dataset, so chat / prompt / text columns are all handled. +# --------------------------------------------------------------------------- + + +def _write_jsonl(path, rows): + """Write a list of dicts to *path* as JSONL. Returns the path as ``str``.""" + import json + + with open(path, "w", encoding="utf-8") as f: + f.writelines(json.dumps(row) + "\n" for row in rows) + return str(path) + + +@pytest.fixture +def chat_tokenizer(): + """Mock tokenizer whose ``apply_chat_template`` joins messages role:content.""" + tok = Mock() + tok.apply_chat_template = Mock( + side_effect=lambda msgs, tokenize=False, **kw: " | ".join( + f"{m['role']}:{m['content']}" for m in msgs + ) + ) + return tok + + +class TestLocalJsonlLoading: + """Local ``.jsonl`` paths route through HF's ``json`` builder + auto-preprocess.""" + + def test_text_column(self, tmp_path): + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "plain.jsonl", + [{"text": f"plain {i}"} for i in range(3)], + ) + samples = get_dataset_samples(path, num_samples=3) + assert samples == ["plain 0", "plain 1", "plain 2"] + + def test_messages_column_uses_chat_template(self, tmp_path, chat_tokenizer): + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "chat.jsonl", + [ + { + "messages": [ + {"role": "user", "content": f"hello {i}"}, + {"role": "assistant", "content": f"hi {i}"}, + ] + } + for i in range(3) + ], + ) + samples = get_dataset_samples(path, num_samples=3, tokenizer=chat_tokenizer) + assert len(samples) == 3 + assert samples[0] == "user:hello 0 | assistant:hi 0" + # apply_chat_template must have been invoked once per sample + assert chat_tokenizer.apply_chat_template.call_count == 3 + + def test_conversations_column_uses_chat_template(self, tmp_path, chat_tokenizer): + """Auto-detect also recognizes ``conversations`` (Magpie-style).""" + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "convs.jsonl", + [ + { + "conversations": [ + {"role": "user", "content": "q"}, + {"role": "assistant", "content": "a"}, + ] + } + ], + ) + samples = get_dataset_samples(path, num_samples=1, tokenizer=chat_tokenizer) + assert samples == ["user:q | assistant:a"] + + def test_prompt_completion_concatenated(self, tmp_path): + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "prompt.jsonl", + [{"prompt": "Q?", "completion": "A."}], + ) + samples = get_dataset_samples(path, num_samples=1) + assert samples == ["Q?\nA."] + + def test_input_output_concatenated(self, tmp_path): + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "io.jsonl", + [{"input": "in", "output": "out"}], + ) + samples = get_dataset_samples(path, num_samples=1) + assert samples == ["in\nout"] + + def test_num_samples_honored(self, tmp_path): + """Loads only the requested number of rows even when the file is larger.""" + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "many.jsonl", + [{"text": f"row {i}"} for i in range(100)], + ) + samples = get_dataset_samples(path, num_samples=5) + assert len(samples) == 5 + assert samples == [f"row {i}" for i in range(5)] + + def test_tools_forwarded_to_chat_template(self, tmp_path, chat_tokenizer): + """If a row carries a ``tools`` field, it's passed through to apply_chat_template.""" + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "tools.jsonl", + [ + { + "messages": [{"role": "user", "content": "x"}], + "tools": [{"name": "calc"}], + } + ], + ) + get_dataset_samples(path, num_samples=1, tokenizer=chat_tokenizer) + _, kwargs = chat_tokenizer.apply_chat_template.call_args + assert kwargs.get("tools") == [{"name": "calc"}] + + def test_unrecognized_columns_raise(self, tmp_path): + """Auto-detect raises ValueError when no recognized column is present.""" + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "bad.jsonl", + [{"unrelated_field": "value"}], + ) + with pytest.raises(ValueError, match="Cannot auto-detect format"): + get_dataset_samples(path, num_samples=1) + + +# --------------------------------------------------------------------------- +# get_dataset_dataloader — blending across multiple sources +# --------------------------------------------------------------------------- + + +class _FakeTokenizer: + """Minimal callable tokenizer that mimics the HF tokenizer surface used by the dataloader. + + Tokenizes by character ordinal and left-pads to the longest sample (capped at max_length). + Avoids a hard dependency on ``transformers`` in the test environment. + """ + + padding_side = "left" + pad_token_id = 0 + + def __call__(self, texts, return_tensors=None, padding=True, truncation=True, max_length=16): + ids = [[ord(c) % 100 + 1 for c in t][:max_length] for t in texts] + n = max(len(x) for x in ids) + input_ids = [[self.pad_token_id] * (n - len(x)) + x for x in ids] + attention = [[0] * (n - len(x)) + [1] * len(x) for x in ids] + return { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "attention_mask": torch.tensor(attention, dtype=torch.long), + } + + +@pytest.fixture +def pad_tokenizer(): + return _FakeTokenizer() + + +class TestGetDatasetDataloaderBlending: + """``get_dataset_dataloader`` accepts a list of sources and concatenates them.""" + + def test_single_jsonl(self, tmp_path, pad_tokenizer): + pytest.importorskip("datasets") + path = _write_jsonl( + tmp_path / "single.jsonl", + [{"text": f"row {i}"} for i in range(4)], + ) + loader = get_dataset_dataloader( + dataset_name=path, + tokenizer=pad_tokenizer, + batch_size=2, + num_samples=4, + max_sample_length=16, + ) + batches = list(loader) + assert len(batches) == 2 + assert batches[0]["input_ids"].shape[0] == 2 + assert "attention_mask" in batches[0] + + def test_list_of_jsonl_blends(self, tmp_path, pad_tokenizer): + """Two local JSONL files concatenated into a single dataloader.""" + pytest.importorskip("datasets") + a = _write_jsonl(tmp_path / "a.jsonl", [{"text": f"a{i}"} for i in range(3)]) + b = _write_jsonl(tmp_path / "b.jsonl", [{"text": f"b{i}"} for i in range(2)]) + + loader = get_dataset_dataloader( + dataset_name=[a, b], + tokenizer=pad_tokenizer, + batch_size=5, + num_samples=[3, 2], + max_sample_length=16, + ) + batches = list(loader) + assert len(batches) == 1 + assert batches[0]["input_ids"].shape[0] == 5 + + def test_mixed_formats_blended(self, tmp_path, pad_tokenizer): + """Mixing a text-column JSONL with a prompt/completion JSONL — both should flow.""" + pytest.importorskip("datasets") + plain = _write_jsonl(tmp_path / "plain.jsonl", [{"text": "hello"}]) + pc = _write_jsonl(tmp_path / "pc.jsonl", [{"prompt": "Q?", "completion": "A."}]) + + loader = get_dataset_dataloader( + dataset_name=[plain, pc], + tokenizer=pad_tokenizer, + batch_size=2, + num_samples=[1, 1], + max_sample_length=16, + ) + batches = list(loader) + assert len(batches) == 1 + assert batches[0]["input_ids"].shape[0] == 2 + + def test_length_mismatch_raises(self, tmp_path, pad_tokenizer): + """``dataset_name`` and ``num_samples`` lists must align.""" + pytest.importorskip("datasets") + a = _write_jsonl(tmp_path / "a.jsonl", [{"text": "x"}]) + b = _write_jsonl(tmp_path / "b.jsonl", [{"text": "y"}]) + with pytest.raises(AssertionError, match="same length"): + get_dataset_dataloader( + dataset_name=[a, b], + tokenizer=pad_tokenizer, + num_samples=[1], + max_sample_length=16, + ) From 35273990aecf0857016faf26a12d66556291c8a9 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Fri, 24 Apr 2026 15:58:52 -0700 Subject: [PATCH 2/7] keep backward compatibility Signed-off-by: Shengliang Xu --- modelopt/torch/utils/dataset_utils.py | 51 ++++++++++++++------ tests/unit/torch/utils/test_dataset_utils.py | 22 ++++++++- 2 files changed, 57 insertions(+), 16 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index b0a8a444a79..d99c299963f 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -314,21 +314,42 @@ def _preprocess(sample: dict) -> str: # load_dataset does not support a list of splits while streaming, so load each separately. print(f"Loading dataset with {config=} and {splits=}") - dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits] - - num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits) - num_per_split[-1] += num_samples - sum(num_per_split) - - samples: list[str] = [] - for dataset, n in zip(dataset_splits, num_per_split): - for i, sample in enumerate(dataset): - if i >= n: - break - text = _preprocess(sample) - if text: - samples.append(text) - - return samples + try: + dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits] + + num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits) + num_per_split[-1] += num_samples - sum(num_per_split) + + samples: list[str] = [] + for dataset, n in zip(dataset_splits, num_per_split): + for i, sample in enumerate(dataset): + if i >= n: + break + text = _preprocess(sample) + if text: + samples.append(text) + + return samples + except Exception as e: + # Backward-compat fallback: legacy callers passed JSONL files whose only usable + # field is ``text``. If the HF ``json`` builder or auto-detect can't handle the + # file (schema inference error, unrecognized columns, etc.), fall back to a + # line-by-line reader that pulls the ``text`` field directly. + if is_jsonl: + assert local_dataset_path is not None # is_jsonl implies the path exists + try: + fallback_samples = get_jsonl_text_samples( + local_dataset_path, num_samples, key="text" + ) + except Exception: + # Fallback can't help either — surface the original HF error. + raise e from None + warn( + f"Failed to load {local_dataset_path} via the HF 'json' builder " + f"({type(e).__name__}: {e}); fell back to legacy text-field reader." + ) + return fallback_samples + raise class _CustomDataset(torch.utils.data.Dataset): diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index bd4fceddfb5..b351854f6d2 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -295,7 +295,12 @@ def test_tools_forwarded_to_chat_template(self, tmp_path, chat_tokenizer): assert kwargs.get("tools") == [{"name": "calc"}] def test_unrecognized_columns_raise(self, tmp_path): - """Auto-detect raises ValueError when no recognized column is present.""" + """Auto-detect raises ValueError when no recognized column is present. + + The HF builder loads the rows fine; auto-detect rejects them. There's no + ``text`` field to fall back to, so the error propagates instead of being + masked by the legacy fallback. + """ pytest.importorskip("datasets") path = _write_jsonl( tmp_path / "bad.jsonl", @@ -304,6 +309,21 @@ def test_unrecognized_columns_raise(self, tmp_path): with pytest.raises(ValueError, match="Cannot auto-detect format"): get_dataset_samples(path, num_samples=1) + def test_legacy_text_fallback_on_hf_builder_failure(self, tmp_path): + """If the HF json builder raises, fall back to the legacy text-field reader.""" + pytest.importorskip("datasets") + # Mixed-type ``meta`` field across rows — int vs string — trips PyArrow + # schema unification in the HF json builder. The rows still carry a + # ``text`` field, so the legacy reader can recover the samples. + rows = [ + {"text": "row a", "meta": 1}, + {"text": "row b", "meta": "two"}, + {"text": "row c", "meta": 3}, + ] + path = _write_jsonl(tmp_path / "mixed.jsonl", rows) + samples = get_dataset_samples(path, num_samples=3) + assert samples == ["row a", "row b", "row c"] + # --------------------------------------------------------------------------- # get_dataset_dataloader — blending across multiple sources From ad3a37a080be3377305a6a1315f9d8f316fcf968 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Fri, 24 Apr 2026 16:20:53 -0700 Subject: [PATCH 3/7] more tests for general dataset loading functionalities Signed-off-by: Shengliang Xu --- tests/unit/torch/utils/test_dataset_utils.py | 81 ++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index b351854f6d2..fbaa29e2904 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -423,3 +423,84 @@ def test_length_mismatch_raises(self, tmp_path, pad_tokenizer): num_samples=[1], max_sample_length=16, ) + + +# --------------------------------------------------------------------------- +# Live HF dataset round-trips. ``hf-internal-testing/dataset_with_data_files`` +# is a 10-row x {train,test} fixture maintained by HF for their own CI — tiny +# enough to download in a unit test and stable across releases. +# --------------------------------------------------------------------------- + +_HF_TINY = "hf-internal-testing/dataset_with_data_files" # train, test splits, ``text`` col + + +def _hf_dump_to_jsonl(name: str, split: str, path) -> str: + from datasets import load_dataset + + ds = load_dataset(name, split=split) + ds.to_json(str(path), lines=True) + return str(path) + + +class TestHfTinyDataset: + """End-to-end coverage with a real (tiny) HF dataset.""" + + def test_load_single_split_directly(self): + pytest.importorskip("datasets") + samples = get_dataset_samples(_HF_TINY, num_samples=4, split="train") + assert len(samples) == 4 + assert all(isinstance(s, str) and s for s in samples) + + def test_load_multiple_splits_directly(self): + """``split=["train", "test"]`` divides ``num_samples`` across both splits.""" + pytest.importorskip("datasets") + samples = get_dataset_samples(_HF_TINY, num_samples=6, split=["train", "test"]) + assert len(samples) == 6 + # Default per-split is num_samples // n + remainder; for 6/2 → 3 from each. + # We can't assert exact origin without re-reading, but both splits should + # contribute, which we'll confirm by comparing against direct loads below. + train_only = set(get_dataset_samples(_HF_TINY, num_samples=10, split="train")) + test_only = set(get_dataset_samples(_HF_TINY, num_samples=10, split="test")) + assert any(s in train_only for s in samples) + assert any(s in test_only for s in samples) + + def test_default_split_is_train(self): + pytest.importorskip("datasets") + default_samples = get_dataset_samples(_HF_TINY, num_samples=4) + train_samples = get_dataset_samples(_HF_TINY, num_samples=4, split="train") + assert default_samples == train_samples + + def test_download_to_jsonl_then_load(self, tmp_path): + """Dump the HF dataset to JSONL, then reload it via the local-jsonl path.""" + pytest.importorskip("datasets") + jsonl_path = _hf_dump_to_jsonl(_HF_TINY, "train", tmp_path / "train.jsonl") + from_jsonl = get_dataset_samples(jsonl_path, num_samples=10) + from_hf = get_dataset_samples(_HF_TINY, num_samples=10, split="train") + assert from_jsonl == from_hf + + def test_dataloader_blending_two_hf_datasets(self, pad_tokenizer): + """Two HF datasets concatenated via ``get_dataset_dataloader``.""" + pytest.importorskip("datasets") + loader = get_dataset_dataloader( + dataset_name=[_HF_TINY, "hf-internal-testing/multi_dir_dataset"], + tokenizer=pad_tokenizer, + batch_size=4, + num_samples=[3, 1], + max_sample_length=16, + ) + batches = list(loader) + assert sum(b["input_ids"].shape[0] for b in batches) == 4 + + def test_dataloader_mixing_hf_and_local_jsonl(self, tmp_path, pad_tokenizer): + """Live HF dataset blended with a local synthetic JSONL file.""" + pytest.importorskip("datasets") + local = _write_jsonl(tmp_path / "local.jsonl", [{"text": f"local {i}"} for i in range(2)]) + loader = get_dataset_dataloader( + dataset_name=[_HF_TINY, local], + tokenizer=pad_tokenizer, + batch_size=5, + num_samples=[3, 2], + max_sample_length=16, + ) + batches = list(loader) + assert sum(b["input_ids"].shape[0] for b in batches) == 5 From 165c7450e16ea42db2747edf6de86626c3b30b8e Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Fri, 24 Apr 2026 16:23:48 -0700 Subject: [PATCH 4/7] Fix comments Signed-off-by: Shengliang Xu --- modelopt/torch/utils/dataset_utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index d99c299963f..3ea345c09f8 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -230,9 +230,12 @@ def get_dataset_samples( or a path to a ``.jsonl`` file. For local directory paths, the predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key). - For ``.jsonl`` paths, the file is loaded via HuggingFace's ``json`` builder and - routed through the same auto-preprocess path as unregistered HF datasets, so - chat / prompt / text columns are handled consistently with live HF datasets. + For ``.jsonl`` paths, the file is first loaded via HuggingFace's ``json`` + builder and routed through the same auto-preprocess path as unregistered HF + datasets so chat / prompt / text columns are handled consistently with live + HF datasets. If that path fails (e.g. PyArrow schema unification across + heterogeneous rows), it falls back to a line-by-line reader that extracts + the legacy ``text`` field for backward compatibility. num_samples: Number of samples to load from the dataset. apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset). For unregistered datasets with a From 9a94525e1f7892052f2c561102d2a97a1619970b Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Wed, 29 Apr 2026 14:45:27 -0700 Subject: [PATCH 5/7] lazy HF dependency Signed-off-by: Shengliang Xu --- modelopt/torch/utils/dataset_utils.py | 87 +++++++++++++++++++-------- 1 file changed, 63 insertions(+), 24 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index dec15ac2465..04ea726c118 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -234,9 +234,12 @@ def get_dataset_samples( For ``.jsonl`` paths, the file is first loaded via HuggingFace's ``json`` builder and routed through the same auto-preprocess path as unregistered HF datasets so chat / prompt / text columns are handled consistently with live - HF datasets. If that path fails (e.g. PyArrow schema unification across - heterogeneous rows), it falls back to a line-by-line reader that extracts - the legacy ``text`` field for backward compatibility. + HF datasets. If that path fails on JSON parsing or PyArrow schema + unification, it falls back to a line-by-line reader that extracts the + legacy ``text`` field for backward compatibility. The fallback is also + used when the optional ``datasets`` package isn't installed, preserving + legacy plain-``.jsonl`` workflows in base installations. Local JSONL + files only expose the ``train`` split; passing any other ``split`` raises. num_samples: Number of samples to load from the dataset. apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset). For unregistered datasets with a @@ -251,14 +254,34 @@ def get_dataset_samples( Returns: Samples: The list of samples. """ - from datasets import load_dataset - # Local JSONL: load via HF's ``json`` builder and route through the same # auto-preprocess path as unregistered HF datasets so chat / prompt / text # columns are handled consistently with a downloaded HF dataset. Never # matches ``SUPPORTED_DATASET_CONFIG``. is_jsonl = dataset_name.endswith(".jsonl") and os.path.isfile(dataset_name) + # HF's file-based builders only expose ``train`` for the ``data_files`` form + # we use, so any other split is a caller error. Surface it up front rather + # than letting ``load_dataset`` fail and silently dropping into the + # text-field fallback (which would ignore the requested split). + if is_jsonl and split is not None: + invalid = [s for s in _normalize_splits(split) if s != "train"] + if invalid: + raise ValueError( + f"Local JSONL files only expose the 'train' split, got {invalid}. " + "Either omit ``split`` or pass ``split='train'``." + ) + + # Lazy ``datasets`` import: legacy ``.jsonl`` workflows historically didn't + # require the optional ``datasets`` extra, so keep them working with just + # the stdlib reader when the package isn't installed. + try: + from datasets import load_dataset + except ImportError: + if is_jsonl: + return get_jsonl_text_samples(dataset_name, num_samples, key="text") + raise + local_dataset_path = None if os.path.exists(dataset_name): # Local path local_dataset_path = dataset_name @@ -316,6 +339,25 @@ def _preprocess(sample: dict) -> str: def _preprocess(sample: dict) -> str: return _auto_preprocess_sample(sample, dataset_name, tokenizer) + # Narrow the legacy fallback to JSON-parsing / Arrow schema failures. Any + # other error (split-not-found, IO, OOM, ...) should surface to the caller + # rather than be hidden by the text-field reader. Imported lazily because + # the exact module paths vary across versions; an empty tuple is a valid + # ``except`` target that catches nothing if neither is importable. + fallback_types: tuple[type[BaseException], ...] = () + try: + from datasets.exceptions import DatasetGenerationError + + fallback_types += (DatasetGenerationError,) + except ImportError: + pass + try: + from pyarrow.lib import ArrowInvalid + + fallback_types += (ArrowInvalid,) + except ImportError: + pass + # load_dataset does not support a list of splits while streaming, so load each separately. print(f"Loading dataset with {config=} and {splits=}") try: @@ -334,26 +376,23 @@ def _preprocess(sample: dict) -> str: samples.append(text) return samples - except Exception as e: + except fallback_types as e: # Backward-compat fallback: legacy callers passed JSONL files whose only usable - # field is ``text``. If the HF ``json`` builder or auto-detect can't handle the - # file (schema inference error, unrecognized columns, etc.), fall back to a - # line-by-line reader that pulls the ``text`` field directly. - if is_jsonl: - assert local_dataset_path is not None # is_jsonl implies the path exists - try: - fallback_samples = get_jsonl_text_samples( - local_dataset_path, num_samples, key="text" - ) - except Exception: - # Fallback can't help either — surface the original HF error. - raise e from None - warn( - f"Failed to load {local_dataset_path} via the HF 'json' builder " - f"({type(e).__name__}: {e}); fell back to legacy text-field reader." - ) - return fallback_samples - raise + # field is ``text``. If the HF ``json`` builder fails on schema inference or + # JSON parsing, fall back to a line-by-line reader that pulls ``text`` directly. + if not is_jsonl: + raise + assert local_dataset_path is not None # is_jsonl implies the path exists + try: + fallback_samples = get_jsonl_text_samples(local_dataset_path, num_samples, key="text") + except Exception: + # Fallback can't help either — surface the original HF error. + raise e from None + warn( + f"Failed to load {local_dataset_path} via the HF 'json' builder " + f"({type(e).__name__}: {e}); fell back to legacy text-field reader." + ) + return fallback_samples class _CustomDataset(torch.utils.data.Dataset): From 409eb3857d95767ae8c7f8c5d3c17e626182ac62 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Wed, 29 Apr 2026 14:56:46 -0700 Subject: [PATCH 6/7] fix review Signed-off-by: Shengliang Xu --- modelopt/torch/utils/dataset_utils.py | 12 +++++++++--- tests/unit/torch/utils/test_dataset_utils.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 04ea726c118..8f6d1d14cad 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -178,6 +178,12 @@ def _auto_preprocess_sample( ValueError: If the tokenizer is missing/incompatible for chat-format datasets, or if no recognized column is found. """ + # Truthy ``sample.get`` checks instead of ``key in sample``: HF's schema + # unification fills missing values with ``None`` across heterogeneous JSONL + # rows, so a row that only has ``text`` would still expose ``prompt=None`` + # in the unified schema. Falling through on null/empty lets such rows + # match the next column (e.g. ``text``) instead of crashing on + # ``"\n".join([None])``. chat_key = next((k for k in ("messages", "conversations") if sample.get(k)), None) if chat_key is not None: if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"): @@ -191,15 +197,15 @@ def _auto_preprocess_sample( kwargs["tools"] = tools return tokenizer.apply_chat_template(sample[chat_key], tokenize=False, **kwargs) - if "prompt" in sample: + if sample.get("prompt"): parts = [sample["prompt"]] parts.extend(sample[k] for k in ("completion", "response", "output") if sample.get(k)) return "\n".join(parts) - if "text" in sample: + if sample.get("text"): return sample["text"] - if "input" in sample: + if sample.get("input"): parts = [sample["input"]] if sample.get("output"): parts.append(sample["output"]) diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index a773ce5cbfd..5ff60b53547 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -392,6 +392,26 @@ def test_unrecognized_columns_raise(self, tmp_path): with pytest.raises(ValueError, match="Cannot auto-detect format"): get_dataset_samples(path, num_samples=1) + def test_sparse_recognized_column_falls_through_to_text(self, tmp_path): + """Sparse ``prompt`` column (None on most rows) must not shadow ``text``. + + HF's schema unification fills missing values with None across heterogeneous + rows, so a row with only ``text`` ends up exposing ``prompt=None`` in the + unified schema. Auto-detect must skip null-valued recognized columns + rather than crash on ``"\\n".join([None])``. + """ + pytest.importorskip("datasets") + rows = [ + {"text": "row a"}, + {"text": "row b", "prompt": "ignored", "completion": "stuff"}, + {"text": "row c"}, + ] + path = _write_jsonl(tmp_path / "sparse.jsonl", rows) + samples = get_dataset_samples(path, num_samples=3) + # text-only rows fall through to ``text``; the prompt-bearing row uses + # the prompt+completion path. + assert samples == ["row a", "ignored\nstuff", "row c"] + def test_legacy_text_fallback_on_hf_builder_failure(self, tmp_path): """If the HF json builder raises, fall back to the legacy text-field reader.""" pytest.importorskip("datasets") From ccbfc6464631f77ddd1bb71bec1dc9434d712600 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Wed, 29 Apr 2026 15:02:25 -0700 Subject: [PATCH 7/7] more comments Signed-off-by: Shengliang Xu --- modelopt/torch/utils/dataset_utils.py | 23 ++++++++++---------- tests/unit/torch/utils/test_dataset_utils.py | 20 +++++++++++++++++ 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 8f6d1d14cad..ed1efce881c 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -178,12 +178,11 @@ def _auto_preprocess_sample( ValueError: If the tokenizer is missing/incompatible for chat-format datasets, or if no recognized column is found. """ - # Truthy ``sample.get`` checks instead of ``key in sample``: HF's schema - # unification fills missing values with ``None`` across heterogeneous JSONL - # rows, so a row that only has ``text`` would still expose ``prompt=None`` - # in the unified schema. Falling through on null/empty lets such rows - # match the next column (e.g. ``text``) instead of crashing on - # ``"\n".join([None])``. + # ``is not None`` (rather than ``key in sample`` or truthy ``.get(...)``) + # so we skip HF's schema-unification ``None`` padding while still treating + # legitimate empty strings as valid values (the caller filters falsy + # samples downstream). Chat keys still use truthy checks because an empty + # ``messages``/``conversations`` list isn't meaningful to template. chat_key = next((k for k in ("messages", "conversations") if sample.get(k)), None) if chat_key is not None: if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"): @@ -197,17 +196,19 @@ def _auto_preprocess_sample( kwargs["tools"] = tools return tokenizer.apply_chat_template(sample[chat_key], tokenize=False, **kwargs) - if sample.get("prompt"): + if sample.get("prompt") is not None: parts = [sample["prompt"]] - parts.extend(sample[k] for k in ("completion", "response", "output") if sample.get(k)) + parts.extend( + sample[k] for k in ("completion", "response", "output") if sample.get(k) is not None + ) return "\n".join(parts) - if sample.get("text"): + if sample.get("text") is not None: return sample["text"] - if sample.get("input"): + if sample.get("input") is not None: parts = [sample["input"]] - if sample.get("output"): + if sample.get("output") is not None: parts.append(sample["output"]) return "\n".join(parts) diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index 5ff60b53547..08602279504 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -412,6 +412,26 @@ def test_sparse_recognized_column_falls_through_to_text(self, tmp_path): # the prompt+completion path. assert samples == ["row a", "ignored\nstuff", "row c"] + def test_empty_string_columns_treated_as_present(self, tmp_path): + """Empty strings are valid values, not absent columns. + + ``prompt=""`` should still take the prompt path (caller filters empty + results downstream), and ``text=""`` rows must not crash the load — + only ``None`` should fall through to the next column. + """ + pytest.importorskip("datasets") + rows = [ + {"text": ""}, # blank but valid; caller drops empty samples + {"prompt": "", "completion": "from-completion"}, + {"text": "kept"}, + ] + path = _write_jsonl(tmp_path / "blank.jsonl", rows) + samples = get_dataset_samples(path, num_samples=3) + # ``{"text": ""}`` produces "" and is filtered by the caller. + # ``{"prompt": "", "completion": "from-completion"}`` joins to + # "\nfrom-completion". ``{"text": "kept"}`` is kept verbatim. + assert samples == ["\nfrom-completion", "kept"] + def test_legacy_text_fallback_on_hf_builder_failure(self, tmp_path): """If the HF json builder raises, fall back to the legacy text-field reader.""" pytest.importorskip("datasets")