From 691f4f6a3ecdc6c83533b61c9673b989d212c679 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Wed, 6 May 2026 01:30:47 -0700 Subject: [PATCH 01/13] Add bypass distillation stage to puzzletron pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces blockwise local distillation (BLD) as an optional 5th step between pruning and replacement-library construction. When `bypass:` is present in the Hydra config, puzzletron trains alternative transformer block configurations against the teacher via per-block knowledge distillation, then surfaces the trained subblocks to MIP through symlinks under `puzzle_dir/ckpts/`. New module `modelopt/torch/puzzletron/bypass_distillation/`: - training_loop.py: pipeline-parallel KD loop with cosine LR schedule, per-block AdamW + GradScaler, validation, and time/step-based checkpointing. - stitched_model_factory.py: unified factory (FFN/attention/MoE/Mamba/ whole-block) driven by `mlp_init_mode` x `keys_to_learn`; composes multiple pruning mixins (experts_removal + kv_heads + ffn_intermediate) when student/teacher configs differ along multiple axes. - bypass_checkpoint_utils.py: stitched-module state save/load with `latest` symlink and `saving_completed` marker; resume scans only plain `iter-NNNNNN-ckpt` directories. - bypass_utils.py: experiment_id derivation encoding every override axis (FFN+KV / experts+KV / KV-only) so sweeps cannot collide. Pipeline integration (puzzletron_nas_plugin.py): - _progress_step(hydra_cfg, stage) helper + canonical stage order produce coherent `Puzzletron Progress N/T` strings; total grows from 8 to 9 when bypass is configured. - Skip-if-done caching for every stage (convert / score / prune / bypass / library build). - Stale-library detection: any `ckpts/*` entry newer than `replacement_library.json` triggers a rebuild so post-bypass weights are picked up automatically. - build_replacement_library orders bypass-trained subblocks before Truncate-init variants so `drop_duplicates(keep="first")` is deterministic. - Auto-download HF model when input_hf_model_path is not a local dir. KV-heads pruning generalized beyond Llama: - KVHeadsLayerDescriptor registered for GptOss, NemotronH, NemotronH-V2, Qwen3-VL. - _lm_attrs() probes text_config and language_config so VL configs (Qwen3-VL, Llava, Llama-4) read num_attention_heads and head_dim from the right sub-config. - _init_attention_biases falls back to state-dict probing when a config (e.g. GptOssConfig) doesn't expose o_proj_bias / attention_bias as top-level attributes. MIP: new target_num_kv_heads constraint over stats.num_kv_heads for KV-cache-only sweeps. Tools: - _copy_auto_map_code_files: copies custom modeling_*.py files alongside config.json so trust-remote-code models reload correctly; identifier-shape guard rejects malformed auto_map entries. - tools/robust_json.py: JSON encoder for dataclasses, paths, enums, Namespaces, OmegaConf nodes, functions/classes, and timedeltas; used by bypass to serialize resume state. Multi-mixin and bug fixes in legacy paths: - child_init._process_single_layer accepts a single mixin or a list, enabling experts_removal + kv_heads + ffn_intermediate stacking. - update_model_config.override returns the original `item` (not None) when an override is None — preserves the original value instead of clobbering it. Data utilities: - create_train_dataloader for the infinite ConstantLengthDataset used by bypass training. - ConstantLengthDataset falls back to newline-joined message contents when the tokenizer has no chat_template (base models). - format_stitched_losses filters NaN entries from no-op blocks. Tutorial + configs: - examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md: end-to-end KV-heads-only demonstration on a real 30B-A3B MoE-Mamba teacher, showing bypass closes ~74% of the lm_loss regression gap. - Per-family configs under configs/nemotron-3-nano-30b-a3b/. - Generic bypass/defaults.yaml template under llama-3_1-8B. Tests: - Unit: normalized-MSE losses; get_distributed_modules_ownership. - GPU: test_bypass.py parametrizes block-pruning, KV-head compression, multi-config sweep, and checkpoint-contents tests across 9 model families (extracted into PUZZLETRON_FAMILIES). Signed-off-by: Sepehr Sameni --- .../Nemotron-3-Nano-30B-A3B-Base-BF16.md | 96 ++ examples/puzzletron/README.md | 2 +- .../bypass/defaults.yaml | 130 +++ ...DIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml | 110 ++ .../bypass/defaults.yaml | 120 +++ .../nemotron-3-nano-30b-a3b.yaml | 30 + .../pruning/kv_heads_pruning.yaml | 24 + .../pruning/pruning_defaults.yaml | 33 + .../validate_model_defaults.yaml | 17 + .../validate_solutions_defaults.yaml | 10 + examples/puzzletron/main.py | 25 +- modelopt/torch/puzzletron/__init__.py | 1 + .../anymodel/model_descriptor/base.py | 13 + .../gpt_oss/gpt_oss_model_descriptor.py | 24 +- .../nemotron_h/nemotron_h_model_descriptor.py | 20 +- .../nemotron_h_v2_model_descriptor.py | 20 +- .../qwen3_vl/qwen3_vl_model_descriptor.py | 15 +- .../bypass_distillation/__init__.py | 22 + .../bypass_checkpoint_utils.py | 192 ++++ .../bypass_distillation/bypass_utils.py | 80 ++ .../bypass_distillation/data_classes.py | 43 + .../stitched_model_factory.py | 646 ++++++++++++ .../bypass_distillation/training_loop.py | 978 ++++++++++++++++++ modelopt/torch/puzzletron/mip/run_puzzle.py | 5 + .../pruning/kv_heads_pruning_mixin.py | 4 +- .../torch/puzzletron/pruning/pruning_utils.py | 46 +- .../torch/puzzletron/puzzletron_nas_plugin.py | 208 +++- .../build_replacement_library.py | 16 +- .../torch/puzzletron/sewing_kit/passage.py | 1 + modelopt/torch/puzzletron/sewing_kit/utils.py | 54 + .../tools/bypassed_training/child_init.py | 44 +- .../puzzletron/tools/checkpoint_utils_hf.py | 57 +- .../torch/puzzletron/tools/robust_json.py | 77 ++ .../puzzletron/utils/data/dataloaders.py | 50 +- .../torch/puzzletron/utils/data/dataset.py | 9 +- modelopt/torch/puzzletron/utils/parsing.py | 15 + tests/_test_utils/torch/puzzletron/utils.py | 37 + tests/gpu/torch/puzzletron/test_bypass.py | 662 ++++++++++++ tests/gpu/torch/puzzletron/test_puzzletron.py | 14 +- .../torch/puzzletron/test_bypass_losses.py | 117 +++ .../torch/puzzletron/test_bypass_utils.py | 87 ++ 41 files changed, 4059 insertions(+), 95 deletions(-) create mode 100644 examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md create mode 100644 examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/kv_heads_pruning.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/pruning_defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_model_defaults.yaml create mode 100644 examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_solutions_defaults.yaml create mode 100644 modelopt/torch/puzzletron/bypass_distillation/__init__.py create mode 100644 modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py create mode 100644 modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py create mode 100644 modelopt/torch/puzzletron/bypass_distillation/data_classes.py create mode 100644 modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py create mode 100644 modelopt/torch/puzzletron/bypass_distillation/training_loop.py create mode 100644 modelopt/torch/puzzletron/tools/robust_json.py create mode 100644 tests/gpu/torch/puzzletron/test_bypass.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_losses.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_utils.py diff --git a/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md b/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md new file mode 100644 index 00000000000..3f48460cb2d --- /dev/null +++ b/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md @@ -0,0 +1,96 @@ +# Bypass Distillation Tutorial: Nemotron-3-Nano-30B-A3B (KV-heads-only) + +A minimal end-to-end demonstration that **bypass distillation improves quality** at the same compression budget. The setup is a **toy pruning task on a real production model** — we compress only KV heads (12 → 9, a modest 25% reduction) so a single comparison surfaces the bypass benefit cleanly without needing extensive downstream evaluation. The model itself ([Nemotron-3-Nano-30B-A3B-Base-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16)) is a real 30B-A3B MoE-Mamba hybrid, not a tiny stand-in. + +## What this tutorial does + +The teacher has 6 attention layers (each with `num_key_value_heads=2`) interleaved between Mamba and MoE-FFN blocks — **12 KV heads total** across the whole model. We compress to **9 KV heads (75% of teacher)** in two ways and compare: + +1. **Without bypass** — replacement library uses Truncate-init weights (KV heads sliced from teacher; no further training). +2. **With bypass** — the bypass step runs ~10M tokens of per-block knowledge distillation, training a 1-KV-head variant per attention layer against the teacher. + +Both runs use the same MIP solver and the same constraint (`target_num_kv_heads: 9`), so MIP picks per attention layer from `{teacher 2-head, 1-head, no_op}` (the no_op variant lets the solver drop attention entirely on a layer if doing so is cheap enough). FFN/MoE/Mamba blocks are copied verbatim from the teacher in both runs — only attention weights change. + +**Metrics:** `lm_loss` and `token_accuracy_top_1` measured against the same held-out dataset by the realize-model step (printed automatically to `puzzle_dir/log.txt`). + +## Hardware & install + +- 8×H100 80GB (the teacher needs ≥60 GiB for activation scoring on a 4096 context). +- Container: `nvcr.io/nvidia/nemo:26.04` or later. +- `pip install -e ".[dev]"` from the modelopt repo root. +- Mamba kernels (required by Nemotron-3-Nano's hybrid backbone): + ```bash + pip install mamba-ssm[causal-conv1d] --no-build-isolation + ``` +- HF auth set up so the model is downloadable: `huggingface-cli login`. + +## Step A — pipeline without bypass + +Edit `examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml` to point `puzzle_dir` and `dataset_path` at writable locations, then: + +```bash +torchrun --nproc_per_node=8 examples/puzzletron/main.py \ + --config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml +``` + +This runs the 8-step puzzletron pipeline (convert → score pruning activations → prune → build replacement library → score replacements → MIP → realize). With `bypass:` added in Step B the pipeline grows to 9 steps; without it, the bypass step is skipped and progress prints `N/8`. Wall-clock: roughly **1h on 8×H100** for this KV-heads-only task (KV-head importance scoring is one forward pass via `IndependentKvHeadContributionHook`, much cheaper than iterative FFN-channel scoring). + +When the realize-model step finishes, the log lines at `${puzzle_dir}/log.txt` contain: + +``` +validate_model_with_kl_div(model_name='teacher', ...) +Average losses = {'lm_loss': ..., 'token_accuracy_top_1': ..., 'token_accuracy_top_5': ..., 'token_accuracy_top_10': ...} +... +validate_model_with_kl_div(model_name='solution_0', ...) +Average losses = {..., 'token_accuracy_top_1': ..., ...} +``` + +Record the teacher's `token_accuracy_top_1` and `solution_0`'s `token_accuracy_top_1`. **Move or rename `${puzzle_dir}/single_sequence_replacement_solutions--validation/` and `${puzzle_dir}/mip/` aside** before Step B if you want to keep the no-bypass artifacts — Step B reuses the same `puzzle_dir` and the library/scoring/MIP outputs will be overwritten. + +## Step B — pipeline with bypass + +Add `bypass: defaults` to the `defaults:` list of `NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml` (replace the existing empty `- bypass:` entry): + +```yaml +defaults: + - pruning: kv_heads_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: defaults # <-- changed from `bypass:` + - override hydra/hydra_logging: disabled + - _self_ +``` + +Re-run the same command: + +```bash +torchrun --nproc_per_node=8 examples/puzzletron/main.py \ + --config examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml +``` + +Skip-if-done caching reuses Step A's converted teacher checkpoint, activation scores, and pruned checkpoints. Only Step 5 (bypass distillation, ~60 min for 10M tokens) and the downstream library/scoring/MIP rerun. Wall-clock: roughly **+1.5 h** on top of Step A. + +Bypass writes its outputs under `${puzzle_dir}/bypass/bypass_runs/bypass_heads_1/` and creates a symlink `${puzzle_dir}/ckpts/bypass_heads_1` that the replacement library builder picks up automatically. + +Capture `solution_0`'s `token_accuracy_top_1` from the new realize-model log section. + +## Results + +Reducing total KV heads from 12 → 9 (75% of teacher) at fixed FFN/MoE/Mamba on Nemotron-3-Nano-30B-A3B-Base-BF16: + +| Run | `target_num_kv_heads` | `lm_loss` | `token_accuracy_top_1` | +|------------------------------|----------------------:|----------:|-----------------------:| +| Teacher | 12 | 0.5950 | 0.8468 | +| Pruned, **no bypass** (Truncate-init) | 9 | 0.6347 | 0.8373 | +| Pruned, **with bypass** (10M-token BLD) | 9 | **0.6055**| **0.8441** | + +**Bypass closes ~74% of the regression gap** at this compression budget: + +- `lm_loss` gap to teacher: `0.0397` without bypass → `0.0105` with bypass — bypass recovers **74%**. +- `token_accuracy_top_1` gap to teacher: `0.0095` without bypass → `0.0027` with bypass — bypass recovers **72%**. + +For 10M tokens of per-block KD, that's a substantial lift on a real 30B-A3B teacher. + +## Going further: full accuracy recovery + +Bypass distillation is Stage 1 of the PUZZLE pipeline — local, per-block KD that tightens the replacement library. For larger compression targets (or more aggressive KV pruning) you'll want Stage 2: **global knowledge distillation** on the realized student. See [`examples/pruning/puzzletron/`](../pruning/puzzletron/) for the Megatron-Bridge recipe and concrete MMLU recovery numbers. diff --git a/examples/puzzletron/README.md b/examples/puzzletron/README.md index 571b40ca499..93f8ced1cd5 100644 --- a/examples/puzzletron/README.md +++ b/examples/puzzletron/README.md @@ -11,7 +11,7 @@ To use the Puzzle algorithm effectively, we need to specify the target number of In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model reducing GPU memory usage from 113 GiB to 96 GiB (15% reduction) with less than 1% regression in the token_accuracy_top_10 metric. Other supported models should be compressed in a similar way. For GptOss there is one [additional step to be performed](GPTOSS.md). -> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md). +> **Note:** Other models are also supported. See the [configs](./configs/) directory for additional model configurations (e.g., Llama-3.2-3B-Instruct on 1x H100, Qwen2.5-7B-Instruct on 1x H100, Qwen3-8B on 1x H100, Nemotron-Nano-12B-v2 on 1x H100, Mistral-Small-24B-Instruct-2501 on 4x H100, Nemotron-3-Nano-30B-A3B-Base-BF16 on 8x H100 — see the [bypass distillation tutorial](Nemotron-3-Nano-30B-A3B-Base-BF16.md)). For information on adding support for new models, see the [AnyModel Guide](../../modelopt/torch/puzzletron/anymodel/README.md). ## Environment diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml new file mode 100644 index 00000000000..7a0be378949 --- /dev/null +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml @@ -0,0 +1,130 @@ +# @package bypass +# Bypass Distillation Configuration +# This config defines parameters for blockwise local distillation (BLD), +# which trains alternative transformer block configurations using per-block +# knowledge distillation from a teacher model. + +# Runtime Configuration +dtype: "bf16" # Model precision: bf16 for efficiency, fp32 for stability +seed: 42 # Random seed for reproducibility + +# Experiment Tracking +experiment_id: # Unique identifier for this experiment. Will be dynamically set +experiment_dir: # Directory for this experiment. Will be dynamically set +iter_num: 1 # Current iteration number +step_num: 1 # Current step number within iteration +token_count: 0 # Token count tracker (auto-updated during training) + +# Data Configuration +data: + data_column: "messages" + block_size: 512 # Sequence length (tokens per sample) + bos_rate: 0.5 + fim_rate: 0 + fim_spm_rate: 0 + source_datasets_to_discard: [] + load_from_disk: true # Load preprocessed data from disk or from stream + keep_in_memory: false + val_dataset_name: valid + max_eval_samples: 4 + eval_samples_per_process: null # Samples per GPU during distributed eval (auto if null) + shuffle_train_data_seed: ${random_int:0,9999} # Seed for shuffling train data + +# Training Configuration +training: + learning_rate: 1e-4 # Initial learning rate (1e-4 = 0.0001) + training_tokens: 1e+4 # Total training tokens (10K tokens - sanity check) + micro_batch_size: 2 + val_micro_batch_size: 1 + warmup_ratio: 0.05 + warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.warmup_ratio}} # Auto-calculated warmup steps + min_lr_factor: 1e-5 + grad_accumulation_steps: 1 + skip_first_batches: 0 # Use for debugging or to skip few batches which cause crashes or optimization issues. + weight_decay: 0.1 + decay_lr: true + beta1: 0.9 + beta2: 0.95 + use_grad_scaling: false + grad_clip: 1.0 + grad_clip_type: norm + clipping_count: 0 + log_interval: 5 + eval_interval: 5 + +# Model Loading Configuration +resume_checkpoint_path: null # Path to resume training from checkpoint +find_last_ckpt_for_resume: True # Auto-resume by finding last checkpoint (bool) +parameter_count: null +init_checkpoint_path: null # Path to initialize weights from + +model: + student_weights_dtype: "bf16" # Student model weight precision + + model_overrides: + delete_old_checkpoints: true # Clean up old checkpoints to save disk space + save_interval_seconds: 12900 # Save checkpoint every ~3.5 hours + save_interval: 1e+9 # Save checkpoint every 1B steps (effectively disabled) + save_checkpoint_when_done: true # Save final checkpoint when training completes + +# Architecture modifications for student model + model_config_overrides: + ffn: + - intermediate_size: + no_op: # Disable FFN entirely (true/false) + attention: + - num_key_value_heads: # Number of kv-heads (for GQA) + no_op: # Disable attention entirely (true/false) + +# Model Factory Configuration - Controls student model creation and initialization +model_factory: + factory: bypass_factory_fn # Unified factory supporting all layer types + block_loss_func: normalized_mse_loss # Loss function for comparing teacher/student blocks. vectorwise_normalized_mse_loss / batched_normalized_mse_loss / normalized_mse_loss + gqa_init_mode: AverageKV # How to initialize K/V heads in GQA. All options here: GQAInitMode + mlp_init_mode: Truncate # MLP initialization. All options here: MlpInitMode + mlp_init_config: # Configuration for MLP initialization (if needed) + activations_log_dir: null # Directory with activation statistics (required for PruneByActivationsLog) + linear_init_mode: FromTeacher # How to initialize linear layers: FromTeacher, Random, etc. + submodule_for_loss_calculation: null # Specific submodule for loss calc. + keys_to_learn: null # What parameters to train. Either "entire_block", or specific submodules. Computed dynamically. + +# Validation Configuration +disable_initial_validate: false +validate_teacher_model: true +validate_student_model: true +disable_validation: false # Enable validation to exercise all code paths +best_val_loss: 1e+9 # Track best validation loss achieved + +# Performance Optimization +compile: false # Use PyTorch compilation +disable_fa2: false # Disable Flash Attention 2 (false = use FA2 if available) +teacher_model_load_on_cpu: false + +# Checkpoint Management +save_checkpoint_before_training: false # Save initial checkpoint before training +disable_checkpoint_save: false # Disable all checkpoint saving +save_best_ckpt: true # Save checkpoint when validation improves +kill_after_first_save: false # Exit after first checkpoint save (for testing) +realize_best_or_latest: "best" + +wandb_log: false +wandb: + project: + entity: + +# Multiple bypass configurations to train sequentially. +# Each entry overrides model.model_config_overrides and optionally model_factory.keys_to_learn. +# If empty or absent, a single run uses the settings above. +configs: + - model_config_overrides: + ffn: + - intermediate_size: 3072 + attention: + - num_key_value_heads: 8 + keys_to_learn: subblock_ffn + - model_config_overrides: + ffn: + - intermediate_size: 5888 + attention: + - num_key_value_heads: 8 + keys_to_learn: subblock_ffn diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml new file mode 100644 index 00000000000..f57ec578785 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16.yaml @@ -0,0 +1,110 @@ +defaults: + - pruning: kv_heads_pruning + - scoring: ../validate_solutions_defaults + - realize_model: ../validate_solutions_defaults + - bypass: defaults + - override hydra/hydra_logging: disabled + - _self_ + +puzzle_dir: ??? +descriptor: nemotron_h +teacher_dir: ${puzzle_dir}/ckpts/teacher/ +replacement_library_path: ${puzzle_dir}/replacement_library.json +dataset_path: ??? # path to Nemotron-Post-Training-Dataset-v2 + +skip_realize_model: false + +# KV-heads-only pruning: lock off FFN/MoE-side variants. The replacement library +# exposes {teacher 2-head, 1-head, no_op} per attention layer; FFN and Mamba +# blocks are copied verbatim from the teacher. +build_replacement_library: + add_ffn_no_ops: false + add_attention_no_ops: true + +calc_subblock_stats: + batch_sizes: [64, 96, 128] + prefill_seq_len: 4096 + generation_seq_len: 4096 + num_active_tokens_override: # Optional override for sequence lengths + prefill_queue_size: 0 + allocate_prefill_query: false + benchmark_iterations: # Set to a number (e.g., 1000) to enable runtime benchmarking + merge_with_existing_stats: false + subblock_stats_filename: "subblock_stats.json" + moe_stats_filename: "moe_stats.json" + runtime_stats: + backend: trt_torch + +scoring: + descriptor: ${descriptor} + solutions_to_validate: + skip_existing_solutions: true + + replacement_library_path: ${replacement_library_path} + solutions_path: ${to_path:${puzzle_dir}/single_sequence_replacement_solutions.json} + teacher_dir: ${to_path:${teacher_dir}} + output_dir: ${puzzle_dir}/single_sequence_replacement_solutions--validation + + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +mip: + single_block_replacement_validation_dir: ${to_path:${scoring.output_dir}} + subblock_stats_path: ${to_path:${puzzle_dir}/${calc_subblock_stats.subblock_stats_filename}} + output_path: ${to_path:${puzzle_dir}/mip/puzzle_solutions} + gathered_metrics_path: + puzzle_profile: + + objective: metrics.cosine_embedding_loss_hidden_states + bigger_is_better: false + + subblock_stats_args: + - batch_size: 96 + weights_dtype: torch.bfloat16 + activations_dtype: torch.bfloat16 + kv_cache_dtype: torch.bfloat16 + + report_additional_costs: + - stats.memory_mib + - stats.num_params + - stats.num_kv_heads + - stats.has_attention + - stats.has_ffn + - stats.kv_cache_memory_mib + - stats.attention_memory_mib + - stats.ffn_memory_mib + - stats.ffn_num_params + - stats.attention_num_params + + human_constraints: + target_num_kv_heads: 9 # toy KV-heads-only target; see nemotron-3-nano-30b-a3b.yaml + + mip_constraints: + metric_overrides: + max_seconds_per_solution: 60 + +realize_model: + descriptor: ${descriptor} + teacher_dir: ${to_path:${teacher_dir}} + tokenizer_name: ${to_path:${teacher_dir}} + replacement_library_path: ${replacement_library_path} + save_models: true + solutions_path: # Filled dynamically + + # Validate params + skip_validation: false + eval_samples: 128 + micro_batch_size: 1 + seed: 42 + shuffle_seed: 444 + dataset_path: ${dataset_path} + +nccl_timeout_minutes: ${timedelta_minutes:10} + +# This section redirects Hydra outputs +hydra: + run: + dir: ${puzzle_dir}/hydra_logs/${now:%Y-%m-%d}/${now:%H-%M-%S} diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml new file mode 100644 index 00000000000..f9f744d31ce --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml @@ -0,0 +1,120 @@ +# @package bypass +# Bypass Distillation Configuration — Nemotron-3-Nano-30B-A3B (KV-heads-only toy task). +# +# Trains a single 1-KV-head variant per attention layer using per-block knowledge +# distillation against the teacher (`subblock_attention` keys only — FFN/MoE/Mamba +# blocks are frozen). The trained weights are saved into the replacement library +# and consumed by the MIP solver alongside the no_op variant. +# +# Tutorial budget: ~10M tokens (quick sanity, ~30 min on 4×H100). Increase +# `training_tokens` for a stronger bypass effect. + +# Runtime Configuration +dtype: "bf16" +seed: 42 + +# Experiment Tracking (auto-set by set_experiment_id → "bypass_heads_1") +experiment_id: +experiment_dir: +iter_num: 1 +step_num: 1 +token_count: 0 + +# Data Configuration +data: + data_column: "messages" + block_size: 4096 + bos_rate: 0.5 + fim_rate: 0 + fim_spm_rate: 0 + source_datasets_to_discard: [] + load_from_disk: true + keep_in_memory: false + val_dataset_name: valid + max_eval_samples: 4 + eval_samples_per_process: null + shuffle_train_data_seed: ${random_int:0,9999} + +# Training Configuration +training: + learning_rate: 3e-4 + training_tokens: 1e+7 # 10M tokens (toy budget) + micro_batch_size: 2 + val_micro_batch_size: 2 + warmup_ratio: 0.05 + warmup_steps: ${warmup_steps:${.training_tokens},${..data.block_size},${.micro_batch_size},${.warmup_ratio}} + min_lr_factor: 1e-5 + grad_accumulation_steps: 8 + skip_first_batches: 0 + weight_decay: 0.1 + decay_lr: true + beta1: 0.9 + beta2: 0.95 + use_grad_scaling: false + grad_clip: 1.0 + grad_clip_type: norm + clipping_count: 0 + log_interval: 100 + eval_interval: 100 + +# Model Loading Configuration +resume_checkpoint_path: null +find_last_ckpt_for_resume: True +parameter_count: null +init_checkpoint_path: null + +model: + student_weights_dtype: "bf16" + + model_overrides: + delete_old_checkpoints: true + save_interval_seconds: 12900 + save_interval: 1e+9 + save_checkpoint_when_done: true + + # Architecture override: only attention is touched. FFN/MoE/Mamba sub-blocks + # use teacher weights verbatim (the `ffn` key is omitted on purpose). + model_config_overrides: + attention: + - num_key_value_heads: 1 + no_op: + +# Model Factory Configuration +model_factory: + factory: bypass_factory_fn + block_loss_func: normalized_mse_loss + gqa_init_mode: AverageKV + mlp_init_mode: Truncate # FFN is frozen; this knob is dormant for KV-only tasks + mlp_init_config: + activations_log_dir: null + linear_init_mode: FromTeacher + submodule_for_loss_calculation: null + keys_to_learn: subblock_attention # train ONLY the attention sub-block + +# Validation Configuration +disable_initial_validate: false +validate_teacher_model: true +validate_student_model: true +disable_validation: false +best_val_loss: 1e+9 + +# Performance Optimization +compile: false +disable_fa2: false +teacher_model_load_on_cpu: false + +# Checkpoint Management +save_checkpoint_before_training: false +disable_checkpoint_save: false +save_best_ckpt: true +kill_after_first_save: false +realize_best_or_latest: "best" + +wandb_log: false +wandb: + project: + entity: + +# Single architectural variant — `set_experiment_id` produces "bypass_heads_1". +# Add more entries here to train multiple variants in one bypass run. +configs: [] diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml new file mode 100644 index 00000000000..8a48d4f3914 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/nemotron-3-nano-30b-a3b.yaml @@ -0,0 +1,30 @@ +defaults: + - NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 + - _self_ + +# Input Hugging Face model to compress. +# Auto-downloads from HuggingFace if the path is not a local directory. +input_hf_model_path: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16 + +# Dataset path for pruning and NAS scoring +dataset_path: /workspace/datasets/Nemotron-Post-Training-Dataset-v2 + +# Working directory for puzzletron outputs +puzzle_dir: /workspace/puzzle_dir + +# Toy KV-heads-only constraint. +# Teacher has 6 attention layers × num_key_value_heads=2 = 12 KV heads total. +# Target 9 leaves 75% of teacher KV heads — the MIP solver picks per-layer from +# {teacher 2-head, 1-head, no_op} so some layers stay full, some collapse to 1 +# head, and some become no_op. +mip: + human_constraints: + target_num_kv_heads: 9 + +# KV-heads-only toy pruning task. +# teacher num_attention_heads = 32, num_key_value_heads = 2 (n_heads_in_group = 16) +# Bypass-trains a single 1-KV-head variant per attention layer +# (n_heads_in_group = 32). Combined with `add_attention_no_ops: true` in the base +# config, MIP picks per-layer from {teacher 2-head, 1-head, no_op}. +pruning: + n_heads_in_group_list: [32] # 32 / 32 = 1 KV head per attention layer diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/kv_heads_pruning.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/kv_heads_pruning.yaml new file mode 100644 index 00000000000..df37b7403c0 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/kv_heads_pruning.yaml @@ -0,0 +1,24 @@ +defaults: + - /pruning/pruning_defaults@_here_ + +# Score per-KV-head importance and create the pruned-checkpoint variants used +# to build the replacement library. +hook_class: ${get_object:modelopt.torch.prune.importance_hooks.base_hooks.IndependentKvHeadContributionHook} + +activations_log_dir: ${puzzle_dir}/pruning/pruning_scores/attn_${pruning.activation_hooks_kwargs.method}/${pruning.experiment_id} + +pruning_mixin: + _target_: modelopt.torch.puzzletron.pruning.kv_heads_pruning_mixin.KVHeadsPruningMixIn + layer_descriptor: + _target_: modelopt.torch.puzzletron.anymodel.models.nemotron_h.nemotron_h_model_descriptor.NemotronHKVHeadsLayerDescriptor + +activation_hooks_kwargs: + method: independent_kv_head_contribution + optimize_for: memory + target_layer: "mixer.o_proj" # Nemotron-H attention is under `mixer`, not `self_attn` + layer_input_descriptors_path: + +# Teacher: num_attention_heads = 32, num_key_value_heads = 2 (n_heads_in_group = 16) +# Single 1-KV-head variant: n_heads_in_group = 32 → num_kv_heads = 32 / 32 = 1 +n_heads_in_group_list: [32] +gqa_init_mode: "PruneKVHeads" diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/pruning_defaults.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/pruning_defaults.yaml new file mode 100644 index 00000000000..e05e775bee3 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/pruning/pruning_defaults.yaml @@ -0,0 +1,33 @@ +defaults: + - /validate_model_defaults + +descriptor: ${descriptor} +model_name_or_path: ${teacher_dir} +experiment_id: ${pruning.eval_samples}samples_diverse_mini +activations_log_dir: ??? +activation_hooks_kwargs: ??? + +# Data: +eval_samples: 1000 # default is 10000 +micro_batch_size: 4 +dataset_path: ${dataset_path} +val_dataset_name: train + +# Prune ckpts +pruned_ckpts_output_dir: ${puzzle_dir}/pruning/${pruning.experiment_id} + +## FFN pruning +ffn_list: +mlp_init_mode: "Truncate" # PruneByActivationsLog + +## KV-heads pruning +n_heads_in_group_list: +gqa_init_mode: "AverageKV" + +## Hidden dimension pruning +hidden_size_list: +hidden_size_init_mode: "PruneByChannelRanking" +linear_init_mode: "FromTeacher" + +mlp_init_config_yaml: + activations_log_dir: ${pruning.activations_log_dir} diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_model_defaults.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_model_defaults.yaml new file mode 100644 index 00000000000..ce1749d9698 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_model_defaults.yaml @@ -0,0 +1,17 @@ +model_dtype: torch.bfloat16 # dtype to cast the model for validate_model +autocast_dtype: torch.bfloat16 # dtype for torch.autocast for validate_model +block_size: 8192 +bos_rate: 0.5 +data_column: messages +val_dataset_name: valid +shuffle_seed: 81436 +seed: 42 +fim_rate: 0 +fim_spm_rate: 0 +source_datasets_to_discard: +varlen: false +write_results: false +calc_losses_on_cpu: false +activations_log_dir: +model_name_or_path: +load_dataset_fn: ${get_object:modelopt.torch.puzzletron.utils.data.dataloaders.load_from_disk_fn} diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_solutions_defaults.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_solutions_defaults.yaml new file mode 100644 index 00000000000..ec139023794 --- /dev/null +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/validate_solutions_defaults.yaml @@ -0,0 +1,10 @@ +defaults: + - /validate_model_defaults + - _self_ + +solutions_to_validate: +skip_validation: false +save_models: false +bigger_is_better: false +sort_solutions_by: +calculate_full_score_ablations: false diff --git a/examples/puzzletron/main.py b/examples/puzzletron/main.py index 8ceed378318..990609da4ec 100644 --- a/examples/puzzletron/main.py +++ b/examples/puzzletron/main.py @@ -67,7 +67,6 @@ def run_full_puzzletron(hydra_config_path: str): Args: config_path: Path to the YAML configuration file """ - mtpz.tools.mprint("Puzzletron Progress 1/8: starting puzzletron pipeline") dist.setup(timeout=timedelta(minutes=10)) # Register Hydra custom resolvers (needed for config resolution) @@ -77,12 +76,17 @@ def run_full_puzzletron(hydra_config_path: str): hydra_config_dir = str(hydra_config_path.parent) hydra_config_name = hydra_config_path.stem - # Load hydra config + # Load hydra config to determine total step count (bypass adds one step) hydra_cfg = mtpz.tools.initialize_hydra_config_for_dir( config_dir=hydra_config_dir, config_name=hydra_config_name, overrides=[], ) + start_step, total_steps = mtpz.puzzletron_nas_plugin._progress_step(hydra_cfg, "start") + + mtpz.tools.mprint( + f"Puzzletron Progress {start_step}/{total_steps}: starting puzzletron pipeline" + ) # Convert model (convert from HF to DeciLM, score pruning activations, # prune the model and save pruned checkpoints) @@ -113,7 +117,10 @@ def run_full_puzzletron(hydra_config_path: str): ) dist.cleanup() - mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + complete_step, _ = mtpz.puzzletron_nas_plugin._progress_step(hydra_cfg, "complete") + mtpz.tools.mprint( + f"Puzzletron Progress {complete_step}/{total_steps}: puzzletron pipeline completed (multi-gpu)" + ) def run_mip_only(hydra_config_path: str): @@ -140,21 +147,27 @@ def run_mip_only(hydra_config_path: str): config_name=hydra_config_name, overrides=[], ) + mip_step, total_steps = mtpz.puzzletron_nas_plugin._progress_step(hydra_cfg, "mip") # Check if sweep mode is enabled if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): mtpz.tools.mprint( - "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)" + f"Puzzletron Progress {mip_step}/{total_steps}: running MIP sweep for multiple compression rates (multi-gpu)" ) mtpz.mip.run_mip_sweep(hydra_cfg) else: # mip_and_realize_models (distributed processing) # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API - mtpz.tools.mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + mtpz.tools.mprint( + f"Puzzletron Progress {mip_step}/{total_steps}: running MIP and realizing models (multi-gpu)" + ) mtpz.mip.launch_mip_and_realize_model(hydra_cfg) dist.cleanup() - mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + complete_step, _ = mtpz.puzzletron_nas_plugin._progress_step(hydra_cfg, "complete") + mtpz.tools.mprint( + f"Puzzletron Progress {complete_step}/{total_steps}: puzzletron pipeline completed (multi-gpu)" + ) def main(): diff --git a/modelopt/torch/puzzletron/__init__.py b/modelopt/torch/puzzletron/__init__.py index 15389dedfa2..0af53b5cef3 100644 --- a/modelopt/torch/puzzletron/__init__.py +++ b/modelopt/torch/puzzletron/__init__.py @@ -19,6 +19,7 @@ anymodel, block_config, build_library_and_stats, + bypass_distillation, dataset, entrypoint, mip, diff --git a/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py b/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py index 3c1749d46ec..58b045bd21c 100644 --- a/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py +++ b/modelopt/torch/puzzletron/anymodel/model_descriptor/base.py @@ -169,6 +169,19 @@ def uses_autocast() -> bool: """ return True + @staticmethod + def pruning_mixins() -> Dict[str, Any]: + """Return available pruning mixins for bypass distillation. + + Override in subclasses to provide model-specific pruning mixins, e.g. + ``{"kv_heads": KVHeadsPruningMixIn(...), "experts_removal": ExpertRemovalPruningMixIn(...)}``. + + Returns an empty dict by default so that descriptors that do not need + model-specific weight-slicing (e.g. Llama with standard FFN truncation) + can rely on the generic ``create_child_state_dict`` fallback path. + """ + return {} + @staticmethod def get_language_model_config(config): """Get the language model config from a PretrainedConfig. diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py index c8fd86b4bb6..342766c949c 100644 --- a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py @@ -28,6 +28,10 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import ( + KVHeadsLayerDescriptor, + KVHeadsPruningMixIn, +) # Expert removal is supported for unquantized models (test models). # Production models use MXFP4 quantized MoE with combined tensors @@ -37,7 +41,11 @@ from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size -__all__ = ["GptOssModelDescriptor", "GptOssExpertRemovalLayerDescriptor"] +__all__ = [ + "GptOssExpertRemovalLayerDescriptor", + "GptOssKVHeadsLayerDescriptor", + "GptOssModelDescriptor", +] @ModelDescriptorFactory.register_decorator("gpt_oss") @@ -173,7 +181,19 @@ def pruning_mixins() -> Dict[str, PruningMixIn]: Note: Expert removal works for unquantized models (test models). Production models use MXFP4 quantization which is not yet supported. """ - return {"expert_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor())} + return { + "experts_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(GptOssKVHeadsLayerDescriptor()), + } + + +@dataclass +class GptOssKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "self_attn.o_proj" + attn_prefix_name: str = "model.layers.{layer_idx}.self_attn" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) @dataclass diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py index 1c5706d1944..52667b91f70 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -29,11 +29,19 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import ( + KVHeadsLayerDescriptor, + KVHeadsPruningMixIn, +) from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same -__all__ = ["NemotronHExpertRemovalLayerDescriptor", "NemotronHModelDescriptor"] +__all__ = [ + "NemotronHExpertRemovalLayerDescriptor", + "NemotronHKVHeadsLayerDescriptor", + "NemotronHModelDescriptor", +] def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: @@ -51,6 +59,15 @@ def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: return matches +@dataclass +class NemotronHKVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "mixer.o_proj" + attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + @dataclass class NemotronHExpertRemovalLayerDescriptor(ExpertRemovalLayerDescriptor): target_name: str = "mixer.gate" @@ -251,4 +268,5 @@ def build_attention_predicates() -> Dict[str, re.Pattern]: def pruning_mixins() -> Dict[str, PruningMixIn]: return { "experts_removal": ExpertRemovalPruningMixIn(NemotronHExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(NemotronHKVHeadsLayerDescriptor()), } diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py index a1e326f2357..aefe0919e9d 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py @@ -29,11 +29,19 @@ FFNIntermediateLayerDescriptor, FFNIntermediatePruningMixIn, ) +from ....pruning.kv_heads_pruning_mixin import ( + KVHeadsLayerDescriptor, + KVHeadsPruningMixIn, +) from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same -__all__ = ["NemotronHV2FFNIntermediateLayerDescriptor", "NemotronHV2ModelDescriptor"] +__all__ = [ + "NemotronHV2FFNIntermediateLayerDescriptor", + "NemotronHV2KVHeadsLayerDescriptor", + "NemotronHV2ModelDescriptor", +] def get_dynamic_modules(module_cls_str: str) -> List[Type[nn.Module]]: @@ -69,6 +77,15 @@ class NemotronHV2FFNIntermediateLayerDescriptor(FFNIntermediateLayerDescriptor): linear_weight_names: List[str] = field(default_factory=lambda: ["down_proj", "up_proj"]) +@dataclass +class NemotronHV2KVHeadsLayerDescriptor(KVHeadsLayerDescriptor): + o_proj_name: str = "mixer.o_proj" + attn_prefix_name: str = "backbone.layers.{layer_idx}.mixer" + qkvo_weight_names: List[str] = field( + default_factory=lambda: ["q_proj", "k_proj", "v_proj", "o_proj"] + ) + + @ModelDescriptorFactory.register_decorator("nemotron_h_v2") class NemotronHV2ModelDescriptor(ModelDescriptor): _DECODER_LAYER_CLS: Type[nn.Module] = None @@ -251,5 +268,6 @@ def pruning_mixins() -> Dict[str, PruningMixIn]: "ffn_intermediate": FFNIntermediatePruningMixIn( NemotronHV2FFNIntermediateLayerDescriptor() ), + "kv_heads": KVHeadsPruningMixIn(NemotronHV2KVHeadsLayerDescriptor()), # TODO: Add expert removal support when ExpertRemovalPruningMixIn is migrated } diff --git a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py index aeedd419923..a0f9c95c6ce 100644 --- a/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/qwen3_vl/qwen3_vl_model_descriptor.py @@ -26,9 +26,13 @@ ) from ....block_config import BlockConfig -from ....pruning.expert_removal_pruning_mixin import ExpertRemovalLayerDescriptor +from ....pruning.expert_removal_pruning_mixin import ( + ExpertRemovalLayerDescriptor, + ExpertRemovalPruningMixIn, +) from ....pruning.ffn_intermediate_pruning_mixin import FFNIntermediateLayerDescriptor -from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn +from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same, return_tuple_of_size @@ -56,6 +60,13 @@ def get_language_model_config(config): """Qwen3-VL has nested text_config for language model parameters.""" return config.text_config if hasattr(config, "text_config") else config + @staticmethod + def pruning_mixins() -> Dict[str, PruningMixIn]: + return { + "experts_removal": ExpertRemovalPruningMixIn(Qwen3VLExpertRemovalLayerDescriptor()), + "kv_heads": KVHeadsPruningMixIn(Qwen3VLKVHeadsLayerDescriptor()), + } + @staticmethod def decoder_layer_cls(): return Qwen3VLMoeTextDecoderLayer diff --git a/modelopt/torch/puzzletron/bypass_distillation/__init__.py b/modelopt/torch/puzzletron/bypass_distillation/__init__.py new file mode 100644 index 00000000000..790166b4519 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/__init__.py @@ -0,0 +1,22 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bypass distillation (blockwise local distillation) for the PUZZLE framework. + +This module implements Stage 1 of the PUZZLE pipeline: training alternative transformer +block configurations using per-block knowledge distillation from a teacher model. +""" + +from .training_loop import launch_bypass_distillation diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py new file mode 100644 index 00000000000..d1d95939282 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Checkpoint utilities for bypass distillation.""" + +import re +from collections import OrderedDict +from pathlib import Path +from typing import Optional, Type, Union + +import torch +from omegaconf import DictConfig +from tqdm import tqdm + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.robust_json import json_dump + +from .stitched_model_factory import StitchedModuleDescriptor + + +def find_latest_run_dir(run_parent_dir: Union[str, Path]) -> str | None: + """Find the latest plain-iter checkpoint directory within a run parent directory. + + Resume must pick a directory created by the step-interval / time-based / final save + paths (named ``iter-NNNNNN-ckpt``) — not ``best-iter-*`` (which corresponds to a + validation-best snapshot whose optimizer state may be stale relative to the latest + iter), nor ``start-iter-*`` / ``final-iter-*`` (markers, not resume points). + """ + run_parent_dir = Path(run_parent_dir) + + # Check for the "latest" symlink — set only by save_bypass_checkpoint, always + # points at a plain ``iter-*`` directory. Fast path. + latest_dir = run_parent_dir / "latest" + if latest_dir.exists() and (latest_dir / "saving_completed").exists(): + return str(latest_dir) + + # Fallback: scan plain ``iter-NNNNNN-ckpt`` directories only. + iter_re = re.compile(r"^iter-(\d+)-ckpt$") + candidate_dirs: list[tuple[int, Path]] = [] + for d in run_parent_dir.iterdir(): + if not d.is_dir(): + continue + match = iter_re.match(d.name) + if match: + candidate_dirs.append((int(match.group(1)), d)) + + if not candidate_dirs: + return None + + candidate_dirs.sort(key=lambda x: x[0], reverse=True) + for _, ckpt_dir in candidate_dirs: + if (ckpt_dir / "saving_completed").exists(): + return str(ckpt_dir) + return None + + +def load_local_state( + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_path: str | Path, +) -> None: + """Load local state from a checkpoint. + + Loads both optimizer and state dicts into stitched module descriptors. + Modifies stitched_module_descriptors in place. + """ + device = torch.device(f"cuda:{dist.local_rank()}") + load_dir = Path(checkpoint_path) + + if not load_dir.exists(): + raise RuntimeError(f'Can\'t load local state. "{load_dir}" does not exist.') + + for stitched_module_name, stitched_module_descriptor in stitched_module_descriptors.items(): + stitched_module = stitched_module_descriptor.stitched_module + optimizer = stitched_module_descriptor.optimizer + + state_dict_path = load_dir / "stitched" / f"{stitched_module_name}.state_dict.pth" + mprint(f"Loading state dict for module {stitched_module_name} from {state_dict_path}") + loaded_state_dict = torch.load(state_dict_path, map_location=device, weights_only=True) + loaded_state_dict = {**stitched_module.state_dict(), **loaded_state_dict} + + stitched_module.load_state_dict(loaded_state_dict) + del loaded_state_dict + + if optimizer is not None: + optimizer_state_path = ( + load_dir / "stitched" / f"{stitched_module_name}.optimizer_state.pth" + ) + mprint( + f"Loading optimizer state for module {stitched_module_name} from {optimizer_state_path}" + ) + loaded_optimizer_state = torch.load( + optimizer_state_path, map_location=device, weights_only=True + ) + optimizer.load_state_dict(loaded_optimizer_state) + del loaded_optimizer_state + + +def _save_local_file(obj, save_path: Path | str, overwrite=True): + save_path = Path(save_path) + if save_path.exists(): + if not overwrite: + mprint(f'WARNING: Local save path "{save_path}" already exists. Skipping') + return + torch.save(obj, save_path) + + +def _save_local_state( + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_dir: Path | str, + overwrite=True, +) -> None: + save_dir = Path(checkpoint_dir) / "stitched" + + if dist.is_master(): + save_dir.mkdir(parents=True, exist_ok=True) + + # Main process creates the directory, so we must wait for it to finish + dist.barrier() + + for stitched_module_name, stitched_module_descriptor in tqdm( + stitched_module_descriptors.items() + ): + optimizer = stitched_module_descriptor.optimizer + + state_dict_path = save_dir / f"{stitched_module_name}.state_dict.pth" + aprint(f"Saving state dict for module {stitched_module_name} to {state_dict_path}") + state_dict = { + **stitched_module_descriptor.owned_parameters, + **stitched_module_descriptor.owned_buffers, + } + _save_local_file(state_dict, state_dict_path, overwrite=overwrite) + + if optimizer is not None: + optimizer_state_path = save_dir / f"{stitched_module_name}.optimizer_state.pth" + mprint( + f"Saving optimizer state for module {stitched_module_name} to {optimizer_state_path}" + ) + _save_local_file(optimizer.state_dict(), optimizer_state_path, overwrite=overwrite) + + dist.barrier() + + +def save_bypass_checkpoint( + cfg: DictConfig, + descriptor: Type[ModelDescriptor], + model: torch.nn.Module, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + checkpoint_dir: Path | str, + reference_checkpoint_dir: Optional[Path] = None, +) -> None: + """Save a bypass distillation checkpoint.""" + checkpoint_dir = Path(checkpoint_dir) + mprint("Starting checkpoint save") + mprint(f"Saving checkpoint to {checkpoint_dir}") + + # Save stitched module states + _save_local_state( + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=checkpoint_dir, + overwrite=cfg.bypass.model.model_overrides.delete_old_checkpoints, + ) + # Save as HF checkpoint + save_checkpoint(model=model, checkpoint_dir=checkpoint_dir, descriptor=descriptor) + + if dist.is_master(): + # Create 'latest' symlink + latest_symlink = Path(cfg.bypass.experiment_dir) / "latest" + latest_symlink.unlink(missing_ok=True) + latest_symlink.symlink_to(checkpoint_dir.name) + # Save config args json + json_dump(cfg.bypass, checkpoint_dir / "args.json") + # Save completed file + completed_file = checkpoint_dir / "saving_completed" + completed_file.touch() + + dist.barrier() + mprint("Checkpoint save done") diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py new file mode 100644 index 00000000000..34140c1be10 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for bypass distillation.""" + +from pathlib import Path + +from omegaconf import DictConfig + +import modelopt.torch.utils.distributed as dist + + +def set_experiment_id(cfg: DictConfig) -> None: + """Set the experiment ID based on the model config overrides. + + The ID encodes every override that affects the produced student so that + sweeps over (FFN size × KV heads) or (num_experts × KV heads) get distinct + directories instead of clobbering each other. + """ + if cfg.bypass.experiment_id is not None: + return + + overrides = cfg.bypass.model.model_config_overrides + parts: list[str] = [] + + if "ffn" in overrides: + ffn_override = overrides.ffn[0] + if "intermediate_size" in ffn_override and ffn_override["intermediate_size"] is not None: + parts.append(f"ffn_{ffn_override['intermediate_size']}") + elif "moe" in ffn_override and ffn_override["moe"] is not None: + parts.append(f"experts_{ffn_override['moe']['num_local_experts']}") + + if "attention" in overrides: + attn_override = overrides.attention[0] + if ( + "num_key_value_heads" in attn_override + and attn_override["num_key_value_heads"] is not None + ): + parts.append(f"heads_{attn_override['num_key_value_heads']}") + + if parts: + cfg.bypass.experiment_id = "bypass_" + "_".join(parts) + + +def set_experiment_dir(cfg: DictConfig) -> None: + """Set the experiment directory for the bypass run. + + Stores the path as a string in the OmegaConf node (OmegaConf only supports + primitive types natively). Use sites should reconstruct ``Path(...)`` as needed. + """ + experiment_dir = Path(cfg.puzzle_dir) / "bypass" / "bypass_runs" / cfg.bypass.experiment_id + cfg.bypass.experiment_dir = str(experiment_dir) + if dist.is_master(): + experiment_dir.mkdir(parents=True, exist_ok=True) + + +def get_distributed_modules_ownership(module_count: int, world_size: int) -> list[int]: + """Map module (block) indices to GPU ranks for pipeline-parallel distribution.""" + modules_process_ownership: list[int] = [] + + for i in range(world_size): + num_modules_for_process = module_count // world_size + if i < module_count % world_size: + num_modules_for_process += 1 + + modules_process_ownership.extend([i] * num_modules_for_process) + + return modules_process_ownership diff --git a/modelopt/torch/puzzletron/bypass_distillation/data_classes.py b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py new file mode 100644 index 00000000000..7c169e9c427 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data classes for bypass distillation training.""" + +import dataclasses +from typing import TypeAlias + + +IterNum: TypeAlias = int +GlobalRank: TypeAlias = int + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class IterStatistics: + step_num: int + token_count: int + iter_duration: float + lr: float + clipping_count: int + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class LocalTrainingStats: + iter_num: int + stitched_module_losses: dict[str, float] + + +@dataclasses.dataclass(kw_only=True, frozen=True) +class TimeToSaveSignal: + step_num: int diff --git a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py new file mode 100644 index 00000000000..815750a1919 --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py @@ -0,0 +1,646 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Factory for creating stitched teacher/student models for bypass distillation.""" + +import copy +import dataclasses +import re +from argparse import Namespace +from collections import OrderedDict +from pathlib import Path +from typing import Any, Callable, Mapping, Optional, Sequence, Type + +import torch +from omegaconf import DictConfig, OmegaConf +from torch.amp.grad_scaler import GradScaler +from torch.optim import AdamW, Optimizer +from transformers import PretrainedConfig, PreTrainedModel + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor +from modelopt.torch.puzzletron.pruning.pruning_utils import ( + GQAInitMode, + LinearInitMode, + MlpInitMode, +) +from modelopt.torch.puzzletron.sewing_kit import ( + ExternalTarget, + FunctionTarget, + InputArgs, + ModuleTarget, + Needle, + RemoteTarget, + StitchedModule, + always_true_predicate, +) +from modelopt.torch.puzzletron.sewing_kit.core import InputReducer +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + normalized_mse_loss, + vectorwise_normalized_mse_loss, +) +from modelopt.torch.puzzletron.tools.bypassed_training.child_init import ( + create_child_state_dict, + update_model_config, +) +from modelopt.torch.puzzletron.tools.logger import mprint +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import create_sharded_model +from modelopt.torch.puzzletron.utils.parsing import format_block_configs, parse_dtype + +StitchedModulesProcessOwnership = list[int] +SyncDistributedModelWeightsFn = Callable[[], None] +Config = Mapping[str, Any] +Args = Namespace + + +@dataclasses.dataclass +class StitchedModuleDescriptor: + stitched_module: StitchedModule + owned_parameters: dict[str, torch.nn.Parameter] + owned_buffers: dict[str, torch.Tensor] + optimizer: Optional[Optimizer] = None + grad_scaler: Optional[GradScaler] = None + + +def default_factory( + teacher_model: PreTrainedModel, + descriptor: Type[ModelDescriptor], + config: Config, + model_blocks_process_ownership: Sequence[int], + student_model: Optional[PreTrainedModel] = None, +) -> tuple[ + PreTrainedModel, + StitchedModule, + StitchedModule, + StitchedModule, + OrderedDict[str, StitchedModuleDescriptor], + PretrainedConfig, +]: + raise NotImplementedError() + + +StitchedModelFactoryFn = type(default_factory) + +_SUBBLOCK_KEYS_TO_LEARN = frozenset({"subblock_ffn", "subblock_attention", "subblock_mamba", "entire_block"}) + + +def _set_keys_to_learn( + model: PreTrainedModel, + descriptor: Type[ModelDescriptor], + keys_to_learn: str | Sequence[str], +) -> None: + """Set ``requires_grad=True`` on parameters selected by ``keys_to_learn``. + + * A **sequence of strings** (not a bare ``str``): each string is a full parameter + name; gradients are enabled only where ``named_parameters()`` names match exactly. + * A **single string**: if it is ``"subblock_ffn"``, ``"subblock_attention"``, or + ``"entire_block"``, enables gradients for the corresponding descriptor weight + groups; otherwise ``re.search`` is applied to each parameter name. + """ + # If keys_to_learn is a sequence of strings. + if isinstance(keys_to_learn, Sequence) and not isinstance(keys_to_learn, str): + param_names = set(keys_to_learn) + # If keys_to_learn is a single string. + else: + # If keys_to_learn is a single string that is a subblock key. + if keys_to_learn in _SUBBLOCK_KEYS_TO_LEARN: + lm_config = descriptor.get_language_model_config(model.config) + weight_groups = descriptor.get_weight_groups( + model.state_dict().keys(), lm_config.num_hidden_layers + ) + + attn_group_names = [ + group_name + for group_name in weight_groups.keys() + if group_name.endswith("_attention") + ] + ffn_group_names = [ + group_name + for group_name in weight_groups.keys() + if group_name.endswith("_ffn") + ] + if keys_to_learn == "subblock_attention": + group_names = attn_group_names + elif keys_to_learn == "subblock_ffn": + group_names = ffn_group_names + elif keys_to_learn == "subblock_mamba": + group_names = attn_group_names # Mamba params live in _attention groups + else: # entire_block + group_names = attn_group_names + ffn_group_names + + block_configs = getattr(lm_config, "block_configs", None) + + param_names = [] + for group_name in group_names: + # For hybrid models (e.g. NemotronH), a single "_attention" group + # name can contain either Mamba SSM params *or* GQA params depending + # on the block. Use the block config — not the keys_to_learn string + # — to decide whether each block belongs to the current subblock type. + if block_configs is not None: + m = re.match(r"block_(\d+)_attention", group_name) + if m: + block_idx = int(m.group(1)) + if block_idx < len(block_configs): + is_mamba = ( + getattr(block_configs[block_idx].attention, "mamba", None) + is not None + ) + # subblock_attention → GQA blocks only (not Mamba) + # subblock_mamba → Mamba blocks only (not GQA) + # entire_block → all blocks (no filtering) + if keys_to_learn == "subblock_attention" and is_mamba: + continue + if keys_to_learn == "subblock_mamba" and not is_mamba: + continue + param_names.extend(weight_groups[group_name]) + param_names = set(param_names) + # If keys_to_learn is a single string that is not a subblock key, treat as regex. + else: + param_names = { + param_name + for param_name, _ in model.named_parameters() + if re.search(keys_to_learn, param_name) + } + # In pipeline-parallel training a rank may own only blocks that don't match + # keys_to_learn (e.g. a rank with only Mamba blocks during subblock_attention + # bypass has no GQA params after the _mamba rename). That is a valid state: + # all its blocks will produce NaN loss and be excluded from statistics. + if not param_names: + return + + # Set requires_grad to True for the selected parameters. + for param_name, param in model.named_parameters(): + if param_name in param_names and torch.is_floating_point(param): + param.requires_grad_(True) + + +def _get_all_non_persistent_buffers_set(module: torch.nn.Module) -> set[str]: + all_non_persistent = set() + for module_name, submodule in module.named_modules(): + for buffer_name in submodule._non_persistent_buffers_set: + full_name = f"{module_name}.{buffer_name}" if module_name else buffer_name + all_non_persistent.add(full_name) + return all_non_persistent + + +def bypass_factory_fn( + teacher_model: PreTrainedModel, + descriptor: Type[ModelDescriptor], + cfg: DictConfig, + model_blocks_process_ownership: Sequence[int], + student_model: Optional[PreTrainedModel] = None, +) -> tuple[ + PreTrainedModel, + StitchedModule, + StitchedModule, + StitchedModule, + OrderedDict[str, StitchedModuleDescriptor], + PretrainedConfig, +]: + """Unified factory function for bypass (blockwise local) distillation. + + Handles all layer types — FFN, attention (GQA/MHA), MoE experts, Mamba, and whole blocks — + through a single pipeline. Behavior is driven entirely by ``model_factory`` config fields: + + - ``mlp_init_mode``: how student FFN / MoE weights are initialised + - ``"ExpertRemoval"``: select top-N experts from teacher (MoE models) + - ``"Truncate"`` / ``"PruneByActivationsLog"``: prune FFN channels (dense models) + - ``"CopyAsIs"``: copy weights unchanged (attention-only or Mamba-only runs) + - ``gqa_init_mode``: how attention KV heads are initialised (optional, default ``AverageKV``). + Irrelevant when the student has the same number of KV heads as the teacher. + - ``keys_to_learn``: which parameters to train. + Accepts ``"subblock_ffn"``, ``"subblock_attention"``, ``"entire_block"``, or a regex string. + + The stitching logic (pipeline-parallel per-block KD) is architecture-agnostic and unchanged + regardless of which layer type is being distilled. + + Args: + teacher_model: The teacher model to use for stitching. + descriptor: Model descriptor for layer naming and pruning mixin lookup. + cfg: The bypass config section. + model_blocks_process_ownership: Ownership mapping of model blocks to process ranks. + student_model: Optionally provided pre-built student model (skips initialisation). + + Returns: + Tuple of (student_model, teacher_stitched, teacher_val_stitched, + student_val_stitched, stitched_module_descriptors, student_config) + """ + device = torch.device(f"cuda:{dist.local_rank()}") + model_config_overrides = cfg.model.model_config_overrides + + block_loss_func = { + "normalized_mse_loss": normalized_mse_loss, + "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, + "batched_normalized_mse_loss": batched_normalized_mse_loss, + }[cfg.model_factory.block_loss_func] + mprint(f"{block_loss_func.__name__=}") + + owned_block_indexes = set( + block_index + for block_index, owner_rank in enumerate(model_blocks_process_ownership) + if owner_rank == dist.rank() + ) + + # Initialize student_model + if student_model is None: + mprint("Creating student model from teacher model") + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + if isinstance(model_config_overrides, DictConfig): + config_to_override = OmegaConf.to_container(model_config_overrides, resolve=True) + else: + config_to_override = model_config_overrides + mprint(f"{config_to_override=}") + student_model_config = update_model_config( + model_config=teacher_model.config, + model_config_overrides=config_to_override, + ) + student_model_config.use_cache = False + + mprint(f"Student model config:\n {format_block_configs(student_model_config)}") + + from modelopt.torch.puzzletron.anymodel.puzzformer import deci_x_patcher + + runtime = Namespace( + device=device, + dtype=torch.bfloat16, + global_rank=dist.rank(), + world_size=dist.size(), + is_main_process=dist.is_master(), + is_last_process=dist.is_last_process(), + ) + + with deci_x_patcher( + model_descriptor=descriptor, + block_configs=getattr(student_model_config, "block_configs", None), + ): + student_model = create_sharded_model( + runtime=runtime, + descriptor=descriptor, + model_config=student_model_config, + owned_block_indexes=owned_block_indexes, + device=device, + ) + # `_init_weights` is HF's per-module initializer; apply it across the + # whole model rather than passing the model itself as a single module. + student_model.apply(student_model._init_weights) + + student_weights_dtype = parse_dtype(cfg.model.student_weights_dtype) + descriptor.init_rotary_embedding(student_model, runtime) + student_model.type(student_weights_dtype) + + mlp_init_mode = MlpInitMode(cfg.model_factory.mlp_init_mode or MlpInitMode.CopyAsIs) + + # For expert removal, use the model-specific pruning mixin so that model-specific + # key paths (e.g. backbone.layers.{i}.mixer for Nemotron-H vs model.layers.{i}.mlp + # for GPT-OSS) are handled correctly. For all other init modes the legacy inline + # key logic in create_child_state_dict is sufficient. + _mixins = [] + if mlp_init_mode == MlpInitMode.ExpertRemoval: + _expert_mixin = descriptor.pruning_mixins().get("experts_removal") + if _expert_mixin is not None: + _mixins.append(_expert_mixin) + + # If any attention layer has fewer KV heads in the student than the teacher, use the + # model-specific KV heads mixin so that k_proj/v_proj weights are correctly sliced + # rather than copied verbatim from the (larger) teacher state dict. + _kv_mixin = descriptor.pruning_mixins().get("kv_heads") + if _kv_mixin is not None: + _student_kv = [ + b.attention.num_key_value_heads + for b in student_model_config.block_configs + if b.attention is not None and b.attention.num_key_value_heads is not None + ] + _teacher_kv = [ + b.attention.num_key_value_heads + for b in teacher_model.config.block_configs + if b.attention is not None and b.attention.num_key_value_heads is not None + ] + if _student_kv != _teacher_kv: + _mixins.append(_kv_mixin) + + # If any FFN layer has a smaller intermediate_size in the student than the teacher, + # use the model-specific FFN-intermediate mixin. The generic create_child_state_dict + # path is hardcoded to `model.layers.{i}.mlp.*` (Llama-style), so for families that + # place FFN under a different prefix (e.g. `backbone.layers.{i}.mixer.*` for + # Nemotron-H/H_v2) the mixin is required to slice up_proj/down_proj correctly. + # Filter out no_op FFN blocks (their intermediate_size is None) — relevant for + # hybrid families where each layer is exactly one of {attention, ffn, mamba}. + _ffn_mixin = descriptor.pruning_mixins().get("ffn_intermediate") + if _ffn_mixin is not None and mlp_init_mode in ( + MlpInitMode.Truncate, + MlpInitMode.PruneByActivationsLog, + ): + _student_ffn = [ + b.ffn.intermediate_size + for b in student_model_config.block_configs + if b.ffn is not None and b.ffn.intermediate_size is not None + ] + _teacher_ffn = [ + b.ffn.intermediate_size + for b in teacher_model.config.block_configs + if b.ffn is not None and b.ffn.intermediate_size is not None + ] + if _student_ffn != _teacher_ffn: + _mixins.append(_ffn_mixin) + + if len(_mixins) == 0: + pruning_mixin = None + elif len(_mixins) == 1: + pruning_mixin = _mixins[0] + else: + pruning_mixin = _mixins + + # GQA init mode is optional: only relevant when the student has fewer KV heads than + # the teacher. Defaults to AverageKV and is a no-op when head counts are equal. + gqa_init_mode = GQAInitMode( + cfg.model_factory.get("gqa_init_mode", GQAInitMode.AverageKV) + ) + + student_state_dict = create_child_state_dict( + pruning_mixin=pruning_mixin, + descriptor=descriptor, + original_state_dict=teacher_model.state_dict(), + new_state_dict=student_model.state_dict(), + original_config=teacher_model.config, + new_config=student_model_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=cfg.model_factory.mlp_init_config, + owned_block_indexes=owned_block_indexes, + linear_init_mode=LinearInitMode( + cfg.model_factory.linear_init_mode or LinearInitMode.Random + ), + ) + + # Load student state dict + missing_keys, unexpected_keys = student_model.load_state_dict( + student_state_dict, strict=False + ) + assert len(unexpected_keys) == 0, f"{unexpected_keys=}" + # GQA models have learnable logit parameters not present in the teacher state dict; + # allow those to be absent and assert nothing else is missing. + non_gqa_missing = [k for k in missing_keys if not re.search(r"gqa_\w+_logits", k)] + assert len(non_gqa_missing) == 0, f"Unexpected missing keys: {non_gqa_missing}" + + else: + mprint("Student model provided explicitly, not using teacher model to instantiate") + student_model_config = student_model.config + + # Set up training parameters + lm_config = descriptor.get_language_model_config(student_model_config) + all_block_indices = list(range(lm_config.num_hidden_layers)) + + student_model.requires_grad_(False) + keys_to_learn = cfg.model_factory.keys_to_learn + mprint(f"Keys to learn: {keys_to_learn}") + + _set_keys_to_learn(model=student_model, descriptor=descriptor, keys_to_learn=keys_to_learn) + + dist.barrier() + mprint(f"Global rank: {dist.rank()}, {owned_block_indexes=}") + dist.barrier() + + torch.cuda.synchronize() + torch.cuda.empty_cache() + dist.barrier() + + min_owned_index = min(owned_block_indexes) + max_owned_index = max(owned_block_indexes) + prev_rank: Optional[int] = ( + None + if min_owned_index == min(all_block_indices) + else model_blocks_process_ownership[min_owned_index - 1] + ) + next_rank: Optional[int] = ( + None + if max_owned_index == max(all_block_indices) + else model_blocks_process_ownership[max_owned_index + 1] + ) + + teacher_parameters = set(teacher_model.parameters()) + teacher_buffers = set(teacher_model.buffers()) + + # Setup the student model's submodules for knowledge distillation training + with torch.autocast(device_type="cuda", dtype=torch.bfloat16), torch.device(device): + stitched_module_descriptors = OrderedDict[str, StitchedModuleDescriptor]() + submodule_for_loss_calculation = cfg.model_factory.submodule_for_loss_calculation + + teacher_target = ModuleTarget("teacher", teacher_model) + teacher_stitcher = Needle() + teacher_val_stitcher = Needle() + + student_target = ModuleTarget("student", student_model) + student_val_stitcher = Needle() + + for local_block_index, global_block_index in enumerate(sorted(owned_block_indexes)): + module_name = descriptor.layer_block_name(global_block_index) + module = student_model.get_submodule(module_name) + + submodule_name = "" + submodule_input_descriptor = submodule_name + submodule_output_descriptor = submodule_name + + if submodule_for_loss_calculation is not None: + assert hasattr(module, submodule_for_loss_calculation) + submodule_output_descriptor = submodule_for_loss_calculation + + input_descriptor = f"{module_name}.{submodule_input_descriptor}".rstrip(".") + output_descriptor = f"{module_name}.{submodule_output_descriptor}".rstrip(".") + + # Receive activations from previous rank + if global_block_index > 0 and local_block_index == 0 and prev_rank is not None: + teacher_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="teacher_activations", adapter=lambda x: InputArgs(x) + ), + teacher_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + teacher_val_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="teacher_activations", adapter=lambda x: InputArgs(x) + ), + teacher_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + student_val_stitcher.stitch( + RemoteTarget(peer_rank=prev_rank).value( + name="student_activations", adapter=lambda x: InputArgs(x) + ), + student_target.input( + name=module_name, + reducer=InputReducer( + lambda acc, override, orig, *args: override + orig.drop_args(0) + ), + ), + ) + + # Send activations to next rank or register model output + if local_block_index + 1 == len(owned_block_indexes): + if next_rank is None: + student_val_stitcher.stitch( + student_target.output(name=""), + ExternalTarget().output("model_output"), + ) + teacher_val_stitcher.stitch( + teacher_target.output(name=""), + ExternalTarget().output("model_output"), + ) + else: + teacher_stitcher.stitch( + teacher_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="teacher_activations"), + ) + teacher_val_stitcher.stitch( + teacher_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="teacher_activations"), + ) + student_val_stitcher.stitch( + student_target.output(name=module_name), + RemoteTarget(peer_rank=next_rank).value(name="student_activations"), + ) + + # Bypass training stitches + teacher_stitcher.stitch( + teacher_target.input(name=input_descriptor), + ExternalTarget().input(name=input_descriptor), + ).stitch( + teacher_target.output(name=output_descriptor), + ExternalTarget().output(name=output_descriptor), + ) + + # Create the student block stitched module + student_stitched_module_loss_target = FunctionTarget( + "module_loss_func", block_loss_func + ) + student_stitched_module_name = f"block_{global_block_index}" + student_submodule_target = ModuleTarget("student_submodule", module) + student_stitched_module = ( + Needle() + .stitch( + ExternalTarget().input(name=input_descriptor), + student_submodule_target.input(name=submodule_input_descriptor), + ) + .stitch( + ExternalTarget().output( + name=output_descriptor, + adapter=lambda v: InputArgs(target=v) + if not isinstance(v, tuple) + else InputArgs(target=v[0]), + ), + student_stitched_module_loss_target.input(), + ) + .stitch( + student_submodule_target.output( + name=submodule_output_descriptor, + adapter=lambda v: InputArgs(input=v) + if not isinstance(v, tuple) + else InputArgs(input=v[0]), + ), + student_stitched_module_loss_target.input(), + ) + .stitch( + student_stitched_module_loss_target.output(), + ExternalTarget().output(name="loss"), + ) + .knot( + ignore_extra_overrides=True, + capture_cache_outputs_predicate=always_true_predicate, + ) + ) + + assert "learning_rate" in cfg.training + num_trainable_params = sum( + p.requires_grad and submodule_name in p_name + for p_name, p in student_stitched_module.named_parameters() + if "dummy_param" not in p_name # exclude placeholder params + ) + # Do NOT enable dummy params: blocks with no real trainable parameters + # (e.g. Mamba blocks during an attention-only bypass run) should produce + # NaN loss so they are excluded from statistics — identical to the + # optimizer=None path in the training loop. + + student_module_parameters = { + p_name: p + for p_name, p in student_stitched_module.named_parameters() + if p not in teacher_parameters and "dummy_param" not in p_name + } + student_module_buffers = { + p_name: p + for p_name, p in student_stitched_module.named_buffers() + if p not in teacher_buffers + and p_name not in _get_all_non_persistent_buffers_set(student_stitched_module) + } + + trainable_params = { + p_name: p + for p_name, p in student_module_parameters.items() + if p.requires_grad + } + + optimizer = ( + AdamW( + list(trainable_params.values()), + lr=cfg.training.learning_rate, + weight_decay=cfg.training.weight_decay, + betas=(cfg.training.beta1, cfg.training.beta2), + fused=True, + ) + if len(trainable_params) > 0 + else None + ) + + grad_scaler = ( + None + if optimizer is None + else GradScaler(device=device.type, enabled=cfg.training.use_grad_scaling) + ) + + stitched_module_descriptors[student_stitched_module_name] = StitchedModuleDescriptor( + stitched_module=student_stitched_module, + owned_parameters=student_module_parameters, + owned_buffers=student_module_buffers, + optimizer=optimizer, + grad_scaler=grad_scaler, + ) + + teacher_stitched_module = teacher_stitcher.knot(ignore_extra_overrides=True) + teacher_val_stitched_module = teacher_val_stitcher.knot(ignore_extra_overrides=True) + student_val_stitched_module = student_val_stitcher.knot(ignore_extra_overrides=True) + + return ( + student_model, + teacher_stitched_module, + teacher_val_stitched_module, + student_val_stitched_module, + stitched_module_descriptors, + student_model_config, + ) + + + +# Backward-compatible name aliases +gqa_factory_fn = bypass_factory_fn +moe_factory_fn = bypass_factory_fn diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py new file mode 100644 index 00000000000..b3ca788888c --- /dev/null +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -0,0 +1,978 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bypass distillation training loop for per-block knowledge distillation. + +This module implements the blockwise local distillation (BLD) stage of the PUZZLE framework. +It trains alternative transformer block configurations using per-block knowledge distillation +from a teacher model, producing a library of "puzzle pieces" with different efficiency/performance +trade-offs. +""" + +import logging +import math +import os +import shutil +import sys +import time +import traceback +from collections import OrderedDict, defaultdict +from pathlib import Path +from statistics import mean +from typing import Optional, Type, cast + +import datasets +import torch +import torch.distributed +import transformers +from omegaconf import DictConfig +from torch.utils.data.dataloader import DataLoader +from transformers import AutoTokenizer, PreTrainedTokenizerBase, PretrainedConfig + +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory +from modelopt.torch.puzzletron.sewing_kit import InputArgs, StitchedModule +from modelopt.torch.puzzletron.sewing_kit.utils import fake_tensor +from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config +from modelopt.torch.puzzletron.tools.logger import aprint, mprint +from modelopt.torch.puzzletron.tools.robust_json import json_load +from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model +from modelopt.torch.puzzletron.utils.parsing import format_global_config, format_stitched_losses + +from .bypass_checkpoint_utils import find_latest_run_dir, load_local_state, save_bypass_checkpoint +from .bypass_utils import get_distributed_modules_ownership, set_experiment_dir, set_experiment_id +from .data_classes import GlobalRank, IterNum, IterStatistics, LocalTrainingStats, TimeToSaveSignal +from .stitched_model_factory import StitchedModuleDescriptor, StitchedModulesProcessOwnership + +import modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory as stitched_model_factory_module + +time_start = time.time() + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +def launch_bypass_distillation(hydra_cfg: DictConfig) -> None: + """Top-level entry point for bypass distillation stage. + + Runs sewing-kit pipeline-parallel per-block knowledge distillation. + + Supports multiple bypass configurations via ``bypass.configs`` list. + Each entry overrides ``bypass.model.model_config_overrides`` and optionally + ``bypass.model_factory.keys_to_learn``, then runs a full bypass training. + + If ``bypass.configs`` is absent or empty, runs a single bypass training + with the settings already in ``bypass``. + + Args: + hydra_cfg: The full Hydra configuration with a 'bypass' section. + """ + configs_list = hydra_cfg.bypass.get("configs", None) + + if not configs_list: + # Single config mode — run once with whatever is in bypass already + mprint("Starting bypass distillation (single config)") + run_bypassed_training(hydra_cfg) + mprint("Bypass distillation completed") + return + + mprint(f"Starting bypass distillation sweep ({len(configs_list)} configs)") + for i, override in enumerate(configs_list): + mprint(f"Bypass config {i + 1}/{len(configs_list)}: {override}") + + # Apply overrides for this run + if "model_config_overrides" in override: + hydra_cfg.bypass.model.model_config_overrides = override.model_config_overrides + if "keys_to_learn" in override: + hydra_cfg.bypass.model_factory.keys_to_learn = override.keys_to_learn + + # Reset per-run state so each config starts fresh + hydra_cfg.bypass.experiment_id = None + hydra_cfg.bypass.iter_num = 1 + hydra_cfg.bypass.step_num = 1 + hydra_cfg.bypass.token_count = 0 + hydra_cfg.bypass.best_val_loss = 1e9 + hydra_cfg.bypass.training.clipping_count = 0 + + run_bypassed_training(hydra_cfg) + mprint(f"Bypass config {i + 1}/{len(configs_list)} completed") + + mprint("Bypass distillation sweep completed") + + +def train( + cfg: DictConfig, + descriptor: Type[ModelDescriptor], + student_model: torch.nn.Module, + student_stitched_model: StitchedModule, + teacher_stitched_model: StitchedModule, + stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], + stitched_modules_process_ownership: StitchedModulesProcessOwnership, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + student_model_config: PretrainedConfig, + skip_first_batches: int = 0, + tokenizer: Optional[PreTrainedTokenizerBase] = None, +) -> None: + """Inner training loop for bypass distillation.""" + device = torch.device(f"cuda:{dist.local_rank()}") + + dist.barrier() + + time_last_save = time_start + iter_t0 = time.time() + + resumed_iter_num = cfg.bypass.iter_num + mprint(f"resumed_iter_num: {resumed_iter_num}") + + # Number of total stitched modules + global_stitched_modules_count = len(stitched_modules_process_ownership) + # Number of stitched modules per process + num_stitched_modules_per_process = [ + sum(1 for x in stitched_modules_process_ownership if x == owner_rank) + for owner_rank in range(dist.size()) + ] + # Indices of stitched modules owned by the current process + owned_stitched_module_indices = [ + i + for i, owner in enumerate(stitched_modules_process_ownership) + if owner == dist.rank() + ] + mprint(f"{global_stitched_modules_count=}") + mprint(f"{num_stitched_modules_per_process=}") + dist.barrier() + + if dist.is_master(): + # {iter_num: {stitched_module_name: loss}} + stitched_losses_history = dict[IterNum, dict[str, float]]() + else: + stitched_losses_history = None + + # Save checkpoint before training starts + if cfg.bypass.save_checkpoint_before_training and not cfg.bypass.disable_checkpoint_save: + subdir_name = f"start-iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=Path(cfg.bypass.experiment_dir) / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + + # Track statistics for each iteration + iter_stats_history: dict[IterNum, IterStatistics] = {} + + # Create fake input ids for the teacher model + fake_input_ids = fake_tensor( + torch.ones( + size=(cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), + dtype=torch.long, + device=device, + ) + ) + + # Get pipeline neighbor ranks + min_owned_index = min(owned_stitched_module_indices) + max_owned_index = max(owned_stitched_module_indices) + prev_rank: Optional[int] = ( + None + if min_owned_index - 1 < 0 + else stitched_modules_process_ownership[min_owned_index - 1] + ) + next_rank: Optional[int] = ( + None + if max_owned_index + 1 >= global_stitched_modules_count + else stitched_modules_process_ownership[max_owned_index + 1] + ) + + torch.cuda.synchronize() + + mprint(f'Grad scaling status: {"enabled" if cfg.bypass.training.use_grad_scaling else "disabled"}') + + train_iterator = iter(train_dataloader) + + mprint("Waiting for everyone before training starts") + dist.barrier() + + step_to_save = None + # Track best loss value for each block + best_losses_by_name = dict[str, float]() + best_steps_by_name = dict[str, int]() + # Buffer variables + input_ids = torch.zeros(1, 1, dtype=torch.int64) + + aprint( + f"previous rank: {str(prev_rank):<5} next rank: {str(next_rank):<5} {owned_stitched_module_indices=}" + ) + + # Train loop start + while True: + time_now = time.time() + # Check if we've reached the maximum number of steps + if cfg.bypass.step_num >= cfg.bypass.training.max_steps: + if ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and not cfg.bypass.disable_checkpoint_save + ): + mprint("Saving final checkpoint before training completion") + subdir_name = f"final-iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=Path(cfg.bypass.experiment_dir) / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + + if cfg.bypass.model.model_overrides.delete_old_checkpoints and dist.is_master(): + existing_ckpt_paths = list(Path(cfg.bypass.experiment_dir).glob("iter-*")) + for old_ckpt_path in existing_ckpt_paths: + if old_ckpt_path.name != subdir_name: + shutil.rmtree(str(old_ckpt_path)) + break + + is_accumulating = cfg.bypass.iter_num % cfg.bypass.training.grad_accumulation_steps != 0 + # Determine and set the learning rate for this iteration + lr = ( + _get_lr(cfg, cfg.bypass.step_num) + if cfg.bypass.training.decay_lr + else cfg.bypass.training.learning_rate + ) + for stitched_module_descriptor in stitched_module_descriptors.values(): + optimizer = stitched_module_descriptor.optimizer + if optimizer is not None: + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + if dist.is_master(): + train_data = next(train_iterator) + input_ids = train_data["input_ids"] + input_ids = input_ids.to(device) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16), torch.no_grad(): + teacher_input_ids = input_ids if prev_rank is None else fake_input_ids + teacher_output = teacher_stitched_model({}, {}, teacher_input_ids) + + input_overrides = teacher_output.captured_inputs + output_overrides = teacher_output.captured_outputs + + del teacher_output + + input_overrides["teacher_inputs"] = InputArgs(fake_input_ids) + + iter_stitched_module_losses: dict[str, float] = {} + + for local_stitched_module_index, ( + stitched_module_name, + stitched_module_descriptor, + ) in enumerate(stitched_module_descriptors.items()): + stitched_module = stitched_module_descriptor.stitched_module + optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler + + if optimizer is not None: + assert grad_scaler is not None + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + stitched_module_output = stitched_module( + input_overrides=input_overrides, + output_overrides=output_overrides, + ) + stitched_module_loss = stitched_module_output.captured_outputs["loss"] + del stitched_module_output + grad_scaler.scale(stitched_module_loss).backward() + else: + stitched_module_loss = torch.full( + [1], fill_value=torch.nan, dtype=torch.float32 + ) + + iter_stitched_module_losses[stitched_module_name] = ( + stitched_module_loss.to("cpu").item() + ) + + del stitched_module_loss + + if not is_accumulating: + if optimizer is not None: + grad_clip = cfg.bypass.training.grad_clip + if grad_clip is not None: + if cfg.bypass.training.grad_clip_type == "norm": + grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=stitched_module.parameters(), + max_norm=grad_clip, + ) + if grad_norm > grad_clip: + cfg.bypass.training.clipping_count += 1 + elif cfg.bypass.training.grad_clip_type == "value": + max_abs_grad_per_param = [ + p.grad.abs().max().item() + for p in stitched_module.parameters() + if p.grad is not None + ] + max_abs_grad = ( + max(max_abs_grad_per_param) + if len(max_abs_grad_per_param) > 0 + else 0.0 + ) + if max_abs_grad > grad_clip: + cfg.bypass.training.clipping_count += 1 + torch.nn.utils.clip_grad_value_( + parameters=stitched_module.parameters(), + clip_value=grad_clip, + ) + else: + raise RuntimeError( + f"Invalid {cfg.bypass.training.grad_clip_type}" + ) + + assert grad_scaler is not None + grad_scaler.step(optimizer) + grad_scaler.update() + optimizer.zero_grad(set_to_none=True) + + # Collect losses from all ranks using all_gather_object + local_training_stats = LocalTrainingStats( + iter_num=cfg.bypass.iter_num, + stitched_module_losses=iter_stitched_module_losses, + ) + all_training_stats = [None] * dist.size() + torch.distributed.all_gather_object(all_training_stats, local_training_stats) + + if dist.is_master(): + if cfg.bypass.iter_num == resumed_iter_num: + mprint(f"Starting from iter {cfg.bypass.iter_num}") + + # Merge all stats into the losses history + assert stitched_losses_history is not None + merged_losses: dict[str, float] = {} + for stats in all_training_stats: + if stats is not None: + merged_losses.update(stats.stitched_module_losses) + stitched_losses_history[cfg.bypass.iter_num] = merged_losses + + cfg.bypass.token_count += cfg.bypass.training.tokens_per_iter + iter_t1 = time.time() + iter_duration = iter_t1 - iter_t0 + iter_stats_history[cfg.bypass.iter_num] = IterStatistics( + token_count=cfg.bypass.token_count, + iter_duration=iter_duration, + step_num=cfg.bypass.step_num, + lr=lr, + clipping_count=cfg.bypass.training.clipping_count, + ) + iter_t0 = iter_t1 + + # Time-based save signal (broadcast from master) + save_signal = [step_to_save] + if dist.is_master(): + if cfg.bypass.model.model_overrides.save_interval_seconds is not None: + time_now = time.time() + if time_now - time_last_save >= cfg.bypass.model.model_overrides.save_interval_seconds: + mprint( + f"Time to save! {cfg.bypass.model.model_overrides.save_interval_seconds=}, " + f"{time_last_save=}, {time_now=}" + ) + step_to_save = cfg.bypass.step_num + 5 + save_signal = [step_to_save] + time_last_save = time_now + + torch.distributed.broadcast_object_list(save_signal, src=0) + step_to_save = save_signal[0] + + # Logging + if dist.is_master(): + assert stitched_losses_history is not None + while len(stitched_losses_history) >= cfg.bypass.training.log_interval: + lowest_iter = next(iter(stitched_losses_history.keys())) + + log_chunk = { + it: losses + for it, losses in stitched_losses_history.items() + if it - lowest_iter < cfg.bypass.training.log_interval + } + if len(log_chunk) < cfg.bypass.training.log_interval: + break + + highest_iter = list(log_chunk.keys())[-1] + highest_iter_stats = iter_stats_history[highest_iter] + + losses_by_name = defaultdict[str, list[float]](lambda: []) + for losses in log_chunk.values(): + for name, loss in losses.items(): + losses_by_name[name].append(loss) + + losses_by_name_avg = { + name: mean(losses) for name, losses in losses_by_name.items() + } + + # Update best losses tracking + for name, current_loss in losses_by_name_avg.items(): + if name not in best_losses_by_name or current_loss < best_losses_by_name[name]: + best_losses_by_name[name] = current_loss + best_steps_by_name[name] = highest_iter + + chunk_iter_durations = [ + iter_stats_history[it].iter_duration for it in log_chunk.keys() + ] + avg_chunk_iter_duration = mean(chunk_iter_durations) + avg_token_speed = cfg.bypass.training.tokens_per_iter / avg_chunk_iter_duration + mprint( + # `highest_iter` is in micro-batch units (iter_num); compare against + # max_iters (= max_steps × grad_accumulation_steps), not max_steps, + # so the progress fraction is in consistent units. + f"iter {highest_iter}/{cfg.bypass.training.max_iters:,}" + f" (step {highest_iter_stats.step_num}/{cfg.bypass.training.max_steps:,}):" + f" avg_iter_time={avg_chunk_iter_duration * 1000:.2f}ms" + f" avg_token_speed={avg_token_speed:,.0f}[tok/s]" + ) + mprint( + format_stitched_losses( + losses_dict=losses_by_name_avg, + best_steps_dict=best_steps_by_name, + best_values_dict=best_losses_by_name, + step_number=highest_iter, + title="Stitched Module Losses", + ) + ) + + if cfg.bypass.wandb_log: + try: + import wandb + + wandb.log( + { + "iter": highest_iter, + "step": highest_iter_stats.step_num, + "token_count": highest_iter_stats.token_count, + "token_speed": avg_token_speed, + "lr": highest_iter_stats.lr, + "grad_clipping": highest_iter_stats.clipping_count, + }, + step=highest_iter, + ) + except ImportError: + pass + + for it in log_chunk.keys(): + del iter_stats_history[it] + del stitched_losses_history[it] + + # Validation + if ( + not is_accumulating + and (cfg.bypass.step_num % cfg.bypass.training.eval_interval) == 0 + and val_dataloader is not None + ): + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + losses, _ = calculate_losses_pipeline( + stitched_model=student_stitched_model, + dataloader=val_dataloader, + descriptor=descriptor, + ) + + val_loss = float("inf") + if losses is not None and "lm_loss" in losses: + val_loss = losses["lm_loss"]["avg"] + mprint(f"Validation loss at iter {cfg.bypass.iter_num}: {val_loss:.4f}") + + # Broadcast val_loss so all ranks agree on checkpoint decisions + val_loss_tensor = torch.tensor([val_loss], device=device) + torch.distributed.broadcast(val_loss_tensor, src=dist.size() - 1) + val_loss = val_loss_tensor.item() + + if val_loss < cfg.bypass.best_val_loss: + cfg.bypass.best_val_loss = val_loss + if not cfg.bypass.disable_checkpoint_save and cfg.bypass.save_best_ckpt: + subdir_name = f"best-iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=Path(cfg.bypass.experiment_dir) / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + # Keep only the most recent best-checkpoint when delete_old_checkpoints + # is enabled — mirrors the iter-*/final-iter-* cleanup elsewhere so a + # long run with many validation improvements doesn't fill the disk. + if cfg.bypass.model.model_overrides.delete_old_checkpoints and dist.is_master(): + for old_ckpt_path in Path(cfg.bypass.experiment_dir).glob("best-iter-*"): + if old_ckpt_path.name != subdir_name: + shutil.rmtree(str(old_ckpt_path)) + if cfg.bypass.kill_after_first_save: + raise RuntimeError( + "Done saving checkpoint, kill_after_first_save=True" + ) + + # Checkpoint saving (step-based or time-based) + if not is_accumulating and ( + (cfg.bypass.step_num % cfg.bypass.model.model_overrides.save_interval) == 0 + or step_to_save == cfg.bypass.step_num + or ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and cfg.bypass.step_num >= cfg.bypass.training.max_steps + ) + ): + if not cfg.bypass.disable_checkpoint_save: + if (cfg.bypass.step_num % cfg.bypass.model.model_overrides.save_interval) == 0: + mprint("Saving step-interval checkpoint") + elif step_to_save == cfg.bypass.step_num: + mprint("Saving time-based checkpoint") + elif ( + cfg.bypass.model.model_overrides.save_checkpoint_when_done + and cfg.bypass.step_num >= cfg.bypass.training.max_steps + ): + mprint("Saving final checkpoint") + + subdir_name = f"iter-{cfg.bypass.iter_num:06d}-ckpt" + save_bypass_checkpoint( + cfg=cfg, + descriptor=descriptor, + model=student_model, + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_dir=Path(cfg.bypass.experiment_dir) / subdir_name, + reference_checkpoint_dir=cfg.teacher_dir, + ) + + if cfg.bypass.kill_after_first_save: + dist.barrier() + raise RuntimeError( + "Done saving checkpoint, kill_after_first_save=True" + ) + + if cfg.bypass.model.model_overrides.delete_old_checkpoints and dist.is_master(): + existing_ckpt_paths = list( + Path(cfg.bypass.experiment_dir).glob("iter-*") + ) + for old_ckpt_path in existing_ckpt_paths: + if old_ckpt_path.name != subdir_name: + shutil.rmtree(str(old_ckpt_path)) + + cfg.bypass.iter_num += 1 + if not is_accumulating: + cfg.bypass.step_num += 1 + + mprint("Finished successfully!") + + +# Learning rate decay scheduler (cosine with warmup) +def _get_lr(cfg: DictConfig, step: int) -> float: + warmup_steps = cfg.bypass.training.warmup_steps + lr_decay_steps = cfg.bypass.training.lr_decay_steps + # Degenerate budget (e.g. tiny `training_tokens` in tests): no room for cosine decay. + # Skip warmup/decay entirely and return base LR — avoids ZeroDivisionError on + # `lr_decay_steps - warmup_steps` and `step / warmup_steps`. + if lr_decay_steps <= warmup_steps: + return cfg.bypass.training.learning_rate + + # 1) linear warmup for warmup_steps steps + if step <= warmup_steps: + lr = cfg.bypass.training.learning_rate * step / warmup_steps + # 2) if step > lr_decay_steps, return min learning rate + elif step > lr_decay_steps: + lr = cfg.bypass.training.min_lr + # 3) in between, use cosine decay down to min learning rate + else: + decay_ratio = (step - warmup_steps - 1) / (lr_decay_steps - warmup_steps) + assert 0 <= decay_ratio <= 1 + coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 + lr = cfg.bypass.training.min_lr + coeff * ( + cfg.bypass.training.learning_rate - cfg.bypass.training.min_lr + ) + + return lr + + +def run_bypassed_training(cfg: DictConfig): + """Setup and orchestrate bypass distillation training.""" + logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.WARN + ) + + # Suppress debug messages from HuggingFace libraries + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + + device = torch.device(f"cuda:{dist.local_rank()}") + + descriptor = ModelDescriptorFactory.get(cfg.descriptor) + trust_remote_code = descriptor.requires_trust_remote_code() + teacher_model_config = load_model_config(cfg.teacher_dir, trust_remote_code=trust_remote_code) + + try: + mprint("Waiting for distributed setup...") + dist.barrier() + + if cfg.bypass.disable_initial_validate: + cfg.bypass.validate_teacher_model = False + cfg.bypass.validate_student_model = False + + if cfg.bypass.teacher_model_load_on_cpu: + assert not cfg.bypass.validate_teacher_model, ( + "Teacher model validation is too slow on CPU" + ) + + num_hidden_layers = descriptor.get_language_model_config( + teacher_model_config + ).num_hidden_layers + + model_blocks_process_ownership = get_distributed_modules_ownership( + module_count=num_hidden_layers, + world_size=dist.size(), + ) + + owned_block_indexes = set( + block_index + for block_index, owner_rank in enumerate(model_blocks_process_ownership) + if owner_rank == dist.rank() + ) + + cfg.teacher_dir = str(Path(cfg.teacher_dir).expanduser()) + teacher_model_config = load_model_config( + cfg.teacher_dir, + trust_remote_code=trust_remote_code, + ) + # Disable KV cache during bypass forward passes. Set the attribute directly rather + # than passing it as an AutoConfig override — some custom configs (GptOss, Qwen3-VL, etc.) + # don't accept it as a known kwarg and would raise via the strict unused-kwargs check. + if hasattr(teacher_model_config, "use_cache"): + teacher_model_config.use_cache = False + if hasattr(teacher_model_config, "text_config") and hasattr( + teacher_model_config.text_config, "use_cache" + ): + teacher_model_config.text_config.use_cache = False + + student_model = None + if cfg.bypass.init_checkpoint_path is not None: + mprint(f"Loading student model from {cfg.bypass.init_checkpoint_path}") + student_model = load_and_shard_model( + descriptor=descriptor, + checkpoint_path=cfg.bypass.init_checkpoint_path, + owned_block_indexes=owned_block_indexes, + ) + + cfg.bypass.training.min_lr = ( + cfg.bypass.training.learning_rate * cfg.bypass.training.min_lr_factor + ) + cfg.bypass.training.batch_size_per_iter = cfg.bypass.training.micro_batch_size + cfg.bypass.training.tokens_per_iter = ( + cfg.bypass.data.block_size * cfg.bypass.training.batch_size_per_iter + ) + cfg.bypass.training.max_steps = math.ceil( + cfg.bypass.training.training_tokens / cfg.bypass.training.tokens_per_iter + ) + cfg.bypass.training.max_iters = ( + cfg.bypass.training.max_steps * cfg.bypass.training.grad_accumulation_steps + ) + cfg.bypass.training.max_token_count = ( + cfg.bypass.training.max_iters * cfg.bypass.training.tokens_per_iter + ) + cfg.bypass.training.lr_decay_steps = cfg.bypass.training.max_steps + + if cfg.bypass.training.val_micro_batch_size is None: + cfg.bypass.training.val_micro_batch_size = cfg.bypass.training.micro_batch_size + + if cfg.bypass.training.warmup_steps is None: + cfg.bypass.training.warmup_steps = 0 + + mprint(f'\n{format_global_config(cfg.bypass, "Bypass Configurations")}') + mprint(f"Max token count: {cfg.bypass.training.max_token_count:,}") + + seed = cfg.bypass.seed + torch.manual_seed(seed) + + tokenizer = AutoTokenizer.from_pretrained( + cfg.teacher_dir, + trust_remote_code=True, + token=True, + ) + + assert teacher_model_config is not None + + mprint( + f"Load and shard model with: {owned_block_indexes=}, {cfg.teacher_dir=}" + ) + teacher_model = load_and_shard_model( + descriptor=descriptor, + checkpoint_path=cfg.teacher_dir, + owned_block_indexes=owned_block_indexes, + model_config=teacher_model_config, + ) + + teacher_model.requires_grad_(False) + + # Create dataloaders + from modelopt.torch.puzzletron.utils.data.dataloaders import ( + create_train_dataloader, + create_validation_dataloader, + load_from_disk_fn, + load_streaming_fn, + ) + + if cfg.bypass.data.eval_samples_per_process is not None: + max_eval_samples = cfg.bypass.data.eval_samples_per_process * dist.size() + else: + max_eval_samples = cfg.bypass.data.max_eval_samples + + load_dataset_fn = load_streaming_fn if not cfg.bypass.data.load_from_disk else load_from_disk_fn + + train_dataloader = create_train_dataloader( + seed=seed, + tokenizer=tokenizer, + block_size=cfg.bypass.data.block_size, + dataset_path=cfg.dataset_path, + content_field=cfg.bypass.data.data_column, + fim_rate=cfg.bypass.data.fim_rate, + fim_spm_rate=cfg.bypass.data.fim_spm_rate, + micro_batch_size=cfg.bypass.training.micro_batch_size, + load_dataset_fn=load_dataset_fn, + keep_in_memory=cfg.bypass.data.keep_in_memory, + source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()), + bos_rate=cfg.bypass.data.bos_rate, + shuffle_seed=cfg.bypass.data.shuffle_train_data_seed, + ) + + val_dataloader = None + if not cfg.bypass.disable_validation: + val_dataloader = create_validation_dataloader( + accelerator=None, + seed=seed, + tokenizer=tokenizer, + block_size=cfg.bypass.data.block_size, + dataset=cfg.dataset_path, + content_field=cfg.bypass.data.data_column, + fim_rate=cfg.bypass.data.fim_rate, + fim_spm_rate=cfg.bypass.data.fim_spm_rate, + micro_batch_size=cfg.bypass.training.val_micro_batch_size, + eval_samples=max_eval_samples, + load_dataset_fn=load_dataset_fn, + dataset_name=cfg.bypass.data.val_dataset_name, + keep_in_memory=cfg.bypass.data.keep_in_memory, + source_datasets_to_discard=cfg.bypass.get( + "source_datasets_to_discard", tuple() + ), + bos_rate=cfg.bypass.data.bos_rate, + ) + + # Set ID from experiment configuration + set_experiment_id(cfg) + # Set directory for experiment ID + set_experiment_dir(cfg) + + dist.barrier() + + with torch.device(device): + stitched_model_factory_fn = cast( + stitched_model_factory_module.StitchedModelFactoryFn, + getattr(stitched_model_factory_module, cfg.bypass.model_factory.factory), + ) + ( + student_model, + teacher_stitched_model, + teacher_val_stitched_module, + student_val_stitched_model, + stitched_module_descriptors, + student_model_config, + ) = stitched_model_factory_fn( + teacher_model=teacher_model, + descriptor=descriptor, + cfg=cfg.bypass, + model_blocks_process_ownership=model_blocks_process_ownership, + student_model=student_model, + ) + + # Check whether to resume from checkpoint + resume_checkpoint_path = None + if cfg.bypass.resume_checkpoint_path is not None: + resume_checkpoint_path = cfg.bypass.resume_checkpoint_path + elif cfg.bypass.find_last_ckpt_for_resume: + _ckpt_dir = find_latest_run_dir(run_parent_dir=cfg.bypass.experiment_dir) + if _ckpt_dir is None: + mprint( + "Couldn't find any run dir for resume, assuming this is the first job" + ) + else: + mprint( + f"`cfg.bypass.find_last_ckpt_for_resume` is True. " + f"Auto-found a checkpoint to resume: `{_ckpt_dir}`" + ) + resume_checkpoint_path = _ckpt_dir + + if resume_checkpoint_path: + load_local_state( + stitched_module_descriptors=stitched_module_descriptors, + checkpoint_path=resume_checkpoint_path, + ) + + # Load resume ckpt bypass configs and extract resume iter_num + resume_cfg = DictConfig(json_load(Path(resume_checkpoint_path) / "args.json")) + + # Resume stats + cfg.bypass.iter_num = resume_cfg.iter_num + cfg.bypass.token_count = resume_cfg.token_count + cfg.bypass.step_num = resume_cfg.step_num + cfg.bypass.best_val_loss = resume_cfg.best_val_loss + cfg.bypass.training.clipping_count = resume_cfg.training.clipping_count + mprint(f"Resume from iter_num: {cfg.bypass.iter_num}") + + # Only copy wandb.run_id if it exists in resume config + if hasattr(resume_cfg, "wandb") and hasattr(resume_cfg.wandb, "run_id"): + cfg.bypass.wandb.run_id = resume_cfg.wandb.run_id + + cfg.bypass.save_checkpoint_before_training = False + cfg.bypass.validate_teacher_model = False + cfg.bypass.validate_student_model = False + + cfg.bypass.resume_checkpoint_path = resume_checkpoint_path + + # Initialize Weights and Biases + if cfg.bypass.wandb_log: + try: + import wandb + + wandb.init( + project=cfg.bypass.wandb.project, + entity=cfg.bypass.wandb.entity, + config=dict(cfg.bypass), + ) + except ImportError: + mprint("wandb not installed, disabling wandb logging") + cfg.bypass.wandb_log = False + else: + mprint("Weights & Biases logging disabled (wandb_log=False)") + + if cfg.bypass.validate_teacher_model and val_dataloader is not None: + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + mprint("Evaluating teacher model:") + losses, _ = calculate_losses_pipeline( + stitched_model=teacher_val_stitched_module, + dataloader=val_dataloader, + descriptor=descriptor, + ) + if losses is not None: + mprint(f"Teacher validation losses: {losses}") + mprint("Evaluated teacher model") + + torch.cuda.empty_cache() + dist.barrier() + + parameter_count = sum(p.numel() for p in student_model.parameters()) + aprint(f"Model parameter count: {parameter_count:,}") + cfg.bypass.parameter_count = parameter_count + + dist.barrier() + mprint("Performing dummy runs on stitched modules:") + torch.cuda.synchronize() + with torch.no_grad(), torch.autocast( + device_type="cuda", dtype=torch.bfloat16 + ), torch.device(device): + input_ids = torch.ones( + (cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), + dtype=torch.long, + ) + dummy_fake_input_ids = fake_tensor(input_ids) + mprint(f"Dummy runs on stitched modules with shape: {dummy_fake_input_ids.shape=}") + teacher_output = teacher_stitched_model({}, {}, input_ids) + for stitched_module_descriptor in stitched_module_descriptors.values(): + stitched_module = stitched_module_descriptor.stitched_module + stitched_module( + input_overrides={ + **teacher_output.captured_inputs, + "teacher_inputs": InputArgs(dummy_fake_input_ids), + }, + output_overrides=teacher_output.captured_outputs, + ) + for name, param in stitched_module.named_parameters(recurse=True): + if "iter_num" in name: + param.data = torch.zeros_like(param.data) + del name, param + del input_ids, dummy_fake_input_ids, teacher_output + torch.cuda.synchronize() + dist.barrier() + + del teacher_model + + if cfg.bypass.validate_student_model and val_dataloader is not None: + from modelopt.torch.puzzletron.utils.validate_runtime_pipeline import ( + calculate_losses_pipeline, + ) + + mprint("Validating model before training:") + losses, _ = calculate_losses_pipeline( + stitched_model=student_val_stitched_model, + dataloader=val_dataloader, + descriptor=descriptor, + ) + if losses is not None: + mprint(f"Student validation losses: {losses}") + + dist.barrier() + torch.cuda.empty_cache() + dist.barrier() + + train( + cfg=cfg, + descriptor=descriptor, + student_model=student_model, + student_stitched_model=student_val_stitched_model, + teacher_stitched_model=teacher_stitched_model, + stitched_module_descriptors=stitched_module_descriptors, + stitched_modules_process_ownership=model_blocks_process_ownership, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + student_model_config=student_model_config, + skip_first_batches=cfg.bypass.training.skip_first_batches, + tokenizer=tokenizer, + ) + + aprint("Finished training successfully!") + dist.barrier() + + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + if isinstance(e, SystemExit): + raise e + else: + sys.exit(1) + + dist.barrier() + if dist.is_master(): + mprint("Realizing bypass checkpoints") + realize_bypass_checkpoints(cfg) + + +def realize_bypass_checkpoints(cfg: DictConfig): + """Create symlinks from bypass checkpoint directories to the ckpts directory.""" + checkpoint_dir = Path(cfg.bypass.experiment_dir) / "latest" + if not checkpoint_dir.exists(): + mprint(f"Could not find checkpoint directory: {checkpoint_dir}") + return + + ckpts_dir = Path(cfg.puzzle_dir) / "ckpts" + ckpts_dir.mkdir(parents=True, exist_ok=True) + + symlink_name = ckpts_dir / cfg.bypass.experiment_id + if symlink_name.exists() or symlink_name.is_symlink(): + symlink_name.unlink() + + symlink_name.symlink_to(checkpoint_dir, target_is_directory=True) + mprint(f"Created symlink: {symlink_name} -> {checkpoint_dir}") diff --git a/modelopt/torch/puzzletron/mip/run_puzzle.py b/modelopt/torch/puzzletron/mip/run_puzzle.py index 761534f6df9..bade8cfb15b 100644 --- a/modelopt/torch/puzzletron/mip/run_puzzle.py +++ b/modelopt/torch/puzzletron/mip/run_puzzle.py @@ -81,6 +81,7 @@ class Type(enum.Enum): "target_throughput", "target_latency", "target_time_to_first_token", + "target_num_kv_heads", "num_params", "stats.has_attention", } @@ -167,6 +168,10 @@ def to_mip_constraints(self, subblock_stats_args) -> dict[str, Any]: if "target_memory" in self.constraints: mip_constraints["stats.memory_mib"] = self.constraints["target_memory"] + # Total KV-heads constraint (sum across attention layers; used for KV-cache-only sweeps) + if "target_num_kv_heads" in self.constraints: + mip_constraints["stats.num_kv_heads"] = self.constraints["target_num_kv_heads"] + # Throughput constraints throughput_constraints = [] if "target_throughput" in self.constraints: diff --git a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py index 740d1fada3c..80f345513bb 100644 --- a/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py +++ b/modelopt/torch/puzzletron/pruning/kv_heads_pruning_mixin.py @@ -24,7 +24,7 @@ ) from .pruning_mixin import LayerDescriptor, PruningMixIn -from .pruning_utils import GQAInitMode, _init_attention_biases, _init_attention_weights +from .pruning_utils import GQAInitMode, _init_attention_biases, _init_attention_weights, _lm_attrs __all__ = [ "KVHeadsLayerDescriptor", @@ -74,7 +74,7 @@ def prune_single_layer( f"{attn_prefix}.{proj_name}" for proj_name in self.layer_descriptor.qkvo_weight_names ] - head_size = new_config.head_dim + head_size = _lm_attrs(new_config).head_dim for part in ["weight", "bias"]: attn_keys = [f"{name}.{part}" for name in [q_name, k_name, v_name, o_name]] q_key, k_key, v_key, o_key = attn_keys diff --git a/modelopt/torch/puzzletron/pruning/pruning_utils.py b/modelopt/torch/puzzletron/pruning/pruning_utils.py index c600e119cfa..10e3f35c4f6 100644 --- a/modelopt/torch/puzzletron/pruning/pruning_utils.py +++ b/modelopt/torch/puzzletron/pruning/pruning_utils.py @@ -52,6 +52,7 @@ class MlpInitMode(Enum): PruneByActivationsLog = "PruneByActivationsLog" ExpertRemoval = "ExpertRemoval" ConcatExpertsIntoDenseFFN = "ConcatExpertsIntoDenseFFN" + MoEChannelPruning = "MoEChannelPruning" class LinearInitMode(Enum): @@ -66,6 +67,22 @@ class HiddenSizeInitMode(Enum): CopyAsIs = "CopyAsIs" +def _lm_attrs(config): + """Return the language-model sub-config for VL configs, else the config itself. + + VL configs nest language-model fields like ``num_attention_heads``, ``head_dim``, + and ``hidden_size`` under a sub-config. The attribute name varies by family — + ``text_config`` (Qwen3-VL, Llava, Idefics) and ``language_config`` (Llama-4 and + a handful of others) are both common. Probe both before falling back to the + raw config. + """ + for attr in ("text_config", "language_config"): + sub = getattr(config, attr, None) + if sub is not None: + return sub + return config + + def resolve_pruning_mixin( pruning_mixin, descriptor: Type[ModelDescriptor] ) -> PruningMixIn | List[PruningMixIn]: @@ -224,10 +241,13 @@ def _init_attention_weights( head_size, mlp_init_config, ): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + new_lm = _lm_attrs(new_config) + orig_lm = _lm_attrs(original_config) + assert new_lm.num_attention_heads == orig_lm.num_attention_heads, ( + f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})" ) - num_q_heads = new_config.num_attention_heads + num_q_heads = new_lm.num_attention_heads + # block_configs lives on the outer puzzletron-converted config, not on text_config. num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads @@ -372,17 +392,27 @@ def _init_attention_biases( head_size, mlp_init_config, ): - assert new_config.num_attention_heads == original_config.num_attention_heads, ( - f"({new_config.num_attention_heads=}) != ({original_config.num_attention_heads=})" + new_lm = _lm_attrs(new_config) + orig_lm = _lm_attrs(original_config) + assert new_lm.num_attention_heads == orig_lm.num_attention_heads, ( + f"({new_lm.num_attention_heads=}) != ({orig_lm.num_attention_heads=})" ) - num_q_heads = new_config.num_attention_heads + num_q_heads = new_lm.num_attention_heads + # block_configs lives on the outer puzzletron-converted config, not on text_config. num_kv_heads = new_config.block_configs[layer_idx].attention.num_key_value_heads orig_num_kv_heads = original_config.block_configs[layer_idx].attention.num_key_value_heads n_heads_in_group = num_q_heads // num_kv_heads orig_n_heads_in_group = num_q_heads // orig_num_kv_heads - o_proj_bias = new_config.o_proj_bias - attention_bias = new_config.attention_bias + # Some HF native configs (e.g. GptOssConfig) don't expose o_proj_bias / attention_bias as + # top-level attributes the way puzzletron's DeciLM-style configs do. Fall back to probing + # the new state dict for the actual bias keys when the attribute is missing. + o_proj_bias = getattr(new_config, "o_proj_bias", None) + if o_proj_bias is None: + o_proj_bias = o_key in new_state_dict + attention_bias = getattr(new_config, "attention_bias", None) + if attention_bias is None: + attention_bias = q_key in new_state_dict # If no biases if not (o_proj_bias or attention_bias): diff --git a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py index 253674f97af..438e5f4e298 100644 --- a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py +++ b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py @@ -37,6 +37,7 @@ ) from modelopt.torch.opt.searcher import BaseSearcher, SearchStateDict +from . import bypass_distillation from .activation_scoring import launch_score_activations from .anymodel.converter import ConverterFactory from .anymodel.model_descriptor import ModelDescriptorFactory @@ -100,10 +101,53 @@ class PuzzletronConfig(ModeloptBaseConfig): ) +_StageName = str + +# Canonical stage order. Stages absent from a given run (e.g. "bypass" when +# bypass isn't configured) are skipped, but the rest keep their relative order. +_STAGE_ORDER: tuple[_StageName, ...] = ( + "start", + "convert", + "score_activations", + "prune", + "bypass", + "build_library", + "score_blocks", + "mip", + "complete", +) + + +def _total_steps(hydra_cfg) -> int: + """Return total pipeline step count: 9 with bypass, 8 without.""" + return 9 if hydra_cfg.get("bypass", None) is not None else 8 + + +def _progress_step(hydra_cfg, stage: _StageName) -> tuple[int, int]: + """Return ``(step_number, total_steps)`` for a given pipeline stage. + + Single source of truth for the user-facing ``Puzzletron Progress N/T`` strings — + keeps numbering coherent across ``main.py``, ``convert_puzzletron_model``, and + ``PuzzletronSearcher.run_search``, and shifts MIP/realize automatically when + bypass is added or removed. + """ + has_bypass = hydra_cfg.get("bypass", None) is not None + total = _total_steps(hydra_cfg) + step = 0 + for s in _STAGE_ORDER: + if s == "bypass" and not has_bypass: + continue + step += 1 + if s == stage: + return step, total + raise ValueError(f"Unknown pipeline stage: {stage!r}") + + def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> ConvertReturnType: """1. Convert the model from HF format to AnyModel format. 2. Score the pruning activations. - 3. Prune the model and save pruned checkpoints + 3. Prune the model and save pruned checkpoints. + 4. (Optional) Run bypass distillation. The output of this step will be used by mnt.search() to perform the NAS search. """ @@ -125,37 +169,101 @@ def convert_puzzletron_model(model: nn.Module, config: PuzzletronConfig) -> Conv # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) hydra_cfg = hydra.utils.instantiate(hydra_cfg) - # Convert HuggingFace model to Puzzletron heterogeneous format (generic, uses descriptor from config) - if dist.is_master(): - mprint( - "Puzzletron Progress 2/8: converting model to Puzzletron heterogeneous format (single-gpu)" - ) - hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + has_bypass = hydra_cfg.get("bypass", None) is not None + convert_step, N = _progress_step(hydra_cfg, "convert") + score_step, _ = _progress_step(hydra_cfg, "score_activations") + prune_step, _ = _progress_step(hydra_cfg, "prune") - # Get descriptor and converter from the hydra config - descriptor_name = hydra_cfg.descriptor - descriptor = ModelDescriptorFactory.get(descriptor_name) - converter = ConverterFactory.get(descriptor_name) + # Step 2: Convert HuggingFace model to Puzzletron heterogeneous format + hf_ckpt_teacher_dir = "ckpts/teacher" # TODO: make it configurable + teacher_dir = Path(config.puzzle_dir) / hf_ckpt_teacher_dir + if dist.is_master(): + if (teacher_dir / "config.json").exists(): + mprint( + f"Puzzletron Progress {convert_step}/{N}: teacher checkpoint already exists, skipping conversion" + ) + else: + mprint( + f"Puzzletron Progress {convert_step}/{N}: converting model to Puzzletron heterogeneous format (single-gpu)" + ) - converter.convert( - descriptor=descriptor, - input_dir=Path(config.input_model_path), - output_dir=Path(config.puzzle_dir) / hf_ckpt_teacher_dir, - ) + # Get descriptor and converter from the hydra config + descriptor_name = hydra_cfg.descriptor + descriptor = ModelDescriptorFactory.get(descriptor_name) + converter = ConverterFactory.get(descriptor_name) + + # Auto-download from HuggingFace if path doesn't exist locally + input_model_path = config.input_model_path + if not Path(input_model_path).exists(): + from huggingface_hub import snapshot_download + + if input_model_path.startswith("https://huggingface.co/"): + model_id = "/".join(input_model_path.rstrip("/").split("/")[-2:]) + else: + model_id = input_model_path # assume HF model ID like "org/model-name" + mprint( + f"Downloading HuggingFace model '{model_id}' — this may take several minutes " + f"for large models. Other ranks are waiting at a barrier." + ) + input_model_path = snapshot_download(repo_id=model_id) + mprint(f"Downloaded to: {input_model_path}") + + converter.convert( + descriptor=descriptor, + input_dir=Path(input_model_path), + output_dir=teacher_dir, + ) dist.barrier() - # Score_pruning_activations (distributed processing) - mprint("Puzzletron Progress 3/8: scoring pruning activations (multi-gpu)") - launch_score_activations(hydra_cfg) - - # Prune the model and save pruned checkpoints - if dist.is_master(): + # Step 3: Score pruning activations (distributed processing) + activations_log_dir = Path(hydra_cfg.pruning.activations_log_dir) + if activations_log_dir.exists() and any(activations_log_dir.glob("rank_*.pth")): mprint( - "Puzzletron Progress 4/8: pruning the model and saving pruned checkpoints (single-gpu)" + f"Puzzletron Progress {score_step}/{N}: pruning activation scores already exist, skipping scoring" ) - launch_prune_ckpt(hydra_cfg) + dist.barrier() + else: + mprint(f"Puzzletron Progress {score_step}/{N}: scoring pruning activations (multi-gpu)") + launch_score_activations(hydra_cfg) + + # Step 4: Prune the model and save pruned checkpoints (single process) + pruned_ckpts_dir = Path(hydra_cfg.pruning.pruned_ckpts_output_dir) + if dist.is_master(): + if pruned_ckpts_dir.exists() and any(pruned_ckpts_dir.iterdir()): + mprint( + f"Puzzletron Progress {prune_step}/{N}: pruned checkpoints already exist, skipping pruning" + ) + else: + mprint( + f"Puzzletron Progress {prune_step}/{N}: pruning the model and saving pruned checkpoints (single-gpu)" + ) + launch_prune_ckpt(hydra_cfg) dist.barrier() + # Step 5: Bypass distillation (optional, distributed processing) + if has_bypass: + bypass_step, _ = _progress_step(hydra_cfg, "bypass") + # Skip if a previous run already produced bypass checkpoints. The realize step + # writes a `latest` symlink under each experiment_dir; if any exists, bypass has + # completed and rerunning would waste 5-15 min on teacher load + dataloader setup + # before its own resume-from-checkpoint logic short-circuits. + bypass_runs_dir = Path(config.puzzle_dir) / "bypass" / "bypass_runs" + bypass_done = bypass_runs_dir.exists() and any( + (run_dir / "latest").exists() + for run_dir in bypass_runs_dir.iterdir() + if run_dir.is_dir() + ) + if bypass_done: + mprint( + f"Puzzletron Progress {bypass_step}/{N}: bypass distillation already completed, skipping" + ) + else: + mprint( + f"Puzzletron Progress {bypass_step}/{N}: running bypass distillation (multi-gpu)" + ) + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + return model, {} @@ -226,18 +334,52 @@ def run_search(self) -> None: # Instantiate nested Hydra configs (e.g., pruning_mixin, hook_class) hydra_cfg = hydra.utils.instantiate(hydra_cfg) - # Build_library_and_stats (single process) + library_step, N = _progress_step(hydra_cfg, "build_library") + scoring_step, _ = _progress_step(hydra_cfg, "score_blocks") + mip_step, _ = _progress_step(hydra_cfg, "mip") + + # Build replacement library and subblock statistics (single process) + puzzle_dir = Path(self.model.puzzle_dir) + replacement_library_path = puzzle_dir / "replacement_library.json" + subblock_stats_path = puzzle_dir / hydra_cfg.calc_subblock_stats.subblock_stats_filename + # Detect a stale library: any ckpts/* entry newer than the library file means + # a new replacement (e.g. bypass-trained subblocks) appeared after the last build + # and must be picked up. Without this check, our skip-if-done would happily reuse + # a no-bypass library even after bypass completes. + ckpts_dir = puzzle_dir / "ckpts" + library_is_stale = False + if replacement_library_path.exists() and ckpts_dir.exists(): + library_mtime = replacement_library_path.stat().st_mtime + for entry in ckpts_dir.iterdir(): + # Resolve symlinks (bypass + pruning checkpoints land here as symlinks + # to the real directories elsewhere under puzzle_dir). + resolved = entry.resolve() if entry.is_symlink() else entry + if resolved.exists() and resolved.stat().st_mtime > library_mtime: + library_is_stale = True + mprint( + f"Replacement library is stale: '{entry.name}' is newer than the existing library, will rebuild." + ) + break if dist.is_master(): - mprint( - "Puzzletron Progress 5/8: building replacement library and subblock statistics (single-gpu)" - ) - launch_build_library_and_stats(hydra_cfg) + if ( + replacement_library_path.exists() + and subblock_stats_path.exists() + and not library_is_stale + ): + mprint( + f"Puzzletron Progress {library_step}/{N}: replacement library and subblock stats already exist, skipping" + ) + else: + mprint( + f"Puzzletron Progress {library_step}/{N}: building replacement library and subblock statistics (single-gpu)" + ) + launch_build_library_and_stats(hydra_cfg) dist.barrier() - # Calc_one_block_scores (distributed processing) - mprint("Puzzletron Progress 6/8: calculating one block scores (multi-gpu)") + # Calculate one block scores (distributed processing) + mprint(f"Puzzletron Progress {scoring_step}/{N}: calculating one block scores (multi-gpu)") launch_scoring(hydra_cfg) - # mip_and_realize_models (distributed processing) - mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + # MIP search and realize models (distributed processing) + mprint(f"Puzzletron Progress {mip_step}/{N}: running MIP and realizing models (multi-gpu)") launch_mip_and_realize_model(hydra_cfg) diff --git a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py index 999ec6c690a..c2ba5bb7093 100644 --- a/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py +++ b/modelopt/torch/puzzletron/replacement_library/build_replacement_library.py @@ -211,7 +211,21 @@ def _build_subblocks_df( checkpoint_dirs = _get_last_checkpoint_from_each_experiment( master_puzzle_dir, trust_remote_code=trust_remote_code ) - checkpoint_dirs = [teacher_checkpoint_dir] + list(checkpoint_dirs - {teacher_checkpoint_dir}) + + # Order the non-teacher checkpoints so that downstream `drop_duplicates(keep="first")` + # deterministically prefers bypass-trained subblocks over Truncate-init pruned ones + # when both produce a row with the same architectural identifier. Without this, + # `set` iteration order makes the choice random (hash-of-path) and we'd sometimes + # discard the BLD-trained weights we just paid 30+ min to compute. + # + # Priority (lowest sort key wins): 0 = bypass-trained, 1 = everything else. + # Bypass checkpoints land under `/bypass/bypass_runs//`. + def _checkpoint_priority(p: Path) -> tuple[int, str]: + is_bypass = "bypass" in p.parts and "bypass_runs" in p.parts + return (0 if is_bypass else 1, str(p)) + + non_teacher_dirs = sorted(checkpoint_dirs - {teacher_checkpoint_dir}, key=_checkpoint_priority) + checkpoint_dirs = [teacher_checkpoint_dir] + non_teacher_dirs checkpoints_to_split = [teacher_checkpoint_dir] subblock_rows = [] diff --git a/modelopt/torch/puzzletron/sewing_kit/passage.py b/modelopt/torch/puzzletron/sewing_kit/passage.py index d8fa1f51cf9..c77b9dd41cd 100644 --- a/modelopt/torch/puzzletron/sewing_kit/passage.py +++ b/modelopt/torch/puzzletron/sewing_kit/passage.py @@ -45,6 +45,7 @@ "PassageOutput", "Predicate", "always_false_predicate", + "always_true_predicate", "Passage", "patch_module", ] diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index 3db63f60013..d8df3423a57 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -23,6 +23,7 @@ Callable, ContextManager, Generic, + Literal, Optional, Protocol, TypeVar, @@ -35,6 +36,7 @@ import torch._dynamo import torch.distributed import torch.nn as nn +import torch.nn.functional as F import torch.utils._pytree as pytree from torch import Tensor from torch._subclasses import FakeTensor, FakeTensorMode @@ -451,3 +453,55 @@ def _get_group_kwarg_if_necessary() -> dict: torch.distributed.distributed_c10d._object_to_tensor ).parameters.keys() return dict(group=None) if "group" in arg_names else dict() + + +# ────────────────────────────────────────────────────────────────────────────── +# Loss functions for bypass distillation (blockwise local knowledge distillation) +# ────────────────────────────────────────────────────────────────────────────── + +Reduction = Literal["none", "mean", "sum"] + + +def normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + reduction: Reduction = "mean", + epsilon: float = 1e-6, +) -> torch.Tensor: + """MSE loss normalized by the variance of the target. + + Dividing by the target's self-MSE makes the loss scale-invariant, so that + blocks whose activations have large magnitude do not dominate training. + """ + loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction=reduction + ) + return loss + + +def vectorwise_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, +) -> torch.Tensor: + """Like normalized_mse_loss, but normalization is done per-vector (last dim), then averaged.""" + return batched_normalized_mse_loss(input, target, epsilon, batch_dims=range(input.ndim - 1)) + + +def batched_normalized_mse_loss( + input: torch.Tensor, + target: torch.Tensor, + epsilon: float = 1e-6, + batch_dims: Sequence[int] = (0,), +) -> torch.Tensor: + """Like normalized_mse_loss, but normalization is done on non-batch dims, then averaged. + + Useful when activations within a batch item should be normalized independently + rather than normalizing across the full batch. + """ + norm_dims = list(set(range(input.ndim)) - set(batch_dims)) + norm_of_target_vectors = F.mse_loss( + target, torch.zeros_like(target) + epsilon, reduction="none" + ).mean(norm_dims) + loss = F.mse_loss(input, target, reduction="none").mean(norm_dims) / norm_of_target_vectors + return loss.mean() diff --git a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py index b242c7d48ac..93206efaf1d 100644 --- a/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py +++ b/modelopt/torch/puzzletron/tools/bypassed_training/child_init.py @@ -83,26 +83,30 @@ def _process_single_layer( keys_to_remove = {} layer_out_state_dict = {} - # Delegate to pruning_mixin if available + # Delegate to pruning_mixin if available (supports a single mixin or a list of mixins). + # When the bypass factory composes multiple mixins (e.g. experts_removal + kv_heads), + # it passes them as a list so each can contribute its slice of the layer state dict. if pruning_mixin is not None: - _layer_out = pruning_mixin.prune_single_layer( - layer_idx=layer_idx, - parent_state_dict=parent_state_dict, - new_state_dict=new_state_dict, - original_config=original_config, - new_config=new_config, - gqa_init_mode=gqa_init_mode, - mlp_init_mode=mlp_init_mode, - mlp_init_config=mlp_init_config, - linear_init_mode=linear_init_mode, - ignored_keys=ignored_keys, - keys=keys, - is_original_mha=is_original_mha, - head_size=head_size, - hidden_size=hidden_size, - keys_to_remove=keys_to_remove, - ) - layer_out_state_dict.update(_layer_out) + _mixins = pruning_mixin if isinstance(pruning_mixin, list) else [pruning_mixin] + for _mixin in _mixins: + _layer_out = _mixin.prune_single_layer( + layer_idx=layer_idx, + parent_state_dict=parent_state_dict, + new_state_dict=new_state_dict, + original_config=original_config, + new_config=new_config, + gqa_init_mode=gqa_init_mode, + mlp_init_mode=mlp_init_mode, + mlp_init_config=mlp_init_config, + linear_init_mode=linear_init_mode, + ignored_keys=ignored_keys, + keys=keys, + is_original_mha=is_original_mha, + head_size=head_size, + hidden_size=hidden_size, + keys_to_remove=keys_to_remove, + ) + layer_out_state_dict.update(_layer_out) return layer_out_state_dict, keys_to_remove # Legacy inline processing (fallback when no pruning_mixin) @@ -791,7 +795,7 @@ def update_model_config( def override(item, item_overrides): if item_overrides is None: - return item_overrides + return item # None override means "keep original value" if dataclasses.is_dataclass(item): assert isinstance(item_overrides, dict) return dataclass_override(item, item_overrides) diff --git a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py index 1240d1c9b65..ec72f1bec28 100644 --- a/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py +++ b/modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py @@ -21,7 +21,10 @@ import concurrent.futures import dataclasses import fcntl +import inspect import os +import re +import shutil import time from collections import defaultdict from collections.abc import Callable, Mapping @@ -135,16 +138,23 @@ def load_model_config( return config +_FALLBACK_WARNED_CLASSES: set[str] = set() + + def _get_model_class_from_config(config: PretrainedConfig) -> type: """Resolve HuggingFace model class from ``config.architectures`` (see puzzletron checkpoint_utils_hf).""" if hasattr(config, "architectures") and config.architectures: model_class_name = config.architectures[0] if hasattr(transformers, model_class_name): return getattr(transformers, model_class_name) - mprint( - f"Warning: {model_class_name} not found in transformers, " - "falling back to AutoModelForCausalLM" - ) + # Warn at most once per missing class per process — the fallback path + # may be hit thousands of times during scoring/realize loops. + if model_class_name not in _FALLBACK_WARNED_CLASSES: + _FALLBACK_WARNED_CLASSES.add(model_class_name) + mprint( + f"Warning: {model_class_name} not found in transformers, " + "falling back to AutoModelForCausalLM" + ) return AutoModelForCausalLM @@ -490,6 +500,44 @@ def _build_safetensors_weight_map( return weight_map +def _copy_auto_map_code_files(model_config: PretrainedConfig, checkpoint_dir: Path) -> None: + """Copy custom modeling Python files referenced in ``auto_map`` to the checkpoint dir. + + ``PretrainedConfig.save_pretrained()`` only copies the config class's own source file + (e.g. ``configuration_nemotron_h.py``). Trust-remote-code models also need ``modeling_*.py`` + (and any other auto_map-referenced ``.py``) present alongside ``config.json``, otherwise + later ``AutoConfig.from_pretrained(..., trust_remote_code=True)`` calls fail with + "does not appear to have a file named modeling_*.py". + + We discover the source directory from the config class itself (via ``inspect.getfile``) + and copy every distinct ``.py`` referenced by the auto_map values. + """ + if not hasattr(model_config, "auto_map") or not isinstance(model_config.auto_map, dict): + return + + try: + source_dir = Path(inspect.getfile(type(model_config))).parent + except (TypeError, OSError): + # Built-in / non-file-backed config class — nothing to copy. + return + + # Module names must look like Python identifiers — refuse anything with separators + # or relative-path components so a malformed/hostile auto_map can't drive shutil.copy + # outside source_dir / checkpoint_dir. + _module_name_re = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + module_names = {class_ref.split(".")[0] for class_ref in model_config.auto_map.values()} + + for module_name in module_names: + if not _module_name_re.match(module_name): + mprint(f"Warning: skipping non-identifier auto_map module name: {module_name!r}") + continue + filename = f"{module_name}.py" + src = source_dir / filename + dst = Path(checkpoint_dir) / filename + if src.exists() and not dst.exists(): + shutil.copy(src, dst) + + def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str) -> None: if hasattr(model_config, "block_configs"): model_config.block_configs = [ @@ -497,3 +545,4 @@ def save_model_config(model_config: PretrainedConfig, checkpoint_dir: Path | str for conf in model_config.block_configs ] model_config.save_pretrained(checkpoint_dir) + _copy_auto_map_code_files(model_config, Path(checkpoint_dir)) diff --git a/modelopt/torch/puzzletron/tools/robust_json.py b/modelopt/torch/puzzletron/tools/robust_json.py new file mode 100644 index 00000000000..0b424dce95c --- /dev/null +++ b/modelopt/torch/puzzletron/tools/robust_json.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# mypy: ignore-errors + +""" +Provides a robust JSON encoder that can handle various types of objects, +including dataclasses, paths, enums, namespaces, and functions. +""" + +import argparse +import dataclasses +import datetime +import inspect +import json +from enum import Enum +from pathlib import Path +from typing import Any + +from omegaconf import DictConfig, ListConfig, OmegaConf + + +class RobustJSONEncoder(json.JSONEncoder): + def default(self, o): + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + if isinstance(o, Path): + return str(o) + if isinstance(o, Enum): + return o.name + if isinstance(o, argparse.Namespace): + return vars(o) + if type(o).__name__ == "dtype": + return str(o) + if isinstance(o, (DictConfig, ListConfig)): + return OmegaConf.to_container(o, resolve=True) + if inspect.isfunction(o) or inspect.ismethod(o): + if o.__module__ == "__main__": + # User-defined function in main — fallback to just the name + return o.__name__ + return f"{o.__module__}.{o.__qualname__}" + if inspect.isclass(o): + return f"{o.__module__}.{o.__qualname__}" + if isinstance(o, datetime.timedelta): + return str(o) + # Fallback for arbitrary objects: return their class path + if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"): + return f"{o.__class__.__module__}.{o.__class__.__qualname__}" + return super().default(o) + + +def json_dumps(obj: Any) -> str: + return json.dumps(obj, cls=RobustJSONEncoder, indent=2) + + +def json_dump(obj: Any, path: Path | str) -> None: + path = Path(path) + path.parent.mkdir(exist_ok=True, parents=True) + json_text = json_dumps(obj) + path.write_text(json_text) + + +def json_load(path: Path | str) -> dict: + path = Path(path) + text = path.read_text() + return json.loads(text) diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py index f4046531491..a90550b64ec 100644 --- a/modelopt/torch/puzzletron/utils/data/dataloaders.py +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -31,7 +31,7 @@ from ...tools.logger import mprint from .dataset import ConstantLengthDataset -__all__ = ["create_validation_dataloader", "create_padded_tensor"] +__all__ = ["create_train_dataloader", "create_validation_dataloader", "create_padded_tensor"] def collate_none_fn( @@ -73,6 +73,54 @@ def load_streaming_fn( return dataset +def create_train_dataloader( + seed: int, + tokenizer: PreTrainedTokenizerBase, + block_size: int, + dataset_path: str | Mapping[str, Dataset], + content_field: str, + fim_rate: float, + fim_spm_rate: float, + micro_batch_size: int, + load_dataset_fn: LoadDatasetFn = load_from_disk_fn, + dataset_name: str = "train", + keep_in_memory: bool = False, + shuffle_seed: int | None = None, + source_datasets_to_discard: Sequence[str] = (), + bos_rate: float = 1.0, + num_workers: int = 0, +) -> DataLoader: + """Create an infinite training DataLoader over ConstantLengthDataset.""" + if isinstance(dataset_path, str): + dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory) + else: + dataset = dataset_path + + train_data = dataset[dataset_name] + if shuffle_seed is not None: + train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) + + train_dataset = ConstantLengthDataset( + tokenizer, + train_data, + infinite=True, + seq_length=block_size, + content_field=content_field, + fim_rate=fim_rate, + fim_spm_rate=fim_spm_rate, + seed=seed, + source_datasets_to_discard=source_datasets_to_discard, + bos_rate=bos_rate, + ) + + return DataLoader( + train_dataset, + batch_size=micro_batch_size, + pin_memory=True, + num_workers=num_workers, + ) + + def create_validation_dataloader( accelerator: Accelerator | None, seed: int, diff --git a/modelopt/torch/puzzletron/utils/data/dataset.py b/modelopt/torch/puzzletron/utils/data/dataset.py index f88e44a234b..511570079ab 100644 --- a/modelopt/torch/puzzletron/utils/data/dataset.py +++ b/modelopt/torch/puzzletron/utils/data/dataset.py @@ -128,7 +128,14 @@ def __iter__(self) -> dict[str, torch.Tensor]: and {"content", "role"}.issubset(sample[0]) ): if len(sample) > 1: - sample = self.tokenizer.apply_chat_template(sample, tokenize=False) + if getattr(self.tokenizer, "chat_template", None) is not None: + sample = self.tokenizer.apply_chat_template( + sample, tokenize=False + ) + else: + # Base models have no chat template — concatenate message + # contents separated by newlines as plain text. + sample = "\n".join(m["content"] for m in sample) else: sample = sample[0]["content"] else: diff --git a/modelopt/torch/puzzletron/utils/parsing.py b/modelopt/torch/puzzletron/utils/parsing.py index 149563b4321..20d6f08f977 100644 --- a/modelopt/torch/puzzletron/utils/parsing.py +++ b/modelopt/torch/puzzletron/utils/parsing.py @@ -24,6 +24,7 @@ # mypy: ignore-errors import json +import math from pathlib import Path from typing import Any @@ -338,6 +339,20 @@ def format_stitched_losses( if not losses_dict: return "❌ No losses found" + # Filter out nan entries — these are no-op blocks (e.g. Mamba) with no trainable + # parameters. The training loop sets their stitched_module_loss to NaN intentionally + # (see training_loop.py); filtering them here keeps the table focused on the + # actually-trained blocks while still surfacing real NaNs (which would only appear + # on a block that does have an optimizer and has diverged). + losses_dict = {k: v for k, v in losses_dict.items() if not math.isnan(v)} + if best_steps_dict: + best_steps_dict = {k: v for k, v in best_steps_dict.items() if k in losses_dict} + if best_values_dict: + best_values_dict = {k: v for k, v in best_values_dict.items() if k in losses_dict} + + if not losses_dict: + return "❌ No trainable blocks found" + lines = [] # Calculate statistics diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index ea0a6fd2193..e56e93cc0e1 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -16,6 +16,7 @@ import os from pathlib import Path +import pytest import torch from _test_utils.torch.transformers_models import get_tiny_tokenizer from datasets import Dataset, DatasetDict @@ -25,6 +26,42 @@ import modelopt.torch.utils.distributed as dist from modelopt.torch.export import copy_hf_ckpt_remote_code +# Shared parametrize tuple for puzzletron GPU integration tests. +# Fields: (hf_model_name, converter, hybrid_override_pattern, has_moe_layers). +# To add a new model family, append a single pytest.param row here — every test +# that imports PUZZLETRON_FAMILIES picks it up automatically. +PUZZLETRON_FAMILIES = [ + pytest.param("meta-llama/Llama-3.1-8B-Instruct", "llama", None, False, id="llama-3.1-8B"), + pytest.param("meta-llama/Llama-3.2-3B-Instruct", "llama", None, False, id="llama-3.2-3B"), + pytest.param( + "mistralai/Mistral-Small-24B-Instruct-2501", + "mistral_small", + None, + False, + id="mistral-small-24B", + ), + pytest.param( + "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16", + "nemotron_h", + "*E", + True, + id="nemotron-3-30B-A3B", + ), + pytest.param( + "nvidia/NVIDIA-Nemotron-Nano-12B-v2", + "nemotron_h_v2", + "*-", + False, + id="nemotron-nano-12B-v2", + ), + pytest.param("openai/gpt-oss-20b", "gpt_oss", None, True, id="gpt-oss-20b"), + pytest.param("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False, id="qwen2.5-7B"), + pytest.param("Qwen/Qwen3-8B", "qwen3", None, False, id="qwen3-8B"), + pytest.param( + "Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True, id="qwen3-VL-30B-A3B" + ), +] + def setup_test_model_and_data( tmp_path: Path, rank: int, hf_model_name: str, hybrid_override_pattern: str | None = None diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py new file mode 100644 index 00000000000..3c77222c4d0 --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -0,0 +1,662 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU integration tests for bypass distillation (blockwise local distillation). + +Each test is parametrized over the same model families covered by ``test_puzzletron.py`` +(see ``PUZZLETRON_FAMILIES`` in ``tests/_test_utils/torch/puzzletron/utils.py``). + +Tiny model dimensions used throughout (set by ``setup_test_model_and_data``): + - hidden_size: 256, intermediate_size: 512, num_layers: max(2, world_size) + - num_attention_heads: 32, num_key_value_heads: 8 + - num_local_experts: 16 (MoE families only, e.g. Qwen3-VL) + - training_tokens: 128, block_size: 64, micro_batch_size: 1 -> max_steps = 2 + +Pruning targets (used by all four tests): + - pruned intermediate_size: 256 (dense) — half of teacher + - pruned num_local_experts: 8 (MoE) — half of teacher + - pruned num_key_value_heads: 4 — half of teacher + +mlp_init_mode is family-aware: + - Dense families use ``Truncate`` (FFN intermediate slicing in the generic path). + - MoE families use ``ExpertRemoval`` and delegate per-expert weight slicing to the + ``experts_removal`` mixin registered on the descriptor. ``mlp_init_config`` is + sourced from the family's pruning YAML (``mlp_init_config_yaml``) — no + per-family branching needed in this test file. + +To add a new model family: + 1. Append one row to PUZZLETRON_FAMILIES in tests/_test_utils/torch/puzzletron/utils.py. + 2. Ensure tests/gpu/torch/puzzletron/resources/configs//.yaml exists + and that setup_test_model_and_data() can build a tiny stand-in for it. + 3. For MoE families, ensure the family's descriptor registers ``"kv_heads"`` and + ``"experts_removal"`` in ``pruning_mixins()`` (see e.g. NemotronH, GPT-OSS, + Qwen3-VL descriptors). + 4. The four bypass tests below pick up the new row automatically. +""" + +import copy +from datetime import timedelta +from functools import partial +from pathlib import Path + +import hydra +import pytest +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.misc import set_seed +from _test_utils.torch.puzzletron.utils import PUZZLETRON_FAMILIES, setup_test_model_and_data +from omegaconf import OmegaConf + +import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations +import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation +import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import set_experiment_id +from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir + +# --------------------------------------------------------------------------- +# Constants — shared tiny-model dimensions and pruning targets +# --------------------------------------------------------------------------- + +SEED = 1234 + +# Teacher tiny-model dimensions (set uniformly by setup_test_model_and_data) +TEACHER_INTERMEDIATE_SIZE = 512 +TEACHER_NUM_KV_HEADS = 8 +TEACHER_NUM_LOCAL_EXPERTS = 16 + +# Pruned targets (half of teacher) +PRUNED_INTERMEDIATE_SIZE = 256 +PRUNED_NUM_KV_HEADS = 4 +PRUNED_NUM_LOCAL_EXPERTS = 8 + +# Training budget: 128 tokens / (64 block * 1 mbs) = 2 steps — completes fast +TRAINING_TOKENS = 128 +BLOCK_SIZE = 64 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _block_override(has_moe_layers: bool, pruned: bool = True) -> dict: + """Return a single FFN-block override entry, family-aware. + + When ``pruned=True`` the override compresses the block (halves intermediate size for + dense or halves num_local_experts for MoE). When ``pruned=False`` it pins the block + to teacher size — used by tests that exercise attention pruning while keeping the FFN + side fixed. + """ + if has_moe_layers: + n_experts = PRUNED_NUM_LOCAL_EXPERTS if pruned else TEACHER_NUM_LOCAL_EXPERTS + return {"moe": {"num_local_experts": n_experts}, "no_op": None} + intermediate = PRUNED_INTERMEDIATE_SIZE if pruned else TEACHER_INTERMEDIATE_SIZE + return {"intermediate_size": intermediate, "no_op": None} + + +def _mlp_init_settings(has_moe_layers: bool, hydra_cfg) -> tuple[str, dict]: + """Return ``(mlp_init_mode, mlp_init_config)`` for the family. + + Dense families use ``Truncate`` (FFN intermediate slicing). MoE families use + ``ExpertRemoval``, which delegates per-expert weight slicing to the + ``experts_removal`` mixin registered on the descriptor. The expert-scores + metadata (``expert_scores_key``, ``layer_prefix_template``) is read directly + from the family's pruning YAML — no per-family branching here. + """ + if not has_moe_layers: + return "Truncate", {"activations_log_dir": None} + + mlp_init_config = OmegaConf.to_container( + hydra_cfg.pruning.get("mlp_init_config_yaml", OmegaConf.create({})), + resolve=True, + ) or {} + mlp_init_config["activations_log_dir"] = str(hydra_cfg.pruning.activations_log_dir) + return "ExpertRemoval", mlp_init_config + + +def _make_bypass_cfg_dict( + has_moe_layers: bool, + hydra_cfg, + *, + include_block_override: bool = True, + block_pruned: bool = True, + include_attention_override: bool = True, + attention_pruned: bool = True, + configs_list: list | None = None, +) -> dict: + """Return a plain-dict bypass config suitable for OmegaConf.update injection. + + Args: + has_moe_layers: Whether the model family is MoE (dispatches FFN override shape + and the mlp_init_mode). + hydra_cfg: The post-pruning hydra config — used to source the family's + ``mlp_init_config_yaml`` and ``activations_log_dir`` for MoE expert removal. + include_block_override / block_pruned: Whether to override the per-block FFN + sub-component, and whether to prune (vs. pin to teacher). + include_attention_override / attention_pruned: Same for the attention sub-component. + configs_list: If provided, populates bypass.configs for a multi-config sweep. + """ + overrides: dict = {} + if include_block_override: + overrides["ffn"] = [_block_override(has_moe_layers, pruned=block_pruned)] + if include_attention_override: + kv = PRUNED_NUM_KV_HEADS if attention_pruned else TEACHER_NUM_KV_HEADS + overrides["attention"] = [{"num_key_value_heads": kv, "no_op": None}] + + mlp_init_mode, mlp_init_config = _mlp_init_settings(has_moe_layers, hydra_cfg) + + cfg = { + "dtype": "bf16", + "seed": 42, + "experiment_id": None, + "experiment_dir": None, + "iter_num": 1, + "step_num": 1, + "token_count": 0, + "data": { + # The dummy test dataset stores conversations under the "conversation" column. + "data_column": "conversation", + "block_size": BLOCK_SIZE, + "bos_rate": 0.5, + "fim_rate": 0, + "fim_spm_rate": 0, + "source_datasets_to_discard": [], + "load_from_disk": True, + "keep_in_memory": False, + "val_dataset_name": "valid", + "max_eval_samples": 1, + "eval_samples_per_process": None, + "shuffle_train_data_seed": 42, + }, + "training": { + "learning_rate": 1e-4, + "training_tokens": TRAINING_TOKENS, + "micro_batch_size": 1, + "val_micro_batch_size": 1, + "warmup_ratio": 0.05, + "warmup_steps": None, + "min_lr_factor": 1e-5, + "grad_accumulation_steps": 1, + "skip_first_batches": 0, + "weight_decay": 0.1, + "decay_lr": True, + "beta1": 0.9, + "beta2": 0.95, + "use_grad_scaling": False, + "grad_clip": 1.0, + "grad_clip_type": "norm", + "clipping_count": 0, + "log_interval": 5, + # Large eval_interval so validation is skipped during this short run. + # Validation is fully disabled anyway (disable_validation=True below). + "eval_interval": 100, + }, + "resume_checkpoint_path": None, + "find_last_ckpt_for_resume": False, + "parameter_count": None, + "init_checkpoint_path": None, + "model": { + "student_weights_dtype": "bf16", + "model_overrides": { + "delete_old_checkpoints": True, + "save_interval_seconds": None, + # Effectively disable step-interval saving; rely on save_checkpoint_when_done. + "save_interval": 1_000_000_000, + "save_checkpoint_when_done": True, + }, + "model_config_overrides": overrides, + }, + "model_factory": { + "factory": "bypass_factory_fn", + "block_loss_func": "normalized_mse_loss", + "gqa_init_mode": "AverageKV", + "mlp_init_mode": mlp_init_mode, + "mlp_init_config": mlp_init_config, + "linear_init_mode": "FromTeacher", + "submodule_for_loss_calculation": None, + "keys_to_learn": "entire_block", + }, + # Disable all validation to keep tests fast. + "disable_initial_validate": True, + "validate_teacher_model": False, + "validate_student_model": False, + "disable_validation": True, + "best_val_loss": 1e9, + "compile": False, + "disable_fa2": False, + "teacher_model_load_on_cpu": False, + "save_checkpoint_before_training": False, + "disable_checkpoint_save": False, + "save_best_ckpt": True, + # Do NOT use kill_after_first_save — it raises RuntimeError which becomes sys.exit(1). + # Instead let the short training run (2 steps) complete naturally. + "kill_after_first_save": False, + "realize_best_or_latest": "best", + "wandb_log": False, + "wandb": {"project": None, "entity": None}, + } + + if configs_list is not None: + cfg["configs"] = configs_list + + return cfg + + +def _expected_experiment_id(bypass_cfg_dict: dict) -> str: + """Compute the experiment_id that ``set_experiment_id`` will assign. + + Avoids duplicating the formula in tests — uses the same function the runtime uses. + """ + cfg = OmegaConf.create({"bypass": copy.deepcopy(bypass_cfg_dict)}) + set_experiment_id(cfg) + return cfg.bypass.experiment_id + + +def _setup_hydra_cfg_and_pruning( + project_root_path: Path, + tmp_path: Path, + rank: int, + size: int, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, +) -> tuple: + """Set up the tiny model, convert it, score activations, and create pruning ckpts. + + Returns ``(puzzle_dir, dataset_path, hydra_cfg)``. + + Steps performed: + 1. Create a small HF model and dummy dataset via ``setup_test_model_and_data``. + 2. Convert the HF checkpoint to AnyModel/DeciLM format (rank 0 only). + 3. Load the per-family Hydra config with ``puzzle_dir`` and ``dataset_path`` overrides. + 4. Run ``score_pruning_activations`` (distributed). + 5. Run ``pruning_ckpts`` (rank 0 only) then barrier. + """ + set_seed(SEED) + dist.setup(timeout=timedelta(10)) + + puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, hf_model_name, hybrid_override_pattern + ) + + hydra_config_dir = str( + project_root_path / "tests/gpu/torch/puzzletron/resources/configs" + ) + # Per-family hydra config name follows the layout configs///. + hydra_config_name = f"{hf_model_name}/{Path(hf_model_name).name}" + + # Step 0: Convert HF checkpoint to AnyModel/DeciLM format. + if rank == 0: + convert_model( + input_dir=str(hf_checkpoint_path), + output_dir=str(puzzle_dir / "ckpts/teacher"), + converter=converter, + ) + dist.barrier() + + # Step 1: Load Hydra config. + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[ + f"puzzle_dir={puzzle_dir}", + f"dataset_path={dataset_path}", + ], + ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + # Step 2: Score pruning activations (distributed). + score_pruning_activations.launch_score_activations(hydra_cfg) + + # Step 3: Create pruning checkpoints (rank 0 only). + if rank == 0: + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() + + return puzzle_dir, dataset_path, hydra_cfg + + +# --------------------------------------------------------------------------- +# Tests — each parametrized over PUZZLETRON_FAMILIES +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + PUZZLETRON_FAMILIES, +) +def test_bypass_block_pruning( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Bypass distillation with the per-block sub-component pruned. + + For dense families, prunes FFN intermediate (512 -> 256). For MoE families, + prunes num_local_experts (16 -> 8). KV heads are also halved (8 -> 4). + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_block_pruning_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_block_pruning_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size, + hf_model_name, converter, hybrid_override_pattern, + ) + + bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_block_pruning[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + PUZZLETRON_FAMILIES, +) +def test_bypass_kv_head_compression( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Bypass distillation with KV heads halved (8 -> 4) and FFN block pinned to teacher. + + For dense, the experiment_id will be ``bypass_ffn_512_heads_4`` (FFN at teacher size, + attention halved). For MoE, ``bypass_experts_16_heads_4``. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_kv_head_compression_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_kv_head_compression_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size, + hf_model_name, converter, hybrid_override_pattern, + ) + + bypass_cfg_dict = _make_bypass_cfg_dict( + has_moe_layers, + hydra_cfg, + block_pruned=False, # keep FFN/experts at teacher + attention_pruned=True, # halve KV heads + ) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_kv_head_compression[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + PUZZLETRON_FAMILIES, +) +def test_bypass_multi_config_sequential( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Bypass distillation sweep: two configs run sequentially via bypass.configs list. + + Config 0: block pruned + attention pruned + Config 1: block at teacher + attention pruned + Both checkpoint symlinks must exist after the sweep completes. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_multi_config_sequential_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_multi_config_sequential_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size, + hf_model_name, converter, hybrid_override_pattern, + ) + + configs_list = [ + { + "model_config_overrides": { + "ffn": [_block_override(has_moe_layers, pruned=True)], + "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], + }, + "keys_to_learn": "entire_block", + }, + { + "model_config_overrides": { + "ffn": [_block_override(has_moe_layers, pruned=False)], + "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], + }, + "keys_to_learn": "entire_block", + }, + ] + bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg, configs_list=configs_list) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + # Compute expected IDs by running set_experiment_id against each sub-config. + expected_ids = [] + for sub in configs_list: + sub_cfg = copy.deepcopy(bypass_cfg_dict) + sub_cfg["model"]["model_config_overrides"] = sub["model_config_overrides"] + sub_cfg["experiment_id"] = None + expected_ids.append(_expected_experiment_id(sub_cfg)) + + for experiment_id in expected_ids: + experiment_dir = puzzle_dir / "bypass/bypass_runs" / experiment_id + ckpt_symlink = puzzle_dir / "ckpts" / experiment_id + + assert experiment_dir.exists(), ( + f"Expected bypass experiment directory to exist: {experiment_dir}" + ) + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink to exist: {ckpt_symlink}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_multi_config_sequential[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + PUZZLETRON_FAMILIES, +) +def test_bypass_checkpoint_contents( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Verify that a bypass checkpoint contains expected HuggingFace model files.""" + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_checkpoint_contents_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_checkpoint_contents_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size, + hf_model_name, converter, hybrid_override_pattern, + ) + + bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + ckpt_symlink = puzzle_dir / "ckpts" / expected_experiment_id + + assert ckpt_symlink.exists() or ckpt_symlink.is_symlink(), ( + f"Expected bypass checkpoint symlink: {ckpt_symlink}" + ) + + # The symlink resolves to the latest checkpoint dir; verify HF config exists. + resolved = ckpt_symlink.resolve() + config_json = resolved / "config.json" + assert config_json.exists(), ( + f"Expected HuggingFace config.json inside checkpoint: {config_json}" + ) + + # The saving_completed marker must be present (set by save_bypass_checkpoint). + saving_completed = resolved / "saving_completed" + assert saving_completed.exists(), ( + f"Expected saving_completed marker inside checkpoint: {saving_completed}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_checkpoint_contents[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) diff --git a/tests/gpu/torch/puzzletron/test_puzzletron.py b/tests/gpu/torch/puzzletron/test_puzzletron.py index d44cbc71e9c..a00b81fd642 100644 --- a/tests/gpu/torch/puzzletron/test_puzzletron.py +++ b/tests/gpu/torch/puzzletron/test_puzzletron.py @@ -22,7 +22,7 @@ import transformers from _test_utils.torch.distributed.utils import spawn_multiprocess_job from _test_utils.torch.misc import set_seed -from _test_utils.torch.puzzletron.utils import setup_test_model_and_data +from _test_utils.torch.puzzletron.utils import PUZZLETRON_FAMILIES, setup_test_model_and_data from packaging.version import Version import modelopt.torch.puzzletron as mtpz @@ -39,17 +39,7 @@ @pytest.mark.parametrize( ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), - [ - ("meta-llama/Llama-3.1-8B-Instruct", "llama", None, False), - ("meta-llama/Llama-3.2-3B-Instruct", "llama", None, False), - ("mistralai/Mistral-Small-24B-Instruct-2501", "mistral_small", None, False), - ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16", "nemotron_h", "*E", True), - ("nvidia/NVIDIA-Nemotron-Nano-12B-v2", "nemotron_h_v2", "*-", False), - ("openai/gpt-oss-20b", "gpt_oss", None, True), - ("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False), - ("Qwen/Qwen3-8B", "qwen3", None, False), - ("Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True), - ], + PUZZLETRON_FAMILIES, ) def test_puzzletron( project_root_path: Path, diff --git a/tests/unit/torch/puzzletron/test_bypass_losses.py b/tests/unit/torch/puzzletron/test_bypass_losses.py new file mode 100644 index 00000000000..759fb5fa34a --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_losses.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for normalized MSE loss functions in sewing_kit/utils.py.""" + +import pytest +import torch + +from modelopt.torch.puzzletron.sewing_kit.utils import ( + batched_normalized_mse_loss, + normalized_mse_loss, + vectorwise_normalized_mse_loss, +) + + +# --------------------------------------------------------------------------- +# normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_normalized_mse_loss_identical_tensors(): + """Identical input and target should produce a loss of approximately 0.""" + torch.manual_seed(42) + x = torch.randn(4, 8) + loss = normalized_mse_loss(x, x) + assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-6) + + +def test_normalized_mse_loss_basic(): + """Loss should be positive and finite for random, non-identical tensors.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target) + assert loss.item() > 0.0 + assert torch.isfinite(loss) + + +def test_normalized_mse_loss_reduction_none(): + """With reduction='none' the output shape should match the input shape.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target, reduction="none") + assert loss.shape == input_.shape + + +def test_normalized_mse_loss_reduction_sum(): + """With reduction='sum' the output should be a scalar tensor.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = normalized_mse_loss(input_, target, reduction="sum") + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + + +# --------------------------------------------------------------------------- +# vectorwise_normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_vectorwise_normalized_mse_loss_shape(): + """vectorwise_normalized_mse_loss should return a scalar for any 2-D input.""" + torch.manual_seed(42) + input_ = torch.randn(4, 16) + target = torch.randn(4, 16) + loss = vectorwise_normalized_mse_loss(input_, target) + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + + +def test_vectorwise_normalized_mse_loss_identical(): + """Identical input and target should give a loss of approximately 0.""" + torch.manual_seed(42) + x = torch.randn(4, 16) + loss = vectorwise_normalized_mse_loss(x, x) + assert torch.allclose(loss, torch.zeros_like(loss), atol=1e-6) + + +# --------------------------------------------------------------------------- +# batched_normalized_mse_loss +# --------------------------------------------------------------------------- + + +def test_batched_normalized_mse_loss_basic(): + """Should return a scalar with a positive, finite value for random tensors.""" + torch.manual_seed(42) + input_ = torch.randn(4, 8) + target = torch.randn(4, 8) + loss = batched_normalized_mse_loss(input_, target) + assert loss.ndim == 0 # scalar + assert loss.item() > 0.0 + assert torch.isfinite(loss) + + +def test_batched_normalized_mse_loss_custom_dims(): + """Custom batch_dims=(0, 1) on a 3-D tensor should still return a scalar.""" + torch.manual_seed(42) + input_ = torch.randn(2, 3, 8) + target = torch.randn(2, 3, 8) + loss = batched_normalized_mse_loss(input_, target, batch_dims=(0, 1)) + assert loss.ndim == 0 # scalar + assert torch.isfinite(loss) + assert loss.item() > 0.0 diff --git a/tests/unit/torch/puzzletron/test_bypass_utils.py b/tests/unit/torch/puzzletron/test_bypass_utils.py new file mode 100644 index 00000000000..8e4551e24e3 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_utils.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for get_distributed_modules_ownership in bypass_utils.py.""" + +import pytest + +from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import ( + get_distributed_modules_ownership, +) + + +def test_single_gpu_all_to_rank_0(): + """With world_size=1, all 4 modules should be assigned to rank 0.""" + ownership = get_distributed_modules_ownership(module_count=4, world_size=1) + assert ownership == [0, 0, 0, 0] + + +def test_even_distribution(): + """With world_size=2 and 4 modules, each rank should own exactly 2 modules.""" + ownership = get_distributed_modules_ownership(module_count=4, world_size=2) + assert ownership.count(0) == 2 + assert ownership.count(1) == 2 + assert len(ownership) == 4 + + +def test_uneven_distribution(): + """With world_size=2 and 3 modules, rank 0 should own 2 and rank 1 should own 1.""" + ownership = get_distributed_modules_ownership(module_count=3, world_size=2) + assert ownership.count(0) == 2 + assert ownership.count(1) == 1 + assert len(ownership) == 3 + + +@pytest.mark.parametrize( + ("module_count", "world_size"), + [ + (1, 1), + (4, 1), + (4, 2), + (4, 4), + (7, 3), + (10, 4), + (1, 2), + ], +) +def test_total_equals_module_count(module_count, world_size): + """The length of the ownership list must always equal module_count.""" + ownership = get_distributed_modules_ownership( + module_count=module_count, world_size=world_size + ) + assert len(ownership) == module_count + + +def test_consecutive_ownership(): + """Each rank should own a contiguous block of indices (no interleaving).""" + ownership = get_distributed_modules_ownership(module_count=7, world_size=3) + # Verify that once we see a new rank, we never see the previous rank again. + seen_ranks = set() + prev_rank = ownership[0] + seen_ranks.add(prev_rank) + for rank in ownership[1:]: + if rank != prev_rank: + assert rank not in seen_ranks, ( + f"Rank {rank} appears non-consecutively in ownership list: {ownership}" + ) + seen_ranks.add(rank) + prev_rank = rank + + +def test_single_module(): + """With world_size=2 and only 1 module, rank 0 should be the sole owner.""" + ownership = get_distributed_modules_ownership(module_count=1, world_size=2) + assert ownership == [0] + assert len(ownership) == 1 From 5098ee5a70f5cc1df1fee336217063363a0ac980 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Wed, 6 May 2026 04:06:50 -0700 Subject: [PATCH 02/13] Fix mypy attr-defined errors from top-level datasets imports MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit HF datasets uses lazy __getattr__ at the package level (PEP 562), so mypy can't resolve top-level names — `from datasets import DatasetDict` fails with attr-defined. Switch to submodule imports (e.g. `from datasets.dataset_dict import DatasetDict`, `from datasets.load import load_dataset`) which bypass the lazy loader. Also folds in pre-commit cleanups across the bypass changeset: - ruff E501/N806/PT006 lint fixes (uppercase N in main.py, line length in PUZZLETRON_FAMILIES + main.py MIP-sweep mprint, parametrize tuple shape in test_bypass_utils.py). - markdownlint MD040 (fenced-code language tag in tutorial md). - ruff format auto-applied (PUZZLETRON_FAMILIES table, descriptor imports, etc.). - yamlfmt auto-applied to bypass YAML configs. - Drop dead StitchedModelFactoryFn type alias and its cast() call. - Annotate `descriptor` as ModelDescriptor (not type[ModelDescriptor]) to match the codebase convention used in init_child_from_parent.py (annotated as instance, called with the class — no-op at runtime, silences mypy). Signed-off-by: Sepehr Sameni --- .../Nemotron-3-Nano-30B-A3B-Base-BF16.md | 4 +- .../bypass/defaults.yaml | 16 +- .../bypass/defaults.yaml | 14 +- examples/puzzletron/main.py | 3 +- .../gpt_oss/gpt_oss_model_descriptor.py | 5 +- .../nemotron_h/nemotron_h_model_descriptor.py | 5 +- .../nemotron_h_v2_model_descriptor.py | 5 +- .../bypass_distillation/__init__.py | 2 + .../bypass_checkpoint_utils.py | 4 +- .../bypass_distillation/data_classes.py | 1 - .../stitched_model_factory.py | 161 +++++++----------- .../bypass_distillation/training_loop.py | 97 +++++------ .../puzzletron/dataset/prepare_dataset.py | 14 +- .../puzzletron/utils/data/dataloaders.py | 21 ++- modelopt/torch/utils/dataset_utils.py | 9 +- modelopt/torch/utils/vlm_dataset_utils.py | 8 +- tests/_test_utils/torch/puzzletron/utils.py | 6 +- tests/gpu/torch/puzzletron/test_bypass.py | 55 ++++-- .../torch/puzzletron/test_bypass_losses.py | 2 - .../torch/puzzletron/test_bypass_utils.py | 4 +- 20 files changed, 213 insertions(+), 223 deletions(-) diff --git a/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md b/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md index 3f48460cb2d..f085f7165fc 100644 --- a/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md +++ b/examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md @@ -19,9 +19,11 @@ Both runs use the same MIP solver and the same constraint (`target_num_kv_heads: - Container: `nvcr.io/nvidia/nemo:26.04` or later. - `pip install -e ".[dev]"` from the modelopt repo root. - Mamba kernels (required by Nemotron-3-Nano's hybrid backbone): + ```bash pip install mamba-ssm[causal-conv1d] --no-build-isolation ``` + - HF auth set up so the model is downloadable: `huggingface-cli login`. ## Step A — pipeline without bypass @@ -37,7 +39,7 @@ This runs the 8-step puzzletron pipeline (convert → score pruning activations When the realize-model step finishes, the log lines at `${puzzle_dir}/log.txt` contain: -``` +```text validate_model_with_kl_div(model_name='teacher', ...) Average losses = {'lm_loss': ..., 'token_accuracy_top_1': ..., 'token_accuracy_top_5': ..., 'token_accuracy_top_10': ...} ... diff --git a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml index 7a0be378949..0545c6700d2 100644 --- a/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml +++ b/examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml @@ -27,7 +27,7 @@ data: keep_in_memory: false val_dataset_name: valid max_eval_samples: 4 - eval_samples_per_process: null # Samples per GPU during distributed eval (auto if null) + eval_samples_per_process: # Samples per GPU during distributed eval (auto if null) shuffle_train_data_seed: ${random_int:0,9999} # Seed for shuffling train data # Training Configuration @@ -53,10 +53,10 @@ training: eval_interval: 5 # Model Loading Configuration -resume_checkpoint_path: null # Path to resume training from checkpoint -find_last_ckpt_for_resume: True # Auto-resume by finding last checkpoint (bool) -parameter_count: null -init_checkpoint_path: null # Path to initialize weights from +resume_checkpoint_path: # Path to resume training from checkpoint +find_last_ckpt_for_resume: true # Auto-resume by finding last checkpoint (bool) +parameter_count: +init_checkpoint_path: # Path to initialize weights from model: student_weights_dtype: "bf16" # Student model weight precision @@ -83,10 +83,10 @@ model_factory: gqa_init_mode: AverageKV # How to initialize K/V heads in GQA. All options here: GQAInitMode mlp_init_mode: Truncate # MLP initialization. All options here: MlpInitMode mlp_init_config: # Configuration for MLP initialization (if needed) - activations_log_dir: null # Directory with activation statistics (required for PruneByActivationsLog) + activations_log_dir: # Directory with activation statistics (required for PruneByActivationsLog) linear_init_mode: FromTeacher # How to initialize linear layers: FromTeacher, Random, etc. - submodule_for_loss_calculation: null # Specific submodule for loss calc. - keys_to_learn: null # What parameters to train. Either "entire_block", or specific submodules. Computed dynamically. + submodule_for_loss_calculation: # Specific submodule for loss calc. + keys_to_learn: # What parameters to train. Either "entire_block", or specific submodules. Computed dynamically. # Validation Configuration disable_initial_validate: false diff --git a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml index f9f744d31ce..d855dbf6244 100644 --- a/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml +++ b/examples/puzzletron/configs/nemotron-3-nano-30b-a3b/bypass/defaults.yaml @@ -32,7 +32,7 @@ data: keep_in_memory: false val_dataset_name: valid max_eval_samples: 4 - eval_samples_per_process: null + eval_samples_per_process: shuffle_train_data_seed: ${random_int:0,9999} # Training Configuration @@ -58,10 +58,10 @@ training: eval_interval: 100 # Model Loading Configuration -resume_checkpoint_path: null -find_last_ckpt_for_resume: True -parameter_count: null -init_checkpoint_path: null +resume_checkpoint_path: +find_last_ckpt_for_resume: true +parameter_count: +init_checkpoint_path: model: student_weights_dtype: "bf16" @@ -86,9 +86,9 @@ model_factory: gqa_init_mode: AverageKV mlp_init_mode: Truncate # FFN is frozen; this knob is dormant for KV-only tasks mlp_init_config: - activations_log_dir: null + activations_log_dir: linear_init_mode: FromTeacher - submodule_for_loss_calculation: null + submodule_for_loss_calculation: keys_to_learn: subblock_attention # train ONLY the attention sub-block # Validation Configuration diff --git a/examples/puzzletron/main.py b/examples/puzzletron/main.py index 990609da4ec..f0ff54e8be7 100644 --- a/examples/puzzletron/main.py +++ b/examples/puzzletron/main.py @@ -152,7 +152,8 @@ def run_mip_only(hydra_config_path: str): # Check if sweep mode is enabled if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): mtpz.tools.mprint( - f"Puzzletron Progress {mip_step}/{total_steps}: running MIP sweep for multiple compression rates (multi-gpu)" + f"Puzzletron Progress {mip_step}/{total_steps}:" + " running MIP sweep for multiple compression rates (multi-gpu)" ) mtpz.mip.run_mip_sweep(hydra_cfg) else: diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py index 342766c949c..eb2cfe68688 100644 --- a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py @@ -28,10 +28,7 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) -from ....pruning.kv_heads_pruning_mixin import ( - KVHeadsLayerDescriptor, - KVHeadsPruningMixIn, -) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn # Expert removal is supported for unquantized models (test models). # Production models use MXFP4 quantized MoE with combined tensors diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py index 52667b91f70..b3f33887367 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py @@ -29,10 +29,7 @@ ExpertRemovalLayerDescriptor, ExpertRemovalPruningMixIn, ) -from ....pruning.kv_heads_pruning_mixin import ( - KVHeadsLayerDescriptor, - KVHeadsPruningMixIn, -) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same diff --git a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py index aefe0919e9d..0c677f67542 100644 --- a/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/nemotron_h_v2/nemotron_h_v2_model_descriptor.py @@ -29,10 +29,7 @@ FFNIntermediateLayerDescriptor, FFNIntermediatePruningMixIn, ) -from ....pruning.kv_heads_pruning_mixin import ( - KVHeadsLayerDescriptor, - KVHeadsPruningMixIn, -) +from ....pruning.kv_heads_pruning_mixin import KVHeadsLayerDescriptor, KVHeadsPruningMixIn from ....pruning.pruning_mixin import PruningMixIn from ...model_descriptor import ModelDescriptor, ModelDescriptorFactory from ...puzzformer.no_op import MatchingZeros, Same diff --git a/modelopt/torch/puzzletron/bypass_distillation/__init__.py b/modelopt/torch/puzzletron/bypass_distillation/__init__.py index 790166b4519..119cbd5cdaf 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/__init__.py +++ b/modelopt/torch/puzzletron/bypass_distillation/__init__.py @@ -20,3 +20,5 @@ """ from .training_loop import launch_bypass_distillation + +__all__ = ["launch_bypass_distillation"] diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py index d1d95939282..673964b59b0 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py @@ -18,7 +18,7 @@ import re from collections import OrderedDict from pathlib import Path -from typing import Optional, Type, Union +from typing import Optional, Union import torch from omegaconf import DictConfig @@ -157,7 +157,7 @@ def _save_local_state( def save_bypass_checkpoint( cfg: DictConfig, - descriptor: Type[ModelDescriptor], + descriptor: ModelDescriptor, model: torch.nn.Module, stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], checkpoint_dir: Path | str, diff --git a/modelopt/torch/puzzletron/bypass_distillation/data_classes.py b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py index 7c169e9c427..a6b37099ceb 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/data_classes.py +++ b/modelopt/torch/puzzletron/bypass_distillation/data_classes.py @@ -18,7 +18,6 @@ import dataclasses from typing import TypeAlias - IterNum: TypeAlias = int GlobalRank: TypeAlias = int diff --git a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py index 815750a1919..5d90b572f2b 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py +++ b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py @@ -21,7 +21,7 @@ from argparse import Namespace from collections import OrderedDict from pathlib import Path -from typing import Any, Callable, Mapping, Optional, Sequence, Type +from typing import Any, Callable, Mapping, Optional, Sequence import torch from omegaconf import DictConfig, OmegaConf @@ -31,11 +31,7 @@ import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor -from modelopt.torch.puzzletron.pruning.pruning_utils import ( - GQAInitMode, - LinearInitMode, - MlpInitMode, -) +from modelopt.torch.puzzletron.pruning.pruning_utils import GQAInitMode, LinearInitMode, MlpInitMode from modelopt.torch.puzzletron.sewing_kit import ( ExternalTarget, FunctionTarget, @@ -75,31 +71,14 @@ class StitchedModuleDescriptor: grad_scaler: Optional[GradScaler] = None -def default_factory( - teacher_model: PreTrainedModel, - descriptor: Type[ModelDescriptor], - config: Config, - model_blocks_process_ownership: Sequence[int], - student_model: Optional[PreTrainedModel] = None, -) -> tuple[ - PreTrainedModel, - StitchedModule, - StitchedModule, - StitchedModule, - OrderedDict[str, StitchedModuleDescriptor], - PretrainedConfig, -]: - raise NotImplementedError() - - -StitchedModelFactoryFn = type(default_factory) - -_SUBBLOCK_KEYS_TO_LEARN = frozenset({"subblock_ffn", "subblock_attention", "subblock_mamba", "entire_block"}) +_SUBBLOCK_KEYS_TO_LEARN = frozenset( + {"subblock_ffn", "subblock_attention", "subblock_mamba", "entire_block"} +) def _set_keys_to_learn( model: PreTrainedModel, - descriptor: Type[ModelDescriptor], + descriptor: ModelDescriptor, keys_to_learn: str | Sequence[str], ) -> None: """Set ``requires_grad=True`` on parameters selected by ``keys_to_learn``. @@ -114,66 +93,60 @@ def _set_keys_to_learn( if isinstance(keys_to_learn, Sequence) and not isinstance(keys_to_learn, str): param_names = set(keys_to_learn) # If keys_to_learn is a single string. - else: - # If keys_to_learn is a single string that is a subblock key. - if keys_to_learn in _SUBBLOCK_KEYS_TO_LEARN: - lm_config = descriptor.get_language_model_config(model.config) - weight_groups = descriptor.get_weight_groups( - model.state_dict().keys(), lm_config.num_hidden_layers - ) + # If keys_to_learn is a single string that is a subblock key. + elif keys_to_learn in _SUBBLOCK_KEYS_TO_LEARN: + lm_config = descriptor.get_language_model_config(model.config) + weight_groups = descriptor.get_weight_groups( + model.state_dict().keys(), lm_config.num_hidden_layers + ) - attn_group_names = [ - group_name - for group_name in weight_groups.keys() - if group_name.endswith("_attention") - ] - ffn_group_names = [ - group_name - for group_name in weight_groups.keys() - if group_name.endswith("_ffn") - ] - if keys_to_learn == "subblock_attention": - group_names = attn_group_names - elif keys_to_learn == "subblock_ffn": - group_names = ffn_group_names - elif keys_to_learn == "subblock_mamba": - group_names = attn_group_names # Mamba params live in _attention groups - else: # entire_block - group_names = attn_group_names + ffn_group_names - - block_configs = getattr(lm_config, "block_configs", None) - - param_names = [] - for group_name in group_names: - # For hybrid models (e.g. NemotronH), a single "_attention" group - # name can contain either Mamba SSM params *or* GQA params depending - # on the block. Use the block config — not the keys_to_learn string - # — to decide whether each block belongs to the current subblock type. - if block_configs is not None: - m = re.match(r"block_(\d+)_attention", group_name) - if m: - block_idx = int(m.group(1)) - if block_idx < len(block_configs): - is_mamba = ( - getattr(block_configs[block_idx].attention, "mamba", None) - is not None - ) - # subblock_attention → GQA blocks only (not Mamba) - # subblock_mamba → Mamba blocks only (not GQA) - # entire_block → all blocks (no filtering) - if keys_to_learn == "subblock_attention" and is_mamba: - continue - if keys_to_learn == "subblock_mamba" and not is_mamba: - continue - param_names.extend(weight_groups[group_name]) - param_names = set(param_names) - # If keys_to_learn is a single string that is not a subblock key, treat as regex. - else: - param_names = { - param_name - for param_name, _ in model.named_parameters() - if re.search(keys_to_learn, param_name) - } + attn_group_names = [ + group_name for group_name in weight_groups.keys() if group_name.endswith("_attention") + ] + ffn_group_names = [ + group_name for group_name in weight_groups.keys() if group_name.endswith("_ffn") + ] + if keys_to_learn == "subblock_attention": + group_names = attn_group_names + elif keys_to_learn == "subblock_ffn": + group_names = ffn_group_names + elif keys_to_learn == "subblock_mamba": + group_names = attn_group_names # Mamba params live in _attention groups + else: # entire_block + group_names = attn_group_names + ffn_group_names + + block_configs = getattr(lm_config, "block_configs", None) + + collected: list[str] = [] + for group_name in group_names: + # For hybrid models (e.g. NemotronH), a single "_attention" group + # name can contain either Mamba SSM params *or* GQA params depending + # on the block. Use the block config — not the keys_to_learn string + # — to decide whether each block belongs to the current subblock type. + if block_configs is not None: + m = re.match(r"block_(\d+)_attention", group_name) + if m: + block_idx = int(m.group(1)) + if block_idx < len(block_configs): + is_mamba = ( + getattr(block_configs[block_idx].attention, "mamba", None) is not None + ) + # subblock_attention → GQA blocks only (not Mamba) + # subblock_mamba → Mamba blocks only (not GQA) + # entire_block → all blocks (no filtering) + if keys_to_learn == "subblock_attention" and is_mamba: + continue + if keys_to_learn == "subblock_mamba" and not is_mamba: + continue + collected.extend(weight_groups[group_name]) + param_names = set(collected) + # If keys_to_learn is a single string that is not a subblock key, treat as regex. + else: + param_names = { + param_name + for param_name, _ in model.named_parameters() + if re.search(keys_to_learn, param_name) + } # In pipeline-parallel training a rank may own only blocks that don't match # keys_to_learn (e.g. a rank with only Mamba blocks during subblock_attention # bypass has no GQA params after the _mamba rename). That is a valid state: @@ -198,7 +171,7 @@ def _get_all_non_persistent_buffers_set(module: torch.nn.Module) -> set[str]: def bypass_factory_fn( teacher_model: PreTrainedModel, - descriptor: Type[ModelDescriptor], + descriptor: ModelDescriptor, cfg: DictConfig, model_blocks_process_ownership: Sequence[int], student_model: Optional[PreTrainedModel] = None, @@ -241,11 +214,12 @@ def bypass_factory_fn( device = torch.device(f"cuda:{dist.local_rank()}") model_config_overrides = cfg.model.model_config_overrides - block_loss_func = { + _block_loss_funcs: dict[str, Callable[..., Any]] = { "normalized_mse_loss": normalized_mse_loss, "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, "batched_normalized_mse_loss": batched_normalized_mse_loss, - }[cfg.model_factory.block_loss_func] + } + block_loss_func = _block_loss_funcs[cfg.model_factory.block_loss_func] mprint(f"{block_loss_func.__name__=}") owned_block_indexes = set( @@ -366,9 +340,7 @@ def bypass_factory_fn( # GQA init mode is optional: only relevant when the student has fewer KV heads than # the teacher. Defaults to AverageKV and is a no-op when head counts are equal. - gqa_init_mode = GQAInitMode( - cfg.model_factory.get("gqa_init_mode", GQAInitMode.AverageKV) - ) + gqa_init_mode = GQAInitMode(cfg.model_factory.get("gqa_init_mode", GQAInitMode.AverageKV)) student_state_dict = create_child_state_dict( pruning_mixin=pruning_mixin, @@ -595,9 +567,7 @@ def bypass_factory_fn( } trainable_params = { - p_name: p - for p_name, p in student_module_parameters.items() - if p.requires_grad + p_name: p for p_name, p in student_module_parameters.items() if p.requires_grad } optimizer = ( @@ -640,7 +610,6 @@ def bypass_factory_fn( ) - # Backward-compatible name aliases gqa_factory_fn = bypass_factory_fn moe_factory_fn = bypass_factory_fn diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index b3ca788888c..796974a5a95 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -31,18 +31,23 @@ from collections import OrderedDict, defaultdict from pathlib import Path from statistics import mean -from typing import Optional, Type, cast +from typing import Optional -import datasets import torch import torch.distributed import transformers from omegaconf import DictConfig from torch.utils.data.dataloader import DataLoader -from transformers import AutoTokenizer, PreTrainedTokenizerBase, PretrainedConfig +from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase +import datasets +import datasets.utils.logging +import modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory as stitched_model_factory_module import modelopt.torch.utils.distributed as dist -from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor, ModelDescriptorFactory +from modelopt.torch.puzzletron.anymodel.model_descriptor import ( + ModelDescriptor, + ModelDescriptorFactory, +) from modelopt.torch.puzzletron.sewing_kit import InputArgs, StitchedModule from modelopt.torch.puzzletron.sewing_kit.utils import fake_tensor from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config @@ -56,8 +61,6 @@ from .data_classes import GlobalRank, IterNum, IterStatistics, LocalTrainingStats, TimeToSaveSignal from .stitched_model_factory import StitchedModuleDescriptor, StitchedModulesProcessOwnership -import modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory as stitched_model_factory_module - time_start = time.time() os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -113,7 +116,7 @@ def launch_bypass_distillation(hydra_cfg: DictConfig) -> None: def train( cfg: DictConfig, - descriptor: Type[ModelDescriptor], + descriptor: ModelDescriptor, student_model: torch.nn.Module, student_stitched_model: StitchedModule, teacher_stitched_model: StitchedModule, @@ -145,9 +148,7 @@ def train( ] # Indices of stitched modules owned by the current process owned_stitched_module_indices = [ - i - for i, owner in enumerate(stitched_modules_process_ownership) - if owner == dist.rank() + i for i, owner in enumerate(stitched_modules_process_ownership) if owner == dist.rank() ] mprint(f"{global_stitched_modules_count=}") mprint(f"{num_stitched_modules_per_process=}") @@ -187,9 +188,7 @@ def train( min_owned_index = min(owned_stitched_module_indices) max_owned_index = max(owned_stitched_module_indices) prev_rank: Optional[int] = ( - None - if min_owned_index - 1 < 0 - else stitched_modules_process_ownership[min_owned_index - 1] + None if min_owned_index - 1 < 0 else stitched_modules_process_ownership[min_owned_index - 1] ) next_rank: Optional[int] = ( None @@ -199,7 +198,9 @@ def train( torch.cuda.synchronize() - mprint(f'Grad scaling status: {"enabled" if cfg.bypass.training.use_grad_scaling else "disabled"}') + mprint( + f"Grad scaling status: {'enabled' if cfg.bypass.training.use_grad_scaling else 'disabled'}" + ) train_iterator = iter(train_dataloader) @@ -295,13 +296,11 @@ def train( del stitched_module_output grad_scaler.scale(stitched_module_loss).backward() else: - stitched_module_loss = torch.full( - [1], fill_value=torch.nan, dtype=torch.float32 - ) + stitched_module_loss = torch.full([1], fill_value=torch.nan, dtype=torch.float32) - iter_stitched_module_losses[stitched_module_name] = ( - stitched_module_loss.to("cpu").item() - ) + iter_stitched_module_losses[stitched_module_name] = stitched_module_loss.to( + "cpu" + ).item() del stitched_module_loss @@ -334,9 +333,7 @@ def train( clip_value=grad_clip, ) else: - raise RuntimeError( - f"Invalid {cfg.bypass.training.grad_clip_type}" - ) + raise RuntimeError(f"Invalid {cfg.bypass.training.grad_clip_type}") assert grad_scaler is not None grad_scaler.step(optimizer) @@ -380,7 +377,10 @@ def train( if dist.is_master(): if cfg.bypass.model.model_overrides.save_interval_seconds is not None: time_now = time.time() - if time_now - time_last_save >= cfg.bypass.model.model_overrides.save_interval_seconds: + if ( + time_now - time_last_save + >= cfg.bypass.model.model_overrides.save_interval_seconds + ): mprint( f"Time to save! {cfg.bypass.model.model_overrides.save_interval_seconds=}, " f"{time_last_save=}, {time_now=}" @@ -409,14 +409,12 @@ def train( highest_iter = list(log_chunk.keys())[-1] highest_iter_stats = iter_stats_history[highest_iter] - losses_by_name = defaultdict[str, list[float]](lambda: []) + losses_by_name = defaultdict[str, list[float]](list) for losses in log_chunk.values(): for name, loss in losses.items(): losses_by_name[name].append(loss) - losses_by_name_avg = { - name: mean(losses) for name, losses in losses_by_name.items() - } + losses_by_name_avg = {name: mean(losses) for name, losses in losses_by_name.items()} # Update best losses tracking for name, current_loss in losses_by_name_avg.items(): @@ -516,9 +514,7 @@ def train( if old_ckpt_path.name != subdir_name: shutil.rmtree(str(old_ckpt_path)) if cfg.bypass.kill_after_first_save: - raise RuntimeError( - "Done saving checkpoint, kill_after_first_save=True" - ) + raise RuntimeError("Done saving checkpoint, kill_after_first_save=True") # Checkpoint saving (step-based or time-based) if not is_accumulating and ( @@ -552,14 +548,10 @@ def train( if cfg.bypass.kill_after_first_save: dist.barrier() - raise RuntimeError( - "Done saving checkpoint, kill_after_first_save=True" - ) + raise RuntimeError("Done saving checkpoint, kill_after_first_save=True") if cfg.bypass.model.model_overrides.delete_old_checkpoints and dist.is_master(): - existing_ckpt_paths = list( - Path(cfg.bypass.experiment_dir).glob("iter-*") - ) + existing_ckpt_paths = list(Path(cfg.bypass.experiment_dir).glob("iter-*")) for old_ckpt_path in existing_ckpt_paths: if old_ckpt_path.name != subdir_name: shutil.rmtree(str(old_ckpt_path)) @@ -691,7 +683,7 @@ def run_bypassed_training(cfg: DictConfig): if cfg.bypass.training.warmup_steps is None: cfg.bypass.training.warmup_steps = 0 - mprint(f'\n{format_global_config(cfg.bypass, "Bypass Configurations")}') + mprint(f"\n{format_global_config(cfg.bypass, 'Bypass Configurations')}") mprint(f"Max token count: {cfg.bypass.training.max_token_count:,}") seed = cfg.bypass.seed @@ -705,9 +697,7 @@ def run_bypassed_training(cfg: DictConfig): assert teacher_model_config is not None - mprint( - f"Load and shard model with: {owned_block_indexes=}, {cfg.teacher_dir=}" - ) + mprint(f"Load and shard model with: {owned_block_indexes=}, {cfg.teacher_dir=}") teacher_model = load_and_shard_model( descriptor=descriptor, checkpoint_path=cfg.teacher_dir, @@ -730,7 +720,9 @@ def run_bypassed_training(cfg: DictConfig): else: max_eval_samples = cfg.bypass.data.max_eval_samples - load_dataset_fn = load_streaming_fn if not cfg.bypass.data.load_from_disk else load_from_disk_fn + load_dataset_fn = ( + load_streaming_fn if not cfg.bypass.data.load_from_disk else load_from_disk_fn + ) train_dataloader = create_train_dataloader( seed=seed, @@ -764,9 +756,7 @@ def run_bypassed_training(cfg: DictConfig): load_dataset_fn=load_dataset_fn, dataset_name=cfg.bypass.data.val_dataset_name, keep_in_memory=cfg.bypass.data.keep_in_memory, - source_datasets_to_discard=cfg.bypass.get( - "source_datasets_to_discard", tuple() - ), + source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()), bos_rate=cfg.bypass.data.bos_rate, ) @@ -778,9 +768,8 @@ def run_bypassed_training(cfg: DictConfig): dist.barrier() with torch.device(device): - stitched_model_factory_fn = cast( - stitched_model_factory_module.StitchedModelFactoryFn, - getattr(stitched_model_factory_module, cfg.bypass.model_factory.factory), + stitched_model_factory_fn = getattr( + stitched_model_factory_module, cfg.bypass.model_factory.factory ) ( student_model, @@ -804,9 +793,7 @@ def run_bypassed_training(cfg: DictConfig): elif cfg.bypass.find_last_ckpt_for_resume: _ckpt_dir = find_latest_run_dir(run_parent_dir=cfg.bypass.experiment_dir) if _ckpt_dir is None: - mprint( - "Couldn't find any run dir for resume, assuming this is the first job" - ) + mprint("Couldn't find any run dir for resume, assuming this is the first job") else: mprint( f"`cfg.bypass.find_last_ckpt_for_resume` is True. " @@ -882,9 +869,11 @@ def run_bypassed_training(cfg: DictConfig): dist.barrier() mprint("Performing dummy runs on stitched modules:") torch.cuda.synchronize() - with torch.no_grad(), torch.autocast( - device_type="cuda", dtype=torch.bfloat16 - ), torch.device(device): + with ( + torch.no_grad(), + torch.autocast(device_type="cuda", dtype=torch.bfloat16), + torch.device(device), + ): input_ids = torch.ones( (cfg.bypass.training.micro_batch_size, cfg.bypass.data.block_size), dtype=torch.long, diff --git a/modelopt/torch/puzzletron/dataset/prepare_dataset.py b/modelopt/torch/puzzletron/dataset/prepare_dataset.py index 0928b111afc..38ea015973e 100644 --- a/modelopt/torch/puzzletron/dataset/prepare_dataset.py +++ b/modelopt/torch/puzzletron/dataset/prepare_dataset.py @@ -15,10 +15,16 @@ import os -import datasets import fire import numpy as np +# Import via submodules: HF `datasets` uses lazy `__getattr__` at the package level +# (PEP 562), so mypy can't see top-level names — `from datasets import DatasetDict` +# fails with `attr-defined`. Submodule paths bypass the lazy loader. +from datasets.combine import concatenate_datasets +from datasets.dataset_dict import DatasetDict +from datasets.load import load_dataset + from ..tools.logger import mprint __all__ = ["process_and_save_dataset"] @@ -40,8 +46,8 @@ def process_and_save_dataset( ) return - ds = datasets.load_dataset(dataset_name, split=split) - ds = datasets.concatenate_datasets(ds) + ds = load_dataset(dataset_name, split=split) + ds = concatenate_datasets(ds) # Filter out samples with reasoning = on ds = ds.filter(lambda x: x["reasoning"] == "off") # Hardcoded for dynamically create a deterministic train-val split @@ -49,7 +55,7 @@ def process_and_save_dataset( generator = np.random.RandomState(seed=seed) ds_split = ds.train_test_split(test_size=0.05, shuffle=True, generator=generator) # Rename dataset names to follow previous conventions - ds_dict = datasets.DatasetDict( + ds_dict = DatasetDict( { "train": ds_split["train"], "valid": ds_split["test"], diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py index a90550b64ec..32e623d083b 100644 --- a/modelopt/torch/puzzletron/utils/data/dataloaders.py +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -19,7 +19,6 @@ from functools import partial from typing import Protocol, TypeVar -import datasets import torch import torch.distributed from accelerate import Accelerator @@ -28,6 +27,14 @@ from tqdm import tqdm from transformers import PreTrainedTokenizerBase +# Import via submodules: HF `datasets` uses lazy `__getattr__` at the package level +# (PEP 562), so mypy can't see top-level names — `from datasets import DatasetDict` +# fails with `attr-defined`. Submodule paths bypass the lazy loader. +from datasets.arrow_dataset import Dataset as HFDataset +from datasets.dataset_dict import DatasetDict +from datasets.features import Features, Value +from datasets.load import load_dataset, load_from_disk + from ...tools.logger import mprint from .dataset import ConstantLengthDataset @@ -53,18 +60,18 @@ def __call__( def load_from_disk_fn( dataset_path: str, content_field: str, keep_in_memory: bool = False ) -> Mapping[str, Dataset]: - return datasets.load_from_disk(dataset_path, keep_in_memory=keep_in_memory) + return load_from_disk(dataset_path, keep_in_memory=keep_in_memory) def load_streaming_fn( dataset_path: str, content_field: str, keep_in_memory: bool = False ) -> Mapping[str, Dataset]: - dataset = datasets.load_dataset( + dataset = load_dataset( dataset_path, streaming=True, - features=datasets.Features( + features=Features( { - content_field: datasets.Value(dtype="string"), + content_field: Value(dtype="string"), } ), keep_in_memory=keep_in_memory, @@ -147,13 +154,13 @@ def create_validation_dataloader( if isinstance(dataset, str): dataset = load_dataset_fn(dataset, content_field, keep_in_memory) - if isinstance(dataset, datasets.Dataset | torch.utils.data.Dataset): + if isinstance(dataset, HFDataset | torch.utils.data.Dataset): valid_data = dataset mprint( "#### Path to specific dataset was given (not DatasetDict), taking it as-is ####" ) else: - assert isinstance(dataset, datasets.DatasetDict) + assert isinstance(dataset, DatasetDict) if dataset_name == "__auto__": val_split_options = [] for val_key_prefix in ("val", "test"): diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 4515d7eda32..409c2b0c230 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -249,7 +249,10 @@ def get_dataset_samples( if dataset_name.endswith(".jsonl"): return get_jsonl_text_samples(dataset_name, num_samples, key="text") - from datasets import load_dataset + # Import via submodule: HF `datasets` uses lazy `__getattr__` at the package + # level (PEP 562), so `from datasets import load_dataset` fails mypy with + # `attr-defined`. The submodule path bypasses the lazy loader. + from datasets.load import load_dataset local_dataset_path = None if os.path.exists(dataset_name): # Local path @@ -753,9 +756,11 @@ def download_hf_dataset_as_jsonl( Returns: List of paths to downloaded JSONL files. """ - from datasets import load_dataset + # See note above: import via submodule to satisfy mypy. from huggingface_hub.utils import build_hf_headers + from datasets.load import load_dataset + print(f"Downloading dataset {dataset_name} from Hugging Face") if isinstance(json_keys, str): json_keys = [json_keys] diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 9de40792e4b..5e75117291d 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -222,7 +222,10 @@ def _get_vlm_dataset( """ # Load the dataset if dataset_name in SUPPORTED_VLM_DATASET_CONFIG: - from datasets import load_dataset + # Import via submodule: HF `datasets` uses lazy `__getattr__` at the package + # level (PEP 562), so `from datasets import load_dataset` fails mypy with + # `attr-defined`. The submodule path bypasses the lazy loader. + from datasets.load import load_dataset cfg = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"].copy() streaming = bool(cfg.pop("streaming", False)) @@ -273,7 +276,8 @@ def _get_vlm_dataset( for subset in subsets ] try: - from datasets import interleave_datasets + # See note above: import via submodule to satisfy mypy. + from datasets.combine import interleave_datasets ds = interleave_datasets(streams) except Exception: diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index e56e93cc0e1..5bfcfe87b2e 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -19,11 +19,11 @@ import pytest import torch from _test_utils.torch.transformers_models import get_tiny_tokenizer -from datasets import Dataset, DatasetDict from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase import modelopt.torch.puzzletron as mtpz import modelopt.torch.utils.distributed as dist +from datasets import Dataset, DatasetDict from modelopt.torch.export import copy_hf_ckpt_remote_code # Shared parametrize tuple for puzzletron GPU integration tests. @@ -57,9 +57,7 @@ pytest.param("openai/gpt-oss-20b", "gpt_oss", None, True, id="gpt-oss-20b"), pytest.param("Qwen/Qwen2.5-7B-Instruct", "qwen2", None, False, id="qwen2.5-7B"), pytest.param("Qwen/Qwen3-8B", "qwen3", None, False, id="qwen3-8B"), - pytest.param( - "Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True, id="qwen3-VL-30B-A3B" - ), + pytest.param("Qwen/Qwen3-VL-30B-A3B-Instruct", "qwen3_vl", None, True, id="qwen3-VL-30B-A3B"), ] diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py index 3c77222c4d0..c9e77df87bc 100644 --- a/tests/gpu/torch/puzzletron/test_bypass.py +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -120,10 +120,13 @@ def _mlp_init_settings(has_moe_layers: bool, hydra_cfg) -> tuple[str, dict]: if not has_moe_layers: return "Truncate", {"activations_log_dir": None} - mlp_init_config = OmegaConf.to_container( - hydra_cfg.pruning.get("mlp_init_config_yaml", OmegaConf.create({})), - resolve=True, - ) or {} + mlp_init_config = ( + OmegaConf.to_container( + hydra_cfg.pruning.get("mlp_init_config_yaml", OmegaConf.create({})), + resolve=True, + ) + or {} + ) mlp_init_config["activations_log_dir"] = str(hydra_cfg.pruning.activations_log_dir) return "ExpertRemoval", mlp_init_config @@ -293,9 +296,7 @@ def _setup_hydra_cfg_and_pruning( tmp_path, rank, hf_model_name, hybrid_override_pattern ) - hydra_config_dir = str( - project_root_path / "tests/gpu/torch/puzzletron/resources/configs" - ) + hydra_config_dir = str(project_root_path / "tests/gpu/torch/puzzletron/resources/configs") # Per-family hydra config name follows the layout configs///. hydra_config_name = f"{hf_model_name}/{Path(hf_model_name).name}" @@ -378,8 +379,13 @@ def _test_bypass_block_pruning_job( size: int, ): puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( - project_root_path, tmp_path, rank, size, - hf_model_name, converter, hybrid_override_pattern, + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, ) bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) @@ -451,15 +457,20 @@ def _test_bypass_kv_head_compression_job( size: int, ): puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( - project_root_path, tmp_path, rank, size, - hf_model_name, converter, hybrid_override_pattern, + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, ) bypass_cfg_dict = _make_bypass_cfg_dict( has_moe_layers, hydra_cfg, - block_pruned=False, # keep FFN/experts at teacher - attention_pruned=True, # halve KV heads + block_pruned=False, # keep FFN/experts at teacher + attention_pruned=True, # halve KV heads ) OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) @@ -530,8 +541,13 @@ def _test_bypass_multi_config_sequential_job( size: int, ): puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( - project_root_path, tmp_path, rank, size, - hf_model_name, converter, hybrid_override_pattern, + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, ) configs_list = [ @@ -623,8 +639,13 @@ def _test_bypass_checkpoint_contents_job( size: int, ): puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( - project_root_path, tmp_path, rank, size, - hf_model_name, converter, hybrid_override_pattern, + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, ) bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) diff --git a/tests/unit/torch/puzzletron/test_bypass_losses.py b/tests/unit/torch/puzzletron/test_bypass_losses.py index 759fb5fa34a..4f6869c4e6d 100644 --- a/tests/unit/torch/puzzletron/test_bypass_losses.py +++ b/tests/unit/torch/puzzletron/test_bypass_losses.py @@ -15,7 +15,6 @@ """Unit tests for normalized MSE loss functions in sewing_kit/utils.py.""" -import pytest import torch from modelopt.torch.puzzletron.sewing_kit.utils import ( @@ -24,7 +23,6 @@ vectorwise_normalized_mse_loss, ) - # --------------------------------------------------------------------------- # normalized_mse_loss # --------------------------------------------------------------------------- diff --git a/tests/unit/torch/puzzletron/test_bypass_utils.py b/tests/unit/torch/puzzletron/test_bypass_utils.py index 8e4551e24e3..63fee390174 100644 --- a/tests/unit/torch/puzzletron/test_bypass_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_utils.py @@ -58,9 +58,7 @@ def test_uneven_distribution(): ) def test_total_equals_module_count(module_count, world_size): """The length of the ownership list must always equal module_count.""" - ownership = get_distributed_modules_ownership( - module_count=module_count, world_size=world_size - ) + ownership = get_distributed_modules_ownership(module_count=module_count, world_size=world_size) assert len(ownership) == module_count From ed0eae5e3d180963b3db561a5263249f42027f1a Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Wed, 6 May 2026 04:56:32 -0700 Subject: [PATCH 03/13] Address CodeRabbit review and revert non-bypass datasets-import workarounds MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bypass-distillation review fixes: - bypass_checkpoint_utils: persist + restore GradScaler state alongside optimizer state so resumed runs with use_grad_scaling=True don't lose the running scale + growth tracker. - stitched_model_factory: raise a clear RuntimeError when world_size exceeds num_hidden_layers (would otherwise crash trailing ranks with a bare `min() arg is an empty sequence`); same condition fires identically on every rank, so no NCCL hang. - training_loop: actually apply skip_first_batches by advancing the data iterator before the loop (parameter was previously accepted but unused). - training_loop: fix off-by-one in the max_steps exit condition (was >=, now >) so the final scheduled step actually runs; same fix applied to the in-loop save-when-done branch so the final checkpoint is saved exactly once. - training_loop: read source_datasets_to_discard from cfg.bypass.data (where the YAML actually nests it) instead of cfg.bypass root, where it always fell back to the empty tuple. - dataloaders: reject num_workers > 0 in create_train_dataloader with an explicit error since ConstantLengthDataset.__iter__ does not shard via torch.utils.data.get_worker_info(); guard removable once the dataset gains worker-aware iteration. - dataloaders: branch on isinstance(train_data, datasets.IterableDataset) before passing keep_in_memory=True to .shuffle(); streaming mode (load_from_disk=false) doesn't accept that kwarg. Revert the datasets submodule-import workarounds (`from datasets.load import load_dataset`, etc.) — CI's mypy resolves top-level `from datasets import X` correctly because it installs the package via uv, so these workarounds were only needed for the local-pre-commit env on a node without datasets installed. Reverting shrinks the MR back to bypass-only files. Signed-off-by: Sepehr Sameni --- .../bypass_checkpoint_utils.py | 34 +++++++++++++++ .../stitched_model_factory.py | 13 ++++++ .../bypass_distillation/training_loop.py | 30 +++++++++---- .../puzzletron/dataset/prepare_dataset.py | 14 ++---- .../puzzletron/utils/data/dataloaders.py | 43 ++++++++++++------- modelopt/torch/utils/dataset_utils.py | 9 +--- modelopt/torch/utils/vlm_dataset_utils.py | 8 +--- tests/_test_utils/torch/puzzletron/utils.py | 2 +- 8 files changed, 106 insertions(+), 47 deletions(-) diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py index 673964b59b0..fbb658d4e57 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py @@ -87,6 +87,7 @@ def load_local_state( for stitched_module_name, stitched_module_descriptor in stitched_module_descriptors.items(): stitched_module = stitched_module_descriptor.stitched_module optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler state_dict_path = load_dir / "stitched" / f"{stitched_module_name}.state_dict.pth" mprint(f"Loading state dict for module {stitched_module_name} from {state_dict_path}") @@ -109,6 +110,25 @@ def load_local_state( optimizer.load_state_dict(loaded_optimizer_state) del loaded_optimizer_state + # Restore GradScaler state (only relevant when use_grad_scaling=True; for the + # default bf16 / use_grad_scaling=False path the scaler is disabled and its + # state is a no-op, but we still load it if present for forward-compatibility). + # Older checkpoints predating this save path won't have the file — skip silently. + if grad_scaler is not None: + grad_scaler_state_path = ( + load_dir / "stitched" / f"{stitched_module_name}.grad_scaler.pth" + ) + if grad_scaler_state_path.exists(): + mprint( + f"Loading grad_scaler state for module {stitched_module_name} " + f"from {grad_scaler_state_path}" + ) + loaded_scaler_state = torch.load( + grad_scaler_state_path, map_location=device, weights_only=True + ) + grad_scaler.load_state_dict(loaded_scaler_state) + del loaded_scaler_state + def _save_local_file(obj, save_path: Path | str, overwrite=True): save_path = Path(save_path) @@ -136,6 +156,7 @@ def _save_local_state( stitched_module_descriptors.items() ): optimizer = stitched_module_descriptor.optimizer + grad_scaler = stitched_module_descriptor.grad_scaler state_dict_path = save_dir / f"{stitched_module_name}.state_dict.pth" aprint(f"Saving state dict for module {stitched_module_name} to {state_dict_path}") @@ -152,6 +173,19 @@ def _save_local_state( ) _save_local_file(optimizer.state_dict(), optimizer_state_path, overwrite=overwrite) + # Persist GradScaler state. Required for correct resume when + # use_grad_scaling=True (state dict carries running scale + growth tracker). + # For the default bf16 / use_grad_scaling=False path the state dict is trivial + # but cheap, so save unconditionally whenever a scaler exists — keeps the + # save/load paths symmetric with the optimizer. + if grad_scaler is not None: + grad_scaler_state_path = save_dir / f"{stitched_module_name}.grad_scaler.pth" + mprint( + f"Saving grad_scaler state for module {stitched_module_name} " + f"to {grad_scaler_state_path}" + ) + _save_local_file(grad_scaler.state_dict(), grad_scaler_state_path, overwrite=overwrite) + dist.barrier() diff --git a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py index 5d90b572f2b..1650cf9ad4c 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py +++ b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py @@ -390,6 +390,19 @@ def bypass_factory_fn( torch.cuda.empty_cache() dist.barrier() + # Every rank derives ownership from the same `model_blocks_process_ownership` + # list, so this guard fires identically on every rank when world_size exceeds + # num_hidden_layers — no NCCL hang from a single rank diverging. + ranks_with_blocks = set(model_blocks_process_ownership) + empty_ranks = [r for r in range(dist.size()) if r not in ranks_with_blocks] + if empty_ranks: + raise RuntimeError( + f"world_size ({dist.size()}) exceeds num_hidden_layers " + f"({len(all_block_indices)}); ranks {empty_ranks} would own 0 blocks. " + f"Pipeline-parallel bypass distillation does not support idle ranks — " + f"reduce nproc_per_node to at most num_hidden_layers." + ) + min_owned_index = min(owned_block_indexes) max_owned_index = max(owned_block_indexes) prev_rank: Optional[int] = ( diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index 796974a5a95..3156439fd86 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -33,6 +33,7 @@ from statistics import mean from typing import Optional +import datasets import torch import torch.distributed import transformers @@ -40,8 +41,6 @@ from torch.utils.data.dataloader import DataLoader from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase -import datasets -import datasets.utils.logging import modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory as stitched_model_factory_module import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.anymodel.model_descriptor import ( @@ -204,6 +203,15 @@ def train( train_iterator = iter(train_dataloader) + # Advance past the first `skip_first_batches` batches before the training loop + # starts. Used either to skip a known-bad batch range during debugging, or to + # roll the data iterator forward when resuming a run (model + optimizer state + # are restored from the checkpoint, but the dataloader itself starts fresh). + if skip_first_batches > 0: + mprint(f"Skipping first {skip_first_batches} batches before training") + for _ in range(skip_first_batches): + next(train_iterator) + mprint("Waiting for everyone before training starts") dist.barrier() @@ -221,8 +229,10 @@ def train( # Train loop start while True: time_now = time.time() - # Check if we've reached the maximum number of steps - if cfg.bypass.step_num >= cfg.bypass.training.max_steps: + # Check if we've reached the maximum number of steps. `step_num` is 1-based + # and incremented at the END of each iteration, so we must use `>` (not `>=`) + # to ensure step `max_steps` itself runs before exiting. + if cfg.bypass.step_num > cfg.bypass.training.max_steps: if ( cfg.bypass.model.model_overrides.save_checkpoint_when_done and not cfg.bypass.disable_checkpoint_save @@ -522,7 +532,7 @@ def train( or step_to_save == cfg.bypass.step_num or ( cfg.bypass.model.model_overrides.save_checkpoint_when_done - and cfg.bypass.step_num >= cfg.bypass.training.max_steps + and cfg.bypass.step_num > cfg.bypass.training.max_steps ) ): if not cfg.bypass.disable_checkpoint_save: @@ -532,7 +542,7 @@ def train( mprint("Saving time-based checkpoint") elif ( cfg.bypass.model.model_overrides.save_checkpoint_when_done - and cfg.bypass.step_num >= cfg.bypass.training.max_steps + and cfg.bypass.step_num > cfg.bypass.training.max_steps ): mprint("Saving final checkpoint") @@ -735,7 +745,9 @@ def run_bypassed_training(cfg: DictConfig): micro_batch_size=cfg.bypass.training.micro_batch_size, load_dataset_fn=load_dataset_fn, keep_in_memory=cfg.bypass.data.keep_in_memory, - source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()), + source_datasets_to_discard=cfg.bypass.data.get( + "source_datasets_to_discard", tuple() + ), bos_rate=cfg.bypass.data.bos_rate, shuffle_seed=cfg.bypass.data.shuffle_train_data_seed, ) @@ -756,7 +768,9 @@ def run_bypassed_training(cfg: DictConfig): load_dataset_fn=load_dataset_fn, dataset_name=cfg.bypass.data.val_dataset_name, keep_in_memory=cfg.bypass.data.keep_in_memory, - source_datasets_to_discard=cfg.bypass.get("source_datasets_to_discard", tuple()), + source_datasets_to_discard=cfg.bypass.data.get( + "source_datasets_to_discard", tuple() + ), bos_rate=cfg.bypass.data.bos_rate, ) diff --git a/modelopt/torch/puzzletron/dataset/prepare_dataset.py b/modelopt/torch/puzzletron/dataset/prepare_dataset.py index 38ea015973e..0928b111afc 100644 --- a/modelopt/torch/puzzletron/dataset/prepare_dataset.py +++ b/modelopt/torch/puzzletron/dataset/prepare_dataset.py @@ -15,16 +15,10 @@ import os +import datasets import fire import numpy as np -# Import via submodules: HF `datasets` uses lazy `__getattr__` at the package level -# (PEP 562), so mypy can't see top-level names — `from datasets import DatasetDict` -# fails with `attr-defined`. Submodule paths bypass the lazy loader. -from datasets.combine import concatenate_datasets -from datasets.dataset_dict import DatasetDict -from datasets.load import load_dataset - from ..tools.logger import mprint __all__ = ["process_and_save_dataset"] @@ -46,8 +40,8 @@ def process_and_save_dataset( ) return - ds = load_dataset(dataset_name, split=split) - ds = concatenate_datasets(ds) + ds = datasets.load_dataset(dataset_name, split=split) + ds = datasets.concatenate_datasets(ds) # Filter out samples with reasoning = on ds = ds.filter(lambda x: x["reasoning"] == "off") # Hardcoded for dynamically create a deterministic train-val split @@ -55,7 +49,7 @@ def process_and_save_dataset( generator = np.random.RandomState(seed=seed) ds_split = ds.train_test_split(test_size=0.05, shuffle=True, generator=generator) # Rename dataset names to follow previous conventions - ds_dict = DatasetDict( + ds_dict = datasets.DatasetDict( { "train": ds_split["train"], "valid": ds_split["test"], diff --git a/modelopt/torch/puzzletron/utils/data/dataloaders.py b/modelopt/torch/puzzletron/utils/data/dataloaders.py index 32e623d083b..4f02d9dcef9 100644 --- a/modelopt/torch/puzzletron/utils/data/dataloaders.py +++ b/modelopt/torch/puzzletron/utils/data/dataloaders.py @@ -19,6 +19,7 @@ from functools import partial from typing import Protocol, TypeVar +import datasets import torch import torch.distributed from accelerate import Accelerator @@ -27,14 +28,6 @@ from tqdm import tqdm from transformers import PreTrainedTokenizerBase -# Import via submodules: HF `datasets` uses lazy `__getattr__` at the package level -# (PEP 562), so mypy can't see top-level names — `from datasets import DatasetDict` -# fails with `attr-defined`. Submodule paths bypass the lazy loader. -from datasets.arrow_dataset import Dataset as HFDataset -from datasets.dataset_dict import DatasetDict -from datasets.features import Features, Value -from datasets.load import load_dataset, load_from_disk - from ...tools.logger import mprint from .dataset import ConstantLengthDataset @@ -60,18 +53,18 @@ def __call__( def load_from_disk_fn( dataset_path: str, content_field: str, keep_in_memory: bool = False ) -> Mapping[str, Dataset]: - return load_from_disk(dataset_path, keep_in_memory=keep_in_memory) + return datasets.load_from_disk(dataset_path, keep_in_memory=keep_in_memory) def load_streaming_fn( dataset_path: str, content_field: str, keep_in_memory: bool = False ) -> Mapping[str, Dataset]: - dataset = load_dataset( + dataset = datasets.load_dataset( dataset_path, streaming=True, - features=Features( + features=datasets.Features( { - content_field: Value(dtype="string"), + content_field: datasets.Value(dtype="string"), } ), keep_in_memory=keep_in_memory, @@ -98,6 +91,19 @@ def create_train_dataloader( num_workers: int = 0, ) -> DataLoader: """Create an infinite training DataLoader over ConstantLengthDataset.""" + # ConstantLengthDataset.__iter__ does not consult torch.utils.data.get_worker_info() + # to shard work across DataLoader workers, so num_workers > 0 would have every + # worker iterate the full dataset and emit duplicate samples. Reject explicitly + # until ConstantLengthDataset gains worker-aware iteration; the guard can then + # be removed. + if num_workers > 0: + raise ValueError( + f"create_train_dataloader: num_workers={num_workers} is not supported " + f"because ConstantLengthDataset.__iter__ does not shard via " + f"torch.utils.data.get_worker_info(). Use num_workers=0 (the default) " + f"or add worker-aware sharding to ConstantLengthDataset.__iter__." + ) + if isinstance(dataset_path, str): dataset = load_dataset_fn(dataset_path, content_field, keep_in_memory) else: @@ -105,7 +111,14 @@ def create_train_dataloader( train_data = dataset[dataset_name] if shuffle_seed is not None: - train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) + # `keep_in_memory` is only valid on map-style HF Datasets; streaming + # `IterableDataset.shuffle()` only accepts `seed` (and an optional + # `buffer_size`). Branch on the dataset type so streaming users + # (`load_from_disk: false`) don't crash on this call. + if isinstance(train_data, datasets.IterableDataset): + train_data = train_data.shuffle(seed=shuffle_seed) + else: + train_data = train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) train_dataset = ConstantLengthDataset( tokenizer, @@ -154,13 +167,13 @@ def create_validation_dataloader( if isinstance(dataset, str): dataset = load_dataset_fn(dataset, content_field, keep_in_memory) - if isinstance(dataset, HFDataset | torch.utils.data.Dataset): + if isinstance(dataset, datasets.Dataset | torch.utils.data.Dataset): valid_data = dataset mprint( "#### Path to specific dataset was given (not DatasetDict), taking it as-is ####" ) else: - assert isinstance(dataset, DatasetDict) + assert isinstance(dataset, datasets.DatasetDict) if dataset_name == "__auto__": val_split_options = [] for val_key_prefix in ("val", "test"): diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 409c2b0c230..4515d7eda32 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -249,10 +249,7 @@ def get_dataset_samples( if dataset_name.endswith(".jsonl"): return get_jsonl_text_samples(dataset_name, num_samples, key="text") - # Import via submodule: HF `datasets` uses lazy `__getattr__` at the package - # level (PEP 562), so `from datasets import load_dataset` fails mypy with - # `attr-defined`. The submodule path bypasses the lazy loader. - from datasets.load import load_dataset + from datasets import load_dataset local_dataset_path = None if os.path.exists(dataset_name): # Local path @@ -756,11 +753,9 @@ def download_hf_dataset_as_jsonl( Returns: List of paths to downloaded JSONL files. """ - # See note above: import via submodule to satisfy mypy. + from datasets import load_dataset from huggingface_hub.utils import build_hf_headers - from datasets.load import load_dataset - print(f"Downloading dataset {dataset_name} from Hugging Face") if isinstance(json_keys, str): json_keys = [json_keys] diff --git a/modelopt/torch/utils/vlm_dataset_utils.py b/modelopt/torch/utils/vlm_dataset_utils.py index 5e75117291d..9de40792e4b 100644 --- a/modelopt/torch/utils/vlm_dataset_utils.py +++ b/modelopt/torch/utils/vlm_dataset_utils.py @@ -222,10 +222,7 @@ def _get_vlm_dataset( """ # Load the dataset if dataset_name in SUPPORTED_VLM_DATASET_CONFIG: - # Import via submodule: HF `datasets` uses lazy `__getattr__` at the package - # level (PEP 562), so `from datasets import load_dataset` fails mypy with - # `attr-defined`. The submodule path bypasses the lazy loader. - from datasets.load import load_dataset + from datasets import load_dataset cfg = SUPPORTED_VLM_DATASET_CONFIG[dataset_name]["config"].copy() streaming = bool(cfg.pop("streaming", False)) @@ -276,8 +273,7 @@ def _get_vlm_dataset( for subset in subsets ] try: - # See note above: import via submodule to satisfy mypy. - from datasets.combine import interleave_datasets + from datasets import interleave_datasets ds = interleave_datasets(streams) except Exception: diff --git a/tests/_test_utils/torch/puzzletron/utils.py b/tests/_test_utils/torch/puzzletron/utils.py index 5bfcfe87b2e..82a66c1bb00 100644 --- a/tests/_test_utils/torch/puzzletron/utils.py +++ b/tests/_test_utils/torch/puzzletron/utils.py @@ -19,11 +19,11 @@ import pytest import torch from _test_utils.torch.transformers_models import get_tiny_tokenizer +from datasets import Dataset, DatasetDict from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase import modelopt.torch.puzzletron as mtpz import modelopt.torch.utils.distributed as dist -from datasets import Dataset, DatasetDict from modelopt.torch.export import copy_hf_ckpt_remote_code # Shared parametrize tuple for puzzletron GPU integration tests. From ba0e6e824dd63c28409631d190b892afb3041086 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Wed, 6 May 2026 05:08:11 -0700 Subject: [PATCH 04/13] Remove unused factory aliases; reflow source_datasets_to_discard kwargs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Drop the `gqa_factory_fn` / `moe_factory_fn` backward-compat aliases in stitched_model_factory.py — no remaining call sites or YAML configs reference them after the unified `bypass_factory_fn` migration. - Apply ruff-format reflow to the two `source_datasets_to_discard=...` kwargs in training_loop.py (collapse the train-side call onto one line; re-indent the val-side multi-line call body). Signed-off-by: Sepehr Sameni --- .../bypass_distillation/stitched_model_factory.py | 5 ----- .../torch/puzzletron/bypass_distillation/training_loop.py | 8 +++----- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py index 1650cf9ad4c..21a89d11762 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py +++ b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py @@ -621,8 +621,3 @@ def bypass_factory_fn( stitched_module_descriptors, student_model_config, ) - - -# Backward-compatible name aliases -gqa_factory_fn = bypass_factory_fn -moe_factory_fn = bypass_factory_fn diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index 3156439fd86..f57d998d7b3 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -745,9 +745,7 @@ def run_bypassed_training(cfg: DictConfig): micro_batch_size=cfg.bypass.training.micro_batch_size, load_dataset_fn=load_dataset_fn, keep_in_memory=cfg.bypass.data.keep_in_memory, - source_datasets_to_discard=cfg.bypass.data.get( - "source_datasets_to_discard", tuple() - ), + source_datasets_to_discard=cfg.bypass.data.get("source_datasets_to_discard", tuple()), bos_rate=cfg.bypass.data.bos_rate, shuffle_seed=cfg.bypass.data.shuffle_train_data_seed, ) @@ -769,8 +767,8 @@ def run_bypassed_training(cfg: DictConfig): dataset_name=cfg.bypass.data.val_dataset_name, keep_in_memory=cfg.bypass.data.keep_in_memory, source_datasets_to_discard=cfg.bypass.data.get( - "source_datasets_to_discard", tuple() - ), + "source_datasets_to_discard", tuple() + ), bos_rate=cfg.bypass.data.bos_rate, ) From f85e800bfbc6c9a989ad862a2b45bbcede4a662e Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Wed, 6 May 2026 08:42:08 -0700 Subject: [PATCH 05/13] Add bypass test coverage and address remaining review findings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Test additions -------------- Three new test files plus three new tests appended to test_bypass.py covering paths the existing 4×9-family happy-path tests didn't reach: - tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py — pins all 8 branches of `_set_keys_to_learn` (subblock_ffn / subblock_attention / subblock_mamba / entire_block / list / regex / hybrid block_configs filter / no-match silent-return). The hybrid Mamba-vs-GQA filter is silently misroutable on descriptor refactors; this test catches it. - tests/unit/torch/puzzletron/test_bypass_replacement_library.py — verifies `_get_last_checkpoint_from_each_experiment` discovers symlinked bypass + pruning checkpoints, and that the bypass-priority sort closure orders bypass-rooted paths before Truncate-init ones (a regression here would silently discard bypass-trained weights). - tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py — save/load round-trip for stitched-module state, optimizer state, and (new) GradScaler state. The GradScaler round-trip is the regression test for the recent CodeRabbit-flagged bug where resumed fp16 + use_grad_scaling runs lost the running scale + growth tracker. Lives under tests/gpu/ rather than tests/unit/ because the production load path constructs `torch.device(f"cuda:{rank}")` and torch.load needs a real CUDA device to deserialize. - tests/gpu/torch/puzzletron/test_bypass.py — three new tests: * test_bypass_resume_from_checkpoint: 2-phase train→save→resume on Llama-3.2-3B, asserts `iter_num` advances past the saved value. GradScaler resume is covered separately at unit level (above) because GradScaler.step() is fp16-only and the bypass infra is bf16. * test_bypass_subblock_modes: parametrized over Llama-3.2-3B (dense Truncate path) × GPT-OSS-20B (MoE ExpertRemoval + windowed-attention- with-sinks path) × {subblock_ffn, subblock_attention, entire_block}; diffs start-of-training vs end-of-training stitched-module weights and asserts only the expected param groups changed. * test_bypass_then_build_library: end-to-end smoke — runs bypass, then `_build_subblocks_df`, asserts the bypass experiment appears in the resulting subblocks DataFrame's checkpoint-source columns. Review-driven fixes ------------------- - bypass_distillation/training_loop.py: * AutoTokenizer.from_pretrained: pass `trust_remote_code=trust_remote_code` (was hardcoded `True`); the variable is already derived from the descriptor a few lines earlier. * Wrap-around `try/except Exception`: re-raise (was `sys.exit(1)`) so pytest sees the real exception type instead of a generic SystemExit, and so distributed runs surface usable tracebacks. - sewing_kit/utils.py: `normalized_mse_loss` is now re-exported from `tools.kd_model` instead of redefined; the two implementations were byte-for-byte identical. Signed-off-by: Sepehr Sameni --- .../bypass_distillation/training_loop.py | 13 +- modelopt/torch/puzzletron/sewing_kit/utils.py | 23 +- tests/gpu/torch/puzzletron/test_bypass.py | 411 ++++++++++++++++++ .../test_bypass_checkpoint_utils.py | 255 +++++++++++ .../puzzletron/test_bypass_keys_to_learn.py | 245 +++++++++++ .../test_bypass_replacement_library.py | 193 ++++++++ 6 files changed, 1115 insertions(+), 25 deletions(-) create mode 100644 tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_replacement_library.py diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index f57d998d7b3..41abcc2281a 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -701,7 +701,7 @@ def run_bypassed_training(cfg: DictConfig): tokenizer = AutoTokenizer.from_pretrained( cfg.teacher_dir, - trust_remote_code=True, + trust_remote_code=trust_remote_code, token=True, ) @@ -948,12 +948,13 @@ def run_bypassed_training(cfg: DictConfig): aprint("Finished training successfully!") dist.barrier() - except Exception as e: + except Exception: + # Print the traceback explicitly so distributed runs surface it on every + # rank's stderr (workers under torchrun otherwise lose ordering), then + # re-raise so test frameworks see the real exception instead of a + # generic SystemExit(1). print(traceback.format_exc(), file=sys.stderr) - if isinstance(e, SystemExit): - raise e - else: - sys.exit(1) + raise dist.barrier() if dist.is_master(): diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index d8df3423a57..2820574bf69 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -23,7 +23,6 @@ Callable, ContextManager, Generic, - Literal, Optional, Protocol, TypeVar, @@ -459,24 +458,10 @@ def _get_group_kwarg_if_necessary() -> dict: # Loss functions for bypass distillation (blockwise local knowledge distillation) # ────────────────────────────────────────────────────────────────────────────── -Reduction = Literal["none", "mean", "sum"] - - -def normalized_mse_loss( - input: torch.Tensor, - target: torch.Tensor, - reduction: Reduction = "mean", - epsilon: float = 1e-6, -) -> torch.Tensor: - """MSE loss normalized by the variance of the target. - - Dividing by the target's self-MSE makes the loss scale-invariant, so that - blocks whose activations have large magnitude do not dominate training. - """ - loss = F.mse_loss(input, target, reduction=reduction) / F.mse_loss( - target, torch.zeros_like(target) + epsilon, reduction=reduction - ) - return loss +# `normalized_mse_loss` already lives in tools.kd_model — re-export it here so +# bypass-distillation imports stay co-located with the per-vector / per-batch +# variants below, without duplicating the implementation. +from modelopt.torch.puzzletron.tools.kd_model import normalized_mse_loss # noqa: E402, F401 def vectorwise_normalized_mse_loss( diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py index c9e77df87bc..18eccdbb16c 100644 --- a/tests/gpu/torch/puzzletron/test_bypass.py +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -47,6 +47,7 @@ """ import copy +import json from datetime import timedelta from functools import partial from pathlib import Path @@ -62,6 +63,7 @@ import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.puzzletron.replacement_library.build_replacement_library as build_lib import modelopt.torch.utils.distributed as dist from modelopt.torch.puzzletron.anymodel import convert_model from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import set_experiment_id @@ -681,3 +683,412 @@ def _test_bypass_checkpoint_contents_job( f"PYTEST SUMMARY: test_bypass_checkpoint_contents[{hf_model_name}] completed. " f"Puzzle directory: {puzzle_dir}" ) + + +# --------------------------------------------------------------------------- +# Tests below this line target a single (or two) family deliberately — they +# exercise paths where parametrizing over all 9 families is overkill or +# requires extras (e.g. NemotronH's mamba-ssm dep). +# --------------------------------------------------------------------------- + +# Llama-3.2-3B is the smallest dense family and the canonical "FFN bypass" path. +LLAMA_FAMILY = pytest.param( + "meta-llama/Llama-3.2-3B-Instruct", "llama", None, False, id="llama-3.2-3B" +) +# GPT-OSS adds MoE expert pruning (mlp_init_mode="ExpertRemoval") and windowed +# attention with sinks — different code paths than dense Llama. +GPT_OSS_FAMILY = pytest.param("openai/gpt-oss-20b", "gpt_oss", None, True, id="gpt-oss-20b") + + +# --------------------------------------------------------------------------- +# Resume from checkpoint +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + [LLAMA_FAMILY], +) +def test_bypass_resume_from_checkpoint( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Two-phase test: train + save, then re-launch with resume and verify continuity. + + Phase 1: short bypass run (2 steps), checkpoint saved under + ``puzzle_dir/bypass/bypass_runs//iter-NNNNNN-ckpt/``. + Phase 2: same hydra_cfg + ``find_last_ckpt_for_resume=True`` + double the + training_tokens budget. The resume path in + ``training_loop.run_bypassed_training:805-840`` must restore + ``iter_num`` / ``step_num`` / ``token_count`` from the saved + ``args.json`` and load stitched-module + optimizer state from disk. + + The GradScaler save/load mechanism added in the recent CodeRabbit-driven + fix is tested separately in + ``tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py`` because + GradScaler is fp16-only and the bypass test infrastructure ships bf16, + which makes ``GradScaler.step()`` raise on the unscale path. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_resume_from_checkpoint_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_resume_from_checkpoint_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size, + hf_model_name, converter, hybrid_override_pattern, + ) + + # ---- Phase 1: train + save --------------------------------------------- + phase1_cfg = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + phase1_cfg["find_last_ckpt_for_resume"] = False + OmegaConf.update(hydra_cfg, "bypass", phase1_cfg, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + expected_experiment_id = _expected_experiment_id(phase1_cfg) + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + args_json_path = experiment_dir / "latest" / "args.json" + stitched_dir = experiment_dir / "latest" / "stitched" + + if rank == 0: + # Phase 1 must have produced the canonical artifacts. + assert args_json_path.exists(), f"Phase 1 missing args.json: {args_json_path}" + with open(args_json_path) as f: + phase1_state = json.load(f) + phase1_iter_num = phase1_state["iter_num"] + assert phase1_iter_num > 1, ( + f"Phase 1 should have advanced past iter 1, got {phase1_iter_num}" + ) + + # Optimizer state must be present (covers the resume path's load). + assert (stitched_dir / "block_0.optimizer_state.pth").exists(), stitched_dir + + dist.barrier() + + # ---- Phase 2: resume and continue -------------------------------------- + phase2_cfg = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + phase2_cfg["find_last_ckpt_for_resume"] = True + # Double the budget so the resumed run takes additional steps. + phase2_cfg["training"]["training_tokens"] = TRAINING_TOKENS * 2 + OmegaConf.update(hydra_cfg, "bypass", phase2_cfg, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + assert args_json_path.exists(), "Phase 2 should still have args.json" + with open(args_json_path) as f: + phase2_state = json.load(f) + phase2_iter_num = phase2_state["iter_num"] + # The resumed run must have moved past phase 1's last iter — proves + # both that resume happened (didn't restart at 1) and that further + # training executed. + assert phase2_iter_num > phase1_iter_num, ( + f"Resume did not advance: phase1={phase1_iter_num}, phase2={phase2_iter_num}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_resume_from_checkpoint[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +# --------------------------------------------------------------------------- +# Per-subblock training modes (Llama dense + GPT-OSS MoE/windowed-attn-sinks) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("keys_to_learn", ["subblock_ffn", "subblock_attention", "entire_block"]) +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + [LLAMA_FAMILY, GPT_OSS_FAMILY], +) +def test_bypass_subblock_modes( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + keys_to_learn: str, +): + """Verify that ``keys_to_learn`` correctly freezes the right param groups. + + For each (family, keys_to_learn) cell: + - Run bypass for 2 steps with that keys_to_learn. + - After training, load the saved stitched_module state dict. + - Compare against the teacher-derived initialization (``copied_dir`` of + the bypass experiment, which holds the post-init pre-train weights): + * subblock_ffn → only FFN keys differ from init; attention identical. + * subblock_attention → only attention keys differ; FFN identical. + * entire_block → both differ. + + GPT-OSS coverage matters because the MoE expert path uses + ``mlp_init_mode="ExpertRemoval"`` instead of ``"Truncate"``, and GPT-OSS's + windowed attention adds attention-sink parameters that the freeze must + correctly include in the "attention" group. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_subblock_modes_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + keys_to_learn, + ), + backend="nccl", + ) + + +def _test_bypass_subblock_modes_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + keys_to_learn: str, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size, + hf_model_name, converter, hybrid_override_pattern, + ) + + bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + bypass_cfg_dict["model_factory"]["keys_to_learn"] = keys_to_learn + # Save start-of-training checkpoint so we can diff trained-vs-init. + bypass_cfg_dict["save_checkpoint_before_training"] = True + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + experiment_dir = puzzle_dir / "bypass/bypass_runs" / expected_experiment_id + # `start-iter-*` is the pre-training snapshot (saved when + # save_checkpoint_before_training=True). The post-training snapshot + # under this short-budget config lives at `final-iter-*` (saved by the + # early-exit branch in training_loop.py); the periodic `iter-*` save + # never fires because the budget is only 2 steps. `latest` is updated + # by every `save_bypass_checkpoint` call, so post-training it points at + # the most recent save (the `final-iter-*` one). + start_dirs = sorted(experiment_dir.glob("start-iter-*-ckpt")) + assert start_dirs, f"Expected a start-iter-* checkpoint under {experiment_dir}" + start_dir = start_dirs[0] + end_dir = experiment_dir / "latest" + assert end_dir.exists(), f"Expected `latest` symlink under {experiment_dir}" + # Resolve to the real directory so glob below works regardless of the + # symlink-vs-directory distinction. + end_dir = end_dir.resolve() + assert end_dir != start_dir.resolve(), ( + f"`latest` still points at the pre-training snapshot {end_dir} — " + "no post-training checkpoint was written." + ) + + # Diff every saved stitched module's state dict between start (pre-train) + # and end (post-train). Block names look like ``block_0``, ``block_1``… + ffn_token_set = {".mlp.", ".experts."} # Llama vs GPT-OSS naming + attn_token = ".self_attn." + + def _key_kind(key: str) -> str: + if attn_token in key: + return "attn" + if any(t in key for t in ffn_token_set): + return "ffn" + return "other" + + ffn_changed = False + attn_changed = False + for state_dict_path in (start_dir / "stitched").glob("block_*.state_dict.pth"): + block_name = state_dict_path.stem.replace(".state_dict", "") + end_path = end_dir / "stitched" / state_dict_path.name + if not end_path.exists(): + continue + start_state = torch.load(state_dict_path, map_location="cpu", weights_only=True) + end_state = torch.load(end_path, map_location="cpu", weights_only=True) + for key in start_state.keys() & end_state.keys(): + kind = _key_kind(key) + if kind == "other": + continue + changed = not torch.equal(start_state[key], end_state[key]) + if kind == "ffn" and changed: + ffn_changed = True + if kind == "attn" and changed: + attn_changed = True + + if keys_to_learn == "subblock_ffn": + assert ffn_changed, f"subblock_ffn should change FFN weights ({hf_model_name})" + assert not attn_changed, ( + f"subblock_ffn should leave attention weights bit-identical ({hf_model_name})" + ) + elif keys_to_learn == "subblock_attention": + assert attn_changed, ( + f"subblock_attention should change attention weights ({hf_model_name})" + ) + assert not ffn_changed, ( + f"subblock_attention should leave FFN weights bit-identical ({hf_model_name})" + ) + else: # entire_block + assert ffn_changed and attn_changed, ( + f"entire_block should change both groups ({hf_model_name}); " + f"got ffn={ffn_changed}, attn={attn_changed}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_subblock_modes" + f"[{hf_model_name}, keys_to_learn={keys_to_learn}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) + + +# --------------------------------------------------------------------------- +# End-to-end: bypass then build replacement library +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("hf_model_name", "converter", "hybrid_override_pattern", "has_moe_layers"), + [LLAMA_FAMILY], +) +def test_bypass_then_build_library( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, +): + """Run bypass, then build the replacement library; assert bypass entries appear. + + Verifies the wiring between the bypass step and the downstream NAS step: + - ``realize_bypass_checkpoints`` creates a symlink at ``ckpts/``. + - ``_get_last_checkpoint_from_each_experiment`` resolves it back to the + bypass run dir. + - ``_build_subblocks_df``'s priority sort puts the bypass-rooted path + before non-bypass ones in the resulting DataFrame. + - The final ``replacement_library.json`` includes entries pointing at + the bypass experiment. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial( + _test_bypass_then_build_library_job, + project_root_path, + tmp_path, + hf_model_name, + converter, + hybrid_override_pattern, + has_moe_layers, + ), + backend="nccl", + ) + + +def _test_bypass_then_build_library_job( + project_root_path: Path, + tmp_path: Path, + hf_model_name: str, + converter: str, + hybrid_override_pattern: str | None, + has_moe_layers: bool, + rank: int, + size: int, +): + puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( + project_root_path, tmp_path, rank, size, + hf_model_name, converter, hybrid_override_pattern, + ) + + bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) + OmegaConf.update(hydra_cfg, "bypass", bypass_cfg_dict, merge=True) + + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + expected_experiment_id = _expected_experiment_id(bypass_cfg_dict) + ckpts_dir = puzzle_dir / "ckpts" + + # 1. The realize step must have created a symlink for this bypass run. + bypass_symlink = ckpts_dir / expected_experiment_id + assert bypass_symlink.is_symlink() or bypass_symlink.exists(), ( + f"Expected bypass symlink at {bypass_symlink}" + ) + + # 2. Discovery must find the bypass entry alongside the teacher (and any + # pruning-pipeline outputs from the setup helper). + discovered = build_lib._get_last_checkpoint_from_each_experiment(puzzle_dir) + bypass_resolved = bypass_symlink.resolve() + assert bypass_resolved in discovered, ( + f"Bypass run not discovered. Resolved={bypass_resolved}, " + f"discovered={discovered}" + ) + # The resolved bypass path must contain "bypass" + "bypass_runs" in its + # parts so the priority sort picks it up. + assert "bypass" in bypass_resolved.parts and "bypass_runs" in bypass_resolved.parts + + # 3. Build the replacement library and verify the bypass entry appears. + teacher_dir = ckpts_dir / "teacher" + subblocks_df = build_lib._build_subblocks_df( + master_puzzle_dir=puzzle_dir, + teacher_checkpoint_dir=teacher_dir, + add_ffn_no_ops=False, + add_attention_no_ops=False, + trust_remote_code=False, + ) + # Some subblock row's checkpoint_dir column must reference the bypass path. + # FFN-only rows leave attention_checkpoint_dir as NaN (and vice versa); we + # drop those before string-casting because pandas' .astype(str) doesn't + # reliably stringify NaN on object-dtype columns, and 'X' in float('nan') + # raises TypeError. + bypass_str = str(bypass_resolved) + attn_sources = subblocks_df["attention_checkpoint_dir"].dropna().astype(str).tolist() + ffn_sources = subblocks_df["ffn_checkpoint_dir"].dropna().astype(str).tolist() + assert any(bypass_str in s for s in attn_sources + ffn_sources), ( + f"replacement_library subblocks_df has no bypass-sourced rows. " + f"attn_sources={set(attn_sources)}, ffn_sources={set(ffn_sources)}" + ) + + dist.cleanup() + + print( + f"PYTEST SUMMARY: test_bypass_then_build_library[{hf_model_name}] completed. " + f"Puzzle directory: {puzzle_dir}" + ) diff --git a/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py new file mode 100644 index 00000000000..9e2c9f43843 --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -0,0 +1,255 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Save/load round-trip tests for ``bypass_checkpoint_utils``. + +These tests pin two correctness-critical pieces of ``bypass_checkpoint_utils``: + +1. ``_save_local_state`` persists the GradScaler state alongside the optimizer + state (regression coverage for the recent CodeRabbit-driven fix — without + it, fp16 + use_grad_scaling=True runs silently lost the running scale + + growth tracker on resume). +2. ``load_local_state`` restores it from disk. + +Lives under ``tests/gpu/`` because the production ``load_local_state`` builds +``torch.device(f"cuda:{rank}")`` for ``map_location``, so a real CUDA device +is required to round-trip ``torch.load`` without monkeypatching the device +machinery. The full bypass GPU integration test cannot cover this path +because the test infrastructure ships bf16 and ``GradScaler.step()`` is +fp16-only (raises ``NotImplementedError: +_amp_foreach_non_finite_check_and_unscale_cuda not implemented for 'BFloat16'``). +These tests sidestep that by hitting the save/load functions directly, +without ever invoking ``.step()``. +""" + +from collections import OrderedDict +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from torch.amp.grad_scaler import GradScaler + +from modelopt.torch.puzzletron.bypass_distillation import bypass_checkpoint_utils as bcu +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import ( + StitchedModuleDescriptor, +) + + +# --------------------------------------------------------------------------- +# Fixture: silence the dist helpers so the save/load functions run on a +# single GPU process without `torchrun` / NCCL setup. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bcu_no_dist(monkeypatch): + """Mock the dist helpers so ``bypass_checkpoint_utils`` runs without distributed init.""" + monkeypatch.setattr(bcu.dist, "local_rank", lambda: 0) + monkeypatch.setattr(bcu.dist, "is_master", lambda: True) + monkeypatch.setattr(bcu.dist, "barrier", lambda: None) + return bcu + + +def _make_descriptor( + *, + with_optimizer: bool = True, + with_scaler: bool = True, + grad_scaler_init_scale: float = 2.0**16, +): + """Build a minimal StitchedModuleDescriptor on CPU. + + ``stitched_module`` is a real ``nn.Linear`` so ``state_dict()`` / + ``load_state_dict()`` work without needing the actual ``StitchedModule`` + machinery (which depends on the sewing-kit graph, distributed init, etc.). + + The GradScaler is created with ``enabled=True`` so that ``state_dict()`` + actually contains content (a disabled scaler returns ``{}``, making + round-trip tests vacuous). We never call ``.scale()`` / ``.step()`` so + none of the fp16-only kernels run — only the bookkeeping fields + (``scale``, ``growth_factor``, ``backoff_factor``, ``growth_interval``, + ``_growth_tracker``) go through save/load. + """ + module = nn.Linear(4, 4, bias=False) + owned_parameters = dict(module.named_parameters()) + owned_buffers: dict[str, torch.Tensor] = {} + optimizer = ( + torch.optim.AdamW(list(module.parameters()), lr=1e-3) if with_optimizer else None + ) + scaler = ( + GradScaler(device="cpu", enabled=True, init_scale=grad_scaler_init_scale) + if with_scaler + else None + ) + return StitchedModuleDescriptor( + stitched_module=module, + owned_parameters=owned_parameters, + owned_buffers=owned_buffers, + optimizer=optimizer, + grad_scaler=scaler, + ) + + +# --------------------------------------------------------------------------- +# Save: every relevant artifact lands on disk +# --------------------------------------------------------------------------- + + +def test_save_local_state_writes_state_dict_optimizer_and_grad_scaler( + tmp_path: Path, bcu_no_dist +): + bcu = bcu_no_dist + descriptor = _make_descriptor() + descriptors = OrderedDict([("block_0", descriptor)]) + + bcu._save_local_state(descriptors, tmp_path) + + stitched = tmp_path / "stitched" + assert (stitched / "block_0.state_dict.pth").exists() + assert (stitched / "block_0.optimizer_state.pth").exists() + # The CodeRabbit-driven fix added this third file. Without it, resuming + # an fp16 + grad-scaling run would default-init the scaler. + assert (stitched / "block_0.grad_scaler.pth").exists() + + +def test_save_local_state_skips_grad_scaler_when_descriptor_has_none( + tmp_path: Path, bcu_no_dist +): + bcu = bcu_no_dist + descriptor = _make_descriptor(with_scaler=False) + descriptors = OrderedDict([("block_0", descriptor)]) + + bcu._save_local_state(descriptors, tmp_path) + + stitched = tmp_path / "stitched" + assert (stitched / "block_0.state_dict.pth").exists() + # No scaler in the descriptor → no .grad_scaler.pth file written. + assert not (stitched / "block_0.grad_scaler.pth").exists() + + +def test_save_local_state_skips_optimizer_when_descriptor_has_none( + tmp_path: Path, bcu_no_dist +): + """Pipeline-parallel idle ranks pass optimizer=None; no file should appear.""" + bcu = bcu_no_dist + descriptor = _make_descriptor(with_optimizer=False, with_scaler=False) + descriptors = OrderedDict([("block_0", descriptor)]) + + bcu._save_local_state(descriptors, tmp_path) + + stitched = tmp_path / "stitched" + assert (stitched / "block_0.state_dict.pth").exists() + assert not (stitched / "block_0.optimizer_state.pth").exists() + + +# --------------------------------------------------------------------------- +# Load: state survives the round-trip and lands back on the live scaler +# --------------------------------------------------------------------------- + + +def test_load_local_state_restores_grad_scaler_state(tmp_path: Path, bcu_no_dist): + """Round-trip: scaler with non-default init_scale → save → load into fresh scaler → state matches. + + This is the regression test for the CodeRabbit-flagged bug: prior to the + fix, ``load_local_state`` skipped the scaler entirely, so a resumed run + would silently start with a default scale (typically 65536.0) regardless + of where the previous run had grown the scale to. + + We compare via ``state_dict()`` rather than poking at private attributes + because the canonical save/load contract is ``state_dict()`` <-> + ``load_state_dict()``; ``state_dict()['scale']`` is the field a real + bypass run would have grown over time. + """ + bcu = bcu_no_dist + + # 1. Save phase: scaler with a non-default init scale. + save_descriptor = _make_descriptor(grad_scaler_init_scale=12345.0) + saved_state = save_descriptor.grad_scaler.state_dict() + assert saved_state["scale"] == 12345.0 # sanity: state actually carries the value + descriptors_save = OrderedDict([("block_0", save_descriptor)]) + bcu._save_local_state(descriptors_save, tmp_path) + + # 2. Load phase: a fresh descriptor with a different init scale; the load + # must overwrite it with the saved value. + load_descriptor = _make_descriptor(grad_scaler_init_scale=999.0) + pre_load_state = load_descriptor.grad_scaler.state_dict() + assert pre_load_state != saved_state # sanity: starts in a distinct state + descriptors_load = OrderedDict([("block_0", load_descriptor)]) + bcu.load_local_state(descriptors_load, tmp_path) + + assert load_descriptor.grad_scaler.state_dict() == saved_state + + +def test_load_local_state_handles_legacy_checkpoint_without_grad_scaler( + tmp_path: Path, bcu_no_dist +): + """Backward compat: a checkpoint saved before the GradScaler-fix must still load. + + Older bypass runs predating the GradScaler save did not write + ``block_0.grad_scaler.pth``. The current ``load_local_state`` must skip + silently in that case rather than raising — our deployed users have + legacy checkpoints they want to resume from. + """ + bcu = bcu_no_dist + + # First save with a scaler so we have a normal "complete" save… + save_descriptor = _make_descriptor() + descriptors_save = OrderedDict([("block_0", save_descriptor)]) + bcu._save_local_state(descriptors_save, tmp_path) + # …then delete the grad_scaler artifact to mimic a legacy checkpoint. + (tmp_path / "stitched" / "block_0.grad_scaler.pth").unlink() + + # Loading must not raise. + load_descriptor = _make_descriptor() + descriptors_load = OrderedDict([("block_0", load_descriptor)]) + bcu.load_local_state(descriptors_load, tmp_path) + + +def test_load_local_state_restores_optimizer_state(tmp_path: Path, bcu_no_dist): + """End-to-end optimizer round-trip — covers the resume path's main job.""" + bcu = bcu_no_dist + + save_descriptor = _make_descriptor() + # Take an optimizer step so AdamW has non-default ``state`` (exp_avg etc). + for p in save_descriptor.stitched_module.parameters(): + p.grad = torch.ones_like(p) + save_descriptor.optimizer.step() + saved_state = save_descriptor.optimizer.state_dict() + descriptors_save = OrderedDict([("block_0", save_descriptor)]) + bcu._save_local_state(descriptors_save, tmp_path) + + load_descriptor = _make_descriptor() + # Fresh optimizer's state dict should differ from `saved_state` until load. + assert load_descriptor.optimizer.state_dict() != saved_state + descriptors_load = OrderedDict([("block_0", load_descriptor)]) + bcu.load_local_state(descriptors_load, tmp_path) + + # After load, AdamW step counter and exp_avg buffers must match. + # Production runs co-locate model + state on cuda:0, but this fixture has the + # model on CPU so the loaded state ends up split: exp_avg / exp_avg_sq follow + # the param device (CPU), while AdamW's `step` tensor is loaded via + # ``map_location='cuda:0'`` and stays there. Move both to CPU for the + # comparison — we're verifying value equality, not device placement. + loaded_state = load_descriptor.optimizer.state_dict() + assert loaded_state["state"].keys() == saved_state["state"].keys() + for param_id in loaded_state["state"]: + for key, val in saved_state["state"][param_id].items(): + loaded_val = loaded_state["state"][param_id][key] + if torch.is_tensor(val): + assert torch.equal(loaded_val.to("cpu"), val.to("cpu")), ( + f"optimizer.state[{param_id}][{key}] not restored" + ) + else: + assert loaded_val == val diff --git a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py new file mode 100644 index 00000000000..da47d7e0001 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``_set_keys_to_learn`` in stitched_model_factory.py. + +This function is the single source of truth for which parameters get trained +during a bypass run. Its branches (subblock_ffn / subblock_attention / +subblock_mamba / entire_block / list / regex) and its hybrid-model +``block_configs`` filter are all silent on misuse — a regression here would +freeze the wrong layers and produce a worse-than-teacher checkpoint with no +loud failure. +""" + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import ( + _set_keys_to_learn, +) + + +# --------------------------------------------------------------------------- +# Fixtures: a minimal Llama-shaped model and a Llama-shaped descriptor stub +# --------------------------------------------------------------------------- + + +def _make_dense_model(num_layers: int = 2) -> nn.Module: + """Build a tiny model whose named_parameters mimic Llama's naming. + + Parameters live under ``model.layers.{i}.self_attn.{q,k,v,o}_proj.weight`` + and ``model.layers.{i}.mlp.{up,down}_proj.weight``. The function never reads + parameter shapes, so size doesn't matter — what matters is that the names + match what `_set_keys_to_learn` expects to see in `named_parameters()` and + `state_dict().keys()`. + """ + model = nn.Module() + model_inner = nn.Module() + layers = nn.ModuleList() + for _ in range(num_layers): + layer = nn.Module() + # attention + layer.self_attn = nn.Module() + for proj in ("q_proj", "k_proj", "v_proj", "o_proj"): + setattr(layer.self_attn, proj, nn.Linear(4, 4, bias=False)) + # feed-forward + layer.mlp = nn.Module() + for proj in ("up_proj", "down_proj"): + setattr(layer.mlp, proj, nn.Linear(4, 4, bias=False)) + layers.append(layer) + model_inner.layers = layers + model.model = model_inner + # `_set_keys_to_learn` reads `model.config` only to pass through to + # `descriptor.get_language_model_config` — a SimpleNamespace is enough. + model.config = SimpleNamespace() + # Start with everything frozen so any True flag is something the function set. + for p in model.parameters(): + p.requires_grad_(False) + return model + + +def _make_descriptor(num_layers: int, *, block_configs=None): + """Build a descriptor stub exposing only what ``_set_keys_to_learn`` calls. + + - ``get_language_model_config(config)`` returns an object with + ``num_hidden_layers`` and (optionally) ``block_configs``. + - ``get_weight_groups(state_dict_keys, num_hidden_layers)`` returns + ``{"block_{i}_attention": [...], "block_{i}_ffn": [...]}``. + """ + + def get_language_model_config(_config): + ns = SimpleNamespace(num_hidden_layers=num_layers) + if block_configs is not None: + ns.block_configs = block_configs + return ns + + def get_weight_groups(state_dict_keys, n): + groups: dict[str, list[str]] = {} + for i in range(n): + attn_prefix = f"model.layers.{i}.self_attn." + ffn_prefix = f"model.layers.{i}.mlp." + groups[f"block_{i}_attention"] = [ + k for k in state_dict_keys if k.startswith(attn_prefix) + ] + groups[f"block_{i}_ffn"] = [k for k in state_dict_keys if k.startswith(ffn_prefix)] + return groups + + return SimpleNamespace( + get_language_model_config=get_language_model_config, + get_weight_groups=get_weight_groups, + ) + + +def _trainable_names(model: nn.Module) -> set[str]: + return {n for n, p in model.named_parameters() if p.requires_grad} + + +# --------------------------------------------------------------------------- +# Single-string subblock keys (dense model) +# --------------------------------------------------------------------------- + + +def test_subblock_ffn_trains_only_mlp(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, "subblock_ffn") + trainable = _trainable_names(model) + assert all(".mlp." in n for n in trainable), trainable + assert not any(".self_attn." in n for n in trainable), trainable + # Both layers' mlp params must be trainable, not just one. + assert any("model.layers.0.mlp." in n for n in trainable) + assert any("model.layers.1.mlp." in n for n in trainable) + + +def test_subblock_attention_trains_only_self_attn(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, "subblock_attention") + trainable = _trainable_names(model) + assert all(".self_attn." in n for n in trainable), trainable + assert not any(".mlp." in n for n in trainable), trainable + + +def test_entire_block_trains_attention_and_mlp(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, "entire_block") + trainable = _trainable_names(model) + # Both groups present. + assert any(".self_attn." in n for n in trainable), trainable + assert any(".mlp." in n for n in trainable), trainable + # Equal to the union of every model parameter. + assert trainable == {n for n, _ in model.named_parameters()} + + +# --------------------------------------------------------------------------- +# Hybrid model: subblock_mamba vs subblock_attention should partition by +# block_configs[i].attention.mamba — this is the path most likely to +# silently misroute training under future descriptor changes. +# --------------------------------------------------------------------------- + + +def _hybrid_block_configs(): + """Block 0: Mamba. Block 1: GQA. Detected via ``attention.mamba is not None``.""" + return [ + SimpleNamespace(attention=SimpleNamespace(mamba=SimpleNamespace())), # Mamba + SimpleNamespace(attention=SimpleNamespace(mamba=None)), # GQA + ] + + +def test_subblock_mamba_on_hybrid_trains_only_mamba_block(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2, block_configs=_hybrid_block_configs()) + _set_keys_to_learn(model, descriptor, "subblock_mamba") + trainable = _trainable_names(model) + # Block 0 (Mamba) attention-group params should be trainable; block 1 (GQA) must not. + assert any("model.layers.0.self_attn." in n for n in trainable), trainable + assert not any("model.layers.1.self_attn." in n for n in trainable), trainable + # FFN params are never trainable under subblock_mamba. + assert not any(".mlp." in n for n in trainable), trainable + + +def test_subblock_attention_on_hybrid_trains_only_gqa_block(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2, block_configs=_hybrid_block_configs()) + _set_keys_to_learn(model, descriptor, "subblock_attention") + trainable = _trainable_names(model) + # Block 1 (GQA) attention-group params are trainable; block 0 (Mamba) must not. + assert any("model.layers.1.self_attn." in n for n in trainable), trainable + assert not any("model.layers.0.self_attn." in n for n in trainable), trainable + assert not any(".mlp." in n for n in trainable), trainable + + +# --------------------------------------------------------------------------- +# Free-form key forms: list, regex, no-match +# --------------------------------------------------------------------------- + + +def test_explicit_param_name_list_only_marks_listed_params(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + target = "model.layers.0.self_attn.q_proj.weight" + _set_keys_to_learn(model, descriptor, [target]) + assert _trainable_names(model) == {target} + + +def test_regex_string_uses_re_search(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + _set_keys_to_learn(model, descriptor, r"q_proj") + trainable = _trainable_names(model) + # Both layers' q_proj must match; nothing else should. + assert trainable == { + "model.layers.0.self_attn.q_proj.weight", + "model.layers.1.self_attn.q_proj.weight", + } + + +def test_no_match_regex_returns_silently_with_no_trainable_params(): + model = _make_dense_model(num_layers=2) + descriptor = _make_descriptor(num_layers=2) + # Pipeline-parallel idle-rank case: a regex that matches nothing on this rank. + _set_keys_to_learn(model, descriptor, r"nonexistent_param_pattern") + assert _trainable_names(model) == set() + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "keys_to_learn", + ["subblock_ffn", "subblock_attention", "entire_block"], +) +def test_subblock_keys_skip_non_floating_point_params(keys_to_learn): + """Integer / non-floating buffers exposed as parameters must stay frozen. + + The function explicitly guards on ``torch.is_floating_point(param)``; this + test pins that guard so a future refactor doesn't accidentally try to + enable grad on int tensors (which would raise at runtime). + """ + model = _make_dense_model(num_layers=2) + # Inject an int "param" alongside a real one. + int_param = nn.Parameter(torch.zeros(2, dtype=torch.long), requires_grad=False) + model.model.layers[0].self_attn.register_parameter("int_counter", int_param) + descriptor = _make_descriptor(num_layers=2) + # Should not raise even though the int param's name matches the attention group. + _set_keys_to_learn(model, descriptor, keys_to_learn) + # The int counter must remain frozen regardless. + assert not model.model.layers[0].self_attn.int_counter.requires_grad diff --git a/tests/unit/torch/puzzletron/test_bypass_replacement_library.py b/tests/unit/torch/puzzletron/test_bypass_replacement_library.py new file mode 100644 index 00000000000..ced076ac268 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_replacement_library.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for replacement-library checkpoint discovery + bypass priority. + +The ``build_replacement_library`` module is responsible for two correctness-critical +behaviors after a bypass run: + +1. ``_get_last_checkpoint_from_each_experiment`` must surface every valid + checkpoint under ``puzzle_dir/ckpts/``, including those that live there only + as symlinks (which is exactly how bypass writes its results). +2. When a bypass-trained subblock and a Truncate-init subblock would produce + the same architectural identifier, the bypass-trained one must be preferred + by the downstream ``drop_duplicates(keep="first")``. This is enforced by a + tuple-sort closure inside ``_build_subblocks_df`` that gives bypass paths + priority 0 and everything else priority 1. + +A regression in either path silently discards bypass-trained weights — exactly +the kind of bug that's invisible in normal CI runs. +""" + +from pathlib import Path + +import pytest + +from modelopt.torch.puzzletron.replacement_library import build_replacement_library as brl + + +# --------------------------------------------------------------------------- +# Filesystem fixture: tiny puzzle_dir with three checkpoints +# --------------------------------------------------------------------------- + + +def _write_minimal_config(checkpoint_dir: Path) -> None: + """Write a placeholder config.json so the discovery rglob finds the dir. + + The actual config contents don't matter — these tests monkeypatch + ``is_valid_decilm_checkpoint`` so no real config parsing happens. + """ + checkpoint_dir.mkdir(parents=True, exist_ok=True) + (checkpoint_dir / "config.json").write_text("{}") + + +@pytest.fixture +def puzzle_dir_with_three_ckpts(tmp_path: Path, monkeypatch) -> Path: + """Build a puzzle_dir tree mirroring a real post-bypass post-prune layout. + + Layout:: + + puzzle_dir/ + ckpts/ + teacher/ # real dir + config.json + bypass_ffn_256_heads_4 -> ../bypass/bypass_runs/.../iter-000010-ckpt + pruned_intermediate_256 -> ../pruning/pruned_intermediate_256 + bypass/bypass_runs/bypass_ffn_256_heads_4/iter-000010-ckpt/ + config.json + pruning/pruned_intermediate_256/ + config.json + + The two non-teacher entries under ``ckpts/`` are symlinks — that is how + ``puzzletron_nas_plugin.realize_bypass_checkpoints`` and the pruning + pipeline actually write them. ``_get_last_checkpoint_from_each_experiment`` + must `.resolve()` these to see the real path under ``bypass/bypass_runs/`` + or ``pruning/`` — that resolution is what the priority sort later keys on. + """ + puzzle_dir = tmp_path / "puzzle_dir" + ckpts = puzzle_dir / "ckpts" + ckpts.mkdir(parents=True) + + # Teacher: real directory directly under ckpts. + _write_minimal_config(ckpts / "teacher") + + # Bypass: real dir under bypass/bypass_runs/, symlinked from ckpts/. + bypass_real = puzzle_dir / "bypass" / "bypass_runs" / "bypass_ffn_256_heads_4" / "iter-000010-ckpt" + _write_minimal_config(bypass_real) + (ckpts / "bypass_ffn_256_heads_4").symlink_to(bypass_real, target_is_directory=True) + + # Truncate-pruned: real dir under pruning/, symlinked from ckpts/. + pruning_real = puzzle_dir / "pruning" / "pruned_intermediate_256" + _write_minimal_config(pruning_real) + (ckpts / "pruned_intermediate_256").symlink_to(pruning_real, target_is_directory=True) + + # Make every config.json look "valid" without parsing — load_model_config + # would otherwise try to load these as real HF configs. + monkeypatch.setattr(brl, "is_valid_decilm_checkpoint", lambda *a, **kw: True) + + return puzzle_dir + + +# --------------------------------------------------------------------------- +# Discovery +# --------------------------------------------------------------------------- + + +def test_get_last_checkpoint_from_each_experiment_finds_all_three( + puzzle_dir_with_three_ckpts: Path, +): + discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) + discovered_names = {p.name for p in discovered} + assert discovered_names == {"teacher", "iter-000010-ckpt", "pruned_intermediate_256"} + + +def test_get_last_checkpoint_from_each_experiment_resolves_symlinks( + puzzle_dir_with_three_ckpts: Path, +): + """The resolved paths must reflect the real filesystem location. + + This is what makes the bypass-priority sort work — the closure inside + ``_build_subblocks_df`` checks ``"bypass" in p.parts and "bypass_runs" + in p.parts``, which only succeeds on the resolved path. + """ + discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) + bypass_path = next(p for p in discovered if p.name == "iter-000010-ckpt") + assert "bypass" in bypass_path.parts + assert "bypass_runs" in bypass_path.parts + # And the pruning entry must NOT pick up "bypass" anywhere in its parts. + pruning_path = next(p for p in discovered if p.name == "pruned_intermediate_256") + assert "bypass" not in pruning_path.parts + + +def test_get_last_checkpoint_skips_invalid_checkpoints( + puzzle_dir_with_three_ckpts: Path, monkeypatch +): + """Only checkpoints that pass ``is_valid_decilm_checkpoint`` should appear. + + A regression where a malformed config.json silently slips through would + later raise inside ``_construct_subblock_rows_from_current_checkpoint`` + with a much less helpful traceback. + """ + + def _only_teacher_is_valid(checkpoint_dir, trust_remote_code=False): + return Path(checkpoint_dir).name == "teacher" + + monkeypatch.setattr(brl, "is_valid_decilm_checkpoint", _only_teacher_is_valid) + discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) + assert {p.name for p in discovered} == {"teacher"} + + +# --------------------------------------------------------------------------- +# Bypass-priority sort +# --------------------------------------------------------------------------- + + +def _bypass_priority(p: Path) -> tuple[int, str]: + """Re-implementation of the closure inside ``_build_subblocks_df``. + + Kept identical to ``modelopt/torch/puzzletron/replacement_library/ + build_replacement_library.py:222-225``. If that closure is changed, + update this test mirror; this is intentional duplication so the unit + test stays cheap (no need to build an end-to-end DataFrame just to + verify a 3-line priority function). + """ + is_bypass = "bypass" in p.parts and "bypass_runs" in p.parts + return (0 if is_bypass else 1, str(p)) + + +def test_bypass_priority_orders_bypass_before_pruning(puzzle_dir_with_three_ckpts: Path): + """The same input set the real code receives must sort bypass first.""" + discovered = brl._get_last_checkpoint_from_each_experiment(puzzle_dir_with_three_ckpts) + teacher = next(p for p in discovered if p.name == "teacher") + non_teacher_sorted = sorted(discovered - {teacher}, key=_bypass_priority) + + # Bypass must come first; pruning must come second. + assert non_teacher_sorted[0].name == "iter-000010-ckpt" + assert non_teacher_sorted[1].name == "pruned_intermediate_256" + + +def test_bypass_priority_is_stable_for_two_bypass_checkpoints(tmp_path: Path): + """Multiple bypass checkpoints must sort deterministically by string. + + Without this, ``set`` iteration order changes the picked-first checkpoint + across Python invocations, defeating the whole point of the priority sort. + """ + p1 = tmp_path / "puzzle/bypass/bypass_runs/bypass_a/iter-000010-ckpt" + p2 = tmp_path / "puzzle/bypass/bypass_runs/bypass_b/iter-000020-ckpt" + paths = {p2, p1} # insert in non-sorted order + out = sorted(paths, key=_bypass_priority) + assert [p.name for p in out] == ["iter-000010-ckpt", "iter-000020-ckpt"] + # Repeated runs hit the same order. + assert sorted({p1, p2}, key=_bypass_priority) == out From fc4bcfc47dfc80c5b7348fd2d26d644e9b0fca74 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Wed, 6 May 2026 08:50:44 -0700 Subject: [PATCH 06/13] Fix CI: F841, mypy attr-defined re-export, ruff format reflows - test_bypass.py: drop unused `block_name` (F841); ruff-format reflows for `_test_bypass_then_build_library_job`'s `_setup_hydra_cfg_and_pruning` args (one-arg-per-line) and the `Bypass run not discovered` assert message (single-line f-string). - sewing_kit/utils.py: re-export `normalized_mse_loss` via the `as normalized_mse_loss` form (PEP 484 explicit re-export). The prior `from X import Y # noqa: F401` form is treated by mypy as a private import, which surfaced as `attr-defined` at the call site in stitched_model_factory.py. - test_bypass_checkpoint_utils.py: remove blank line after the import block; collapse single-line optimizer ternary; collapse three function signatures whose argument list now fits on one line. - test_bypass_keys_to_learn.py: collapse single-symbol import and drop the trailing blank line. - test_bypass_replacement_library.py: drop a blank line and reflow `bypass_real = ...` to use parens for the line continuation. Signed-off-by: Sepehr Sameni --- modelopt/torch/puzzletron/sewing_kit/utils.py | 8 ++++++-- tests/gpu/torch/puzzletron/test_bypass.py | 13 ++++++++----- .../puzzletron/test_bypass_checkpoint_utils.py | 17 ++++------------- .../puzzletron/test_bypass_keys_to_learn.py | 5 +---- .../test_bypass_replacement_library.py | 5 +++-- 5 files changed, 22 insertions(+), 26 deletions(-) diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index 2820574bf69..b513715b223 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -460,8 +460,12 @@ def _get_group_kwarg_if_necessary() -> dict: # `normalized_mse_loss` already lives in tools.kd_model — re-export it here so # bypass-distillation imports stay co-located with the per-vector / per-batch -# variants below, without duplicating the implementation. -from modelopt.torch.puzzletron.tools.kd_model import normalized_mse_loss # noqa: E402, F401 +# variants below, without duplicating the implementation. The `as +# normalized_mse_loss` form is PEP 484's explicit re-export (mypy treats +# `from X import Y` as a private import otherwise). +from modelopt.torch.puzzletron.tools.kd_model import ( # noqa: E402 + normalized_mse_loss as normalized_mse_loss, +) def vectorwise_normalized_mse_loss( diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py index 18eccdbb16c..0ca9d62f6c1 100644 --- a/tests/gpu/torch/puzzletron/test_bypass.py +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -935,7 +935,6 @@ def _key_kind(key: str) -> str: ffn_changed = False attn_changed = False for state_dict_path in (start_dir / "stitched").glob("block_*.state_dict.pth"): - block_name = state_dict_path.stem.replace(".state_dict", "") end_path = end_dir / "stitched" / state_dict_path.name if not end_path.exists(): continue @@ -1032,8 +1031,13 @@ def _test_bypass_then_build_library_job( size: int, ): puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( - project_root_path, tmp_path, rank, size, - hf_model_name, converter, hybrid_override_pattern, + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, ) bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) @@ -1057,8 +1061,7 @@ def _test_bypass_then_build_library_job( discovered = build_lib._get_last_checkpoint_from_each_experiment(puzzle_dir) bypass_resolved = bypass_symlink.resolve() assert bypass_resolved in discovered, ( - f"Bypass run not discovered. Resolved={bypass_resolved}, " - f"discovered={discovered}" + f"Bypass run not discovered. Resolved={bypass_resolved}, discovered={discovered}" ) # The resolved bypass path must contain "bypass" + "bypass_runs" in its # parts so the priority sort picks it up. diff --git a/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py index 9e2c9f43843..a813d0060b7 100644 --- a/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py +++ b/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -47,7 +47,6 @@ StitchedModuleDescriptor, ) - # --------------------------------------------------------------------------- # Fixture: silence the dist helpers so the save/load functions run on a # single GPU process without `torchrun` / NCCL setup. @@ -85,9 +84,7 @@ def _make_descriptor( module = nn.Linear(4, 4, bias=False) owned_parameters = dict(module.named_parameters()) owned_buffers: dict[str, torch.Tensor] = {} - optimizer = ( - torch.optim.AdamW(list(module.parameters()), lr=1e-3) if with_optimizer else None - ) + optimizer = torch.optim.AdamW(list(module.parameters()), lr=1e-3) if with_optimizer else None scaler = ( GradScaler(device="cpu", enabled=True, init_scale=grad_scaler_init_scale) if with_scaler @@ -107,9 +104,7 @@ def _make_descriptor( # --------------------------------------------------------------------------- -def test_save_local_state_writes_state_dict_optimizer_and_grad_scaler( - tmp_path: Path, bcu_no_dist -): +def test_save_local_state_writes_state_dict_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): bcu = bcu_no_dist descriptor = _make_descriptor() descriptors = OrderedDict([("block_0", descriptor)]) @@ -124,9 +119,7 @@ def test_save_local_state_writes_state_dict_optimizer_and_grad_scaler( assert (stitched / "block_0.grad_scaler.pth").exists() -def test_save_local_state_skips_grad_scaler_when_descriptor_has_none( - tmp_path: Path, bcu_no_dist -): +def test_save_local_state_skips_grad_scaler_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): bcu = bcu_no_dist descriptor = _make_descriptor(with_scaler=False) descriptors = OrderedDict([("block_0", descriptor)]) @@ -139,9 +132,7 @@ def test_save_local_state_skips_grad_scaler_when_descriptor_has_none( assert not (stitched / "block_0.grad_scaler.pth").exists() -def test_save_local_state_skips_optimizer_when_descriptor_has_none( - tmp_path: Path, bcu_no_dist -): +def test_save_local_state_skips_optimizer_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): """Pipeline-parallel idle ranks pass optimizer=None; no file should appear.""" bcu = bcu_no_dist descriptor = _make_descriptor(with_optimizer=False, with_scaler=False) diff --git a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py index da47d7e0001..09b1c322421 100644 --- a/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py +++ b/tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py @@ -29,10 +29,7 @@ import torch import torch.nn as nn -from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import ( - _set_keys_to_learn, -) - +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import _set_keys_to_learn # --------------------------------------------------------------------------- # Fixtures: a minimal Llama-shaped model and a Llama-shaped descriptor stub diff --git a/tests/unit/torch/puzzletron/test_bypass_replacement_library.py b/tests/unit/torch/puzzletron/test_bypass_replacement_library.py index ced076ac268..4a950247dbe 100644 --- a/tests/unit/torch/puzzletron/test_bypass_replacement_library.py +++ b/tests/unit/torch/puzzletron/test_bypass_replacement_library.py @@ -37,7 +37,6 @@ from modelopt.torch.puzzletron.replacement_library import build_replacement_library as brl - # --------------------------------------------------------------------------- # Filesystem fixture: tiny puzzle_dir with three checkpoints # --------------------------------------------------------------------------- @@ -84,7 +83,9 @@ def puzzle_dir_with_three_ckpts(tmp_path: Path, monkeypatch) -> Path: _write_minimal_config(ckpts / "teacher") # Bypass: real dir under bypass/bypass_runs/, symlinked from ckpts/. - bypass_real = puzzle_dir / "bypass" / "bypass_runs" / "bypass_ffn_256_heads_4" / "iter-000010-ckpt" + bypass_real = ( + puzzle_dir / "bypass" / "bypass_runs" / "bypass_ffn_256_heads_4" / "iter-000010-ckpt" + ) _write_minimal_config(bypass_real) (ckpts / "bypass_ffn_256_heads_4").symlink_to(bypass_real, target_is_directory=True) From ca90612685c8a8c767d4443b3c44e8b9c9f3c055 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Wed, 6 May 2026 08:58:12 -0700 Subject: [PATCH 07/13] Fix CI: ruff format remaining _setup_hydra_cfg_and_pruning call sites Two more call sites (in _test_bypass_resume_from_checkpoint_job and _test_bypass_subblock_modes_job) had the old multi-line-arg-pair format that ruff format wants reflowed to one-arg-per-line. The previous CI fix caught only the third call site. Signed-off-by: Sepehr Sameni --- tests/gpu/torch/puzzletron/test_bypass.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/gpu/torch/puzzletron/test_bypass.py b/tests/gpu/torch/puzzletron/test_bypass.py index 0ca9d62f6c1..4b7a1c1d1d8 100644 --- a/tests/gpu/torch/puzzletron/test_bypass.py +++ b/tests/gpu/torch/puzzletron/test_bypass.py @@ -759,8 +759,13 @@ def _test_bypass_resume_from_checkpoint_job( size: int, ): puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( - project_root_path, tmp_path, rank, size, - hf_model_name, converter, hybrid_override_pattern, + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, ) # ---- Phase 1: train + save --------------------------------------------- @@ -884,8 +889,13 @@ def _test_bypass_subblock_modes_job( size: int, ): puzzle_dir, _, hydra_cfg = _setup_hydra_cfg_and_pruning( - project_root_path, tmp_path, rank, size, - hf_model_name, converter, hybrid_override_pattern, + project_root_path, + tmp_path, + rank, + size, + hf_model_name, + converter, + hybrid_override_pattern, ) bypass_cfg_dict = _make_bypass_cfg_dict(has_moe_layers, hydra_cfg) From 89981f5f0bc3d04a8cb13d91bb6289b585b060d2 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Thu, 7 May 2026 00:52:49 -0700 Subject: [PATCH 08/13] Add bypass+sewing_kit unit tests; fix _get_lr cosine off-by-one Pure-CPU unit tests covering the bypass-distillation surface that codecov flagged as uncovered: * puzzletron_nas_plugin progress helpers * dataloaders (split auto-detect, num_workers guard, pad helper, Printer fake accelerator, load_*_fn delegators) * launch_bypass_distillation sweep dispatcher * bypass_checkpoint_utils (find_latest_run_dir, _save_local_file, _save_local_state, save_bypass_checkpoint orchestration) * stitched_model_factory _get_all_non_persistent_buffers_set * sewing_kit InputArgs, ActivityContext, Needle validation Plus a GPU integration test pinning resume-from-latest end-to-end (test_bypass_resume.py). Fix off-by-one in _get_lr cosine: decay_ratio = (step - W) / (D - W) so the schedule reaches min_lr exactly at step==D instead of relying on the post-decay clamp at D+1 to mask a one-step plateau at base_lr right after warmup. Signed-off-by: Sepehr Sameni --- .../bypass_distillation/training_loop.py | 2 +- .../torch/puzzletron/test_bypass_resume.py | 250 +++++++++++++ .../test_bypass_checkpoint_utils.py | 347 ++++++++++++++++++ .../puzzletron/test_bypass_dataloaders.py | 280 ++++++++++++++ .../puzzletron/test_bypass_lr_scheduler.py | 127 +++++++ .../test_launch_bypass_distillation.py | 138 +++++++ .../puzzletron/test_puzzletron_progress.py | 113 ++++++ .../test_sewing_kit_activity_context.py | 181 +++++++++ .../puzzletron/test_sewing_kit_input_args.py | 162 ++++++++ .../puzzletron/test_sewing_kit_needle.py | 197 ++++++++++ .../test_stitched_model_factory_buffers.py | 76 ++++ 11 files changed, 1872 insertions(+), 1 deletion(-) create mode 100644 tests/gpu/torch/puzzletron/test_bypass_resume.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_dataloaders.py create mode 100644 tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py create mode 100644 tests/unit/torch/puzzletron/test_launch_bypass_distillation.py create mode 100644 tests/unit/torch/puzzletron/test_puzzletron_progress.py create mode 100644 tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py create mode 100644 tests/unit/torch/puzzletron/test_sewing_kit_input_args.py create mode 100644 tests/unit/torch/puzzletron/test_sewing_kit_needle.py create mode 100644 tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index 41abcc2281a..8486584b1c6 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -591,7 +591,7 @@ def _get_lr(cfg: DictConfig, step: int) -> float: lr = cfg.bypass.training.min_lr # 3) in between, use cosine decay down to min learning rate else: - decay_ratio = (step - warmup_steps - 1) / (lr_decay_steps - warmup_steps) + decay_ratio = (step - warmup_steps) / (lr_decay_steps - warmup_steps) assert 0 <= decay_ratio <= 1 coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 lr = cfg.bypass.training.min_lr + coeff * ( diff --git a/tests/gpu/torch/puzzletron/test_bypass_resume.py b/tests/gpu/torch/puzzletron/test_bypass_resume.py new file mode 100644 index 00000000000..f6a7419d2a2 --- /dev/null +++ b/tests/gpu/torch/puzzletron/test_bypass_resume.py @@ -0,0 +1,250 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU integration test for the bypass-distillation resume path. + +The existing ``test_bypass.py`` covers the save side: a fresh bypass run +produces a checkpoint and a ``ckpts/`` symlink. What it doesn't cover is +the *resume* side: a re-launched job calling ``find_latest_run_dir`` against +a real experiment directory and loading optimizer / state via ``load_local_state``. + +That contract — between what training writes (``saving_completed`` marker, +``args.json``, ``stitched/*.pth``) and what the resume helpers read — is +exactly the kind of thing that quietly diverges as the save format evolves. +A unit test can pin the regex; only an integration test pins the byte-level +agreement between writer and reader. + +Single dense family (Llama-3.2-3B-Instruct) is enough — the resume code path +is family-agnostic. +""" + +from datetime import timedelta +from functools import partial +from pathlib import Path + +import pytest +import torch +from _test_utils.torch.distributed.utils import spawn_multiprocess_job +from _test_utils.torch.misc import set_seed +from _test_utils.torch.puzzletron.utils import setup_test_model_and_data +from omegaconf import OmegaConf + +import modelopt.torch.puzzletron.activation_scoring.score_pruning_activations as score_pruning_activations +import modelopt.torch.puzzletron.bypass_distillation as bypass_distillation +import modelopt.torch.puzzletron.pruning.pruning_ckpts as pruning_ckpts +import modelopt.torch.utils.distributed as dist +from modelopt.torch.puzzletron.anymodel import convert_model +from modelopt.torch.puzzletron.bypass_distillation.bypass_checkpoint_utils import ( + find_latest_run_dir, +) +from modelopt.torch.puzzletron.bypass_distillation.bypass_utils import set_experiment_id +from modelopt.torch.puzzletron.tools.hydra_utils import initialize_hydra_config_for_dir + +# Match the constants in test_bypass.py so the run completes in two steps. +SEED = 1234 +TRAINING_TOKENS = 128 +BLOCK_SIZE = 64 +PRUNED_INTERMEDIATE_SIZE = 256 +PRUNED_NUM_KV_HEADS = 4 + +# One dense family — resume path is family-agnostic, so a second parametrize +# row would only add runtime, not coverage. +HF_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct" +CONVERTER = "llama" + + +def _bypass_cfg_dict(*, find_last_ckpt_for_resume: bool) -> dict: + """Minimal bypass config — derived from test_bypass.py's _make_bypass_cfg_dict + for a dense family with FFN+KV pruning.""" + return { + "dtype": "bf16", + "seed": 42, + "experiment_id": None, + "experiment_dir": None, + "iter_num": 1, + "step_num": 1, + "token_count": 0, + "data": { + "data_column": "conversation", + "block_size": BLOCK_SIZE, + "bos_rate": 0.5, + "fim_rate": 0, + "fim_spm_rate": 0, + "source_datasets_to_discard": [], + "load_from_disk": True, + "keep_in_memory": False, + "val_dataset_name": "valid", + "max_eval_samples": 1, + "eval_samples_per_process": None, + "shuffle_train_data_seed": 42, + }, + "training": { + "learning_rate": 1e-4, + "training_tokens": TRAINING_TOKENS, + "micro_batch_size": 1, + "val_micro_batch_size": 1, + "warmup_ratio": 0.05, + "warmup_steps": None, + "min_lr_factor": 1e-5, + "grad_accumulation_steps": 1, + "skip_first_batches": 0, + "weight_decay": 0.1, + "decay_lr": True, + "beta1": 0.9, + "beta2": 0.95, + "use_grad_scaling": False, + "grad_clip": 1.0, + "grad_clip_type": "norm", + "clipping_count": 0, + "log_interval": 5, + "eval_interval": 100, + }, + "resume_checkpoint_path": None, + "find_last_ckpt_for_resume": find_last_ckpt_for_resume, + "parameter_count": None, + "init_checkpoint_path": None, + "model": { + "student_weights_dtype": "bf16", + "model_overrides": { + "delete_old_checkpoints": True, + "save_interval_seconds": None, + "save_interval": 1_000_000_000, + "save_checkpoint_when_done": True, + }, + "model_config_overrides": { + "ffn": [{"intermediate_size": PRUNED_INTERMEDIATE_SIZE, "no_op": None}], + "attention": [{"num_key_value_heads": PRUNED_NUM_KV_HEADS, "no_op": None}], + }, + }, + "model_factory": { + "factory": "bypass_factory_fn", + "block_loss_func": "normalized_mse_loss", + "gqa_init_mode": "AverageKV", + "mlp_init_mode": "Truncate", + "mlp_init_config": {"activations_log_dir": None}, + "linear_init_mode": "FromTeacher", + "submodule_for_loss_calculation": None, + "keys_to_learn": "entire_block", + }, + "disable_initial_validate": True, + "validate_teacher_model": False, + "validate_student_model": False, + "disable_validation": True, + "best_val_loss": 1e9, + "compile": False, + "disable_fa2": False, + "teacher_model_load_on_cpu": False, + "save_checkpoint_before_training": False, + "disable_checkpoint_save": False, + "save_best_ckpt": True, + "kill_after_first_save": False, + "realize_best_or_latest": "best", + "wandb_log": False, + "wandb": {"project": None, "entity": None}, + } + + +def _expected_experiment_dir(puzzle_dir: Path, bypass_cfg_dict: dict) -> Path: + """Compute the experiment directory the runtime will choose.""" + cfg = OmegaConf.create({"bypass": dict(bypass_cfg_dict)}) + set_experiment_id(cfg) + return puzzle_dir / "bypass/bypass_runs" / cfg.bypass.experiment_id + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU required") +def test_bypass_resume_finds_latest_checkpoint(project_root_path: Path, tmp_path: Path): + """Run bypass once, verify ``find_latest_run_dir`` locates the saved + checkpoint, then re-launch with ``find_last_ckpt_for_resume=True`` and + verify the second run resumes from the saved iter_num. + """ + spawn_multiprocess_job( + size=torch.cuda.device_count(), + job=partial(_resume_job, project_root_path, tmp_path), + backend="nccl", + ) + + +def _resume_job(project_root_path: Path, tmp_path: Path, rank: int, size: int): + set_seed(SEED) + dist.setup(timeout=timedelta(10)) + + puzzle_dir, hf_checkpoint_path, dataset_path = setup_test_model_and_data( + tmp_path, rank, HF_MODEL_NAME, hybrid_override_pattern=None + ) + + hydra_config_dir = str(project_root_path / "tests/gpu/torch/puzzletron/resources/configs") + hydra_config_name = f"{HF_MODEL_NAME}/{Path(HF_MODEL_NAME).name}" + + if rank == 0: + convert_model( + input_dir=str(hf_checkpoint_path), + output_dir=str(puzzle_dir / "ckpts/teacher"), + converter=CONVERTER, + ) + dist.barrier() + + import hydra + + hydra_cfg = initialize_hydra_config_for_dir( + config_dir=hydra_config_dir, + config_name=hydra_config_name, + overrides=[f"puzzle_dir={puzzle_dir}", f"dataset_path={dataset_path}"], + ) + hydra_cfg = hydra.utils.instantiate(hydra_cfg) + + score_pruning_activations.launch_score_activations(hydra_cfg) + if rank == 0: + pruning_ckpts.launch_prune_ckpt(hydra_cfg) + dist.barrier() + + # First bypass run — produces a real checkpoint. + cfg_dict = _bypass_cfg_dict(find_last_ckpt_for_resume=False) + OmegaConf.update(hydra_cfg, "bypass", cfg_dict, merge=True) + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + experiment_dir = _expected_experiment_dir(puzzle_dir, cfg_dict) + if rank == 0: + # The save side wrote what the resume side expects. + assert experiment_dir.exists(), f"Expected experiment dir at {experiment_dir}" + latest = find_latest_run_dir(experiment_dir) + assert latest is not None, f"find_latest_run_dir returned None for {experiment_dir}" + assert (Path(latest) / "saving_completed").exists(), ( + f"Resume target {latest} missing saving_completed marker" + ) + assert (Path(latest) / "args.json").exists(), ( + f"Resume target {latest} missing args.json — load path would crash" + ) + dist.barrier() + + # Second bypass run — re-uses the same experiment_dir, finds the latest + # checkpoint via ``find_last_ckpt_for_resume=True``, and resumes. + # Reset cfg.bypass to a fresh dict (experiment_id back to None so + # set_experiment_id recomputes the same id from model_config_overrides). + cfg_dict_resume = _bypass_cfg_dict(find_last_ckpt_for_resume=True) + OmegaConf.update(hydra_cfg, "bypass", cfg_dict_resume, merge=True) + bypass_distillation.launch_bypass_distillation(hydra_cfg) + dist.barrier() + + if rank == 0: + # After the second run, iter_num must have advanced past 1 — proving + # the run picked up state from the first run rather than starting fresh. + # (The resume code path overwrites iter_num from args.json on line 826.) + assert hydra_cfg.bypass.iter_num > 1, ( + f"Resume failed: iter_num={hydra_cfg.bypass.iter_num} suggests fresh start, " + f"not a resume from the saved checkpoint" + ) + + dist.cleanup() diff --git a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py new file mode 100644 index 00000000000..f0b967ea5e7 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -0,0 +1,347 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CPU unit tests for ``bypass_checkpoint_utils``. + +The save/resume contract here is the most important regression surface in the +bypass feature: a wrong checkpoint pick or a missing ``saving_completed`` +marker silently restarts training from the wrong iteration. + +What's covered here (CPU-only, codecov-visible): + * ``find_latest_run_dir`` — every branch of the regex/scan/symlink logic. + * ``_save_local_file`` — overwrite/skip semantics. + * ``_save_local_state`` — same three save-path assertions as the GPU file + (state_dict / optimizer / grad_scaler), but on CPU so codecov picks them + up. The GPU file's ``test_load_local_state_*`` cases stay there because + ``load_local_state`` constructs ``torch.device(f"cuda:{rank}")`` directly. + * ``save_bypass_checkpoint`` — orchestration: ``latest`` symlink update, + ``args.json`` dump, ``saving_completed`` marker, master-only gating. +""" + +from collections import OrderedDict +from pathlib import Path + +import pytest +import torch +import torch.nn as nn +from omegaconf import OmegaConf +from torch.amp.grad_scaler import GradScaler + +from modelopt.torch.puzzletron.bypass_distillation import bypass_checkpoint_utils as bcu +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import ( + StitchedModuleDescriptor, +) + + +# --------------------------------------------------------------------------- +# Shared fixture: silence the dist helpers so these run single-process / CPU. +# Mirrors tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py:56-62. +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bcu_no_dist(monkeypatch): + monkeypatch.setattr(bcu.dist, "local_rank", lambda: 0) + monkeypatch.setattr(bcu.dist, "is_master", lambda: True) + monkeypatch.setattr(bcu.dist, "barrier", lambda: None) + return bcu + + +def _make_descriptor(*, with_optimizer: bool = True, with_scaler: bool = True): + """Build a CPU-only StitchedModuleDescriptor — the GPU file's helper minus + the configurable init_scale (we don't round-trip the scaler here).""" + module = nn.Linear(4, 4, bias=False) + owned_parameters = dict(module.named_parameters()) + optimizer = torch.optim.AdamW(list(module.parameters()), lr=1e-3) if with_optimizer else None + scaler = GradScaler(device="cpu", enabled=True, init_scale=2.0**16) if with_scaler else None + return StitchedModuleDescriptor( + stitched_module=module, + owned_parameters=owned_parameters, + owned_buffers={}, + optimizer=optimizer, + grad_scaler=scaler, + ) + + +# --------------------------------------------------------------------------- +# find_latest_run_dir +# --------------------------------------------------------------------------- + + +def test_find_latest_run_dir_returns_none_for_empty_dir(tmp_path: Path): + assert bcu.find_latest_run_dir(tmp_path) is None + + +def test_find_latest_run_dir_picks_only_iter_with_marker(tmp_path: Path): + iter_dir = tmp_path / "iter-000010-ckpt" + iter_dir.mkdir() + (iter_dir / "saving_completed").touch() + assert bcu.find_latest_run_dir(tmp_path) == str(iter_dir) + + +def test_find_latest_run_dir_picks_highest_iter_number(tmp_path: Path): + """When several plain iter checkpoints have completed markers, the highest + integer wins — not lexicographic order, not insertion order.""" + for i in (5, 10, 20): + d = tmp_path / f"iter-{i:06d}-ckpt" + d.mkdir() + (d / "saving_completed").touch() + assert bcu.find_latest_run_dir(tmp_path) == str(tmp_path / "iter-000020-ckpt") + + +def test_find_latest_run_dir_skips_iter_without_marker(tmp_path: Path): + """A partially-written checkpoint (no ``saving_completed``) must be skipped + even when it has a higher iter number — otherwise resume would crash on a + truncated state dict.""" + high = tmp_path / "iter-000099-ckpt" + high.mkdir() + # No saving_completed → must be ignored. + low = tmp_path / "iter-000050-ckpt" + low.mkdir() + (low / "saving_completed").touch() + assert bcu.find_latest_run_dir(tmp_path) == str(low) + + +def test_find_latest_run_dir_returns_none_when_no_iter_has_marker(tmp_path: Path): + (tmp_path / "iter-000010-ckpt").mkdir() + (tmp_path / "iter-000020-ckpt").mkdir() + # No saving_completed anywhere. + assert bcu.find_latest_run_dir(tmp_path) is None + + +def test_find_latest_run_dir_excludes_non_plain_iter_names(tmp_path: Path): + """``best-iter-*`` / ``start-iter-*`` / ``final-iter-*`` aren't valid resume + targets — pinned by the docstring on lines 39-42.""" + for name in ("best-iter-000099-ckpt", "start-iter-000001-ckpt", "final-iter-000050-ckpt"): + d = tmp_path / name + d.mkdir() + (d / "saving_completed").touch() + # No plain iter-*-ckpt at all. + assert bcu.find_latest_run_dir(tmp_path) is None + + +def test_find_latest_run_dir_uses_latest_symlink_fast_path(tmp_path: Path): + """The ``latest`` symlink, when present and complete, short-circuits the + scan — even when a numerically higher iter dir also has a marker. This + matters because the scan branch can be slow on filesystems with many + iter dirs (NFS, lustre).""" + target = tmp_path / "iter-000010-ckpt" + target.mkdir() + (target / "saving_completed").touch() + (tmp_path / "latest").symlink_to(target.name) + + higher = tmp_path / "iter-000020-ckpt" + higher.mkdir() + (higher / "saving_completed").touch() + + # Symlink wins despite higher iter existing. + assert bcu.find_latest_run_dir(tmp_path) == str(tmp_path / "latest") + + +def test_find_latest_run_dir_falls_through_when_latest_lacks_marker(tmp_path: Path): + """A ``latest`` symlink whose target lacks ``saving_completed`` (interrupted + save) must be ignored, falling through to the highest completed iter.""" + incomplete = tmp_path / "iter-000020-ckpt" + incomplete.mkdir() + # No saving_completed. + (tmp_path / "latest").symlink_to(incomplete.name) + + completed = tmp_path / "iter-000010-ckpt" + completed.mkdir() + (completed / "saving_completed").touch() + + assert bcu.find_latest_run_dir(tmp_path) == str(completed) + + +# --------------------------------------------------------------------------- +# _save_local_file +# --------------------------------------------------------------------------- + + +def test_save_local_file_writes_object_to_disk(tmp_path: Path): + target = tmp_path / "blob.pth" + bcu._save_local_file({"a": torch.tensor([1, 2, 3])}, target) + assert target.exists() + loaded = torch.load(target, weights_only=True) + assert torch.equal(loaded["a"], torch.tensor([1, 2, 3])) + + +def test_save_local_file_overwrite_true_replaces_contents(tmp_path: Path): + target = tmp_path / "blob.pth" + bcu._save_local_file({"v": torch.tensor([1])}, target) + bcu._save_local_file({"v": torch.tensor([99])}, target, overwrite=True) + loaded = torch.load(target, weights_only=True) + assert torch.equal(loaded["v"], torch.tensor([99])) + + +def test_save_local_file_overwrite_false_skips_existing(tmp_path: Path): + target = tmp_path / "blob.pth" + bcu._save_local_file({"v": torch.tensor([1])}, target) + # Second save should be a no-op. + bcu._save_local_file({"v": torch.tensor([99])}, target, overwrite=False) + loaded = torch.load(target, weights_only=True) + assert torch.equal(loaded["v"], torch.tensor([1])) + + +# --------------------------------------------------------------------------- +# _save_local_state — CPU-mirror of the three GPU save tests so codecov sees them +# --------------------------------------------------------------------------- + + +def test_save_local_state_writes_state_dict_optimizer_and_grad_scaler( + tmp_path: Path, bcu_no_dist +): + descriptors = OrderedDict([("block_0", _make_descriptor())]) + bcu_no_dist._save_local_state(descriptors, tmp_path) + stitched = tmp_path / "stitched" + assert (stitched / "block_0.state_dict.pth").exists() + assert (stitched / "block_0.optimizer_state.pth").exists() + assert (stitched / "block_0.grad_scaler.pth").exists() + + +def test_save_local_state_skips_grad_scaler_when_descriptor_has_none( + tmp_path: Path, bcu_no_dist +): + descriptors = OrderedDict([("block_0", _make_descriptor(with_scaler=False))]) + bcu_no_dist._save_local_state(descriptors, tmp_path) + stitched = tmp_path / "stitched" + assert (stitched / "block_0.state_dict.pth").exists() + assert not (stitched / "block_0.grad_scaler.pth").exists() + + +def test_save_local_state_skips_optimizer_when_descriptor_has_none( + tmp_path: Path, bcu_no_dist +): + descriptors = OrderedDict( + [("block_0", _make_descriptor(with_optimizer=False, with_scaler=False))] + ) + bcu_no_dist._save_local_state(descriptors, tmp_path) + stitched = tmp_path / "stitched" + assert (stitched / "block_0.state_dict.pth").exists() + assert not (stitched / "block_0.optimizer_state.pth").exists() + + +# --------------------------------------------------------------------------- +# save_bypass_checkpoint — orchestration: symlink, args.json, marker +# --------------------------------------------------------------------------- + + +def _make_save_cfg(experiment_dir: Path, *, delete_old: bool = True): + """Minimal cfg shape used by ``save_bypass_checkpoint``. + + ``cfg.bypass`` is the object that gets dumped to ``args.json``, so it must + be JSON-serialisable (or DictConfig-with-primitives, which json_dump handles). + """ + return OmegaConf.create( + { + "bypass": { + "experiment_dir": str(experiment_dir), + "model": {"model_overrides": {"delete_old_checkpoints": delete_old}}, + "iter_num": 7, + } + } + ) + + +@pytest.fixture +def patched_save(monkeypatch, bcu_no_dist): + """Stub out the heavy callees so the test only exercises the orchestration + logic in ``save_bypass_checkpoint``.""" + monkeypatch.setattr(bcu_no_dist, "_save_local_state", lambda **kwargs: None) + monkeypatch.setattr(bcu_no_dist, "save_checkpoint", lambda **kwargs: None) + return bcu_no_dist + + +def test_save_bypass_checkpoint_creates_latest_symlink_and_marker( + tmp_path: Path, patched_save +): + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + checkpoint_dir = experiment_dir / "iter-000007-ckpt" + checkpoint_dir.mkdir() + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=checkpoint_dir, + ) + + latest = experiment_dir / "latest" + assert latest.is_symlink() + # Symlink target is relative — just the dir name, so it resolves under experiment_dir. + import os + + assert os.readlink(latest) == "iter-000007-ckpt" + assert latest.resolve() == checkpoint_dir.resolve() + assert (checkpoint_dir / "args.json").exists() + assert (checkpoint_dir / "saving_completed").exists() + + +def test_save_bypass_checkpoint_replaces_existing_latest_symlink( + tmp_path: Path, patched_save +): + """A stale ``latest`` from a prior save must be replaced, not appended to. + Without ``unlink(missing_ok=True)`` the symlink_to() call would raise + FileExistsError mid-save and leave the run unable to checkpoint.""" + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + old_target = experiment_dir / "iter-000003-ckpt" + old_target.mkdir() + new_target = experiment_dir / "iter-000007-ckpt" + new_target.mkdir() + (experiment_dir / "latest").symlink_to(old_target.name) + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=new_target, + ) + + import os + + assert os.readlink(experiment_dir / "latest") == "iter-000007-ckpt" + + +def test_save_bypass_checkpoint_master_only_skips_symlink_on_non_master( + tmp_path: Path, monkeypatch, patched_save +): + """Non-master ranks must not write the symlink, args.json, or marker — + only rank 0 owns those files. The other ranks still call _save_local_state + (their owned blocks) but stop short of the per-experiment metadata.""" + monkeypatch.setattr(patched_save.dist, "is_master", lambda: False) + + experiment_dir = tmp_path / "exp" + experiment_dir.mkdir() + checkpoint_dir = experiment_dir / "iter-000007-ckpt" + checkpoint_dir.mkdir() + + cfg = _make_save_cfg(experiment_dir) + patched_save.save_bypass_checkpoint( + cfg=cfg, + descriptor=None, + model=None, + stitched_module_descriptors=OrderedDict(), + checkpoint_dir=checkpoint_dir, + ) + + assert not (experiment_dir / "latest").exists() + assert not (checkpoint_dir / "args.json").exists() + assert not (checkpoint_dir / "saving_completed").exists() diff --git a/tests/unit/torch/puzzletron/test_bypass_dataloaders.py b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py new file mode 100644 index 00000000000..00b5487dda1 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py @@ -0,0 +1,280 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for bypass-distillation dataloader utilities. + +Covers the pure-Python branches of ``utils/data/dataloaders.py`` that don't +need a real tokenizer / GPU / distributed init: the validation-split +auto-detect rules, the ``num_workers`` guard rail, the dataset-loader +delegators, the ``Printer`` fake accelerator, and the small numeric helpers +(``create_padded_tensor``, ``realize_dataset_in_memory``, ``collate_none_fn``). +""" + +import datasets +import pytest +import torch +from datasets import Dataset, DatasetDict + +import modelopt.torch.puzzletron.utils.data.dataloaders as dl +from modelopt.torch.puzzletron.utils.data.dataloaders import ( + Printer, + collate_fn_with_none_support, + collate_none_fn, + create_padded_tensor, + create_train_dataloader, + create_validation_dataloader, + load_from_disk_fn, + load_streaming_fn, + realize_dataset_in_memory, +) + + +# --------------------------------------------------------------------------- +# realize_dataset_in_memory: pure list materialisation with optional cap +# --------------------------------------------------------------------------- + + +def test_realize_dataset_in_memory_full(): + items = [{"a": 1}, {"a": 2}, {"a": 3}] + out = realize_dataset_in_memory(iter(items), eval_samples=None) + assert out == items + + +def test_realize_dataset_in_memory_capped(): + items = [{"a": 1}, {"a": 2}, {"a": 3}] + out = realize_dataset_in_memory(iter(items), eval_samples=2) + assert out == [{"a": 1}, {"a": 2}] + + +# --------------------------------------------------------------------------- +# create_padded_tensor: identity, 1D pad, 2D pad with non-zero pad value +# --------------------------------------------------------------------------- + + +def test_create_padded_tensor_identity(): + t = torch.arange(6, dtype=torch.float32).reshape(2, 3) + out = create_padded_tensor(t, desired_shape=(2, 3)) + assert out is t # short-circuit, no copy + + +def test_create_padded_tensor_pads_1d_with_default_zero(): + t = torch.tensor([1, 2, 3], dtype=torch.int32) + out = create_padded_tensor(t, desired_shape=(5,)) + assert out.tolist() == [1, 2, 3, 0, 0] + assert out.dtype == torch.int32 + + +def test_create_padded_tensor_pads_2d_with_custom_value(): + t = torch.tensor([[1.0, 2.0]]) + out = create_padded_tensor(t, desired_shape=(2, 3), padding_value=-100.0) + assert out.tolist() == [[1.0, 2.0, -100.0], [-100.0, -100.0, -100.0]] + + +# --------------------------------------------------------------------------- +# Collate helpers: None-aware default collator +# --------------------------------------------------------------------------- + + +def test_collate_none_fn_returns_none(): + assert collate_none_fn([None, None]) is None + assert collate_none_fn([1, 2, 3]) is None # unconditional + + +def test_collate_fn_with_none_support_passes_none_through(): + """A label tensor of None should not be coerced to ``[None, None]`` — the + bypass val loop expects a single ``None`` so it can short-circuit loss + computation. This pins the ``type(None) -> collate_none_fn`` registration.""" + batch = [{"x": torch.tensor([1.0]), "y": None}, {"x": torch.tensor([2.0]), "y": None}] + out = collate_fn_with_none_support(batch) + assert out["y"] is None + assert torch.equal(out["x"], torch.tensor([[1.0], [2.0]])) + + +# --------------------------------------------------------------------------- +# Printer: degenerate "main process" stand-in for Accelerator +# --------------------------------------------------------------------------- + + +def test_printer_attributes_match_main_process_contract(): + assert Printer.is_main_process is True + assert Printer.process_index is None + Printer.print("hello world") # must not raise + + +# --------------------------------------------------------------------------- +# load_from_disk_fn / load_streaming_fn: thin wrappers around datasets.* +# --------------------------------------------------------------------------- + + +def test_load_from_disk_fn_delegates_to_datasets(monkeypatch): + captured = {} + + def fake_load_from_disk(path, keep_in_memory=False): + captured["path"] = path + captured["keep_in_memory"] = keep_in_memory + return "sentinel" + + monkeypatch.setattr(datasets, "load_from_disk", fake_load_from_disk) + out = load_from_disk_fn("/some/path", content_field="conversation", keep_in_memory=True) + assert out == "sentinel" + assert captured == {"path": "/some/path", "keep_in_memory": True} + + +def test_load_streaming_fn_uses_streaming_with_features(monkeypatch): + """``load_streaming_fn`` must request streaming and pin the content field's + feature schema — without ``features=`` HuggingFace would auto-infer types + per-shard, which has caused bypass jobs to crash on schema drift in the past. + """ + captured = {} + + def fake_load_dataset(path, streaming, features, keep_in_memory): + captured["path"] = path + captured["streaming"] = streaming + captured["features"] = features + captured["keep_in_memory"] = keep_in_memory + return "stream-sentinel" + + monkeypatch.setattr(datasets, "load_dataset", fake_load_dataset) + out = load_streaming_fn("hf-org/dataset", content_field="text", keep_in_memory=False) + assert out == "stream-sentinel" + assert captured["path"] == "hf-org/dataset" + assert captured["streaming"] is True + assert captured["keep_in_memory"] is False + # features must be a Features object keyed by the requested content_field + # with a string Value — schema-drift protection is the whole point of this fn. + assert isinstance(captured["features"], datasets.Features) + assert "text" in captured["features"] + assert captured["features"]["text"].dtype == "string" + + +# --------------------------------------------------------------------------- +# create_train_dataloader: ``num_workers > 0`` is a configuration error +# --------------------------------------------------------------------------- + + +def test_create_train_dataloader_rejects_num_workers_gt_zero(): + """ConstantLengthDataset doesn't shard work via ``get_worker_info`` — every + worker would emit the same samples. The guard fires before tokenizer or + dataset are touched, so bare-bones args are enough.""" + with pytest.raises(ValueError, match="num_workers"): + create_train_dataloader( + seed=0, + tokenizer=None, + block_size=8, + dataset_path={"train": []}, + content_field="text", + fim_rate=0.0, + fim_spm_rate=0.0, + micro_batch_size=1, + num_workers=2, + ) + + +# --------------------------------------------------------------------------- +# create_validation_dataloader: split auto-detect + explicit override +# --------------------------------------------------------------------------- + + +class _FakeConstantLengthDataset: + """Stub for ``ConstantLengthDataset`` that records its ``dataset`` arg. + + Yields one trivial item so ``realize_dataset_in_memory`` can iterate over + it without touching a tokenizer. + """ + + last_dataset = None # class-level capture so tests can read after construction + + def __init__(self, tokenizer, dataset, **kwargs): + type(self).last_dataset = dataset + self._dataset = dataset + + def __iter__(self): + yield {"input_ids": torch.tensor([0])} + + +@pytest.fixture +def patched_dataloader(monkeypatch): + """Replace the heavy bits inside ``create_validation_dataloader`` so the + function exercises only its pure split-selection logic + DataLoader build.""" + monkeypatch.setattr(dl, "ConstantLengthDataset", _FakeConstantLengthDataset) + # Force a tiny in-memory list so we don't drain a real iterable. + monkeypatch.setattr( + dl, + "realize_dataset_in_memory", + lambda dataset, eval_samples: [{"input_ids": torch.tensor([0])}], + ) + _FakeConstantLengthDataset.last_dataset = None + return _FakeConstantLengthDataset + + +def _make_dict_dataset(splits: dict[str, list]) -> DatasetDict: + return DatasetDict({k: Dataset.from_list(v) for k, v in splits.items()}) + + +def _kwargs(): + return { + "accelerator": None, # → Printer (single-process path) + "seed": 0, + "tokenizer": None, + "block_size": 4, + "content_field": "text", + "fim_rate": 0.0, + "fim_spm_rate": 0.0, + "micro_batch_size": 1, + } + + +def test_validation_split_auto_picks_validation_when_present(patched_dataloader): + dd = _make_dict_dataset({"train": [{"text": "t"}], "validation": [{"text": "v"}]}) + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + # The "validation" split must have been the one passed to ConstantLengthDataset. + assert patched_dataloader.last_dataset is dd["validation"] + + +def test_validation_split_auto_falls_back_to_test_when_no_val(patched_dataloader): + dd = _make_dict_dataset({"train": [{"text": "t"}], "test": [{"text": "te"}]}) + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + assert patched_dataloader.last_dataset is dd["test"] + + +def test_validation_split_auto_prefers_val_over_test(patched_dataloader): + """If both ``validation`` and ``test`` exist, the val* prefix must win — + bypass relies on this to score against held-out data, not test data.""" + dd = _make_dict_dataset( + {"train": [{"text": "t"}], "validation": [{"text": "v"}], "test": [{"text": "te"}]} + ) + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + assert patched_dataloader.last_dataset is dd["validation"] + + +def test_validation_split_auto_assertion_on_multiple_val_options(patched_dataloader): + """Ambiguity must fail loudly — silently picking one would be a footgun.""" + dd = _make_dict_dataset({"validation": [{"text": "a"}], "valtest": [{"text": "b"}]}) + with pytest.raises(AssertionError, match="exactly one validation split"): + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + + +def test_validation_split_auto_assertion_on_no_val_or_test(patched_dataloader): + dd = _make_dict_dataset({"train": [{"text": "t"}], "extra": [{"text": "e"}]}) + with pytest.raises(AssertionError, match="exactly one validation split"): + create_validation_dataloader(dataset=dd, dataset_name="__auto__", **_kwargs()) + + +def test_validation_split_explicit_override_bypasses_auto(patched_dataloader): + """Explicit ``dataset_name`` must skip the auto-detect, even when the + chosen name doesn't match val* / test* prefixes.""" + dd = _make_dict_dataset({"my_eval": [{"text": "x"}]}) + create_validation_dataloader(dataset=dd, dataset_name="my_eval", **_kwargs()) + assert patched_dataloader.last_dataset is dd["my_eval"] diff --git a/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py b/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py new file mode 100644 index 00000000000..38701ba8be3 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_bypass_lr_scheduler.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the cosine-with-warmup LR scheduler used by bypass distillation. + +``_get_lr`` is the scheduler invoked every step inside ``train``. An off-by-one +in the cosine ramp would silently degrade convergence — bypass jobs run for +hours and produce subtly worse student weights. The degenerate-budget guard +matters for tests and short sweeps where ``training_tokens`` is small. + +Schedule shape (warmup_steps=W, lr_decay_steps=D): + + step ∈ [0, W]: linear ramp 0 → base_lr (warmup branch) + step ∈ (W, D]: cosine decay base_lr → min_lr (cosine branch) + step > D: clamped to min_lr (post-decay branch) + +The cosine uses ``decay_ratio = (step - W) / (D - W)`` so the boundary cases +align: at step=W+1 the cosine has just started (decay_ratio = 1/(D-W)) and at +step=D it reaches min_lr exactly (decay_ratio=1, coeff=0). +""" + +import math + +import pytest +from omegaconf import OmegaConf + +from modelopt.torch.puzzletron.bypass_distillation.training_loop import _get_lr + + +def _make_cfg( + *, + warmup_steps: int, + lr_decay_steps: int, + learning_rate: float = 1.0, + min_lr: float = 0.1, +): + return OmegaConf.create( + { + "bypass": { + "training": { + "warmup_steps": warmup_steps, + "lr_decay_steps": lr_decay_steps, + "learning_rate": learning_rate, + "min_lr": min_lr, + } + } + } + ) + + +def test_degenerate_budget_returns_base_lr(): + """When ``lr_decay_steps <= warmup_steps`` (tiny test budgets), the scheduler + must short-circuit to ``learning_rate`` rather than divide by zero.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=10, learning_rate=0.5) + assert _get_lr(cfg, step=0) == 0.5 + assert _get_lr(cfg, step=1) == 0.5 + assert _get_lr(cfg, step=99) == 0.5 + + +def test_degenerate_budget_warmup_greater_than_decay(): + """``lr_decay_steps < warmup_steps`` is also caught by the same guard.""" + cfg = _make_cfg(warmup_steps=20, lr_decay_steps=10, learning_rate=0.7) + assert _get_lr(cfg, step=5) == 0.7 + + +def test_warmup_linear_ramp(): + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=100, learning_rate=1.0) + assert _get_lr(cfg, step=0) == pytest.approx(0.0) + assert _get_lr(cfg, step=5) == pytest.approx(0.5) + assert _get_lr(cfg, step=10) == pytest.approx(1.0) + + +def test_cosine_starts_decaying_immediately_after_warmup(): + """At ``step == warmup_steps + 1`` the cosine branch is entered with + ``decay_ratio = 1/(D-W)`` — already a small step below base LR, not a + duplicate plateau at base LR. This is the boundary the previous formula + got wrong (it used ``step - W - 1`` and gave ``decay_ratio == 0`` here).""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.0) + # decay_ratio = (11 - 10) / 10 = 0.1 + expected = 0.5 * (1.0 + math.cos(math.pi * 0.1)) + assert _get_lr(cfg, step=11) == pytest.approx(expected) + # Strictly below base LR — the cosine has begun. + assert _get_lr(cfg, step=11) < 1.0 + + +def test_cosine_endpoint_returns_min_lr(): + """At ``step == lr_decay_steps`` the cosine branch reaches its endpoint: + ``decay_ratio == 1`` → ``coeff == 0`` → returns ``min_lr`` exactly. The + post-decay clamp at ``step == lr_decay_steps + 1`` is then a no-op + continuation, not a correction for an off-by-one.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.1) + assert _get_lr(cfg, step=20) == pytest.approx(0.1) + + +def test_cosine_midpoint_is_halfway(): + """At the cosine midpoint, ``coeff == 0.5`` → returns ``(lr + min_lr) / 2``.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.0) + # Midpoint of the post-warmup window: step such that decay_ratio == 0.5. + # decay_ratio = (step - 10) / (20 - 10) → step = 15 gives ratio 0.5. + expected_coeff = 0.5 * (1.0 + math.cos(math.pi * 0.5)) + assert _get_lr(cfg, step=15) == pytest.approx(expected_coeff) + + +def test_post_decay_clamps_to_min_lr(): + """``step > lr_decay_steps`` always returns ``min_lr`` exactly.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=20, learning_rate=1.0, min_lr=0.1) + assert _get_lr(cfg, step=21) == 0.1 + assert _get_lr(cfg, step=1000) == 0.1 + + +def test_min_lr_zero_decays_to_zero(): + """Common config: ``min_lr=0`` → cosine endpoint is exactly 0.""" + cfg = _make_cfg(warmup_steps=10, lr_decay_steps=30, learning_rate=2.0, min_lr=0.0) + assert _get_lr(cfg, step=30) == pytest.approx(0.0) + assert _get_lr(cfg, step=31) == 0.0 diff --git a/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py new file mode 100644 index 00000000000..3e08900fc5b --- /dev/null +++ b/tests/unit/torch/puzzletron/test_launch_bypass_distillation.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``launch_bypass_distillation`` (sweep dispatcher). + +The dispatcher's job is to iterate over ``bypass.configs``, apply each override +to the live ``hydra_cfg``, reset the per-run state machine, and invoke +``run_bypassed_training``. Reordering or dropping a reset would silently make +the second sweep entry resume from the first entry's iter counter — a bug +that would only surface as wasted compute and confused checkpoint dirs. + +We patch ``run_bypassed_training`` to a recorder so this stays a pure-Python +test (no GPU, no real training). +""" + +from omegaconf import OmegaConf + +import modelopt.torch.puzzletron.bypass_distillation.training_loop as tl + + +def _base_cfg(configs=None): + """Build a minimal cfg shape that ``launch_bypass_distillation`` reads. + + Includes only the keys touched by the dispatcher itself; ``run_bypassed_training`` + is mocked so its richer requirements are irrelevant here. + """ + cfg = { + "bypass": { + "model": {"model_config_overrides": {"intermediate_size": 1024}}, + "model_factory": {"keys_to_learn": "subblock_ffn"}, + "experiment_id": "stale-id", + "iter_num": 999, + "step_num": 999, + "token_count": 999_999, + "best_val_loss": 0.0, + "training": {"clipping_count": 42}, + } + } + if configs is not None: + cfg["bypass"]["configs"] = configs + return OmegaConf.create(cfg) + + +def _record_calls(monkeypatch): + """Patch ``run_bypassed_training`` to capture deep-copied cfg snapshots.""" + snapshots = [] + + def _recorder(cfg): + # Deep-copy via container conversion; the live cfg is mutated between calls. + snapshots.append(OmegaConf.to_container(cfg, resolve=True)) + + monkeypatch.setattr(tl, "run_bypassed_training", _recorder) + return snapshots + + +def test_no_configs_key_runs_once(monkeypatch): + """Absent ``bypass.configs`` is the single-config path — one call, no resets.""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(configs=None) + tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 1 + # Single-config path doesn't touch the state machine — values remain as supplied. + assert snapshots[0]["bypass"]["iter_num"] == 999 + assert snapshots[0]["bypass"]["training"]["clipping_count"] == 42 + + +def test_empty_configs_list_runs_once(monkeypatch): + """``configs: []`` must hit the same branch as missing — the truthiness + check on line 85 of training_loop.py treats both as 'no sweep'.""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(configs=[]) + tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 1 + + +def test_two_configs_run_twice_with_distinct_overrides(monkeypatch): + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg( + configs=[ + {"model_config_overrides": {"intermediate_size": 256}}, + {"model_config_overrides": {"intermediate_size": 128}}, + ] + ) + tl.launch_bypass_distillation(cfg) + assert len(snapshots) == 2 + assert snapshots[0]["bypass"]["model"]["model_config_overrides"] == {"intermediate_size": 256} + assert snapshots[1]["bypass"]["model"]["model_config_overrides"] == {"intermediate_size": 128} + + +def test_keys_to_learn_override_applied(monkeypatch): + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(configs=[{"keys_to_learn": "subblock_attention"}]) + tl.launch_bypass_distillation(cfg) + assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_attention" + + +def test_per_run_state_reset_before_each_call(monkeypatch): + """Every sweep entry must see iter_num=1, step_num=1, token_count=0, + best_val_loss=1e9, clipping_count=0, experiment_id=None — even when the + previous entry left the cfg in some other state.""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg( + configs=[ + {"model_config_overrides": {"intermediate_size": 256}}, + {"model_config_overrides": {"intermediate_size": 128}}, + ] + ) + tl.launch_bypass_distillation(cfg) + for snap in snapshots: + assert snap["bypass"]["experiment_id"] is None + assert snap["bypass"]["iter_num"] == 1 + assert snap["bypass"]["step_num"] == 1 + assert snap["bypass"]["token_count"] == 0 + assert snap["bypass"]["best_val_loss"] == 1e9 + assert snap["bypass"]["training"]["clipping_count"] == 0 + + +def test_override_without_keys_to_learn_leaves_cfg_value_untouched(monkeypatch): + """A sweep entry that only sets ``model_config_overrides`` must not clobber + the inherited ``keys_to_learn`` (the dispatcher's `if "keys_to_learn" in override` + guard, line 99).""" + snapshots = _record_calls(monkeypatch) + cfg = _base_cfg(configs=[{"model_config_overrides": {"intermediate_size": 256}}]) + tl.launch_bypass_distillation(cfg) + # keys_to_learn was set to "subblock_ffn" in _base_cfg — must survive. + assert snapshots[0]["bypass"]["model_factory"]["keys_to_learn"] == "subblock_ffn" diff --git a/tests/unit/torch/puzzletron/test_puzzletron_progress.py b/tests/unit/torch/puzzletron/test_puzzletron_progress.py new file mode 100644 index 00000000000..f83b07cdbe2 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_puzzletron_progress.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``_total_steps`` / ``_progress_step`` in ``puzzletron_nas_plugin``. + +These two helpers are the single source of truth for the user-facing +``Puzzletron Progress N/T`` log lines emitted by ``convert_puzzletron_model`` +and ``PuzzletronSearcher.run_search``. A regression that drops or reorders a +stage silently misnumbers every progress message; worse, an off-by-one would +hide which stage the pipeline crashed in. +""" + +import pytest +from omegaconf import OmegaConf + +from modelopt.torch.puzzletron.puzzletron_nas_plugin import ( + _STAGE_ORDER, + _progress_step, + _total_steps, +) + + +def _cfg_with_bypass(): + return OmegaConf.create({"bypass": {"experiment_dir": "/tmp/x"}}) + + +def _cfg_without_bypass(): + return OmegaConf.create({"some_other_key": True}) + + +def _cfg_with_null_bypass(): + return OmegaConf.create({"bypass": None}) + + +def test_total_steps_with_bypass_is_nine(): + assert _total_steps(_cfg_with_bypass()) == 9 + + +def test_total_steps_without_bypass_key_is_eight(): + assert _total_steps(_cfg_without_bypass()) == 8 + + +def test_total_steps_with_null_bypass_is_eight(): + """``bypass: null`` (typical override-to-disable) must read as 'no bypass'.""" + assert _total_steps(_cfg_with_null_bypass()) == 8 + + +def test_progress_step_walks_eight_stages_without_bypass(): + cfg = _cfg_without_bypass() + expected_no_bypass = [s for s in _STAGE_ORDER if s != "bypass"] + seen = [] + for stage in expected_no_bypass: + step, total = _progress_step(cfg, stage) + seen.append((stage, step, total)) + assert seen == [ + ("start", 1, 8), + ("convert", 2, 8), + ("score_activations", 3, 8), + ("prune", 4, 8), + ("build_library", 5, 8), + ("score_blocks", 6, 8), + ("mip", 7, 8), + ("complete", 8, 8), + ] + + +def test_progress_step_walks_nine_stages_with_bypass(): + cfg = _cfg_with_bypass() + seen = [(stage, *_progress_step(cfg, stage)) for stage in _STAGE_ORDER] + assert seen == [ + ("start", 1, 9), + ("convert", 2, 9), + ("score_activations", 3, 9), + ("prune", 4, 9), + ("bypass", 5, 9), + ("build_library", 6, 9), + ("score_blocks", 7, 9), + ("mip", 8, 9), + ("complete", 9, 9), + ] + + +def test_progress_step_bypass_stage_unknown_when_absent(): + """Asking for the bypass stage when bypass isn't configured is a programming + error — must raise, not silently return 0/8.""" + cfg = _cfg_without_bypass() + with pytest.raises(ValueError, match="Unknown pipeline stage"): + _progress_step(cfg, "bypass") + + +def test_progress_step_unknown_stage_raises(): + cfg = _cfg_with_bypass() + with pytest.raises(ValueError, match="Unknown pipeline stage"): + _progress_step(cfg, "definitely_not_a_real_stage") + + +def test_mip_step_shifts_when_bypass_added_or_removed(): + """Removing bypass must shift MIP from 8/9 to 7/8 — pinned by the docstring + on _progress_step which calls this out explicitly.""" + assert _progress_step(_cfg_with_bypass(), "mip") == (8, 9) + assert _progress_step(_cfg_without_bypass(), "mip") == (7, 8) diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py b/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py new file mode 100644 index 00000000000..e84db79f4ed --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``sewing_kit.utils.ActivityContext``. + +``ActivityContext`` is the stack the ``Passage`` machinery uses to track which +passages are currently active inside a ``StitchedModule.forward`` call. A bug +in push/pop ordering or in the exception-safe cleanup would leak state across +forward passes — every subsequent block would see a stale "active passage" +and route inputs/outputs to the wrong module. +""" + +import pytest + +from modelopt.torch.puzzletron.sewing_kit.utils import ( + ActivityContext, + ActivityContextDuplicateException, + ActivityContextMaxDepthException, +) + + +# --------------------------------------------------------------------------- +# Basic push/pop semantics via the ``with ctx(value):`` form +# --------------------------------------------------------------------------- + + +def test_starts_empty_and_inactive(): + ctx: ActivityContext[str] = ActivityContext() + assert len(ctx) == 0 + assert not ctx.is_active() + assert ctx.get_active() is None + + +def test_with_block_pushes_and_pops_value(): + ctx: ActivityContext[str] = ActivityContext() + with ctx("a"): + assert ctx.is_active() + assert ctx.get_active() == "a" + assert "a" in ctx + assert len(ctx) == 1 + # After the block: stack must be back to empty. + assert len(ctx) == 0 + assert ctx.get_active() is None + + +def test_nested_pushes_track_lifo_order(): + """``get_active`` returns the *most recent* push (LIFO) — Passage relies on + this to find the innermost active passage during forward.""" + ctx: ActivityContext[str] = ActivityContext() + with ctx("outer"): + assert ctx.get_active() == "outer" + with ctx("inner"): + assert ctx.get_active() == "inner" + assert ctx[0] == "outer" + assert ctx[1] == "inner" + # Inner pop returns to outer. + assert ctx.get_active() == "outer" + + +# --------------------------------------------------------------------------- +# max_depth: limits stack height +# --------------------------------------------------------------------------- + + +def test_max_depth_one_allows_single_push(): + ctx: ActivityContext[str] = ActivityContext(max_depth=1) + with ctx("a"): + assert ctx.get_active() == "a" + + +def test_max_depth_one_rejects_second_push(): + ctx: ActivityContext[str] = ActivityContext(max_depth=1) + with ctx("a"): + with pytest.raises(ActivityContextMaxDepthException): + with ctx("b"): + pass + # Stack must have unwound to empty even after the exception. + assert len(ctx) == 0 + + +# --------------------------------------------------------------------------- +# no_duplicates: same value can't appear twice +# --------------------------------------------------------------------------- + + +def test_no_duplicates_rejects_repeat_value(): + ctx: ActivityContext[str] = ActivityContext(no_duplicates=True) + with ctx("x"): + with pytest.raises(ActivityContextDuplicateException): + with ctx("x"): + pass + # Stack unwound; the still-active "x" was preserved through the failed push. + assert len(ctx) == 0 + + +def test_no_duplicates_allows_distinct_values(): + ctx: ActivityContext[str] = ActivityContext(no_duplicates=True) + with ctx("x"): + with ctx("y"): + assert "x" in ctx and "y" in ctx + + +# --------------------------------------------------------------------------- +# reversed=True: insert at front, pop from front +# --------------------------------------------------------------------------- + + +def test_reversed_pushes_to_front_and_pops_from_front(): + """``Passage.active_passages_context`` uses ``reversed=True`` so the + *first* active passage in iteration order is the innermost. Pin both + insert position and pop position.""" + ctx: ActivityContext[str] = ActivityContext(reversed=True) + with ctx("a"): + with ctx("b"): + # b inserted at front of stack. + assert ctx[0] == "b" + assert ctx[1] == "a" + # Pop from front: only "a" left. + assert list(ctx[:]) == ["a"] + + +# --------------------------------------------------------------------------- +# Exception safety: stack unwinds even if the caller's body raises +# --------------------------------------------------------------------------- + + +def test_stack_unwinds_when_body_raises(): + """A bug here would leak stack frames — the next forward pass would see + a stale active passage. This is the silent-failure scenario.""" + ctx: ActivityContext[str] = ActivityContext() + with pytest.raises(ValueError, match="boom"): + with ctx("a"): + assert ctx.get_active() == "a" + raise ValueError("boom") + assert len(ctx) == 0 + + +# --------------------------------------------------------------------------- +# is_submodule_of / is_submodule_or_same — string predicates used by passage.py +# --------------------------------------------------------------------------- + + +from modelopt.torch.puzzletron.sewing_kit.utils import ( # noqa: E402 + is_submodule_of, + is_submodule_or_same, +) + + +def test_is_submodule_of_proper_descendant(): + assert is_submodule_of("model.layers.0.self_attn", "model.layers.0") + assert is_submodule_of("model.layers.0", "model") + # Empty string parent matches any non-empty name (root-of-everything case). + assert is_submodule_of("model", "") + + +def test_is_submodule_of_rejects_self_and_unrelated(): + assert not is_submodule_of("model.layers.0", "model.layers.0") + assert not is_submodule_of("model.layers.0", "model.layers.1") + # Empty == empty is not a submodule relationship. + assert not is_submodule_of("", "") + # Prefix collision: "model.layers" is NOT a submodule of "model.lay" — the + # predicate requires a literal "." separator after the parent. + assert not is_submodule_of("model.layers", "model.lay") + + +def test_is_submodule_or_same_includes_equality(): + assert is_submodule_or_same("model.layers.0", "model.layers.0") + assert is_submodule_or_same("model.layers.0.attn", "model.layers.0") + assert not is_submodule_or_same("model.layers.0", "model.layers.1") diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py b/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py new file mode 100644 index 00000000000..15aa945f3bb --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py @@ -0,0 +1,162 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``sewing_kit.passage.InputArgs``. + +``InputArgs`` is the workhorse args/kwargs container the bypass distillation +factory uses inside its stitching reducers — see ``bypass_factory_fn`` calls +like ``lambda acc, override, orig, *args: override + orig.drop_args(0)``. +A regression in ``__add__`` or ``drop_args`` would silently corrupt the +inputs passed into per-block forward passes, producing wrong loss values +without any loud failure. +""" + +import pytest + +from modelopt.torch.puzzletron.sewing_kit.passage import InputArgs + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +def test_init_accepts_positional_and_keyword_args(): + ia = InputArgs(1, 2, foo="bar") + assert ia.args == [1, 2] + assert ia.kwargs == {"foo": "bar"} + + +def test_init_with_no_args_is_empty(): + ia = InputArgs() + assert ia.args == [] + assert ia.kwargs == {} + + +# --------------------------------------------------------------------------- +# __add__: concatenates args, merges kwargs (right wins on collision) +# --------------------------------------------------------------------------- + + +def test_add_concatenates_positional_args_in_order(): + a = InputArgs(1, 2) + b = InputArgs(3, 4) + result = a + b + assert result.args == [1, 2, 3, 4] + assert result.kwargs == {} + + +def test_add_merges_kwargs_with_right_winning(): + """Bypass reducers chain ``override + orig.drop_args(0)`` — when both sides + happen to set the same kwarg, the right-side value (the original input) + must win, otherwise the override silently displaces the original kwarg.""" + a = InputArgs(foo="from_a", bar="only_a") + b = InputArgs(foo="from_b", baz="only_b") + result = a + b + assert result.kwargs == {"foo": "from_b", "bar": "only_a", "baz": "only_b"} + + +def test_add_does_not_mutate_operands(): + a = InputArgs(1, 2, x="a") + b = InputArgs(3, y="b") + _ = a + b + assert a.args == [1, 2] and a.kwargs == {"x": "a"} + assert b.args == [3] and b.kwargs == {"y": "b"} + + +def test_add_rejects_non_input_args(): + with pytest.raises(AssertionError): + InputArgs(1) + [2] # type: ignore[operator] + + +# --------------------------------------------------------------------------- +# drop_args: clears all positional args (default) or one by index/slice +# --------------------------------------------------------------------------- + + +def test_drop_args_default_clears_all_positional(): + """The ``drop_args(0)`` and ``drop_args()`` forms are both used by bypass + stitches — the default-no-arg form must wipe the entire positional tuple + (kwargs untouched).""" + ia = InputArgs(1, 2, 3, foo="bar") + out = ia.drop_args() + assert out.args == [] + assert out.kwargs == {"foo": "bar"} + # And the original is unmodified. + assert ia.args == [1, 2, 3] + + +def test_drop_args_with_index_drops_one(): + ia = InputArgs(10, 20, 30) + out = ia.drop_args(0) + assert out.args == [20, 30] + # Source preserved. + assert ia.args == [10, 20, 30] + + +def test_drop_args_with_slice_drops_range(): + ia = InputArgs(10, 20, 30, 40) + out = ia.drop_args(slice(1, 3)) + assert out.args == [10, 40] + + +# --------------------------------------------------------------------------- +# drop_kwargs: clears all kwargs (default) or specific keys +# --------------------------------------------------------------------------- + + +def test_drop_kwargs_default_clears_all(): + ia = InputArgs(1, foo="bar", baz="qux") + out = ia.drop_kwargs() + assert out.args == [1] + assert out.kwargs == {} + + +def test_drop_kwargs_with_keys_drops_only_those(): + ia = InputArgs(1, foo="bar", baz="qux", keep="this") + out = ia.drop_kwargs(["foo", "baz"]) + assert out.kwargs == {"keep": "this"} + + +def test_drop_kwargs_silently_ignores_missing_keys(): + """A key listed in ``drop_kwargs`` that isn't present must not raise — + bypass calls this against args from arbitrary upstream stitches and may + pass keys that only some sources produce.""" + ia = InputArgs(foo="bar") + out = ia.drop_kwargs(["nonexistent"]) # must not KeyError + assert out.kwargs == {"foo": "bar"} + + +# --------------------------------------------------------------------------- +# from_value: lifts assorted values into InputArgs +# --------------------------------------------------------------------------- + + +def test_from_value_passes_through_existing_input_args(): + ia = InputArgs(1, foo="bar") + out = InputArgs.from_value(ia) + assert out is ia + + +def test_from_value_lifts_sequence_to_positional_args(): + out = InputArgs.from_value([1, 2, 3]) + assert out.args == [1, 2, 3] + assert out.kwargs == {} + + +def test_from_value_lifts_scalar_to_single_positional(): + out = InputArgs.from_value(42) + assert out.args == [42] + assert out.kwargs == {} diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_needle.py b/tests/unit/torch/puzzletron/test_sewing_kit_needle.py new file mode 100644 index 00000000000..d44c3e67cd8 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_sewing_kit_needle.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``sewing_kit.core.Needle`` graph construction and validation. + +The bypass factory builds three ``Needle``\\s per rank (teacher train, teacher +val, student val) and calls ``Needle.knot()`` on each. ``knot()`` runs +``_validate_nodes`` first; a regression in that validation would either crash +with an opaque NoneType error during forward, or — worse — silently allow a +malformed graph that produces incorrect activations. + +We test the validation contract on CPU without instantiating ``StitchedModule`` +itself (which requires Module patching). ``_validate_nodes`` is a private +method but it's the unit of behavior worth pinning; ``knot()`` is essentially +``_validate_nodes() + StitchedModule(...)``. +""" + +import pytest +import torch.nn as nn + +from modelopt.torch.puzzletron.sewing_kit.core import ( + ExternalTarget, + ModuleTarget, + Needle, + Node, + OnlyInternalNodesException, + StitchDescriptor, +) +from modelopt.torch.puzzletron.sewing_kit.core import ( + InputsLoopFoundException, +) + + +# --------------------------------------------------------------------------- +# get_node_for_target: lazy creation, cached lookup +# --------------------------------------------------------------------------- + + +def test_get_node_for_target_creates_node_on_first_call(): + needle = Needle() + target = ModuleTarget("a", nn.Linear(2, 2)) + node = needle.get_node_for_target(target) + assert isinstance(node, Node) + assert node.target is target + assert needle.nodes[target] is node + + +def test_get_node_for_target_returns_same_node_on_repeat_call(): + """Re-getting the same target must NOT create a duplicate node — every + stitch involving that target must funnel into a single Node, otherwise + the validation/forward graph fragments.""" + needle = Needle() + target = ModuleTarget("a", nn.Linear(2, 2)) + node1 = needle.get_node_for_target(target) + node2 = needle.get_node_for_target(target) + assert node1 is node2 + assert len(needle.nodes) == 1 + + +# --------------------------------------------------------------------------- +# stitch: adds StitchDescriptor to source.stitches_from and dest.stitches_to +# --------------------------------------------------------------------------- + + +def test_stitch_records_descriptor_on_both_endpoints(): + needle = Needle() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + needle.stitch(target_a.output("x"), target_b.input("y")) + + node_a = needle.nodes[target_a] + node_b = needle.nodes[target_b] + # Source endpoint: A has one outgoing stitch; B has one incoming stitch. + assert len(node_a.stitches_from) == 1 + assert len(node_a.stitches_to) == 0 + assert len(node_b.stitches_from) == 0 + assert len(node_b.stitches_to) == 1 + # Same StitchDescriptor object on both lists. + assert node_a.stitches_from[0] is node_b.stitches_to[0] + assert isinstance(node_a.stitches_from[0], StitchDescriptor) + + +def test_stitch_returns_self_for_chaining(): + """Bypass factory chains ``.stitch(...).stitch(...)`` — the return type + must be the Needle itself so the second call sees the same graph.""" + needle = Needle() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + out = needle.stitch(target_a.output("x"), target_b.input("y")) + assert out is needle + + +# --------------------------------------------------------------------------- +# _validate_nodes: contract checks before knot() builds the StitchedModule +# --------------------------------------------------------------------------- + + +def test_validate_raises_when_only_internal_nodes_present(): + """A graph with no External and no Remote target has nothing for the + runtime to feed inputs through — must raise loudly rather than build a + dead StitchedModule.""" + needle = Needle() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + needle.stitch(target_a.output("x"), target_b.input("y")) + + with pytest.raises(OnlyInternalNodesException): + needle._validate_nodes() + + +def test_validate_passes_with_external_plus_dag(): + """Happy path: ExternalTarget + a small linear DAG. Must not raise.""" + needle = Needle() + ext = ExternalTarget() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + needle.stitch(ext.output("init"), target_a.input("entry")) + needle.stitch(target_a.output("x"), target_b.input("y")) + needle.stitch(target_b.output("z"), ext.input("final")) + + # No raise. + needle._validate_nodes() + + +def test_validate_raises_on_input_cycle_among_internal_nodes(): + """Detect a 2-node cycle A→B→A among internal nodes. + + The validation uses ``_search_loops`` walking ``stitches_to`` (incoming + edges); ExternalTarget short-circuits the recursion, so we add an + external feed to A so ``_validate_nodes`` doesn't bail out early on the + 'no external' check. + """ + needle = Needle() + ext = ExternalTarget() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + target_b = ModuleTarget("b", nn.Linear(2, 2)) + + # Anchor an external feed so we get past the OnlyInternalNodes check. + needle.stitch(ext.output("init"), target_a.input("entry")) + # Cycle: A -> B -> A. + needle.stitch(target_a.output("x"), target_b.input("y")) + needle.stitch(target_b.output("p"), target_a.input("q")) + + with pytest.raises(InputsLoopFoundException): + needle._validate_nodes() + + +def test_validate_passes_when_external_node_has_self_referential_loop_via_external(): + """``_search_loops`` short-circuits at ExternalTarget. So a 'loop' that + only goes through external (e.g. external→A and A→external) is fine — + and indeed required for normal stitching, where external is both the + input and output endpoint. + """ + needle = Needle() + ext = ExternalTarget() + target_a = ModuleTarget("a", nn.Linear(2, 2)) + + needle.stitch(ext.output("in"), target_a.input("entry")) + needle.stitch(target_a.output("x"), ext.input("out")) + + # Despite the external→A→external pattern, this is the canonical bypass + # shape and must validate clean. + needle._validate_nodes() + + +# --------------------------------------------------------------------------- +# Sanity: ExternalTarget.input()/output() builds correctly typed descriptors +# --------------------------------------------------------------------------- + + +def test_module_target_descriptors_carry_target_and_name(): + """The ``.input("foo")`` and ``.output("bar")`` builders are what the + bypass factory uses to construct stitches. They must propagate the + target reference and the name into the resulting descriptor so the + runtime can route values correctly.""" + target = ModuleTarget("a", nn.Linear(2, 2)) + in_desc = target.input("foo") + out_desc = target.output("bar") + assert in_desc.target is target + assert in_desc.input_name == "foo" + assert out_desc.target is target + assert out_desc.output_name == "bar" diff --git a/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py b/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py new file mode 100644 index 00000000000..5fab764b565 --- /dev/null +++ b/tests/unit/torch/puzzletron/test_stitched_model_factory_buffers.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for ``_get_all_non_persistent_buffers_set``. + +This helper is what ``bypass_factory_fn`` uses to decide which buffers belong +to ``owned_buffers`` (and therefore get checkpointed) versus which are +recomputed on every forward (RoPE caches, attention masks, etc.). A regression +that drops the module-name prefix would cause the post-resume model to silently +load buffers under wrong names. +""" + +import torch +import torch.nn as nn + +from modelopt.torch.puzzletron.bypass_distillation.stitched_model_factory import ( + _get_all_non_persistent_buffers_set, +) + + +def test_module_with_no_buffers_returns_empty_set(): + assert _get_all_non_persistent_buffers_set(nn.Module()) == set() + + +def test_persistent_buffer_excluded_non_persistent_included(): + m = nn.Module() + m.register_buffer("p", torch.zeros(1), persistent=True) + m.register_buffer("np", torch.zeros(1), persistent=False) + out = _get_all_non_persistent_buffers_set(m) + assert out == {"np"} + + +def test_nested_submodule_paths_are_fully_qualified(): + """Sub-module non-persistent buffers must surface as ``submodule_name.buffer_name`` + so the matching key in ``state_dict()`` and the bypass save/restore code agree.""" + outer = nn.Module() + inner = nn.Module() + inner.register_buffer("nb", torch.zeros(1), persistent=False) + outer.add_module("inner", inner) + out = _get_all_non_persistent_buffers_set(outer) + assert out == {"inner.nb"} + + +def test_top_level_buffer_has_no_leading_dot(): + """Module name is "" at the root — fully-qualified name must not start + with a dot, otherwise it won't match any state_dict key.""" + m = nn.Module() + m.register_buffer("x", torch.zeros(1), persistent=False) + out = _get_all_non_persistent_buffers_set(m) + assert out == {"x"} + assert not any(name.startswith(".") for name in out) + + +def test_mix_of_persistent_and_non_persistent_in_nested_module(): + """The full discrimination: only the nested non-persistent buffer should + appear, with its fully-qualified path.""" + outer = nn.Module() + inner = nn.Module() + inner.register_buffer("keep", torch.zeros(1), persistent=True) # persistent → excluded + inner.register_buffer("rope_cache", torch.zeros(1), persistent=False) + outer.add_module("attn", inner) + outer.register_buffer("global_keep", torch.zeros(1), persistent=True) # → excluded + out = _get_all_non_persistent_buffers_set(outer) + assert out == {"attn.rope_cache"} From 4a685d2382d4333e2c63fa8aaf347f4abfa9ef0a Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Thu, 7 May 2026 01:08:57 -0700 Subject: [PATCH 09/13] Fix pre-commit ruff issues in new sewing_kit tests * SIM117: combine nested `with` statements where semantically equivalent (test_max_depth, test_no_duplicates_*, test_stack_unwinds). test_reversed_pushes_to_front_and_pops_from_front keeps the nested form because the intermediate assertion between exits is the test's actual point. * ruff format: drop blank line after lone import, merge two `from ... import` blocks for sewing_kit.core. Signed-off-by: Sepehr Sameni --- .../test_sewing_kit_activity_context.py | 27 ++++++++----------- .../puzzletron/test_sewing_kit_input_args.py | 1 - .../puzzletron/test_sewing_kit_needle.py | 5 +--- 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py b/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py index e84db79f4ed..3fc9ab0ddeb 100644 --- a/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py +++ b/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py @@ -82,10 +82,8 @@ def test_max_depth_one_allows_single_push(): def test_max_depth_one_rejects_second_push(): ctx: ActivityContext[str] = ActivityContext(max_depth=1) - with ctx("a"): - with pytest.raises(ActivityContextMaxDepthException): - with ctx("b"): - pass + with ctx("a"), pytest.raises(ActivityContextMaxDepthException), ctx("b"): + pass # Stack must have unwound to empty even after the exception. assert len(ctx) == 0 @@ -97,19 +95,16 @@ def test_max_depth_one_rejects_second_push(): def test_no_duplicates_rejects_repeat_value(): ctx: ActivityContext[str] = ActivityContext(no_duplicates=True) - with ctx("x"): - with pytest.raises(ActivityContextDuplicateException): - with ctx("x"): - pass + with ctx("x"), pytest.raises(ActivityContextDuplicateException), ctx("x"): + pass # Stack unwound; the still-active "x" was preserved through the failed push. assert len(ctx) == 0 def test_no_duplicates_allows_distinct_values(): ctx: ActivityContext[str] = ActivityContext(no_duplicates=True) - with ctx("x"): - with ctx("y"): - assert "x" in ctx and "y" in ctx + with ctx("x"), ctx("y"): + assert "x" in ctx and "y" in ctx # --------------------------------------------------------------------------- @@ -127,7 +122,8 @@ def test_reversed_pushes_to_front_and_pops_from_front(): # b inserted at front of stack. assert ctx[0] == "b" assert ctx[1] == "a" - # Pop from front: only "a" left. + # Pop from front: only "a" left — runs between the inner and outer + # exits, which is why these withs can't be combined. assert list(ctx[:]) == ["a"] @@ -140,10 +136,9 @@ def test_stack_unwinds_when_body_raises(): """A bug here would leak stack frames — the next forward pass would see a stale active passage. This is the silent-failure scenario.""" ctx: ActivityContext[str] = ActivityContext() - with pytest.raises(ValueError, match="boom"): - with ctx("a"): - assert ctx.get_active() == "a" - raise ValueError("boom") + with pytest.raises(ValueError, match="boom"), ctx("a"): + assert ctx.get_active() == "a" + raise ValueError("boom") assert len(ctx) == 0 diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py b/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py index 15aa945f3bb..f5e52cb5032 100644 --- a/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py +++ b/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py @@ -27,7 +27,6 @@ from modelopt.torch.puzzletron.sewing_kit.passage import InputArgs - # --------------------------------------------------------------------------- # Construction # --------------------------------------------------------------------------- diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_needle.py b/tests/unit/torch/puzzletron/test_sewing_kit_needle.py index d44c3e67cd8..a3db5ef30b8 100644 --- a/tests/unit/torch/puzzletron/test_sewing_kit_needle.py +++ b/tests/unit/torch/puzzletron/test_sewing_kit_needle.py @@ -32,16 +32,13 @@ from modelopt.torch.puzzletron.sewing_kit.core import ( ExternalTarget, + InputsLoopFoundException, ModuleTarget, Needle, Node, OnlyInternalNodesException, StitchDescriptor, ) -from modelopt.torch.puzzletron.sewing_kit.core import ( - InputsLoopFoundException, -) - # --------------------------------------------------------------------------- # get_node_for_target: lazy creation, cached lookup From f3b52b0650978e3bf7b1de2cea00f73951c49a0e Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Thu, 7 May 2026 01:19:37 -0700 Subject: [PATCH 10/13] Fix remaining ruff issues in new puzzletron tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * RUF005: noqa the `InputArgs(1) + [2]` line — the auto-fix `[*InputArgs(1), 2]` would replace the operator call we are testing (this test pins that ``__add__`` rejects non-InputArgs via its internal assert). * ruff format: drop spurious blank lines after import blocks, collapse a function signature back to one line. * PLC0415 (in-function imports): hoist the two ``import os`` calls inside ``test_save_bypass_checkpoint_*`` to the file's import block. * Move the late ``is_submodule_of`` / ``is_submodule_or_same`` imports in ``test_sewing_kit_activity_context.py`` to the top of the file so the # noqa: E402 marker isn't needed. Signed-off-by: Sepehr Sameni --- .../torch/puzzletron/test_bypass_checkpoint_utils.py | 10 ++-------- tests/unit/torch/puzzletron/test_bypass_dataloaders.py | 1 - .../puzzletron/test_sewing_kit_activity_context.py | 9 ++------- .../torch/puzzletron/test_sewing_kit_input_args.py | 5 ++++- 4 files changed, 8 insertions(+), 17 deletions(-) diff --git a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py index f0b967ea5e7..518ff36869d 100644 --- a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -30,6 +30,7 @@ ``args.json`` dump, ``saving_completed`` marker, master-only gating. """ +import os from collections import OrderedDict from pathlib import Path @@ -44,7 +45,6 @@ StitchedModuleDescriptor, ) - # --------------------------------------------------------------------------- # Shared fixture: silence the dist helpers so these run single-process / CPU. # Mirrors tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py:56-62. @@ -284,17 +284,13 @@ def test_save_bypass_checkpoint_creates_latest_symlink_and_marker( latest = experiment_dir / "latest" assert latest.is_symlink() # Symlink target is relative — just the dir name, so it resolves under experiment_dir. - import os - assert os.readlink(latest) == "iter-000007-ckpt" assert latest.resolve() == checkpoint_dir.resolve() assert (checkpoint_dir / "args.json").exists() assert (checkpoint_dir / "saving_completed").exists() -def test_save_bypass_checkpoint_replaces_existing_latest_symlink( - tmp_path: Path, patched_save -): +def test_save_bypass_checkpoint_replaces_existing_latest_symlink(tmp_path: Path, patched_save): """A stale ``latest`` from a prior save must be replaced, not appended to. Without ``unlink(missing_ok=True)`` the symlink_to() call would raise FileExistsError mid-save and leave the run unable to checkpoint.""" @@ -315,8 +311,6 @@ def test_save_bypass_checkpoint_replaces_existing_latest_symlink( checkpoint_dir=new_target, ) - import os - assert os.readlink(experiment_dir / "latest") == "iter-000007-ckpt" diff --git a/tests/unit/torch/puzzletron/test_bypass_dataloaders.py b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py index 00b5487dda1..c140d94e7bd 100644 --- a/tests/unit/torch/puzzletron/test_bypass_dataloaders.py +++ b/tests/unit/torch/puzzletron/test_bypass_dataloaders.py @@ -40,7 +40,6 @@ realize_dataset_in_memory, ) - # --------------------------------------------------------------------------- # realize_dataset_in_memory: pure list materialisation with optional cap # --------------------------------------------------------------------------- diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py b/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py index 3fc9ab0ddeb..58df5ffe327 100644 --- a/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py +++ b/tests/unit/torch/puzzletron/test_sewing_kit_activity_context.py @@ -28,9 +28,10 @@ ActivityContext, ActivityContextDuplicateException, ActivityContextMaxDepthException, + is_submodule_of, + is_submodule_or_same, ) - # --------------------------------------------------------------------------- # Basic push/pop semantics via the ``with ctx(value):`` form # --------------------------------------------------------------------------- @@ -147,12 +148,6 @@ def test_stack_unwinds_when_body_raises(): # --------------------------------------------------------------------------- -from modelopt.torch.puzzletron.sewing_kit.utils import ( # noqa: E402 - is_submodule_of, - is_submodule_or_same, -) - - def test_is_submodule_of_proper_descendant(): assert is_submodule_of("model.layers.0.self_attn", "model.layers.0") assert is_submodule_of("model.layers.0", "model") diff --git a/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py b/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py index f5e52cb5032..a568fadc07b 100644 --- a/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py +++ b/tests/unit/torch/puzzletron/test_sewing_kit_input_args.py @@ -76,8 +76,11 @@ def test_add_does_not_mutate_operands(): def test_add_rejects_non_input_args(): + # ``__add__`` enforces InputArgs+InputArgs only via an internal assert. + # ruff's RUF005 auto-fix to ``[*InputArgs(1), 2]`` would silently replace + # the operator call we're testing — keep the explicit ``+`` form. with pytest.raises(AssertionError): - InputArgs(1) + [2] # type: ignore[operator] + InputArgs(1) + [2] # type: ignore[operator] # noqa: RUF005 # --------------------------------------------------------------------------- From 4ea32625595abfb46f7a4955b13d50f18ee39670 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Thu, 7 May 2026 01:29:56 -0700 Subject: [PATCH 11/13] Collapse remaining multi-line test signatures (ruff format) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Four test signatures in test_bypass_checkpoint_utils.py fit on a single line within the 100-char limit; ruff format wants them collapsed. The fifth signature in the same file (test_save_bypass_checkpoint_master_only_skips_symlink_on_non_master) has three args and stays multi-line. Verified with `awk 'length > 100'` and `grep -nE '^def test.*\(\s*$'` across all new test files — no other lint nits should remain. Signed-off-by: Sepehr Sameni --- .../puzzletron/test_bypass_checkpoint_utils.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py index 518ff36869d..0b193d536e3 100644 --- a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -200,9 +200,7 @@ def test_save_local_file_overwrite_false_skips_existing(tmp_path: Path): # --------------------------------------------------------------------------- -def test_save_local_state_writes_state_dict_optimizer_and_grad_scaler( - tmp_path: Path, bcu_no_dist -): +def test_save_local_state_writes_state_dict_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): descriptors = OrderedDict([("block_0", _make_descriptor())]) bcu_no_dist._save_local_state(descriptors, tmp_path) stitched = tmp_path / "stitched" @@ -211,9 +209,7 @@ def test_save_local_state_writes_state_dict_optimizer_and_grad_scaler( assert (stitched / "block_0.grad_scaler.pth").exists() -def test_save_local_state_skips_grad_scaler_when_descriptor_has_none( - tmp_path: Path, bcu_no_dist -): +def test_save_local_state_skips_grad_scaler_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): descriptors = OrderedDict([("block_0", _make_descriptor(with_scaler=False))]) bcu_no_dist._save_local_state(descriptors, tmp_path) stitched = tmp_path / "stitched" @@ -221,9 +217,7 @@ def test_save_local_state_skips_grad_scaler_when_descriptor_has_none( assert not (stitched / "block_0.grad_scaler.pth").exists() -def test_save_local_state_skips_optimizer_when_descriptor_has_none( - tmp_path: Path, bcu_no_dist -): +def test_save_local_state_skips_optimizer_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): descriptors = OrderedDict( [("block_0", _make_descriptor(with_optimizer=False, with_scaler=False))] ) @@ -264,9 +258,7 @@ def patched_save(monkeypatch, bcu_no_dist): return bcu_no_dist -def test_save_bypass_checkpoint_creates_latest_symlink_and_marker( - tmp_path: Path, patched_save -): +def test_save_bypass_checkpoint_creates_latest_symlink_and_marker(tmp_path: Path, patched_save): experiment_dir = tmp_path / "exp" experiment_dir.mkdir() checkpoint_dir = experiment_dir / "iter-000007-ckpt" From 623429ec7252ffed2aecaf60ade3d1b738a83b21 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Thu, 7 May 2026 05:47:36 -0700 Subject: [PATCH 12/13] Address Claude review and halve bypass checkpoint disk usage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop the per-block weight save from `_save_local_state` — the same parameters were already on disk in the top-level HF checkpoint that `save_bypass_checkpoint` writes via `save_checkpoint(model, ...)`. The stitched/ directory now carries only optimizer + grad_scaler state; weights round-trip through the HF format. Resume now routes through `load_and_shard_model` (same path as `init_checkpoint_path`) so weight loading has a single entry point. Per-iter checkpoint disk footprint roughly halves. Other Claude review fixes: * Use modelopt.torch.utils.robust_json (canonical) and delete the duplicate puzzletron-local copy. * Remove dead num_trainable_params block in bypass_factory_fn (unused and the sum was a bool count, not a numel sum). * GptOssModelDescriptor: keep "expert_removal" as a deprecation alias for the new "experts_removal" key so existing string-API callers don't break. * Move time_start out of module scope into train() — at module level it became stale relative to actual training start, firing the first time-based save immediately. * Batch the per-block loss GPU->CPU copy into a single sync after the loop (was N sync points per training step). * puzzletron_nas_plugin staleness check: probe checkpoint config.json mtime instead of resolved-symlink directory mtime; handle dangling symlinks via try/except. * Fix off-by-one in _get_lr cosine: decay_ratio = (step - W) / (D - W) so the schedule reaches min_lr exactly at step==D. * batched_normalized_mse_loss: clamp the per-vector denominator to a floor of epsilon so an all-zero target slice doesn't explode the loss. Slightly diverges from the original Puzzle implementation (documented in the docstring). * Add HF-convention comment in stitched_model_factory: tuple-returning blocks always have hidden_states at index 0. Tests updated to match: stitched/{block}.state_dict.pth assertion flipped to assert-not-written; new regression tests for the LR schedule endpoint and the zero-target loss path. Signed-off-by: Sepehr Sameni --- .../gpt_oss/gpt_oss_model_descriptor.py | 12 ++- .../bypass_checkpoint_utils.py | 43 ++++---- .../stitched_model_factory.py | 11 +- .../bypass_distillation/training_loop.py | 100 ++++++++++++------ .../torch/puzzletron/puzzletron_nas_plugin.py | 37 +++++-- modelopt/torch/puzzletron/sewing_kit/utils.py | 10 ++ .../torch/puzzletron/tools/robust_json.py | 77 -------------- .../test_bypass_checkpoint_utils.py | 83 ++++----------- .../test_bypass_checkpoint_utils.py | 19 +++- .../torch/puzzletron/test_bypass_losses.py | 20 ++++ 10 files changed, 198 insertions(+), 214 deletions(-) delete mode 100644 modelopt/torch/puzzletron/tools/robust_json.py diff --git a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py index eb2cfe68688..1abecdec0c2 100644 --- a/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py +++ b/modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py @@ -178,8 +178,18 @@ def pruning_mixins() -> Dict[str, PruningMixIn]: Note: Expert removal works for unquantized models (test models). Production models use MXFP4 quantization which is not yet supported. """ + # Single instance shared between the canonical key and the legacy alias + # so resolve_pruning_mixin returns the same object regardless of which + # name a caller uses. + expert_mixin = ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor()) return { - "experts_removal": ExpertRemovalPruningMixIn(GptOssExpertRemovalLayerDescriptor()), + "experts_removal": expert_mixin, + # Backward-compat alias: this key was "expert_removal" before the + # bypass branch standardised on "experts_removal" (matching the + # NemotronH descriptor). Kept so external scripts that still call + # `resolve_pruning_mixin("expert_removal", GptOssModelDescriptor)` + # continue to work. Remove after a deprecation cycle. + "expert_removal": expert_mixin, "kv_heads": KVHeadsPruningMixIn(GptOssKVHeadsLayerDescriptor()), } diff --git a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py index fbb658d4e57..2acded1d06c 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py +++ b/modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py @@ -28,7 +28,7 @@ from modelopt.torch.puzzletron.anymodel.model_descriptor import ModelDescriptor from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import save_checkpoint from modelopt.torch.puzzletron.tools.logger import aprint, mprint -from modelopt.torch.puzzletron.tools.robust_json import json_dump +from modelopt.torch.utils.robust_json import json_dump from .stitched_model_factory import StitchedModuleDescriptor @@ -73,10 +73,16 @@ def load_local_state( stitched_module_descriptors: OrderedDict[str, StitchedModuleDescriptor], checkpoint_path: str | Path, ) -> None: - """Load local state from a checkpoint. + """Load optimizer and grad-scaler state for each stitched module. - Loads both optimizer and state dicts into stitched module descriptors. - Modifies stitched_module_descriptors in place. + Weights are NOT loaded here — they live in the HF checkpoint at + ``checkpoint_path`` and must be loaded into the student model via + ``load_and_shard_model`` before this function runs (typically by setting + ``init_checkpoint_path`` to the resume directory). This avoids + persisting the same parameters twice (once in ``stitched/*.pth`` and + once in the HF state dict). + + Modifies ``stitched_module_descriptors`` in place. """ device = torch.device(f"cuda:{dist.local_rank()}") load_dir = Path(checkpoint_path) @@ -85,18 +91,9 @@ def load_local_state( raise RuntimeError(f'Can\'t load local state. "{load_dir}" does not exist.') for stitched_module_name, stitched_module_descriptor in stitched_module_descriptors.items(): - stitched_module = stitched_module_descriptor.stitched_module optimizer = stitched_module_descriptor.optimizer grad_scaler = stitched_module_descriptor.grad_scaler - state_dict_path = load_dir / "stitched" / f"{stitched_module_name}.state_dict.pth" - mprint(f"Loading state dict for module {stitched_module_name} from {state_dict_path}") - loaded_state_dict = torch.load(state_dict_path, map_location=device, weights_only=True) - loaded_state_dict = {**stitched_module.state_dict(), **loaded_state_dict} - - stitched_module.load_state_dict(loaded_state_dict) - del loaded_state_dict - if optimizer is not None: optimizer_state_path = ( load_dir / "stitched" / f"{stitched_module_name}.optimizer_state.pth" @@ -144,6 +141,16 @@ def _save_local_state( checkpoint_dir: Path | str, overwrite=True, ) -> None: + """Persist optimizer and grad-scaler state for each stitched module. + + Weights are intentionally NOT saved here. The same trainable parameters + would otherwise land on disk twice — once as ``stitched/{block}.state_dict.pth`` + and once as part of the HF checkpoint that ``save_bypass_checkpoint`` + writes at the top level via ``save_checkpoint(model, ...)``. The HF + checkpoint is the single source of truth for weights; this directory + only carries the optimizer/scaler state that the HF format doesn't + cover. + """ save_dir = Path(checkpoint_dir) / "stitched" if dist.is_master(): @@ -158,17 +165,9 @@ def _save_local_state( optimizer = stitched_module_descriptor.optimizer grad_scaler = stitched_module_descriptor.grad_scaler - state_dict_path = save_dir / f"{stitched_module_name}.state_dict.pth" - aprint(f"Saving state dict for module {stitched_module_name} to {state_dict_path}") - state_dict = { - **stitched_module_descriptor.owned_parameters, - **stitched_module_descriptor.owned_buffers, - } - _save_local_file(state_dict, state_dict_path, overwrite=overwrite) - if optimizer is not None: optimizer_state_path = save_dir / f"{stitched_module_name}.optimizer_state.pth" - mprint( + aprint( f"Saving optimizer state for module {stitched_module_name} to {optimizer_state_path}" ) _save_local_file(optimizer.state_dict(), optimizer_state_path, overwrite=overwrite) diff --git a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py index 21a89d11762..7c6b090e36b 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py +++ b/modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py @@ -522,6 +522,12 @@ def bypass_factory_fn( ) student_stitched_module_name = f"block_{global_block_index}" student_submodule_target = ModuleTarget("student_submodule", module) + # When a block returns a tuple, ``v[0]`` is the hidden state by + # HF convention — every HF transformer block (Llama, Qwen, GPT-OSS, + # NemotronH, …) returns ``(hidden_states, *aux)``, with ``aux`` + # varying (attention weights, KV cache, router logits, …) but + # element 0 always being the hidden state. Puzzletron is HF-format- + # only, so this assumption holds across every supported family. student_stitched_module = ( Needle() .stitch( @@ -557,11 +563,6 @@ def bypass_factory_fn( ) assert "learning_rate" in cfg.training - num_trainable_params = sum( - p.requires_grad and submodule_name in p_name - for p_name, p in student_stitched_module.named_parameters() - if "dummy_param" not in p_name # exclude placeholder params - ) # Do NOT enable dummy params: blocks with no real trainable parameters # (e.g. Mamba blocks during an attention-only bypass run) should produce # NaN loss so they are excluded from statistics — identical to the diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index 8486584b1c6..daf62cf06ae 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -51,17 +51,15 @@ from modelopt.torch.puzzletron.sewing_kit.utils import fake_tensor from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import load_model_config from modelopt.torch.puzzletron.tools.logger import aprint, mprint -from modelopt.torch.puzzletron.tools.robust_json import json_load from modelopt.torch.puzzletron.tools.sharded_checkpoint_utils import load_and_shard_model from modelopt.torch.puzzletron.utils.parsing import format_global_config, format_stitched_losses +from modelopt.torch.utils.robust_json import json_load from .bypass_checkpoint_utils import find_latest_run_dir, load_local_state, save_bypass_checkpoint from .bypass_utils import get_distributed_modules_ownership, set_experiment_dir, set_experiment_id from .data_classes import GlobalRank, IterNum, IterStatistics, LocalTrainingStats, TimeToSaveSignal from .stitched_model_factory import StitchedModuleDescriptor, StitchedModulesProcessOwnership -time_start = time.time() - os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -132,7 +130,12 @@ def train( dist.barrier() - time_last_save = time_start + # Anchor the time-based save interval at training start, not module import. + # Earlier this was a module-level `time_start = time.time()`, which made + # the first time-based save fire immediately if the module was imported + # well before train() actually ran (e.g. via test collection or Hydra config + # resolution). + time_last_save = time.time() iter_t0 = time.time() resumed_iter_num = cfg.bypass.iter_num @@ -284,7 +287,11 @@ def train( input_overrides["teacher_inputs"] = InputArgs(fake_input_ids) - iter_stitched_module_losses: dict[str, float] = {} + # Collect per-block loss tensors and batch the GPU→CPU copy to a + # single sync point at the end of the per-block loop. Doing + # ``.to("cpu").item()`` per block forced one CUDA synchronization per + # block per iter, serialising the GPU pipeline across N blocks. + iter_loss_tensors: dict[str, torch.Tensor] = {} for local_stitched_module_index, ( stitched_module_name, @@ -306,11 +313,15 @@ def train( del stitched_module_output grad_scaler.scale(stitched_module_loss).backward() else: - stitched_module_loss = torch.full([1], fill_value=torch.nan, dtype=torch.float32) + # Match the device of the optimizer-yes branch so all per-block + # loss tensors can be stacked into a single GPU tensor below. + stitched_module_loss = torch.full( + [1], fill_value=torch.nan, dtype=torch.float32, device=device + ) - iter_stitched_module_losses[stitched_module_name] = stitched_module_loss.to( - "cpu" - ).item() + # Detach to drop the autograd graph (we already called backward + # above) and defer the GPU→CPU copy to after the per-block loop. + iter_loss_tensors[stitched_module_name] = stitched_module_loss.detach() del stitched_module_loss @@ -350,6 +361,19 @@ def train( grad_scaler.update() optimizer.zero_grad(set_to_none=True) + # Single GPU→CPU sync for all per-block losses collected above. Stacking + # into a 1-D tensor lets us issue exactly one ``.to("cpu")`` instead of + # one per block. + if iter_loss_tensors: + loss_stack = torch.stack( + [t.flatten()[0] for t in iter_loss_tensors.values()] + ) + iter_stitched_module_losses: dict[str, float] = dict( + zip(iter_loss_tensors.keys(), loss_stack.to("cpu").tolist()) + ) + else: + iter_stitched_module_losses = {} + # Collect losses from all ranks using all_gather_object local_training_stats = LocalTrainingStats( iter_num=cfg.bypass.iter_num, @@ -660,12 +684,39 @@ def run_bypassed_training(cfg: DictConfig): ): teacher_model_config.text_config.use_cache = False + # Resume detection has to run BEFORE the weight-loading branch below + # so a resume can route through ``load_and_shard_model`` (the HF + # checkpoint at ``resume_checkpoint_path`` is now the single source + # of truth for weights — see _save_local_state docstring). + # set_experiment_id / set_experiment_dir are idempotent and only + # depend on cfg.bypass.model.model_config_overrides + cfg.puzzle_dir, + # so it's safe to call them this early. + set_experiment_id(cfg) + set_experiment_dir(cfg) + resume_checkpoint_path: Optional[str] = None + if cfg.bypass.resume_checkpoint_path is not None: + resume_checkpoint_path = cfg.bypass.resume_checkpoint_path + elif cfg.bypass.find_last_ckpt_for_resume: + _ckpt_dir = find_latest_run_dir(run_parent_dir=cfg.bypass.experiment_dir) + if _ckpt_dir is None: + mprint("Couldn't find any run dir for resume, assuming this is the first job") + else: + mprint( + f"`cfg.bypass.find_last_ckpt_for_resume` is True. " + f"Auto-found a checkpoint to resume: `{_ckpt_dir}`" + ) + resume_checkpoint_path = _ckpt_dir + + # Both ``init_checkpoint_path`` and ``resume_checkpoint_path`` point at + # an HF-format directory; share the same loader. ``init_checkpoint_path`` + # wins if both are set (explicit user override beats auto-detect). + weight_load_path = cfg.bypass.init_checkpoint_path or resume_checkpoint_path student_model = None - if cfg.bypass.init_checkpoint_path is not None: - mprint(f"Loading student model from {cfg.bypass.init_checkpoint_path}") + if weight_load_path is not None: + mprint(f"Loading student model from {weight_load_path}") student_model = load_and_shard_model( descriptor=descriptor, - checkpoint_path=cfg.bypass.init_checkpoint_path, + checkpoint_path=weight_load_path, owned_block_indexes=owned_block_indexes, ) @@ -772,10 +823,8 @@ def run_bypassed_training(cfg: DictConfig): bos_rate=cfg.bypass.data.bos_rate, ) - # Set ID from experiment configuration - set_experiment_id(cfg) - # Set directory for experiment ID - set_experiment_dir(cfg) + # set_experiment_id / set_experiment_dir already ran above (before + # weight loading) so the resume detection could use experiment_dir. dist.barrier() @@ -798,21 +847,10 @@ def run_bypassed_training(cfg: DictConfig): student_model=student_model, ) - # Check whether to resume from checkpoint - resume_checkpoint_path = None - if cfg.bypass.resume_checkpoint_path is not None: - resume_checkpoint_path = cfg.bypass.resume_checkpoint_path - elif cfg.bypass.find_last_ckpt_for_resume: - _ckpt_dir = find_latest_run_dir(run_parent_dir=cfg.bypass.experiment_dir) - if _ckpt_dir is None: - mprint("Couldn't find any run dir for resume, assuming this is the first job") - else: - mprint( - f"`cfg.bypass.find_last_ckpt_for_resume` is True. " - f"Auto-found a checkpoint to resume: `{_ckpt_dir}`" - ) - resume_checkpoint_path = _ckpt_dir - + # ``resume_checkpoint_path`` was determined earlier (before weight + # loading); the student weights are already in place via + # ``load_and_shard_model``. Only the optimizer/scaler state needs to + # be restored from the per-block ``stitched/`` files. if resume_checkpoint_path: load_local_state( stitched_module_descriptors=stitched_module_descriptors, diff --git a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py index 438e5f4e298..e50f4effd4f 100644 --- a/modelopt/torch/puzzletron/puzzletron_nas_plugin.py +++ b/modelopt/torch/puzzletron/puzzletron_nas_plugin.py @@ -342,22 +342,41 @@ def run_search(self) -> None: puzzle_dir = Path(self.model.puzzle_dir) replacement_library_path = puzzle_dir / "replacement_library.json" subblock_stats_path = puzzle_dir / hydra_cfg.calc_subblock_stats.subblock_stats_filename - # Detect a stale library: any ckpts/* entry newer than the library file means - # a new replacement (e.g. bypass-trained subblocks) appeared after the last build - # and must be picked up. Without this check, our skip-if-done would happily reuse - # a no-bypass library even after bypass completes. + # Detect a stale library: any ckpts/* entry whose finalisation marker + # is newer than the library file means a new replacement (e.g. bypass- + # trained subblocks) appeared after the last build and must be picked + # up. Without this check, our skip-if-done would happily reuse a + # no-bypass library even after bypass completes. + # + # We probe ``config.json`` rather than the directory mtime because: + # 1. directory mtime tracks "an entry was added/removed", not "a file + # inside was modified" — adding new shards to an existing checkpoint + # dir wouldn't bump the dir mtime; + # 2. ``entry.resolve()`` on a dangling symlink raises (or returns a + # non-existent path), which the previous code's ``resolved.exists()`` + # silently treated as "not stale"; + # 3. ``config.json`` is written last when a checkpoint is finalised — + # its mtime is the real "checkpoint ready" timestamp. + # The check is conservative: false positives just trigger a rebuild, + # which is safe. ckpts_dir = puzzle_dir / "ckpts" library_is_stale = False if replacement_library_path.exists() and ckpts_dir.exists(): library_mtime = replacement_library_path.stat().st_mtime for entry in ckpts_dir.iterdir(): - # Resolve symlinks (bypass + pruning checkpoints land here as symlinks - # to the real directories elsewhere under puzzle_dir). - resolved = entry.resolve() if entry.is_symlink() else entry - if resolved.exists() and resolved.stat().st_mtime > library_mtime: + # `Path.stat()` follows symlinks by default, so this works + # whether `entry` is a real dir or a symlink to one (the + # bypass and pruning pipelines both land here as symlinks). + # `try` guards against dangling symlinks (FileNotFoundError). + config_path = entry / "config.json" + try: + config_mtime = config_path.stat().st_mtime + except (FileNotFoundError, OSError): + continue + if config_mtime > library_mtime: library_is_stale = True mprint( - f"Replacement library is stale: '{entry.name}' is newer than the existing library, will rebuild." + f"Replacement library is stale: '{entry.name}/config.json' is newer than the existing library, will rebuild." ) break if dist.is_master(): diff --git a/modelopt/torch/puzzletron/sewing_kit/utils.py b/modelopt/torch/puzzletron/sewing_kit/utils.py index b513715b223..0c190ee10d1 100644 --- a/modelopt/torch/puzzletron/sewing_kit/utils.py +++ b/modelopt/torch/puzzletron/sewing_kit/utils.py @@ -487,10 +487,20 @@ def batched_normalized_mse_loss( Useful when activations within a batch item should be normalized independently rather than normalizing across the full batch. + + Note: this slightly diverges from the original Puzzle implementation. With + per-batch-element normalization, an all-zero target slice produces a + denominator of ``epsilon ** 2 ~= 1e-12``, which then explodes the loss for + that slice (the global-reduction variant in ``normalized_mse_loss`` dilutes + it across non-zero elements, hiding the issue). We clamp the denominator + to a floor of ``epsilon`` so the per-element minimum matches the intent of + the epsilon term. The clamp only triggers on near-zero target slices — + typical activations are unaffected. """ norm_dims = list(set(range(input.ndim)) - set(batch_dims)) norm_of_target_vectors = F.mse_loss( target, torch.zeros_like(target) + epsilon, reduction="none" ).mean(norm_dims) + norm_of_target_vectors = norm_of_target_vectors.clamp(min=epsilon) loss = F.mse_loss(input, target, reduction="none").mean(norm_dims) / norm_of_target_vectors return loss.mean() diff --git a/modelopt/torch/puzzletron/tools/robust_json.py b/modelopt/torch/puzzletron/tools/robust_json.py deleted file mode 100644 index 0b424dce95c..00000000000 --- a/modelopt/torch/puzzletron/tools/robust_json.py +++ /dev/null @@ -1,77 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# mypy: ignore-errors - -""" -Provides a robust JSON encoder that can handle various types of objects, -including dataclasses, paths, enums, namespaces, and functions. -""" - -import argparse -import dataclasses -import datetime -import inspect -import json -from enum import Enum -from pathlib import Path -from typing import Any - -from omegaconf import DictConfig, ListConfig, OmegaConf - - -class RobustJSONEncoder(json.JSONEncoder): - def default(self, o): - if dataclasses.is_dataclass(o): - return dataclasses.asdict(o) - if isinstance(o, Path): - return str(o) - if isinstance(o, Enum): - return o.name - if isinstance(o, argparse.Namespace): - return vars(o) - if type(o).__name__ == "dtype": - return str(o) - if isinstance(o, (DictConfig, ListConfig)): - return OmegaConf.to_container(o, resolve=True) - if inspect.isfunction(o) or inspect.ismethod(o): - if o.__module__ == "__main__": - # User-defined function in main — fallback to just the name - return o.__name__ - return f"{o.__module__}.{o.__qualname__}" - if inspect.isclass(o): - return f"{o.__module__}.{o.__qualname__}" - if isinstance(o, datetime.timedelta): - return str(o) - # Fallback for arbitrary objects: return their class path - if hasattr(o, "__class__") and hasattr(o.__class__, "__module__"): - return f"{o.__class__.__module__}.{o.__class__.__qualname__}" - return super().default(o) - - -def json_dumps(obj: Any) -> str: - return json.dumps(obj, cls=RobustJSONEncoder, indent=2) - - -def json_dump(obj: Any, path: Path | str) -> None: - path = Path(path) - path.parent.mkdir(exist_ok=True, parents=True) - json_text = json_dumps(obj) - path.write_text(json_text) - - -def json_load(path: Path | str) -> dict: - path = Path(path) - text = path.read_text() - return json.loads(text) diff --git a/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py index a813d0060b7..dc0df1b4f6b 100644 --- a/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py +++ b/tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -13,25 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Save/load round-trip tests for ``bypass_checkpoint_utils``. - -These tests pin two correctness-critical pieces of ``bypass_checkpoint_utils``: - -1. ``_save_local_state`` persists the GradScaler state alongside the optimizer - state (regression coverage for the recent CodeRabbit-driven fix — without - it, fp16 + use_grad_scaling=True runs silently lost the running scale + - growth tracker on resume). -2. ``load_local_state`` restores it from disk. - -Lives under ``tests/gpu/`` because the production ``load_local_state`` builds -``torch.device(f"cuda:{rank}")`` for ``map_location``, so a real CUDA device -is required to round-trip ``torch.load`` without monkeypatching the device -machinery. The full bypass GPU integration test cannot cover this path -because the test infrastructure ships bf16 and ``GradScaler.step()`` is -fp16-only (raises ``NotImplementedError: +"""Load round-trip tests for ``bypass_checkpoint_utils``. + +These pin that ``load_local_state`` correctly restores the optimizer and +grad-scaler state from disk into a fresh descriptor — the resume path's +main job after the recent dedupe (weights are now loaded from the HF +checkpoint via ``load_and_shard_model``, not from ``stitched/*.pth``). + +Lives under ``tests/gpu/`` because the production ``load_local_state`` +builds ``torch.device(f"cuda:{rank}")`` for ``map_location``, so a real CUDA +device is required to round-trip ``torch.load`` without monkeypatching the +device machinery. The full bypass GPU integration test cannot cover this +path because the test infrastructure ships bf16 and ``GradScaler.step()`` +is fp16-only (raises ``NotImplementedError: _amp_foreach_non_finite_check_and_unscale_cuda not implemented for 'BFloat16'``). -These tests sidestep that by hitting the save/load functions directly, -without ever invoking ``.step()``. +These tests sidestep that by hitting the load functions directly, without +ever invoking ``.step()``. + +The corresponding save tests live in tests/unit/torch/puzzletron/ +test_bypass_checkpoint_utils.py — ``_save_local_state`` no longer touches +CUDA, so it doesn't need a GPU lane. """ from collections import OrderedDict @@ -99,52 +100,6 @@ def _make_descriptor( ) -# --------------------------------------------------------------------------- -# Save: every relevant artifact lands on disk -# --------------------------------------------------------------------------- - - -def test_save_local_state_writes_state_dict_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): - bcu = bcu_no_dist - descriptor = _make_descriptor() - descriptors = OrderedDict([("block_0", descriptor)]) - - bcu._save_local_state(descriptors, tmp_path) - - stitched = tmp_path / "stitched" - assert (stitched / "block_0.state_dict.pth").exists() - assert (stitched / "block_0.optimizer_state.pth").exists() - # The CodeRabbit-driven fix added this third file. Without it, resuming - # an fp16 + grad-scaling run would default-init the scaler. - assert (stitched / "block_0.grad_scaler.pth").exists() - - -def test_save_local_state_skips_grad_scaler_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): - bcu = bcu_no_dist - descriptor = _make_descriptor(with_scaler=False) - descriptors = OrderedDict([("block_0", descriptor)]) - - bcu._save_local_state(descriptors, tmp_path) - - stitched = tmp_path / "stitched" - assert (stitched / "block_0.state_dict.pth").exists() - # No scaler in the descriptor → no .grad_scaler.pth file written. - assert not (stitched / "block_0.grad_scaler.pth").exists() - - -def test_save_local_state_skips_optimizer_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): - """Pipeline-parallel idle ranks pass optimizer=None; no file should appear.""" - bcu = bcu_no_dist - descriptor = _make_descriptor(with_optimizer=False, with_scaler=False) - descriptors = OrderedDict([("block_0", descriptor)]) - - bcu._save_local_state(descriptors, tmp_path) - - stitched = tmp_path / "stitched" - assert (stitched / "block_0.state_dict.pth").exists() - assert not (stitched / "block_0.optimizer_state.pth").exists() - - # --------------------------------------------------------------------------- # Load: state survives the round-trip and lands back on the live scaler # --------------------------------------------------------------------------- diff --git a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py index 0b193d536e3..a69fa88fc77 100644 --- a/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py +++ b/tests/unit/torch/puzzletron/test_bypass_checkpoint_utils.py @@ -196,24 +196,33 @@ def test_save_local_file_overwrite_false_skips_existing(tmp_path: Path): # --------------------------------------------------------------------------- -# _save_local_state — CPU-mirror of the three GPU save tests so codecov sees them +# _save_local_state: optimizer + grad_scaler only. +# Weights deliberately do NOT land here — the HF checkpoint at the same +# directory carries the full student state dict via ``save_checkpoint``. +# Saving the per-block weights again would just double the disk footprint. # --------------------------------------------------------------------------- -def test_save_local_state_writes_state_dict_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): +def test_save_local_state_writes_optimizer_and_grad_scaler(tmp_path: Path, bcu_no_dist): descriptors = OrderedDict([("block_0", _make_descriptor())]) bcu_no_dist._save_local_state(descriptors, tmp_path) stitched = tmp_path / "stitched" - assert (stitched / "block_0.state_dict.pth").exists() assert (stitched / "block_0.optimizer_state.pth").exists() assert (stitched / "block_0.grad_scaler.pth").exists() +def test_save_local_state_does_not_write_weights_state_dict(tmp_path: Path, bcu_no_dist): + """Pin the de-duplication: weights live in the HF checkpoint, not here.""" + descriptors = OrderedDict([("block_0", _make_descriptor())]) + bcu_no_dist._save_local_state(descriptors, tmp_path) + assert not (tmp_path / "stitched" / "block_0.state_dict.pth").exists() + + def test_save_local_state_skips_grad_scaler_when_descriptor_has_none(tmp_path: Path, bcu_no_dist): descriptors = OrderedDict([("block_0", _make_descriptor(with_scaler=False))]) bcu_no_dist._save_local_state(descriptors, tmp_path) stitched = tmp_path / "stitched" - assert (stitched / "block_0.state_dict.pth").exists() + assert (stitched / "block_0.optimizer_state.pth").exists() assert not (stitched / "block_0.grad_scaler.pth").exists() @@ -223,8 +232,8 @@ def test_save_local_state_skips_optimizer_when_descriptor_has_none(tmp_path: Pat ) bcu_no_dist._save_local_state(descriptors, tmp_path) stitched = tmp_path / "stitched" - assert (stitched / "block_0.state_dict.pth").exists() assert not (stitched / "block_0.optimizer_state.pth").exists() + assert not (stitched / "block_0.grad_scaler.pth").exists() # --------------------------------------------------------------------------- diff --git a/tests/unit/torch/puzzletron/test_bypass_losses.py b/tests/unit/torch/puzzletron/test_bypass_losses.py index 4f6869c4e6d..dad3a11adf5 100644 --- a/tests/unit/torch/puzzletron/test_bypass_losses.py +++ b/tests/unit/torch/puzzletron/test_bypass_losses.py @@ -113,3 +113,23 @@ def test_batched_normalized_mse_loss_custom_dims(): assert loss.ndim == 0 # scalar assert torch.isfinite(loss) assert loss.item() > 0.0 + + +def test_batched_normalized_mse_loss_zero_target_does_not_explode(): + """All-zero target slice would otherwise divide by epsilon**2 ~= 1e-12 and + blow the loss up to ~1e12; the clamp on the per-vector denominator floors + that at epsilon, keeping the loss bounded for the all-zero-target case. + + Without the clamp, this test asserts a value on the order of 1e12 instead + of a small finite number. + """ + # One batch element with all-zero target; non-zero input forces a positive + # numerator so the division actually exercises the denominator path. + input_ = torch.full((1, 8), 1.0) + target = torch.zeros(1, 8) + loss = batched_normalized_mse_loss(input_, target) + assert torch.isfinite(loss) + # With clamp(min=epsilon=1e-6), denominator is ≈ epsilon, numerator is + # mse(1.0, 0.0) = 1.0 → loss ≈ 1.0 / 1e-6 = 1e6 (not 1e12). Use a loose + # upper bound to pin "doesn't explode" without coupling to epsilon's value. + assert loss.item() < 1e9 From 837937979e5c34fb3f0a7ac06c5953f0d0864c70 Mon Sep 17 00:00:00 2001 From: Sepehr Sameni Date: Thu, 7 May 2026 05:54:06 -0700 Subject: [PATCH 13/13] Collapse loss_stack call to one line (ruff format) The stacked-loss expression fits on a single line within the 100-char limit; ruff format unwrapped it. Signed-off-by: Sepehr Sameni --- .../torch/puzzletron/bypass_distillation/training_loop.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py index daf62cf06ae..7a496475fcf 100644 --- a/modelopt/torch/puzzletron/bypass_distillation/training_loop.py +++ b/modelopt/torch/puzzletron/bypass_distillation/training_loop.py @@ -365,9 +365,7 @@ def train( # into a 1-D tensor lets us issue exactly one ``.to("cpu")`` instead of # one per block. if iter_loss_tensors: - loss_stack = torch.stack( - [t.flatten()[0] for t in iter_loss_tensors.values()] - ) + loss_stack = torch.stack([t.flatten()[0] for t in iter_loss_tensors.values()]) iter_stitched_module_losses: dict[str, float] = dict( zip(iter_loss_tensors.keys(), loss_stack.to("cpu").tolist()) )