Skip to content

Add Gemma4 MoE quantization support#1219

Merged
yueshen2016 merged 9 commits intomainfrom
yueshen/gemma-4-moe
May 5, 2026
Merged

Add Gemma4 MoE quantization support#1219
yueshen2016 merged 9 commits intomainfrom
yueshen/gemma-4-moe

Conversation

@yueshen2016
Copy link
Copy Markdown
Contributor

@yueshen2016 yueshen2016 commented Apr 9, 2026

Summary

  • Register Gemma4TextExperts with _QuantQwen35MoeExperts plugin to unfuse fused 3D expert tensors into per-expert nn.Linear layers for quantization
  • Add structural is_moe() detection for modules with router + experts attributes (Gemma4 has no dedicated SparseMoeBlock class — the decoder layer directly owns router and experts)
  • Add Gemma4TextDecoderLayer to get_expert_linear_names() returning ["gate_proj", "down_proj", "up_proj"]
  • Add "*.experts.*" pattern to NVFP4_MLP_ONLY_CFG and NVFP4_EXPERTS_ONLY_CFG to match Gemma4's expert path (model.layers.X.experts.*, not nested under mlp)

Context: Gemma4 MoE models (e.g. google/gemma-4-26B-A4B-it) store expert weights as fused 3D nn.Parameter tensors (gate_up_proj, down_proj) instead of nn.ModuleList of nn.Linear. Since ModelOpt's quantizer only discovers nn.Linear modules, it silently skips the expert weights — the bulk of the model remains unquantized.

Companion vLLM PR: vllm-project/vllm#39406 (robust quantized MoE weight loading for Gemma4)

Test plan

  • hf_ptq.py --pyt_ckpt_path google/gemma-4-26B-A4B-it --qformat nvfp4_mlp_only — 35k+ quantizers inserted, 17GB output (vs 49GB BF16)
  • vllm serve <path> --quantization modelopt — loads and serves successfully
  • Text generation: correct ("The capital of France is Paris.")
  • Vision: correct (describes image content accurately)

🤖 Generated with Claude Code

Summary by CodeRabbit

  • New Features

    • Support quantizing models with separate base/full components (handles heads present only on the full model)
    • Enhanced Mixture-of-Experts detection and explicit support for Gemma4 expert layer layouts
    • Extended NVFP4 selective quantization presets and recipes to include expert-layer patterns and enable FP8 for expert modules
  • Bug Fixes

    • Improved loss/logit handling and clearer errors for unsupported quantization methods

@yueshen2016 yueshen2016 requested review from a team as code owners April 9, 2026 10:18
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 9, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-05-05 16:10 UTC

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds optional full_model plumbing to PTQ auto-quantization and a Gemma4-specific base-model path that projects hidden states with full_model.lm_head for loss/logit computation; extends MoE detection and Gemma4 expert-linear name mapping; expands NVFP4 selective quantization patterns and a PTQ recipe to include *.experts.*.

Changes

