From 7ad01c0bb9a58f0bb673101a1b46e2645ab260b5 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Thu, 30 Apr 2026 08:53:54 -0700 Subject: [PATCH 1/2] Enable active-param and memory based Minitron pruning constraint + rich logging Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- CHANGELOG.rst | 1 + examples/megatron_bridge/README.md | 38 +- examples/megatron_bridge/prune_minitron.py | 115 ++-- examples/pruning/README.md | 2 +- modelopt/torch/nas/plugins/__init__.py | 4 +- .../torch/nas/plugins/megatron_model_stats.py | 606 ++++++++++++++++++ .../torch/prune/plugins/mcore_minitron.py | 328 +++++++--- .../_test_utils/torch/transformers_models.py | 2 +- .../megatron_bridge/test_prune_minitron.py | 8 +- .../test_megatron_mamba_dynamic_modules.py | 124 +++- .../nas/plugins/test_megatron_model_stats.py | 445 +++++++++++++ .../test_mcore_mamba_minitron_pruning.py | 310 ++++++--- 12 files changed, 1757 insertions(+), 226 deletions(-) create mode 100644 modelopt/torch/nas/plugins/megatron_model_stats.py create mode 100644 tests/gpu_megatron/torch/nas/plugins/test_megatron_model_stats.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d2369885431..efa830d1a10 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,6 +19,7 @@ Changelog - Add offline DFlash speculative decoding training. Train the draft module from pre-computed base-model hidden states dumped by ``examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py``; base-model transformer layers are deleted after conversion to save memory. Controlled by the auto-derived ``dflash_offline`` flag on ``DFlashConfig`` (derived from ``data_args.offline_data_path``). The dump scripts now share ``collect_hidden_states/common.py`` for aux-layer selection (``--aux-layers eagle|dflash|``) and optional assistant-token ``loss_mask`` for answer-only-loss training. - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. +- Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md `_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model. 0.44 (2026-05-xx) ^^^^^^^^^^^^^^^^^ diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index 1e384acfb19..ea7d2922810 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -53,7 +53,18 @@ hf auth login --token This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md). -Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults: +The script supports three NAS-based pruning targets and one manual export mode: + +| Mode | Flag | Description | +| :---: | :---: | :--- | +| NAS | `--prune_target_params` | Prune to a target total parameter count | +| NAS | `--prune_target_active_params` | Prune to a target active parameter count (useful for MoE models). For non-MoE models, this is equivalent to `--prune_target_params`. | +| NAS | `--prune_target_memory_mb` | Prune to a target memory footprint in MB (weights + KV-cache) for a given batch size and sequence length assuming BF16 precision | +| Manual | `--prune_export_config` | Prune directly to a specified architecture config (no NAS). Useful if you want to take top K candidates and do a short knowledge distillation before selecting the best model. | + +Multiple NAS targets can be combined — e.g. `--prune_target_params 6e9 --prune_target_memory_mb 12288` finds the best model with under 6B params and under 12GB memory footprint at (default) batch size 1 and sequence length 4096 assuming BF16 precision. + +**Prune by total parameter count** — prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of `num_attention_heads` using following defaults: 1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration, at-most 20% depth (`num_layers`) and 40% width is pruned per prunable hparam (`hidden_size`, `ffn_hidden_size`, ...), top-10 candidates are evaluated for MMLU score (5% sampled data) to select the best model. @@ -67,8 +78,29 @@ torchrun --nproc_per_node 2 prune_minitron.py \ --output_hf_path /tmp/Qwen3-8B-Pruned-6B ``` -Example usage for manually pruning to a specific architecture using following defaults: - 1024 samples from [`nemotron-post-training-dataset-v2`](https://huggingface.co/datasets/nvidia/Nemotron-Post-Training-Dataset-v2) for calibration. +**Prune by active parameter count** — useful for MoE models where most experts are inactive per token (e.g. prune Nemotron-3-Nano-30B-A3B-BF16 (3.6B active params) to 3B active params): + +```bash +torchrun --nproc_per_node 2 prune_minitron.py \ + --pp_size 2 \ + --hf_model_name_or_path nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 \ + --prune_target_active_params 3e9 \ + --output_hf_path /tmp/Nemotron-3-Nano-30B-A3B-BF16-Pruned-3B-Active +``` + +**Prune by memory footprint** — prune to fit a target GPU memory budget (weights + KV-cache at the given sequence length and batch size, assuming BF16): + +```bash +torchrun --nproc_per_node 2 prune_minitron.py \ + --pp_size 2 \ + --hf_model_name_or_path Qwen/Qwen3-8B \ + --prune_target_memory_mb 12288 \ + --seq_length 4096 \ + --calib_mbs 1 \ + --output_hf_path /tmp/Qwen3-8B-Pruned-12GB +``` + +**Manual pruning** — prune directly to a specified architecture (no NAS, no score evaluation): ```bash torchrun --nproc_per_node 2 prune_minitron.py \ diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index 0fa9a658ff2..1eff609fee6 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -14,6 +14,11 @@ # limitations under the License. """Example script for pruning a GPT / Mamba model using Minitron algorithm on a Megatron-Bridge model (load from HF). +Supports three NAS-based pruning targets (can be combined): + --prune_target_params Total parameter count (e.g. 6e9 for 6B total params) + --prune_target_active_params Active parameter count for MoE models (e.g. 3e9 for 3B active params) + --prune_target_memory_mb Memory footprint in MB (uses --seq_length for KV-cache estimate, assumes BF16) + Example usage to prune Qwen3-8B to 6B on 2-GPUs (Pipeline Parallelism = 2) while skipping pruning of num_attention_heads using following defaults: 1024 samples from nemotron-post-training-dataset-v2 for calibration, @@ -47,7 +52,7 @@ import modelopt.torch.opt as mto import modelopt.torch.prune as mtp import modelopt.torch.utils.distributed as dist -from modelopt.torch.utils import get_supported_datasets, num2hrb, print_rank_0, warn_rank_0 +from modelopt.torch.utils import get_supported_datasets, print_rank_0, warn_rank_0 from modelopt.torch.utils.plugins.mbridge import ( get_hf_mbridge_calibration_loop, load_mbridge_model_from_hf, @@ -105,7 +110,6 @@ def get_args() -> argparse.Namespace: ) parser.add_argument("--calib_gbs", type=int, default=1, help="Calibration global batch size") parser.add_argument("--seq_length", type=int, default=4096) - # Pruning parameters parser.add_argument( "--prune_intermediate_ckpt", @@ -117,23 +121,40 @@ def get_args() -> argparse.Namespace: ), ) - target_group = parser.add_mutually_exclusive_group(required=True) - target_group.add_argument( + parser.add_argument( "--prune_export_config", type=str, help=( 'Target pruned config as JSON e.g., \'{"hidden_size": 512, "ffn_hidden_size": 2048}\'. ' f"Supported hyperparameters: {mtp.mcore_minitron.SUPPORTED_HPARAMS}. " - "Cannot be used with --prune_target_params." + "Cannot be combined with NAS-based targets." ), ) - target_group.add_argument( + parser.add_argument( "--prune_target_params", type=float, help=( - "Target parameter count for pruning e.g., 6e9 for pruning to 6B params (total params, not active params). " - "Uses Neural Architecture Search (NAS) to find the best pruned model that maximizes the --prune_score_func." - "Cannot be used with --prune_export_config." + "Target total parameter count e.g., 6e9 for 6B params. " + "Uses NAS to find the best pruned model that maximizes --prune_score_func. " + "Can be combined with --prune_target_active_params and/or --prune_target_memory_mb." + ), + ) + parser.add_argument( + "--prune_target_active_params", + type=float, + help=( + "Target active parameter count e.g., 3e9 for 3B active params (useful for MoE models). " + "Uses NAS to find the best pruned model that maximizes --prune_score_func. " + "Can be combined with --prune_target_params and/or --prune_target_memory_mb." + ), + ) + parser.add_argument( + "--prune_target_memory_mb", + type=float, + help=( + "Target memory footprint in MB (weights + KV-cache estimated via seq_length and calib_mbs; assumes BF16). " + "Uses NAS to find the best pruned model that maximizes --prune_score_func. " + "Can be combined with --prune_target_params and/or --prune_target_active_params." ), ) @@ -142,7 +163,7 @@ def get_args() -> argparse.Namespace: type=str, default="mmlu_10pct", help=( - "Score function to use for NAS-based pruning (--prune_target_params). Only supports MMLU at the moment. " + "Score function to use for NAS-based pruning. Only supports MMLU at the moment. " "Format: mmlu_pct where is the percentage of MMLU data to sample per subject " "(e.g. mmlu_10pct for 10%, mmlu_100pct for full eval)." ), @@ -152,7 +173,7 @@ def get_args() -> argparse.Namespace: type=int, default=None, help=( - "hidden_size / ffn_hidden_size divisor for NAS-based pruning (--prune_target_params). " + "hidden_size / ffn_hidden_size divisor for NAS-based pruning. " "Leave as None to use default divisors." ), ) @@ -162,14 +183,14 @@ def get_args() -> argparse.Namespace: default=0.4, help=( f"Maximum width pruning percentage ({mtp.mcore_minitron.SUPPORTED_HPARAMS - {'num_layers'}}) " - "for NAS-based pruning (--prune_target_params)" + "for NAS-based pruning" ), ) parser.add_argument( "--max_depth_pruning", type=float, default=0.2, - help="Maximum depth pruning percentage ('num_layers') for NAS-based pruning (--prune_target_params)", + help="Maximum depth pruning percentage ('num_layers') for NAS-based pruning", ) parser.add_argument( "--hparams_to_skip", @@ -178,7 +199,7 @@ def get_args() -> argparse.Namespace: default=[], choices=mtp.mcore_minitron.SUPPORTED_HPARAMS, help=( - "Space-separated list of hparams to skip for NAS-based pruning (--prune_target_params) " + "Space-separated list of hparams to skip for NAS-based pruning " "e.g. dont prune 'num_attention_heads'" ), ) @@ -187,13 +208,27 @@ def get_args() -> argparse.Namespace: type=int, default=10, help=( - "Number of top candidates to consider for NAS-based pruning (--prune_target_params). " + "Number of top candidates to consider for NAS-based pruning. " "Higher values will take longer to prune but may find a better model." ), ) args = parser.parse_args() + # Validate pruning target arguments + _nas_targets = [ + args.prune_target_params, + args.prune_target_active_params, + args.prune_target_memory_mb, + ] + if args.prune_export_config and any(t is not None for t in _nas_targets): + parser.error("--prune_export_config cannot be combined with NAS-based targets.") + if not args.prune_export_config and not any(t is not None for t in _nas_targets): + parser.error( + "At least one of --prune_export_config, --prune_target_params," + " --prune_target_active_params, or --prune_target_memory_mb is required." + ) + # Post-process arguments if args.prune_intermediate_ckpt is None: if args.output_megatron_path: @@ -250,11 +285,6 @@ def main(args: argparse.Namespace): init_model_parallel=True, moe_grouped_gemm=False, ) - print_rank_0(f"\nPruning model (showing PP rank0): {unwrapped_model}") - print_rank_0( - f"Original model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}" - ) - forward_loop = get_hf_mbridge_calibration_loop( model=model, provider=provider, @@ -271,10 +301,20 @@ def main(args: argparse.Namespace): "forward_loop": forward_loop, "checkpoint": args.prune_intermediate_ckpt, } - if args.prune_target_params is not None: - # Restrict search space to a smaller set of candidates - # Allow more choices for MoE FFN as they are generally smaller - # NOTE: You can reduce the divisors and increase config['top_k'] to potentially find a better model. + if args.prune_export_config is not None: + # Less restrictive search space for manual pruning + ss_config = mtp.mcore_minitron.get_mcore_minitron_config( + hidden_size_divisor=64, + ffn_hidden_size_divisor=64, + mamba_head_dim_divisor=8, + num_moe_experts_divisor=8, + num_layers_divisor=1, + ) + pruning_constraints = {"export_config": args.prune_export_config} + else: + # NAS-based pruning: restrict search space to a smaller set of candidates. + # Allow more choices for MoE FFN as they are generally smaller. + # NOTE: Reduce divisors and increase config['top_k'] to potentially find a better model. hidden_size_divisor = args.ss_channel_divisor if args.ss_channel_divisor else 256 ffn_hidden_size_divisor = ( args.ss_channel_divisor @@ -290,7 +330,14 @@ def main(args: argparse.Namespace): ) print_rank_0(f"Using search space config: {ss_config}") - pruning_constraints = {"params": args.prune_target_params} + pruning_constraints = {} + if args.prune_target_params is not None: + pruning_constraints["params"] = args.prune_target_params + if args.prune_target_active_params is not None: + pruning_constraints["active_params"] = args.prune_target_active_params + if args.prune_target_memory_mb is not None: + pruning_constraints["memory_mb"] = args.prune_target_memory_mb + print_rank_0( f"Using NAS-based automatic pruning with score function: {args.prune_score_func}. " "You can change this to be any other metric you want to maximize (e.g. negative validation loss)." @@ -313,17 +360,9 @@ def score_func(m): pruning_config["max_depth_pruning"] = args.max_depth_pruning pruning_config["hparams_to_skip"] = args.hparams_to_skip pruning_config["top_k"] = args.top_k - elif args.prune_export_config is not None: - # Less restrictive search space for manual pruning - ss_config = mtp.mcore_minitron.get_mcore_minitron_config( - hidden_size_divisor=64, - ffn_hidden_size_divisor=64, - mamba_head_dim_divisor=8, - num_moe_experts_divisor=8, - num_layers_divisor=1, - ) - - pruning_constraints = {"export_config": args.prune_export_config} + # memory_mb constraint requires batch_size and seq_length + pruning_config["batch_size"] = args.calib_mbs + pruning_config["seq_length"] = args.seq_length print_rank_0(f"Pruning constraints: {pruning_constraints}") unwrapped_model, pruning_scores = mtp.prune( # in-place pruning @@ -343,10 +382,6 @@ def score_func(m): else "hybrid_layer_pattern" ) setattr(provider, hybrid_key, getattr(unwrapped_model, hybrid_key)) - print_rank_0(f"\nPruned model (showing PP rank0): {unwrapped_model}") - print_rank_0( - f"Pruned model params: {num2hrb(mtp.mcore_minitron.get_mcore_param_count(unwrapped_model))}" - ) if args.output_megatron_path is not None: print_rank_0( diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 294f00031dd..895d3b8f182 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -179,7 +179,7 @@ If your model parameters are already sorted and you just want to prune the weigh | **Algorithm** | **Model** | **Pruning Constraints** | | :---: | :---: | :---: | -| Minitron | Megatron-core (M-LM, M-Bridge) based GPT / Mamba / MoE / Hybrid LLM Models1 | **Manual:** `export_config` with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) pruned values
**Auto:** `params` (requires `score_func` in config) | +| Minitron | Megatron-core (M-LM, M-Bridge) based GPT / Mamba / MoE / Hybrid LLM Models1 | **Manual:** `export_config` with width (`hidden_size`, `ffn_hidden_size`, `num_attention_heads`, `mamba_num_heads`, `mamba_head_dim`, `num_moe_experts`, `moe_ffn_hidden_size`, `moe_shared_expert_intermediate_size`) and/or depth (`num_layers`) pruned values
**Auto:** one or more of `params`, `active_params`, `memory_mb` (requires `score_func` in config) | | FastNAS | Computer Vision models | `flops`, `params` | | GradNAS | HuggingFace BERT, GPT-J | `flops`, `params` | diff --git a/modelopt/torch/nas/plugins/__init__.py b/modelopt/torch/nas/plugins/__init__.py index 5b439a86b48..50666daab23 100644 --- a/modelopt/torch/nas/plugins/__init__.py +++ b/modelopt/torch/nas/plugins/__init__.py @@ -21,9 +21,7 @@ with import_plugin("megatron"): from .megatron import * - -with import_plugin("transformer engine"): - from .transformer_engine import * + from .megatron_model_stats import * with import_plugin("transformers"): from .transformers import * diff --git a/modelopt/torch/nas/plugins/megatron_model_stats.py b/modelopt/torch/nas/plugins/megatron_model_stats.py new file mode 100644 index 00000000000..4ca72241718 --- /dev/null +++ b/modelopt/torch/nas/plugins/megatron_model_stats.py @@ -0,0 +1,606 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Analytical parameter count and memory footprint utilities for MCore GPT and Mamba/Hybrid models. + +These are fast, model-free alternatives to forward-pass-based counters. + +Layer conventions (validated against ``TELayerNormColumnParallelLinear`` and ``TEColumnParallelLinear`` MCore specs): + +- Dense attention / MLP layers: ``input_layernorm`` / ``pre_mlp_layernorm`` are fused into + ``linear_qkv`` / ``linear_fc1`` via ``TELayerNormColumnParallelLinear`` (their weight—and bias + for LayerNorm—count as part of that linear module's parameters). +- MoE layers: ``pre_mlp_layernorm`` is a *separate* ``TENorm`` (not fused); routed expert + ``linear_fc1`` uses plain ``TEColumnParallelLinear`` (no fused LN). Shared experts never have + bias (``assert add_bias_linear == False`` in ``SharedExpertMLP``). +- Mamba layers: ``in_proj`` uses ``TELayerNormColumnParallelLinear`` (fused LN). The internal + ``norm`` on ``d_inner`` is always RMSNorm regardless of the global ``normalization`` setting. +- GDN (``G``) layers are not currently supported and raise an error. + +Hybrid pattern characters (from ``megatron.core.ssm.mamba_hybrid_layer_allocation.Symbols``): + ``M`` = Mamba, ``*`` = Attention-only TransformerLayer, ``-`` = MLP-only TransformerLayer, + ``E`` = MoE-only TransformerLayer, ``G`` = GDN (unsupported), ``|`` = PP boundary (ignored), + ``/`` = MTP separator (everything from ``/`` onward is MTP and ignored). +""" + +import io +import sys +from typing import TYPE_CHECKING, Any + +from megatron.core.models.mamba.mamba_model import MambaModel +from rich.console import Console +from rich.panel import Panel +from rich.table import Table + +from modelopt.torch.utils import num2hrb, print_rank_0 + +if TYPE_CHECKING: + from megatron.core.models.gpt.gpt_model import GPTModel + + +__all__ = [ + "mcore_memory_footprint_mb", + "mcore_param_count", + "parse_main_layer_chars", + "print_mcore_model_stats", +] + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + +_HYBRID_MAMBA = "M" +_HYBRID_ATTN = "*" +_HYBRID_MLP = "-" +_HYBRID_MOE = "E" +_HYBRID_GDN = "G" + + +def _norm_params(size: int, normalization: str) -> int: + """Number of parameters for a norm layer (weight + optional bias).""" + return size * (2 if normalization == "LayerNorm" else 1) + + +def _attn_layer_params( + hidden_size: int, + num_attention_heads: int, + num_query_groups: int, + kv_channels: int, + add_bias_linear: bool, + normalization: str, + qk_layernorm: bool, + attention_output_gate: bool = False, +) -> int: + """Params for a single attention sublayer. + + Includes ``linear_qkv`` (with fused ``input_layernorm``), ``linear_proj``, and optional QK layernorms. + """ + # linear_qkv: hidden_size -> (Q + 2*KV) * kv_channels, with fused input_layernorm + qkv_out = (num_attention_heads + 2 * num_query_groups) * kv_channels + if attention_output_gate: + qkv_out += kv_channels * num_attention_heads + params = hidden_size * qkv_out + if add_bias_linear: + params += qkv_out + params += _norm_params(hidden_size, normalization) # fused input_layernorm + + # linear_proj: (num_attention_heads * kv_channels) -> hidden_size + params += num_attention_heads * kv_channels * hidden_size + if add_bias_linear: + params += hidden_size + + # optional per-head QK layernorm (q_layernorm + k_layernorm) + if qk_layernorm: + params += 2 * _norm_params(kv_channels, normalization) + + return params + + +def _dense_mlp_params( + hidden_size: int, + ffn_hidden_size: int, + gated_linear_unit: bool, + add_bias_linear: bool, + normalization: str, +) -> int: + """Params for a dense MLP sublayer. + + ``pre_mlp_layernorm`` is fused into ``linear_fc1`` (TELayerNormColumnParallelLinear). + """ + fc1_out = ffn_hidden_size * (2 if gated_linear_unit else 1) + params = hidden_size * fc1_out + if add_bias_linear: + params += fc1_out + params += _norm_params(hidden_size, normalization) # fused pre_mlp_layernorm + + params += ffn_hidden_size * hidden_size + if add_bias_linear: + params += hidden_size + + return params + + +def _moe_layer_params( + hidden_size: int, + num_moe_experts: int, + moe_router_topk: int, + moe_ffn_hidden_size: int, + gated_linear_unit: bool, + add_bias_linear: bool, + normalization: str, + moe_shared_expert_intermediate_size: int | None, + moe_shared_expert_gate: bool = False, +) -> tuple[int, int]: + """Params for a MoE sublayer, returned as (total, active). + + ``pre_mlp_layernorm`` is a *separate* TENorm (not fused into expert fc1). + Routed expert fc1/fc2 use ``TEColumnParallelLinear`` (no fused LN). + Shared experts never carry bias regardless of ``add_bias_linear``. + + ``total`` counts all ``num_moe_experts`` routed experts; ``active`` counts only + ``moe_router_topk`` (the experts actually used in each forward pass). The router, + pre-layernorm, and shared expert are always fully active and count equally in both. + """ + # Always-active: pre_mlp_layernorm + router + shared expert + always = _norm_params(hidden_size, normalization) + always += num_moe_experts * hidden_size # router weight + if add_bias_linear: + always += num_moe_experts # router bias + + # Shared expert (SharedExpertMLP always has add_bias_linear=False) + if moe_shared_expert_intermediate_size: + s_fc1_out = moe_shared_expert_intermediate_size * (2 if gated_linear_unit else 1) + always += hidden_size * s_fc1_out + moe_shared_expert_intermediate_size * hidden_size + if moe_shared_expert_gate: + always += hidden_size # gate_weight: 1 x hidden_size + + # Per routed-expert params + fc1_out = moe_ffn_hidden_size * (2 if gated_linear_unit else 1) + per_expert = hidden_size * fc1_out + moe_ffn_hidden_size * hidden_size + if add_bias_linear: + per_expert += fc1_out + moe_ffn_hidden_size + + total = always + num_moe_experts * per_expert + active = always + moe_router_topk * per_expert + return total, active + + +def _mamba_layer_params( + hidden_size: int, + mamba_num_heads: int, + mamba_head_dim: int, + mamba_num_groups: int, + mamba_state_dim: int, + normalization: str, + d_conv: int = 4, +) -> int: + """Params for a single Mamba layer. + + ``in_proj`` uses TELayerNormColumnParallelLinear (fused input LN). + The internal ``norm`` on ``d_inner`` is always RMSNorm (1 weight, no bias). + """ + d_inner = mamba_num_heads * mamba_head_dim + + # in_proj: hidden_size -> (2*d_inner + 2*ngroups*d_state + nheads), no bias, fused input LN + in_proj_out = 2 * d_inner + 2 * mamba_num_groups * mamba_state_dim + mamba_num_heads + params = hidden_size * in_proj_out + params += _norm_params(hidden_size, normalization) # fused input_layernorm + + # out_proj: d_inner -> hidden_size, no bias + params += d_inner * hidden_size + + # conv1d (depthwise) on (d_inner + 2*ngroups*d_state) channels: weight + bias + conv_dim = d_inner + 2 * mamba_num_groups * mamba_state_dim + params += conv_dim * d_conv + conv_dim + + # Scalar per-head params: A_log + dt_bias + D + params += 3 * mamba_num_heads + + # Internal RMSNorm on d_inner (always RMSNorm, 1 weight only) + params += d_inner + + return params + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def parse_main_layer_chars(hybrid_layer_pattern: str, num_layers: int | None = None) -> list[str]: + """Extract per-layer characters from the main (non-MTP) part of a hybrid pattern. + + Strips the MTP suffix (``/...``) and PP boundaries (``|``), returning one char per layer. + When ``num_layers`` is provided the result length must equal it exactly. + """ + main = hybrid_layer_pattern.split("/")[0] + chars = [c for c in main if c != "|"] + if num_layers is not None and len(chars) != num_layers: + raise ValueError( + f"Hybrid pattern '{hybrid_layer_pattern}' has {len(chars)} layers " + f"but num_layers={num_layers}." + ) + return chars + + +def mcore_param_count( + config: Any, + vocab_size: int, + share_embeddings_and_output_weights: bool = False, + hybrid_layer_pattern: str | None = None, + **overrides: Any, +) -> tuple[int, int]: + """Compute total and active parameter counts for an MCore GPT or Mamba/Hybrid model. + + For non-MoE models ``total == active``. For MoE models, ``active`` counts only the + ``moe_router_topk`` routed experts actually used per forward pass (router weights, + pre-layernorm, and shared experts are always active and count in both figures). + + Args: + config: MCore ``TransformerConfig`` (or any object exposing the same attributes). + vocab_size: Vocabulary size. + share_embeddings_and_output_weights: Whether the word-embedding and LM-head weights + are tied (output_layer excluded from the count when True). + hybrid_layer_pattern: Hybrid layer pattern string for Mamba/Hybrid models (e.g. + ``"M*M*E-"``) or ``None`` for a pure GPT model. Characters are the MCore + ``Symbols`` values; ``|`` PP boundaries and ``/...`` MTP suffixes are stripped. + **overrides: Per-call overrides for any ``config`` attribute. Useful for computing + counts for hypothetical configs without modifying the model. + + Returns: ``(total_params, active_params)`` + """ + + def _get(attr: str, default: Any = None) -> Any: + return overrides.get(attr, getattr(config, attr, default)) + + hidden_size: int = _get("hidden_size") + num_layers: int = _get("num_layers") + num_attention_heads: int = _get("num_attention_heads") + num_query_groups: int | None = _get("num_query_groups") + kv_channels: int | None = _get("kv_channels") + ffn_hidden_size: int | None = _get("ffn_hidden_size") + num_moe_experts: int | None = _get("num_moe_experts") + moe_router_topk: int = _get("moe_router_topk", 2) + moe_ffn_hidden_size: int | None = _get("moe_ffn_hidden_size") + moe_shared_expert_intermediate_size: int | None = _get("moe_shared_expert_intermediate_size") + moe_shared_expert_gate: bool = _get("moe_shared_expert_gate", False) + mamba_num_heads: int | None = _get("mamba_num_heads") + mamba_head_dim: int | None = _get("mamba_head_dim") + mamba_num_groups: int | None = _get("mamba_num_groups") + mamba_state_dim: int | None = _get("mamba_state_dim") + gated_linear_unit: bool = _get("gated_linear_unit", False) + add_bias_linear: bool = _get("add_bias_linear", False) + normalization: str = _get("normalization", "RMSNorm") + qk_layernorm: bool = _get("qk_layernorm", False) + attention_output_gate: bool = _get("attention_output_gate", False) + moe_layer_freq: int | list[int] = _get("moe_layer_freq", 1) + + # Fill in derived defaults + if num_query_groups is None: + num_query_groups = num_attention_heads + if kv_channels is None and num_attention_heads: + kv_channels = hidden_size // num_attention_heads + if moe_ffn_hidden_size is None and num_moe_experts is not None: + moe_ffn_hidden_size = ffn_hidden_size + + # Embedding + final norm + output layer (always active) + base = vocab_size * hidden_size + base += _norm_params(hidden_size, normalization) # final layernorm + if not share_embeddings_and_output_weights: + base += hidden_size * vocab_size + + total = base + active = base + + if hybrid_layer_pattern is None: + # ---- Pure GPT: all layers have attention + dense-MLP or attention + MoE ---- + assert kv_channels is not None, "kv_channels must be set for GPT attention layers" + if isinstance(moe_layer_freq, list): + moe_pattern = list(moe_layer_freq[:num_layers]) + else: + moe_pattern = [1 if (i % moe_layer_freq == 0) else 0 for i in range(num_layers)] + + for i in range(num_layers): + layer_t = layer_a = _attn_layer_params( + hidden_size, + num_attention_heads, + num_query_groups, + kv_channels, + add_bias_linear, + normalization, + qk_layernorm, + attention_output_gate, + ) + if moe_pattern[i] and num_moe_experts: + assert moe_ffn_hidden_size is not None, ( + "moe_ffn_hidden_size must be set for MoE layers" + ) + mt, ma = _moe_layer_params( + hidden_size, + num_moe_experts, + moe_router_topk, + moe_ffn_hidden_size, + gated_linear_unit, + add_bias_linear, + normalization, + moe_shared_expert_intermediate_size, + moe_shared_expert_gate, + ) + layer_t += mt + layer_a += ma + else: + assert ffn_hidden_size is not None, ( + "ffn_hidden_size must be set for dense MLP layers" + ) + mlp = _dense_mlp_params( + hidden_size, + ffn_hidden_size, + gated_linear_unit, + add_bias_linear, + normalization, + ) + layer_t += mlp + layer_a += mlp + total += layer_t + active += layer_a + else: + # ---- Hybrid / MambaModel: layer type is encoded in the pattern ---- + layer_chars = parse_main_layer_chars(hybrid_layer_pattern, num_layers) + + for char in layer_chars: + if char == _HYBRID_MAMBA: + assert mamba_num_heads is not None, "mamba_num_heads must be set for Mamba layers" + assert mamba_head_dim is not None, "mamba_head_dim must be set for Mamba layers" + assert mamba_num_groups is not None, "mamba_num_groups must be set for Mamba layers" + assert mamba_state_dim is not None, "mamba_state_dim must be set for Mamba layers" + t = a = _mamba_layer_params( + hidden_size, + mamba_num_heads, + mamba_head_dim, + mamba_num_groups, + mamba_state_dim, + normalization, + ) + elif char == _HYBRID_ATTN: + assert kv_channels is not None, "kv_channels must be set for attention layers" + t = a = _attn_layer_params( + hidden_size, + num_attention_heads, + num_query_groups, + kv_channels, + add_bias_linear, + normalization, + qk_layernorm, + attention_output_gate, + ) + elif char == _HYBRID_MLP: + assert ffn_hidden_size is not None, "ffn_hidden_size must be set for MLP layers" + t = a = _dense_mlp_params( + hidden_size, + ffn_hidden_size, + gated_linear_unit, + add_bias_linear, + normalization, + ) + elif char == _HYBRID_MOE: + assert num_moe_experts is not None, "num_moe_experts must be set for MoE layers" + assert moe_ffn_hidden_size is not None, ( + "moe_ffn_hidden_size must be set for MoE layers" + ) + t, a = _moe_layer_params( + hidden_size, + num_moe_experts, + moe_router_topk, + moe_ffn_hidden_size, + gated_linear_unit, + add_bias_linear, + normalization, + moe_shared_expert_intermediate_size, + moe_shared_expert_gate, + ) + else: + raise ValueError(f"Unsupported hybrid layer character: {char}") + total += t + active += a + + return total, active + + +def mcore_memory_footprint_mb( + config: Any, + vocab_size: int, + share_embeddings_and_output_weights: bool = False, + hybrid_layer_pattern: str | None = None, + dtype_bytes: int = 2, + kv_cache_dtype_bytes: int | None = None, + sequence_length: int = 4096, + batch_size: int = 1, + **overrides: Any, +) -> tuple[float, float, float, float]: + """Compute inference memory footprint in MB for an MCore model. + + Covers three components: + + * **params**: model weights at ``dtype_bytes`` precision. + * **kv_cache**: KV cache for all attention layers (2 tensors per layer at ``kv_cache_dtype_bytes`` precision). + * **mamba_state**: recurrent SSM sliding-window state stored for all Mamba layers during + generation (one buffer of size ``(d_inner + 2*ngroups*d_state) * d_conv`` per layer). + + Args: + config: MCore ``TransformerConfig`` (or any object exposing the same attributes). + vocab_size: Vocabulary size. + share_embeddings_and_output_weights: Tied embedding/LM-head flag. + hybrid_layer_pattern: Hybrid layer pattern (``None`` for pure GPT). + dtype_bytes: Bytes per parameter (2 for fp16/bf16, 4 for fp32). + kv_cache_dtype_bytes: Bytes per KV-cache element; defaults to ``dtype_bytes``. + sequence_length: Context length for KV-cache sizing. + batch_size: Batch size for KV-cache and Mamba-state sizing. + **overrides: Config attribute overrides (same as :func:`mcore_param_count`). + + Returns: ``(params_mb, kv_cache_mb, mamba_state_mb, total_mb)`` + """ + if kv_cache_dtype_bytes is None: + kv_cache_dtype_bytes = dtype_bytes + + def _get(attr: str, default: Any = None) -> Any: + return overrides.get(attr, getattr(config, attr, default)) + + hidden_size: int = _get("hidden_size") + num_layers: int = _get("num_layers") + num_attention_heads: int = _get("num_attention_heads") + num_query_groups: int | None = _get("num_query_groups") + kv_channels: int | None = _get("kv_channels") + mamba_num_heads: int | None = _get("mamba_num_heads") + mamba_head_dim: int | None = _get("mamba_head_dim") + mamba_num_groups: int | None = _get("mamba_num_groups") + mamba_state_dim: int | None = _get("mamba_state_dim") + + if num_query_groups is None: + num_query_groups = num_attention_heads + if kv_channels is None and num_attention_heads: + kv_channels = hidden_size // num_attention_heads + + # Parameter memory + total_params, _ = mcore_param_count( + config, + vocab_size, + share_embeddings_and_output_weights, + hybrid_layer_pattern, + **overrides, + ) + params_bytes = total_params * dtype_bytes + + # Count attention and Mamba layers from pattern + if hybrid_layer_pattern is None: + n_attn = num_layers + n_mamba = 0 + else: + chars = parse_main_layer_chars(hybrid_layer_pattern, num_layers) + n_attn = chars.count(_HYBRID_ATTN) + n_mamba = chars.count(_HYBRID_MAMBA) + + # KV cache: 2 tensors (K, V) per attention layer + # each tensor: [batch_size, sequence_length, num_query_groups, kv_channels] + kv_bytes = 0 + if n_attn > 0: + if num_query_groups is None or kv_channels is None: + raise ValueError( + "num_query_groups and kv_channels must be set when attention layers exist." + ) + kv_per_layer = 2 * batch_size * sequence_length * num_query_groups * kv_channels + kv_bytes = n_attn * kv_per_layer * kv_cache_dtype_bytes + + # Mamba recurrent state per layer (both caches needed for autoregressive generation): + # conv1d sliding window: [batch, d_inner + 2*ngroups*d_state, d_conv - 1] + # SSM recurrent state: [batch, nheads, d_head, d_state] + mamba_bytes = 0 + if n_mamba > 0: + if None in (mamba_num_heads, mamba_head_dim, mamba_num_groups, mamba_state_dim): + raise ValueError( + "mamba_num_heads, mamba_head_dim, mamba_num_groups, and mamba_state_dim " + "must be set when Mamba layers exist." + ) + d_inner = mamba_num_heads * mamba_head_dim + d_conv = 4 # hardcoded in MambaMixer + conv_dim = d_inner + 2 * mamba_num_groups * mamba_state_dim + conv_state = batch_size * conv_dim * (d_conv - 1) + ssm_state = batch_size * mamba_num_heads * mamba_head_dim * mamba_state_dim + mamba_bytes = n_mamba * (conv_state + ssm_state) * dtype_bytes + + _mb = 1024**2 + params_mb = params_bytes / _mb + kv_cache_mb = kv_bytes / _mb + mamba_state_mb = mamba_bytes / _mb + total_mb = params_mb + kv_cache_mb + mamba_state_mb + return params_mb, kv_cache_mb, mamba_state_mb, total_mb + + +def print_mcore_model_stats( + model: "GPTModel | MambaModel", + label: str = "Model", + seq_length: int = 4096, + batch_size: int = 1, + dtype_bytes: int = 2, +) -> None: + """Print total params, active params, and memory footprint for an MCore model. + + Args: + model: GPTModel or MambaModel to print stats for. + label: Label prefix for the output line (e.g. ``"Original"``, ``"Pruned"``). + seq_length: Sequence length for KV-cache / Mamba-state memory estimate. + batch_size: Batch size for KV-cache / Mamba-state memory estimate. + dtype_bytes: Bytes per parameter for memory estimation (default: 2 for BF16). + """ + hybrid_layer_pattern: str | None = None + config_overrides: dict = {} + if isinstance(model, MambaModel): + hybrid_key = ( + "hybrid_override_pattern" + if hasattr(model, "hybrid_override_pattern") + else "hybrid_layer_pattern" + ) + hybrid_layer_pattern = getattr(model, hybrid_key) + # mamba_num_heads may not be stored in config when derived from model architecture; + # fall back to reading it from the actual layer. + if getattr(model.config, "mamba_num_heads", None) is None: + for layer in model.decoder.layers: + if hasattr(layer, "mixer") and hasattr(layer.mixer, "nheads"): + config_overrides["mamba_num_heads"] = layer.mixer.nheads + break + + total, active = mcore_param_count( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + hybrid_layer_pattern=hybrid_layer_pattern, + **config_overrides, + ) + params_mb, kv_cache_mb, mamba_state_mb, total_mb = mcore_memory_footprint_mb( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + hybrid_layer_pattern=hybrid_layer_pattern, + dtype_bytes=dtype_bytes, + sequence_length=seq_length, + batch_size=batch_size, + **config_overrides, + ) + dtype_str = {1: "FP8", 2: "BF16", 4: "FP32"}.get(dtype_bytes, f"{dtype_bytes}B") + + grid = Table.grid(padding=(0, 2)) + grid.add_column(style="bold cyan", no_wrap=True) + grid.add_column() + grid.add_row("Total Parameters", num2hrb(total)) + if active != total: + grid.add_row("Active Parameters", num2hrb(active)) + + mem_items = [f"weights: {params_mb:.1f} MB", f"kv_cache: {kv_cache_mb:.1f} MB"] + if mamba_state_mb > 0: + mem_items.append(f"mamba_state: {mamba_state_mb:.1f} MB") + mem_items.append(f"[bold]Total: {total_mb:.1f} MB[/bold]") + grid.add_row(f"Memory ({dtype_str}, {seq_length=}, {batch_size=})", ", ".join(mem_items)) + + buf = io.StringIO() + Console(file=buf, highlight=False, force_terminal=sys.stdout.isatty()).print( + Panel( + grid, + title=f"[bold cyan]{label} Stats[/bold cyan]", + border_style="cyan", + padding=(0, 1), + expand=False, + ) + ) + print_rank_0() + print_rank_0(buf.getvalue()) diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index e99a44a7910..0aee519f137 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -24,6 +24,8 @@ Actual dynamic module implementations are at :mod:`modelopt.torch.nas.plugins.megatron`. """ +import io +import sys from collections.abc import Callable from dataclasses import dataclass from functools import partial @@ -38,16 +40,23 @@ from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.mamba.mamba_model import MambaModel from megatron.core.parallel_state import ( + get_expert_tensor_and_model_parallel_group, + get_expert_tensor_parallel_rank, get_pipeline_model_parallel_group, get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, ) from megatron.core.tensor_parallel import ( gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, ) from pydantic import create_model +from rich.console import Console +from rich.markup import escape as rich_escape +from rich.panel import Panel +from rich.table import Table from tqdm import tqdm from modelopt.torch.nas.conversion import NASModeRegistry @@ -63,6 +72,12 @@ _DynamicSequentialMLP, _DynamicTransformerLayer, ) +from modelopt.torch.nas.plugins.megatron_model_stats import ( + mcore_memory_footprint_mb, + mcore_param_count, + parse_main_layer_chars, + print_mcore_model_stats, +) from modelopt.torch.nas.registry import DMRegistry from modelopt.torch.nas.utils import get_subnet_config, sample, sort_parameters from modelopt.torch.opt.config import ModeloptBaseConfig, get_kwargs_for_create_model_with_rules @@ -106,7 +121,6 @@ "MCoreMinitronSearcher", "drop_mcore_language_model_layers", "get_mcore_minitron_config", - "get_mcore_param_count", ] @@ -164,10 +178,26 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i model.config.num_layers = new_num_layers +def _rprint(*renderables: Any) -> None: + """Render rich renderables and print on rank 0 only.""" + buf = io.StringIO() + Console(file=buf, highlight=False, force_terminal=sys.stdout.isatty(), width=160).print( + *renderables + ) + print_rank_0() + print_rank_0(buf.getvalue()) + + +# Constraint keys that trigger the grid-search path in MCoreMinitronSearcher. +# Order defines priority: first active key is used as the primary display/sort metric. +_METRIC_CONSTRAINT_PRIORITY = ("params", "active_params", "memory_mb") +_METRIC_CONSTRAINTS = frozenset(_METRIC_CONSTRAINT_PRIORITY) + + @dataclass class CandidateSubnet: ss_config: dict - params: float + metrics: dict[str, float] score: float | None @@ -177,20 +207,26 @@ class CandidateSubnet: class MCoreMinitronSearcher(BaseSearcher): """Searcher for Minitron pruning algorithm. - Available additional config options (used when `params` constraint is provided): + Supported constraint keys: ``export_config``, ``params``, ``active_params``, ``memory_mb``. + + Available additional config options (used when a metric constraint is provided): - `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.40). Only top (1 - max_width_pruning) choices will be considered. - `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.20). Only top (1 - max_depth_pruning) choices will be considered. - `hparams_to_skip`: List of hparams to skip during the search (default: None). - `top_k`: Number of candidates to consider for score_func validation (default: 10). + - `seq_length`: Sequence length for KV-cache memory estimate (default: 4096). + Only used with the ``memory_mb`` constraint. + - `batch_size`: Batch size for KV-cache and Mamba-state memory estimate (default: 1). + Only used with the ``memory_mb`` constraint. """ local_activations: dict[str, torch.Tensor] layer_scores: dict[int, torch.Tensor] sorted_layers: list[int] | None # 1-indexed sorted list of layer numbers # Dict from params constraint to list of all CandidateSubnets fitting that constraint - all_candidates_per_constraint: dict[float, list[CandidateSubnet]] + all_candidates_per_constraint: dict[tuple, list[CandidateSubnet]] @property def default_search_config(self) -> SearchConfig: @@ -200,11 +236,14 @@ def default_search_config(self) -> SearchConfig: "max_iter_data_loader": 1024, "skip_sorting": False, "scores_path": None, - # Additional search config for parameter-based pruning + # Additional search config for metric-based pruning "max_width_pruning": 0.40, "max_depth_pruning": 0.20, "hparams_to_skip": None, "top_k": 10, + # Memory footprint config (only used with memory_mb constraint) + "seq_length": 4096, + "batch_size": 1, } @property @@ -229,11 +268,16 @@ def before_search(self) -> None: """Optional pre-processing steps before the search.""" super().before_search() - # Check that the constraint is valid - assert len(self.constraints) == 1 and next(iter(self.constraints.keys())) in { - "export_config", - "params", - }, "Only `export_config` or `params` constraint is supported!" + # Check that the constraint is valid. + # export_config must be the sole key; metric constraints can be combined freely. + active_metric_keys = self.constraints.keys() & _METRIC_CONSTRAINTS + assert self.constraints.keys() <= {"export_config"} | _METRIC_CONSTRAINTS, ( + f"Only {sorted({'export_config'} | _METRIC_CONSTRAINTS)} constraints are supported!" + ) + assert not ("export_config" in self.constraints and active_metric_keys), ( + "export_config cannot be combined with metric constraints!" + ) + assert self.constraints, "At least one constraint must be provided!" if "export_config" in self.constraints: export_config = self.constraints["export_config"] @@ -253,10 +297,11 @@ def before_search(self) -> None: # If a user only prunes depth, we should not sort width parameters self.hps_to_sort = set(export_config.keys()) else: - assert isinstance(self.constraints["params"], (int, float)), "params must be a float!" - assert self.has_score, "score_func (e.g. MMLU) is required for parameter-based pruning!" + for k in active_metric_keys: + assert isinstance(self.constraints[k], (int, float)), f"{k} must be a float!" + assert self.has_score, "score_func (e.g. MMLU) is required for metric-based pruning!" export_config = None - # Sort all parameters for parameter-based pruning + # Sort all parameters for metric-based pruning self.hps_to_sort = SUPPORTED_HPARAMS for n, hp in named_hparams(self.model, unique=True): @@ -276,6 +321,9 @@ def before_search(self) -> None: def run_search(self) -> None: """Run forward loop to collect activations, sort parameters, and prune the model.""" + print_mcore_model_stats( + self.model, "Original Model", self.config["seq_length"], self.config["batch_size"] + ) registry = ImportanceEstimatorRegistry(self.model) if self.local_activations and self.layer_scores: # Available from per-rank checkpoint registry.set_local_activations_and_layer_scores( @@ -315,8 +363,8 @@ def run_search(self) -> None: ), "Cannot prune `num_layers` without collecting layer scores!" self.sorted_layers = None - if "params" in self.constraints: - export_config = self.search_best_arch_by_params() + if self.constraints.keys() & _METRIC_CONSTRAINTS: + export_config = self.search_best_arch_by_metrics() else: export_config = self.constraints["export_config"] @@ -346,6 +394,10 @@ def run_search(self) -> None: ) print_rank_0(f"Pruned {hybrid_key}: {getattr(self.model, hybrid_key)}") + print_mcore_model_stats( + self.model, "Pruned Model", self.config["seq_length"], self.config["batch_size"] + ) + def _prune(self, export_config: dict, prune_depth: bool = True) -> None: """Prune the model homogeneously based on the export_config by setting active choices for configurable hparams. @@ -391,25 +443,27 @@ def _prune(self, export_config: dict, prune_depth: bool = True) -> None: if isinstance(m, _DynamicMoELayer): m._export_reinit_token_dispatcher() - def search_best_arch_by_params(self) -> dict: - """Search for the best architecture based on the given parameters constraints. + def search_best_arch_by_metrics(self) -> dict: + """Search for the best architecture based on the given metric constraint. - We perform a grid-search over the search space to find subnets (homogeneous) fitting the constraints. - Top-k candidates (sorted by param count) are then validated using the score_func (e.g. MMLU) - and the best subnet is returned. + Supports ``params``, ``active_params``, and ``memory_mb`` constraints. + Performs a grid-search over the search space to find subnets fitting the constraint, + then validates the top-k candidates using ``score_func`` (e.g. MMLU). Returns: export_config: Dictionary mapping hyperparameter names to their pruned values. """ assert self.sorted_layers is not None - max_params = float(self.constraints["params"]) # type: ignore[arg-type] + # Ordered list of active metric keys; primary (first) is used for sorting/display. + active_metric_keys = [k for k in _METRIC_CONSTRAINT_PRIORITY if k in self.constraints] + primary_key = active_metric_keys[0] + max_metrics: dict[str, float] = {k: float(self.constraints[k]) for k in active_metric_keys} # type: ignore[arg-type] max_width_pruning = self.config["max_width_pruning"] max_depth_pruning = self.config["max_depth_pruning"] hparams_to_skip = self.config["hparams_to_skip"] top_k = self.config["top_k"] - print_rank_0( - f"\nSearching for the best pruned architecture under {num2hrb(max_params)} params constraints..." - ) + constraints_str = ", ".join(f"{self._fmt_metric(v, k)} {k}" for k, v in max_metrics.items()) + print_rank_0(f"\nSearching for the best pruned architecture under {constraints_str}...") # 1. Find available search space choices (across all PP ranks) hp_choices = {} @@ -425,8 +479,9 @@ def search_best_arch_by_params(self) -> dict: }, ) - # 2. Perform grid-search over the search space to find subnets fitting the constraints - if max_params not in self.all_candidates_per_constraint: + # 2. Perform grid-search over the search space to find subnets fitting all constraints + constraints_cache_key = tuple((k, max_metrics[k]) for k in active_metric_keys) + if constraints_cache_key not in self.all_candidates_per_constraint: max_num_layers = self.model.get_hparam("num_layers").max search_space_configs = MCoreMinitronSearcher._generate_search_space_combos( hp_choices, @@ -434,35 +489,39 @@ def search_best_arch_by_params(self) -> dict: max_depth_pruning, hparams_to_skip, ) - sample(self.model, sample_func=max) # reset to max subnet (for sanity) selected = [] for ss_config in tqdm( search_space_configs, desc="Finding all candidates fitting the constraints...", disable=not dist.is_master(), ): - self._prune(ss_config, prune_depth=False) - layer_ids = None - if "num_layers" in ss_config and ss_config["num_layers"] < max_num_layers: - layer_ids = self.sorted_layers[: ss_config["num_layers"]] - candidate_params = _param_num_dynamic(self.model, layer_numbers_to_count=layer_ids) - if candidate_params <= max_params: - selected.append(CandidateSubnet(ss_config, candidate_params, None)) - sample(self.model, sample_func=max) # reset to max subnet + candidate_metrics = self._compute_candidate_metrics(ss_config, max_num_layers) + if all(candidate_metrics[k] <= max_metrics[k] for k in active_metric_keys): + selected.append( + CandidateSubnet( + ss_config, {k: candidate_metrics[k] for k in active_metric_keys}, None + ) + ) assert len(selected) > 0, "No subnets found fitting the constraints!" print_rank_0(f"Found {len(selected)} candidates fitting the constraints!") - self.all_candidates_per_constraint[max_params] = sorted( - selected, key=lambda x: x.params, reverse=True + self.all_candidates_per_constraint[constraints_cache_key] = sorted( + selected, key=lambda x: x.metrics[primary_key], reverse=True ) self.save_search_checkpoint(verbose=True) else: print_rank_0(f"\nUsing top {top_k} candidates from checkpoint") - top_k_candidates = self.all_candidates_per_constraint[max_params][:top_k] - - print_rank_0(f"\n====================\nTop {top_k} candidates:") - for candidate in top_k_candidates: - print_rank_0(f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params") - print_rank_0("====================\n") + top_k_candidates = self.all_candidates_per_constraint[constraints_cache_key][:top_k] + + table = Table(title=f"Top {top_k} Candidates", show_header=True, header_style="bold") + table.add_column("#", justify="right", style="dim", no_wrap=True) + table.add_column("export_config", overflow="fold") + for k in active_metric_keys: + table.add_column(k, justify="right") + for i, candidate in enumerate(top_k_candidates, 1): + row = [str(i), rich_escape(str(candidate.ss_config))] + row += [self._fmt_metric(candidate.metrics[k], k) for k in active_metric_keys] + table.add_row(*row) + _rprint(table) # 3. Optional Knowledge Distillation (KD) step for all top-k candidates print_rank_0( @@ -508,27 +567,47 @@ def search_best_arch_by_params(self) -> dict: layer.layer_number = start_layer_number start_layer_number += 1 self.model.decoder.layers = all_layers - print_rank_0( - f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score\n" + metrics_str = ", ".join( + f"{self._fmt_metric(v, k)} {k}" for k, v in candidate.metrics.items() ) + print_rank_0(f"\t{candidate.ss_config} -> {metrics_str}, {candidate.score:.4f} score\n") for m in _routers_with_expert_bias: m.enable_expert_bias = True - print_rank_0(f"\n====================\nTop {top_k} candidates with scores:") - for candidate in top_k_candidates: - print_rank_0( - f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score" - ) - print_rank_0("====================\n") + scored_table = Table( + title=f"Top {top_k} Candidates with Scores", show_header=True, header_style="bold" + ) + scored_table.add_column("#", justify="right", style="dim", no_wrap=True) + scored_table.add_column("export_config") + for k in active_metric_keys: + scored_table.add_column(k, justify="right") + scored_table.add_column("score", justify="right") + for i, candidate in enumerate(top_k_candidates, 1): + row = [str(i), rich_escape(str(candidate.ss_config))] + row += [self._fmt_metric(candidate.metrics[k], k) for k in active_metric_keys] + row.append(f"{candidate.score:.4f}") + scored_table.add_row(*row) + _rprint(scored_table) dist.barrier() best = max(top_k_candidates, key=lambda x: x.score) # type: ignore[arg-type, return-value] - print_rank_0( - f"\n[BEST SUBNET] {best.ss_config} -> {num2hrb(best.params)} params, {best.score:.4f} score\n" + best_grid = Table.grid(padding=(0, 2)) + best_grid.add_column(style="bold green", no_wrap=True) + best_grid.add_column() + best_grid.add_row("export_config", rich_escape(str(best.ss_config))) + for k, v in best.metrics.items(): + best_grid.add_row(k, self._fmt_metric(v, k)) + best_grid.add_row("score", f"{best.score:.4f}") + _rprint( + Panel(best_grid, title="[bold green]Best Subnet[/bold green]", border_style="green") ) return best.ss_config + def _fmt_metric(self, value: float, constraint_key: str) -> str: + """Format a metric value for display.""" + return f"{value:.3f} MB" if constraint_key == "memory_mb" else num2hrb(value) + @staticmethod def _generate_search_space_combos( search_space: dict[str, list], @@ -557,11 +636,6 @@ def _generate_search_space_combos( {"hidden_size": 4096, "num_layers": 32}, ] """ - print_rank_0( - f"\nOnly considering atmost {(max_width_pruning * 100):.0f}% for width and " - f"{max_depth_pruning * 100:.0f}% for depth pruning hparams" - ) - if hparams_to_skip: search_space = dict(search_space) # Avoid modifying the original search space print_rank_0(f"Skipping {hparams_to_skip=} during search space generation...") @@ -582,10 +656,19 @@ def _generate_search_space_combos( } ss_size = 1 + table = Table( + title=f"Search Space \n(≤{max_width_pruning * 100:.0f}% width / ≤{max_depth_pruning * 100:.0f}% depth pruning)", # noqa: E501 + show_header=True, + header_style="bold", + ) + table.add_column("Hyperparameter") + table.add_column("Choices", overflow="fold") for k, v in filtered_ss.items(): - print_rank_0(f"\tSearch space for {k}: {v}") + table.add_row(k, rich_escape(str(v))) ss_size *= len(v) - print_rank_0(f"\tTotal search space in consideration: {ss_size}\n") + table.add_section() + table.add_row("Search space size", f"{ss_size}") + _rprint(table) hparam_names = list(filtered_ss.keys()) hparam_choices_lists = [filtered_ss[name] for name in hparam_names] @@ -597,29 +680,120 @@ def _generate_search_space_combos( return search_space_combos + def _compute_candidate_metrics(self, ss_config: dict, max_num_layers: int) -> dict[str, float]: + """Compute all active metric constraint values for a candidate config analytically. -def get_mcore_param_count(model: GPTModel | MambaModel) -> float: - """Get the number of parameters in the MCore GPTModel or MambaModel (reduced across TP and PP ranks).""" + Calls ``mcore_param_count`` at most once (covers both ``params`` and ``active_params``) + and ``mcore_memory_footprint_mb`` at most once (for ``memory_mb``). + Replaces the slow ``_prune → _param_num_dynamic → sample(max)`` loop used during search. + Handles depth pruning by filtering the hybrid layer pattern to the kept (best) layers. + """ + model = self.model + active_metric_keys = self.constraints.keys() & _METRIC_CONSTRAINTS + + # Get hybrid layer pattern for MambaModel (None for pure GPT) + hybrid_layer_pattern: str | None = None + if isinstance(model, MambaModel): + hybrid_key = ( + "hybrid_override_pattern" + if hasattr(self.model, "hybrid_override_pattern") + else "hybrid_layer_pattern" + ) + hybrid_layer_pattern = getattr(model, hybrid_key) + + # If depth pruning on a hybrid model, filter the pattern to only the kept layers. + # sorted_layers gives layer numbers (1-indexed) ordered best-first; we keep the top N. + num_layers_target: int = ss_config.get("num_layers", max_num_layers) + if hybrid_layer_pattern is not None and num_layers_target < max_num_layers: + assert self.sorted_layers is not None + kept = set(self.sorted_layers[:num_layers_target]) + layer_chars = parse_main_layer_chars(hybrid_layer_pattern) + hybrid_layer_pattern = "".join(c for i, c in enumerate(layer_chars) if (i + 1) in kept) + + metrics: dict[str, float] = {} + + if active_metric_keys & {"params", "active_params"}: + total, active = mcore_param_count( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + hybrid_layer_pattern=hybrid_layer_pattern, + **ss_config, + ) + if "params" in active_metric_keys: + metrics["params"] = total + if "active_params" in active_metric_keys: + metrics["active_params"] = active + + if "memory_mb" in active_metric_keys: + _, _, _, metrics["memory_mb"] = mcore_memory_footprint_mb( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + hybrid_layer_pattern=hybrid_layer_pattern, + dtype_bytes=2, # assume BF16 input + sequence_length=self.config["seq_length"], + batch_size=self.config["batch_size"], + **ss_config, + ) + + return metrics + + +def get_mcore_param_count(model: GPTModel | MambaModel) -> int: + """Get the number of parameters in the MCore GPTModel or MambaModel (reduced across TP, EP, ETP, and PP ranks).""" assert isinstance(model, (GPTModel, MambaModel)), "Model must be a GPTModel or MambaModel" if isinstance(model, DynamicModule): - return _param_num_dynamic(model) + return int(_param_num_dynamic(model)) else: - return _param_num(model) + return int(_param_num(model)) def _param_num(model: GPTModel | MambaModel) -> float: - """Get the number of parameters in the model (reduced across TP and PP ranks).""" - # Dont double count output_layer parameters if model.share_embeddings_and_output_weights is True - params = sum( - p.numel() - for name, p in model.named_parameters() - if not model.share_embeddings_and_output_weights or "output_layer.weight" not in name - ) + """Get the number of parameters in the model (reduced across TP, EP, ETP, and PP ranks). - reduced_params = torch.Tensor([params]).to(device=next(model.parameters()).device) - torch.distributed.all_reduce(reduced_params, group=get_pipeline_model_parallel_group()) - torch.distributed.all_reduce(reduced_params, group=get_tensor_model_parallel_group()) - return reduced_params.item() + Expert params (``allreduce=False``) are EP/ETP-sharded and require a separate reduction over + the joint EPxETP group rather than the regular TP group. Non-expert params are reduced over + PP and TP as before. When EP is not configured the expert path is a no-op. + """ + tp_rank = get_tensor_model_parallel_rank() + # get_expert_tensor_parallel_rank() falls back to tp_rank when ETP is not configured. + etp_rank = get_expert_tensor_parallel_rank() + + regular_params = 0 # allreduce=True (or unset): replicated / TP-sharded + expert_params = 0 # allreduce=False: EP-sharded (and possibly ETP-sharded) + + for name, p in model.named_parameters(): + if model.share_embeddings_and_output_weights and "output_layer.weight" in name: + continue + is_expert_parallel = not getattr(p, "allreduce", True) + is_tp_sharded = getattr(p, "tensor_model_parallel", False) + if is_expert_parallel: + # EP/ETP-sharded: ETP-sharded params are summed across all ETP ranks; non-ETP params + # are counted only on ETP rank 0 to avoid multiplying by ETP size in the EPxETP reduce. + if not is_tp_sharded and etp_rank != 0: + continue + expert_params += p.numel() + else: + # Non-expert: TP-sharded params are summed; replicated params counted on TP rank 0. + if not is_tp_sharded and tp_rank != 0: + continue + regular_params += p.numel() + + device = next(model.parameters()).device + + regular_tensor = torch.tensor([regular_params], device=device) + torch.distributed.all_reduce(regular_tensor, group=get_pipeline_model_parallel_group()) + torch.distributed.all_reduce(regular_tensor, group=get_tensor_model_parallel_group()) + + ep_etp_group = get_expert_tensor_and_model_parallel_group(check_initialized=False) + if ep_etp_group is not None: + expert_tensor = torch.tensor([expert_params], device=device) + torch.distributed.all_reduce(expert_tensor, group=get_pipeline_model_parallel_group()) + torch.distributed.all_reduce(expert_tensor, group=ep_etp_group) + return (regular_tensor + expert_tensor).item() + + return regular_tensor.item() def _param_num_dynamic( @@ -640,6 +814,10 @@ def get_param_count(mod, name) -> int: submodule = mod.get_submodule(module_path) if module_path else mod return getattr(submodule, param_name).numel() + assert model.config.tensor_model_parallel_size == 1, ( + "_param_num_dynamic does not support tensor parallelism (TP > 1)" + ) + # Account for depth pruning with uneven PP and hybrid models! # Dont double count output_layer parameters if model.share_embeddings_and_output_weights is True params = sum( diff --git a/tests/_test_utils/torch/transformers_models.py b/tests/_test_utils/torch/transformers_models.py index a6bfc4484a8..34bc96cd0ae 100644 --- a/tests/_test_utils/torch/transformers_models.py +++ b/tests/_test_utils/torch/transformers_models.py @@ -110,7 +110,7 @@ def get_tiny_qwen3_moe(**config_kwargs) -> PreTrainedModel: def create_tiny_qwen3_moe_dir( tmp_path: Path | str, with_tokenizer: bool = False, **config_kwargs -) -> Path: +) -> Path | tuple[Path, PreTrainedModel]: qwen3_moe_dir = Path(tmp_path) / "tiny_qwen3_moe" if with_tokenizer: tokenizer = tokenizer = get_tiny_tokenizer() diff --git a/tests/examples/megatron_bridge/test_prune_minitron.py b/tests/examples/megatron_bridge/test_prune_minitron.py index 57de7fe7985..b63a7720736 100644 --- a/tests/examples/megatron_bridge/test_prune_minitron.py +++ b/tests/examples/megatron_bridge/test_prune_minitron.py @@ -18,6 +18,7 @@ from _test_utils.examples.run_command import extend_cmd_parts, run_example_command from _test_utils.torch.transformers_models import create_tiny_qwen3_dir +from transformers import AutoModelForCausalLM def test_prune_minitron(tmp_path: Path, num_gpus): @@ -25,6 +26,7 @@ def test_prune_minitron(tmp_path: Path, num_gpus): tmp_path, with_tokenizer=True, return_model=True, num_hidden_layers=num_gpus ) teacher_params = sum(p.numel() for p in teacher_model.parameters()) + prune_target_params = int(teacher_params * 0.8) pruned_model_path = tmp_path / "pruned" prune_command_parts = extend_cmd_parts( @@ -35,7 +37,7 @@ def test_prune_minitron(tmp_path: Path, num_gpus): calib_dataset_name="cnn_dailymail", calib_num_samples=16, seq_length=32, - prune_target_params=teacher_params * 0.8, + prune_target_params=prune_target_params, prune_score_func="mmlu_1pct", ss_channel_divisor=4, hparams_to_skip="num_attention_heads", @@ -43,3 +45,7 @@ def test_prune_minitron(tmp_path: Path, num_gpus): ) run_example_command(prune_command_parts, example_path="megatron_bridge") assert (pruned_model_path / "config.json").exists() + + pruned_model = AutoModelForCausalLM.from_pretrained(pruned_model_path) + pruned_params = sum(p.numel() for p in pruned_model.parameters()) + assert pruned_params <= prune_target_params diff --git a/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py index db8b9e10ba6..b639e34494a 100644 --- a/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py +++ b/tests/gpu_megatron/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py @@ -21,9 +21,14 @@ from _test_utils.torch.megatron.models import get_mcore_mamba_hybrid_model from _test_utils.torch.megatron.utils import run_mcore_inference -from megatron.core.parallel_state import is_pipeline_first_stage, is_pipeline_last_stage +from megatron.core.parallel_state import ( + get_pipeline_model_parallel_group, + is_pipeline_first_stage, + is_pipeline_last_stage, +) import modelopt.torch.nas as mtn +import modelopt.torch.utils.distributed as dist from modelopt.torch.nas.modules.conv import _DynamicConvNd from modelopt.torch.nas.plugins.megatron import ( MambaDInnerHp, @@ -38,9 +43,13 @@ _DynamicTENorm, _DynamicTERowParallelLinear, ) +from modelopt.torch.nas.plugins.megatron_model_stats import mcore_param_count from modelopt.torch.nas.traced_hp import TracedHp -from modelopt.torch.opt.utils import named_dynamic_modules, search_space_size -from modelopt.torch.prune.plugins.mcore_minitron import get_mcore_minitron_config +from modelopt.torch.opt.utils import named_dynamic_modules, named_hparams, search_space_size +from modelopt.torch.prune.plugins.mcore_minitron import ( + _param_num_dynamic, + get_mcore_minitron_config, +) from modelopt.torch.utils.random import centroid SEED = 1234 @@ -137,6 +146,115 @@ def test_mamba_search_space(dist_workers): dist_workers.run(_test_mamba_search_space) +def _test_param_num_dynamic_matches_formula(rank, size): + """Sample min-width subnet and assert _param_num_dynamic matches the analytical formula. + + Uses "ME*-" to exercise all four block types (Mamba, MoE, Attention, dense MLP). + Depth pruning is excluded from the formula override because _param_num_dynamic counts all + physical layers on each PP rank (actual depth pruning requires drop_mcore_language_model_layers). + """ + assert size <= 4, "test_param_num_dynamic_matches_formula only configured for upto 4 GPUs" + channel_divisor = 4 + mamba_head_dim_divisor = 4 + + # 4-layer hybrid covering all block types + num_layers = 4 + hybrid_override_pattern = "ME*-" + hidden_size = 16 + ffn_hidden_size = 32 + num_attention_heads = 16 + num_query_groups = 4 + mamba_state_dim = 4 + mamba_num_heads = 8 + mamba_head_dim = 16 + mamba_num_groups = 2 + num_moe_experts = 8 + moe_ffn_hidden_size = 16 + moe_shared_expert_intermediate_size = 16 + vocab_size = 32 + + model = get_mcore_mamba_hybrid_model( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=size, + initialize_megatron=True, + num_layers=num_layers, + hybrid_override_pattern=hybrid_override_pattern, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + ffn_hidden_size=ffn_hidden_size, + mamba_state_dim=mamba_state_dim, + mamba_num_heads=mamba_num_heads, + mamba_head_dim=mamba_head_dim, + mamba_num_groups=mamba_num_groups, + moe_ffn_hidden_size=moe_ffn_hidden_size, + moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, + num_moe_experts=num_moe_experts, + vocab_size=vocab_size, + transformer_impl="transformer_engine", + bf16=False, + ).cuda() + + mtn.convert( + model, + [ + ( + "mcore_minitron", + get_mcore_minitron_config( + hidden_size_divisor=channel_divisor, + ffn_hidden_size_divisor=channel_divisor, + mamba_head_dim_divisor=mamba_head_dim_divisor, + num_moe_experts_divisor=1, + num_layers_divisor=1, + ), + ) + ], + ) + + mtn.sample(model, min) + + hybrid_key = ( + "hybrid_override_pattern" + if hasattr(model, "hybrid_override_pattern") + else "hybrid_layer_pattern" + ) + hybrid_layer_pattern = getattr(model, hybrid_key) + + # Build a flat {hparam_name: active_value} dict (same convention as the searcher's ss_config). + # get_subnet_config() returns full-path keys, which mcore_param_count does not understand. + # With PP > 1 each rank only holds a subset of layer types, so gather across the PP group + # to get the complete set of hparam overrides for the global formula. + # Exclude num_layers: _param_num_dynamic counts all physical layers regardless of the depth + # hparam; actual depth pruning (drop_mcore_language_model_layers) is not called here. + local_config = { + n.split(".")[-1]: hp.active + for n, hp in named_hparams(model, configurable=True) + if n.split(".")[-1] != "num_layers" + } + width_ss_config = dist.DistributedProcessGroup.get_dist_syncd_obj( + local_config, + dist.DistributedProcessGroup(get_pipeline_model_parallel_group()), + op=lambda all_rank_configs: {k: v for d in all_rank_configs for k, v in d.items()}, + ) + formula_total, _ = mcore_param_count( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + hybrid_layer_pattern=hybrid_layer_pattern, + **width_ss_config, + ) + dynamic_count = int(_param_num_dynamic(model)) + + assert formula_total == dynamic_count, ( + f"Formula ({formula_total:,}) != _param_num_dynamic ({dynamic_count:,}) " + f"for min-width subnet {width_ss_config} (PP={size})" + ) + + +def test_param_num_dynamic_matches_formula(dist_workers): + dist_workers.run(_test_param_num_dynamic_matches_formula) + + def test_mamba_num_heads_hp(): num_heads = MambaNumHeadsHp(8, ngroups=2) # 4 heads per group assert num_heads.choices == [2, 4, 6, 8] diff --git a/tests/gpu_megatron/torch/nas/plugins/test_megatron_model_stats.py b/tests/gpu_megatron/torch/nas/plugins/test_megatron_model_stats.py new file mode 100644 index 00000000000..ada5c239b13 --- /dev/null +++ b/tests/gpu_megatron/torch/nas/plugins/test_megatron_model_stats.py @@ -0,0 +1,445 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for modelopt.torch.nas.plugins.megatron_model_stats. + +Two test groups: + - TestMcoreParamCountFormulas: pure arithmetic / no GPU needed. + - Live model tests (_test_formula_matches_*): build a real MCore model on GPU, + compare the analytical formula result against get_mcore_param_count(). +""" + +from types import SimpleNamespace + +import pytest +from _test_utils.torch.megatron.models import ( + HAS_MAMBA, + get_mcore_gpt_model, + get_mcore_mamba_hybrid_model, +) + +from modelopt.torch.nas.plugins.megatron_model_stats import ( + mcore_memory_footprint_mb, + mcore_param_count, +) +from modelopt.torch.prune.plugins.mcore_minitron import get_mcore_param_count + +# --------------------------------------------------------------------------- +# Small reference dimensions - easy to verify by hand +# --------------------------------------------------------------------------- + +_H = 4 # hidden_size +_NH = 2 # num_attention_heads +_NKV = 1 # num_query_groups +_KV = 2 # kv_channels (== _H // _NH) +_FFN = 8 # ffn_hidden_size +_NE = 2 # num_moe_experts +_TOPK = 1 # moe_router_topk +_MOE_FFN = 8 # moe_ffn_hidden_size +_MNH = 2 # mamba_num_heads +_MDH = 2 # mamba_head_dim +_MNG = 2 # mamba_num_groups +_MDS = 2 # mamba_state_dim +_V = 10 # vocab_size + +_BASE_CFG = SimpleNamespace( + hidden_size=_H, + num_layers=1, + num_attention_heads=_NH, + num_query_groups=_NKV, + kv_channels=_KV, + ffn_hidden_size=_FFN, + num_moe_experts=_NE, + moe_router_topk=_TOPK, + moe_ffn_hidden_size=_MOE_FFN, + moe_shared_expert_intermediate_size=None, + moe_shared_expert_gate=False, + mamba_num_heads=_MNH, + mamba_head_dim=_MDH, + mamba_num_groups=_MNG, + mamba_state_dim=_MDS, + gated_linear_unit=False, + add_bias_linear=False, + normalization="RMSNorm", + qk_layernorm=False, + attention_output_gate=False, + moe_layer_freq=1, +) + +# Pre-computed expected component sizes (all verified by hand): + +_LN = _H # single RMSNorm weight + +# Embedding + final RMSNorm + output layer (untied) +_BASE_UNTIED = _V * _H + _LN + _H * _V # 40 + 4 + 40 = 84 + +# Attention sublayer (*): +# linear_qkv weight: H * (NH + 2*NKV) * KV + input_layernorm: H +# linear_proj weight: NH * KV * H +_QKV = _H * (_NH + 2 * _NKV) * _KV # 4 * 4 * 2 = 32 +_PROJ = _NH * _KV * _H # 2 * 2 * 4 = 16 +_ATTN = _QKV + _LN + _PROJ # 32 + 4 + 16 = 52 + +# Dense MLP sublayer (-): +# linear_fc1: H * FFN + pre_mlp_layernorm: H +# linear_fc2: FFN * H +_DENSE_MLP = _H * _FFN + _LN + _FFN * _H # 32 + 4 + 32 = 68 + +# MoE sublayer (E): +# always-active: pre_mlp_layernorm + router weight (NE * H) [not per-expert] +# per routed expert: fc1 (H * MOE_FFN) + fc2 (MOE_FFN * H) +_MOE_ALWAYS = _LN + _NE * _H # 4 + 8 = 12 +_MOE_PER_EXP = _H * _MOE_FFN + _MOE_FFN * _H # 32 + 32 = 64 +_MOE_TOTAL = _MOE_ALWAYS + _NE * _MOE_PER_EXP # 12 + 128 = 140 +_MOE_ACTIVE = _MOE_ALWAYS + _TOPK * _MOE_PER_EXP # 12 + 64 = 76 + +# Mamba sublayer (M): +# in_proj: H * (2*d_inner + 2*MNG*MDS + MNH) + input_layernorm: H +# out_proj: d_inner * H +# conv1d (depthwise): conv_dim * d_conv + conv_dim (weight + bias) +# scalars: A_log + dt_bias + D -> 3 * MNH +# internal RMSNorm on d_inner: d_inner +_D_INNER = _MNH * _MDH # 2 * 2 = 4 +_IN_PROJ_OUT = 2 * _D_INNER + 2 * _MNG * _MDS + _MNH # 8 + 8 + 2 = 18 +_CONV_DIM = _D_INNER + 2 * _MNG * _MDS # 4 + 8 = 12 +_MAMBA = ( + _H * _IN_PROJ_OUT # in_proj weight + + _LN # input_layernorm + + _D_INNER * _H # out_proj + + _CONV_DIM * 4 + + _CONV_DIM # conv weight + bias + + 3 * _MNH # scalars + + _D_INNER # internal RMSNorm +) # 72 + 4 + 16 + 60 + 6 + 4 = 162 + + +# --------------------------------------------------------------------------- +# Formula tests (no GPU required) +# --------------------------------------------------------------------------- + + +class TestMcoreParamCountFormulas: + def test_single_attention_layer(self): + total, active = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="*") + expected = _BASE_UNTIED + _ATTN + assert total == expected + assert active == expected # attention has no MoE split + + def test_single_dense_mlp_layer(self): + total, active = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="-") + expected = _BASE_UNTIED + _DENSE_MLP + assert total == expected + assert active == expected + + def test_single_moe_layer_total_and_active(self): + total, active = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="E") + assert total == _BASE_UNTIED + _MOE_TOTAL + assert active == _BASE_UNTIED + _MOE_ACTIVE + assert active < total + + def test_single_mamba_layer(self): + total, active = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="M") + expected = _BASE_UNTIED + _MAMBA + assert total == expected + assert active == expected + + def test_hybrid_pattern_is_sum_of_per_layer_costs(self): + pattern = "MEM*E" + total, active = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern=pattern, num_layers=5) + assert total == _BASE_UNTIED + 2 * _MAMBA + 2 * _MOE_TOTAL + _ATTN + assert active == _BASE_UNTIED + 2 * _MAMBA + 2 * _MOE_ACTIVE + _ATTN + + def test_pipe_char_ignored(self): + base = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="ME", num_layers=2) + with_pipe = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="M|E", num_layers=2) + assert base == with_pipe + + def test_mtp_separator_strips_suffix(self): + # Everything from '/' onward is MTP and must be ignored. + base = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="M") + with_mtp = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="M/E*") + assert base == with_mtp + + def test_tied_vocab_excludes_output_layer(self): + untied, _ = mcore_param_count( + _BASE_CFG, _V, share_embeddings_and_output_weights=False, hybrid_layer_pattern="M" + ) + tied, _ = mcore_param_count( + _BASE_CFG, _V, share_embeddings_and_output_weights=True, hybrid_layer_pattern="M" + ) + assert untied - tied == _V * _H + + def test_moe_topk_equals_num_experts_gives_equal_total_active(self): + total, active = mcore_param_count( + _BASE_CFG, _V, hybrid_layer_pattern="E", moe_router_topk=_NE + ) + assert total == active + + def test_empty_pattern_only_base_params(self): + total, active = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="", num_layers=0) + assert total == _BASE_UNTIED + assert active == _BASE_UNTIED + + def test_layernorm_adds_bias_over_rmsnorm(self): + # LayerNorm has weight + bias vs RMSNorm weight-only, so LayerNorm models are larger. + rmsnorm, _ = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="*") + layernorm, _ = mcore_param_count( + _BASE_CFG, _V, hybrid_layer_pattern="*", normalization="LayerNorm" + ) + assert layernorm > rmsnorm + + def test_pure_gpt_dense_scales_with_num_layers(self): + # All-dense GPT (num_moe_experts=None): each layer = attn + dense MLP. + total_1, _ = mcore_param_count(_BASE_CFG, _V, num_layers=1, num_moe_experts=None) + total_2, _ = mcore_param_count(_BASE_CFG, _V, num_layers=2, num_moe_experts=None) + assert total_2 - total_1 == _ATTN + _DENSE_MLP + + def test_pure_gpt_moe_layer_freq(self): + # moe_layer_freq=2: layer indices 0, 2 are MoE; 1, 3 are dense. + total, active = mcore_param_count(_BASE_CFG, _V, num_layers=4, moe_layer_freq=2) + assert total == _BASE_UNTIED + 4 * _ATTN + 2 * _MOE_TOTAL + 2 * _DENSE_MLP + assert active == _BASE_UNTIED + 4 * _ATTN + 2 * _MOE_ACTIVE + 2 * _DENSE_MLP + + +# --------------------------------------------------------------------------- +# Memory footprint tests (no GPU required) +# --------------------------------------------------------------------------- + +_MB = 1024**2 +_SEQ = 4 # sequence_length used in memory tests +_BSZ = 1 # batch_size used in memory tests + +# Mamba state size (per layer, batch=1, dtype=2 bytes): +# conv_state = batch * conv_dim * (d_conv - 1) where d_conv=4 +# ssm_state = batch * MNH * MDH * MDS +_MAMBA_CONV_DIM = _D_INNER + 2 * _MNG * _MDS # same as _CONV_DIM = 12 +_MAMBA_CONV_STATE = _BSZ * _MAMBA_CONV_DIM * (4 - 1) # 1 * 12 * 3 = 36 +_MAMBA_SSM_STATE = _BSZ * _MNH * _MDH * _MDS # 1 * 2 * 2 * 2 = 8 +_MAMBA_STATE_BYTES = (_MAMBA_CONV_STATE + _MAMBA_SSM_STATE) * 2 # * dtype_bytes=2 = 88 + + +class TestMcoreMemoryFootprint: + def _mem(self, pattern, num_layers=1, sequence_length=_SEQ, batch_size=_BSZ, **kw): + return mcore_memory_footprint_mb( + _BASE_CFG, + _V, + hybrid_layer_pattern=pattern, + dtype_bytes=2, + sequence_length=sequence_length, + batch_size=batch_size, + num_layers=num_layers, + **kw, + ) + + def test_total_equals_sum_of_components(self): + params_mb, kv_cache_mb, mamba_state_mb, total_mb = self._mem("M*", num_layers=2) + assert total_mb == pytest.approx(params_mb + kv_cache_mb + mamba_state_mb) + + def test_params_mb_consistent_with_param_count(self): + # params_mb must equal total_params * dtype_bytes / GB + params_mb, _, _, _ = self._mem("*") + total, _ = mcore_param_count(_BASE_CFG, _V, hybrid_layer_pattern="*") + assert params_mb == pytest.approx(total * 2 / _MB) + + def test_dtype_bytes_scales_params_mb(self): + params_mb2, _, _, _ = self._mem("*") + params_mb4, _, _, _ = mcore_memory_footprint_mb( + _BASE_CFG, + _V, + hybrid_layer_pattern="*", + dtype_bytes=4, + sequence_length=_SEQ, + batch_size=_BSZ, + ) + assert params_mb4 == pytest.approx(params_mb2 * 2) + + def test_kv_cache_mb_zero_for_pure_mamba(self): + _, kv_cache_mb, _, _ = self._mem("M") + assert kv_cache_mb == 0.0 + + def test_mamba_state_mb_zero_for_pure_gpt(self): + _, _, mamba_state_mb, _ = self._mem(None, num_layers=1) # no hybrid pattern -> pure GPT + assert mamba_state_mb == 0.0 + + def test_kv_cache_mb_exact_for_single_attention_layer(self): + # kv_per_layer = 2 * batch * seq * NKV * KV * dtype_bytes + expected_bytes = 2 * _BSZ * _SEQ * _NKV * _KV * 2 + _, kv_cache_mb, _, _ = self._mem("*") + assert kv_cache_mb == pytest.approx(expected_bytes / _MB) + + def test_kv_cache_scales_linearly_with_seq_and_batch(self): + # KV cache is linear in both sequence_length and batch_size. + _, base_kv, _, _ = self._mem("*") + _, kv_seq4, _, _ = self._mem("*", sequence_length=_SEQ * 4) + _, kv_bsz4, _, _ = self._mem("*", batch_size=4) + assert kv_seq4 == pytest.approx(base_kv * 4) + assert kv_bsz4 == pytest.approx(base_kv * 4) + + def test_kv_cache_dtype_bytes_independent_of_param_dtype(self): + # kv_cache_dtype_bytes=4 doubles kv_cache_mb but not params_mb + params_mb2, kv_cache_mb2, _, _ = self._mem("*") + params_mb_kv4, kv_cache_mb_kv4, _, _ = mcore_memory_footprint_mb( + _BASE_CFG, + _V, + hybrid_layer_pattern="*", + dtype_bytes=2, + kv_cache_dtype_bytes=4, + sequence_length=_SEQ, + batch_size=_BSZ, + ) + assert kv_cache_mb_kv4 == pytest.approx(kv_cache_mb2 * 2) + assert params_mb_kv4 == pytest.approx(params_mb2) + + def test_mamba_state_mb_exact(self): + _, _, mamba_state_mb, _ = self._mem("M") + assert mamba_state_mb == pytest.approx(_MAMBA_STATE_BYTES / _MB) + + def test_mamba_state_scales_with_num_mamba_layers(self): + _, _, mamba_state1, _ = self._mem("M") + _, _, mamba_state2, _ = self._mem("MM", num_layers=2) + assert mamba_state2 == pytest.approx(mamba_state1 * 2) + + +# --------------------------------------------------------------------------- +# Live model tests: formula must match get_mcore_param_count() exactly +# --------------------------------------------------------------------------- + + +def _test_formula_matches_gpt_model(rank, size, parallelism): + model = get_mcore_gpt_model( + tensor_model_parallel_size=1 if parallelism != "tp" else size, + pipeline_model_parallel_size=1 if parallelism != "pp" else size, + initialize_megatron=True, + num_layers=4, + hidden_size=64, + num_attention_heads=8, + num_query_groups=4, + ffn_hidden_size=128, + vocab_size=128, + normalization="RMSNorm", + activation_func="swiglu", + bf16=True, + ).cuda() + + expected_total, expected_active = mcore_param_count( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + ) + actual = get_mcore_param_count(model) + + assert expected_total == expected_active, "Non-MoE GPT: total must equal active" + assert expected_total == actual, ( + f"Formula ({expected_total:,}) != live model ({actual:,}) for {parallelism}" + ) + + +@pytest.mark.parametrize("parallelism", ["tp", "pp"]) +def test_formula_matches_gpt_model(dist_workers, parallelism, num_gpus): + """Builds a real GPTModel and asserts the analytical formula matches the live count.""" + if num_gpus == 1 and parallelism != "tp": + pytest.skip("Skipping as redundant test on 1 GPU") + dist_workers.run(_test_formula_matches_gpt_model, parallelism=parallelism) + + +def _test_formula_matches_gpt_moe_model(rank, size, parallelism): + model = get_mcore_gpt_model( + tensor_model_parallel_size=1 if parallelism != "tp" else size, + pipeline_model_parallel_size=1 if parallelism != "pp" else size, + expert_model_parallel_size=1 if parallelism != "ep" else size, + initialize_megatron=True, + num_layers=4, + hidden_size=64, + num_attention_heads=8, + ffn_hidden_size=128, + moe_grouped_gemm=True, + num_moe_experts=4, + moe_ffn_hidden_size=64, + moe_shared_expert_intermediate_size=16, + vocab_size=128, + normalization="RMSNorm", + activation_func="swiglu", + moe_layer_freq=2, + bf16=True, + ).cuda() + + expected_total, expected_active = mcore_param_count( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + ) + actual = get_mcore_param_count(model) + + assert expected_active < expected_total, "MoE model: active must be less than total" + assert expected_total == actual, ( + f"Formula total ({expected_total:,}) != live model ({actual:,}) for {parallelism}" + ) + + +@pytest.mark.parametrize("parallelism", ["tp", "pp", "ep"]) +def test_formula_total_matches_gpt_moe_model(dist_workers, parallelism, num_gpus): + """Builds a GPTModel with MoE layers; formula total must match the live param count.""" + if num_gpus == 1 and parallelism != "tp": + pytest.skip("Skipping as redundant test on 1 GPU") + dist_workers.run(_test_formula_matches_gpt_moe_model, parallelism=parallelism) + + +def _test_formula_matches_mamba_model(rank, size, parallelism): + hidden_size = 64 + mamba_head_dim = 16 + mamba_num_heads = hidden_size // mamba_head_dim # 4 + pattern = "ME*-" # 4-layer hybrid + + model = get_mcore_mamba_hybrid_model( + tensor_model_parallel_size=1 if parallelism != "tp" else size, + pipeline_model_parallel_size=1 if parallelism != "pp" else size, + expert_model_parallel_size=1 if parallelism != "ep" else size, + initialize_megatron=True, + num_layers=4, + hidden_size=hidden_size, + num_attention_heads=8, + mamba_head_dim=mamba_head_dim, + mamba_num_heads=mamba_num_heads, + vocab_size=128, + hybrid_override_pattern=pattern, + moe_grouped_gemm=False, + num_moe_experts=4, + moe_ffn_hidden_size=64, + moe_shared_expert_intermediate_size=None, + bf16=True, + ).cuda() + + hybrid_layer_pattern = getattr(model, "hybrid_layer_pattern", pattern) + expected_total, _ = mcore_param_count( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + hybrid_layer_pattern=hybrid_layer_pattern, + ) + actual = get_mcore_param_count(model) + + assert expected_total == actual, ( + f"Formula ({expected_total:,}) != live model ({actual:,}) for {parallelism}" + ) + + +@pytest.mark.parametrize("parallelism", ["tp", "pp", "ep"]) +def test_formula_matches_mamba_model(dist_workers, parallelism, num_gpus): + """Builds a non-MoE hybrid MambaModel; formula must match the live param count.""" + if num_gpus == 1 and parallelism != "tp": + pytest.skip("Skipping as redundant test on 1 GPU") + if not HAS_MAMBA: + pytest.skip("Mamba not installed") + dist_workers.run(_test_formula_matches_mamba_model, parallelism=parallelism) diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index 7a7ad90a180..0f56a370b94 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -15,7 +15,7 @@ import contextlib import io -from functools import partial +import re import pytest import torch @@ -34,11 +34,14 @@ from megatron.core.transformer.identity_op import IdentityOp import modelopt.torch.nas as mtn +from modelopt.torch.nas.plugins.megatron_model_stats import ( + mcore_memory_footprint_mb, + mcore_param_count, +) from modelopt.torch.prune.plugins.mcore_minitron import ( ImportanceEstimatorRegistry, _convert_model_to_dynamic_space, get_mcore_minitron_config, - get_mcore_param_count, ) SEED = 1234 @@ -120,7 +123,7 @@ def test_mcore_mamba_parameter_sorting(dist_workers): dist_workers.run(_test_mcore_mamba_parameter_sorting) -def _test_mcore_mamba_hybrid_pruning(ckpt_dir, rank, size): +def _test_mcore_mamba_hybrid_pruning(rank, size, ckpt_dir): channel_divisor = 4 num_layers = min(size * 2, 8) @@ -228,140 +231,249 @@ def forward_loop(m): def test_mcore_mamba_hybrid_pruning(dist_workers, tmp_path): - dist_workers.run(partial(_test_mcore_mamba_hybrid_pruning, tmp_path / "minitron_scores")) - - -def _test_mcore_mamba_hybrid_pruning_nas(ckpt_dir, rank, size): - set_seed(SEED) - channel_divisor = 4 - - num_layers = 4 # Atleast one of "M, *, -, E" blocks - hybrid_pattern = "ME*-" - hidden_size = 16 - ffn_hidden_size = 32 - num_attention_heads = 16 - num_query_groups = 4 - mamba_state_dim = 4 - mamba_num_heads = 8 - mamba_head_dim = 16 - mamba_num_groups = 2 - num_moe_experts = 8 - moe_ffn_hidden_size = 16 - moe_shared_expert_intermediate_size = 16 - vocab_size = 32 - batch_size = 2 - - model = get_mcore_mamba_hybrid_model( + dist_workers.run(_test_mcore_mamba_hybrid_pruning, tmp_path / "minitron_scores") + + +# Shared parameters for the "ME*-" hybrid NAS tests +_NAS_CHANNEL_DIVISOR = 4 +_NAS_BATCH_SIZE = 2 +_NAS_MODEL_KWARGS = { + "num_layers": 4, + "hybrid_override_pattern": "ME*-", + "hidden_size": 16, + "ffn_hidden_size": 32, + "num_attention_heads": 16, + "num_query_groups": 4, + "mamba_state_dim": 4, + "mamba_num_heads": 8, + "mamba_head_dim": 16, + "mamba_num_groups": 2, + "moe_ffn_hidden_size": 16, + "moe_shared_expert_intermediate_size": 16, + "num_moe_experts": 8, + "vocab_size": 32, +} + + +def _make_nas_hybrid_model(size): + return get_mcore_mamba_hybrid_model( tensor_model_parallel_size=1, pipeline_model_parallel_size=size, initialize_megatron=True, - num_layers=num_layers, - hybrid_override_pattern=hybrid_pattern, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_query_groups=num_query_groups, - ffn_hidden_size=ffn_hidden_size, - mamba_state_dim=mamba_state_dim, - mamba_num_heads=mamba_num_heads, - mamba_head_dim=mamba_head_dim, - mamba_num_groups=mamba_num_groups, - moe_ffn_hidden_size=moe_ffn_hidden_size, - moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size, - num_moe_experts=num_moe_experts, - vocab_size=vocab_size, transformer_impl="transformer_engine", bf16=False, + **_NAS_MODEL_KWARGS, ).cuda() - param_count = get_mcore_param_count(model) - assert param_count == 14984.0, param_count - def forward_loop(m): - for _ in range(2): - run_mcore_inference_with_dummy_input(m, batch_size, hidden_size) +def _nas_forward_loop(m): + for _ in range(2): + run_mcore_inference_with_dummy_input(m, _NAS_BATCH_SIZE, _NAS_MODEL_KWARGS["hidden_size"]) - def score_func(m): - c = m.config - return ( - c.num_layers - + c.hidden_size - + c.ffn_hidden_size - + c.mamba_num_heads - + c.mamba_head_dim - + c.num_attention_heads - + c.num_moe_experts - + c.moe_ffn_hidden_size - + c.moe_shared_expert_intermediate_size - ) - constraints = {"params": int(param_count * 0.7)} - config = { - "forward_loop": forward_loop, +def _nas_score_func(m): + c = m.config + return ( + c.num_layers + + c.hidden_size + + c.ffn_hidden_size + + c.mamba_num_heads + + c.mamba_head_dim + + c.num_attention_heads + + c.num_moe_experts + + c.moe_ffn_hidden_size + + c.moe_shared_expert_intermediate_size + ) + + +def _base_nas_config(ckpt_dir): + return { + "forward_loop": _nas_forward_loop, "checkpoint": ckpt_dir, - "score_func": score_func, + "score_func": _nas_score_func, "max_width_pruning": 0.5, "max_depth_pruning": 0.5, "hparams_to_skip": ["num_attention_heads", "moe_shared_expert_intermediate_size"], "top_k": 10, } + +def _get_hybrid_layer_pattern(model): + key = ( + "hybrid_override_pattern" + if hasattr(model, "hybrid_override_pattern") + else "hybrid_layer_pattern" + ) + return getattr(model, key) + + +def _get_sorted_layers(searcher_state): + return [ + layer + for layer, _ in sorted( + searcher_state["layer_scores"].items(), key=lambda x: x[1], reverse=True + ) + ] + + +def _assert_top_k_candidates(searcher_state, constraint_key, expected_top_k, k=10): + top_k = searcher_state["all_candidates_per_constraint"][constraint_key][:k] + assert len(top_k) == k + for actual, (ss_config, metrics, score) in zip(top_k, expected_top_k): + assert actual.ss_config == ss_config, (actual.ss_config, ss_config) + assert actual.metrics == metrics, (actual.metrics, metrics) + assert actual.score == score, (actual.score, score) + + +def _test_mcore_mamba_hybrid_pruning_nas_params(rank, size, ckpt_dir): + set_seed(SEED) + model = _make_nas_hybrid_model(size) + + baseline_params, baseline_active = mcore_param_count( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + hybrid_layer_pattern=_get_hybrid_layer_pattern(model), + ) + assert baseline_params == 14984, baseline_params + constraints = { + "params": int(baseline_params * 0.5), + "active_params": int(baseline_active * 0.7), + } + # Capture stdout to assert search space output stdout_capture = io.StringIO() with contextlib.redirect_stdout(stdout_capture): - model, searcher_state = prune_minitron(model, constraints, config, channel_divisor) + model, searcher_state = prune_minitron( + model, constraints, _base_nas_config(ckpt_dir), _NAS_CHANNEL_DIVISOR + ) - # Assert expected search space output is present + # Assert expected search space output is present (rich table format, strip ANSI codes first) captured_output = stdout_capture.getvalue() print(captured_output) + clean_output = re.sub(r"\x1b\[[0-9;]*[mGKH]", "", captured_output) if rank == 0: - assert "Search space for num_layers: [3, 4]" in captured_output - assert "Search space for hidden_size: [12, 16]" in captured_output - assert "Search space for mamba_num_heads: [6, 8]" in captured_output - assert "Search space for mamba_head_dim: [12, 16]" in captured_output - assert "Search space for num_moe_experts: [5, 6, 7, 8]" in captured_output - assert "Search space for moe_ffn_hidden_size: [12, 16]" in captured_output - assert "Total search space in consideration: 512" in captured_output + lines = clean_output.splitlines() + + def assert_row(key: str, value: str) -> None: + assert any(key in line and value in line for line in lines), ( + f"Expected row with {key!r} and {value!r} not found in search space table" + ) + + assert_row("num_layers", "[3, 4]") + assert_row("hidden_size", "[12, 16]") + assert_row("mamba_num_heads", "[6, 8]") + assert_row("num_moe_experts", "[5, 6, 7, 8]") + assert_row("moe_ffn_hidden_size", "[12, 16]") + assert_row("Search space size", "512") + + pruned_params, pruned_active_params = mcore_param_count( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + hybrid_layer_pattern=_get_hybrid_layer_pattern(model), + ) + assert pruned_params == 7154, pruned_params + assert pruned_active_params == 7154, pruned_active_params # NOTE: Slight variation in layer ordering for MoE / Attention / MLP depending on PP configuration # This affects param counts when num_layers is pruned - sorted_layers = [ - layer - for layer, _ in sorted( - searcher_state["layer_scores"].items(), key=lambda x: x[1], reverse=True - ) - ] + sorted_layers = _get_sorted_layers(searcher_state) # fmt: off if sorted_layers == [1, 4, 3, 2]: # PP 1/2 + # Winner is 3-layer: keeps layers [1,4,3] from "ME*-" → drops 'E' (layer 2) → "M*-" + assert _get_hybrid_layer_pattern(model) == "M*-", _get_hybrid_layer_pattern(model) expected_top_k = [ - [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 20}, 10482.0, 112.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 24}, 10472.0, 118.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 8, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 20}, 10400.0, 112.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, 10388.0, 123.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 20}, 10376.0, 114.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 28}, 10370.0, 117.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, 10338.0, 123.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 28}, 10292.0, 119.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 5, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, 10268.0, 125.0], # noqa: E501 - [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 24}, 10242.0, 113.0], # noqa: E501 + # 4 four-layer models qualifying under params_thresh=7492 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 20}, {"params": 7418, "active_params": 6266}, 104], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 7406, "active_params": 6542}, 115], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 28}, {"params": 7310, "active_params": 6446}, 111], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 24}, {"params": 7214, "active_params": 6350}, 107], # noqa: E501 + # 6 depth-pruned (num_layers=3) models; params==active_params since MoE layer is dropped + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 118], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 122], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 119], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 123], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 120], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 124], # noqa: E501 ] else: raise RuntimeError(f"FIXME: Non deterministic test, assertions may fail: {sorted_layers=}") # fmt: on - assert get_mcore_param_count(model) == 10268.0 - - top_k = searcher_state["all_candidates_per_constraint"][constraints["params"]][:10] - assert len(top_k) == 10 - for actual, (ss_config, params, score) in zip(top_k, expected_top_k): - assert actual.ss_config == ss_config, (actual.ss_config, ss_config) - assert actual.params == params, (actual.params, params) - assert actual.score == score, (actual.score, score) + _assert_top_k_candidates( + searcher_state, + (("params", constraints["params"]), ("active_params", constraints["active_params"])), + expected_top_k, + ) + run_mcore_inference_with_dummy_input(model, _NAS_BATCH_SIZE, model.config.hidden_size) @pytest.mark.skipif( torch.cuda.device_count() > 2, reason="Assertions not configured for more than 2 GPUs" ) -def test_mcore_mamba_hybrid_pruning_nas(dist_workers, tmp_path): - dist_workers.run( - partial(_test_mcore_mamba_hybrid_pruning_nas, tmp_path / "minitron_scores"), +def test_mcore_mamba_hybrid_pruning_nas_params(dist_workers, tmp_path): + dist_workers.run(_test_mcore_mamba_hybrid_pruning_nas_params, tmp_path / "minitron_scores") + + +def _test_mcore_mamba_hybrid_pruning_nas_memory_mb(rank, size, ckpt_dir): + set_seed(SEED) + dtype_bytes = 2 + sequence_length = 128 + model = _make_nas_hybrid_model(size) + + _, _, _, baseline_memory_mb = mcore_memory_footprint_mb( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + hybrid_layer_pattern=_get_hybrid_layer_pattern(model), + dtype_bytes=dtype_bytes, + sequence_length=sequence_length, + batch_size=1, + ) + memory_threshold = baseline_memory_mb * 0.7 + + constraints = {"memory_mb": memory_threshold} + config = { + **_base_nas_config(ckpt_dir), + "seq_length": sequence_length, + "batch_size": 1, + } + model, searcher_state = prune_minitron(model, constraints, config, _NAS_CHANNEL_DIVISOR) + + pruned_params, _ = mcore_param_count( + model.config, + model.vocab_size, + model.share_embeddings_and_output_weights, + hybrid_layer_pattern=_get_hybrid_layer_pattern(model), ) + assert pruned_params == 10082, pruned_params + + sorted_layers = _get_sorted_layers(searcher_state) + # fmt: off + if sorted_layers == [1, 4, 3, 2]: + expected_top_k = [ + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 6, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 24}, {"memory_mb": 0.0226287841796875}, 114], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 12, "num_moe_experts": 8, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"memory_mb": 0.022613525390625}, 124], # noqa: E501 + [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 6, "mamba_head_dim": 16, "num_moe_experts": 8, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"memory_mb": 0.022556304931640625}, 126], # noqa: E501 + [{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 24}, {"memory_mb": 0.022541046142578125}, 113], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 20}, {"memory_mb": 0.0225067138671875}, 112], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 5, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 20}, {"memory_mb": 0.0225067138671875}, 116], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 6, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 20}, {"memory_mb": 0.0225067138671875}, 113], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 20}, {"memory_mb": 0.0225067138671875}, 117], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 20}, {"memory_mb": 0.0225067138671875}, 114], # noqa: E501 + [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 8, "mamba_head_dim": 16, "num_moe_experts": 7, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 20}, {"memory_mb": 0.0225067138671875}, 118], # noqa: E501 + ] + else: + raise RuntimeError(f"FIXME: Non deterministic test, assertions may fail: {sorted_layers=}") + # fmt: on + + _assert_top_k_candidates(searcher_state, (("memory_mb", memory_threshold),), expected_top_k) + run_mcore_inference_with_dummy_input(model, _NAS_BATCH_SIZE, model.config.hidden_size) + + +@pytest.mark.skipif( + torch.cuda.device_count() > 2, reason="Assertions not configured for more than 2 GPUs" +) +def test_mcore_mamba_hybrid_pruning_nas_memory_mb(dist_workers, tmp_path): + dist_workers.run(_test_mcore_mamba_hybrid_pruning_nas_memory_mb, tmp_path / "minitron_scores") From f905b6ee16446dd58f06ddcd5a1adeac1d829349 Mon Sep 17 00:00:00 2001 From: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> Date: Fri, 1 May 2026 12:20:47 -0700 Subject: [PATCH 2/2] Fixes based on Nemotron3 tests Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com> --- .github/workflows/example_tests.yml | 2 +- examples/megatron_bridge/README.md | 10 ++++- examples/megatron_bridge/prune_minitron.py | 13 ++++--- examples/megatron_bridge/requirements.txt | 2 + examples/pruning/README.md | 2 +- .../torch/prune/plugins/mcore_minitron.py | 11 ++---- .../megatron_bridge/test_prune_minitron.py | 2 +- .../test_mcore_mamba_minitron_pruning.py | 39 +++++++++++-------- 8 files changed, 47 insertions(+), 34 deletions(-) create mode 100644 examples/megatron_bridge/requirements.txt diff --git a/.github/workflows/example_tests.yml b/.github/workflows/example_tests.yml index 2e6bfa690eb..4bc892e35c3 100644 --- a/.github/workflows/example_tests.yml +++ b/.github/workflows/example_tests.yml @@ -86,7 +86,7 @@ jobs: uses: ./.github/workflows/_example_tests_runner.yml secrets: inherit with: - docker_image: "nvcr.io/nvidia/nemo:26.02" + docker_image: "nvcr.io/nvidia/nemo:26.04" example: megatron_bridge timeout_minutes: 30 pip_install_extras: "[hf,puzzletron,dev-test]" diff --git a/examples/megatron_bridge/README.md b/examples/megatron_bridge/README.md index ea7d2922810..9ad13424327 100644 --- a/examples/megatron_bridge/README.md +++ b/examples/megatron_bridge/README.md @@ -16,7 +16,7 @@ This directory contains examples of using Model Optimizer with [NeMo Megatron-Br ## Pre-Requisites -Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02`) which has all the dependencies installed. +Running these examples requires many additional dependencies to be installed (e.g., Megatron-Bridge, Megatron-core, etc.), hence we strongly recommend directly using the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.04`) which has all the dependencies installed. To get the ModelOpt examples scripts, mount your Model-Optimizer repo to the container as follows: @@ -26,7 +26,7 @@ if [ ! -d "${MODELOPT_DIR}" ]; then git clone https://github.com/NVIDIA/Model-Optimizer.git ${MODELOPT_DIR} fi -export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.02 +export DOCKER_IMAGE=nvcr.io/nvidia/nemo:26.04 docker run \ --gpus all \ --shm-size=16GB \ @@ -49,6 +49,12 @@ hf auth login --token > [!WARNING] > Use `python -m pip` instead of `pip` to avoid conflicts with the system-wide installed packages in the NeMo containers. You may also refer to this [doc](https://github.com/NVIDIA-NeMo/Megatron-Bridge/blob/main/docker/common/README.md#installing-packages-inside-the-container) on how to correctly install packages in the NeMo containers without breaking existing torch installation. +Also install additional dependencies from the [requirements.txt](./requirements.txt) file. + +```bash +python -m pip install -r requirements.txt +``` + ## Pruning This section shows how to prune a HuggingFace model using Minitron algorithm in Megatron-Bridge framework. Checkout other available pruning algorithms, supported frameworks and models, and general pruning getting-started in the [pruning README](../pruning/README.md). diff --git a/examples/megatron_bridge/prune_minitron.py b/examples/megatron_bridge/prune_minitron.py index 1eff609fee6..5dfbc1b7e38 100644 --- a/examples/megatron_bridge/prune_minitron.py +++ b/examples/megatron_bridge/prune_minitron.py @@ -161,11 +161,11 @@ def get_args() -> argparse.Namespace: parser.add_argument( "--prune_score_func", type=str, - default="mmlu_10pct", + default="mmlu_10pct_bs1", help=( "Score function to use for NAS-based pruning. Only supports MMLU at the moment. " - "Format: mmlu_pct where is the percentage of MMLU data to sample per subject " - "(e.g. mmlu_10pct for 10%, mmlu_100pct for full eval)." + "Format: mmlu_pct_ where is the percentage of MMLU data to sample per subject and is " + "batch size for fast evaluation (default is mmlu_10pct_bs1)." ), ) parser.add_argument( @@ -343,16 +343,17 @@ def main(args: argparse.Namespace): "You can change this to be any other metric you want to maximize (e.g. negative validation loss)." ) - match = re.fullmatch(r"mmlu_(\d+)pct", args.prune_score_func) + match = re.fullmatch(r"mmlu_(\d+)pct_bs(\d+)", args.prune_score_func) if not match: raise ValueError( - f"Invalid score function: {args.prune_score_func}. Expected format: mmlu_pct (e.g. mmlu_10pct)" + f"Invalid score function: {args.prune_score_func}. Expected format: mmlu_pct_bs" ) mmlu_frac = float(match.group(1)) / 100.0 + batch_size = int(match.group(2)) def score_func(m): return megatron_mmlu( - m, tokenizer, few_shots=0, fraction=mmlu_frac, batch_size=args.calib_mbs + m, tokenizer, few_shots=0, fraction=mmlu_frac, batch_size=batch_size ) pruning_config["score_func"] = score_func diff --git a/examples/megatron_bridge/requirements.txt b/examples/megatron_bridge/requirements.txt new file mode 100644 index 00000000000..ec38c2f7ee7 --- /dev/null +++ b/examples/megatron_bridge/requirements.txt @@ -0,0 +1,2 @@ +# Saving some pruned models (e.g. Nemotron-3-Nano-30B-A3B-BF16) have issues with transformers>=5.0 +transformers<5.0 diff --git a/examples/pruning/README.md b/examples/pruning/README.md index 895d3b8f182..3f0e4c3e33b 100644 --- a/examples/pruning/README.md +++ b/examples/pruning/README.md @@ -27,7 +27,7 @@ This section focuses on applying Model Optimizer's state-of-the-art complementar ## Pre-Requisites -For Minitron pruning for Megatron-Bridge / Megatron-LM models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.02`) which has all the dependencies installed. +For Minitron pruning for Megatron-Bridge / Megatron-LM models, use the NeMo container (e.g., `nvcr.io/nvidia/nemo:26.04`) which has all the dependencies installed. For FastNAS pruning for PyTorch Computer Vision models, no additional dependencies are required. diff --git a/modelopt/torch/prune/plugins/mcore_minitron.py b/modelopt/torch/prune/plugins/mcore_minitron.py index 0aee519f137..5dd84d1f3ad 100644 --- a/modelopt/torch/prune/plugins/mcore_minitron.py +++ b/modelopt/torch/prune/plugins/mcore_minitron.py @@ -190,7 +190,7 @@ def _rprint(*renderables: Any) -> None: # Constraint keys that trigger the grid-search path in MCoreMinitronSearcher. # Order defines priority: first active key is used as the primary display/sort metric. -_METRIC_CONSTRAINT_PRIORITY = ("params", "active_params", "memory_mb") +_METRIC_CONSTRAINT_PRIORITY = ("active_params", "params", "memory_mb") _METRIC_CONSTRAINTS = frozenset(_METRIC_CONSTRAINT_PRIORITY) @@ -524,15 +524,15 @@ def search_best_arch_by_metrics(self) -> dict: _rprint(table) # 3. Optional Knowledge Distillation (KD) step for all top-k candidates - print_rank_0( - "\nSkipping optional Knowledge Distillation (KD) step for candidates as it is a manual step. " + _rprint( + f"[yellow]\nSkipping optional Knowledge Distillation (KD) step for candidates as it is a manual step. " "As per the original paper (https://arxiv.org/pdf/2407.14679), ideally we need to perform a short " f"Knowledge Distillation on ~2B tokens for all top {top_k} candidates before evaluating the " "`score_func`, which will take a lot longer to prune, require splitting the pruning process into multiple " "stages and a lot more compute for pruning but can lead to better pruned model selection. If you are " f"interested to do this, you can take the top {top_k} candidates' `export_config` from the logs above and " "then export all models separately and perform Knowledge Distillation on each of them before evaluating " - "the `score_func`.\n" + f"the `score_func`.\n[/yellow]" ) # 4. Validate top-k candidates using the score_func and return the best subnet @@ -683,9 +683,6 @@ def _generate_search_space_combos( def _compute_candidate_metrics(self, ss_config: dict, max_num_layers: int) -> dict[str, float]: """Compute all active metric constraint values for a candidate config analytically. - Calls ``mcore_param_count`` at most once (covers both ``params`` and ``active_params``) - and ``mcore_memory_footprint_mb`` at most once (for ``memory_mb``). - Replaces the slow ``_prune → _param_num_dynamic → sample(max)`` loop used during search. Handles depth pruning by filtering the hybrid layer pattern to the kept (best) layers. """ model = self.model diff --git a/tests/examples/megatron_bridge/test_prune_minitron.py b/tests/examples/megatron_bridge/test_prune_minitron.py index b63a7720736..8e0dddcc5c1 100644 --- a/tests/examples/megatron_bridge/test_prune_minitron.py +++ b/tests/examples/megatron_bridge/test_prune_minitron.py @@ -38,7 +38,7 @@ def test_prune_minitron(tmp_path: Path, num_gpus): calib_num_samples=16, seq_length=32, prune_target_params=prune_target_params, - prune_score_func="mmlu_1pct", + prune_score_func="mmlu_1pct_bs32", ss_channel_divisor=4, hparams_to_skip="num_attention_heads", top_k=1, diff --git a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py index 0f56a370b94..93ef70ac40e 100644 --- a/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py +++ b/tests/gpu_megatron/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py @@ -321,7 +321,12 @@ def _assert_top_k_candidates(searcher_state, constraint_key, expected_top_k, k=1 assert len(top_k) == k for actual, (ss_config, metrics, score) in zip(top_k, expected_top_k): assert actual.ss_config == ss_config, (actual.ss_config, ss_config) - assert actual.metrics == metrics, (actual.metrics, metrics) + for metric_name, expected_value in metrics.items(): + actual_value = actual.metrics[metric_name] + if isinstance(expected_value, float): + assert actual_value == pytest.approx(expected_value), (actual.metrics, metrics) + else: + assert actual_value == expected_value, (actual.metrics, metrics) assert actual.score == score, (actual.score, score) @@ -338,7 +343,7 @@ def _test_mcore_mamba_hybrid_pruning_nas_params(rank, size, ckpt_dir): assert baseline_params == 14984, baseline_params constraints = { "params": int(baseline_params * 0.5), - "active_params": int(baseline_active * 0.7), + "active_params": int(baseline_active * 0.55), } # Capture stdout to assert search space output @@ -373,8 +378,8 @@ def assert_row(key: str, value: str) -> None: model.share_embeddings_and_output_weights, hybrid_layer_pattern=_get_hybrid_layer_pattern(model), ) - assert pruned_params == 7154, pruned_params - assert pruned_active_params == 7154, pruned_active_params + assert pruned_params == 6536, pruned_params + assert pruned_active_params == 6536, pruned_active_params # NOTE: Slight variation in layer ordering for MoE / Attention / MLP depending on PP configuration # This affects param counts when num_layers is pruned @@ -384,18 +389,20 @@ def assert_row(key: str, value: str) -> None: # Winner is 3-layer: keeps layers [1,4,3] from "ME*-" → drops 'E' (layer 2) → "M*-" assert _get_hybrid_layer_pattern(model) == "M*-", _get_hybrid_layer_pattern(model) expected_top_k = [ - # 4 four-layer models qualifying under params_thresh=7492 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 20}, {"params": 7418, "active_params": 6266}, 104], # noqa: E501 + # position 1: the one qualifying 4-layer model (active=6542 > 3-layer H=12 active), + # demonstrating that active_params-first ranking can elevate 4-layer above 3-layer models [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 7406, "active_params": 6542}, 115], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 28}, {"params": 7310, "active_params": 6446}, 111], # noqa: E501 - [{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 24}, {"params": 7214, "active_params": 6350}, 107], # noqa: E501 - # 6 depth-pruned (num_layers=3) models; params==active_params since MoE layer is dropped - [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 118], # noqa: E501 - [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 122], # noqa: E501 - [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 119], # noqa: E501 - [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 123], # noqa: E501 - [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 120], # noqa: E501 - [{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 6, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"params": 7154, "active_params": 7154}, 124], # noqa: E501 + # positions 2-9: 3-layer H=12 MNH=8 MHD=12 ffn=32 (active==params=6536, no MoE layer) + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 6536, "active_params": 6536}, 116], # noqa: E501 + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 12, "num_moe_experts": 5, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"params": 6536, "active_params": 6536}, 120], # noqa: E501 + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 6536, "active_params": 6536}, 117], # noqa: E501 + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 12, "num_moe_experts": 6, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"params": 6536, "active_params": 6536}, 121], # noqa: E501 + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 6536, "active_params": 6536}, 118], # noqa: E501 + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 12, "num_moe_experts": 7, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"params": 6536, "active_params": 6536}, 122], # noqa: E501 + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 12, "num_moe_experts": 8, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 6536, "active_params": 6536}, 119], # noqa: E501 + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 8, "mamba_head_dim": 12, "num_moe_experts": 8, "moe_ffn_hidden_size": 16, "ffn_hidden_size": 32}, {"params": 6536, "active_params": 6536}, 123], # noqa: E501 + # position 10: first 3-layer H=12 MNH=6 MHD=16 ffn=32 candidate (active=6506) + [{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 6, "mamba_head_dim": 16, "num_moe_experts": 5, "moe_ffn_hidden_size": 12, "ffn_hidden_size": 32}, {"params": 6506, "active_params": 6506}, 118], # noqa: E501 ] else: raise RuntimeError(f"FIXME: Non deterministic test, assertions may fail: {sorted_layers=}") @@ -403,7 +410,7 @@ def assert_row(key: str, value: str) -> None: _assert_top_k_candidates( searcher_state, - (("params", constraints["params"]), ("active_params", constraints["active_params"])), + (("active_params", constraints["active_params"]), ("params", constraints["params"])), expected_top_k, ) run_mcore_inference_with_dummy_input(model, _NAS_BATCH_SIZE, model.config.hidden_size)