Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Bypass Distillation Tutorial: Nemotron-3-Nano-30B-A3B (KV-heads-only)

A minimal end-to-end demonstration that **bypass distillation improves quality** at the same compression budget. The setup is a **toy pruning task on a real production model** — we compress only KV heads (12 → 9, a modest 25% reduction) so a single comparison surfaces the bypass benefit cleanly without needing extensive downstream evaluation. The model itself ([Nemotron-3-Nano-30B-A3B-Base-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16)) is a real 30B-A3B MoE-Mamba hybrid, not a tiny stand-in.

## What this tutorial does

The teacher has 6 attention layers (each with `num_key_value_heads=2`) interleaved between Mamba and MoE-FFN blocks — **12 KV heads total** across the whole model. We compress to **9 KV heads (75% of teacher)** in two ways and compare:

1. **Without bypass** — replacement library uses Truncate-init weights (KV heads sliced from teacher; no further training).
2. **With bypass** — the bypass step runs ~10M tokens of per-block knowledge distillation, training a 1-KV-head variant per attention layer against the teacher.

Both runs use the same MIP solver and the same constraint (`target_num_kv_heads: 9`), so MIP picks per attention layer from `{teacher 2-head, 1-head, no_op}` (the no_op variant lets the solver drop attention entirely on a layer if doing so is cheap enough). FFN/MoE/Mamba blocks are copied verbatim from the teacher in both runs — only attention weights change.

**Metrics:** `lm_loss` and `token_accuracy_top_1` measured against the same held-out dataset by the realize-model step (printed automatically to `puzzle_dir/log.txt`).

## Hardware & install

- 8×H100 80GB (the teacher needs ≥60 GiB for activation scoring on a 4096 context).
- Container: `nvcr.io/nvidia/nemo:26.04` or later.
- `pip install -e ".[dev]"` from the modelopt repo root.
- Mamba kernels (required by Nemotron-3-Nano's hybrid backbone):

```bash
pip install mamba-ssm[causal-conv1d] --no-build-isolation
```

- HF auth set up so the model is downloadable: `huggingface-cli login`.

## Step A — pipeline without bypass

Edit `examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml` to point `puzzle_dir` and `dataset_path` at writable locations, then:

```bash
torchrun --nproc_per_node=8 examples/puzzletron/main.py \
--config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml
```

This runs the 8-step puzzletron pipeline (convert → score pruning activations → prune → build replacement library → score replacements → MIP → realize). With `bypass:` added in Step B the pipeline grows to 9 steps; without it, the bypass step is skipped and progress prints `N/8`. Wall-clock: roughly **1h on 8×H100** for this KV-heads-only task (KV-head importance scoring is one forward pass via `IndependentKvHeadContributionHook`, much cheaper than iterative FFN-channel scoring).

When the realize-model step finishes, the log lines at `${puzzle_dir}/log.txt` contain:

```text
validate_model_with_kl_div(model_name='teacher', ...)
Average losses = {'lm_loss': ..., 'token_accuracy_top_1': ..., 'token_accuracy_top_5': ..., 'token_accuracy_top_10': ...}
...
validate_model_with_kl_div(model_name='solution_0', ...)
Average losses = {..., 'token_accuracy_top_1': ..., ...}
```

Record the teacher's `token_accuracy_top_1` and `solution_0`'s `token_accuracy_top_1`. **Move or rename `${puzzle_dir}/single_sequence_replacement_solutions--validation/` and `${puzzle_dir}/mip/` aside** before Step B if you want to keep the no-bypass artifacts — Step B reuses the same `puzzle_dir` and the library/scoring/MIP outputs will be overwritten.

## Step B — pipeline with bypass

Add `bypass: defaults` to the `defaults:` list of `NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml` (replace the existing empty `- bypass:` entry):

```yaml
defaults:
- pruning: kv_heads_pruning
- scoring: ../validate_solutions_defaults
- realize_model: ../validate_solutions_defaults
- bypass: defaults # <-- changed from `bypass:`
- override hydra/hydra_logging: disabled
- _self_
```

Re-run the same command:

```bash
torchrun --nproc_per_node=8 examples/puzzletron/main.py \
--config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml
```

Skip-if-done caching reuses Step A's converted teacher checkpoint, activation scores, and pruned checkpoints. Only Step 5 (bypass distillation, ~60 min for 10M tokens) and the downstream library/scoring/MIP rerun. Wall-clock: roughly **+1.5 h** on top of Step A.

Bypass writes its outputs under `${puzzle_dir}/bypass/bypass_runs/bypass_heads_1/` and creates a symlink `${puzzle_dir}/ckpts/bypass_heads_1` that the replacement library builder picks up automatically.

Capture `solution_0`'s `token_accuracy_top_1` from the new realize-model log section.

## Results

Reducing total KV heads from 12 → 9 (75% of teacher) at fixed FFN/MoE/Mamba on Nemotron-3-Nano-30B-A3B-Base-BF16:

| Run | `target_num_kv_heads` | `lm_loss` | `token_accuracy_top_1` |
|------------------------------|----------------------:|----------:|-----------------------:|
| Teacher | 12 | 0.5950 | 0.8468 |
| Pruned, **no bypass** (Truncate-init) | 9 | 0.6347 | 0.8373 |
| Pruned, **with bypass** (10M-token BLD) | 9 | **0.6055**| **0.8441** |

**Bypass closes ~74% of the regression gap** at this compression budget:

- `lm_loss` gap to teacher: `0.0397` without bypass → `0.0105` with bypass — bypass recovers **74%**.
- `token_accuracy_top_1` gap to teacher: `0.0095` without bypass → `0.0027` with bypass — bypass recovers **72%**.

For 10M tokens of per-block KD, that's a substantial lift on a real 30B-A3B teacher.

## Going further: full accuracy recovery

Bypass distillation is Stage 1 of the PUZZLE pipeline — local, per-block KD that tightens the replacement library. For larger compression targets (or more aggressive KV pruning) you'll want Stage 2: **global knowledge distillation** on the realized student. See [`examples/pruning/puzzletron/`](../pruning/puzzletron/) for the Megatron-Bridge recipe and concrete MMLU recovery numbers.
2 changes: 1 addition & 1 deletion examples/puzzletron/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ To use the Puzzle algorithm effectively, we need to specify the target number of

In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md).

> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md).
> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100, Nemotron-3-Nano-30B-A3B-Base-BF16 on 8x H100 — see the [bypass distillation tutorial](Nemotron-3-Nano-30B-A3B-Base-BF16.md)). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md).

## Environment

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# @package bypass
# Bypass Distillation Configuration
# This config defines parameters for blockwise local distillation (BLD),
# which trains alternative transformer block configurations using per-block
# knowledge distillation from a teacher model.

# Runtime Configuration
dtype: "bf16" # Model precision: bf16 for efficiency, fp32 for stability
seed: 42 # Random seed for reproducibility

# Experiment Tracking
experiment_id: # Unique identifier for this experiment. Will be dynamically set
experiment_dir: # Directory for this experiment. Will be dynamically set
iter_num: 1 # Current iteration number
step_num: 1 # Current step number within iteration
token_count: 0 # Token count tracker (auto-updated during training)

# Data Configuration
data:
data_column: "messages"
block_size: 512 # Sequence length (tokens per sample)
bos_rate: 0.5
fim_rate: 0
fim_spm_rate: 0
source_datasets_to_discard: []
load_from_disk: true # Load preprocessed data from disk or from stream
keep_in_memory: false
val_dataset_name: valid
max_eval_samples: 4
eval_samples_per_process: # Samples per GPU during distributed eval (auto if null)
shuffle_train_data_seed: ${random_int:0,9999} # Seed for shuffling train data

# Training Configuration
training:
learning_rate: 1e-4 # Initial learning rate (1e-4 = 0.0001)
training_tokens: 1e+4 # Total training tokens (10K tokens - sanity check)
micro_batch_size: 2
val_micro_batch_size: 1
warmup_ratio: 0.05
warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.warmup_ratio}} # Auto-calculated warmup steps
min_lr_factor: 1e-5
grad_accumulation_steps: 1
skip_first_batches: 0 # Use for debugging or to skip few batches which cause crashes or optimization issues.
weight_decay: 0.1
decay_lr: true
beta1: 0.9
beta2: 0.95
use_grad_scaling: false
grad_clip: 1.0
grad_clip_type: norm
clipping_count: 0
log_interval: 5
eval_interval: 5

# Model Loading Configuration
resume_checkpoint_path: # Path to resume training from checkpoint
find_last_ckpt_for_resume: true # Auto-resume by finding last checkpoint (bool)
parameter_count:
init_checkpoint_path: # Path to initialize weights from

model:
student_weights_dtype: "bf16" # Student model weight precision

model_overrides:
delete_old_checkpoints: true # Clean up old checkpoints to save disk space
save_interval_seconds: 12900 # Save checkpoint every ~3.5 hours
save_interval: 1e+9 # Save checkpoint every 1B steps (effectively disabled)
save_checkpoint_when_done: true # Save final checkpoint when training completes

# Architecture modifications for student model
model_config_overrides:
ffn:
- intermediate_size:
no_op: # Disable FFN entirely (true/false)
attention:
- num_key_value_heads: # Number of kv-heads (for GQA)
no_op: # Disable attention entirely (true/false)

