From b20a4d9937bdb1a94835461b744c022afda55de3 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 30 Apr 2026 02:57:25 -0700 Subject: [PATCH] Add Nemotron-3-Nano-30B-A3B-BF16 e2e pruning tutorial and update Nemotron-Nano-9B-v2 docs Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- examples/dataset/MEGATRON_DATA_PREP.md | 28 ++- .../README.md | 164 ++++++++++++++ .../nemo_evaluator.yaml | 212 ++++++++++++++++++ .../NVIDIA-Nemotron-Nano-9B-v2/README.md | 2 +- .../nemo_evaluator.yaml | 2 - .../utils/plugins/megatron_preprocess_data.py | 38 +++- 6 files changed, 438 insertions(+), 8 deletions(-) create mode 100644 examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md create mode 100644 examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml diff --git a/examples/dataset/MEGATRON_DATA_PREP.md b/examples/dataset/MEGATRON_DATA_PREP.md index c3904d2a0f..7d9ad60e79 100644 --- a/examples/dataset/MEGATRON_DATA_PREP.md +++ b/examples/dataset/MEGATRON_DATA_PREP.md @@ -97,8 +97,8 @@ Tokenization commands for all Nemotron Pre-Training and Post-Training datasets u Two parameters vary by model — set them before running the commands below: ```bash -TOKENIZER=nvidia/NVIDIA-Nemotron-Nano-9B-v2 # HuggingFace tokenizer (or local path) -OUTPUT_DIR=tokenized_nemotron_v2 # Output directory for tokenized files +TOKENIZER=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 # HuggingFace tokenizer (or local path) +OUTPUT_DIR=tokenized_nemotron_3 # Output directory for tokenized files ``` > [!TIP] @@ -154,13 +154,14 @@ python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ Datasets below are from the [Nemotron Post-Training v3 collection](https://huggingface.co/collections/nvidia/nemotron-post-training-v3). All use `--reasoning_content inline` to preserve `` traces. The collection contains many more datasets — if you care about benchmarks not covered here (e.g. multilingual, agentic/tool use, SWE, safety), pick the relevant datasets from the collection and tokenize them the same way. -**[nvidia/Nemotron-Math-v2](https://huggingface.co/datasets/nvidia/Nemotron-Math-v2)** — tokenize `high_part00` and `high_part01` separately: +**[nvidia/Nemotron-Math-v2](https://huggingface.co/datasets/nvidia/Nemotron-Math-v2)** — tokenize `high_part00` and `high_part01` separately. `--hf_streaming` is required because the messages contain extra fields (e.g. `tool_calls`) that cause Arrow type-cast errors in non-streaming mode when using tokenizers with complex chat templates (such as Nemotron v3): ```bash for SPLIT in high_part00 high_part01; do python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ --hf_dataset nvidia/Nemotron-Math-v2 \ --hf_split ${SPLIT} \ + --hf_streaming \ --json_keys messages \ --tokenizer ${TOKENIZER} \ --output_dir ${OUTPUT_DIR} \ @@ -170,6 +171,26 @@ for SPLIT in high_part00 high_part01; do done ``` +**[nvidia/Nemotron-SFT-Math-v3](https://huggingface.co/datasets/nvidia/Nemotron-SFT-Math-v3)** — stored as raw JSONL on HuggingFace, download before tokenizing (more reliable than streaming for this dataset due to complex nested `tool_calls` fields): + +```bash +hf download nvidia/Nemotron-SFT-Math-v3 \ + --repo-type dataset \ + --local-dir datasets/Nemotron-SFT-Math-v3/ +python -m modelopt.torch.utils.plugins.megatron_preprocess_data \ + --jsonl_paths datasets/Nemotron-SFT-Math-v3/data/train.jsonl \ + --json_keys messages \ + --tokenizer ${TOKENIZER} \ + --output_dir ${OUTPUT_DIR} \ + --workers 96 \ + --max_sequence_length 256_000 \ + --reasoning_content inline + +# Rename to avoid generic file name +mv ${OUTPUT_DIR}/train_messages.bin ${OUTPUT_DIR}/nvidia--Nemotron-SFT-Math-v3_default_train_messages.bin +mv ${OUTPUT_DIR}/train_messages.idx ${OUTPUT_DIR}/nvidia--Nemotron-SFT-Math-v3_default_train_messages.idx +``` + **[nvidia/Nemotron-SFT-Competitive-Programming-v2](https://huggingface.co/datasets/nvidia/Nemotron-SFT-Competitive-Programming-v2)** — stored as raw JSONL on HuggingFace, download before tokenizing: ```bash @@ -233,6 +254,7 @@ nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-MATH_train_text_max10000000.{bi nvidia--Nemotron-Post-Training-Dataset-v1_default_stem_messages_max5000000.{bin,idx} nvidia--Nemotron-Math-v2_default_high_part00_messages.{bin,idx} nvidia--Nemotron-Math-v2_default_high_part01_messages.{bin,idx} +nvidia--Nemotron-SFT-Math-v3_default_train_messages.{bin,idx} competitive_programming_python_00_messages.{bin,idx} competitive_programming_cpp_00_messages.{bin,idx} MCQ_messages.{bin,idx} diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md b/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md new file mode 100644 index 0000000000..a73576e6e3 --- /dev/null +++ b/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/README.md @@ -0,0 +1,164 @@ +# Nemotron-3-Nano-30B-A3B: Prune + Distill + Quantize + vLLM Deployment + +End-to-end optimization of [NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16) demonstrating how ModelOpt techniques stack: Minitron structured pruning → Megatron-Bridge knowledge distillation to recover accuracy → FP8 quantization → vLLM deployment and throughput benchmarking. This document covers: + +1. **[Data Preparation](#1-data-preparation)** — tokenizing the training blend for distillation +2. **[Pruning](#2-pruning)** — Minitron structured pruning +3. **[Distillation](#3-distillation)** — recovering accuracy via Megatron-Bridge knowledge distillation +4. **[Evaluation](#4-evaluation)** — benchmarking with NeMo Evaluator across MMLU Pro, GPQA Diamond, AIME, and more +5. **[Quantization](#5-quantization)** — FP8 PTQ on the distilled checkpoint using ModelOpt's `examples/llm_ptq/hf_ptq.py` script +6. **[vLLM Inference Benchmarking](#6-vllm-inference-benchmarking)** — throughput comparison of BF16 vs FP8 on a single H100 + +**Environment:** Container `nvcr.io/nvidia/nemo:26.02`, ModelOpt 0.45.0. See the [Megatron-Bridge README](../../../megatron_bridge/README.md) for environment setup (including ModelOpt mount path) and container usage. + +## Results + +TODO + +--- + +## Steps to Reproduce + +### 1. Data Preparation + +See [examples/dataset/MEGATRON_DATA_PREP.md](../../../dataset/MEGATRON_DATA_PREP.md) for tokenization commands for all datasets used in this blend. + +For this experiment: `TOKENIZER=nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16`, `OUTPUT_DIR=tokenized_nemotron_3`. + +> [!NOTE] +> Compared to experiments in [NVIDIA-Nemotron-Nano-9B-v2](../NVIDIA-Nemotron-Nano-9B-v2/README.md), we use `Nemotron-SFT-Math-v3` instead of `Nemotron-Math-v2 / high_part01` since it is higher quality with full reasoning traces. + +#### Data Blend + +**30% Pretraining (Code 5, General 20, MATH 5) + 70% Post-training v1/v3 (Math 30, Coding 20, Science 15, IF 5)** + +```bash +DATA_BLEND=" \ +5 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-Code_train_text_max10000000 \ +20 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-General_train_text_max10000000 \ +5 tokenized_nemotron_3/nvidia--Nemotron-Pretraining-SFT-v1_Nemotron-SFT-MATH_train_text_max10000000 \ +10 tokenized_nemotron_3/nvidia--Nemotron-Math-v2_default_high_part00_messages \ +20 tokenized_nemotron_3/nvidia--Nemotron-SFT-Math-v3_default_train_messages \ +15 tokenized_nemotron_3/competitive_programming_python_00_messages \ +5 tokenized_nemotron_3/competitive_programming_cpp_00_messages \ +10 tokenized_nemotron_3/nvidia--Nemotron-Post-Training-Dataset-v1_default_stem_messages_max5000000 \ +3 tokenized_nemotron_3/MCQ_messages \ +2 tokenized_nemotron_3/RQA_messages \ +3 tokenized_nemotron_3/reasoning_on_messages \ +2 tokenized_nemotron_3/reasoning_off_messages \ +" +``` + +| Dataset | Tokens | Weight | Notes | +| ----------------------------------------------------- | ------ | ------ | ---------------------------------------------- | +| Nemotron-Pretraining-SFT-v1 / Code (10M samples) | 7B | 5 | Pretraining code | +| Nemotron-Pretraining-SFT-v1 / General (10M samples) | 16B | 20 | Upweighted to close MMLU gap | +| Nemotron-Pretraining-SFT-v1 / MATH (10M samples) | 13B | 5 | Pretraining math | +| Nemotron-Math-v2 / high_part00 | 13B | 10 | Hard math reasoning | +| Nemotron-SFT-Math-v3 / train | 52B | 20 | Hard math reasoning with full reasoning traces | +| Nemotron-SFT-Competitive-Programming-v2 / python_00 | 7B | 15 | Python reasoning traces | +| Nemotron-SFT-Competitive-Programming-v2 / cpp_00 | 7B | 5 | C++ reasoning traces | +| Nemotron-Post-Training-Dataset-v1 / stem (5M samples) | 22B | 10 | Broad STEM | +| Nemotron-Science-v1 / MCQ | 0.5B | 3 | GPQA MCQ format alignment | +| Nemotron-Science-v1 / RQA | 0.3B | 2 | GPQA format diversity | +| Nemotron-SFT-IF-Chat-v2 / reasoning_on | 2B | 3 | Instruction following (thinking on) | +| Nemotron-SFT-IF-Chat-v2 / reasoning_off | 1B | 2 | Instruction following (thinking off) | + +#### General Guidelines + +The optimal blend is 30% pretraining and 70% post-training data. Exact proportions may vary depending on the benchmarks you care about. The blend above was designed to maximize recovery on popular General Knowledge, Reasoning, Instruction Following, and Tool Calling benchmarks. The key design decisions were: + +- **30% pretraining data** closes the MMLU gap that arises from training exclusively on reasoning-heavy post-training data. The General split (20%) is upweighted specifically to recover general knowledge recall. +- **Math (30%)** is the largest post-training category because AIME and MMLU Pro respond strongly to more math reasoning tokens. We use a mix of `Nemotron-Math-v2` and `Nemotron-SFT-Math-v3` for higher quality math reasoning signal with full reasoning traces. +- **Science (15%)** uses `Nemotron-Post-Training-Dataset-v1 / stem` as the primary source for volume and GPQA stability, with small allocations to `Nemotron-Science-v1` MCQ/RQA subsets for format alignment with GPQA's multiple-choice structure. +- **Instruction following (5%)** saturates quickly so a small allocation is sufficient. + +This blend intentionally omits capabilities not targeted in this experiment (e.g. long context and multilingual benchmarks). Depending on what benchmarks matter for your use case, you can substitute or add datasets from the [Nemotron Post-Training v3 collection](https://huggingface.co/collections/nvidia/nemotron-post-training-v3), for example: + +| Capability | Relevant datasets | +| --- | --- | +| Multilingual | `Nemotron-SFT-Multilingual-v1` | +| Agentic / tool use | `Nemotron-SFT-Tool-Call-v1`, `Nemotron-SFT-Tool-Call-v2` | +| Software engineering (SWE) | `Nemotron-SFT-SWE-v1` | +| Safety / alignment | `Nemotron-SFT-Safety-v1` | +| Long context | `Nemotron-SFT-Long-Context-v1` | + +When adding new datasets, reduce weights of lower-priority categories proportionally to keep the total at 100%. + +--- + +### 2. Pruning + +TODO + +--- + +### 3. Distillation + +TODO + +--- + +### 4. Evaluation + +The eval config in [nemo_evaluator.yaml](nemo_evaluator.yaml) is for Slurm-based evaluation — it submits a vLLM serving job and runs evals against it. For local model execution and evaluation, refer to the [NeMo Evaluator documentation](https://docs.nvidia.com/nemo/evaluator/latest/). + +Before running, update the following fields in the yaml: + +- `execution.hostname` — your Slurm login node hostname +- `execution.account` — your Slurm account +- `deployment.checkpoint_path` — Hugging Face checkpoint path (original, pruned or quantized) +- `evaluation.nemo_evaluator_config.config.params.extra.tokenizer` — same path as `checkpoint_path` + +Set the required environment variables and run: + +> [!TIP] +> Uncomment `limit_samples` under any task to run a small subset and verify the end-to-end eval pipeline before launching full evals. + +```bash +pip install "nemo-evaluator-launcher[all]==0.1.90" + +# Required environment variables +export HF_TOKEN= +export SLURM_JOB_DIR= +export HF_HOME= +export VLLM_CACHE_ROOT= + +# Additional unused but required environment variables +export API_KEY=xxxxxx +export INFERENCE_API_KEY=xxxxxx +export OPENAI_CLIENT_ID=xxxxxx +export OPENAI_CLIENT_SECRET=xxxxxx + +nemo-evaluator-launcher run --config nemo_evaluator.yaml +``` + +**Tasks and exact metric names reported in the results table:** + +| Benchmark | Tool | Metric name | +| --- | --- | --- | +| MMLU | [lm-evaluation-harness](https://github.com/EleutherAI/lm-evaluation-harness) (5-shot) | `mmlu` | +| MMLU Pro | NeMo Evaluator | `mmlu-pro_pass_at_1_symbolic_correct` | +| GPQA Diamond | NeMo Evaluator | `gpqa_pass_at_1_symbolic_correct` | +| LiveCodeBench v6 | NeMo Evaluator | `livecodebench_pass_at_1_accuracy` | +| AIME 2025 | NeMo Evaluator | `aime25_pass_at_1_symbolic_correct` | +| IFBench | NeMo Evaluator | `ifbench_pass_at_1_average_score` | +| SciCode (Subtask) | NeMo Evaluator | `scicode_pass_at_1_subtask_accuracy` | +| BFCL v3 | NeMo Evaluator | `bfcl_v3_overall_accuracy_accuracy` | +| BFCL v4 | NeMo Evaluator | `bfcl_v4_overall_accuracy_accuracy` | + +**Key vLLM settings:** Tool calling is enabled via `--enable-auto-tool-choice --tool-call-parser qwen3_coder`. + +For more details on NeMo Evaluator, see the [GitHub repo](https://github.com/NVIDIA-NeMo/evaluator) and [documentation](https://docs.nvidia.com/nemo/evaluator/latest/). + +--- + +### 5. Quantization + +TODO + +--- + +### 6. vLLM Inference Benchmarking + +TODO diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml b/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml new file mode 100644 index 0000000000..4c7d78a863 --- /dev/null +++ b/examples/pruning/minitron/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/nemo_evaluator.yaml @@ -0,0 +1,212 @@ +# NeMo Evaluator Launcher config for Nemotron-3-Nano-30B-A3B and pruned variants +# ------------------------------------------------------------------------------ +# Before running, update the following fields in the yaml: +# - `execution.hostname` — your Slurm login node hostname +# - `execution.account` — your Slurm account +# - `deployment.checkpoint_path` — Hugging Face checkpoint path (original, pruned or quantized) +# - `evaluation.nemo_evaluator_config.config.params.extra.tokenizer` — same path as `checkpoint_path` +# +# Usage: +# pip install "nemo-evaluator-launcher[all]==0.1.90" +# +# # Set required environment variables: +# export HF_TOKEN= +# export SLURM_JOB_DIR= +# export HF_HOME= +# export VLLM_CACHE_ROOT= +# +# # Set additional unused but required environment variables: +# export API_KEY=xxxxxx +# export INFERENCE_API_KEY=xxxxxx +# export OPENAI_CLIENT_ID=xxxxxx +# export OPENAI_CLIENT_SECRET=xxxxxx +# +# nemo-evaluator-launcher run --config nemo_evaluator.yaml +# + +defaults: + - execution: slurm/default + - deployment: vllm + - _self_ + +execution: + type: slurm + hostname: + username: ${oc.env:USER} + account: + partition: batch + num_nodes: 1 + ntasks_per_node: 1 + gpus_per_node: 8 + gres: "gpu:8" + walltime: 04:00:00 + sbatch_comment: "{\"OccupiedIdleGPUsJobReaper\":{\"exemptIdleTimeMins\":\"1920\",\"reason\":\"benchmarking\",\"description\":\"Some evals need idle time\ + \ else gets cancelled\"}}" + subproject: nel + output_dir: ${oc.env:SLURM_JOB_DIR} + mode: sequential + + mounts: + mount_home: false + deployment: + n_tasks: 1 + +# Note: Only tp=1 works for Nano (Mamba-based hybrid architecture) +deployment: + # Update this to your Hugging Face checkpoint path (original, pruned or quantized) + checkpoint_path: + served_model_name: Nemotron-3-Nano-30B-A3B + port: 8000 + tensor_parallel_size: 1 + pipeline_parallel_size: 1 + data_parallel_size: 8 + gpu_memory_utilization: 0.90 + extra_args: "--max-model-len 262144 --enable-log-requests --no-enable-prefix-caching --trust-remote-code --mamba_ssm_cache_dtype float32 --enable-auto-tool-choice\ + \ --tool-call-parser qwen3_coder --reasoning-parser-plugin /checkpoint/nano_v3_reasoning_parser.py --reasoning-parser nano_v3" + env_vars: + VLLM_FLASHINFER_MOE_BACKEND: throughput + endpoints: + chat: /v1/chat/completions + completions: /v1/completions + health: /health + multiple_instances: true + +evaluation: + nemo_evaluator_config: + target: + api_endpoint: + adapter_config: + use_system_prompt: true + use_reasoning: false + params_to_add: + chat_template_kwargs: + enable_thinking: true + skip_special_tokens: false + use_caching: true + tracking_requests_stats: true + log_failed_requests: true + use_request_logging: true + max_logged_requests: 10 + use_response_logging: true + max_logged_responses: 10 + config: + params: + parallelism: 64 + max_new_tokens: 131072 + temperature: 0.99999 + top_p: 0.99999 + request_timeout: 3600 + max_retries: 10 + extra: + tokenizer_backend: huggingface + # Update tokenizer path to match checkpoint_path above + tokenizer: + env_vars: + HF_TOKEN: HF_TOKEN + HF_HOME: HF_HOME + VLLM_CACHE_ROOT: VLLM_CACHE_ROOT + API_KEY: API_KEY + INFERENCE_API_KEY: INFERENCE_API_KEY + OPENAI_CLIENT_ID: OPENAI_CLIENT_ID + OPENAI_CLIENT_SECRET: OPENAI_CLIENT_SECRET + + tasks: + # 1. MMLU Pro + - name: ns_mmlu_pro + env_vars: + HF_TOKEN: HF_TOKEN + nemo_evaluator_config: + config: + params: + # limit_samples: 8 + extra: + num_repeats: 1 + args: "++prompt_config=eval/aai/mcq-10choices-boxed" + + # 2. GPQA Diamond + - name: ns_gpqa + env_vars: + HF_TOKEN: HF_TOKEN + nemo_evaluator_config: + config: + params: + # limit_samples: 8 + extra: + num_repeats: 8 + args: "++prompt_config=eval/aai/mcq-4choices" + + # 3. LiveCodeBench + - name: ns_livecodebench + env_vars: + HF_TOKEN: HF_TOKEN + nemo_evaluator_config: + config: + params: + # limit_samples: 8 + extra: + num_repeats: 8 + dataset_split: test_v6_2408_2505 + + # 4. AIME 2025 + - name: ns_aime2025 + env_vars: + HF_TOKEN: HF_TOKEN + nemo_evaluator_config: + config: + params: + # limit_samples: 8 + extra: + num_repeats: 64 + + # 5. IFBench + - name: ns_ifbench + env_vars: + HF_TOKEN: HF_TOKEN + nemo_evaluator_config: + config: + params: + # limit_samples: 8 + extra: + num_repeats: 8 + + # 6. SciCode + - name: ns_scicode + env_vars: + HF_TOKEN: HF_TOKEN + nemo_evaluator_config: + config: + params: + # limit_samples: 8 + extra: + num_repeats: 8 + + # 7. BFCL v3 — tool calling benchmark (requires --enable-auto-tool-choice in deployment) + - name: ns_bfcl_v3 + env_vars: + HF_TOKEN: HF_TOKEN + nemo_evaluator_config: + config: + params: + temperature: 0.6 + top_p: 0.95 + parallelism: 32 + # limit_samples: 8 + extra: + num_repeats: 1 + args: ++use_client_parsing=False + + # 8. BFCL v4 — tool calling benchmark (requires --enable-auto-tool-choice in deployment) + - name: ns_bfcl_v4 + env_vars: + HF_TOKEN: HF_TOKEN + nemo_evaluator_config: + config: + params: + max_new_tokens: 8192 + temperature: 0.6 + top_p: 0.95 + parallelism: 128 + # limit_samples: 8 + extra: + num_repeats: 1 + args: ++use_client_parsing=False diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-Nano-9B-v2/README.md b/examples/pruning/minitron/NVIDIA-Nemotron-Nano-9B-v2/README.md index 620c5780a4..f5135df101 100644 --- a/examples/pruning/minitron/NVIDIA-Nemotron-Nano-9B-v2/README.md +++ b/examples/pruning/minitron/NVIDIA-Nemotron-Nano-9B-v2/README.md @@ -225,7 +225,7 @@ python /opt/Megatron-Bridge/examples/conversion/convert_checkpoints.py export \ ### 4. Evaluation -The eval config xin [nemo_evaluator.yaml](nemo_evaluator.yaml) is for Slurm-based evaluation — it submits a vLLM serving job and runs evals against it. For local model execution and evaluation, refer to the [NeMo Evaluator documentation](https://docs.nvidia.com/nemo/evaluator/latest/) or this [blog](https://huggingface.co/blog/nvidia/nemotron-3-nano-evaluation-recipe). +The eval config in [nemo_evaluator.yaml](nemo_evaluator.yaml) is for Slurm-based evaluation — it submits a vLLM serving job and runs evals against it. For local model execution and evaluation, refer to the [NeMo Evaluator documentation](https://docs.nvidia.com/nemo/evaluator/latest/) or this [blog](https://huggingface.co/blog/nvidia/nemotron-3-nano-evaluation-recipe). Before running, update the following fields in the yaml: diff --git a/examples/pruning/minitron/NVIDIA-Nemotron-Nano-9B-v2/nemo_evaluator.yaml b/examples/pruning/minitron/NVIDIA-Nemotron-Nano-9B-v2/nemo_evaluator.yaml index 256a4031be..9a534a47aa 100644 --- a/examples/pruning/minitron/NVIDIA-Nemotron-Nano-9B-v2/nemo_evaluator.yaml +++ b/examples/pruning/minitron/NVIDIA-Nemotron-Nano-9B-v2/nemo_evaluator.yaml @@ -50,8 +50,6 @@ execution: mount_home: false deployment: n_tasks: 1 - batch_comment: "{\"OccupiedIdleGPUsJobReaper\":{\"exemptIdleTimeMins\":\"1920\",\"reason\":\"benchmarking\",\"description\":\"Required data validation\ - \ and evaluation\"}}" # Note: Only tp=1 works for Nano (Mamba-based architecture) deployment: diff --git a/modelopt/torch/utils/plugins/megatron_preprocess_data.py b/modelopt/torch/utils/plugins/megatron_preprocess_data.py index 81dac1580b..0aa29f1c55 100644 --- a/modelopt/torch/utils/plugins/megatron_preprocess_data.py +++ b/modelopt/torch/utils/plugins/megatron_preprocess_data.py @@ -161,15 +161,49 @@ def _process_messages(self, messages: list[dict]) -> list[dict]: """ if self.reasoning_content == "native": return messages + + def _fix_arguments(args): + """Ensure tool_call.arguments is a dict for Jinja2 |items compatibility.""" + if isinstance(args, dict): + return args + if isinstance(args, str): + try: + parsed = json.loads(args) + return parsed if isinstance(parsed, dict) else {} + except (json.JSONDecodeError, TypeError): + return {} + return {} + processed = [] for msg in messages: - if "reasoning_content" not in msg: + has_tool_calls = "tool_calls" in msg and isinstance(msg.get("tool_calls"), list) + needs_copy = "reasoning_content" in msg or has_tool_calls + if not needs_copy: processed.append(msg) continue msg = dict(msg) # shallow copy — don't mutate the original - rc = msg.pop("reasoning_content") + rc = msg.pop("reasoning_content", None) if self.reasoning_content == "inline" and rc: msg["content"] = f"\n{rc}\n\n{msg.get('content', '')}" + # Always normalize tool_call.arguments to dict so Jinja2 |items doesn't crash. + # The Nemotron v3 chat template reassigns tool_call = tool_call.function when + # the nested OpenAI format is used, so we fix both the direct and nested levels. + if has_tool_calls: + fixed_tool_calls = [] + for tc in msg["tool_calls"]: + if not isinstance(tc, dict): + fixed_tool_calls.append(tc) + continue + tc = dict(tc) + if "arguments" in tc and not isinstance(tc["arguments"], dict): + tc["arguments"] = _fix_arguments(tc["arguments"]) + if isinstance(tc.get("function"), dict): + fn = dict(tc["function"]) + if "arguments" in fn and not isinstance(fn["arguments"], dict): + fn["arguments"] = _fix_arguments(fn["arguments"]) + tc["function"] = fn + fixed_tool_calls.append(tc) + msg["tool_calls"] = fixed_tool_calls processed.append(msg) return processed