Cohort / File(s) Summary
PTQ two-part model handling
examples/llm_ptq/hf_ptq.py
Adds `full_model: torch.nn.Module
MoE detection & Gemma4 expert mapping
modelopt/torch/export/layer_utils.py
Extend is_moe detection to treat modules with a router and an experts nn.Module as MoE; add gemma4 handling in get_experts_list and map Gemma4TextDecoderLayer expert linear names to gate_proj, down_proj, up_proj.
NVFP4 selective quantization patterns
modelopt/torch/quantization/config.py
Update NVFP4_EXPERTS_ONLY_CFG and NVFP4_MLP_ONLY_CFG to include the *.experts.* wildcard in layer_patterns, broadening matched module names.
PTQ recipe: enable expert FP8 quantizers
modelopt_recipes/general/ptq/nvfp4_mlp_only-fp8_kv.yaml
Add two quantizer entries targeting *.experts.*weight_quantizer and *.experts.*input_quantizer, enabling FP8/dynamic scaling for expert modules while leaving other quantizer patterns unchanged.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 38.46% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: adding support for Gemma4 MoE quantization across multiple files and components.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed Comprehensive security analysis of all Python files found no critical anti-patterns including unsafe torch.load/numpy.load, hardcoded trust_remote_code, eval/exec usage, or # nosec comments.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch yueshen/gemma-4-moe

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/llm_ptq/hf_ptq.py (1)

1092-1097: ⚠️ Potential issue | 🟠 Major

Pass the configured auto-quantize options through this call.

This hunk only forwards full_model, so --auto_quantize_method, --auto_quantize_score_size, and --auto_quantize_checkpoint are still ignored here. The helper falls back to its defaults instead of the CLI values.

Possible fix
         auto_quantize(
             args,
             language_model,
             calib_dataloader,
+            auto_quantize_method=args.auto_quantize_method,
+            auto_quantize_score_size=args.auto_quantize_score_size,
+            auto_quantize_checkpoint=args.auto_quantize_checkpoint,
             full_model=full_model,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 1092 - 1097, The call to
auto_quantize currently only forwards args, language_model, calib_dataloader and
full_model, so CLI options for auto-quantization are ignored; update the call to
pass the configured auto-quantize options from args (e.g.
args.auto_quantize_method, args.auto_quantize_score_size,
args.auto_quantize_checkpoint — match the actual arg names) into auto_quantize
so the helper receives the CLI values rather than falling back to defaults
(refer to auto_quantize, args, language_model, calib_dataloader, full_model in
the diff).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 352-379: The gradient/kl_div base-model path passes the raw batch
(which includes "labels") into the extracted base model, causing a TypeError
because base models like Gemma4TextModel don't accept labels; add a small helper
(e.g., sanitize_batch or strip_non_inputs) that removes "labels" and any
non-forward kwargs from the batch, then call that helper inside both
forward_step implementations referenced in the is_base_model branch (where
full_model, lm_head, loss_func, forward_step and auto_quantize_method are
defined) so the model receives only valid forward inputs while loss_func still
reads labels from the original batch.

---

Outside diff comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1092-1097: The call to auto_quantize currently only forwards args,
language_model, calib_dataloader and full_model, so CLI options for
auto-quantization are ignored; update the call to pass the configured
auto-quantize options from args (e.g. args.auto_quantize_method,
args.auto_quantize_score_size, args.auto_quantize_checkpoint — match the actual
arg names) into auto_quantize so the helper receives the CLI values rather than
falling back to defaults (refer to auto_quantize, args, language_model,
calib_dataloader, full_model in the diff).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c05034eb-0929-498f-be1d-c052746d2eba

📥 Commits

Reviewing files that changed from the base of the PR and between 04cd596 and c79ebc0.

📒 Files selected for processing (4)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/export/layer_utils.py
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/plugins/huggingface.py

Comment thread examples/llm_ptq/hf_ptq.py Outdated
Comment thread examples/llm_ptq/hf_ptq.py Outdated
Comment thread modelopt/torch/quantization/plugins/huggingface.py Outdated
Edwardf0t1 added a commit that referenced this pull request Apr 15, 2026
PTQ can silently skip MoE expert quantization when config patterns
(*mlp*, *block_sparse_moe*) don't match the model's naming convention
(e.g., Gemma4 uses layers.N.experts.* instead of mlp.experts.*).
This causes deployment failures downstream when vLLM/SGLang tries to
load unquantized experts as quantized.

Add Step 5 validation to detect this:
- Compare exported weight names against scale params and exclude list
- Flag weights with no scales that aren't in exclude_modules
- Reference the deployment "quant/unquant layer confusion" pattern

Also add MoE expert verification to unsupported-models.md debugging tips.

Learned from: Gemma4-26B-A4B NVFP4 PTQ succeeded but experts were
BF16, causing vLLM FusedMoE shape mismatch at deployment time.
Fix tracked in PR #1219.

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Comment thread modelopt/torch/quantization/plugins/huggingface.py Outdated
Edwardf0t1 added a commit that referenced this pull request Apr 15, 2026
PTQ can silently skip MoE expert quantization when config patterns
(*mlp*, *block_sparse_moe*) don't match the model's naming convention
(e.g., Gemma4 uses layers.N.experts.* instead of mlp.experts.*).
This causes deployment failures downstream when vLLM/SGLang tries to
load unquantized experts as quantized.

Add Step 5 validation to detect this:
- Compare exported weight names against scale params and exclude list
- Flag weights with no scales that aren't in exclude_modules
- Reference the deployment "quant/unquant layer confusion" pattern

Also add MoE expert verification to unsupported-models.md debugging tips.

Learned from: Gemma4-26B-A4B NVFP4 PTQ succeeded but experts were
BF16, causing vLLM FusedMoE shape mismatch at deployment time.
Fix tracked in PR #1219.

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
Edwardf0t1 added a commit that referenced this pull request Apr 16, 2026
PTQ can silently skip MoE expert quantization when config patterns
(*mlp*, *block_sparse_moe*) don't match the model's naming convention
(e.g., Gemma4 uses layers.N.experts.* instead of mlp.experts.*).
This causes deployment failures downstream when vLLM/SGLang tries to
load unquantized experts as quantized.

Add Step 5 validation to detect this:
- Compare exported weight names against scale params and exclude list
- Flag weights with no scales that aren't in exclude_modules
- Reference the deployment "quant/unquant layer confusion" pattern

Also add MoE expert verification to unsupported-models.md debugging tips.

Learned from: Gemma4-26B-A4B NVFP4 PTQ succeeded but experts were
BF16, causing vLLM FusedMoE shape mismatch at deployment time.
Fix tracked in PR #1219.

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
yueshen2016 and others added 4 commits April 29, 2026 21:07
For VLMs like Gemma4 where the extracted language_model lacks lm_head,
use the full_model's lm_head to compute logits/loss from hidden states.

How to run:
cd /opt/Model-Optimizer/examples/llm_ptq && python hf_ptq.py \
  --pyt_ckpt_path /lustre/fsw/portfolios/coreai/users/yueshen/models/gemma-4-31B-it \
  --qformat nvfp4,fp8 \
  --auto_quantize_bits 6.0 \
  --calib_size 512 \
  --dataset cnn_dailymail \
  --export_path /lustre/fsw/portfolios/coreai/users/yueshen/models/gemma-4-31B-it-autoquant-6.0

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: James Shen <yueshen@nvidia.com>
Add assert for full_model to satisfy mypy union-attr check, and add
blank lines before nested def statements per ruff formatting rules.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: James Shen <yueshen@nvidia.com>
Gemma4 MoE models (e.g. google/gemma-4-26B-A4B-it) store expert weights
as fused 3D nn.Parameter tensors instead of nn.ModuleList of nn.Linear,
causing the quantizer to silently skip expert weights.

- Register Gemma4TextExperts with _QuantQwen35MoeExperts plugin to unfuse
  3D tensors into per-expert nn.Linear layers for quantization
- Add structural is_moe() detection for modules with router + experts
  attributes (Gemma4 has no dedicated SparseMoeBlock class)
- Add Gemma4TextDecoderLayer to get_expert_linear_names()
- Add "*.experts.*" pattern to NVFP4_MLP_ONLY_CFG and
  NVFP4_EXPERTS_ONLY_CFG to match Gemma4's expert path
  (experts are at model.layers.X.experts, not under mlp)

Signed-off-by: Yue Shen <yueshen@nvidia.com>
Signed-off-by: James Shen <yueshen@nvidia.com>
Example usage for Gemma4 MoE quantization:

cd /opt/Model-Optimizer/examples/llm_ptq && python hf_ptq.py \
    --pyt_ckpt_path /models/gemma-4-26B-A4B-it \
    --qformat nvfp4_mlp_only \
    --calib_size 512 \
    --dataset cnn_dailymail \
    --export_path /models/gemma-4-26B-A4B-it-nvfp4

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: James Shen <yueshen@nvidia.com>
@yueshen2016 yueshen2016 force-pushed the yueshen/gemma-4-moe branch from c79ebc0 to 509d256 Compare April 30, 2026 04:11
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/llm_ptq/hf_ptq.py (1)

1069-1076: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Assign the result of auto_quantize() back to language_model.

auto_quantize() returns the updated module, but this call site drops it. If the search path replaces the module instance, the rest of quantize_main() will keep exporting the stale reference, and extracted-VLM flows will leave full_model wired to the pre-quantized submodule. Please mirror the mono_quantize() pattern here and rebind/reattach the returned module.

Suggested fix
-        auto_quantize(
+        language_model = auto_quantize(
             args,
             language_model,
             calib_dataloader,
             auto_quantize_method=args.auto_quantize_method,
             auto_quantize_score_size=args.auto_quantize_score_size,
             auto_quantize_checkpoint=args.auto_quantize_checkpoint,
             full_model=full_model,
         )
+        language_model_lineage = get_language_model_from_vl(full_model)
+        if language_model_lineage is not None:
+            language_model_lineage[-2].language_model = language_model
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_ptq/hf_ptq.py` around lines 1069 - 1076, The call to
auto_quantize(...) in quantize_main drops its return value so the
possibly-replaced module isn't reattached; rebind language_model to the returned
module (like mono_quantize does) so subsequent code and full_model reference the
updated submodule—i.e., assign language_model = auto_quantize(...) and ensure
any places that expect the updated module (e.g., full_model composition/export)
use this new reference.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/layer_utils.py`:
- Around line 318-325: The new structural branch in is_moe() causes
Gemma4TextDecoderLayer to be treated as MoE but get_experts_list() still rejects
Gemma4 types, so either extend get_experts_list() to return experts for Gemma4
layers (e.g., recognize Gemma4TextDecoderLayer and extract module.experts/router
similarly) or restrict is_moe() so it only returns True for structural MoE when
the module type is supported by get_experts_list(); update the logic in is_moe()
and/or get_experts_list() (referencing is_moe() and get_experts_list()) so the
two functions remain consistent and Gemma4 export no longer trips the MoE branch
without a compatible expert-list implementation.

---

Outside diff comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 1069-1076: The call to auto_quantize(...) in quantize_main drops
its return value so the possibly-replaced module isn't reattached; rebind
language_model to the returned module (like mono_quantize does) so subsequent
code and full_model reference the updated submodule—i.e., assign language_model
= auto_quantize(...) and ensure any places that expect the updated module (e.g.,
full_model composition/export) use this new reference.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d8897c78-c1a7-4df0-9a38-5ae33519815f

📥 Commits

Reviewing files that changed from the base of the PR and between c79ebc0 and 509d256.

📒 Files selected for processing (2)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/export/layer_utils.py

Comment thread modelopt/torch/export/layer_utils.py
- Strip `labels` from batch before passing to base text models that
  don't accept it (e.g. Gemma4TextModel) in auto_quantize forward_step
- Pass CLI auto-quantize options (method, score_size, checkpoint)
  through to auto_quantize() instead of falling back to defaults
- Drop explicit Gemma4TextExperts/Qwen3_5MoeExperts registration;
  handled by register_fused_experts_on_the_fly auto-detection
- Clarify base-model detection comment as VLM-generic, not Gemma4-specific

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: James Shen <yueshen@nvidia.com>
@yueshen2016 yueshen2016 force-pushed the yueshen/gemma-4-moe branch from 509d256 to 53de662 Compare April 30, 2026 04:23
Extend get_experts_list() to handle Gemma4 models so that the export
path doesn't hit NotImplementedError when is_moe() structurally detects
Gemma4TextDecoderLayer as MoE.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: James Shen <yueshen@nvidia.com>
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Small, focused PR (85 lines) adding Gemma4 MoE quantization support. All critical previous review comments have been addressed (labels stripping, args forwarding, is_moe/get_experts_list consistency, generic naming).

Remaining concerns:

  1. No unit tests — The PR includes only integration-level test evidence (manual runs). No automated tests for the new is_moe structural detection, the new get_expert_linear_names branch, or the base-model auto_quantize path. This is a recurring pattern in this repo but worth flagging.

  2. Broad structural is_moe() detection — The new structural check (hasattr(module, "router") and hasattr(module, "experts") and isinstance(module.experts, nn.Module)) is quite broad. Any future module with router + experts attributes would be classified as MoE. This works for Gemma4 today, but consider adding a comment noting this is intentionally broad, or tightening it (e.g., checking the module type name contains "decoder" or "layer").

  3. auto_quantize return value not captured (minor) — The CodeRabbit concern about not capturing auto_quantize()'s return in quantize_main was not addressed, but this follows the same in-place-modification pattern as mono_quantize, so it's likely fine in practice.

Overall the changes look correct and well-scoped. Nudging for human review primarily due to missing unit tests.

Sync the YAML recipe with NVFP4_MLP_ONLY_CFG which was updated to
include *.experts.* for Gemma4 MoE support.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: James Shen <yueshen@nvidia.com>
@yueshen2016 yueshen2016 requested a review from a team as a code owner April 30, 2026 14:11
@yueshen2016
Copy link
Copy Markdown
Contributor Author

Thanks for the thorough review!

1. Unit tests — Agreed this deserves coverage. The is_moe structural detection and get_expert_linear_names branch are hard to test without pulling in the full Gemma4 model weights, but I can add a lightweight unit test with a mock module that has router + experts attributes. Will add in a follow-up or in this PR if you prefer.

2. Broad is_moe() detection — The broadness is intentional — the structural check is a fallback after the name-based checks, and the router + experts convention is the standard HuggingFace MoE pattern (used by Gemma4, and likely future models). I'll add a comment clarifying this. That said, it only affects the export path, and get_experts_list() already has its own validation — so even if a false positive slips through is_moe(), export would fail explicitly rather than silently produce wrong results.

3. auto_quantize return value — Correct, auto_quantize modifies language_model in-place (via mtq.quantize which mutates the module). Same pattern as mono_quantize. No action needed.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 30, 2026

Codecov Report

❌ Patch coverage is 77.77778% with 2 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.10%. Comparing base (a492fa9) to head (01fc243).
⚠️ Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/export/layer_utils.py 71.42% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1219      +/-   ##
==========================================
+ Coverage   66.35%   76.10%   +9.75%     
==========================================
  Files         471      471              
  Lines       50500    51602    +1102     
==========================================
+ Hits        33508    39274    +5766     
+ Misses      16992    12328    -4664     
Flag Coverage Δ
examples 41.57% <44.44%> (+0.71%) ⬆️
gpu 59.74% <55.55%> (+32.66%) ⬆️
regression 14.90% <22.22%> (-0.01%) ⬇️
unit 52.88% <77.77%> (+0.11%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread examples/llm_ptq/hf_ptq.py Outdated
Unify forward_step definitions across is_base_model branches to reduce
duplication. loss_func remains split as the logic differs fundamentally.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: James Shen <yueshen@nvidia.com>
@yueshen2016 yueshen2016 force-pushed the yueshen/gemma-4-moe branch from dbd31df to ea21ea4 Compare May 1, 2026 02:59
@yueshen2016 yueshen2016 requested a review from cjluo-nv May 1, 2026 03:11
@yueshen2016 yueshen2016 force-pushed the yueshen/gemma-4-moe branch 2 times, most recently from 18333f1 to 5238d63 Compare May 1, 2026 04:46
@yueshen2016 yueshen2016 requested review from h-guo18 and sychen52 May 1, 2026 08:28
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

All critical previous review comments have been addressed. Unit tests are now included covering is_moe and get_expert_linear_names. The code correctly extends existing MoE quantization infrastructure for Gemma4. One minor issue: the new test file has copyright year 2024, should be 2025 (or 2026 per the current LICENSE_HEADER).

Comment thread tests/unit/torch/export/test_layer_utils.py Outdated
…names

Test is_moe with name-based detection (sparsemoeblock, moelayer, arcticmoe),
structural detection (router + experts), and negative cases. Test
get_expert_linear_names for Gemma4, Mixtral, and NemotronH module types.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: James Shen <yueshen@nvidia.com>
@yueshen2016 yueshen2016 force-pushed the yueshen/gemma-4-moe branch from 5238d63 to 01fc243 Compare May 4, 2026 07:58

metadata:
recipe_type: ptq
description: NVFP4 static weight and dynamic activation for all linear layers (W4A4), FP8 KV cache, max calibration.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we want to update this description

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have suggestions on what description to be adopted?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjluo-nv @meenchen might be the original authors who can update the description later.

@yueshen2016 yueshen2016 merged commit ef326c8 into main May 5, 2026
47 checks passed
@yueshen2016 yueshen2016 deleted the yueshen/gemma-4-moe branch May 5, 2026 16:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants