Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 114 additions & 27 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ def _auto_preprocess_sample(
ValueError: If the tokenizer is missing/incompatible for chat-format datasets,
or if no recognized column is found.
"""
# ``is not None`` (rather than ``key in sample`` or truthy ``.get(...)``)
# so we skip HF's schema-unification ``None`` padding while still treating
# legitimate empty strings as valid values (the caller filters falsy
# samples downstream). Chat keys still use truthy checks because an empty
# ``messages``/``conversations`` list isn't meaningful to template.
chat_key = next((k for k in ("messages", "conversations") if sample.get(k)), None)
if chat_key is not None:
if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"):
Expand All @@ -191,17 +196,19 @@ def _auto_preprocess_sample(
kwargs["tools"] = tools
return tokenizer.apply_chat_template(sample[chat_key], tokenize=False, **kwargs)

if "prompt" in sample:
if sample.get("prompt") is not None:
parts = [sample["prompt"]]
parts.extend(sample[k] for k in ("completion", "response", "output") if sample.get(k))
parts.extend(
sample[k] for k in ("completion", "response", "output") if sample.get(k) is not None
)
return "\n".join(parts)

if "text" in sample:
if sample.get("text") is not None:
return sample["text"]

if "input" in sample:
if sample.get("input") is not None:
parts = [sample["input"]]
if sample.get("output"):
if sample.get("output") is not None:
parts.append(sample["output"])
return "\n".join(parts)

Expand Down Expand Up @@ -231,6 +238,15 @@ def get_dataset_samples(
or a path to a ``.jsonl`` file. For local directory paths, the
predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
For ``.jsonl`` paths, the file is first loaded via HuggingFace's ``json``
builder and routed through the same auto-preprocess path as unregistered HF
datasets so chat / prompt / text columns are handled consistently with live
HF datasets. If that path fails on JSON parsing or PyArrow schema
unification, it falls back to a line-by-line reader that extracts the
legacy ``text`` field for backward compatibility. The fallback is also
used when the optional ``datasets`` package isn't installed, preserving
legacy plain-``.jsonl`` workflows in base installations. Local JSONL
files only expose the ``train`` split; passing any other ``split`` raises.
num_samples: Number of samples to load from the dataset.
apply_chat_template: Whether to apply the chat template to the samples
(if supported by the dataset). For unregistered datasets with a
Expand All @@ -245,18 +261,43 @@ def get_dataset_samples(
Returns:
Samples: The list of samples.
"""
# Local JSONL file path support (each line is a JSON object with a `text` field).
if dataset_name.endswith(".jsonl"):
return get_jsonl_text_samples(dataset_name, num_samples, key="text")
# Local JSONL: load via HF's ``json`` builder and route through the same
# auto-preprocess path as unregistered HF datasets so chat / prompt / text
# columns are handled consistently with a downloaded HF dataset. Never
# matches ``SUPPORTED_DATASET_CONFIG``.
is_jsonl = dataset_name.endswith(".jsonl") and os.path.isfile(dataset_name)

# HF's file-based builders only expose ``train`` for the ``data_files`` form
# we use, so any other split is a caller error. Surface it up front rather
# than letting ``load_dataset`` fail and silently dropping into the
# text-field fallback (which would ignore the requested split).
if is_jsonl and split is not None:
invalid = [s for s in _normalize_splits(split) if s != "train"]
if invalid:
raise ValueError(
f"Local JSONL files only expose the 'train' split, got {invalid}. "
"Either omit ``split`` or pass ``split='train'``."
)

from datasets import load_dataset
# Lazy ``datasets`` import: legacy ``.jsonl`` workflows historically didn't
# require the optional ``datasets`` extra, so keep them working with just
# the stdlib reader when the package isn't installed.
try:
from datasets import load_dataset
except ImportError:
if is_jsonl:
return get_jsonl_text_samples(dataset_name, num_samples, key="text")
raise

local_dataset_path = None
if os.path.exists(dataset_name): # Local path
local_dataset_path = dataset_name
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))
if not is_jsonl:
# Directory paths may match a registered key via their basename
# (e.g. /hf-local/abisee/cnn_dailymail -> cnn_dailymail).
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))

is_registered = dataset_name in SUPPORTED_DATASET_CONFIG
is_registered = not is_jsonl and dataset_name in SUPPORTED_DATASET_CONFIG