# Model Factory Configuration - Controls student model creation and initialization
model_factory:
factory: bypass_factory_fn # Unified factory supporting all layer types
block_loss_func: normalized_mse_loss # Loss function for comparing teacher/student blocks. vectorwise_normalized_mse_loss / batched_normalized_mse_loss / normalized_mse_loss
gqa_init_mode: AverageKV # How to initialize K/V heads in GQA. All options here: GQAInitMode
mlp_init_mode: Truncate # MLP initialization. All options here: MlpInitMode
mlp_init_config: # Configuration for MLP initialization (if needed)
activations_log_dir: # Directory with activation statistics (required for PruneByActivationsLog)
linear_init_mode: FromTeacher # How to initialize linear layers: FromTeacher, Random, etc.
submodule_for_loss_calculation: # Specific submodule for loss calc.
keys_to_learn: # What parameters to train. Either "entire_block", or specific submodules. Computed dynamically.

# Validation Configuration
disable_initial_validate: false
validate_teacher_model: true
validate_student_model: true
disable_validation: false # Enable validation to exercise all code paths
best_val_loss: 1e+9 # Track best validation loss achieved

# Performance Optimization
compile: false # Use PyTorch compilation
disable_fa2: false # Disable Flash Attention 2 (false = use FA2 if available)
teacher_model_load_on_cpu: false

# Checkpoint Management
save_checkpoint_before_training: false # Save initial checkpoint before training
disable_checkpoint_save: false # Disable all checkpoint saving
save_best_ckpt: true # Save checkpoint when validation improves
kill_after_first_save: false # Exit after first checkpoint save (for testing)
realize_best_or_latest: "best"

wandb_log: false
wandb:
project:
entity:

# Multiple bypass configurations to train sequentially.
# Each entry overrides model.model_config_overrides and optionally model_factory.keys_to_learn.
# If empty or absent, a single run uses the settings above.
configs:
- model_config_overrides:
ffn:
- intermediate_size: 3072
attention:
- num_key_value_heads: 8
keys_to_learn: subblock_ffn
- model_config_overrides:
ffn:
- intermediate_size: 5888
attention:
- num_key_value_heads: 8
keys_to_learn: subblock_ffn
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
defaults:
- pruning: kv_heads_pruning
- scoring: ../validate_solutions_defaults
- realize_model: ../validate_solutions_defaults
- bypass: defaults
- override hydra/hydra_logging: disabled
- _self_

puzzle_dir: ???
descriptor: nemotron_h
teacher_dir: ${puzzle_dir}/ckpts/teacher/
replacement_library_path: ${puzzle_dir}/replacement_library.json
dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2

skip_realize_model: false

# KV-heads-only pruning: lock off FFN/MoE-side variants. The replacement library
# exposes {teacher 2-head, 1-head, no_op} per attention layer; FFN and Mamba
# blocks are copied verbatim from the teacher.
build_replacement_library:
add_ffn_no_ops: false
add_attention_no_ops: true

calc_subblock_stats:
batch_sizes: [64, 96, 128]
prefill_seq_len: 4096
generation_seq_len: 4096
num_active_tokens_override: # Optional override for sequence lengths
prefill_queue_size: 0
allocate_prefill_query: false
benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking
merge_with_existing_stats: false
subblock_stats_filename: "subblock_stats.json"
moe_stats_filename: "moe_stats.json"
runtime_stats:
backend: trt_torch

scoring:
descriptor: ${descriptor}
solutions_to_validate:
skip_existing_solutions: true

replacement_library_path: ${replacement_library_path}
solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json}
teacher_dir: ${to_path:${teacher_dir}}
output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation

eval_samples: 128
micro_batch_size: 1
seed: 42
shuffle_seed: 444
dataset_path: ${dataset_path}

mip:
single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}}
subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}}
output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions}
gathered_metrics_path:
puzzle_profile:

objective: metrics.cosine_embedding_loss_hidden_states
bigger_is_better: false

subblock_stats_args:
- batch_size: 96
weights_dtype: torch.bfloat16
activations_dtype: torch.bfloat16
kv_cache_dtype: torch.bfloat16

report_additional_costs:
- stats.memory_mib
- stats.num_params
- stats.num_kv_heads
- stats.has_attention
- stats.has_ffn
- stats.kv_cache_memory_mib
- stats.attention_memory_mib
- stats.ffn_memory_mib
- stats.ffn_num_params
- stats.attention_num_params

human_constraints:
target_num_kv_heads: 9 # toy KV-heads-only target; see nemotron-3-nano-30b-a3b.yaml

mip_constraints:
metric_overrides:
max_seconds_per_solution: 60

realize_model:
descriptor: ${descriptor}
teacher_dir: ${to_path:${teacher_dir}}
tokenizer_name: ${to_path:${teacher_dir}}
replacement_library_path: ${replacement_library_path}
save_models: true
solutions_path: # Filled dynamically

# Validate params
skip_validation: false
eval_samples: 128
micro_batch_size: 1
seed: 42
shuffle_seed: 444
dataset_path: ${dataset_path}

nccl_timeout_minutes: ${timedelta_minutes:10}

# This section redirects Hydra outputs
hydra:
run:
dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S}
Loading
Loading