Skip to content

Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111

Open
Separius wants to merge 16 commits intomainfrom
ssameni/puzzletron-bypass
Open

Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111
Separius wants to merge 16 commits intomainfrom
ssameni/puzzletron-bypass

Conversation

@Separius
Copy link
Copy Markdown
Contributor

@Separius Separius commented Mar 24, 2026

Bypass distillation trains alternative transformer block configurations using per-block knowledge distillation from the teacher model, producing a library of better "puzzle pieces" for the MIP solver. It is most beneficial for aggressive FFN pruning (≤ 1/8 of teacher width) or significant KV head compression.

Changes:

  • Add modelopt/torch/puzzletron/bypass_distillation/ module with full training loop, stitched model factory, checkpoint management, and data classes
  • Integrate bypass as optional Step 3 in puzzletron.py and puzzletron_nas_plugin.py (pipeline progress counter updates to 9 steps when bypass is enabled)
  • Add HuggingFace auto-download and skip-if-exists logic to puzzletron_nas_plugin.py for all pipeline steps
  • Add normalized_mse_loss, vectorwise_normalized_mse_loss, and batched_normalized_mse_loss to sewing_kit/utils.py
  • Fix child_init.py: support list of pruning mixins; fix None override treated as "keep original value" instead of raising TypeCheckError
  • Fix dataset.py: graceful fallback when tokenizer has no chat_template (base models)
  • Add _copy_auto_map_code_files to checkpoint_utils_hf.py so modeling Python files are copied alongside config.json (required for trust_remote_code checkpoints such as NemotronH)
  • Add create_train_dataloader to dataloaders.py
  • Add MoEChannelPruning to MlpInitMode enum
  • Add default pruning_mixins() to ModelDescriptor base class
  • Add NemotronHKVHeadsLayerDescriptor and kv_heads mixin to NemotronH descriptor; fix _set_keys_to_learn to skip Mamba blocks during subblock_attention bypass (based on block config)
  • Enable bypass in llama-3_1-8B_pruneffn_memory config; add example bypass/defaults.yaml
  • Update README with bypass documentation: when to use, time cost, sequential execution, W&B logging
  • Add unit tests for loss functions and distribution utilities
  • Add GPU integration tests for bypass (FFN pruning, KV compression, multi-config sweep, checkpoint validation)
  • Fix test_puzzletron.py assertion to handle variable GPU counts

Summary by CodeRabbit

Release Notes

  • New Features

    • Added bypass distillation training for blockwise local knowledge transfer in the PUZZLE framework
    • Introduced KV-heads pruning support for Nemotron models
    • Added Nemotron-3-Nano-30B-A3B model configuration and full pipeline support
    • Implemented normalized MSE loss functions for knowledge distillation training
  • Tests

    • Added comprehensive integration tests for bypass distillation workflows
  • Documentation

    • New tutorial for KV-heads pruning with bypass distillation accuracy recovery

@Separius Separius requested review from a team as code owners March 24, 2026 16:21
@Separius Separius requested a review from cjluo-nv March 24, 2026 16:21
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 24, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 24, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • ✅ Review completed - (🔄 Check again to review again)
📝 Walkthrough

Walkthrough

Adds an optional bypass (blockwise local) distillation stage: new bypass package with stitched teacher–student factory, distributed training loop and checkpointing, model/pruning extensions, normalized-MSE losses, dataloader helper, example configs/docs, and unit/GPU tests; integrates bypass into Puzzletron control flow.

Changes