if is_registered:
dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
Expand Down Expand Up @@ -292,29 +333,73 @@ def _preprocess(sample: dict) -> str:
f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. "
"Auto-detecting format from column names."
)
config = {"path": local_dataset_path or dataset_name}
if is_jsonl:
config = {"path": "json", "data_files": local_dataset_path}
else:
config = {"path": local_dataset_path or dataset_name}
# HF's file-based builders (incl. ``json``) label a string/list ``data_files``
# as the ``train`` split unconditionally — the filename on disk is ignored.
# Named splits require a dict ``data_files={"train": ..., "test": ...}``,
# which we don't expose here.
splits = _normalize_splits(split) if split is not None else ["train"]

def _preprocess(sample: dict) -> str:
return _auto_preprocess_sample(sample, dataset_name, tokenizer)

# Narrow the legacy fallback to JSON-parsing / Arrow schema failures. Any
# other error (split-not-found, IO, OOM, ...) should surface to the caller
# rather than be hidden by the text-field reader. Imported lazily because
# the exact module paths vary across versions; an empty tuple is a valid
# ``except`` target that catches nothing if neither is importable.
fallback_types: tuple[type[BaseException], ...] = ()
try:
from datasets.exceptions import DatasetGenerationError

fallback_types += (DatasetGenerationError,)
except ImportError:
pass
try:
from pyarrow.lib import ArrowInvalid

fallback_types += (ArrowInvalid,)
except ImportError:
pass

# load_dataset does not support a list of splits while streaming, so load each separately.
print(f"Loading dataset with {config=} and {splits=}")
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)
try:
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)
num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
num_per_split[-1] += num_samples - sum(num_per_split)
Comment on lines 343 to +373
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Reject an empty split list before computing per-split quotas.

split=[] currently reaches num_samples // len(dataset_splits) and crashes with ZeroDivisionError. Please fail fast with a ValueError once splits has been normalized.

Proposed fix
         splits = _normalize_splits(split) if split is not None else ["train"]
+        if not splits:
+            raise ValueError("``split`` must contain at least one split name.")
@@
-        num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
+        num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/utils/dataset_utils.py` around lines 343 - 373, After
normalizing splits via _normalize_splits (or defaulting to ["train"]), validate
that the resulting splits list is not empty and raise a ValueError if it is;
place this check before calling load_dataset and before computing num_per_split
to avoid a ZeroDivisionError. Update the logic around the splits variable (used
when calling load_dataset(..., split=s) and when computing num_per_split) to
fail fast with a clear message like "empty splits list after normalization"
rather than allowing a division by zero.


return samples
samples: list[str] = []
for dataset, n in zip(dataset_splits, num_per_split):
for i, sample in enumerate(dataset):
if i >= n:
break
text = _preprocess(sample)
if text:
samples.append(text)

return samples
except fallback_types as e:
# Backward-compat fallback: legacy callers passed JSONL files whose only usable
# field is ``text``. If the HF ``json`` builder fails on schema inference or
# JSON parsing, fall back to a line-by-line reader that pulls ``text`` directly.
if not is_jsonl:
raise
assert local_dataset_path is not None # is_jsonl implies the path exists
try:
fallback_samples = get_jsonl_text_samples(local_dataset_path, num_samples, key="text")
except Exception:
# Fallback can't help either — surface the original HF error.
raise e from None
warn(
f"Failed to load {local_dataset_path} via the HF 'json' builder "
f"({type(e).__name__}: {e}); fell back to legacy text-field reader."
)
return fallback_samples


class _CustomDataset(torch.utils.data.Dataset):
Expand Down Expand Up @@ -345,8 +430,10 @@ def get_dataset_dataloader(
"""Get a dataloader with the dataset name and tokenizer of the target model.

Args:
dataset_name: Name of the dataset to load, or a path to a ``.jsonl`` file.
If a ``.jsonl`` file is provided, each line must be a JSON object with a ``text`` field.
dataset_name: Name of the dataset to load, a path to a ``.jsonl`` file, or a list
mixing the two. Each entry is loaded via :func:`get_dataset_samples` and the
resulting samples are concatenated before tokenization. ``num_samples`` may be
an ``int`` (applied to a single source) or a list aligned with ``dataset_name``.
tokenizer: Instance of HuggingFace tokenizer.
batch_size: Batch size of the returned dataloader.
num_samples: Number of samples from the dataset.
Expand Down
Loading
Loading