Enable active-param and memory based Minitron pruning constraint#1377
Enable active-param and memory based Minitron pruning constraint#1377kevalmorabia97 wants to merge 2 commits intomainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds combinable NAS pruning targets ( Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant CLI as prune_minitron.py
participant Stats as megatron_model_stats
participant Pruner as mcore_minitron.Searcher
participant Model as Teacher / Candidate Builder
participant Output as Pruned Model / Report
User->>CLI: run prune command (targets or --prune_export_config)
CLI->>Stats: compute teacher metrics (params/active/memory) as needed
CLI->>Pruner: pass metric ceilings + batch_size/seq_length (if NAS)
Pruner->>Stats: compute candidate metrics analytically for grid
Pruner->>Pruner: filter candidates by all ceilings, rank by primary metric
Pruner->>Pruner: validate top-k via score_func
Pruner->>Model: build/export chosen pruned model (or use export_config)
Model->>Output: emit pruned model artifact
Pruner->>Output: render search results and stats
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
examples/megatron_bridge/prune_minitron.py (1)
151-158: ⚡ Quick winExpose the memory-budget batch size explicitly.
memory_mbin the searcher depends on bothseq_lengthandbatch_size, but this CLI only forwardsseq_lengthand leavesbatch_sizeat the default of 1. That makes the budget optimistic for any real generation batch greater than 1.Suggested fix
+ parser.add_argument( + "--memory_batch_size", + type=int, + default=1, + help="Batch size to use when evaluating --prune_target_memory_mb.", + ) ... pruning_config["top_k"] = args.top_k pruning_config["seq_length"] = args.seq_length + pruning_config["batch_size"] = args.memory_batch_sizeAlso applies to: 368-368
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/megatron_bridge/prune_minitron.py` around lines 151 - 158, The CLI's prune memory budget uses --seq_length but assumes batch_size=1, making memory_mb optimistic for real batches; add a new argument (e.g., --prune_memory_batch_size or --prune_batch_size) to the parser near the existing parser.add_argument("--prune_target_memory_mb", ...) with type=int and a default of 1, update its help text to note it affects the memory budget calculation, and then pass this value into the memory calculation used by the searcher (the code that computes memory_mb or calls searcher) so memory_mb accounts for seq_length * batch_size when invoking the NAS searcher. Ensure the new flag is also used in the other occurrence referenced (around the other prune call) so both places compute memory consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/megatron_bridge/README.md`:
- Around line 81-88: The documentation example is inconsistent: the prose says
"3B active params" but the command uses --prune_target_active_params 8e9 and the
output path contains "Pruned-8B-Active"; update them to match. Edit README.md so
that either the prose reads "8B active params" to match the prune_minitron.py
invocation and output_hf_path, or change the command flag value to 3e9 and the
output_hf_path to "Pruned-3B-Active" so all three references (the prose,
--prune_target_active_params flag, and the output_hf_path) consistently reflect
the same active-params target.
In `@modelopt/torch/nas/megatron_model_stats.py`:
- Around line 201-207: parse_main_layer_chars currently only strips PP boundary
'|' but must also skip GDN marker 'G' so hybrid patterns containing 'G' don't
propagate into mcore_param_count and trigger the unsupported-char error; update
parse_main_layer_chars to ignore 'G' (in addition to '|') when producing the
per-layer char list (ensure the function returns one char per actual layer by
filtering out '|' and 'G').
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 430-439: active_metric_keys is currently built from iterating the
unordered frozenset _METRIC_CONSTRAINTS which can select a different primary_key
across ranks; change construction of active_metric_keys to preserve the
constraints insertion order by iterating self.constraints and selecting keys
present in _METRIC_CONSTRAINTS (e.g., active_metric_keys = [k for k in
self.constraints if k in _METRIC_CONSTRAINTS]), ensure you handle the empty-case
before accessing primary_key, and keep primary_key deterministic. In
_compute_candidate_metrics(), replace direct access to
model.hybrid_layer_pattern with the same compatibility fallback used in _prune
(use getattr(model, "hybrid_layer_pattern", None) or
self.hybrid_override_pattern) to avoid AttributeError on older Megatron-Core
versions. Ensure both changes reference the existing symbols active_metric_keys,
primary_key, _METRIC_CONSTRAINTS, _compute_candidate_metrics,
model.hybrid_layer_pattern, hybrid_override_pattern and _prune so the fixes
align with the surrounding logic.
- Around line 650-653: The code directly accesses model.hybrid_layer_pattern for
MambaModel which can raise AttributeError on older MCore versions that only
expose hybrid_override_pattern; update the access to use the same fallback logic
as earlier (check for getattr(model, "hybrid_override_pattern", None) first,
then fallback to getattr(model, "hybrid_layer_pattern", None)) when computing
hybrid_layer_pattern for MambaModel so metric-based NAS candidate evaluation
won't fail on pre-0.17 MCore; modify the block handling MambaModel to use this
fallback lookup.
---
Nitpick comments:
In `@examples/megatron_bridge/prune_minitron.py`:
- Around line 151-158: The CLI's prune memory budget uses --seq_length but
assumes batch_size=1, making memory_mb optimistic for real batches; add a new
argument (e.g., --prune_memory_batch_size or --prune_batch_size) to the parser
near the existing parser.add_argument("--prune_target_memory_mb", ...) with
type=int and a default of 1, update its help text to note it affects the memory
budget calculation, and then pass this value into the memory calculation used by
the searcher (the code that computes memory_mb or calls searcher) so memory_mb
accounts for seq_length * batch_size when invoking the NAS searcher. Ensure the
new flag is also used in the other occurrence referenced (around the other prune
call) so both places compute memory consistently.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 8c94f5fa-6801-4848-8f93-59cedf8ae6b2
📒 Files selected for processing (7)
examples/megatron_bridge/README.mdexamples/megatron_bridge/prune_minitron.pymodelopt/torch/nas/megatron_model_stats.pymodelopt/torch/prune/plugins/mcore_minitron.pytests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.pytests/gpu_megatron/torch/nas/test_megatron_model_stats.pytests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1377 +/- ##
===========================================
+ Coverage 66.36% 76.78% +10.41%
===========================================
Files 471 472 +1
Lines 50510 50824 +314
===========================================
+ Hits 33522 39025 +5503
+ Misses 16988 11799 -5189
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
caae65c to
8035a26
Compare
8035a26 to
11805a4
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
♻️ Duplicate comments (1)
modelopt/torch/prune/plugins/mcore_minitron.py (1)
691-694:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAdd fallback for
hybrid_override_patternto ensure compatibility with older MCore versions.Line 694 directly accesses
model.hybrid_layer_patternwithout the compatibility fallback used elsewhere in this file (lines 373-377). On MCore versions that only exposehybrid_override_pattern(pre-0.17), this will raise anAttributeErrorwhen computing metrics for hybrid/Mamba models.🔧 Suggested fix
# Get hybrid layer pattern for MambaModel (None for pure GPT) hybrid_layer_pattern: str | None = None if isinstance(model, MambaModel): - hybrid_layer_pattern = model.hybrid_layer_pattern + hybrid_key = ( + "hybrid_override_pattern" + if hasattr(model, "hybrid_override_pattern") + else "hybrid_layer_pattern" + ) + hybrid_layer_pattern = getattr(model, hybrid_key)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 691 - 694, The code assumes model.hybrid_layer_pattern exists for MambaModel but older MCore exposes hybrid_override_pattern; update the MambaModel branch in compute metrics to use a compatibility fallback (e.g., set hybrid_layer_pattern = getattr(model, "hybrid_layer_pattern", None) or fallback to getattr(model, "hybrid_override_pattern", None)) so it mirrors the earlier compatibility logic used elsewhere in this file; ensure you reference MambaModel and the model variable when adding the fallback so AttributeError is avoided on pre-0.17 MCore.
🧹 Nitpick comments (1)
modelopt/torch/nas/plugins/megatron_model_stats.py (1)
41-44: Move optional dependency imports inside function scope using lazy loading.Lines 41-44 import
megatronandrichat module level. Per coding guidelines, optional dependencies must be loaded lazily viaimport_plugin()at the point of use, not hard-imported at module level. Even though the module itself is wrapped withimport_plugin()in__init__.py, the imports within the file still execute eagerly, violating the lazy-loading convention.Defer these imports to the functions that require them (e.g.,
print_mcore_model_stats,_mamba_layer_params) usingimport_plugin()context managers at call sites.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/nas/plugins/megatron_model_stats.py` around lines 41 - 44, Remove the module-level imports of megatron and rich (MambaModel, Console, Panel, Table) and instead perform lazy imports inside the functions that use them (e.g., print_mcore_model_stats and _mamba_layer_params) by calling import_plugin() as a context manager at the beginning of each function and importing megatron.core.models.mamba.mamba_model.MambaModel and rich.console.Console / rich.panel.Panel / rich.table.Table there; ensure no other code in the module relies on those names at import time and update any references to use the locally imported symbols within those functions.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/nas/plugins/megatron_model_stats.py`:
- Around line 352-353: The code slices
parse_main_layer_chars(hybrid_layer_pattern) to num_layers (assigned to
layer_chars) without verifying the parsed pattern length; update the logic in
the places using parse_main_layer_chars (the assignment to layer_chars) to first
compute parsed = parse_main_layer_chars(hybrid_layer_pattern), check that
len(parsed) >= num_layers (or exactly equals num_layers as appropriate), and if
not raise a clear exception or log an error and fail fast so
hybrid_layer_pattern mismatches aren’t silently truncated; reference
parse_main_layer_chars, hybrid_layer_pattern, num_layers, and layer_chars when
adding this validation and error handling (apply the same change at the other
call site that slices to num_layers).
- Around line 488-501: kv_per_layer and kv_bytes are computed unconditionally
which can crash if kv_channels or num_query_groups are missing and Mamba uses
0/1 fallbacks that silently undercount; fix by gating the KV calculation behind
n_attn > 0 and validating required dimensions (kv_channels, num_query_groups)
before computing kv_per_layer/kv_bytes (raise a clear ValueError if missing),
and for Mamba ensure when n_mamba > 0 you validate mamba_num_heads,
mamba_head_dim, mamba_state_dim (and optionally mamba_num_groups) are not None
rather than defaulting to 0/1; only compute mamba_bytes when the required Mamba
dims are present, otherwise raise an error so the estimator fails fast instead
of under-reporting.
---
Duplicate comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 691-694: The code assumes model.hybrid_layer_pattern exists for
MambaModel but older MCore exposes hybrid_override_pattern; update the
MambaModel branch in compute metrics to use a compatibility fallback (e.g., set
hybrid_layer_pattern = getattr(model, "hybrid_layer_pattern", None) or fallback
to getattr(model, "hybrid_override_pattern", None)) so it mirrors the earlier
compatibility logic used elsewhere in this file; ensure you reference MambaModel
and the model variable when adding the fallback so AttributeError is avoided on
pre-0.17 MCore.
---
Nitpick comments:
In `@modelopt/torch/nas/plugins/megatron_model_stats.py`:
- Around line 41-44: Remove the module-level imports of megatron and rich
(MambaModel, Console, Panel, Table) and instead perform lazy imports inside the
functions that use them (e.g., print_mcore_model_stats and _mamba_layer_params)
by calling import_plugin() as a context manager at the beginning of each
function and importing megatron.core.models.mamba.mamba_model.MambaModel and
rich.console.Console / rich.panel.Panel / rich.table.Table there; ensure no
other code in the module relies on those names at import time and update any
references to use the locally imported symbols within those functions.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 5e1794a8-d42b-4444-abd1-542dd79d975d
📒 Files selected for processing (11)
CHANGELOG.rstexamples/megatron_bridge/README.mdexamples/megatron_bridge/prune_minitron.pyexamples/pruning/README.mdmodelopt/torch/nas/plugins/__init__.pymodelopt/torch/nas/plugins/megatron_model_stats.pymodelopt/torch/prune/plugins/mcore_minitron.pytests/examples/megatron_bridge/test_prune_minitron.pytests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.pytests/gpu_megatron/torch/nas/plugins/test_megatron_model_stats.pytests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
✅ Files skipped from review due to trivial changes (2)
- examples/pruning/README.md
- CHANGELOG.rst
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/examples/megatron_bridge/test_prune_minitron.py`:
- Around line 53-56: The test calls mcore_param_count(pruned_model) with a
HuggingFace model but mcore_param_count requires the vocab_size when given an HF
model; update the call to pass the model's vocab size (use
pruned_model.config.vocab_size) so call mcore_param_count(pruned_model,
vocab_size=pruned_model.config.vocab_size) (mirror the same fix applied for the
teacher model) to ensure correct parameter counting for
AutoModelForCausalLM.from_pretrained(pruned_model_path).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 8432c79c-56ed-4bff-916a-f3d84fafcd02
📒 Files selected for processing (11)
CHANGELOG.rstexamples/megatron_bridge/README.mdexamples/megatron_bridge/prune_minitron.pyexamples/pruning/README.mdmodelopt/torch/nas/plugins/__init__.pymodelopt/torch/nas/plugins/megatron_model_stats.pymodelopt/torch/prune/plugins/mcore_minitron.pytests/examples/megatron_bridge/test_prune_minitron.pytests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.pytests/gpu_megatron/torch/nas/plugins/test_megatron_model_stats.pytests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
✅ Files skipped from review due to trivial changes (4)
- examples/pruning/README.md
- modelopt/torch/nas/plugins/megatron_model_stats.py
- examples/megatron_bridge/README.md
- CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (2)
- modelopt/torch/nas/plugins/init.py
- examples/megatron_bridge/prune_minitron.py
5c832dd to
90ce94e
Compare
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/_test_utils/torch/transformers_models.py (1)
111-120:⚠️ Potential issue | 🟠 Major | ⚡ Quick winThe new tuple return type is unreachable.
create_tiny_qwen3_moe_dir()still always returnsqwen3_moe_dir, so the widened annotation does not match the implementation. Either add the missingreturn_modelpath or keep the return type asPathuntil that behavior exists.♻️ Suggested fix
def create_tiny_qwen3_moe_dir( - tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs + tmp_path: Path | str, with_tokenizer: bool = False, return_model: bool = False, **config_kwargs ) -> Path | tuple[Path, PreTrainedModel]: qwen3_moe_dir = Path(tmp_path) / "tiny_qwen3_moe" if with_tokenizer: - tokenizer = tokenizer = get_tiny_tokenizer() + tokenizer = get_tiny_tokenizer() tokenizer.save_pretrained(qwen3_moe_dir) config_kwargs["vocab_size"] = tokenizer.vocab_size - get_tiny_qwen3_moe(**config_kwargs).save_pretrained(qwen3_moe_dir) - return qwen3_moe_dir + tiny_qwen3_moe = get_tiny_qwen3_moe(**config_kwargs) + tiny_qwen3_moe.save_pretrained(qwen3_moe_dir) + if return_model: + return qwen3_moe_dir, tiny_qwen3_moe + return qwen3_moe_dir🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/_test_utils/torch/transformers_models.py` around lines 111 - 120, The annotated return type claims the function can return Path or tuple[Path, PreTrainedModel], but the implementation always returns only Path; fix by either reverting the annotation to Path or by adding an explicit return_model: bool = False parameter and returning the model when requested: call model = get_tiny_qwen3_moe(**config_kwargs), model.save_pretrained(qwen3_moe_dir), and if return_model is True return (qwen3_moe_dir, model) else return qwen3_moe_dir; update the function signature and return type to match this behavior and adjust references to tokenizer variable duplication (tokenizer = tokenizer) if present.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/megatron_bridge/prune_minitron.py`:
- Around line 151-158: prune_target_memory_mb is currently evaluated using
pruning_config["batch_size"] wired to --calib_mbs (fixed to 1), so add a
dedicated CLI flag (e.g., --prune_batch_size or --prune_target_batch_size) in
the same argparse block where parser.add_argument("--prune_target_memory_mb",
...) is defined and use that new argument to set pruning_config["batch_size"]
instead of using --calib_mbs; keep --calib_mbs for calibration work but ensure
the pruning path reads args.prune_batch_size (with a sensible default, e.g., 1)
and update any place that sets pruning_config["batch_size"] (and related
mentions around lines 363-365) to reference this new flag so memory estimates
reflect the intended deployment batch size.
In `@tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py`:
- Around line 234-238: The current use of
DistributedProcessGroup.get_dist_syncd_obj with the lambda that flattens
all_rank_configs into width_ss_config can silently mask per-rank differences
(last-write-wins); before applying the flattening op, collect all_rank_configs
from get_dist_syncd_obj (or call the synchronization helper without the op),
iterate over each key across the per-rank dicts returned for local_config and
validate that every rank reports an identical value for that key, and if any key
has conflicting values raise/assert/log a clear error (or fail the test)
indicating which key and differing values were observed; then only call the
flattening op (the lambda used to build width_ss_config) after this agreement
check (references: DistributedProcessGroup.get_dist_syncd_obj, local_config,
width_ss_config, get_pipeline_model_parallel_group).
- Line 156: Replace the hard assert that enforces "size <= 4" with a pytest.skip
guard so tests don't hard-fail on larger pipeline-parallel sizes; specifically,
in the helper/test that currently does `assert size <= 4,
"test_param_num_dynamic_matches_formula only configured for upto 4 GPUs"`, check
`if size > 4:` and call `pytest.skip("...")` with a similar explanatory message
referencing the limitation, importing pytest if necessary so the test is skipped
rather than failing.
---
Outside diff comments:
In `@tests/_test_utils/torch/transformers_models.py`:
- Around line 111-120: The annotated return type claims the function can return
Path or tuple[Path, PreTrainedModel], but the implementation always returns only
Path; fix by either reverting the annotation to Path or by adding an explicit
return_model: bool = False parameter and returning the model when requested:
call model = get_tiny_qwen3_moe(**config_kwargs),
model.save_pretrained(qwen3_moe_dir), and if return_model is True return
(qwen3_moe_dir, model) else return qwen3_moe_dir; update the function signature
and return type to match this behavior and adjust references to tokenizer
variable duplication (tokenizer = tokenizer) if present.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 4d72e821-1f4b-4fcf-8aee-8442f5881c50
📒 Files selected for processing (12)
CHANGELOG.rstexamples/megatron_bridge/README.mdexamples/megatron_bridge/prune_minitron.pyexamples/pruning/README.mdmodelopt/torch/nas/plugins/__init__.pymodelopt/torch/nas/plugins/megatron_model_stats.pymodelopt/torch/prune/plugins/mcore_minitron.pytests/_test_utils/torch/transformers_models.pytests/examples/megatron_bridge/test_prune_minitron.pytests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.pytests/gpu_megatron/torch/nas/plugins/test_megatron_model_stats.pytests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
✅ Files skipped from review due to trivial changes (4)
- examples/pruning/README.md
- CHANGELOG.rst
- examples/megatron_bridge/README.md
- modelopt/torch/nas/plugins/megatron_model_stats.py
🚧 Files skipped from review as they are similar to previous changes (3)
- tests/gpu_megatron/torch/nas/plugins/test_megatron_model_stats.py
- tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
- modelopt/torch/prune/plugins/mcore_minitron.py
90ce94e to
f061c19
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
modelopt/torch/prune/plugins/mcore_minitron.py (1)
507-509: ⚡ Quick winUse all active metrics to break ties before slicing
top_k.Right now the cache is sorted only by
primary_key, so candidates that tie on that metric keep the raw product-order from_generate_search_space_combos(). With combined constraints, that can arbitrarily push out subnets that are closer to the secondary ceilings beforetop_kvalidation ever runs.Proposed change
- self.all_candidates_per_constraint[constraints_cache_key] = sorted( - selected, key=lambda x: x.metrics[primary_key], reverse=True - ) + self.all_candidates_per_constraint[constraints_cache_key] = sorted( + selected, + key=lambda x: tuple(x.metrics[k] for k in active_metric_keys), + reverse=True, + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/prune/plugins/mcore_minitron.py` around lines 507 - 509, The cache currently sorts selected only by primary_key which leaves ties ordered by the generator's product order and can evict better candidates on secondary metrics; change the sort used when assigning self.all_candidates_per_constraint[constraints_cache_key] to order by a tuple of all active metric keys (primary_key first, then the other metric names in the same priority order) so ties are broken by subsequent metrics before any top_k slicing. Locate the assignment to self.all_candidates_per_constraint[constraints_cache_key] and replace the key=lambda x: x.metrics[primary_key] with a key that builds a tuple like tuple(x.metrics[m] for m in metrics_list) (use reverse=True if higher is better) where metrics_list contains primary_key followed by the remaining active metrics; keep the rest of the flow (e.g., _generate_search_space_combos() and later top_k validation) unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py`:
- Around line 319-325: The helper _assert_top_k_candidates currently uses exact
equality for metrics (actual.metrics == metrics) which causes flaky tests for
fractional float fields like memory_mb; update _assert_top_k_candidates to
compare numeric/float metric values using pytest.approx (import pytest) rather
than exact equality—e.g., when comparing actual.metrics to expected metrics,
iterate metric keys and for values that are floats use assert actual_value ==
pytest.approx(expected_value) and for non-floats keep exact equality; keep the
existing checks for ss_config and score but replace any direct float comparisons
(including score if it can be fractional) with pytest.approx to stabilize the
test.
---
Nitpick comments:
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 507-509: The cache currently sorts selected only by primary_key
which leaves ties ordered by the generator's product order and can evict better
candidates on secondary metrics; change the sort used when assigning
self.all_candidates_per_constraint[constraints_cache_key] to order by a tuple of
all active metric keys (primary_key first, then the other metric names in the
same priority order) so ties are broken by subsequent metrics before any top_k
slicing. Locate the assignment to
self.all_candidates_per_constraint[constraints_cache_key] and replace the
key=lambda x: x.metrics[primary_key] with a key that builds a tuple like
tuple(x.metrics[m] for m in metrics_list) (use reverse=True if higher is better)
where metrics_list contains primary_key followed by the remaining active
metrics; keep the rest of the flow (e.g., _generate_search_space_combos() and
later top_k validation) unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: c55cb2a6-fe6f-49b5-9fab-31db8a3753f4
📒 Files selected for processing (12)
CHANGELOG.rstexamples/megatron_bridge/README.mdexamples/megatron_bridge/prune_minitron.pyexamples/pruning/README.mdmodelopt/torch/nas/plugins/__init__.pymodelopt/torch/nas/plugins/megatron_model_stats.pymodelopt/torch/prune/plugins/mcore_minitron.pytests/_test_utils/torch/transformers_models.pytests/examples/megatron_bridge/test_prune_minitron.pytests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.pytests/gpu_megatron/torch/nas/plugins/test_megatron_model_stats.pytests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
✅ Files skipped from review due to trivial changes (3)
- CHANGELOG.rst
- modelopt/torch/nas/plugins/megatron_model_stats.py
- examples/megatron_bridge/README.md
🚧 Files skipped from review as they are similar to previous changes (6)
- tests/examples/megatron_bridge/test_prune_minitron.py
- tests/_test_utils/torch/transformers_models.py
- modelopt/torch/nas/plugins/init.py
- tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py
- examples/megatron_bridge/prune_minitron.py
- tests/gpu_megatron/torch/nas/plugins/test_megatron_model_stats.py
…ch logging Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
38d82fb to
b2b3346
Compare
What does this PR do?
Type of change: New feature, new tests, documentation.
Extends the Minitron NAS pruner to support pruning by active parameter count (
active_params) and memory footprint (memory_mb) in addition to the existing total parameter count (params) constraint. Also adds standalone utilities for analytical model stats.Changes
New pruning constraint keys
active_params: prune to a target number of active (routed) params — useful for MoE models where total ≫ active; when present,active_paramsis the primary sort/display metric for candidates (priority:active_params>params>memory_mb)memory_mb: prune to fit a memory budget (BF16 weights + KV-cache + Mamba state at a given sequence length and batch size){"params": 6e9, "memory_mb": 12288}New standalone utilities (
modelopt.torch.nas.plugins.megatron_model_stats)mcore_param_count: analytically computes total and active parameter counts for GPT and Mamba/hybrid MCore modelsmcore_memory_footprint_mb: estimates memory in MB (weights + KV-cache + Mamba state)print_mcore_model_stats: rich-formatted model stats panelRich-formatted pruning logs — search space, top-k candidate tables, and best subnet panel printed on rank 0
prune_score_funcformat update — nowmmlu_<N>pct_bs<bs>(e.g.mmlu_10pct_bs32) to explicitly control batch size for MMLU evaluation; oldmmlu_<N>pctformat removedInfrastructure
nvcr.io/nvidia/nemo:26.04in CI and docsexamples/megatron_bridge/requirements.txtwithtransformers<5.0(required for saving some Nemotron-3-Nano models)Usage
Testing
Pruned Nemotron-3-Nano-30B-A3B (31.6B, A3.6B) --> A3.0B. Takes <1hr on 8x H100 (more details in #1376)
torchrun --nproc_per_node 8 examples/megatron_bridge/prune_minitron.py \ --pp_size 8 \ --hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ --trust_remote_code \ --prune_target_params 28e9 \ --prune_target_active_params 3e9 \ --hparams_to_skip num_attention_heads \ --seq_length 8192 \ --output_hf_path pruned/Nemotron-3-Nano-30B-A3B-Pruned-28B-A3B-top20-max15depth-max30width-mmlu_10pct_bs32 \ --top_k 20 \ --max_depth_pruning 0.15 \ --max_width_pruning 0.30 \ --prune_score_func mmlu_10pct_bs32 \ --num_layers_in_first_pipeline_stage 5 \ --num_layers_in_last_pipeline_stage 5Before your PR is "Ready for review"