Cohort / File(s) Summary
Bypass package & core logic
modelopt/torch/puzzletron/bypass_distillation/__init__.py, modelopt/torch/puzzletron/bypass_distillation/training_loop.py, modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py, modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py, modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py, modelopt/torch/puzzletron/bypass_distillation/data_classes.py
New bypass distillation package: entrypoint, sweep/run orchestration, stitched teacher↔student factory, distributed per-block training loop, checkpoint save/load utilities, experiment id/dir helpers, and dataclasses.
Pipeline integration & orchestrator
modelopt/torch/puzzletron/puzzletron.py, modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py, examples/puzzletron/main.py
Inserted optional bypass stage into main pipeline/NAS plugin; added _total_steps, dynamic progress reporting, restartable *.complete markers, HF auto-download path, and refactored setup with longer distributed timeouts.
Stitched-model & pruning extensions
modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py, modelopt/torch/puzzletron/anymodel/models/nemotron_h/...nemotron_h_model_descriptor.py, modelopt/torch/puzzletron/pruning/pruning_utils.py
Added ModelDescriptor.pruning_mixins() hook; Nemotron‑H KV‑heads layer descriptor and KV‑heads pruning mixin; added MoEChannelPruning enum and dispatch into MLP init/pruning flow.
Loss utilities
modelopt/torch/puzzletron/sewing_kit/utils.py, modelopt/torch/puzzletron/tools/kd_model.py
Added vectorwise_normalized_mse_loss and batched_normalized_mse_loss; adjusted normalized_mse_loss epsilon placement.
Data & dataloaders
modelopt/torch/puzzletron/utils/data/dataloaders.py, modelopt/torch/puzzletron/utils/data/dataset.py
Added create_train_dataloader (infinite ConstantLengthDataset-backed loader) and safer chat-template fallback for tokenizers without chat_template.
Stitch/child-init & checkpoint tweaks
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py, modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
Child init accepts list-of-mixins and avoids nulling overrides; checkpoint helper copies auto_map Python files into HF checkpoints.
Checkpoint helpers & utils
modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py, .../bypass_utils.py
Distributed-safe save/load of per-block stitched state, latest/best run discovery, and experiment-dir setup.
Examples, configs & docs
examples/puzzletron/BYPASS.md, examples/puzzletron/README.md, examples/puzzletron/configs/.../bypass/defaults.yaml, examples/puzzletron/configs/.../*.yaml
New BYPASS documentation, shared bypass defaults, many example Hydra configs updated to reference or document enabling bypass, and numerous small YAML composition files.
Tests — GPU & unit
tests/gpu/torch/puzzletron/test_bypass.py, tests/gpu/.../resources/.../bypass/test_bypass.yaml, tests/unit/torch/puzzletron/test_bypass_losses.py, tests/unit/torch/puzzletron/test_bypass_utils.py, tests/gpu/.../test_puzzletron.py
Added GPU integration tests for bypass workflows and checkpointing; unit tests for normalized-MSE losses and bypass utilities; adjusted distributed init timeouts and some test assertions.
Misc & small edits
various files (tests/*, modelopt/*, examples/*)
Import reorders, parsing/logging tweaks, minor test changes, and many YAML pointer files for config composition.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant Launcher as "launch_bypass_distillation(hydra_cfg)"
    participant Orchestrator as "run_bypassed_training(cfg)"
    participant Factory as "stitched_model_factory()"
    participant Data as "DataLoader / Teacher"
    participant Trainer as "train()"
    participant Checkpoint as "save_bypass_checkpoint()"

    User->>Launcher: provide Hydra cfg (single or sweep)
    Launcher->>Orchestrator: start run(s)
    Orchestrator->>Data: load teacher model & dataloaders
    Orchestrator->>Factory: build stitched teacher & student modules
    Factory-->>Orchestrator: return stitched modules + descriptors
    Orchestrator->>Trainer: start training loop
    loop per iteration
        Trainer->>Data: fetch batch
        Trainer->>Trainer: teacher forward -> capture activations
        Trainer->>Trainer: student forward -> compute per-block losses
        Trainer->>Trainer: backward, grad scale/clip, optimizer step
        Trainer->>Checkpoint: conditional save, write markers, symlink
        Checkpoint-->>Trainer: sync / resume info
    end
    Trainer-->>Orchestrator: training complete
    Orchestrator-->>Launcher: run finished
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 73.68% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title clearly and concisely describes the main feature being added: bypass distillation (blockwise local KD) to the puzzletron pipeline, which is the primary objective across all file changes.
Security Anti-Patterns ✅ Passed No security anti-patterns found: no torch.load() with weights_only=False, no numpy.load() with allow_pickle=True, no trust_remote_code=True, no eval/exec calls, no nosec comments.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ssameni/puzzletron-bypass

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tests/gpu/torch/puzzletron/test_puzzletron.py (1)

236-245: ⚠️ Potential issue | 🟡 Minor

The fallback printer still emits only rank-local values.

This branch now advertises num_layers={total_layers}, but it still prints only the contents of rank_{rank}.pth and is executed on rank 0 only. On multi-GPU runs the suggested EXPECTED_PRUNING_VALUES snippet will be incomplete.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/puzzletron/test_puzzletron.py` around lines 236 - 245, The
printer currently outputs only rank-local pruning_scores causing incomplete
EXPECTED_PRUNING_VALUES for multi-GPU runs; modify the logic so rank 0
aggregates pruning data from all ranks before printing: collect and merge
per-rank pruning_scores (or load all rank_{rank}.pth files) into a global
pruning_scores for each layer_name, compute the global score and channels (e.g.,
combine/average or gather channel indices across ranks) respecting total_layers,
and then have rank 0 iterate over layer_names using the aggregated values when
printing the block that uses total_layers and prints the EXPECTED_PRUNING_VALUES
snippet.
modelopt/torch/puzzletron/pruning/pruning_utils.py (1)

40-47: ⚠️ Potential issue | 🟠 Major

MoEChannelPruning is exposed before the init path supports it.

modelopt/torch/puzzletron/tools/bypassed_training/child_init.py now branches on this enum and forwards it into _init_mlp_module(), but _init_mlp_module() still falls through to Unsupported mlp_init_mode for this value when expert widths change. Any config that selects MoEChannelPruning will fail during child initialization.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/pruning/pruning_utils.py` around lines 40 - 47, The
enum MlpInitMode now includes MoEChannelPruning but _init_mlp_module still
treats that case as unsupported; update the _init_mlp_module implementation to
handle MlpInitMode.MoEChannelPruning (the same call-site that child_init.py
forwards into) by adding a branch for MlpInitMode.MoEChannelPruning that
performs the correct initialization when expert widths change (e.g., adapt the
weight/activation shapes by slicing/reshaping or reuse the
ConcatExpertsIntoDenseFFN logic where appropriate), so the child init no longer
falls through to the "Unsupported mlp_init_mode" error for MoEChannelPruning.
🧹 Nitpick comments (5)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)

804-806: This change makes explicit null resets impossible.

Treating None as “keep original” fixes the accidental overwrite, but it also removes the only way for JSON/YAML overrides to clear an optional field back to None. If callers need both behaviors, use a sentinel for “no override” and reserve None for explicit clearing.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines
804 - 806, The current override function (override) treats item_overrides ==
None as "keep original", which prevents callers from explicitly clearing a value
to None via JSON/YAML; change the logic to use a distinct sentinel (e.g., a new
unique object like NO_OVERRIDE) to represent "no override" and reserve None in
item_overrides to mean "set to None"/clear the field, updating the override
function to check against the sentinel (NO_OVERRIDE) instead of None and adjust
any callers that construct overrides to use the sentinel when they mean "leave
original".
modelopt/torch/puzzletron/utils/data/dataset.py (1)

123-130: Keep role markers in the no-template fallback.

Joining only content collapses system/user/assistant turns into plain text, which changes the supervision for chat datasets. A lightweight fallback like "{role}: {content}" preserves the conversation structure without relying on a tokenizer template.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/utils/data/dataset.py` around lines 123 - 130, The
fallback that builds sample when getattr(self.tokenizer, "chat_template", None)
is None should preserve role markers instead of joining only message["content"];
update the else branch in dataset.py (the block that currently sets sample =
"\n".join(m["content"] for m in sample)) to join messages using a lightweight
role-prefixed format like "{role}: {content}" so conversation turns
(system/user/assistant) are retained; keep using the same sample variable and
ensure this mirrors the structure expected by downstream code that consumes
apply_chat_template outputs.
modelopt/torch/puzzletron/utils/parsing.py (1)

337-345: Don’t silently treat every NaN as a no-op block.

This formatter now drops any NaN entry and can report No trainable blocks found. If a trainable block diverges, the failure disappears from the logs instead of surfacing. Filter only known skipped block types, or emit a separate warning for unexpected NaNs.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/utils/parsing.py` around lines 337 - 345, The
current filtering silently drops any NaN in losses_dict (and prunes
best_steps_dict/best_values_dict to match), which hides diverging trainable
blocks; instead, update the logic around losses_dict, best_steps_dict and
best_values_dict so you only drop entries whose keys match known skipped block
types (e.g., the explicit list of no-op block names like "Mamba"), and for any
other NaN values emit a warning/error (via the existing logger) that a trainable
block produced NaN rather than removing it; ensure best_steps_dict and
best_values_dict are only pruned to match the filtered losses_dict after this
selective filtering and warning behavior.
examples/puzzletron/main.py (1)

154-167: Progress messages in run_mip_only are hardcoded and inconsistent with the dynamic approach.

The run_full_puzzletron function now uses dynamic step counting (N = _total_steps(hydra_cfg)), but run_mip_only still uses hardcoded "7/8" and "8/8" progress messages. If bypass is configured, the step numbers would be incorrect (should be 8/9 and 9/9).

Consider applying the same dynamic step count logic here for consistency.

♻️ Suggested fix
 def run_mip_only(hydra_config_path: str):
     ...
     # Load hydra config
     hydra_cfg = initialize_hydra_config_for_dir(
         config_dir=hydra_config_dir,
         config_name=hydra_config_name,
         overrides=[],
     )
+    N = _total_steps(hydra_cfg)

     # Check if sweep mode is enabled
     if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
         mprint(
-            "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)"
+            f"Puzzletron Progress {N-1}/{N}: running MIP sweep for multiple compression rates (multi-gpu)"
         )
         sweep.run_mip_sweep(hydra_cfg)
     else:
         # mip_and_realize_models (distributed processing)
         # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
-        mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)")
+        mprint(f"Puzzletron Progress {N-1}/{N}: running MIP and realizing models (multi-gpu)")
         mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)

     dist.cleanup()
-    mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
+    mprint(f"Puzzletron Progress {N}/{N}: puzzletron pipeline completed (multi-gpu)")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/puzzletron/main.py` around lines 154 - 167, Update run_mip_only to
compute the total steps like run_full_puzzletron by calling
_total_steps(hydra_cfg) and use that N when formatting the progress messages
instead of hardcoded "7/8" and "8/8"; specifically, replace the two mprint calls
around the conditional that currently show "Puzzletron Progress 7/8" and "8/8"
with dynamic messages using N (e.g., f"Puzzletron Progress {current_step}/{N}:
...") and ensure current_step increments are correct for both the sweep branch
(sweep.run_mip_sweep) and the mip branch
(mip_and_realize_models.launch_mip_and_realize_model) so progress displays
consistently with _total_steps.
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)

548-556: Unused variable num_trainable_params.

The variable num_trainable_params is computed but never used in this function or elsewhere. This appears to be residual code. Consider removing it to reduce unnecessary computation and improve code clarity.

♻️ Proposed removal
             assert "learning_rate" in cfg.training
-            num_trainable_params = sum(
-                p.requires_grad and submodule_name in p_name
-                for p_name, p in student_stitched_module.named_parameters()
-                if "dummy_param" not in p_name  # exclude placeholder params
-            )
-            # Do NOT enable dummy params: blocks with no real trainable parameters
-            # (e.g. Mamba blocks during an attention-only bypass run) should produce
-            # NaN loss so they are excluded from statistics — identical to the
-            # optimizer=None path in the training loop.

             student_module_parameters = {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 548 - 556, Remove the unused computation of num_trainable_params:
delete the sum(...) assignment that iterates over
student_stitched_module.named_parameters() checking p.requires_grad and
submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding
explanatory comment about dummy params if still relevant, but eliminate the dead
variable and its associated needless iteration to avoid wasted computation and
clarify stitched_model_factory.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 45-58: The fallback currently only sorts checkpoint directories by
iteration (get_iter_num) so when multiple checkpoints exist for the same iter we
may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 673-677: Replace the hardcoded trust_remote_code=True in the
AutoTokenizer.from_pretrained call with the same caller-configurable
trust_remote_code flag you already read from the descriptor earlier (the
variable used for model config loading at lines ~597/631); specifically update
the tokenizer = AutoTokenizer.from_pretrained(...) invocation that uses
cfg.teacher_dir so it passes the descriptor-derived trust_remote_code value
instead of True, ensuring the flag remains configurable and defaults to False.

In `@modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py`:
- Around line 146-149: The pre-checks that treat presence of files like
(teacher_dir / "config.json"), any rank_*.pth, files under
pruned_ckpts_output_dir, or library outputs as sufficient to skip stages are
unsafe; change these guards to rely on durable completion markers (e.g., a .done
or .complete file) created at the successful end of
conversion/scoring/pruning/library build instead of existence-only checks, so
functions like the conversion branch around teacher_dir/config.json, the rank_*
checkpoint checks, and the pruned_ckpts_output_dir/library checks only skip when
their corresponding completion marker exists; ensure launch_score_activations()
remains the stricter gate for pruning-activation scoring but remove or weaken
the naive existence checks noted at the conversion lines (the block using
teacher_dir/config.json) and the other mentioned blocks (191-193, 286-289) to
check for the specific "<stage>.complete" marker before skipping.

In `@modelopt/torch/puzzletron/sewing_kit/utils.py`:
- Around line 452-454: The normalization denominator is computed as
F.mse_loss(target, torch.zeros_like(target) + epsilon, ...) which shifts the
target by epsilon and biases the scale; instead compute the denominator as
F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon (or
clamp_min the denominator to epsilon) so you add epsilon to the final scalar
denominator instead of to the zero tensor; update the occurrences around the
loss assignment (loss, input, target, epsilon, F.mse_loss) and the similar block
at lines 479-482 accordingly.

In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 380-396: The auto_map parsing in checkpoint_utils_hf.py
incorrectly assumes each model_config.auto_map value is a dotted string; update
the logic that builds module_files (and any usage of class_ref) to first
normalize each value by: if it's a list/tuple take the first element, if it
contains a repo qualifier split off the "repo_id--" prefix, then take the module
part before the first '.' and append ".py" (so "tokenization_my.py"); apply this
normalization where module_files is created and when iterating filenames so
lists/tuples and repo-qualified references are handled and the correct source
filenames are copied.

In `@modelopt/torch/puzzletron/utils/data/dataloaders.py`:
- Around line 89-90: The DataLoader factory allows num_workers>0 while
ConstantLengthDataset.__iter__ does not shard via get_worker_info(), causing
duplicate samples; update the dataloaders (the function returning DataLoader in
modelopt/torch/puzzletron/utils/data/dataloaders.py) to either reject or
override num_workers>0 (e.g., force num_workers=0 or raise an error) until
ConstantLengthDataset.__iter__ is updated to use
torch.utils.data.get_worker_info() for worker sharding, and apply the same guard
to the other DataLoader creation call sites noted around the 114-118 region;
reference ConstantLengthDataset.__iter__ when adding the guard so future
implementers know to remove it after adding proper worker-aware iteration.
- Around line 98-99: The call to train_data.shuffle(seed=shuffle_seed,
keep_in_memory=True) fails for streaming (Iterable) datasets because
IterableDataset.shuffle() doesn't accept keep_in_memory; update the code that
checks shuffle_seed to detect streaming datasets (e.g., via whatever marker
load_streaming_fn sets or by checking hasattr(train_data, "__iter__") vs
__len__/isinstance of IterableDataset) and branch: for non-streaming datasets
call train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) as before, and
for streaming/iterable datasets call train_data.shuffle(seed=shuffle_seed)
without keep_in_memory; ensure you modify the block that references shuffle_seed
and train_data.shuffle so runtime errors are avoided when load_streaming_fn()
returns a streaming dataset.

In `@tests/gpu/torch/puzzletron/test_bypass.py`:
- Line 213: The timeout passed to dist.setup uses timedelta(10) which means 10
days; change it to an explicit unit like timedelta(seconds=10) (or
timedelta(minutes=10) if intended) to avoid 10-day test hangs — locate the call
to dist.setup (symbol: dist.setup) in tests/gpu/torch/puzzletron/test_bypass.py
and the other listed files and replace timedelta(10) with timedelta(seconds=10)
(or the correct unit) in each occurrence.

---

Outside diff comments:
In `@modelopt/torch/puzzletron/pruning/pruning_utils.py`:
- Around line 40-47: The enum MlpInitMode now includes MoEChannelPruning but
_init_mlp_module still treats that case as unsupported; update the
_init_mlp_module implementation to handle MlpInitMode.MoEChannelPruning (the
same call-site that child_init.py forwards into) by adding a branch for
MlpInitMode.MoEChannelPruning that performs the correct initialization when
expert widths change (e.g., adapt the weight/activation shapes by
slicing/reshaping or reuse the ConcatExpertsIntoDenseFFN logic where
appropriate), so the child init no longer falls through to the "Unsupported
mlp_init_mode" error for MoEChannelPruning.

In `@tests/gpu/torch/puzzletron/test_puzzletron.py`:
- Around line 236-245: The printer currently outputs only rank-local
pruning_scores causing incomplete EXPECTED_PRUNING_VALUES for multi-GPU runs;
modify the logic so rank 0 aggregates pruning data from all ranks before
printing: collect and merge per-rank pruning_scores (or load all rank_{rank}.pth
files) into a global pruning_scores for each layer_name, compute the global
score and channels (e.g., combine/average or gather channel indices across
ranks) respecting total_layers, and then have rank 0 iterate over layer_names
using the aggregated values when printing the block that uses total_layers and
prints the EXPECTED_PRUNING_VALUES snippet.

---

Nitpick comments:
In `@examples/puzzletron/main.py`:
- Around line 154-167: Update run_mip_only to compute the total steps like
run_full_puzzletron by calling _total_steps(hydra_cfg) and use that N when
formatting the progress messages instead of hardcoded "7/8" and "8/8";
specifically, replace the two mprint calls around the conditional that currently
show "Puzzletron Progress 7/8" and "8/8" with dynamic messages using N (e.g.,
f"Puzzletron Progress {current_step}/{N}: ...") and ensure current_step
increments are correct for both the sweep branch (sweep.run_mip_sweep) and the
mip branch (mip_and_realize_models.launch_mip_and_realize_model) so progress
displays consistently with _total_steps.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 548-556: Remove the unused computation of num_trainable_params:
delete the sum(...) assignment that iterates over
student_stitched_module.named_parameters() checking p.requires_grad and
submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding
explanatory comment about dummy params if still relevant, but eliminate the dead
variable and its associated needless iteration to avoid wasted computation and
clarify stitched_model_factory.py.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 804-806: The current override function (override) treats
item_overrides == None as "keep original", which prevents callers from
explicitly clearing a value to None via JSON/YAML; change the logic to use a
distinct sentinel (e.g., a new unique object like NO_OVERRIDE) to represent "no
override" and reserve None in item_overrides to mean "set to None"/clear the
field, updating the override function to check against the sentinel
(NO_OVERRIDE) instead of None and adjust any callers that construct overrides to
use the sentinel when they mean "leave original".

In `@modelopt/torch/puzzletron/utils/data/dataset.py`:
- Around line 123-130: The fallback that builds sample when
getattr(self.tokenizer, "chat_template", None) is None should preserve role
markers instead of joining only message["content"]; update the else branch in
dataset.py (the block that currently sets sample = "\n".join(m["content"] for m
in sample)) to join messages using a lightweight role-prefixed format like
"{role}: {content}" so conversation turns (system/user/assistant) are retained;
keep using the same sample variable and ensure this mirrors the structure
expected by downstream code that consumes apply_chat_template outputs.

In `@modelopt/torch/puzzletron/utils/parsing.py`:
- Around line 337-345: The current filtering silently drops any NaN in
losses_dict (and prunes best_steps_dict/best_values_dict to match), which hides
diverging trainable blocks; instead, update the logic around losses_dict,
best_steps_dict and best_values_dict so you only drop entries whose keys match
known skipped block types (e.g., the explicit list of no-op block names like
"Mamba"), and for any other NaN values emit a warning/error (via the existing
logger) that a trainable block produced NaN rather than removing it; ensure
best_steps_dict and best_values_dict are only pruned to match the filtered
losses_dict after this selective filtering and warning behavior.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 371acd83-77b9-4396-8a82-eddd5b11dd40

📥 Commits

Reviewing files that changed from the base of the PR and between e508b76 and e018ca0.

📒 Files selected for processing (27)
  • examples/puzzletron/README.md
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/main.py
  • modelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.py
  • modelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.py
  • modelopt/torch/puzzletron/bypass_distillation/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
  • modelopt/torch/puzzletron/pruning/pruning_utils.py
  • modelopt/torch/puzzletron/puzzletron.py
  • modelopt/torch/puzzletron/sewing_kit/utils.py
  • modelopt/torch/puzzletron/tools/bypassed_training/child_init.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/utils/data/dataloaders.py
  • modelopt/torch/puzzletron/utils/data/dataset.py
  • modelopt/torch/puzzletron/utils/parsing.py
  • tests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yaml
  • tests/gpu/torch/puzzletron/test_bypass.py
  • tests/gpu/torch/puzzletron/test_puzzletron.py
  • tests/unit/torch/puzzletron/__init__.py
  • tests/unit/torch/puzzletron/test_bypass_losses.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py

Comment on lines +45 to +58
# If "latest" doesn't exist, look explicitly into directories with `*iter-*`
candidate_dirs = [d for d in run_parent_dir.glob("*iter-*") if d.is_dir()]

if not candidate_dirs:
return None

def get_iter_num(dir_name):
match = re.search(r"iter-(\d+)", dir_name.name)
return int(match.group(1)) if match else 0

checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True)
for latest_dir in checkpoint_dirs:
if (latest_dir / "saving_completed").exists():
return str(latest_dir)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Include step_num when picking the latest checkpoint.

This fallback only sorts on iter-(\d+). If a run writes multiple checkpoints inside the same iteration, resume can load an older step even though a newer checkpoint exists in the same run_parent_dir.

💡 Suggested fix
-    def get_iter_num(dir_name):
-        match = re.search(r"iter-(\d+)", dir_name.name)
-        return int(match.group(1)) if match else 0
-
-    checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True)
+    def checkpoint_order(path: Path) -> tuple[int, int, float]:
+        match = re.search(r"iter-(\d+)(?:.*step-(\d+))?", path.name)
+        if not match:
+            return (0, 0, path.stat().st_mtime)
+        return (int(match.group(1)), int(match.group(2) or 0), path.stat().st_mtime)
+
+    checkpoint_dirs = sorted(candidate_dirs, key=checkpoint_order, reverse=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`
around lines 45 - 58, The fallback currently only sorts checkpoint directories
by iteration (get_iter_num) so when multiple checkpoints exist for the same iter
we may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).

Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py
Comment thread modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py Outdated
Comment thread modelopt/torch/puzzletron/sewing_kit/utils.py Outdated
Comment thread modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py Outdated
Comment thread modelopt/torch/puzzletron/utils/data/dataloaders.py
Comment thread modelopt/torch/puzzletron/utils/data/dataloaders.py Outdated
Comment thread tests/gpu/torch/puzzletron/test_bypass.py
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary: Adds bypass distillation (blockwise local knowledge distillation) as an optional pipeline stage to puzzletron. Includes a full training loop, stitched model factory, checkpoint management, loss functions, data loader, configuration, and comprehensive tests. Also fixes bugs in child_init.py, dataset.py, and adds HF auto-download logic.

Issues Found:

  1. [Duplicated Code] normalized_mse_loss in sewing_kit/utils.py (diff lines 432-445) is an exact duplicate of the existing implementation in modelopt/torch/puzzletron/tools/kd_model.py:32-41. The new code should import and reuse the existing function rather than redefining it. The vectorwise_normalized_mse_loss and batched_normalized_mse_loss variants are new and fine, but they should build on the existing import.

  2. [Correctness / Security] training_loop.py:675AutoTokenizer.from_pretrained uses hardcoded trust_remote_code=True. The variable trust_remote_code is already computed from the descriptor at line 648. This should use trust_remote_code=trust_remote_code instead. (Flagged by pre-merge checks as well.)

  3. [Correctness / Security] bypass_checkpoint_utils.py:85,99torch.load() calls lack weights_only=True. The codebase convention (e.g., checkpoint_utils.py:43,77) is to use weights_only=True for state dict loading. These calls load state dicts and optimizer states respectively, which are pure tensor data and should use weights_only=True.

  4. [Correctness] training_loop.py — The except Exception as e block at the end of run_bypassed_training (around line 870) catches all exceptions and calls sys.exit(1) for non-SystemExit exceptions. This swallows the actual exception type and prevents proper test framework error reporting. In GPU tests, a failing bypass run will produce SystemExit(1) instead of the real traceback. Consider re-raising or at least logging before exit.

  5. [Correctness] stitched_model_factory.py:370-373 — The lambda closures in the stitched module creation loop (adapter=lambda v: InputArgs(target=v) and adapter=lambda v: InputArgs(input=v)) capture v correctly since they're arguments, but the loss target/input naming ("target" and "input") relies on block_loss_func accepting exactly these keyword arguments. If someone changes block_loss_func to e.g. batched_normalized_mse_loss, the keyword args don't match (batched_normalized_mse_loss takes input and target positional args, not kwargs via InputArgs). This coupling is implicit and fragile — consider documenting the contract or adding a **kwargs adapter.

  6. [Correctness] bypass_checkpoint_utils.py:89loaded_state_dict = {**stitched_module.state_dict(), **loaded_state_dict} merges the current state dict with the loaded one (loaded takes precedence). However, the current state dict is fetched before loading — if the model is on a different device, keys may contain tensors on the wrong device. The subsequent load_state_dict should handle this, but the intermediate merged dict is wasteful. Consider just using strict=False with load_state_dict directly.

  7. [Readability] stitched_model_factory.py — The bypass_factory_fn function is ~250 lines long with deeply nested logic. The student model initialization block (lines 200-305) could be extracted into a helper like _initialize_student_model(...).

  8. [Readability] training_loop.py — The train() function is ~300 lines with deeply nested control flow for logging, validation, checkpoint saving, and time-based signals. Consider extracting checkpoint-save logic and logging logic into separate functions.

  9. [Readability] stitched_model_factory.py:434-435 — Blank lines between the closing of the function and the backward-compatible aliases (gqa_factory_fn = bypass_factory_fn, moe_factory_fn = bypass_factory_fn). These aliases have no callers in this PR and no documentation. If they're for backward compat with existing configs, add a comment. If they're unused, remove them.

  10. [Tests] The GPU tests are thorough for the happy path but don't test checkpoint resume (loading from a previous run). The find_last_ckpt_for_resume + load_local_state path is complex and untested. At minimum, a test that runs bypass, then runs it again with find_last_ckpt_for_resume=True to verify resume works would increase confidence.

  11. [Tests] No unit test for _set_keys_to_learn which has significant branching logic (subblock types, hybrid model block_configs filtering, regex fallback). This function is critical for correctness.

  12. [Correctness] puzzletron_nas_plugin.py — The new auto-download logic in convert_puzzletron_model (lines 152-165) runs snapshot_download only on rank 0 inside if dist.is_master(), but then all ranks call dist.barrier(). If the download takes a long time, the barrier timeout (set in main.py as timedelta(10) = 10 days) should be fine, but the input_model_path variable is only updated on rank 0 — other ranks never use it since only rank 0 does the conversion. This is correct but subtle; a comment would help.

  13. [Correctness] bypass_utils.py:50set_experiment_dir assigns a Path object to cfg.bypass.experiment_dir, but OmegaConf/DictConfig doesn't natively support Path objects. This works because OmegaConf stores it as-is in struct mode off, but it may cause serialization issues (e.g., json_dump in save_bypass_checkpoint). Consider converting to str.

Suggestions:

  • The _copy_auto_map_code_files addition in checkpoint_utils_hf.py is a good fix for trust_remote_code models. Consider adding a brief unit test or at least a comment about which models require this (e.g., NemotronH).
  • The format_stitched_losses NaN filtering is a nice quality-of-life improvement for hybrid models. The import math inside the function body should be moved to the module top-level.
  • The dataset.py chat_template fallback is correct and handles base models gracefully.
  • The child_init.py fix (return item instead of return item_overrides when None) is a real bug fix — good catch.

Overall Assessment: This is a well-structured, substantial feature addition. The core architecture (stitched model factory, per-block KD, pipeline integration) is sound. However, the hardcoded trust_remote_code=True security issue and the duplicated normalized_mse_loss need to be addressed before merge. The torch.load calls should also use weights_only=True per project convention.

@Separius Separius requested a review from a team as a code owner April 2, 2026 13:23
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)

811-843: ⚠️ Potential issue | 🟠 Major

null overrides still crash for optional nested dataclasses.

The new item_overrides is None branch is bypassed when previous_value is None and _is_dataclass_type(item_type) is true, so an override like ...: null still becomes _get_dataclass_type(item_type)(**item_overrides) and raises at runtime. This is easy to hit for optional sub-configs that default to None.

Suggested fix
-            if previous_value is None and _is_dataclass_type(item_type):
-                new_value = _get_dataclass_type(item_type)(**item_overrides)
+            if item_overrides is None:
+                new_value = previous_value
+            elif previous_value is None and _is_dataclass_type(item_type):
+                assert isinstance(item_overrides, dict)
+                new_value = _get_dataclass_type(item_type)(**item_overrides)
             else:
                 new_value = override(previous_value, item_overrides)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines
811 - 843, The dataclass_override loop special-case instantiates a nested
dataclass even when the provided override is None, causing a crash; modify the
block in dataclass_override that handles "previous_value is None and
_is_dataclass_type(item_type)" to first check if item_overrides is None and in
that case set new_value = None (or call override(previous_value,
item_overrides)), otherwise instantiate with
_get_dataclass_type(item_type)(**item_overrides); keep the subsequent
check_type(new_value, item_type) and existing symbols (override,
dataclass_override, _is_dataclass_type, _get_dataclass_type, check_type) so
optional nested dataclass overrides that are null no longer raise.
🧹 Nitpick comments (3)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)

96-115: Reject overlapping outputs from multiple pruning mixins.

layer_out_state_dict.update(_layer_out) silently makes the final checkpoint depend on mixin order if two mixins emit the same state-dict key. Failing fast here is safer than letting one mixin overwrite the other.

Suggested guard
-            layer_out_state_dict.update(_layer_out)
+            overlapping_keys = layer_out_state_dict.keys() & _layer_out.keys()
+            if overlapping_keys:
+                raise ValueError(
+                    f"Pruning mixins produced overlapping keys for layer {layer_idx}: "
+                    f"{sorted(overlapping_keys)}"
+                )
+            layer_out_state_dict.update(_layer_out)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines
96 - 115, The loop over pruning mixins currently does
layer_out_state_dict.update(_layer_out) which allows later mixins to silently
overwrite keys from earlier ones; change this to detect overlapping keys and
fail fast: for each _mixin when you get _layer_out from prune_single_layer,
compute intersection = set(layer_out_state_dict.keys()) & set(_layer_out.keys())
and if intersection is non-empty raise a ValueError (or AssertionError) listing
the conflicting keys and the mixin identity (use _mixin or its type) instead of
updating; only call layer_out_state_dict.update(_layer_out) when intersection is
empty to ensure deterministic, non-overlapping outputs from prune_single_layer
across mixins.
modelopt/torch/puzzletron/tools/kd_model.py (1)

38-39: Add a zero/near-zero target regression test for this denominator change.

This adjustment mainly changes behavior when target has tiny norm, but tests/unit/torch/puzzletron/test_bypass_losses.py currently only covers random tensors. A focused zero-target case would keep this stabilization behavior from regressing unnoticed.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/tools/kd_model.py` around lines 38 - 39, Add a unit
test in tests/unit/torch/puzzletron/test_bypass_losses.py (e.g.,
test_kd_loss_zero_and_near_zero_target) that imports the kd loss implementation
from modelopt.torch.puzzletron.tools.kd_model, constructs both a zero target and
a near-zero target tensor, calls the code path that computes loss using the
expression containing F.mse_loss(input, target, reduction=reduction) /
(F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon),
and asserts the loss is finite and behaves stably (no division-by-zero, not
NaN/Inf) for both cases; use the same input tensor for both and check that
adding the epsilon in the denominator prevents regressions.
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)

377-381: Consider more descriptive error handling for invalid block_loss_func.

If cfg.model_factory.block_loss_func is not one of the three supported values, a KeyError is raised with just the invalid key name. A more descriptive error would help users identify the misconfiguration quickly.

Suggested improvement
+    _BLOCK_LOSS_FUNCS = {
+        "normalized_mse_loss": normalized_mse_loss,
+        "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss,
+        "batched_normalized_mse_loss": batched_normalized_mse_loss,
+    }
+    loss_func_name = cfg.model_factory.block_loss_func
+    if loss_func_name not in _BLOCK_LOSS_FUNCS:
+        raise ValueError(
+            f"Unknown block_loss_func '{loss_func_name}'. "
+            f"Supported: {list(_BLOCK_LOSS_FUNCS.keys())}"
+        )
-    block_loss_func = {
-        "normalized_mse_loss": normalized_mse_loss,
-        "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss,
-        "batched_normalized_mse_loss": batched_normalized_mse_loss,
-    }[cfg.model_factory.block_loss_func]
+    block_loss_func = _BLOCK_LOSS_FUNCS[loss_func_name]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 377 - 381, The current lookup for block_loss_func in
stitched_model_factory.py uses a direct dict index which raises an opaque
KeyError when cfg.model_factory.block_loss_func is invalid; replace it with a
guarded lookup: retrieve via dict.get or check membership first and raise a
ValueError with a clear message that includes the invalid value and the allowed
options (e.g., "normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss"); update the code around the block_loss_func
assignment (the dict and its use) so callers get a descriptive error instead of
a raw KeyError.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 113-125: The checkpoint load/save is missing GradScaler state so
when use_grad_scaling=True resumed runs lose scaler state; update the save and
load paths around the StitchedModuleDescriptor handling to persist
grad_scaler.state_dict() (e.g., save to
stitched/{stitched_module_name}.grad_scaler.pth) when grad_scaler is not None
and on load (in the blocks that currently load optimizer state and in the
similar 165-171 block) call grad_scaler.load_state_dict(...) after constructing
or retrieving the module’s grad_scaler, using map_location=device, and guard
with the use_grad_scaling flag so scaler state is restored only when applicable.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 597-600: Log currently prints an empty block name because
submodule_name is never reassigned; update the mprint call in
stitched_model_factory.py (the one that prints Block {submodule_name}: ...) to
use a valid identifier such as student_stitched_module_name (or module_name) so
the block name is meaningful; locate the mprint invocation and replace
submodule_name with student_stitched_module_name (ensuring
student_stitched_module_name is in scope) and keep the rest of the
message/parameter counting intact.
- Around line 417-428: The code assumes owned_block_indexes is non-empty before
calling min()/max(), which will raise ValueError if a rank owns no blocks; in
the block around min_owned_index/max_owned_index in stitched_model_factory.py,
first check if not owned_block_indexes and handle it defensively (e.g., set
prev_rank and next_rank to None or raise a clear, explanatory error) instead of
calling min()/max(); update the logic that computes prev_rank and next_rank
using model_blocks_process_ownership and all_block_indices to only run when
owned_block_indexes is non-empty so misconfiguration yields a clear message or
safe defaults.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 837-838: The code reads source_datasets_to_discard from cfg.bypass
root but the new config nests it under bypass.data; update the dataloader calls
that set source_datasets_to_discard (and any similar occurrences) to read from
cfg.bypass.data.get("source_datasets_to_discard", tuple()) instead of
cfg.bypass.get(...), leaving bos_rate as cfg.bypass.data.bos_rate; search for
the occurrences that set source_datasets_to_discard (the two places mentioned
around the calls that also use bos_rate) and replace them to use cfg.bypass.data
so the discard list becomes configurable.
- Around line 252-253: The parameter skip_first_batches is never applied: after
creating the batch iterator from the ConstantLengthDataset/dataloader you must
advance that iterator by skip_first_batches before entering the training loop
(e.g., consume the iterator with next(...) in a short loop or use
itertools.islice to drop the first N items); update the code paths where
skip_first_batches is accepted (the occurrences around skip_first_batches in
training_loop.py and the second occurrence at lines ~329-330) to consume the
iterator accordingly so resumed runs do not replay from batch 0.
- Around line 349-350: The loop exit condition uses a 1-based counter and
currently uses >=, causing it to stop one step too early; update the check in
the training loop that references cfg.bypass.step_num and
cfg.bypass.training.max_steps so it breaks only once step_num has passed the
budget (use > instead of >=) so the final scheduled step runs.
- Around line 103-107: The AutoConfig.from_pretrained call inside
run_bypassed_training bypasses the earlier trust_remote_code decision; update
the AutoConfig.from_pretrained invocation in training_loop.py (the block that
imports HFAutoConfig and sets teacher_hf_cfg/teacher_intermediate_size) to pass
the same trust_remote_code flag you query earlier (the
descriptor/trust_remote_code value used by run_bypassed_training) so
remote-code-required models load consistently and safely. Locate the
AutoConfig.from_pretrained usage and add the trust_remote_code argument (sourced
from the existing hydra_cfg/descriptor/trust_remote_code variable) when calling
from_pretrained, ensuring the call mirrors the trusted-remote-code decision made
elsewhere.

In `@tests/gpu/torch/puzzletron/test_puzzletron.py`:
- Around line 209-215: The test incorrectly computes per-rank FFN counts from
hidden layer count (total_layers = max(2, size)); instead compute the actual
number of prunable FFN blocks (e.g., scan the model's layer names or modules to
count FFN/prunable blocks rather than using hidden-layer count) into
total_ffn_blocks, then compute layers_this_rank = total_ffn_blocks // size + (1
if rank < total_ffn_blocks % size else 0) and assert len(layer_names) ==
layers_this_rank (allowing 0 for ranks that only own Mamba blocks); update the
variables total_layers/layers_this_rank and reference layer_names when making
this change.

---

Outside diff comments:
In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 811-843: The dataclass_override loop special-case instantiates a
nested dataclass even when the provided override is None, causing a crash;
modify the block in dataclass_override that handles "previous_value is None and
_is_dataclass_type(item_type)" to first check if item_overrides is None and in
that case set new_value = None (or call override(previous_value,
item_overrides)), otherwise instantiate with
_get_dataclass_type(item_type)(**item_overrides); keep the subsequent
check_type(new_value, item_type) and existing symbols (override,
dataclass_override, _is_dataclass_type, _get_dataclass_type, check_type) so
optional nested dataclass overrides that are null no longer raise.

---

Nitpick comments:
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 377-381: The current lookup for block_loss_func in
stitched_model_factory.py uses a direct dict index which raises an opaque
KeyError when cfg.model_factory.block_loss_func is invalid; replace it with a
guarded lookup: retrieve via dict.get or check membership first and raise a
ValueError with a clear message that includes the invalid value and the allowed
options (e.g., "normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss"); update the code around the block_loss_func
assignment (the dict and its use) so callers get a descriptive error instead of
a raw KeyError.

In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 96-115: The loop over pruning mixins currently does
layer_out_state_dict.update(_layer_out) which allows later mixins to silently
overwrite keys from earlier ones; change this to detect overlapping keys and
fail fast: for each _mixin when you get _layer_out from prune_single_layer,
compute intersection = set(layer_out_state_dict.keys()) & set(_layer_out.keys())
and if intersection is non-empty raise a ValueError (or AssertionError) listing
the conflicting keys and the mixin identity (use _mixin or its type) instead of
updating; only call layer_out_state_dict.update(_layer_out) when intersection is
empty to ensure deterministic, non-overlapping outputs from prune_single_layer
across mixins.

In `@modelopt/torch/puzzletron/tools/kd_model.py`:
- Around line 38-39: Add a unit test in
tests/unit/torch/puzzletron/test_bypass_losses.py (e.g.,
test_kd_loss_zero_and_near_zero_target) that imports the kd loss implementation
from modelopt.torch.puzzletron.tools.kd_model, constructs both a zero target and
a near-zero target tensor, calls the code path that computes loss using the
expression containing F.mse_loss(input, target, reduction=reduction) /
(F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon),
and asserts the loss is finite and behaves stably (no division-by-zero, not
NaN/Inf) for both cases; use the same input tensor for both and check that
adding the epsilon in the denominator prevents regressions.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 6cb35ef5-41ea-4f6d-990a-791e2c99b812

📥 Commits

Reviewing files that changed from the base of the PR and between e018ca0 and 2b99327.

📒 Files selected for processing (90)
  • examples/puzzletron/BYPASS.md
  • examples/puzzletron/README.md
  • examples/puzzletron/configs/bypass/defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/pruning/defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/realize_model/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/scoring/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/validate_model_defaults.yaml
  • examples/puzzletron/configs/validate_solutions_defaults.yaml
  • examples/puzzletron/main.py
  • modelopt/torch/puzzletron/bypass_distillation/__init__.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/bypass_distillation/training_loop.py
  • modelopt/torch/puzzletron/dataset/prepare_dataset.py
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
  • modelopt/torch/puzzletron/pruning/pruning_utils.py
  • modelopt/torch/puzzletron/sewing_kit/utils.py
  • modelopt/torch/puzzletron/tools/bypassed_training/child_init.py
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • modelopt/torch/puzzletron/tools/kd_model.py
  • modelopt/torch/puzzletron/utils/data/dataloaders.py
  • modelopt/torch/puzzletron/utils/parsing.py
  • modelopt/torch/utils/plugins/transformers_dataset.py
  • tests/_test_utils/torch/puzzletron/utils.py
  • tests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.py
  • tests/gpu/torch/puzzletron/nas/plugins/test_nas_search.py
  • tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml
  • tests/gpu/torch/puzzletron/test_bypass.py
  • tests/gpu/torch/puzzletron/test_puzzletron.py
  • tests/unit/torch/puzzletron/__init__.py
  • tests/unit/torch/puzzletron/test_bypass_losses.py
  • tests/unit/torch/puzzletron/test_bypass_utils.py
✅ Files skipped from review due to trivial changes (67)
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
  • tests/unit/torch/puzzletron/init.py
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
  • modelopt/torch/puzzletron/dataset/prepare_dataset.py
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
  • modelopt/torch/utils/plugins/transformers_dataset.py
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/realize_model/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yaml
  • tests/_test_utils/torch/puzzletron/utils.py
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
  • tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml
  • examples/puzzletron/configs/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml
  • examples/puzzletron/configs/scoring/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yaml
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/README.md
  • modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
  • examples/puzzletron/configs/validate_model_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
  • modelopt/torch/puzzletron/bypass_distillation/data_classes.py
  • modelopt/torch/puzzletron/bypass_distillation/init.py
  • examples/puzzletron/configs/bypass/defaults.yaml
  • examples/puzzletron/BYPASS.md
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
  • tests/unit/torch/puzzletron/test_bypass_losses.py
  • examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
  • examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
  • examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml
🚧 Files skipped from review as they are similar to previous changes (6)
  • examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
  • modelopt/torch/puzzletron/utils/parsing.py
  • modelopt/torch/puzzletron/pruning/pruning_utils.py
  • modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
  • modelopt/torch/puzzletron/sewing_kit/utils.py
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py

Comment on lines +597 to +600
mprint(
f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors "
f"({sum(p.numel() for p in trainable_params.values()):,} params)"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Log message will always show empty block name.

submodule_name is initialized to "" at line 449 and never reassigned within the loop. The log message "Block : ..." will always display an empty block name. Consider using student_stitched_module_name (e.g., block_0) or module_name for clarity.

Suggested fix
             mprint(
-                f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors "
+                f"Block {student_stitched_module_name}: {len(trainable_params)} trainable parameter tensors "
                 f"({sum(p.numel() for p in trainable_params.values()):,} params)"
             )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
mprint(
f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors "
f"({sum(p.numel() for p in trainable_params.values()):,} params)"
)
mprint(
f"Block {student_stitched_module_name}: {len(trainable_params)} trainable parameter tensors "
f"({sum(p.numel() for p in trainable_params.values()):,} params)"
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 597 - 600, Log currently prints an empty block name because
submodule_name is never reassigned; update the mprint call in
stitched_model_factory.py (the one that prints Block {submodule_name}: ...) to
use a valid identifier such as student_stitched_module_name (or module_name) so
the block name is meaningful; locate the mprint invocation and replace
submodule_name with student_stitched_module_name (ensuring
student_stitched_module_name is in scope) and keep the rest of the
message/parameter counting intact.

Comment on lines +103 to +107
if do_ffn or do_attn or do_blk:
from transformers import AutoConfig as HFAutoConfig

teacher_hf_cfg = HFAutoConfig.from_pretrained(str(hydra_cfg.teacher_dir))
teacher_intermediate_size = getattr(teacher_hf_cfg, "intermediate_size", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '95,120p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1567


🏁 Script executed:

rg "requires_trust_remote_code|trust_remote_code" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 614


🏁 Script executed:

rg "ModelDescriptorFactory" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 539


🏁 Script executed:

cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '1,110p'

Repository: NVIDIA/Model-Optimizer

Length of output: 5786


🏁 Script executed:

rg "def requires_trust_remote_code" modelopt/ -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 1441


🏁 Script executed:

cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '73,250p'

Repository: NVIDIA/Model-Optimizer

Length of output: 9638


🏁 Script executed:

rg "def launch_bypass_distillation" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -A 50

Repository: NVIDIA/Model-Optimizer

Length of output: 2748


🏁 Script executed:

rg "hydra_cfg.descriptor|cfg.descriptor" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 344


🏁 Script executed:

cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '240,300p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2827


🏁 Script executed:

rg "descriptor\s*=" modelopt/torch/puzzletron/bypass_distillation/training_loop.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 558


🏁 Script executed:

cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '290,360p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3106


🏁 Script executed:

rg "def run_bypassed_training" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -A 30

Repository: NVIDIA/Model-Optimizer

Length of output: 1259


Thread trust_remote_code through the auto-config probe.

run_bypassed_training() queries the descriptor for trust_remote_code, but this auto-config path at lines 103–107 bypasses that and calls AutoConfig.from_pretrained() with default (unsafe) behavior. Models requiring remote code execution will fail inconsistently depending on which path loads them.

💡 Suggested fix
         if do_ffn or do_attn or do_blk:
             from transformers import AutoConfig as HFAutoConfig
 
-            teacher_hf_cfg = HFAutoConfig.from_pretrained(str(hydra_cfg.teacher_dir))
+            descriptor = ModelDescriptorFactory.get(hydra_cfg.descriptor)
+            trust_remote_code = descriptor.requires_trust_remote_code()
+            teacher_hf_cfg = HFAutoConfig.from_pretrained(
+                str(hydra_cfg.teacher_dir),
+                trust_remote_code=trust_remote_code,
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
103 - 107, The AutoConfig.from_pretrained call inside run_bypassed_training
bypasses the earlier trust_remote_code decision; update the
AutoConfig.from_pretrained invocation in training_loop.py (the block that
imports HFAutoConfig and sets teacher_hf_cfg/teacher_intermediate_size) to pass
the same trust_remote_code flag you query earlier (the
descriptor/trust_remote_code value used by run_bypassed_training) so
remote-code-required models load consistently and safely. Locate the
AutoConfig.from_pretrained usage and add the trust_remote_code argument (sourced
from the existing hydra_cfg/descriptor/trust_remote_code variable) when calling
from_pretrained, ensuring the call mirrors the trusted-remote-code decision made
elsewhere.

Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py
Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py Outdated
Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py Outdated
Comment on lines 209 to 215
# The test model has num_hidden_layers = max(2, size), so every rank owns at least
# one layer. Compute the actual expected count for *this* rank.
total_layers = max(2, size)
layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0)
assert len(layer_names) == layers_this_rank, (
f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

The per-rank FFN count is still wrong for hybrid models.

total_layers = max(2, size) counts hidden layers, not prunable FFN blocks. This file already documents nvidia/NVIDIA-Nemotron-Nano-12B-v2 as having only one FFN layer, so a rank that owns only Mamba blocks can legitimately have len(layer_names) == 0.

💡 Suggested fix
-        total_layers = max(2, size)
-        layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0)
-        assert len(layer_names) == layers_this_rank, (
-            f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
-        )
+        total_layers = max(2, size)
+        if len(expected) == total_layers:
+            layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0)
+            assert len(layer_names) == layers_this_rank, (
+                f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/puzzletron/test_puzzletron.py` around lines 209 - 215, The
test incorrectly computes per-rank FFN counts from hidden layer count
(total_layers = max(2, size)); instead compute the actual number of prunable FFN
blocks (e.g., scan the model's layer names or modules to count FFN/prunable
blocks rather than using hidden-layer count) into total_ffn_blocks, then compute
layers_this_rank = total_ffn_blocks // size + (1 if rank < total_ffn_blocks %
size else 0) and assert len(layer_names) == layers_this_rank (allowing 0 for
ranks that only own Mamba blocks); update the variables
total_layers/layers_this_rank and reference layer_names when making this change.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/puzzletron/main.py (1)

102-135: ⚠️ Potential issue | 🟠 Major

Ensure distributed cleanup runs on failure paths.

If convert/search/sweep/MIP raises, dist.cleanup() is skipped. In multi-GPU flows this can leave process groups hanging. Wrap execution in try/finally in both runners.

Proposed fix
 def run_full_puzzletron(hydra_config_path: str):
@@
     hydra_cfg, hydra_config_dir, hydra_config_name, n = _setup(hydra_config_path)
 
     mprint(f"Puzzletron Progress 1/{n}: starting puzzletron pipeline")
-
-    # Convert model (convert from HF to DeciLM, score pruning activations,
-    # prune the model and save pruned checkpoints)
-    input_model = PuzzletronModel()
-    converted_model = mtn.convert(
-        input_model,
-        mode=[
-            (
-                "puzzletron",
-                {
-                    "puzzle_dir": str(hydra_cfg.puzzle_dir),
-                    "input_model_path": hydra_cfg.input_hf_model_path,
-                    "hydra_config_dir": hydra_config_dir,
-                    "hydra_config_name": hydra_config_name,
-                    "dataset_path": str(hydra_cfg.dataset_path),
-                },
-            )
-        ],
-    )
-
-    # Run NAS search (build replacement library and compute stats,
-    # compute one block scores, run MIP and realize models)
-    mtn.search(
-        converted_model,
-        constraints={},  # this is not used as the search space is defined in the hydra config
-        dummy_input=None,  # Not used
-        config={},  # this is not used as the search space is defined in the hydra config
-    )
-
-    dist.cleanup()
+    try:
+        # Convert model (convert from HF to DeciLM, score pruning activations,
+        # prune the model and save pruned checkpoints)
+        input_model = PuzzletronModel()
+        converted_model = mtn.convert(
+            input_model,
+            mode=[
+                (
+                    "puzzletron",
+                    {
+                        "puzzle_dir": str(hydra_cfg.puzzle_dir),
+                        "input_model_path": hydra_cfg.input_hf_model_path,
+                        "hydra_config_dir": hydra_config_dir,
+                        "hydra_config_name": hydra_config_name,
+                        "dataset_path": str(hydra_cfg.dataset_path),
+                    },
+                )
+            ],
+        )
+
+        # Run NAS search (build replacement library and compute stats,
+        # compute one block scores, run MIP and realize models)
+        mtn.search(
+            converted_model,
+            constraints={},  # this is not used as the search space is defined in the hydra config
+            dummy_input=None,  # Not used
+            config={},  # this is not used as the search space is defined in the hydra config
+        )
+    finally:
+        dist.cleanup()
     mprint(f"Puzzletron Progress {n}/{n}: puzzletron pipeline completed (multi-gpu)")
@@
 def run_mip_only(hydra_config_path: str):
@@
-    # Check if sweep mode is enabled
-    if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
-        mprint(
-            f"Puzzletron Progress {mip_step}/{n}: running MIP sweep for multiple compression rates (multi-gpu)"
-        )
-        sweep.run_mip_sweep(hydra_cfg)
-    else:
-        # mip_and_realize_models (distributed processing)
-        # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
-        mprint(f"Puzzletron Progress {mip_step}/{n}: running MIP and realizing models (multi-gpu)")
-        mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
-
-    dist.cleanup()
+    try:
+        # Check if sweep mode is enabled
+        if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
+            mprint(
+                f"Puzzletron Progress {mip_step}/{n}: running MIP sweep for multiple compression rates (multi-gpu)"
+            )
+            sweep.run_mip_sweep(hydra_cfg)
+        else:
+            # mip_and_realize_models (distributed processing)
+            # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
+            mprint(f"Puzzletron Progress {mip_step}/{n}: running MIP and realizing models (multi-gpu)")
+            mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg)
+    finally:
+        dist.cleanup()
     mprint(f"Puzzletron Progress {n}/{n}: puzzletron pipeline completed (multi-gpu)")

Also applies to: 147-163

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/puzzletron/main.py` around lines 102 - 135, The current flow calls
dist.cleanup() after running mtn.convert and mtn.search but if
mtn.convert/mtn.search (or any subsequent step) raises an exception the cleanup
is skipped; wrap the multi-GPU pipeline (from _setup through
mtn.search/mtn.sweep/mtn.MIP calls around lines that create PuzzletronModel,
call mtn.convert and mtn.search) in a try/finally block so dist.cleanup() always
runs, and apply the same try/finally pattern to the other runner block
referenced around lines 147-163; ensure the try encompasses all work that
requires the distributed group and the finally calls dist.cleanup()
unconditionally.
🧹 Nitpick comments (1)
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)

377-381: Consider clearer error handling for unknown block_loss_func.

If cfg.model_factory.block_loss_func is not one of the three expected values, a KeyError is raised with a cryptic message. A more informative error would help users diagnose configuration issues.

♻️ Suggested improvement
-    block_loss_func = {
+    _BLOCK_LOSS_FUNCS = {
         "normalized_mse_loss": normalized_mse_loss,
         "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss,
         "batched_normalized_mse_loss": batched_normalized_mse_loss,
-    }[cfg.model_factory.block_loss_func]
+    }
+    loss_func_name = cfg.model_factory.block_loss_func
+    if loss_func_name not in _BLOCK_LOSS_FUNCS:
+        raise ValueError(
+            f"Unknown block_loss_func '{loss_func_name}'. "
+            f"Expected one of: {list(_BLOCK_LOSS_FUNCS.keys())}"
+        )
+    block_loss_func = _BLOCK_LOSS_FUNCS[loss_func_name]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 377 - 381, The current lookup for block_loss_func using a dict
keyed by cfg.model_factory.block_loss_func can raise a cryptic KeyError; update
the code around block_loss_func (in stitched_model_factory.py) to explicitly
validate cfg.model_factory.block_loss_func against the allowed names
("normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss") and raise a clear ValueError that includes the
invalid value and the list of valid options; reference the existing functions
normalized_mse_loss, vectorwise_normalized_mse_loss, and
batched_normalized_mse_loss when constructing the mapping and error message so
users can quickly see the supported choices.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/puzzletron/main.py`:
- Around line 99-101: The docstring parameter name is incorrect — replace the
documented `config_path` with the actual function parameter `hydra_config_path`
and update its description to match (e.g., "Path to the YAML configuration
file") so the `hydra_config_path` argument in the function signature and the
docstring are consistent; locate the docstring in examples/puzzletron/main.py
near the function that accepts `hydra_config_path` and make this single-name
correction.

---

Outside diff comments:
In `@examples/puzzletron/main.py`:
- Around line 102-135: The current flow calls dist.cleanup() after running
mtn.convert and mtn.search but if mtn.convert/mtn.search (or any subsequent
step) raises an exception the cleanup is skipped; wrap the multi-GPU pipeline
(from _setup through mtn.search/mtn.sweep/mtn.MIP calls around lines that create
PuzzletronModel, call mtn.convert and mtn.search) in a try/finally block so
dist.cleanup() always runs, and apply the same try/finally pattern to the other
runner block referenced around lines 147-163; ensure the try encompasses all
work that requires the distributed group and the finally calls dist.cleanup()
unconditionally.

---

Nitpick comments:
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 377-381: The current lookup for block_loss_func using a dict keyed
by cfg.model_factory.block_loss_func can raise a cryptic KeyError; update the
code around block_loss_func (in stitched_model_factory.py) to explicitly
validate cfg.model_factory.block_loss_func against the allowed names
("normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss") and raise a clear ValueError that includes the
invalid value and the list of valid options; reference the existing functions
normalized_mse_loss, vectorwise_normalized_mse_loss, and
batched_normalized_mse_loss when constructing the mapping and error message so
users can quickly see the supported choices.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 99826f98-f6fb-41d2-a78c-5c40bec6c4c9

📥 Commits

Reviewing files that changed from the base of the PR and between 351b44e and 346408b.

📒 Files selected for processing (3)
  • examples/puzzletron/main.py
  • modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py

Comment thread examples/puzzletron/main.py Outdated
Comment on lines +99 to +101
Args:
config_path: Path to the YAML configuration file
"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix docstring argument name mismatch.

Line 100 documents config_path, but the function argument is hydra_config_path. Please align the docstring to avoid confusion.

Proposed fix
 def run_full_puzzletron(hydra_config_path: str):
@@
     Args:
-        config_path: Path to the YAML configuration file
+        hydra_config_path: Path to the YAML configuration file
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Args:
config_path: Path to the YAML configuration file
"""
Args:
hydra_config_path: Path to the YAML configuration file
"""
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/puzzletron/main.py` around lines 99 - 101, The docstring parameter
name is incorrect — replace the documented `config_path` with the actual
function parameter `hydra_config_path` and update its description to match
(e.g., "Path to the YAML configuration file") so the `hydra_config_path`
argument in the function signature and the docstring are consistent; locate the
docstring in examples/puzzletron/main.py near the function that accepts
`hydra_config_path` and make this single-name correction.

@Separius
Copy link
Copy Markdown
Contributor Author

Separius commented Apr 2, 2026

@cjluo-nv addressed all the points (thanks again for the great review)

Base automatically changed from feature/puzzletron to main April 15, 2026 19:18
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners April 15, 2026 19:18
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are 250+ files in this review, many no need to change, please could you clean the MR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

250+ files is because feature/puzzletron (base) is deleted I think, I'll try to make another MR to the main later this week

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for another just merge main to your feature branch and push the feature branch and many changes should be gone

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only 76 files now 😅

This section focuses on applying Model Optimizer's state-of-the-art complementary pruning modes to enable you to search for the best subnet architecture from your provided base model:

1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT (and later extended to Mamba, MoE, and Hybrid Transformer Mamba) models in NVIDIA Megatron-LM (M-LM) or Megatron-Bridge (M-Bridge) framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model.
1. [Puzzletron](../puzzletron/README.md): An advanced pruning method by NVIDIA using Mixed Integer Programming (MIP) based NAS search algorithm.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a tutorial for bypass

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also make the test pass

@Separius Separius force-pushed the ssameni/puzzletron-bypass branch 2 times, most recently from b43af72 to aa4b534 Compare May 6, 2026 08:41
@codecov
Copy link
Copy Markdown

codecov Bot commented May 6, 2026

Codecov Report

❌ Patch coverage is 32.79352% with 664 lines in your changes missing coverage. Please review.
✅ Project coverage is 66.75%. Comparing base (555be6c) to head (88792f4).

Files with missing lines Patch % Lines
...ch/puzzletron/bypass_distillation/training_loop.py 16.89% 359 Missing ⚠️
...tron/bypass_distillation/stitched_model_factory.py 39.00% 122 Missing ⚠️
modelopt/torch/puzzletron/puzzletron_nas_plugin.py 17.85% 69 Missing ⚠️
...ron/bypass_distillation/bypass_checkpoint_utils.py 78.26% 20 Missing ⚠️
...rch/puzzletron/bypass_distillation/bypass_utils.py 39.39% 20 Missing ⚠️
modelopt/torch/puzzletron/pruning/pruning_utils.py 9.52% 19 Missing ⚠️
...lopt/torch/puzzletron/tools/checkpoint_utils_hf.py 30.76% 18 Missing ⚠️
...odelopt/torch/puzzletron/utils/data/dataloaders.py 23.07% 10 Missing ⚠️
modelopt/torch/puzzletron/utils/parsing.py 12.50% 7 Missing ⚠️
...n/replacement_library/build_replacement_library.py 0.00% 5 Missing ⚠️
... and 7 more
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1111       +/-   ##
===========================================
- Coverage   76.74%   66.75%   -10.00%     
===========================================
  Files         476      483        +7     
  Lines       51307    52304      +997     
===========================================
- Hits        39377    34915     -4462     
- Misses      11930    17389     +5459     
Flag Coverage Δ
unit 52.54% <32.79%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Separius Separius force-pushed the ssameni/puzzletron-bypass branch from aa4b534 to c1a464d Compare May 6, 2026 08:57
Introduces blockwise local distillation (BLD) as an optional 5th step
between pruning and replacement-library construction. When `bypass:` is
present in the Hydra config, puzzletron trains alternative transformer
block configurations against the teacher via per-block knowledge
distillation, then surfaces the trained subblocks to MIP through symlinks
under `puzzle_dir/ckpts/<experiment_id>`.

New module `modelopt/torch/puzzletron/bypass_distillation/`:
- training_loop.py: pipeline-parallel KD loop with cosine LR schedule,
  per-block AdamW + GradScaler, validation, and time/step-based
  checkpointing.
- stitched_model_factory.py: unified factory (FFN/attention/MoE/Mamba/
  whole-block) driven by `mlp_init_mode` x `keys_to_learn`; composes
  multiple pruning mixins (experts_removal + kv_heads + ffn_intermediate)
  when student/teacher configs differ along multiple axes.
- bypass_checkpoint_utils.py: stitched-module state save/load with
  `latest` symlink and `saving_completed` marker; resume scans only
  plain `iter-NNNNNN-ckpt` directories.
- bypass_utils.py: experiment_id derivation encoding every override
  axis (FFN+KV / experts+KV / KV-only) so sweeps cannot collide.

Pipeline integration (puzzletron_nas_plugin.py):
- _progress_step(hydra_cfg, stage) helper + canonical stage order
  produce coherent `Puzzletron Progress N/T` strings; total grows from
  8 to 9 when bypass is configured.
- Skip-if-done caching for every stage (convert / score / prune /
  bypass / library build).
- Stale-library detection: any `ckpts/*` entry newer than
  `replacement_library.json` triggers a rebuild so post-bypass weights
  are picked up automatically.
- build_replacement_library orders bypass-trained subblocks before
  Truncate-init variants so `drop_duplicates(keep="first")` is
  deterministic.
- Auto-download HF model when input_hf_model_path is not a local dir.

KV-heads pruning generalized beyond Llama:
- KVHeadsLayerDescriptor registered for GptOss, NemotronH,
  NemotronH-V2, Qwen3-VL.
- _lm_attrs() probes text_config and language_config so VL configs
  (Qwen3-VL, Llava, Llama-4) read num_attention_heads and head_dim
  from the right sub-config.
- _init_attention_biases falls back to state-dict probing when a
  config (e.g. GptOssConfig) doesn't expose o_proj_bias /
  attention_bias as top-level attributes.

MIP: new target_num_kv_heads constraint over stats.num_kv_heads
for KV-cache-only sweeps.

Tools:
- _copy_auto_map_code_files: copies custom modeling_*.py files
  alongside config.json so trust-remote-code models reload
  correctly; identifier-shape guard rejects malformed auto_map
  entries.
- tools/robust_json.py: JSON encoder for dataclasses, paths, enums,
  Namespaces, OmegaConf nodes, functions/classes, and timedeltas;
  used by bypass to serialize resume state.

Multi-mixin and bug fixes in legacy paths:
- child_init._process_single_layer accepts a single mixin or a list,
  enabling experts_removal + kv_heads + ffn_intermediate stacking.
- update_model_config.override returns the original `item` (not
  None) when an override is None — preserves the original value
  instead of clobbering it.

Data utilities:
- create_train_dataloader for the infinite ConstantLengthDataset
  used by bypass training.
- ConstantLengthDataset falls back to newline-joined message
  contents when the tokenizer has no chat_template (base models).
- format_stitched_losses filters NaN entries from no-op blocks.

Tutorial + configs:
- examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md:
  end-to-end KV-heads-only demonstration on a real 30B-A3B MoE-Mamba
  teacher, showing bypass closes ~74% of the lm_loss regression gap.
- Per-family configs under configs/nemotron-3-nano-30b-a3b/.
- Generic bypass/defaults.yaml template under llama-3_1-8B.

Tests:
- Unit: normalized-MSE losses; get_distributed_modules_ownership.
- GPU: test_bypass.py parametrizes block-pruning, KV-head
  compression, multi-config sweep, and checkpoint-contents tests
  across 9 model families (extracted into PUZZLETRON_FAMILIES).

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
@Separius Separius force-pushed the ssameni/puzzletron-bypass branch from c1a464d to 691f4f6 Compare May 6, 2026 10:19
HF datasets uses lazy __getattr__ at the package level (PEP 562), so mypy
can't resolve top-level names — `from datasets import DatasetDict` fails
with attr-defined. Switch to submodule imports (e.g. `from
datasets.dataset_dict import DatasetDict`, `from datasets.load import
load_dataset`) which bypass the lazy loader.

Also folds in pre-commit cleanups across the bypass changeset:
- ruff E501/N806/PT006 lint fixes (uppercase N in main.py, line length
  in PUZZLETRON_FAMILIES + main.py MIP-sweep mprint, parametrize tuple
  shape in test_bypass_utils.py).
- markdownlint MD040 (fenced-code language tag in tutorial md).
- ruff format auto-applied (PUZZLETRON_FAMILIES table, descriptor
  imports, etc.).
- yamlfmt auto-applied to bypass YAML configs.
- Drop dead StitchedModelFactoryFn type alias and its cast() call.
- Annotate `descriptor` as ModelDescriptor (not type[ModelDescriptor])
  to match the codebase convention used in init_child_from_parent.py
  (annotated as instance, called with the class — no-op at runtime,
  silences mypy).

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 6, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1111/

Built to branch gh-pages at 2026-05-07 13:48 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Separius added 2 commits May 6, 2026 04:56
…rounds

Bypass-distillation review fixes:
- bypass_checkpoint_utils: persist + restore GradScaler state alongside
  optimizer state so resumed runs with use_grad_scaling=True don't lose
  the running scale + growth tracker.
- stitched_model_factory: raise a clear RuntimeError when world_size
  exceeds num_hidden_layers (would otherwise crash trailing ranks with
  a bare `min() arg is an empty sequence`); same condition fires
  identically on every rank, so no NCCL hang.
- training_loop: actually apply skip_first_batches by advancing the
  data iterator before the loop (parameter was previously accepted
  but unused).
- training_loop: fix off-by-one in the max_steps exit condition (was
  >=, now >) so the final scheduled step actually runs; same fix
  applied to the in-loop save-when-done branch so the final checkpoint
  is saved exactly once.
- training_loop: read source_datasets_to_discard from cfg.bypass.data
  (where the YAML actually nests it) instead of cfg.bypass root, where
  it always fell back to the empty tuple.
- dataloaders: reject num_workers > 0 in create_train_dataloader with
  an explicit error since ConstantLengthDataset.__iter__ does not shard
  via torch.utils.data.get_worker_info(); guard removable once the
  dataset gains worker-aware iteration.
- dataloaders: branch on isinstance(train_data, datasets.IterableDataset)
  before passing keep_in_memory=True to .shuffle(); streaming mode
  (load_from_disk=false) doesn't accept that kwarg.

Revert the datasets submodule-import workarounds (`from datasets.load
import load_dataset`, etc.) — CI's mypy resolves top-level `from datasets
import X` correctly because it installs the package via uv, so these
workarounds were only needed for the local-pre-commit env on a node
without datasets installed. Reverting shrinks the MR back to bypass-only
files.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
- Drop the `gqa_factory_fn` / `moe_factory_fn` backward-compat aliases
  in stitched_model_factory.py — no remaining call sites or YAML configs
  reference them after the unified `bypass_factory_fn` migration.
- Apply ruff-format reflow to the two `source_datasets_to_discard=...`
  kwargs in training_loop.py (collapse the train-side call onto one
  line; re-indent the val-side multi-line call body).

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming not feasible to reuse boilerplate from transformers Trainer?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, PP via SewingKit makes everything complicated/unique (I have an implementation via AutoModel, once this MR is landed, I can fully compare the two and we can decide to merge it or not, it is a bit complicated but less than sewingkit)

Separius added 9 commits May 6, 2026 08:42
Test additions
--------------

Three new test files plus three new tests appended to test_bypass.py
covering paths the existing 4×9-family happy-path tests didn't reach:

- tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py — pins all 8
  branches of `_set_keys_to_learn` (subblock_ffn / subblock_attention /
  subblock_mamba / entire_block / list / regex / hybrid block_configs
  filter / no-match silent-return). The hybrid Mamba-vs-GQA filter is
  silently misroutable on descriptor refactors; this test catches it.

- tests/unit/torch/puzzletron/test_bypass_replacement_library.py —
  verifies `_get_last_checkpoint_from_each_experiment` discovers
  symlinked bypass + pruning checkpoints, and that the bypass-priority
  sort closure orders bypass-rooted paths before Truncate-init ones (a
  regression here would silently discard bypass-trained weights).

- tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py — save/load
  round-trip for stitched-module state, optimizer state, and (new)
  GradScaler state. The GradScaler round-trip is the regression test for
  the recent CodeRabbit-flagged bug where resumed fp16 + use_grad_scaling
  runs lost the running scale + growth tracker. Lives under tests/gpu/
  rather than tests/unit/ because the production load path constructs
  `torch.device(f"cuda:{rank}")` and torch.load needs a real CUDA device
  to deserialize.

- tests/gpu/torch/puzzletron/test_bypass.py — three new tests:
  * test_bypass_resume_from_checkpoint: 2-phase train→save→resume on
    Llama-3.2-3B, asserts `iter_num` advances past the saved value.
    GradScaler resume is covered separately at unit level (above)
    because GradScaler.step() is fp16-only and the bypass infra is bf16.
  * test_bypass_subblock_modes: parametrized over Llama-3.2-3B (dense
    Truncate path) × GPT-OSS-20B (MoE ExpertRemoval + windowed-attention-
    with-sinks path) × {subblock_ffn, subblock_attention, entire_block};
    diffs start-of-training vs end-of-training stitched-module weights
    and asserts only the expected param groups changed.
  * test_bypass_then_build_library: end-to-end smoke — runs bypass, then
    `_build_subblocks_df`, asserts the bypass experiment appears in the
    resulting subblocks DataFrame's checkpoint-source columns.

Review-driven fixes
-------------------

- bypass_distillation/training_loop.py:
  * AutoTokenizer.from_pretrained: pass `trust_remote_code=trust_remote_code`
    (was hardcoded `True`); the variable is already derived from the
    descriptor a few lines earlier.
  * Wrap-around `try/except Exception`: re-raise (was `sys.exit(1)`) so
    pytest sees the real exception type instead of a generic SystemExit,
    and so distributed runs surface usable tracebacks.

- sewing_kit/utils.py: `normalized_mse_loss` is now re-exported from
  `tools.kd_model` instead of redefined; the two implementations were
  byte-for-byte identical.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
- test_bypass.py: drop unused `block_name` (F841); ruff-format reflows
  for `_test_bypass_then_build_library_job`'s `_setup_hydra_cfg_and_pruning`
  args (one-arg-per-line) and the `Bypass run not discovered` assert
  message (single-line f-string).
- sewing_kit/utils.py: re-export `normalized_mse_loss` via the
  `as normalized_mse_loss` form (PEP 484 explicit re-export). The prior
  `from X import Y  # noqa: F401` form is treated by mypy as a private
  import, which surfaced as `attr-defined` at the call site in
  stitched_model_factory.py.
- test_bypass_checkpoint_utils.py: remove blank line after the import
  block; collapse single-line optimizer ternary; collapse three function
  signatures whose argument list now fits on one line.
- test_bypass_keys_to_learn.py: collapse single-symbol import and drop
  the trailing blank line.
- test_bypass_replacement_library.py: drop a blank line and reflow
  `bypass_real = ...` to use parens for the line continuation.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Two more call sites (in _test_bypass_resume_from_checkpoint_job and
_test_bypass_subblock_modes_job) had the old multi-line-arg-pair format
that ruff format wants reflowed to one-arg-per-line. The previous CI fix
caught only the third call site.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Pure-CPU unit tests covering the bypass-distillation surface that
codecov flagged as uncovered:
  * puzzletron_nas_plugin progress helpers
  * dataloaders (split auto-detect, num_workers guard, pad helper,
    Printer fake accelerator, load_*_fn delegators)
  * launch_bypass_distillation sweep dispatcher
  * bypass_checkpoint_utils (find_latest_run_dir, _save_local_file,
    _save_local_state, save_bypass_checkpoint orchestration)
  * stitched_model_factory _get_all_non_persistent_buffers_set
  * sewing_kit InputArgs, ActivityContext, Needle validation

Plus a GPU integration test pinning resume-from-latest end-to-end
(test_bypass_resume.py).

Fix off-by-one in _get_lr cosine: decay_ratio = (step - W) / (D - W)
so the schedule reaches min_lr exactly at step==D instead of relying
on the post-decay clamp at D+1 to mask a one-step plateau at base_lr
right after warmup.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
* SIM117: combine nested `with` statements where semantically
  equivalent (test_max_depth, test_no_duplicates_*, test_stack_unwinds).
  test_reversed_pushes_to_front_and_pops_from_front keeps the nested
  form because the intermediate assertion between exits is the test's
  actual point.
* ruff format: drop blank line after lone import, merge two
  `from ... import` blocks for sewing_kit.core.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
* RUF005: noqa the `InputArgs(1) + [2]` line — the auto-fix
  `[*InputArgs(1), 2]` would replace the operator call we are
  testing (this test pins that ``__add__`` rejects non-InputArgs
  via its internal assert).
* ruff format: drop spurious blank lines after import blocks,
  collapse a function signature back to one line.
* PLC0415 (in-function imports): hoist the two ``import os`` calls
  inside ``test_save_bypass_checkpoint_*`` to the file's import
  block.
* Move the late ``is_submodule_of`` / ``is_submodule_or_same``
  imports in ``test_sewing_kit_activity_context.py`` to the top
  of the file so the # noqa: E402 marker isn't needed.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Four test signatures in test_bypass_checkpoint_utils.py fit on
a single line within the 100-char limit; ruff format wants them
collapsed. The fifth signature in the same file
(test_save_bypass_checkpoint_master_only_skips_symlink_on_non_master)
has three args and stays multi-line.

Verified with `awk 'length > 100'` and `grep -nE '^def test.*\(\s*$'`
across all new test files — no other lint nits should remain.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
@Separius
Copy link
Copy Markdown
Contributor Author

Separius commented May 7, 2026

@AAnoosheh ready for review

Comment thread modelopt/torch/puzzletron/tools/robust_json.py Outdated
@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/ok to test 4ea3262

@kevalmorabia97
Copy link
Copy Markdown
Collaborator

/claude review

Comment thread modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py Outdated
Comment thread modelopt/torch/puzzletron/anymodel/models/gpt_oss/gpt_oss_model_descriptor.py Outdated
Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py
Comment thread modelopt/torch/puzzletron/sewing_kit/utils.py
Comment thread modelopt/torch/puzzletron/bypass_distillation/training_loop.py
Comment thread modelopt/torch/puzzletron/puzzletron_nas_plugin.py
Comment on lines +528 to +545
ExternalTarget().input(name=input_descriptor),
student_submodule_target.input(name=submodule_input_descriptor),
)
.stitch(
ExternalTarget().output(
name=output_descriptor,
adapter=lambda v: InputArgs(target=v)
if not isinstance(v, tuple)
else InputArgs(target=v[0]),
),
student_stitched_module_loss_target.input(),
)
.stitch(
student_submodule_target.output(
name=submodule_output_descriptor,
adapter=lambda v: InputArgs(input=v)
if not isinstance(v, tuple)
else InputArgs(input=v[0]),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[CRITICAL Algorithm] These lambda closures capture the loop variable v by reference at definition time — but the lambdas themselves are created fresh each iteration so they're safe. However, the adapter lambdas that call InputArgs(target=v) vs InputArgs(input=v) both receive the same variable name v. If the stitching framework ever evaluates both adapters in the same scope, they'd alias. This is likely fine as they're stored on different targets, but worth noting the naming collision for readability.

More importantly: the lambda v: InputArgs(target=v) if not isinstance(v, tuple) else InputArgs(target=v[0]) pattern assumes tuple outputs always have the hidden state as element 0. This is true for standard HuggingFace transformer blocks (which return (hidden_states, ...)) but could break for custom blocks that return tuples with a different layout. Consider adding a defensive shape/dtype assertion on v[0] during the dummy-run phase to catch this.

def override(item, item_overrides):
if item_overrides is None:
return item_overrides
return item # None override means "keep original value"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] Good fix — the previous code returned None literally (destroying the original item), which silently corrupted model configs when a layer had no override. This is a bugfix, but note it changes observable behavior: existing scripts that relied on None being passed through to update_model_config (perhaps incorrectly) will now silently keep the original value instead of producing a TypeCheckError or None in the block config. This is almost certainly the correct behavior, but document it in the PR as a behavioral change for existing users.

@claude
Copy link
Copy Markdown

claude Bot commented May 7, 2026

Code Review Summary

CRITICAL: 1 | IMPORTANT: 3 | SUGGESTION: 3

Critical Findings

  1. [CRITICAL Algorithm] batched_normalized_mse_loss in sewing_kit/utils.py — the per-element denominator norm_of_target_vectors can be as small as epsilon^2 = 1e-12 when a target slice is all-zeros, producing loss values on the order of 1e12 and causing gradient explosion. Unlike the original normalized_mse_loss (which averages globally), per-batch normalization has no dilution from non-zero elements. Recommend adding .clamp(min=epsilon) on the denominator.

Important Findings

  1. [IMPORTANT Compatibility] GptOssModelDescriptor.pruning_mixins() renames the dict key from "expert_removal""experts_removal". While YAML configs use Hydra _target_ instantiation (unaffected), the resolve_pruning_mixin(string_key, descriptor) API supports string-based lookups and would now fail for the old key.

  2. [IMPORTANT Performance] training_loop.py:310.to("cpu").item() per block per iteration forces a CUDA sync point for every stitched module. With 32+ layers this creates significant GPU idle time. Consider batching the CPU transfer.

  3. [IMPORTANT Performance] training_loop.py:62time_start = time.time() at module-import time causes the first time-based checkpoint save to trigger immediately if there's a delay between import and training start.

Suggestions

  1. [SUGGESTION] Dead code: num_trainable_params in stitched_model_factory.py is computed but never used (and sums booleans instead of numel()).

  2. [SUGGESTION] _checkpoint_priority() in build_replacement_library.py — uses string-in-path heuristic; robust but brittle if directory layout changes.

  3. [SUGGESTION] Staleness detection for replacement library rebuild checks directory mtime rather than specific file mtime — conservative (safe to rebuild) but may cause unnecessary rebuilds.

Overall Assessment

Risk level: Medium. The PR adds a substantial new training subsystem (bypass distillation) to the puzzletron pipeline. The core architecture is well-designed — pipeline-parallel per-block KD using the existing sewing-kit abstraction, clean checkpoint management, and comprehensive test coverage across all model families.

The critical finding (loss denominator floor) is the most impactful — it can cause training divergence on models with sparse/zero activations in some blocks. The compatibility findings are low-probability impacts (the string-key path and None override behavior change) but worth documenting.

The child_init.py fix for None overrides is a genuine bugfix that corrects silent config corruption. The _lm_attrs helper for VL config nesting is a clean abstraction. The progress counter refactoring is well-structured.

The test suite is thorough: unit tests for losses, checkpoint utils, and LR scheduling, plus GPU integration tests parametrized across 9 model families covering FFN pruning, KV head compression, multi-config sweeps, and checkpoint validation.

Separius added 3 commits May 7, 2026 05:47
Drop the per-block weight save from `_save_local_state` — the same
parameters were already on disk in the top-level HF checkpoint that
`save_bypass_checkpoint` writes via `save_checkpoint(model, ...)`.
The stitched/ directory now carries only optimizer + grad_scaler
state; weights round-trip through the HF format. Resume now routes
through `load_and_shard_model` (same path as `init_checkpoint_path`)
so weight loading has a single entry point. Per-iter checkpoint
disk footprint roughly halves.

Other Claude review fixes:

* Use modelopt.torch.utils.robust_json (canonical) and delete the
  duplicate puzzletron-local copy.
* Remove dead num_trainable_params block in bypass_factory_fn (unused
  and the sum was a bool count, not a numel sum).
* GptOssModelDescriptor: keep "expert_removal" as a deprecation alias
  for the new "experts_removal" key so existing string-API callers
  don't break.
* Move time_start out of module scope into train() — at module level
  it became stale relative to actual training start, firing the first
  time-based save immediately.
* Batch the per-block loss GPU->CPU copy into a single sync after the
  loop (was N sync points per training step).
* puzzletron_nas_plugin staleness check: probe checkpoint config.json
  mtime instead of resolved-symlink directory mtime; handle dangling
  symlinks via try/except.
* Fix off-by-one in _get_lr cosine: decay_ratio = (step - W) / (D - W)
  so the schedule reaches min_lr exactly at step==D.
* batched_normalized_mse_loss: clamp the per-vector denominator to a
  floor of epsilon so an all-zero target slice doesn't explode the
  loss. Slightly diverges from the original Puzzle implementation
  (documented in the docstring).
* Add HF-convention comment in stitched_model_factory: tuple-returning
  blocks always have hidden_states at index 0.

Tests updated to match: stitched/{block}.state_dict.pth assertion
flipped to assert-not-written; new regression tests for the LR
schedule endpoint and the zero-target loss path.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
The stacked-loss expression fits on a single line within the 100-char
limit; ruff format unwrapped it.

Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants