Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111
Add bypass distillation (blockwise local KD) to puzzletron pipeline#1111
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds an optional bypass (blockwise local) distillation stage: new bypass package with stitched teacher–student factory, distributed training loop and checkpointing, model/pruning extensions, normalized-MSE losses, dataloader helper, example configs/docs, and unit/GPU tests; integrates bypass into Puzzletron control flow. Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant Launcher as "launch_bypass_distillation(hydra_cfg)"
participant Orchestrator as "run_bypassed_training(cfg)"
participant Factory as "stitched_model_factory()"
participant Data as "DataLoader / Teacher"
participant Trainer as "train()"
participant Checkpoint as "save_bypass_checkpoint()"
User->>Launcher: provide Hydra cfg (single or sweep)
Launcher->>Orchestrator: start run(s)
Orchestrator->>Data: load teacher model & dataloaders
Orchestrator->>Factory: build stitched teacher & student modules
Factory-->>Orchestrator: return stitched modules + descriptors
Orchestrator->>Trainer: start training loop
loop per iteration
Trainer->>Data: fetch batch
Trainer->>Trainer: teacher forward -> capture activations
Trainer->>Trainer: student forward -> compute per-block losses
Trainer->>Trainer: backward, grad scale/clip, optimizer step
Trainer->>Checkpoint: conditional save, write markers, symlink
Checkpoint-->>Trainer: sync / resume info
end
Trainer-->>Orchestrator: training complete
Orchestrator-->>Launcher: run finished
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Comment |
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tests/gpu/torch/puzzletron/test_puzzletron.py (1)
236-245:⚠️ Potential issue | 🟡 MinorThe fallback printer still emits only rank-local values.
This branch now advertises
num_layers={total_layers}, but it still prints only the contents ofrank_{rank}.pthand is executed on rank 0 only. On multi-GPU runs the suggestedEXPECTED_PRUNING_VALUESsnippet will be incomplete.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/gpu/torch/puzzletron/test_puzzletron.py` around lines 236 - 245, The printer currently outputs only rank-local pruning_scores causing incomplete EXPECTED_PRUNING_VALUES for multi-GPU runs; modify the logic so rank 0 aggregates pruning data from all ranks before printing: collect and merge per-rank pruning_scores (or load all rank_{rank}.pth files) into a global pruning_scores for each layer_name, compute the global score and channels (e.g., combine/average or gather channel indices across ranks) respecting total_layers, and then have rank 0 iterate over layer_names using the aggregated values when printing the block that uses total_layers and prints the EXPECTED_PRUNING_VALUES snippet.modelopt/torch/puzzletron/pruning/pruning_utils.py (1)
40-47:⚠️ Potential issue | 🟠 Major
MoEChannelPruningis exposed before the init path supports it.
modelopt/torch/puzzletron/tools/bypassed_training/child_init.pynow branches on this enum and forwards it into_init_mlp_module(), but_init_mlp_module()still falls through toUnsupported mlp_init_modefor this value when expert widths change. Any config that selectsMoEChannelPruningwill fail during child initialization.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/pruning/pruning_utils.py` around lines 40 - 47, The enum MlpInitMode now includes MoEChannelPruning but _init_mlp_module still treats that case as unsupported; update the _init_mlp_module implementation to handle MlpInitMode.MoEChannelPruning (the same call-site that child_init.py forwards into) by adding a branch for MlpInitMode.MoEChannelPruning that performs the correct initialization when expert widths change (e.g., adapt the weight/activation shapes by slicing/reshaping or reuse the ConcatExpertsIntoDenseFFN logic where appropriate), so the child init no longer falls through to the "Unsupported mlp_init_mode" error for MoEChannelPruning.
🧹 Nitpick comments (5)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)
804-806: This change makes explicitnullresets impossible.Treating
Noneas “keep original” fixes the accidental overwrite, but it also removes the only way for JSON/YAML overrides to clear an optional field back toNone. If callers need both behaviors, use a sentinel for “no override” and reserveNonefor explicit clearing.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines 804 - 806, The current override function (override) treats item_overrides == None as "keep original", which prevents callers from explicitly clearing a value to None via JSON/YAML; change the logic to use a distinct sentinel (e.g., a new unique object like NO_OVERRIDE) to represent "no override" and reserve None in item_overrides to mean "set to None"/clear the field, updating the override function to check against the sentinel (NO_OVERRIDE) instead of None and adjust any callers that construct overrides to use the sentinel when they mean "leave original".modelopt/torch/puzzletron/utils/data/dataset.py (1)
123-130: Keep role markers in the no-template fallback.Joining only
contentcollapsessystem/user/assistantturns into plain text, which changes the supervision for chat datasets. A lightweight fallback like"{role}: {content}"preserves the conversation structure without relying on a tokenizer template.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/utils/data/dataset.py` around lines 123 - 130, The fallback that builds sample when getattr(self.tokenizer, "chat_template", None) is None should preserve role markers instead of joining only message["content"]; update the else branch in dataset.py (the block that currently sets sample = "\n".join(m["content"] for m in sample)) to join messages using a lightweight role-prefixed format like "{role}: {content}" so conversation turns (system/user/assistant) are retained; keep using the same sample variable and ensure this mirrors the structure expected by downstream code that consumes apply_chat_template outputs.modelopt/torch/puzzletron/utils/parsing.py (1)
337-345: Don’t silently treat every NaN as a no-op block.This formatter now drops any NaN entry and can report
No trainable blocks found. If a trainable block diverges, the failure disappears from the logs instead of surfacing. Filter only known skipped block types, or emit a separate warning for unexpected NaNs.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/utils/parsing.py` around lines 337 - 345, The current filtering silently drops any NaN in losses_dict (and prunes best_steps_dict/best_values_dict to match), which hides diverging trainable blocks; instead, update the logic around losses_dict, best_steps_dict and best_values_dict so you only drop entries whose keys match known skipped block types (e.g., the explicit list of no-op block names like "Mamba"), and for any other NaN values emit a warning/error (via the existing logger) that a trainable block produced NaN rather than removing it; ensure best_steps_dict and best_values_dict are only pruned to match the filtered losses_dict after this selective filtering and warning behavior.examples/puzzletron/main.py (1)
154-167: Progress messages inrun_mip_onlyare hardcoded and inconsistent with the dynamic approach.The
run_full_puzzletronfunction now uses dynamic step counting (N = _total_steps(hydra_cfg)), butrun_mip_onlystill uses hardcoded "7/8" and "8/8" progress messages. If bypass is configured, the step numbers would be incorrect (should be 8/9 and 9/9).Consider applying the same dynamic step count logic here for consistency.
♻️ Suggested fix
def run_mip_only(hydra_config_path: str): ... # Load hydra config hydra_cfg = initialize_hydra_config_for_dir( config_dir=hydra_config_dir, config_name=hydra_config_name, overrides=[], ) + N = _total_steps(hydra_cfg) # Check if sweep mode is enabled if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): mprint( - "Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)" + f"Puzzletron Progress {N-1}/{N}: running MIP sweep for multiple compression rates (multi-gpu)" ) sweep.run_mip_sweep(hydra_cfg) else: # mip_and_realize_models (distributed processing) # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API - mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)") + mprint(f"Puzzletron Progress {N-1}/{N}: running MIP and realizing models (multi-gpu)") mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) dist.cleanup() - mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)") + mprint(f"Puzzletron Progress {N}/{N}: puzzletron pipeline completed (multi-gpu)")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/puzzletron/main.py` around lines 154 - 167, Update run_mip_only to compute the total steps like run_full_puzzletron by calling _total_steps(hydra_cfg) and use that N when formatting the progress messages instead of hardcoded "7/8" and "8/8"; specifically, replace the two mprint calls around the conditional that currently show "Puzzletron Progress 7/8" and "8/8" with dynamic messages using N (e.g., f"Puzzletron Progress {current_step}/{N}: ...") and ensure current_step increments are correct for both the sweep branch (sweep.run_mip_sweep) and the mip branch (mip_and_realize_models.launch_mip_and_realize_model) so progress displays consistently with _total_steps.modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)
548-556: Unused variablenum_trainable_params.The variable
num_trainable_paramsis computed but never used in this function or elsewhere. This appears to be residual code. Consider removing it to reduce unnecessary computation and improve code clarity.♻️ Proposed removal
assert "learning_rate" in cfg.training - num_trainable_params = sum( - p.requires_grad and submodule_name in p_name - for p_name, p in student_stitched_module.named_parameters() - if "dummy_param" not in p_name # exclude placeholder params - ) - # Do NOT enable dummy params: blocks with no real trainable parameters - # (e.g. Mamba blocks during an attention-only bypass run) should produce - # NaN loss so they are excluded from statistics — identical to the - # optimizer=None path in the training loop. student_module_parameters = {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py` around lines 548 - 556, Remove the unused computation of num_trainable_params: delete the sum(...) assignment that iterates over student_stitched_module.named_parameters() checking p.requires_grad and submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding explanatory comment about dummy params if still relevant, but eliminate the dead variable and its associated needless iteration to avoid wasted computation and clarify stitched_model_factory.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 45-58: The fallback currently only sorts checkpoint directories by
iteration (get_iter_num) so when multiple checkpoints exist for the same iter we
may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 673-677: Replace the hardcoded trust_remote_code=True in the
AutoTokenizer.from_pretrained call with the same caller-configurable
trust_remote_code flag you already read from the descriptor earlier (the
variable used for model config loading at lines ~597/631); specifically update
the tokenizer = AutoTokenizer.from_pretrained(...) invocation that uses
cfg.teacher_dir so it passes the descriptor-derived trust_remote_code value
instead of True, ensuring the flag remains configurable and defaults to False.
In `@modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py`:
- Around line 146-149: The pre-checks that treat presence of files like
(teacher_dir / "config.json"), any rank_*.pth, files under
pruned_ckpts_output_dir, or library outputs as sufficient to skip stages are
unsafe; change these guards to rely on durable completion markers (e.g., a .done
or .complete file) created at the successful end of
conversion/scoring/pruning/library build instead of existence-only checks, so
functions like the conversion branch around teacher_dir/config.json, the rank_*
checkpoint checks, and the pruned_ckpts_output_dir/library checks only skip when
their corresponding completion marker exists; ensure launch_score_activations()
remains the stricter gate for pruning-activation scoring but remove or weaken
the naive existence checks noted at the conversion lines (the block using
teacher_dir/config.json) and the other mentioned blocks (191-193, 286-289) to
check for the specific "<stage>.complete" marker before skipping.
In `@modelopt/torch/puzzletron/sewing_kit/utils.py`:
- Around line 452-454: The normalization denominator is computed as
F.mse_loss(target, torch.zeros_like(target) + epsilon, ...) which shifts the
target by epsilon and biases the scale; instead compute the denominator as
F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon (or
clamp_min the denominator to epsilon) so you add epsilon to the final scalar
denominator instead of to the zero tensor; update the occurrences around the
loss assignment (loss, input, target, epsilon, F.mse_loss) and the similar block
at lines 479-482 accordingly.
In `@modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py`:
- Around line 380-396: The auto_map parsing in checkpoint_utils_hf.py
incorrectly assumes each model_config.auto_map value is a dotted string; update
the logic that builds module_files (and any usage of class_ref) to first
normalize each value by: if it's a list/tuple take the first element, if it
contains a repo qualifier split off the "repo_id--" prefix, then take the module
part before the first '.' and append ".py" (so "tokenization_my.py"); apply this
normalization where module_files is created and when iterating filenames so
lists/tuples and repo-qualified references are handled and the correct source
filenames are copied.
In `@modelopt/torch/puzzletron/utils/data/dataloaders.py`:
- Around line 89-90: The DataLoader factory allows num_workers>0 while
ConstantLengthDataset.__iter__ does not shard via get_worker_info(), causing
duplicate samples; update the dataloaders (the function returning DataLoader in
modelopt/torch/puzzletron/utils/data/dataloaders.py) to either reject or
override num_workers>0 (e.g., force num_workers=0 or raise an error) until
ConstantLengthDataset.__iter__ is updated to use
torch.utils.data.get_worker_info() for worker sharding, and apply the same guard
to the other DataLoader creation call sites noted around the 114-118 region;
reference ConstantLengthDataset.__iter__ when adding the guard so future
implementers know to remove it after adding proper worker-aware iteration.
- Around line 98-99: The call to train_data.shuffle(seed=shuffle_seed,
keep_in_memory=True) fails for streaming (Iterable) datasets because
IterableDataset.shuffle() doesn't accept keep_in_memory; update the code that
checks shuffle_seed to detect streaming datasets (e.g., via whatever marker
load_streaming_fn sets or by checking hasattr(train_data, "__iter__") vs
__len__/isinstance of IterableDataset) and branch: for non-streaming datasets
call train_data.shuffle(seed=shuffle_seed, keep_in_memory=True) as before, and
for streaming/iterable datasets call train_data.shuffle(seed=shuffle_seed)
without keep_in_memory; ensure you modify the block that references shuffle_seed
and train_data.shuffle so runtime errors are avoided when load_streaming_fn()
returns a streaming dataset.
In `@tests/gpu/torch/puzzletron/test_bypass.py`:
- Line 213: The timeout passed to dist.setup uses timedelta(10) which means 10
days; change it to an explicit unit like timedelta(seconds=10) (or
timedelta(minutes=10) if intended) to avoid 10-day test hangs — locate the call
to dist.setup (symbol: dist.setup) in tests/gpu/torch/puzzletron/test_bypass.py
and the other listed files and replace timedelta(10) with timedelta(seconds=10)
(or the correct unit) in each occurrence.
---
Outside diff comments:
In `@modelopt/torch/puzzletron/pruning/pruning_utils.py`:
- Around line 40-47: The enum MlpInitMode now includes MoEChannelPruning but
_init_mlp_module still treats that case as unsupported; update the
_init_mlp_module implementation to handle MlpInitMode.MoEChannelPruning (the
same call-site that child_init.py forwards into) by adding a branch for
MlpInitMode.MoEChannelPruning that performs the correct initialization when
expert widths change (e.g., adapt the weight/activation shapes by
slicing/reshaping or reuse the ConcatExpertsIntoDenseFFN logic where
appropriate), so the child init no longer falls through to the "Unsupported
mlp_init_mode" error for MoEChannelPruning.
In `@tests/gpu/torch/puzzletron/test_puzzletron.py`:
- Around line 236-245: The printer currently outputs only rank-local
pruning_scores causing incomplete EXPECTED_PRUNING_VALUES for multi-GPU runs;
modify the logic so rank 0 aggregates pruning data from all ranks before
printing: collect and merge per-rank pruning_scores (or load all rank_{rank}.pth
files) into a global pruning_scores for each layer_name, compute the global
score and channels (e.g., combine/average or gather channel indices across
ranks) respecting total_layers, and then have rank 0 iterate over layer_names
using the aggregated values when printing the block that uses total_layers and
prints the EXPECTED_PRUNING_VALUES snippet.
---
Nitpick comments:
In `@examples/puzzletron/main.py`:
- Around line 154-167: Update run_mip_only to compute the total steps like
run_full_puzzletron by calling _total_steps(hydra_cfg) and use that N when
formatting the progress messages instead of hardcoded "7/8" and "8/8";
specifically, replace the two mprint calls around the conditional that currently
show "Puzzletron Progress 7/8" and "8/8" with dynamic messages using N (e.g.,
f"Puzzletron Progress {current_step}/{N}: ...") and ensure current_step
increments are correct for both the sweep branch (sweep.run_mip_sweep) and the
mip branch (mip_and_realize_models.launch_mip_and_realize_model) so progress
displays consistently with _total_steps.
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 548-556: Remove the unused computation of num_trainable_params:
delete the sum(...) assignment that iterates over
student_stitched_module.named_parameters() checking p.requires_grad and
submodule_name in p_name (and the "dummy_param" exclusion). Keep the surrounding
explanatory comment about dummy params if still relevant, but eliminate the dead
variable and its associated needless iteration to avoid wasted computation and
clarify stitched_model_factory.py.
In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 804-806: The current override function (override) treats
item_overrides == None as "keep original", which prevents callers from
explicitly clearing a value to None via JSON/YAML; change the logic to use a
distinct sentinel (e.g., a new unique object like NO_OVERRIDE) to represent "no
override" and reserve None in item_overrides to mean "set to None"/clear the
field, updating the override function to check against the sentinel
(NO_OVERRIDE) instead of None and adjust any callers that construct overrides to
use the sentinel when they mean "leave original".
In `@modelopt/torch/puzzletron/utils/data/dataset.py`:
- Around line 123-130: The fallback that builds sample when
getattr(self.tokenizer, "chat_template", None) is None should preserve role
markers instead of joining only message["content"]; update the else branch in
dataset.py (the block that currently sets sample = "\n".join(m["content"] for m
in sample)) to join messages using a lightweight role-prefixed format like
"{role}: {content}" so conversation turns (system/user/assistant) are retained;
keep using the same sample variable and ensure this mirrors the structure
expected by downstream code that consumes apply_chat_template outputs.
In `@modelopt/torch/puzzletron/utils/parsing.py`:
- Around line 337-345: The current filtering silently drops any NaN in
losses_dict (and prunes best_steps_dict/best_values_dict to match), which hides
diverging trainable blocks; instead, update the logic around losses_dict,
best_steps_dict and best_values_dict so you only drop entries whose keys match
known skipped block types (e.g., the explicit list of no-op block names like
"Mamba"), and for any other NaN values emit a warning/error (via the existing
logger) that a trainable block produced NaN rather than removing it; ensure
best_steps_dict and best_values_dict are only pruned to match the filtered
losses_dict after this selective filtering and warning behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 371acd83-77b9-4396-8a82-eddd5b11dd40
📒 Files selected for processing (27)
examples/puzzletron/README.mdexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/main.pymodelopt/torch/puzzletron/anymodel/model_descriptor/model_descriptor.pymodelopt/torch/puzzletron/anymodel/models/nemotron_h/nemotron_h_model_descriptor.pymodelopt/torch/puzzletron/bypass_distillation/__init__.pymodelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.pymodelopt/torch/puzzletron/bypass_distillation/bypass_utils.pymodelopt/torch/puzzletron/bypass_distillation/data_classes.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/bypass_distillation/training_loop.pymodelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.pymodelopt/torch/puzzletron/pruning/pruning_utils.pymodelopt/torch/puzzletron/puzzletron.pymodelopt/torch/puzzletron/sewing_kit/utils.pymodelopt/torch/puzzletron/tools/bypassed_training/child_init.pymodelopt/torch/puzzletron/tools/checkpoint_utils_hf.pymodelopt/torch/puzzletron/utils/data/dataloaders.pymodelopt/torch/puzzletron/utils/data/dataset.pymodelopt/torch/puzzletron/utils/parsing.pytests/gpu/torch/puzzletron/resources/configs/meta-llama/Llama-3.1-8B-Instruct/bypass/test_bypass.yamltests/gpu/torch/puzzletron/test_bypass.pytests/gpu/torch/puzzletron/test_puzzletron.pytests/unit/torch/puzzletron/__init__.pytests/unit/torch/puzzletron/test_bypass_losses.pytests/unit/torch/puzzletron/test_bypass_utils.py
| # If "latest" doesn't exist, look explicitly into directories with `*iter-*` | ||
| candidate_dirs = [d for d in run_parent_dir.glob("*iter-*") if d.is_dir()] | ||
|
|
||
| if not candidate_dirs: | ||
| return None | ||
|
|
||
| def get_iter_num(dir_name): | ||
| match = re.search(r"iter-(\d+)", dir_name.name) | ||
| return int(match.group(1)) if match else 0 | ||
|
|
||
| checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True) | ||
| for latest_dir in checkpoint_dirs: | ||
| if (latest_dir / "saving_completed").exists(): | ||
| return str(latest_dir) |
There was a problem hiding this comment.
Include step_num when picking the latest checkpoint.
This fallback only sorts on iter-(\d+). If a run writes multiple checkpoints inside the same iteration, resume can load an older step even though a newer checkpoint exists in the same run_parent_dir.
💡 Suggested fix
- def get_iter_num(dir_name):
- match = re.search(r"iter-(\d+)", dir_name.name)
- return int(match.group(1)) if match else 0
-
- checkpoint_dirs = sorted(candidate_dirs, key=get_iter_num, reverse=True)
+ def checkpoint_order(path: Path) -> tuple[int, int, float]:
+ match = re.search(r"iter-(\d+)(?:.*step-(\d+))?", path.name)
+ if not match:
+ return (0, 0, path.stat().st_mtime)
+ return (int(match.group(1)), int(match.group(2) or 0), path.stat().st_mtime)
+
+ checkpoint_dirs = sorted(candidate_dirs, key=checkpoint_order, reverse=True)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`
around lines 45 - 58, The fallback currently only sorts checkpoint directories
by iteration (get_iter_num) so when multiple checkpoints exist for the same iter
we may pick an older step; update the selection to parse both iteration and step
and sort by (iter, step) descending. Modify get_iter_num (or replace with
get_iter_and_step) to extract iter via r"iter-(\d+)" and step via r"step-(\d+)"
(defaulting to 0 when not present), return a tuple of two ints, and use that
tuple as the sort key for checkpoint_dirs before iterating to find the first
directory with "saving_completed" and returning str(latest_dir).
cjluo-nv
left a comment
There was a problem hiding this comment.
Summary: Adds bypass distillation (blockwise local knowledge distillation) as an optional pipeline stage to puzzletron. Includes a full training loop, stitched model factory, checkpoint management, loss functions, data loader, configuration, and comprehensive tests. Also fixes bugs in child_init.py, dataset.py, and adds HF auto-download logic.
Issues Found:
-
[Duplicated Code]
normalized_mse_lossinsewing_kit/utils.py(diff lines 432-445) is an exact duplicate of the existing implementation inmodelopt/torch/puzzletron/tools/kd_model.py:32-41. The new code should import and reuse the existing function rather than redefining it. Thevectorwise_normalized_mse_lossandbatched_normalized_mse_lossvariants are new and fine, but they should build on the existing import. -
[Correctness / Security]
training_loop.py:675—AutoTokenizer.from_pretraineduses hardcodedtrust_remote_code=True. The variabletrust_remote_codeis already computed from the descriptor at line 648. This should usetrust_remote_code=trust_remote_codeinstead. (Flagged by pre-merge checks as well.) -
[Correctness / Security]
bypass_checkpoint_utils.py:85,99—torch.load()calls lackweights_only=True. The codebase convention (e.g.,checkpoint_utils.py:43,77) is to useweights_only=Truefor state dict loading. These calls load state dicts and optimizer states respectively, which are pure tensor data and should useweights_only=True. -
[Correctness]
training_loop.py— Theexcept Exception as eblock at the end ofrun_bypassed_training(around line 870) catches all exceptions and callssys.exit(1)for non-SystemExit exceptions. This swallows the actual exception type and prevents proper test framework error reporting. In GPU tests, a failing bypass run will produceSystemExit(1)instead of the real traceback. Consider re-raising or at least logging before exit. -
[Correctness]
stitched_model_factory.py:370-373— The lambda closures in the stitched module creation loop (adapter=lambda v: InputArgs(target=v)andadapter=lambda v: InputArgs(input=v)) capturevcorrectly since they're arguments, but the loss target/input naming ("target"and"input") relies onblock_loss_funcaccepting exactly these keyword arguments. If someone changesblock_loss_functo e.g.batched_normalized_mse_loss, the keyword args don't match (batched_normalized_mse_losstakesinputandtargetpositional args, not kwargs viaInputArgs). This coupling is implicit and fragile — consider documenting the contract or adding a**kwargsadapter. -
[Correctness]
bypass_checkpoint_utils.py:89—loaded_state_dict = {**stitched_module.state_dict(), **loaded_state_dict}merges the current state dict with the loaded one (loaded takes precedence). However, the current state dict is fetched before loading — if the model is on a different device, keys may contain tensors on the wrong device. The subsequentload_state_dictshould handle this, but the intermediate merged dict is wasteful. Consider just usingstrict=Falsewithload_state_dictdirectly. -
[Readability]
stitched_model_factory.py— Thebypass_factory_fnfunction is ~250 lines long with deeply nested logic. The student model initialization block (lines 200-305) could be extracted into a helper like_initialize_student_model(...). -
[Readability]
training_loop.py— Thetrain()function is ~300 lines with deeply nested control flow for logging, validation, checkpoint saving, and time-based signals. Consider extracting checkpoint-save logic and logging logic into separate functions. -
[Readability]
stitched_model_factory.py:434-435— Blank lines between the closing of the function and the backward-compatible aliases (gqa_factory_fn = bypass_factory_fn,moe_factory_fn = bypass_factory_fn). These aliases have no callers in this PR and no documentation. If they're for backward compat with existing configs, add a comment. If they're unused, remove them. -
[Tests] The GPU tests are thorough for the happy path but don't test checkpoint resume (loading from a previous run). The
find_last_ckpt_for_resume+load_local_statepath is complex and untested. At minimum, a test that runs bypass, then runs it again withfind_last_ckpt_for_resume=Trueto verify resume works would increase confidence. -
[Tests] No unit test for
_set_keys_to_learnwhich has significant branching logic (subblock types, hybrid model block_configs filtering, regex fallback). This function is critical for correctness. -
[Correctness]
puzzletron_nas_plugin.py— The new auto-download logic inconvert_puzzletron_model(lines 152-165) runssnapshot_downloadonly on rank 0 insideif dist.is_master(), but then all ranks calldist.barrier(). If the download takes a long time, the barrier timeout (set inmain.pyastimedelta(10)= 10 days) should be fine, but theinput_model_pathvariable is only updated on rank 0 — other ranks never use it since only rank 0 does the conversion. This is correct but subtle; a comment would help. -
[Correctness]
bypass_utils.py:50—set_experiment_dirassigns aPathobject tocfg.bypass.experiment_dir, butOmegaConf/DictConfigdoesn't natively supportPathobjects. This works because OmegaConf stores it as-is in struct mode off, but it may cause serialization issues (e.g.,json_dumpinsave_bypass_checkpoint). Consider converting tostr.
Suggestions:
- The
_copy_auto_map_code_filesaddition incheckpoint_utils_hf.pyis a good fix for trust_remote_code models. Consider adding a brief unit test or at least a comment about which models require this (e.g., NemotronH). - The
format_stitched_lossesNaN filtering is a nice quality-of-life improvement for hybrid models. Theimport mathinside the function body should be moved to the module top-level. - The
dataset.pychat_template fallback is correct and handles base models gracefully. - The
child_init.pyfix (return iteminstead ofreturn item_overrideswhenNone) is a real bug fix — good catch.
Overall Assessment: This is a well-structured, substantial feature addition. The core architecture (stitched model factory, per-block KD, pipeline integration) is sound. However, the hardcoded trust_remote_code=True security issue and the duplicated normalized_mse_loss need to be addressed before merge. The torch.load calls should also use weights_only=True per project convention.
There was a problem hiding this comment.
Actionable comments posted: 8
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)
811-843:⚠️ Potential issue | 🟠 Major
nulloverrides still crash for optional nested dataclasses.The new
item_overrides is Nonebranch is bypassed whenprevious_value is None and _is_dataclass_type(item_type)is true, so an override like...: nullstill becomes_get_dataclass_type(item_type)(**item_overrides)and raises at runtime. This is easy to hit for optional sub-configs that default toNone.Suggested fix
- if previous_value is None and _is_dataclass_type(item_type): - new_value = _get_dataclass_type(item_type)(**item_overrides) + if item_overrides is None: + new_value = previous_value + elif previous_value is None and _is_dataclass_type(item_type): + assert isinstance(item_overrides, dict) + new_value = _get_dataclass_type(item_type)(**item_overrides) else: new_value = override(previous_value, item_overrides)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines 811 - 843, The dataclass_override loop special-case instantiates a nested dataclass even when the provided override is None, causing a crash; modify the block in dataclass_override that handles "previous_value is None and _is_dataclass_type(item_type)" to first check if item_overrides is None and in that case set new_value = None (or call override(previous_value, item_overrides)), otherwise instantiate with _get_dataclass_type(item_type)(**item_overrides); keep the subsequent check_type(new_value, item_type) and existing symbols (override, dataclass_override, _is_dataclass_type, _get_dataclass_type, check_type) so optional nested dataclass overrides that are null no longer raise.
🧹 Nitpick comments (3)
modelopt/torch/puzzletron/tools/bypassed_training/child_init.py (1)
96-115: Reject overlapping outputs from multiple pruning mixins.
layer_out_state_dict.update(_layer_out)silently makes the final checkpoint depend on mixin order if two mixins emit the same state-dict key. Failing fast here is safer than letting one mixin overwrite the other.Suggested guard
- layer_out_state_dict.update(_layer_out) + overlapping_keys = layer_out_state_dict.keys() & _layer_out.keys() + if overlapping_keys: + raise ValueError( + f"Pruning mixins produced overlapping keys for layer {layer_idx}: " + f"{sorted(overlapping_keys)}" + ) + layer_out_state_dict.update(_layer_out)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py` around lines 96 - 115, The loop over pruning mixins currently does layer_out_state_dict.update(_layer_out) which allows later mixins to silently overwrite keys from earlier ones; change this to detect overlapping keys and fail fast: for each _mixin when you get _layer_out from prune_single_layer, compute intersection = set(layer_out_state_dict.keys()) & set(_layer_out.keys()) and if intersection is non-empty raise a ValueError (or AssertionError) listing the conflicting keys and the mixin identity (use _mixin or its type) instead of updating; only call layer_out_state_dict.update(_layer_out) when intersection is empty to ensure deterministic, non-overlapping outputs from prune_single_layer across mixins.modelopt/torch/puzzletron/tools/kd_model.py (1)
38-39: Add a zero/near-zero target regression test for this denominator change.This adjustment mainly changes behavior when
targethas tiny norm, buttests/unit/torch/puzzletron/test_bypass_losses.pycurrently only covers random tensors. A focused zero-target case would keep this stabilization behavior from regressing unnoticed.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/tools/kd_model.py` around lines 38 - 39, Add a unit test in tests/unit/torch/puzzletron/test_bypass_losses.py (e.g., test_kd_loss_zero_and_near_zero_target) that imports the kd loss implementation from modelopt.torch.puzzletron.tools.kd_model, constructs both a zero target and a near-zero target tensor, calls the code path that computes loss using the expression containing F.mse_loss(input, target, reduction=reduction) / (F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon), and asserts the loss is finite and behaves stably (no division-by-zero, not NaN/Inf) for both cases; use the same input tensor for both and check that adding the epsilon in the denominator prevents regressions.modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)
377-381: Consider more descriptive error handling for invalidblock_loss_func.If
cfg.model_factory.block_loss_funcis not one of the three supported values, aKeyErroris raised with just the invalid key name. A more descriptive error would help users identify the misconfiguration quickly.Suggested improvement
+ _BLOCK_LOSS_FUNCS = { + "normalized_mse_loss": normalized_mse_loss, + "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, + "batched_normalized_mse_loss": batched_normalized_mse_loss, + } + loss_func_name = cfg.model_factory.block_loss_func + if loss_func_name not in _BLOCK_LOSS_FUNCS: + raise ValueError( + f"Unknown block_loss_func '{loss_func_name}'. " + f"Supported: {list(_BLOCK_LOSS_FUNCS.keys())}" + ) - block_loss_func = { - "normalized_mse_loss": normalized_mse_loss, - "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, - "batched_normalized_mse_loss": batched_normalized_mse_loss, - }[cfg.model_factory.block_loss_func] + block_loss_func = _BLOCK_LOSS_FUNCS[loss_func_name]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py` around lines 377 - 381, The current lookup for block_loss_func in stitched_model_factory.py uses a direct dict index which raises an opaque KeyError when cfg.model_factory.block_loss_func is invalid; replace it with a guarded lookup: retrieve via dict.get or check membership first and raise a ValueError with a clear message that includes the invalid value and the allowed options (e.g., "normalized_mse_loss", "vectorwise_normalized_mse_loss", "batched_normalized_mse_loss"); update the code around the block_loss_func assignment (the dict and its use) so callers get a descriptive error instead of a raw KeyError.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.py`:
- Around line 113-125: The checkpoint load/save is missing GradScaler state so
when use_grad_scaling=True resumed runs lose scaler state; update the save and
load paths around the StitchedModuleDescriptor handling to persist
grad_scaler.state_dict() (e.g., save to
stitched/{stitched_module_name}.grad_scaler.pth) when grad_scaler is not None
and on load (in the blocks that currently load optimizer state and in the
similar 165-171 block) call grad_scaler.load_state_dict(...) after constructing
or retrieving the module’s grad_scaler, using map_location=device, and guard
with the use_grad_scaling flag so scaler state is restored only when applicable.
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 597-600: Log currently prints an empty block name because
submodule_name is never reassigned; update the mprint call in
stitched_model_factory.py (the one that prints Block {submodule_name}: ...) to
use a valid identifier such as student_stitched_module_name (or module_name) so
the block name is meaningful; locate the mprint invocation and replace
submodule_name with student_stitched_module_name (ensuring
student_stitched_module_name is in scope) and keep the rest of the
message/parameter counting intact.
- Around line 417-428: The code assumes owned_block_indexes is non-empty before
calling min()/max(), which will raise ValueError if a rank owns no blocks; in
the block around min_owned_index/max_owned_index in stitched_model_factory.py,
first check if not owned_block_indexes and handle it defensively (e.g., set
prev_rank and next_rank to None or raise a clear, explanatory error) instead of
calling min()/max(); update the logic that computes prev_rank and next_rank
using model_blocks_process_ownership and all_block_indices to only run when
owned_block_indexes is non-empty so misconfiguration yields a clear message or
safe defaults.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py`:
- Around line 837-838: The code reads source_datasets_to_discard from cfg.bypass
root but the new config nests it under bypass.data; update the dataloader calls
that set source_datasets_to_discard (and any similar occurrences) to read from
cfg.bypass.data.get("source_datasets_to_discard", tuple()) instead of
cfg.bypass.get(...), leaving bos_rate as cfg.bypass.data.bos_rate; search for
the occurrences that set source_datasets_to_discard (the two places mentioned
around the calls that also use bos_rate) and replace them to use cfg.bypass.data
so the discard list becomes configurable.
- Around line 252-253: The parameter skip_first_batches is never applied: after
creating the batch iterator from the ConstantLengthDataset/dataloader you must
advance that iterator by skip_first_batches before entering the training loop
(e.g., consume the iterator with next(...) in a short loop or use
itertools.islice to drop the first N items); update the code paths where
skip_first_batches is accepted (the occurrences around skip_first_batches in
training_loop.py and the second occurrence at lines ~329-330) to consume the
iterator accordingly so resumed runs do not replay from batch 0.
- Around line 349-350: The loop exit condition uses a 1-based counter and
currently uses >=, causing it to stop one step too early; update the check in
the training loop that references cfg.bypass.step_num and
cfg.bypass.training.max_steps so it breaks only once step_num has passed the
budget (use > instead of >=) so the final scheduled step runs.
- Around line 103-107: The AutoConfig.from_pretrained call inside
run_bypassed_training bypasses the earlier trust_remote_code decision; update
the AutoConfig.from_pretrained invocation in training_loop.py (the block that
imports HFAutoConfig and sets teacher_hf_cfg/teacher_intermediate_size) to pass
the same trust_remote_code flag you query earlier (the
descriptor/trust_remote_code value used by run_bypassed_training) so
remote-code-required models load consistently and safely. Locate the
AutoConfig.from_pretrained usage and add the trust_remote_code argument (sourced
from the existing hydra_cfg/descriptor/trust_remote_code variable) when calling
from_pretrained, ensuring the call mirrors the trusted-remote-code decision made
elsewhere.
In `@tests/gpu/torch/puzzletron/test_puzzletron.py`:
- Around line 209-215: The test incorrectly computes per-rank FFN counts from
hidden layer count (total_layers = max(2, size)); instead compute the actual
number of prunable FFN blocks (e.g., scan the model's layer names or modules to
count FFN/prunable blocks rather than using hidden-layer count) into
total_ffn_blocks, then compute layers_this_rank = total_ffn_blocks // size + (1
if rank < total_ffn_blocks % size else 0) and assert len(layer_names) ==
layers_this_rank (allowing 0 for ranks that only own Mamba blocks); update the
variables total_layers/layers_this_rank and reference layer_names when making
this change.
---
Outside diff comments:
In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 811-843: The dataclass_override loop special-case instantiates a
nested dataclass even when the provided override is None, causing a crash;
modify the block in dataclass_override that handles "previous_value is None and
_is_dataclass_type(item_type)" to first check if item_overrides is None and in
that case set new_value = None (or call override(previous_value,
item_overrides)), otherwise instantiate with
_get_dataclass_type(item_type)(**item_overrides); keep the subsequent
check_type(new_value, item_type) and existing symbols (override,
dataclass_override, _is_dataclass_type, _get_dataclass_type, check_type) so
optional nested dataclass overrides that are null no longer raise.
---
Nitpick comments:
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 377-381: The current lookup for block_loss_func in
stitched_model_factory.py uses a direct dict index which raises an opaque
KeyError when cfg.model_factory.block_loss_func is invalid; replace it with a
guarded lookup: retrieve via dict.get or check membership first and raise a
ValueError with a clear message that includes the invalid value and the allowed
options (e.g., "normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss"); update the code around the block_loss_func
assignment (the dict and its use) so callers get a descriptive error instead of
a raw KeyError.
In `@modelopt/torch/puzzletron/tools/bypassed_training/child_init.py`:
- Around line 96-115: The loop over pruning mixins currently does
layer_out_state_dict.update(_layer_out) which allows later mixins to silently
overwrite keys from earlier ones; change this to detect overlapping keys and
fail fast: for each _mixin when you get _layer_out from prune_single_layer,
compute intersection = set(layer_out_state_dict.keys()) & set(_layer_out.keys())
and if intersection is non-empty raise a ValueError (or AssertionError) listing
the conflicting keys and the mixin identity (use _mixin or its type) instead of
updating; only call layer_out_state_dict.update(_layer_out) when intersection is
empty to ensure deterministic, non-overlapping outputs from prune_single_layer
across mixins.
In `@modelopt/torch/puzzletron/tools/kd_model.py`:
- Around line 38-39: Add a unit test in
tests/unit/torch/puzzletron/test_bypass_losses.py (e.g.,
test_kd_loss_zero_and_near_zero_target) that imports the kd loss implementation
from modelopt.torch.puzzletron.tools.kd_model, constructs both a zero target and
a near-zero target tensor, calls the code path that computes loss using the
expression containing F.mse_loss(input, target, reduction=reduction) /
(F.mse_loss(target, torch.zeros_like(target), reduction=reduction) + epsilon),
and asserts the loss is finite and behaves stably (no division-by-zero, not
NaN/Inf) for both cases; use the same input tensor for both and check that
adding the epsilon in the denominator prevents regressions.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6cb35ef5-41ea-4f6d-990a-791e2c99b812
📒 Files selected for processing (90)
examples/puzzletron/BYPASS.mdexamples/puzzletron/README.mdexamples/puzzletron/configs/bypass/defaults.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yamlexamples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yamlexamples/puzzletron/configs/pruning/defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yamlexamples/puzzletron/configs/realize_model/validate_solutions_defaults.yamlexamples/puzzletron/configs/scoring/validate_solutions_defaults.yamlexamples/puzzletron/configs/validate_model_defaults.yamlexamples/puzzletron/configs/validate_solutions_defaults.yamlexamples/puzzletron/main.pymodelopt/torch/puzzletron/bypass_distillation/__init__.pymodelopt/torch/puzzletron/bypass_distillation/bypass_checkpoint_utils.pymodelopt/torch/puzzletron/bypass_distillation/bypass_utils.pymodelopt/torch/puzzletron/bypass_distillation/data_classes.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/bypass_distillation/training_loop.pymodelopt/torch/puzzletron/dataset/prepare_dataset.pymodelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.pymodelopt/torch/puzzletron/pruning/pruning_utils.pymodelopt/torch/puzzletron/sewing_kit/utils.pymodelopt/torch/puzzletron/tools/bypassed_training/child_init.pymodelopt/torch/puzzletron/tools/checkpoint_utils_hf.pymodelopt/torch/puzzletron/tools/kd_model.pymodelopt/torch/puzzletron/utils/data/dataloaders.pymodelopt/torch/puzzletron/utils/parsing.pymodelopt/torch/utils/plugins/transformers_dataset.pytests/_test_utils/torch/puzzletron/utils.pytests/gpu/torch/puzzletron/nas/plugins/test_nas_convert.pytests/gpu/torch/puzzletron/nas/plugins/test_nas_search.pytests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yamltests/gpu/torch/puzzletron/test_bypass.pytests/gpu/torch/puzzletron/test_puzzletron.pytests/unit/torch/puzzletron/__init__.pytests/unit/torch/puzzletron/test_bypass_losses.pytests/unit/torch/puzzletron/test_bypass_utils.py
✅ Files skipped from review due to trivial changes (67)
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
- tests/unit/torch/puzzletron/init.py
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/bypass/defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
- modelopt/torch/puzzletron/dataset/prepare_dataset.py
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/gptoss-20b_remove_experts_memory/bypass/defaults.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
- modelopt/torch/utils/plugins/transformers_dataset.py
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/bypass/defaults.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/Mistral-Small-24B.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/qwen2_5_7b_instruct.yaml
- examples/puzzletron/configs/gptoss-20b_remove_experts_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/realize_model/validate_solutions_defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/bypass/defaults.yaml
- tests/_test_utils/torch/puzzletron/utils.py
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/hidden_dim_pruning.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/bypass/defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/Llama-3_2-3B.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/gptoss-20b_remove_experts_memory/gptoss-20b.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/qwen3_8b.yaml
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/validate_model_defaults.yaml
- tests/gpu/torch/puzzletron/resources/configs/Qwen/Qwen3-VL-30B-A3B-Instruct/pruning/expert_pruning.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/validate_model_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/pruning/attn_pruning.yaml
- examples/puzzletron/configs/validate_solutions_defaults.yaml
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/bypass/defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/nemotron_nano_12b_v2.yaml
- examples/puzzletron/configs/scoring/validate_solutions_defaults.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/bypass/defaults.yaml
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
- examples/puzzletron/configs/mistral-small-24b-instruct-2501_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/README.md
- modelopt/torch/puzzletron/tools/checkpoint_utils_hf.py
- examples/puzzletron/configs/validate_model_defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/llama-3_2-3B_pruneffn_memory/validate_solutions_defaults.yaml
- modelopt/torch/puzzletron/bypass_distillation/data_classes.py
- modelopt/torch/puzzletron/bypass_distillation/init.py
- examples/puzzletron/configs/bypass/defaults.yaml
- examples/puzzletron/BYPASS.md
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/validate_solutions_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_solutions_defaults.yaml
- tests/unit/torch/puzzletron/test_bypass_losses.py
- examples/puzzletron/configs/qwen3-8b_pruneffn_memory/pruning/pruning_defaults.yaml
- examples/puzzletron/configs/qwen2_5_7b_instruct_pruneffn_memory/validate_model_defaults.yaml
- examples/puzzletron/configs/nemotron-nano-12b-v2/pruning/pruning_defaults.yaml
🚧 Files skipped from review as they are similar to previous changes (6)
- examples/puzzletron/configs/llama-3_1-8B_pruneffn_memory/Llama-3_1-8B.yaml
- modelopt/torch/puzzletron/utils/parsing.py
- modelopt/torch/puzzletron/pruning/pruning_utils.py
- modelopt/torch/puzzletron/bypass_distillation/bypass_utils.py
- modelopt/torch/puzzletron/sewing_kit/utils.py
- modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
| mprint( | ||
| f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors " | ||
| f"({sum(p.numel() for p in trainable_params.values()):,} params)" | ||
| ) |
There was a problem hiding this comment.
Log message will always show empty block name.
submodule_name is initialized to "" at line 449 and never reassigned within the loop. The log message "Block : ..." will always display an empty block name. Consider using student_stitched_module_name (e.g., block_0) or module_name for clarity.
Suggested fix
mprint(
- f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors "
+ f"Block {student_stitched_module_name}: {len(trainable_params)} trainable parameter tensors "
f"({sum(p.numel() for p in trainable_params.values()):,} params)"
)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| mprint( | |
| f"Block {submodule_name}: {len(trainable_params)} trainable parameter tensors " | |
| f"({sum(p.numel() for p in trainable_params.values()):,} params)" | |
| ) | |
| mprint( | |
| f"Block {student_stitched_module_name}: {len(trainable_params)} trainable parameter tensors " | |
| f"({sum(p.numel() for p in trainable_params.values()):,} params)" | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`
around lines 597 - 600, Log currently prints an empty block name because
submodule_name is never reassigned; update the mprint call in
stitched_model_factory.py (the one that prints Block {submodule_name}: ...) to
use a valid identifier such as student_stitched_module_name (or module_name) so
the block name is meaningful; locate the mprint invocation and replace
submodule_name with student_stitched_module_name (ensuring
student_stitched_module_name is in scope) and keep the rest of the
message/parameter counting intact.
| if do_ffn or do_attn or do_blk: | ||
| from transformers import AutoConfig as HFAutoConfig | ||
|
|
||
| teacher_hf_cfg = HFAutoConfig.from_pretrained(str(hydra_cfg.teacher_dir)) | ||
| teacher_intermediate_size = getattr(teacher_hf_cfg, "intermediate_size", None) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '95,120p'Repository: NVIDIA/Model-Optimizer
Length of output: 1567
🏁 Script executed:
rg "requires_trust_remote_code|trust_remote_code" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 614
🏁 Script executed:
rg "ModelDescriptorFactory" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 539
🏁 Script executed:
cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '1,110p'Repository: NVIDIA/Model-Optimizer
Length of output: 5786
🏁 Script executed:
rg "def requires_trust_remote_code" modelopt/ -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 1441
🏁 Script executed:
cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '73,250p'Repository: NVIDIA/Model-Optimizer
Length of output: 9638
🏁 Script executed:
rg "def launch_bypass_distillation" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -A 50Repository: NVIDIA/Model-Optimizer
Length of output: 2748
🏁 Script executed:
rg "hydra_cfg.descriptor|cfg.descriptor" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 344
🏁 Script executed:
cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '240,300p'Repository: NVIDIA/Model-Optimizer
Length of output: 2827
🏁 Script executed:
rg "descriptor\s*=" modelopt/torch/puzzletron/bypass_distillation/training_loop.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 558
🏁 Script executed:
cat -n modelopt/torch/puzzletron/bypass_distillation/training_loop.py | sed -n '290,360p'Repository: NVIDIA/Model-Optimizer
Length of output: 3106
🏁 Script executed:
rg "def run_bypassed_training" modelopt/torch/puzzletron/bypass_distillation/training_loop.py -A 30Repository: NVIDIA/Model-Optimizer
Length of output: 1259
Thread trust_remote_code through the auto-config probe.
run_bypassed_training() queries the descriptor for trust_remote_code, but this auto-config path at lines 103–107 bypasses that and calls AutoConfig.from_pretrained() with default (unsafe) behavior. Models requiring remote code execution will fail inconsistently depending on which path loads them.
💡 Suggested fix
if do_ffn or do_attn or do_blk:
from transformers import AutoConfig as HFAutoConfig
- teacher_hf_cfg = HFAutoConfig.from_pretrained(str(hydra_cfg.teacher_dir))
+ descriptor = ModelDescriptorFactory.get(hydra_cfg.descriptor)
+ trust_remote_code = descriptor.requires_trust_remote_code()
+ teacher_hf_cfg = HFAutoConfig.from_pretrained(
+ str(hydra_cfg.teacher_dir),
+ trust_remote_code=trust_remote_code,
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/puzzletron/bypass_distillation/training_loop.py` around lines
103 - 107, The AutoConfig.from_pretrained call inside run_bypassed_training
bypasses the earlier trust_remote_code decision; update the
AutoConfig.from_pretrained invocation in training_loop.py (the block that
imports HFAutoConfig and sets teacher_hf_cfg/teacher_intermediate_size) to pass
the same trust_remote_code flag you query earlier (the
descriptor/trust_remote_code value used by run_bypassed_training) so
remote-code-required models load consistently and safely. Locate the
AutoConfig.from_pretrained usage and add the trust_remote_code argument (sourced
from the existing hydra_cfg/descriptor/trust_remote_code variable) when calling
from_pretrained, ensuring the call mirrors the trusted-remote-code decision made
elsewhere.
| # The test model has num_hidden_layers = max(2, size), so every rank owns at least | ||
| # one layer. Compute the actual expected count for *this* rank. | ||
| total_layers = max(2, size) | ||
| layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0) | ||
| assert len(layer_names) == layers_this_rank, ( | ||
| f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}" | ||
| ) |
There was a problem hiding this comment.
The per-rank FFN count is still wrong for hybrid models.
total_layers = max(2, size) counts hidden layers, not prunable FFN blocks. This file already documents nvidia/NVIDIA-Nemotron-Nano-12B-v2 as having only one FFN layer, so a rank that owns only Mamba blocks can legitimately have len(layer_names) == 0.
💡 Suggested fix
- total_layers = max(2, size)
- layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0)
- assert len(layer_names) == layers_this_rank, (
- f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
- )
+ total_layers = max(2, size)
+ if len(expected) == total_layers:
+ layers_this_rank = total_layers // size + (1 if rank < total_layers % size else 0)
+ assert len(layer_names) == layers_this_rank, (
+ f"Expected {layers_this_rank} FFN layers on rank {rank}/{size}, got {len(layer_names)}"
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/gpu/torch/puzzletron/test_puzzletron.py` around lines 209 - 215, The
test incorrectly computes per-rank FFN counts from hidden layer count
(total_layers = max(2, size)); instead compute the actual number of prunable FFN
blocks (e.g., scan the model's layer names or modules to count FFN/prunable
blocks rather than using hidden-layer count) into total_ffn_blocks, then compute
layers_this_rank = total_ffn_blocks // size + (1 if rank < total_ffn_blocks %
size else 0) and assert len(layer_names) == layers_this_rank (allowing 0 for
ranks that only own Mamba blocks); update the variables
total_layers/layers_this_rank and reference layer_names when making this change.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/puzzletron/main.py (1)
102-135:⚠️ Potential issue | 🟠 MajorEnsure distributed cleanup runs on failure paths.
If convert/search/sweep/MIP raises,
dist.cleanup()is skipped. In multi-GPU flows this can leave process groups hanging. Wrap execution intry/finallyin both runners.Proposed fix
def run_full_puzzletron(hydra_config_path: str): @@ hydra_cfg, hydra_config_dir, hydra_config_name, n = _setup(hydra_config_path) mprint(f"Puzzletron Progress 1/{n}: starting puzzletron pipeline") - - # Convert model (convert from HF to DeciLM, score pruning activations, - # prune the model and save pruned checkpoints) - input_model = PuzzletronModel() - converted_model = mtn.convert( - input_model, - mode=[ - ( - "puzzletron", - { - "puzzle_dir": str(hydra_cfg.puzzle_dir), - "input_model_path": hydra_cfg.input_hf_model_path, - "hydra_config_dir": hydra_config_dir, - "hydra_config_name": hydra_config_name, - "dataset_path": str(hydra_cfg.dataset_path), - }, - ) - ], - ) - - # Run NAS search (build replacement library and compute stats, - # compute one block scores, run MIP and realize models) - mtn.search( - converted_model, - constraints={}, # this is not used as the search space is defined in the hydra config - dummy_input=None, # Not used - config={}, # this is not used as the search space is defined in the hydra config - ) - - dist.cleanup() + try: + # Convert model (convert from HF to DeciLM, score pruning activations, + # prune the model and save pruned checkpoints) + input_model = PuzzletronModel() + converted_model = mtn.convert( + input_model, + mode=[ + ( + "puzzletron", + { + "puzzle_dir": str(hydra_cfg.puzzle_dir), + "input_model_path": hydra_cfg.input_hf_model_path, + "hydra_config_dir": hydra_config_dir, + "hydra_config_name": hydra_config_name, + "dataset_path": str(hydra_cfg.dataset_path), + }, + ) + ], + ) + + # Run NAS search (build replacement library and compute stats, + # compute one block scores, run MIP and realize models) + mtn.search( + converted_model, + constraints={}, # this is not used as the search space is defined in the hydra config + dummy_input=None, # Not used + config={}, # this is not used as the search space is defined in the hydra config + ) + finally: + dist.cleanup() mprint(f"Puzzletron Progress {n}/{n}: puzzletron pipeline completed (multi-gpu)") @@ def run_mip_only(hydra_config_path: str): @@ - # Check if sweep mode is enabled - if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): - mprint( - f"Puzzletron Progress {mip_step}/{n}: running MIP sweep for multiple compression rates (multi-gpu)" - ) - sweep.run_mip_sweep(hydra_cfg) - else: - # mip_and_realize_models (distributed processing) - # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API - mprint(f"Puzzletron Progress {mip_step}/{n}: running MIP and realizing models (multi-gpu)") - mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) - - dist.cleanup() + try: + # Check if sweep mode is enabled + if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False): + mprint( + f"Puzzletron Progress {mip_step}/{n}: running MIP sweep for multiple compression rates (multi-gpu)" + ) + sweep.run_mip_sweep(hydra_cfg) + else: + # mip_and_realize_models (distributed processing) + # TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API + mprint(f"Puzzletron Progress {mip_step}/{n}: running MIP and realizing models (multi-gpu)") + mip_and_realize_models.launch_mip_and_realize_model(hydra_cfg) + finally: + dist.cleanup() mprint(f"Puzzletron Progress {n}/{n}: puzzletron pipeline completed (multi-gpu)")Also applies to: 147-163
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/puzzletron/main.py` around lines 102 - 135, The current flow calls dist.cleanup() after running mtn.convert and mtn.search but if mtn.convert/mtn.search (or any subsequent step) raises an exception the cleanup is skipped; wrap the multi-GPU pipeline (from _setup through mtn.search/mtn.sweep/mtn.MIP calls around lines that create PuzzletronModel, call mtn.convert and mtn.search) in a try/finally block so dist.cleanup() always runs, and apply the same try/finally pattern to the other runner block referenced around lines 147-163; ensure the try encompasses all work that requires the distributed group and the finally calls dist.cleanup() unconditionally.
🧹 Nitpick comments (1)
modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py (1)
377-381: Consider clearer error handling for unknownblock_loss_func.If
cfg.model_factory.block_loss_funcis not one of the three expected values, aKeyErroris raised with a cryptic message. A more informative error would help users diagnose configuration issues.♻️ Suggested improvement
- block_loss_func = { + _BLOCK_LOSS_FUNCS = { "normalized_mse_loss": normalized_mse_loss, "vectorwise_normalized_mse_loss": vectorwise_normalized_mse_loss, "batched_normalized_mse_loss": batched_normalized_mse_loss, - }[cfg.model_factory.block_loss_func] + } + loss_func_name = cfg.model_factory.block_loss_func + if loss_func_name not in _BLOCK_LOSS_FUNCS: + raise ValueError( + f"Unknown block_loss_func '{loss_func_name}'. " + f"Expected one of: {list(_BLOCK_LOSS_FUNCS.keys())}" + ) + block_loss_func = _BLOCK_LOSS_FUNCS[loss_func_name]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py` around lines 377 - 381, The current lookup for block_loss_func using a dict keyed by cfg.model_factory.block_loss_func can raise a cryptic KeyError; update the code around block_loss_func (in stitched_model_factory.py) to explicitly validate cfg.model_factory.block_loss_func against the allowed names ("normalized_mse_loss", "vectorwise_normalized_mse_loss", "batched_normalized_mse_loss") and raise a clear ValueError that includes the invalid value and the list of valid options; reference the existing functions normalized_mse_loss, vectorwise_normalized_mse_loss, and batched_normalized_mse_loss when constructing the mapping and error message so users can quickly see the supported choices.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/puzzletron/main.py`:
- Around line 99-101: The docstring parameter name is incorrect — replace the
documented `config_path` with the actual function parameter `hydra_config_path`
and update its description to match (e.g., "Path to the YAML configuration
file") so the `hydra_config_path` argument in the function signature and the
docstring are consistent; locate the docstring in examples/puzzletron/main.py
near the function that accepts `hydra_config_path` and make this single-name
correction.
---
Outside diff comments:
In `@examples/puzzletron/main.py`:
- Around line 102-135: The current flow calls dist.cleanup() after running
mtn.convert and mtn.search but if mtn.convert/mtn.search (or any subsequent
step) raises an exception the cleanup is skipped; wrap the multi-GPU pipeline
(from _setup through mtn.search/mtn.sweep/mtn.MIP calls around lines that create
PuzzletronModel, call mtn.convert and mtn.search) in a try/finally block so
dist.cleanup() always runs, and apply the same try/finally pattern to the other
runner block referenced around lines 147-163; ensure the try encompasses all
work that requires the distributed group and the finally calls dist.cleanup()
unconditionally.
---
Nitpick comments:
In `@modelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.py`:
- Around line 377-381: The current lookup for block_loss_func using a dict keyed
by cfg.model_factory.block_loss_func can raise a cryptic KeyError; update the
code around block_loss_func (in stitched_model_factory.py) to explicitly
validate cfg.model_factory.block_loss_func against the allowed names
("normalized_mse_loss", "vectorwise_normalized_mse_loss",
"batched_normalized_mse_loss") and raise a clear ValueError that includes the
invalid value and the list of valid options; reference the existing functions
normalized_mse_loss, vectorwise_normalized_mse_loss, and
batched_normalized_mse_loss when constructing the mapping and error message so
users can quickly see the supported choices.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 99826f98-f6fb-41d2-a78c-5c40bec6c4c9
📒 Files selected for processing (3)
examples/puzzletron/main.pymodelopt/torch/puzzletron/bypass_distillation/stitched_model_factory.pymodelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/puzzletron/nas/plugins/puzzletron_nas_plugin.py
| Args: | ||
| config_path: Path to the YAML configuration file | ||
| """ |
There was a problem hiding this comment.
Fix docstring argument name mismatch.
Line 100 documents config_path, but the function argument is hydra_config_path. Please align the docstring to avoid confusion.
Proposed fix
def run_full_puzzletron(hydra_config_path: str):
@@
Args:
- config_path: Path to the YAML configuration file
+ hydra_config_path: Path to the YAML configuration file📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Args: | |
| config_path: Path to the YAML configuration file | |
| """ | |
| Args: | |
| hydra_config_path: Path to the YAML configuration file | |
| """ |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/puzzletron/main.py` around lines 99 - 101, The docstring parameter
name is incorrect — replace the documented `config_path` with the actual
function parameter `hydra_config_path` and update its description to match
(e.g., "Path to the YAML configuration file") so the `hydra_config_path`
argument in the function signature and the docstring are consistent; locate the
docstring in examples/puzzletron/main.py near the function that accepts
`hydra_config_path` and make this single-name correction.
|
@cjluo-nv addressed all the points (thanks again for the great review) |
18e6fa6 to
41b8ca7
Compare
2669a26 to
977d60a
Compare
There was a problem hiding this comment.
there are 250+ files in this review, many no need to change, please could you clean the MR?
There was a problem hiding this comment.
250+ files is because feature/puzzletron (base) is deleted I think, I'll try to make another MR to the main later this week
There was a problem hiding this comment.
no need for another just merge main to your feature branch and push the feature branch and many changes should be gone
| This section focuses on applying Model Optimizer's state-of-the-art complementary pruning modes to enable you to search for the best subnet architecture from your provided base model: | ||
|
|
||
| 1. [Minitron](https://arxiv.org/pdf/2408.11796): A pruning method developed by NVIDIA Research for pruning GPT (and later extended to Mamba, MoE, and Hybrid Transformer Mamba) models in NVIDIA Megatron-LM (M-LM) or Megatron-Bridge (M-Bridge) framework. It uses the activation magnitudes to prune the embedding hidden size; mlp ffn hidden size; transformer attention heads; mamba heads and head dimension; MoE number of experts, ffn hidden size, and shared expert intermediate size; and number of layers of the model. | ||
| 1. [Puzzletron](../puzzletron/README.md): An advanced pruning method by NVIDIA using Mixed Integer Programming (MIP) based NAS search algorithm. |
There was a problem hiding this comment.
please add a tutorial for bypass
There was a problem hiding this comment.
also make the test pass
b43af72 to
aa4b534
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1111 +/- ##
===========================================
- Coverage 76.74% 66.75% -10.00%
===========================================
Files 476 483 +7
Lines 51307 52304 +997
===========================================
- Hits 39377 34915 -4462
- Misses 11930 17389 +5459
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
aa4b534 to
c1a464d
Compare
Introduces blockwise local distillation (BLD) as an optional 5th step between pruning and replacement-library construction. When `bypass:` is present in the Hydra config, puzzletron trains alternative transformer block configurations against the teacher via per-block knowledge distillation, then surfaces the trained subblocks to MIP through symlinks under `puzzle_dir/ckpts/<experiment_id>`. New module `modelopt/torch/puzzletron/bypass_distillation/`: - training_loop.py: pipeline-parallel KD loop with cosine LR schedule, per-block AdamW + GradScaler, validation, and time/step-based checkpointing. - stitched_model_factory.py: unified factory (FFN/attention/MoE/Mamba/ whole-block) driven by `mlp_init_mode` x `keys_to_learn`; composes multiple pruning mixins (experts_removal + kv_heads + ffn_intermediate) when student/teacher configs differ along multiple axes. - bypass_checkpoint_utils.py: stitched-module state save/load with `latest` symlink and `saving_completed` marker; resume scans only plain `iter-NNNNNN-ckpt` directories. - bypass_utils.py: experiment_id derivation encoding every override axis (FFN+KV / experts+KV / KV-only) so sweeps cannot collide. Pipeline integration (puzzletron_nas_plugin.py): - _progress_step(hydra_cfg, stage) helper + canonical stage order produce coherent `Puzzletron Progress N/T` strings; total grows from 8 to 9 when bypass is configured. - Skip-if-done caching for every stage (convert / score / prune / bypass / library build). - Stale-library detection: any `ckpts/*` entry newer than `replacement_library.json` triggers a rebuild so post-bypass weights are picked up automatically. - build_replacement_library orders bypass-trained subblocks before Truncate-init variants so `drop_duplicates(keep="first")` is deterministic. - Auto-download HF model when input_hf_model_path is not a local dir. KV-heads pruning generalized beyond Llama: - KVHeadsLayerDescriptor registered for GptOss, NemotronH, NemotronH-V2, Qwen3-VL. - _lm_attrs() probes text_config and language_config so VL configs (Qwen3-VL, Llava, Llama-4) read num_attention_heads and head_dim from the right sub-config. - _init_attention_biases falls back to state-dict probing when a config (e.g. GptOssConfig) doesn't expose o_proj_bias / attention_bias as top-level attributes. MIP: new target_num_kv_heads constraint over stats.num_kv_heads for KV-cache-only sweeps. Tools: - _copy_auto_map_code_files: copies custom modeling_*.py files alongside config.json so trust-remote-code models reload correctly; identifier-shape guard rejects malformed auto_map entries. - tools/robust_json.py: JSON encoder for dataclasses, paths, enums, Namespaces, OmegaConf nodes, functions/classes, and timedeltas; used by bypass to serialize resume state. Multi-mixin and bug fixes in legacy paths: - child_init._process_single_layer accepts a single mixin or a list, enabling experts_removal + kv_heads + ffn_intermediate stacking. - update_model_config.override returns the original `item` (not None) when an override is None — preserves the original value instead of clobbering it. Data utilities: - create_train_dataloader for the infinite ConstantLengthDataset used by bypass training. - ConstantLengthDataset falls back to newline-joined message contents when the tokenizer has no chat_template (base models). - format_stitched_losses filters NaN entries from no-op blocks. Tutorial + configs: - examples/puzzletron/Nemotron-3-Nano-30B-A3B-Base-BF16.md: end-to-end KV-heads-only demonstration on a real 30B-A3B MoE-Mamba teacher, showing bypass closes ~74% of the lm_loss regression gap. - Per-family configs under configs/nemotron-3-nano-30b-a3b/. - Generic bypass/defaults.yaml template under llama-3_1-8B. Tests: - Unit: normalized-MSE losses; get_distributed_modules_ownership. - GPU: test_bypass.py parametrizes block-pruning, KV-head compression, multi-config sweep, and checkpoint-contents tests across 9 model families (extracted into PUZZLETRON_FAMILIES). Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
c1a464d to
691f4f6
Compare
HF datasets uses lazy __getattr__ at the package level (PEP 562), so mypy can't resolve top-level names — `from datasets import DatasetDict` fails with attr-defined. Switch to submodule imports (e.g. `from datasets.dataset_dict import DatasetDict`, `from datasets.load import load_dataset`) which bypass the lazy loader. Also folds in pre-commit cleanups across the bypass changeset: - ruff E501/N806/PT006 lint fixes (uppercase N in main.py, line length in PUZZLETRON_FAMILIES + main.py MIP-sweep mprint, parametrize tuple shape in test_bypass_utils.py). - markdownlint MD040 (fenced-code language tag in tutorial md). - ruff format auto-applied (PUZZLETRON_FAMILIES table, descriptor imports, etc.). - yamlfmt auto-applied to bypass YAML configs. - Drop dead StitchedModelFactoryFn type alias and its cast() call. - Annotate `descriptor` as ModelDescriptor (not type[ModelDescriptor]) to match the codebase convention used in init_child_from_parent.py (annotated as instance, called with the class — no-op at runtime, silences mypy). Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
|
…rounds Bypass-distillation review fixes: - bypass_checkpoint_utils: persist + restore GradScaler state alongside optimizer state so resumed runs with use_grad_scaling=True don't lose the running scale + growth tracker. - stitched_model_factory: raise a clear RuntimeError when world_size exceeds num_hidden_layers (would otherwise crash trailing ranks with a bare `min() arg is an empty sequence`); same condition fires identically on every rank, so no NCCL hang. - training_loop: actually apply skip_first_batches by advancing the data iterator before the loop (parameter was previously accepted but unused). - training_loop: fix off-by-one in the max_steps exit condition (was >=, now >) so the final scheduled step actually runs; same fix applied to the in-loop save-when-done branch so the final checkpoint is saved exactly once. - training_loop: read source_datasets_to_discard from cfg.bypass.data (where the YAML actually nests it) instead of cfg.bypass root, where it always fell back to the empty tuple. - dataloaders: reject num_workers > 0 in create_train_dataloader with an explicit error since ConstantLengthDataset.__iter__ does not shard via torch.utils.data.get_worker_info(); guard removable once the dataset gains worker-aware iteration. - dataloaders: branch on isinstance(train_data, datasets.IterableDataset) before passing keep_in_memory=True to .shuffle(); streaming mode (load_from_disk=false) doesn't accept that kwarg. Revert the datasets submodule-import workarounds (`from datasets.load import load_dataset`, etc.) — CI's mypy resolves top-level `from datasets import X` correctly because it installs the package via uv, so these workarounds were only needed for the local-pre-commit env on a node without datasets installed. Reverting shrinks the MR back to bypass-only files. Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
- Drop the `gqa_factory_fn` / `moe_factory_fn` backward-compat aliases in stitched_model_factory.py — no remaining call sites or YAML configs reference them after the unified `bypass_factory_fn` migration. - Apply ruff-format reflow to the two `source_datasets_to_discard=...` kwargs in training_loop.py (collapse the train-side call onto one line; re-indent the val-side multi-line call body). Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
There was a problem hiding this comment.
I'm assuming not feasible to reuse boilerplate from transformers Trainer?
There was a problem hiding this comment.
I don't think so, PP via SewingKit makes everything complicated/unique (I have an implementation via AutoModel, once this MR is landed, I can fully compare the two and we can decide to merge it or not, it is a bit complicated but less than sewingkit)
Test additions
--------------
Three new test files plus three new tests appended to test_bypass.py
covering paths the existing 4×9-family happy-path tests didn't reach:
- tests/unit/torch/puzzletron/test_bypass_keys_to_learn.py — pins all 8
branches of `_set_keys_to_learn` (subblock_ffn / subblock_attention /
subblock_mamba / entire_block / list / regex / hybrid block_configs
filter / no-match silent-return). The hybrid Mamba-vs-GQA filter is
silently misroutable on descriptor refactors; this test catches it.
- tests/unit/torch/puzzletron/test_bypass_replacement_library.py —
verifies `_get_last_checkpoint_from_each_experiment` discovers
symlinked bypass + pruning checkpoints, and that the bypass-priority
sort closure orders bypass-rooted paths before Truncate-init ones (a
regression here would silently discard bypass-trained weights).
- tests/gpu/torch/puzzletron/test_bypass_checkpoint_utils.py — save/load
round-trip for stitched-module state, optimizer state, and (new)
GradScaler state. The GradScaler round-trip is the regression test for
the recent CodeRabbit-flagged bug where resumed fp16 + use_grad_scaling
runs lost the running scale + growth tracker. Lives under tests/gpu/
rather than tests/unit/ because the production load path constructs
`torch.device(f"cuda:{rank}")` and torch.load needs a real CUDA device
to deserialize.
- tests/gpu/torch/puzzletron/test_bypass.py — three new tests:
* test_bypass_resume_from_checkpoint: 2-phase train→save→resume on
Llama-3.2-3B, asserts `iter_num` advances past the saved value.
GradScaler resume is covered separately at unit level (above)
because GradScaler.step() is fp16-only and the bypass infra is bf16.
* test_bypass_subblock_modes: parametrized over Llama-3.2-3B (dense
Truncate path) × GPT-OSS-20B (MoE ExpertRemoval + windowed-attention-
with-sinks path) × {subblock_ffn, subblock_attention, entire_block};
diffs start-of-training vs end-of-training stitched-module weights
and asserts only the expected param groups changed.
* test_bypass_then_build_library: end-to-end smoke — runs bypass, then
`_build_subblocks_df`, asserts the bypass experiment appears in the
resulting subblocks DataFrame's checkpoint-source columns.
Review-driven fixes
-------------------
- bypass_distillation/training_loop.py:
* AutoTokenizer.from_pretrained: pass `trust_remote_code=trust_remote_code`
(was hardcoded `True`); the variable is already derived from the
descriptor a few lines earlier.
* Wrap-around `try/except Exception`: re-raise (was `sys.exit(1)`) so
pytest sees the real exception type instead of a generic SystemExit,
and so distributed runs surface usable tracebacks.
- sewing_kit/utils.py: `normalized_mse_loss` is now re-exported from
`tools.kd_model` instead of redefined; the two implementations were
byte-for-byte identical.
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
- test_bypass.py: drop unused `block_name` (F841); ruff-format reflows for `_test_bypass_then_build_library_job`'s `_setup_hydra_cfg_and_pruning` args (one-arg-per-line) and the `Bypass run not discovered` assert message (single-line f-string). - sewing_kit/utils.py: re-export `normalized_mse_loss` via the `as normalized_mse_loss` form (PEP 484 explicit re-export). The prior `from X import Y # noqa: F401` form is treated by mypy as a private import, which surfaced as `attr-defined` at the call site in stitched_model_factory.py. - test_bypass_checkpoint_utils.py: remove blank line after the import block; collapse single-line optimizer ternary; collapse three function signatures whose argument list now fits on one line. - test_bypass_keys_to_learn.py: collapse single-symbol import and drop the trailing blank line. - test_bypass_replacement_library.py: drop a blank line and reflow `bypass_real = ...` to use parens for the line continuation. Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Two more call sites (in _test_bypass_resume_from_checkpoint_job and _test_bypass_subblock_modes_job) had the old multi-line-arg-pair format that ruff format wants reflowed to one-arg-per-line. The previous CI fix caught only the third call site. Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Pure-CPU unit tests covering the bypass-distillation surface that
codecov flagged as uncovered:
* puzzletron_nas_plugin progress helpers
* dataloaders (split auto-detect, num_workers guard, pad helper,
Printer fake accelerator, load_*_fn delegators)
* launch_bypass_distillation sweep dispatcher
* bypass_checkpoint_utils (find_latest_run_dir, _save_local_file,
_save_local_state, save_bypass_checkpoint orchestration)
* stitched_model_factory _get_all_non_persistent_buffers_set
* sewing_kit InputArgs, ActivityContext, Needle validation
Plus a GPU integration test pinning resume-from-latest end-to-end
(test_bypass_resume.py).
Fix off-by-one in _get_lr cosine: decay_ratio = (step - W) / (D - W)
so the schedule reaches min_lr exactly at step==D instead of relying
on the post-decay clamp at D+1 to mask a one-step plateau at base_lr
right after warmup.
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
* SIM117: combine nested `with` statements where semantically equivalent (test_max_depth, test_no_duplicates_*, test_stack_unwinds). test_reversed_pushes_to_front_and_pops_from_front keeps the nested form because the intermediate assertion between exits is the test's actual point. * ruff format: drop blank line after lone import, merge two `from ... import` blocks for sewing_kit.core. Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
* RUF005: noqa the `InputArgs(1) + [2]` line — the auto-fix `[*InputArgs(1), 2]` would replace the operator call we are testing (this test pins that ``__add__`` rejects non-InputArgs via its internal assert). * ruff format: drop spurious blank lines after import blocks, collapse a function signature back to one line. * PLC0415 (in-function imports): hoist the two ``import os`` calls inside ``test_save_bypass_checkpoint_*`` to the file's import block. * Move the late ``is_submodule_of`` / ``is_submodule_or_same`` imports in ``test_sewing_kit_activity_context.py`` to the top of the file so the # noqa: E402 marker isn't needed. Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Four test signatures in test_bypass_checkpoint_utils.py fit on a single line within the 100-char limit; ruff format wants them collapsed. The fifth signature in the same file (test_save_bypass_checkpoint_master_only_skips_symlink_on_non_master) has three args and stays multi-line. Verified with `awk 'length > 100'` and `grep -nE '^def test.*\(\s*$'` across all new test files — no other lint nits should remain. Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
|
@AAnoosheh ready for review |
|
/ok to test 4ea3262 |
|
/claude review |
| ExternalTarget().input(name=input_descriptor), | ||
| student_submodule_target.input(name=submodule_input_descriptor), | ||
| ) | ||
| .stitch( | ||
| ExternalTarget().output( | ||
| name=output_descriptor, | ||
| adapter=lambda v: InputArgs(target=v) | ||
| if not isinstance(v, tuple) | ||
| else InputArgs(target=v[0]), | ||
| ), | ||
| student_stitched_module_loss_target.input(), | ||
| ) | ||
| .stitch( | ||
| student_submodule_target.output( | ||
| name=submodule_output_descriptor, | ||
| adapter=lambda v: InputArgs(input=v) | ||
| if not isinstance(v, tuple) | ||
| else InputArgs(input=v[0]), |
There was a problem hiding this comment.
[CRITICAL Algorithm] These lambda closures capture the loop variable v by reference at definition time — but the lambdas themselves are created fresh each iteration so they're safe. However, the adapter lambdas that call InputArgs(target=v) vs InputArgs(input=v) both receive the same variable name v. If the stitching framework ever evaluates both adapters in the same scope, they'd alias. This is likely fine as they're stored on different targets, but worth noting the naming collision for readability.
More importantly: the lambda v: InputArgs(target=v) if not isinstance(v, tuple) else InputArgs(target=v[0]) pattern assumes tuple outputs always have the hidden state as element 0. This is true for standard HuggingFace transformer blocks (which return (hidden_states, ...)) but could break for custom blocks that return tuples with a different layout. Consider adding a defensive shape/dtype assertion on v[0] during the dummy-run phase to catch this.
| def override(item, item_overrides): | ||
| if item_overrides is None: | ||
| return item_overrides | ||
| return item # None override means "keep original value" |
There was a problem hiding this comment.
[IMPORTANT Compatibility] Good fix — the previous code returned None literally (destroying the original item), which silently corrupted model configs when a layer had no override. This is a bugfix, but note it changes observable behavior: existing scripts that relied on None being passed through to update_model_config (perhaps incorrectly) will now silently keep the original value instead of producing a TypeCheckError or None in the block config. This is almost certainly the correct behavior, but document it in the PR as a behavioral change for existing users.
Code Review SummaryCRITICAL: 1 | IMPORTANT: 3 | SUGGESTION: 3 Critical Findings
Important Findings
Suggestions
Overall AssessmentRisk level: Medium. The PR adds a substantial new training subsystem (bypass distillation) to the puzzletron pipeline. The core architecture is well-designed — pipeline-parallel per-block KD using the existing sewing-kit abstraction, clean checkpoint management, and comprehensive test coverage across all model families. The critical finding (loss denominator floor) is the most impactful — it can cause training divergence on models with sparse/zero activations in some blocks. The compatibility findings are low-probability impacts (the string-key path and The The test suite is thorough: unit tests for losses, checkpoint utils, and LR scheduling, plus GPU integration tests parametrized across 9 model families covering FFN pruning, KV head compression, multi-config sweeps, and checkpoint validation. |
Drop the per-block weight save from `_save_local_state` — the same
parameters were already on disk in the top-level HF checkpoint that
`save_bypass_checkpoint` writes via `save_checkpoint(model, ...)`.
The stitched/ directory now carries only optimizer + grad_scaler
state; weights round-trip through the HF format. Resume now routes
through `load_and_shard_model` (same path as `init_checkpoint_path`)
so weight loading has a single entry point. Per-iter checkpoint
disk footprint roughly halves.
Other Claude review fixes:
* Use modelopt.torch.utils.robust_json (canonical) and delete the
duplicate puzzletron-local copy.
* Remove dead num_trainable_params block in bypass_factory_fn (unused
and the sum was a bool count, not a numel sum).
* GptOssModelDescriptor: keep "expert_removal" as a deprecation alias
for the new "experts_removal" key so existing string-API callers
don't break.
* Move time_start out of module scope into train() — at module level
it became stale relative to actual training start, firing the first
time-based save immediately.
* Batch the per-block loss GPU->CPU copy into a single sync after the
loop (was N sync points per training step).
* puzzletron_nas_plugin staleness check: probe checkpoint config.json
mtime instead of resolved-symlink directory mtime; handle dangling
symlinks via try/except.
* Fix off-by-one in _get_lr cosine: decay_ratio = (step - W) / (D - W)
so the schedule reaches min_lr exactly at step==D.
* batched_normalized_mse_loss: clamp the per-vector denominator to a
floor of epsilon so an all-zero target slice doesn't explode the
loss. Slightly diverges from the original Puzzle implementation
(documented in the docstring).
* Add HF-convention comment in stitched_model_factory: tuple-returning
blocks always have hidden_states at index 0.
Tests updated to match: stitched/{block}.state_dict.pth assertion
flipped to assert-not-written; new regression tests for the LR
schedule endpoint and the zero-target loss path.
Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
The stacked-loss expression fits on a single line within the 100-char limit; ruff format unwrapped it. Signed-off-by: Sepehr Sameni <ssameni@nvidia.com>
Bypass distillation trains alternative transformer block configurations using per-block knowledge distillation from the teacher model, producing a library of better "puzzle pieces" for the MIP solver. It is most beneficial for aggressive FFN pruning (≤ 1/8 of teacher width) or significant KV head compression.
Changes:
Summary by CodeRabbit
Release Notes
New Features
Tests
Documentation