Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/megatron_bridge/prune_minitron.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ def get_args() -> argparse.Namespace:
"--calib_dataset_name",
type=str,
default="nemotron-post-training-dataset-v2",
choices=get_supported_datasets(),
help="Dataset name for calibration",
help=(
f"HF Dataset name or local path for calibration (supported options: {', '.join(get_supported_datasets())}. "
"You can also pass any other dataset and see if auto-detection for your dataset works."
),
)
parser.add_argument(
"--calib_num_samples", type=int, default=1024, help="Number of samples for calibration"
Expand Down
128 changes: 114 additions & 14 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@

import copy
import json
import os
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING, Any
from warnings import warn

import requests
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
Expand Down Expand Up @@ -48,9 +51,9 @@
"name": "SFT",
"split": ["code", "math", "science", "chat", "safety"],
},
"preprocess": lambda sample: "\n".join(turn["content"] for turn in sample["input"])
+ "\n"
+ sample["output"],
"preprocess": lambda sample: (
"\n".join(turn["content"] for turn in sample["input"]) + "\n" + sample["output"]
),
},
"nemotron-post-training-dataset-v2": {
"config": {
Expand Down Expand Up @@ -104,14 +107,16 @@

__all__ = [
"create_forward_loop",
"download_hf_dataset_as_jsonl",
"get_dataset_dataloader",
"get_dataset_samples",
"get_jsonl_text_samples",
"get_max_batch_size",
"get_supported_datasets",
]


def _get_jsonl_text_samples(jsonl_path: str, num_samples: int) -> list[str]:
def get_jsonl_text_samples(jsonl_path: str, num_samples: int, key: str = "text") -> list[str]:
"""Load up to ``num_samples`` entries from a JSONL file using the ``text`` field.

Each non-empty line must be a JSON object containing a ``text`` field.
Expand Down Expand Up @@ -142,12 +147,12 @@ def _get_jsonl_text_samples(jsonl_path: str, num_samples: int) -> list[str]:
f"got {type(obj)}."
)

if "text" not in obj:
if key not in obj:
raise ValueError(
f"Missing required field 'text' in JSONL file {jsonl_path} at line {line_idx}."
f"Missing required field '{key}' in JSONL file {jsonl_path} at line {line_idx}."
)

samples.append(str(obj["text"]))
samples.append(str(obj[key]))

return samples

Expand All @@ -158,9 +163,7 @@ def _normalize_splits(split: str | list[str]) -> list[str]:


def _auto_preprocess_sample(
sample: dict,
dataset_name: str,
tokenizer: "PreTrainedTokenizerBase | None" = None,
sample: dict, dataset_name: str, tokenizer: "PreTrainedTokenizerBase | None" = None
) -> str:
"""Auto-detect dataset format and preprocess a single sample based on column conventions.

Expand Down Expand Up @@ -223,7 +226,10 @@ def get_dataset_samples(
``messages``/``conversations`` (chat), ``prompt``, ``text``, or ``input``.

Args:
dataset_name: Name or HuggingFace path of the dataset to load, or a path to a ``.jsonl``/``.jsonl.gz`` file.
dataset_name: Name or HuggingFace path of the dataset to load, a local directory path,
or a path to a ``.jsonl`` file. For local directory paths, the
predefined config from ``SUPPORTED_DATASET_CONFIG`` is matched if the base folder name
matches a registered key (e.g. ``/hf-local/abisee/cnn_dailymail`` matches ``cnn_dailymail`` key).
num_samples: Number of samples to load from the dataset.
apply_chat_template: Whether to apply the chat template to the samples
(if supported by the dataset). For unregistered datasets with a
Expand All @@ -240,15 +246,22 @@ def get_dataset_samples(
"""
# Local JSONL file path support (each line is a JSON object with a `text` field).
if dataset_name.endswith(".jsonl"):
return _get_jsonl_text_samples(dataset_name, num_samples)
return get_jsonl_text_samples(dataset_name, num_samples, key="text")

from datasets import load_dataset

local_dataset_path = None
if os.path.exists(dataset_name): # Local path
local_dataset_path = dataset_name
dataset_name = os.path.basename(os.path.normpath(local_dataset_path))

is_registered = dataset_name in SUPPORTED_DATASET_CONFIG

if is_registered:
dataset_config = SUPPORTED_DATASET_CONFIG[dataset_name]
config = dataset_config["config"].copy()
if local_dataset_path:
config["path"] = local_dataset_path
splits = _normalize_splits(split) if split is not None else config.pop("split", [None])
if split is not None:
config.pop("split", None)
Expand All @@ -274,17 +287,18 @@ def _preprocess(sample: dict) -> str:
return dataset_config["preprocess"](sample)

else:
warn(
print(
f"Dataset '{dataset_name}' is not in SUPPORTED_DATASET_CONFIG. "
"Auto-detecting format from column names."
)
config = {"path": dataset_name}
config = {"path": local_dataset_path or dataset_name}
splits = _normalize_splits(split) if split is not None else ["train"]

def _preprocess(sample: dict) -> str:
return _auto_preprocess_sample(sample, dataset_name, tokenizer)

# load_dataset does not support a list of splits while streaming, so load each separately.
print(f"Loading dataset with {config=} and {splits=}")
dataset_splits = [load_dataset(streaming=True, **config, split=s) for s in splits]

num_per_split = [num_samples // len(dataset_splits)] * len(dataset_splits)
Expand Down Expand Up @@ -649,3 +663,89 @@ def create_forward_loop(
def model_type_is_enc_dec(model):
enc_dec_model_list = ["t5", "bart", "whisper"]
return any(model_name in model.__class__.__name__.lower() for model_name in enc_dec_model_list)


def download_hf_dataset_as_jsonl(
dataset_name: str,
output_dir: str | Path,
json_keys: list[str] = ["text"],
name: str | None = None,
split: str | None = "train",
max_samples_per_split: int | None = None,
) -> list[str]:
"""Download a Hugging Face dataset and save as JSONL files.

Args:
dataset_name: Name or HuggingFace path of the dataset to download
output_dir: Directory to save the JSONL files
json_keys: List of keys to extract from the dataset. Defaults to ["text"].
name: Name of the subset to download
split: Split of the dataset to download. Defaults to "train".
max_samples_per_split: Maximum number of samples to download per split. Defaults to None.

Returns:
List of paths to downloaded JSONL files.
"""
from datasets import load_dataset
from huggingface_hub.utils import build_hf_headers

print(f"Downloading dataset {dataset_name} from Hugging Face")
jsonl_paths: list[str] = []

try:
response = requests.get(
f"https://datasets-server.huggingface.co/splits?dataset={dataset_name}",
headers=build_hf_headers(),
timeout=10,
)
response.raise_for_status()
except requests.RequestException as e:
raise RuntimeError(f"Failed to fetch dataset splits for {dataset_name}: {e}") from e

response_json = response.json()
print(f"\nFound {len(response_json['splits'])} total splits for {dataset_name}:")
for entry in response_json["splits"]:
print(f"\t{entry}")

splits_to_process = []
for entry in response_json["splits"]:
if name is not None and name != entry.get("config", None):
continue
if split is not None and split != entry["split"]:
continue
splits_to_process.append(entry)

print(f"\nFound {len(splits_to_process)} splits to process:")
for entry in splits_to_process:
print(f"\t{entry}")

for entry in splits_to_process:
skip_processing = False
path = entry["dataset"]
name = entry.get("config", None)
split = entry["split"]
if max_samples_per_split is not None:
split = f"{split}[:{max_samples_per_split}]"
jsonl_file_path = f"{output_dir}/{path.replace('/', '--')}_{name}_{split}.jsonl"

print(f"\nLoading HF dataset {path=}, {name=}, {split=}")
if os.path.exists(jsonl_file_path):
jsonl_paths.append(jsonl_file_path)
print(f"\t[SKIP] Raw dataset {jsonl_file_path} already exists")
continue
ds = load_dataset(path=path, name=name, split=split)

for key in json_keys:
if key not in ds.features:
warn(f"[SKIP] {key=} not found in {ds.features=}")
skip_processing = True
break

if skip_processing:
continue

print(f"Saving raw dataset to {jsonl_file_path}")
ds.to_json(jsonl_file_path)
jsonl_paths.append(jsonl_file_path)

return jsonl_paths
89 changes: 5 additions & 84 deletions modelopt/torch/utils/plugins/megatron_preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,13 @@
import argparse
import json
import multiprocessing
import os
from pathlib import Path
from warnings import warn

import requests
from datasets import load_dataset
from huggingface_hub.utils import build_hf_headers
from megatron.core.datasets import indexed_dataset
from transformers import AutoTokenizer

from modelopt.torch.utils import num2hrb
from modelopt.torch.utils.dataset_utils import download_hf_dataset_as_jsonl

__all__ = ["megatron_preprocess_data"]

Expand Down Expand Up @@ -188,82 +184,6 @@ def process_json_file(
return final_enc_len


def _download_hf_dataset(
dataset: str,
output_dir: str | Path,
json_keys: list[str],
name: str | None = None,
split: str | None = "train",
max_samples_per_split: int | None = None,
) -> list[str]:
"""Download a Hugging Face dataset and save as JSONL files.

Returns:
List of paths to downloaded JSONL files.
"""
print(f"Downloading dataset {dataset} from Hugging Face")
jsonl_paths: list[str] = []

try:
response = requests.get(
f"https://datasets-server.huggingface.co/splits?dataset={dataset}",
headers=build_hf_headers(),
timeout=10,
)
response.raise_for_status()
except requests.RequestException as e:
raise RuntimeError(f"Failed to fetch dataset splits for {dataset}: {e}") from e

response_json = response.json()
print(f"\nFound {len(response_json['splits'])} total splits for {dataset}:")
for entry in response_json["splits"]:
print(f"\t{entry}")

splits_to_process = []
for entry in response_json["splits"]:
if name is not None and name != entry.get("config", None):
continue
if split is not None and split != entry["split"]:
continue
splits_to_process.append(entry)

print(f"\nFound {len(splits_to_process)} splits to process:")
for entry in splits_to_process:
print(f"\t{entry}")

for entry in splits_to_process:
skip_processing = False
path = entry["dataset"]
name = entry.get("config", None)
split = entry["split"]
if max_samples_per_split is not None:
split = f"{split}[:{max_samples_per_split}]"
jsonl_file_path = f"{output_dir}/raw/{path.replace('/', '--')}_{name}_{split}.jsonl"

print(f"\nLoading HF dataset {path=}, {name=}, {split=}")
if os.path.exists(jsonl_file_path):
jsonl_paths.append(jsonl_file_path)
print(f"\t[SKIP] Raw dataset {jsonl_file_path} already exists")
continue
ds = load_dataset(path=path, name=name, split=split)

for key in json_keys:
if key not in ds.features:
warn(f"[SKIP] {key=} not found in {ds.features=}")
skip_processing = True
break

if skip_processing:
continue

print(f"Saving raw dataset to {jsonl_file_path}")
ds.to_json(jsonl_file_path)
jsonl_paths.append(jsonl_file_path)

print(f"\n\nTokenizing JSONL paths: {jsonl_paths}\n")
return jsonl_paths


def megatron_preprocess_data(
*,
input_dir: str | Path | None = None,
Expand Down Expand Up @@ -309,14 +229,15 @@ def megatron_preprocess_data(
)

if hf_dataset is not None:
jsonl_paths = _download_hf_dataset(
jsonl_paths = download_hf_dataset_as_jsonl(
hf_dataset,
output_dir,
f"{output_dir}/raw",
json_keys,
name=hf_name,
split=hf_split,
max_samples_per_split=hf_max_samples_per_split,
)
print(f"\n\nTokenizing downloaded JSONL files: {jsonl_paths}\n")

if input_dir is not None:
file_names = sorted(Path(input_dir).glob("*.jsonl"))
Expand All @@ -338,7 +259,7 @@ def megatron_preprocess_data(
num_tokens = partition.process_json_file(name, output_dir, encoder)
final_enc_len += num_tokens

print(f"\n\n>>> Total number of tokens currently processed: {num2hrb(final_enc_len)}")
print(f"\n\n>>> Total number of tokens currently processed: {num2hrb(final_enc_len)}\nDone!")


def main():
Expand Down
Loading