diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index c4da627f14..962f680b4f 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -89,8 +89,10 @@ def get_args() -> argparse.Namespace: "--calib_dataset_name", type=str, default="nemotron-post-training-dataset-v2", - choices=get_supported_datasets(), - help="Dataset name for calibration", + help=( + f"HF Dataset name or local path for calibration (supported options: {', '.join(get_supported_datasets())}. " + "You can also pass any other dataset and see if auto-detection for your dataset works." + ), ) parser.add_argument( "--calib_num_samples", type=int, default=1024, help="Number of samples for calibration" diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index d3e2239baa..29e0d8a882 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -17,10 +17,13 @@ import copy import json +import os from collections.abc import Callable +from pathlib import Path from typing import TYPE_CHECKING, Any from warnings import warn +import requests import torch from torch.utils.data import DataLoader from tqdm import tqdm @@ -48,9 +51,9 @@ "name": "SFT", "split": ["code", "math", "science", "chat", "safety"], }, - "preprocess": lambda sample: "\n".join(turn["content"] for turn in sample["input"]) - + "\n" - + sample["output"], + "preprocess": lambda sample: ( + "\n".join(turn["content"] for turn in sample["input"]) + "\n" + sample["output"] + ), }, "nemotron-post-training-dataset-v2": { "config": { @@ -104,14 +107,16 @@ __all__ = [ "create_forward_loop", + "download_hf_dataset_as_jsonl", "get_dataset_dataloader", "get_dataset_samples", + "get_jsonl_text_samples", "get_max_batch_size", "get_supported_datasets", ] -def _get_jsonl_text_samples(jsonl_path: str, num_samples: int) -> list[str]: +def get_jsonl_text_samples(jsonl_path: str, num_samples: int, key: str = "text") -> list[str]: """Load up to ``num_samples`` entries from a JSONL file using the ``text`` field. Each non-empty line must be a JSON object containing a ``text`` field. @@ -142,12 +147,12 @@ def _get_jsonl_text_samples(jsonl_path: str, num_samples: int) -> list[str]: f"got {type(obj)}." ) - if "text" not in obj: + if key not in obj: raise ValueError( - f"Missing required field 'text' in JSONL file {jsonl_path} at line {line_idx}." + f"Missing required field '{key}' in JSONL file {jsonl_path} at line {line_idx}." ) - samples.append(str(obj["text"])) + samples.append(str(obj[key])) return samples @@ -158,9 +163,7 @@ def _normalize_splits(split: str | list[str]) -> list[str]: def _auto_preprocess_sample( - sample: dict, - dataset_name: str, - tokenizer: "PreTrainedTokenizerBase | None" = None, + sample: dict, dataset_name: str, tokenizer: "PreTrainedTokenizerBase | None" = None ) -> str: """Auto-detect dataset format and preprocess a single sample based on column conventions. @@ -223,7 +226,10 @@ def get_dataset_samples( ``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``. Args: - dataset_name: Name or HuggingFace path of the dataset to load, or a path to a ``.jsonl``/``.jsonl.gz`` file. + dataset_name: Name or HuggingFace path of the dataset to load, a local directory path, + or a path to a ``.jsonl`` file. For local directory paths, the + predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name + matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key). num_samples: Number of samples to load from the dataset. apply_chat_template: Whether to apply the chat template to the samples (if supported by the dataset). For unregistered datasets with a @@ -240,15 +246,22 @@ def get_dataset_samples( """ # Local JSONL file path support (each line is a JSON object with a `text` field). if dataset_name.endswith(".jsonl"): - return _get_jsonl_text_samples(dataset_name, num_samples) + return get_jsonl_text_samples(dataset_name, num_samples, key="text") from datasets import load_dataset + local_dataset_path = None + if os.path.exists(dataset_name): # Local path + local_dataset_path = dataset_name + dataset_name = os.path.basename(os.path.normpath(local_dataset_path)) + is_registered = dataset_name in SUPPORTED_DATASET_CONFIG if is_registered: dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name] config = dataset_config["config"].copy() + if local_dataset_path: + config["path"] = local_dataset_path splits = _normalize_splits(split) if split is not None else config.pop("split", [None]) if split is not None: config.pop("split", None) @@ -274,17 +287,18 @@ def _preprocess(sample: dict) -> str: return dataset_config["preprocess"](sample) else: - warn( + print( f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. " "Auto-detecting format from column names." ) - config = {"path": dataset_name} + config = {"path": local_dataset_path or dataset_name} splits = _normalize_splits(split) if split is not None else ["train"] def _preprocess(sample: dict) -> str: return _auto_preprocess_sample(sample, dataset_name, tokenizer) # load_dataset does not support a list of splits while streaming, so load each separately. + print(f"Loading dataset with {config=} and {splits=}") dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits] num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits) @@ -649,3 +663,89 @@ def create_forward_loop( def model_type_is_enc_dec(model): enc_dec_model_list = ["t5", "bart", "whisper"] return any(model_name in model.__class__.__name__.lower() for model_name in enc_dec_model_list) + + +def download_hf_dataset_as_jsonl( + dataset_name: str, + output_dir: str | Path, + json_keys: list[str] = ["text"], + name: str | None = None, + split: str | None = "train", + max_samples_per_split: int | None = None, +) -> list[str]: + """Download a Hugging Face dataset and save as JSONL files. + + Args: + dataset_name: Name or HuggingFace path of the dataset to download + output_dir: Directory to save the JSONL files + json_keys: List of keys to extract from the dataset. Defaults to ["text"]. + name: Name of the subset to download + split: Split of the dataset to download. Defaults to "train". + max_samples_per_split: Maximum number of samples to download per split. Defaults to None. + + Returns: + List of paths to downloaded JSONL files. + """ + from datasets import load_dataset + from huggingface_hub.utils import build_hf_headers + + print(f"Downloading dataset {dataset_name} from Hugging Face") + jsonl_paths: list[str] = [] + + try: + response = requests.get( + f"https://datasets-server.huggingface.co/splits?dataset={dataset_name}", + headers=build_hf_headers(), + timeout=10, + ) + response.raise_for_status() + except requests.RequestException as e: + raise RuntimeError(f"Failed to fetch dataset splits for {dataset_name}: {e}") from e + + response_json = response.json() + print(f"\nFound {len(response_json['splits'])} total splits for {dataset_name}:") + for entry in response_json["splits"]: + print(f"\t{entry}") + + splits_to_process = [] + for entry in response_json["splits"]: + if name is not None and name != entry.get("config", None): + continue + if split is not None and split != entry["split"]: + continue + splits_to_process.append(entry) + + print(f"\nFound {len(splits_to_process)} splits to process:") + for entry in splits_to_process: + print(f"\t{entry}") + + for entry in splits_to_process: + skip_processing = False + path = entry["dataset"] + name = entry.get("config", None) + split = entry["split"] + if max_samples_per_split is not None: + split = f"{split}[:{max_samples_per_split}]" + jsonl_file_path = f"{output_dir}/{path.replace('/', '--')}_{name}_{split}.jsonl" + + print(f"\nLoading HF dataset {path=}, {name=}, {split=}") + if os.path.exists(jsonl_file_path): + jsonl_paths.append(jsonl_file_path) + print(f"\t[SKIP] Raw dataset {jsonl_file_path} already exists") + continue + ds = load_dataset(path=path, name=name, split=split) + + for key in json_keys: + if key not in ds.features: + warn(f"[SKIP] {key=} not found in {ds.features=}") + skip_processing = True + break + + if skip_processing: + continue + + print(f"Saving raw dataset to {jsonl_file_path}") + ds.to_json(jsonl_file_path) + jsonl_paths.append(jsonl_file_path) + + return jsonl_paths diff --git a/modelopt/torch/utils/plugins/megatron_preprocess_data.py b/modelopt/torch/utils/plugins/megatron_preprocess_data.py index 1c47a38dee..fbe298d5e8 100644 --- a/modelopt/torch/utils/plugins/megatron_preprocess_data.py +++ b/modelopt/torch/utils/plugins/megatron_preprocess_data.py @@ -56,17 +56,13 @@ import argparse import json import multiprocessing -import os from pathlib import Path -from warnings import warn -import requests -from datasets import load_dataset -from huggingface_hub.utils import build_hf_headers from megatron.core.datasets import indexed_dataset from transformers import AutoTokenizer from modelopt.torch.utils import num2hrb +from modelopt.torch.utils.dataset_utils import download_hf_dataset_as_jsonl __all__ = ["megatron_preprocess_data"] @@ -188,82 +184,6 @@ def process_json_file( return final_enc_len -def _download_hf_dataset( - dataset: str, - output_dir: str | Path, - json_keys: list[str], - name: str | None = None, - split: str | None = "train", - max_samples_per_split: int | None = None, -) -> list[str]: - """Download a Hugging Face dataset and save as JSONL files. - - Returns: - List of paths to downloaded JSONL files. - """ - print(f"Downloading dataset {dataset} from Hugging Face") - jsonl_paths: list[str] = [] - - try: - response = requests.get( - f"https://datasets-server.huggingface.co/splits?dataset={dataset}", - headers=build_hf_headers(), - timeout=10, - ) - response.raise_for_status() - except requests.RequestException as e: - raise RuntimeError(f"Failed to fetch dataset splits for {dataset}: {e}") from e - - response_json = response.json() - print(f"\nFound {len(response_json['splits'])} total splits for {dataset}:") - for entry in response_json["splits"]: - print(f"\t{entry}") - - splits_to_process = [] - for entry in response_json["splits"]: - if name is not None and name != entry.get("config", None): - continue - if split is not None and split != entry["split"]: - continue - splits_to_process.append(entry) - - print(f"\nFound {len(splits_to_process)} splits to process:") - for entry in splits_to_process: - print(f"\t{entry}") - - for entry in splits_to_process: - skip_processing = False - path = entry["dataset"] - name = entry.get("config", None) - split = entry["split"] - if max_samples_per_split is not None: - split = f"{split}[:{max_samples_per_split}]" - jsonl_file_path = f"{output_dir}/raw/{path.replace('/', '--')}_{name}_{split}.jsonl" - - print(f"\nLoading HF dataset {path=}, {name=}, {split=}") - if os.path.exists(jsonl_file_path): - jsonl_paths.append(jsonl_file_path) - print(f"\t[SKIP] Raw dataset {jsonl_file_path} already exists") - continue - ds = load_dataset(path=path, name=name, split=split) - - for key in json_keys: - if key not in ds.features: - warn(f"[SKIP] {key=} not found in {ds.features=}") - skip_processing = True - break - - if skip_processing: - continue - - print(f"Saving raw dataset to {jsonl_file_path}") - ds.to_json(jsonl_file_path) - jsonl_paths.append(jsonl_file_path) - - print(f"\n\nTokenizing JSONL paths: {jsonl_paths}\n") - return jsonl_paths - - def megatron_preprocess_data( *, input_dir: str | Path | None = None, @@ -309,14 +229,15 @@ def megatron_preprocess_data( ) if hf_dataset is not None: - jsonl_paths = _download_hf_dataset( + jsonl_paths = download_hf_dataset_as_jsonl( hf_dataset, - output_dir, + f"{output_dir}/raw", json_keys, name=hf_name, split=hf_split, max_samples_per_split=hf_max_samples_per_split, ) + print(f"\n\nTokenizing downloaded JSONL files: {jsonl_paths}\n") if input_dir is not None: file_names = sorted(Path(input_dir).glob("*.jsonl")) @@ -338,7 +259,7 @@ def megatron_preprocess_data( num_tokens = partition.process_json_file(name, output_dir, encoder) final_enc_len += num_tokens - print(f"\n\n>>> Total number of tokens currently processed: {num2hrb(final_enc_len)}") + print(f"\n\n>>> Total number of tokens currently processed: {num2hrb(final_enc_len)}\nDone!") def main(): diff --git a/tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py b/tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py index de6bc71818..a0c3b9c519 100644 --- a/tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py +++ b/tests/gpu_megatron/torch/utils/plugins/test_megatron_preprocess_data.py @@ -17,32 +17,10 @@ import os from pathlib import Path -from datasets import load_dataset - +from modelopt.torch.utils.dataset_utils import download_hf_dataset_as_jsonl from modelopt.torch.utils.plugins.megatron_preprocess_data import megatron_preprocess_data -def download_and_prepare_minipile_dataset(output_dir: Path) -> Path: - """Download the nanotron/minipile_100_samples dataset and convert to JSONL format. - - Args: - output_dir: Directory to save the JSONL file - - Returns: - Path to the created JSONL file - """ - dataset = load_dataset("nanotron/minipile_100_samples", split="train") - - jsonl_file = output_dir / "minipile_100_samples.jsonl" - - with open(jsonl_file, "w", encoding="utf-8") as f: - for item in dataset: - json_obj = {"text": item["text"]} - f.write(json.dumps(json_obj) + "\n") - - return jsonl_file - - def test_megatron_preprocess_data_with_minipile_jsonl(tmp_path): """Test megatron_preprocess_data with nanotron/minipile_100_samples dataset. @@ -52,9 +30,10 @@ def test_megatron_preprocess_data_with_minipile_jsonl(tmp_path): 3. Calls megatron_preprocess_data with jsonl_paths 4. Verifies that output files are created """ - input_jsonl = download_and_prepare_minipile_dataset(tmp_path) + input_jsonl = download_hf_dataset_as_jsonl("nanotron/minipile_100_samples", tmp_path / "raw") + assert len(input_jsonl) == 1, "Expected 1 JSONL file" + input_jsonl = Path(input_jsonl[0]) - assert input_jsonl.exists(), "Input JSONL file should exist" assert input_jsonl.stat().st_size > 0, "Input JSONL file should not be empty" with open(input_jsonl, encoding="utf-8") as f: @@ -71,7 +50,7 @@ def test_megatron_preprocess_data_with_minipile_jsonl(tmp_path): workers=1, ) - output_prefix = tmp_path / "minipile_100_samples" + output_prefix = tmp_path / "nanotron--minipile_100_samples_default_train" expected_bin_file = f"{output_prefix}_text_document.bin" expected_idx_file = f"{output_prefix}_text_document.idx" diff --git a/tests/unit/torch/utils/test_dataset_utils.py b/tests/unit/torch/utils/test_dataset_utils.py index f48643fb2c..b0d034025c 100644 --- a/tests/unit/torch/utils/test_dataset_utils.py +++ b/tests/unit/torch/utils/test_dataset_utils.py @@ -18,7 +18,7 @@ import pytest import torch -from modelopt.torch.utils.dataset_utils import _process_batch +from modelopt.torch.utils.dataset_utils import _process_batch, get_dataset_samples def setup_test_data(): @@ -101,3 +101,27 @@ def mock_infer_collect(**kwargs): # Verify all values were processed in the correct order assert processed_values == [0, 1, 2, 3] + + +@pytest.mark.parametrize("test_local_path", [True, False]) +def test_get_dataset_samples_with_unsupported_minipile_dataset(tmp_path, test_local_path): + pytest.importorskip("datasets") + pytest.importorskip("huggingface_hub") + + from huggingface_hub import snapshot_download + + dataset_name = "nanotron/minipile_100_samples" + if test_local_path: + local_dir = str(tmp_path / dataset_name) + snapshot_download( + repo_id=dataset_name, + repo_type="dataset", + local_dir=local_dir, + ) + dataset_name = local_dir + + samples = get_dataset_samples(dataset_name, num_samples=5) + + assert isinstance(samples, list) + assert len(samples) == 5 + assert all(isinstance(s, str) and len(s) > 0 for s in samples)