From 334353972a62e84818c5121e514433b3e7510520 Mon Sep 17 00:00:00 2001 From: ai-hpc Date: Tue, 16 Jun 2026 19:44:02 +0000 Subject: [PATCH] feat: Add Gemma 4 (gemma4_text) checkpoint recognition and config parsing (#72) Gemma 4 (google/gemma-4-E2B-it and siblings) departs from the Llama-style dense template the export frontend assumes: per-layer-type attention head dimensions (head_dim vs global_head_dim), dual/proportional RoPE, Per-Layer Embeddings (PLE), KV sharing across the trailing layers, GeGLU double-wide MLP, unit query scaling, and final-logit soft-capping. This is phase 1 of Gemma 4 support: recognize and parse a gemma4_text checkpoint into the new ModelConfig.gemma4_cfg (Gemma4Config), and register a Gemma4CausalLM that fails with an explicit, actionable NotImplementedError instead of being silently mis-built through the dense CausalLM path. ONNX export and the C++ runtime changes land in follow-up PRs (one concern per PR). Also harden ModelConfig.from_pretrained against optional MoE keys that are explicitly null (dense Gemma 4 sets moe_intermediate_size: null), which previously raised TypeError in int(None). Adds tests/python-unittests/test_gemma4_config.py covering the parser, the ModelConfig integration, the null-MoE regression, and the fail-loud class. Signed-off-by: ai-hpc --- CHANGELOG.md | 3 + tensorrt_edgellm/__init__.py | 5 + tensorrt_edgellm/config.py | 172 +++++++++++++++++- tensorrt_edgellm/models/__init__.py | 1 + tensorrt_edgellm/models/gemma4/__init__.py | 19 ++ .../models/gemma4/modeling_gemma4_text.py | 100 ++++++++++ tests/python-unittests/test_gemma4_config.py | 163 +++++++++++++++++ 7 files changed, 461 insertions(+), 2 deletions(-) create mode 100644 tensorrt_edgellm/models/gemma4/__init__.py create mode 100644 tensorrt_edgellm/models/gemma4/modeling_gemma4_text.py create mode 100644 tests/python-unittests/test_gemma4_config.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b2548e9..a44ef670 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Release Notes +## Unreleased +- Added Gemma 4 (`gemma4_text`) checkpoint recognition and architecture config parsing in the checkpoint export frontend; ONNX export and runtime support are in progress (see [#72](https://github.com/NVIDIA/TensorRT-Edge-LLM/issues/72)) + ## 0.8.0 - Externalized INT4 FFN, INT4 MoE, and LM-head weights to reduce engine build memory usage - Upgraded plugins to TensorRT Plugin V3 for TensorRT 11 readiness diff --git a/tensorrt_edgellm/__init__.py b/tensorrt_edgellm/__init__.py index 5816a460..b47b9f53 100644 --- a/tensorrt_edgellm/__init__.py +++ b/tensorrt_edgellm/__init__.py @@ -37,6 +37,7 @@ from .config import ModelConfig, QuantConfig from .model import AutoModel, register_model # Register model-type-specific implementations +from .models.gemma4.modeling_gemma4_text import Gemma4CausalLM from .models.nemotron_h.modeling_nemotron_h import NemotronHCausalLM from .models.qwen3_5.modeling_qwen3_5_text import Qwen3_5CausalLM from .models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeCausalLM @@ -46,6 +47,10 @@ from .onnx.export import export_onnx register_model("nemotron_h", NemotronHCausalLM) +# Gemma 4: ``gemma4`` is the multimodal root model_type, ``gemma4_text`` the +# promoted text sub-config; register both so either is dispatched here. +register_model("gemma4", Gemma4CausalLM) +register_model("gemma4_text", Gemma4CausalLM) register_model("qwen3_5_text", Qwen3_5CausalLM) register_model("qwen3_5_moe_text", Qwen3_5MoeCausalLM) register_model("qwen3_5_moe", Qwen3_5MoeCausalLM) diff --git a/tensorrt_edgellm/config.py b/tensorrt_edgellm/config.py index cfde4e8e..88907ee3 100644 --- a/tensorrt_edgellm/config.py +++ b/tensorrt_edgellm/config.py @@ -271,6 +271,85 @@ def conv_dim(self) -> int: return self.key_dim + self.key_dim + self.value_dim +@dataclass +class Gemma4Config: + """Gemma 4 (``gemma4_text``) architecture parameters. + + Gemma 4 interleaves local *sliding-window* attention with global *full* + attention, and the two layer types differ in more than the attention mask: + they use different head dimensions (``head_dim`` vs ``global_head_dim``) and + different RoPE parameters (a plain ``rope_theta`` for local layers, a + proportional RoPE with its own ``rope_theta`` and ``partial_rotary_factor`` + for global layers). Gemma 4 also adds Per-Layer Embeddings (PLE), key/value + sharing across the trailing decoder layers, and final-logit soft-capping. + + None of these fit the flat :class:`ModelConfig` shared by Llama-style dense + models, so they are collected here and attached as + :attr:`ModelConfig.gemma4_cfg`. See ``models/gemma4`` for the consuming + modeling code. + """ + + #: Per-layer attention type: ``True`` for global (full) attention, ``False`` + #: for local (sliding-window) attention. Length == ``num_hidden_layers``. + is_global_layer: List[bool] + #: Head dimension for local (sliding) attention layers (``head_dim``). + local_head_dim: int + #: Head dimension for global (full) attention layers (``global_head_dim``). + global_head_dim: int + #: Sliding-window size used by the local attention layers. + sliding_window: int + #: Number of trailing decoder layers that reuse a previous layer's KV + #: projections (``num_kv_shared_layers``); 0 means no sharing. + num_kv_shared_layers: int + #: KV heads for global layers (``num_global_key_value_heads``); falls back to + #: ``num_key_value_heads`` when the checkpoint leaves it unset. + num_global_key_value_heads: int + #: When ``True``, global layers reuse the key projection as the value + #: projection (``attention_k_eq_v``). + attention_k_eq_v: bool + #: RoPE base frequency for local (sliding) layers. + sliding_rope_theta: float + #: RoPE base frequency for global (full) layers. + global_rope_theta: float + #: Fraction of ``global_head_dim`` that receives RoPE on global layers + #: (Proportional RoPE; the global rope params' ``partial_rotary_factor``). + global_partial_rotary_factor: float + #: RoPE variant for global layers (e.g. ``"proportional"``). + global_rope_type: str + #: Per-Layer-Embedding hidden width (``hidden_size_per_layer_input``); 0 + #: disables PLE. + hidden_size_per_layer_input: int + #: Vocabulary size of the PLE table (``vocab_size_per_layer_input``). + vocab_size_per_layer_input: int + #: Final-logit soft-capping value (``final_logit_softcapping``); ``None`` when + #: the checkpoint does not soft-cap logits. + final_logit_softcapping: Optional[float] + #: When ``True``, the KV-shared layers use a double-width MLP intermediate + #: size (``use_double_wide_mlp``). + use_double_wide_mlp: bool + #: Feed-forward activation name (e.g. ``"gelu_pytorch_tanh"``). + hidden_activation: str + + def first_kv_shared_layer_idx(self, num_hidden_layers: int) -> int: + """Index of the first KV-sharing layer (== ``num_hidden_layers`` if none).""" + return num_hidden_layers - self.num_kv_shared_layers + + @property + def num_global_layers(self) -> int: + """Count of global (full) attention layers.""" + return sum(1 for is_global in self.is_global_layer if is_global) + + @property + def num_local_layers(self) -> int: + """Count of local (sliding-window) attention layers.""" + return sum(1 for is_global in self.is_global_layer if not is_global) + + @property + def uses_per_layer_embeddings(self) -> bool: + """True when the model carries a Per-Layer Embedding pathway.""" + return self.hidden_size_per_layer_input > 0 + + @dataclass class ModelConfig: """Flat model hyper-parameter config consumed by module builders.""" @@ -320,6 +399,11 @@ class ModelConfig: mamba_cfg: Optional[MambaConfig] = None # ------------------------------------------ gdn / hybrid config gdn_cfg: Optional[GdnConfig] = None + # ------------------------------------------ gemma4 (gemma4_text) config + # Populated only for Gemma 4 checkpoints; carries the per-layer-type + # attention, Per-Layer-Embedding, KV-sharing, and soft-capping parameters + # that do not fit the flat dense-model fields above. + gemma4_cfg: Optional[Gemma4Config] = None # ------------------------------------------ gated attention (Qwen3.5) attn_output_gate: bool = False # ------------------------------------------ MTP config @@ -383,6 +467,10 @@ def eagle3_target_hidden_size(self) -> int: def is_hybrid(self) -> bool: return self.mamba_cfg is not None or self.gdn_cfg is not None + @property + def is_gemma4(self) -> bool: + return self.gemma4_cfg is not None + @property def is_nemotron_h(self) -> bool: return (self.model_type or "").lower().startswith("nemotron_h") @@ -447,6 +535,7 @@ def from_pretrained(cls, model_dir: str) -> "ModelConfig": layer_types, model_dir=model_dir) gdn_cfg = _parse_gdn_cfg(llm_dict, layer_types) + gemma4_cfg = _parse_gemma4_cfg(llm_dict) has_qk_norm = _detect_has_qk_norm(model_dir) # MTP config @@ -474,8 +563,11 @@ def from_pretrained(cls, model_dir: str) -> "ModelConfig": num_experts = int( llm_dict.get("num_experts", llm_dict.get("num_local_experts", 0)) or 0) - num_experts_per_tok = int(llm_dict.get("num_experts_per_tok", 0)) - moe_intermediate_size = int(llm_dict.get("moe_intermediate_size", 0)) + # ``or 0`` (rather than a get-default) so configs that explicitly set + # these optional keys to ``null`` — e.g. dense Gemma 4 checkpoints with + # ``"moe_intermediate_size": null`` — parse as 0 instead of crashing. + num_experts_per_tok = int(llm_dict.get("num_experts_per_tok") or 0) + moe_intermediate_size = int(llm_dict.get("moe_intermediate_size") or 0) moe_shared_expert_intermediate_size = int( llm_dict.get("moe_shared_expert_intermediate_size", llm_dict.get("shared_expert_intermediate_size", 0)) @@ -523,6 +615,7 @@ def from_pretrained(cls, model_dir: str) -> "ModelConfig": quant=quant, mamba_cfg=mamba_cfg, gdn_cfg=gdn_cfg, + gemma4_cfg=gemma4_cfg, attn_output_gate=bool(llm_dict.get("attn_output_gate", False)), mtp_num_hidden_layers=mtp_num_hidden_layers, mtp_use_dedicated_embeddings=mtp_use_dedicated_embeddings, @@ -815,6 +908,81 @@ def _parse_gdn_cfg(config: dict, ) +# Default RoPE base for Gemma 4 global (full) attention layers when the +# checkpoint's ``rope_parameters`` omits it. +_GEMMA4_DEFAULT_GLOBAL_ROPE_THETA = 1_000_000.0 + + +def _parse_gemma4_cfg(llm_dict: Dict[str, Any]) -> Optional[Gemma4Config]: + """Return a :class:`Gemma4Config` for ``gemma4_text`` checkpoints, else None. + + Gemma 4 stores ``rope_parameters`` as a mapping keyed by layer type + (``"sliding_attention"`` / ``"full_attention"``) rather than a single flat + block, and ``layer_types`` marks each decoder layer as sliding (local) or + full (global). Both are parsed here, along with the per-layer-type head + dimensions, KV-sharing, Per-Layer-Embedding, and soft-capping fields. + + Returns ``None`` for non-Gemma-4 checkpoints so the parser is inert for + every other model family. + """ + model_type = str(llm_dict.get("model_type", "")).lower() + if not model_type.startswith("gemma4"): + return None + + raw_layer_types = llm_dict.get("layer_types") or [] + is_global_layer = [ + str(layer_type).lower() == "full_attention" + for layer_type in raw_layer_types + ] + + rope_params = llm_dict.get("rope_parameters") + if not isinstance(rope_params, dict): + rope_params = {} + sliding_rope = rope_params.get("sliding_attention") or {} + global_rope = rope_params.get("full_attention") or {} + + num_key_value_heads = int(llm_dict.get("num_key_value_heads", 1)) + raw_global_kv_heads = llm_dict.get("num_global_key_value_heads") + num_global_key_value_heads = num_key_value_heads + if raw_global_kv_heads is not None: + num_global_key_value_heads = int(raw_global_kv_heads) + + hidden_size = int(llm_dict["hidden_size"]) + num_attention_heads = int(llm_dict["num_attention_heads"]) + local_head_dim = int( + llm_dict.get("head_dim", hidden_size // num_attention_heads)) + global_head_dim = int(llm_dict.get("global_head_dim") or local_head_dim) + + softcap = llm_dict.get("final_logit_softcapping") + final_logit_softcapping = float(softcap) if softcap is not None else None + + return Gemma4Config( + is_global_layer=is_global_layer, + local_head_dim=local_head_dim, + global_head_dim=global_head_dim, + sliding_window=int(llm_dict.get("sliding_window") or 0), + num_kv_shared_layers=int(llm_dict.get("num_kv_shared_layers") or 0), + num_global_key_value_heads=num_global_key_value_heads, + attention_k_eq_v=bool(llm_dict.get("attention_k_eq_v", False)), + sliding_rope_theta=float( + sliding_rope.get("rope_theta", _DEFAULT_ROPE_THETA)), + global_rope_theta=float( + global_rope.get("rope_theta", _GEMMA4_DEFAULT_GLOBAL_ROPE_THETA)), + global_partial_rotary_factor=float( + global_rope.get("partial_rotary_factor", 1.0)), + global_rope_type=str(global_rope.get("rope_type", "default")), + hidden_size_per_layer_input=int( + llm_dict.get("hidden_size_per_layer_input") or 0), + vocab_size_per_layer_input=int( + llm_dict.get("vocab_size_per_layer_input") + or llm_dict.get("vocab_size") or 0), + final_logit_softcapping=final_logit_softcapping, + use_double_wide_mlp=bool(llm_dict.get("use_double_wide_mlp", False)), + hidden_activation=str( + llm_dict.get("hidden_activation", "gelu_pytorch_tanh")), + ) + + def _get_partial_rotary_factor(llm_dict: Dict[str, Any]) -> float: """Extract partial_rotary_factor from config dict. diff --git a/tensorrt_edgellm/models/__init__.py b/tensorrt_edgellm/models/__init__.py index 06275166..eff39aea 100644 --- a/tensorrt_edgellm/models/__init__.py +++ b/tensorrt_edgellm/models/__init__.py @@ -17,6 +17,7 @@ Model variants live in sub-packages named after the model family: ``default/`` - standard decoder transformer + Mamba hybrid + ``gemma4/`` - Gemma 4 (gemma4_text) export frontend (config parsing; export WIP) ``nemotron_h/`` - Nemotron-H hybrid (Mamba2 + attention) ``nemotron_omni/`` - Nemotron-Omni: RADIO visual + Parakeet audio (LLM reuses nemotron_h) ``qwen3_vl/`` - Qwen3-VL visual encoder + LLM diff --git a/tensorrt_edgellm/models/gemma4/__init__.py b/tensorrt_edgellm/models/gemma4/__init__.py new file mode 100644 index 00000000..c3f924f8 --- /dev/null +++ b/tensorrt_edgellm/models/gemma4/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gemma 4 (``gemma4_text``) export frontend package.""" + +from .modeling_gemma4_text import Gemma4CausalLM + +__all__ = ["Gemma4CausalLM"] diff --git a/tensorrt_edgellm/models/gemma4/modeling_gemma4_text.py b/tensorrt_edgellm/models/gemma4/modeling_gemma4_text.py new file mode 100644 index 00000000..487513b0 --- /dev/null +++ b/tensorrt_edgellm/models/gemma4/modeling_gemma4_text.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Gemma 4 (``gemma4_text``) causal LM — checkpoint export frontend. + +This is the first phase of Gemma 4 support (issue +https://github.com/NVIDIA/TensorRT-Edge-LLM/issues/72). It teaches the frontend +to *recognize* a Gemma 4 text checkpoint and parse its architecture into +:class:`~tensorrt_edgellm.config.Gemma4Config`, but ONNX export and C++ runtime +support are not implemented yet. + +Why a dedicated (currently fail-loud) class instead of the dense +:class:`~tensorrt_edgellm.models.default.modeling_default.CausalLM`? Gemma 4 +departs from the Llama-style dense template in ways that produce silently wrong +results — not just errors — if exported through the default path: + +* Per-layer-type attention: local *sliding* layers use ``head_dim`` (256) while + global *full* layers use ``global_head_dim`` (512). +* Dual RoPE: local layers use a plain RoPE (``theta=10000``); global layers use + a *proportional* RoPE (``theta=1e6``) over only ``partial_rotary_factor`` of + the head. +* Per-Layer Embeddings (PLE): a second embedding pathway feeds a residual signal + into every decoder layer; the C++ runtime would have to compute and thread it + into the engine alongside ``inputs_embeds``. +* KV sharing: the trailing ``num_kv_shared_layers`` layers reuse a previous + layer's keys/values. +* Soft-capped final logits, GeGLU (``gelu_pytorch_tanh``) MLP with a double-wide + variant on the KV-shared layers, and ``(query) scaling = 1.0`` rather than the + usual ``1/sqrt(head_dim)``. + +Several of those live in the C++ runtime / CUDA kernels (per-layer head dim, +dual RoPE, KV sharing, attention scaling, PLE input plumbing) and land in +follow-up phases. Registering this class makes a Gemma 4 checkpoint fail with +an explicit, actionable message instead of being mis-built as a Llama-style +model with the wrong head dimensions, RoPE, and normalization. +""" + +import torch.nn as nn + +from ...config import ModelConfig + +__all__ = ["Gemma4CausalLM"] + +_TRACKING_ISSUE = "https://github.com/NVIDIA/TensorRT-Edge-LLM/issues/72" + + +def _detected_summary(config: ModelConfig) -> str: + """One-line summary of the parsed Gemma 4 architecture, for the error text.""" + cfg = config.gemma4_cfg + if cfg is None: + return f"model_type={config.model_type!r}, {config.num_hidden_layers} layers" + return ( + f"{config.num_hidden_layers} layers " + f"({cfg.num_local_layers} local sliding + {cfg.num_global_layers} global full), " + f"head_dim local={cfg.local_head_dim}/global={cfg.global_head_dim}, " + f"RoPE theta local={cfg.sliding_rope_theta:g}/global={cfg.global_rope_theta:g}, " + f"PLE dim={cfg.hidden_size_per_layer_input}, " + f"kv_shared_layers={cfg.num_kv_shared_layers}, " + f"final_logit_softcapping={cfg.final_logit_softcapping}") + + +def _unsupported_message(config: ModelConfig) -> str: + """Build the actionable NotImplementedError message for Gemma 4 export.""" + return ( + "Gemma 4 (gemma4_text) export is not implemented yet.\n" + f"Detected: {_detected_summary(config)}.\n" + "The frontend can parse this checkpoint, but ONNX export and the C++ " + "runtime still need: per-layer-type head dimensions, dual/proportional " + "RoPE, Per-Layer-Embedding (PLE) input plumbing, KV sharing across the " + "trailing layers, unit query scaling, and final-logit soft-capping.\n" + f"Track progress at {_TRACKING_ISSUE}.") + + +class Gemma4CausalLM(nn.Module): + """Gemma 4 causal LM placeholder that fails loudly until export lands. + + Constructing the model raises :class:`NotImplementedError` with a summary of + the parsed architecture and the remaining work. See the module docstring + for why Gemma 4 cannot reuse the dense export path. + + Args: + config: Parsed model configuration; expected to carry a populated + :attr:`~tensorrt_edgellm.config.ModelConfig.gemma4_cfg`. + """ + + def __init__(self, config: ModelConfig) -> None: + super().__init__() + raise NotImplementedError(_unsupported_message(config)) diff --git a/tests/python-unittests/test_gemma4_config.py b/tests/python-unittests/test_gemma4_config.py new file mode 100644 index 00000000..87ec98c2 --- /dev/null +++ b/tests/python-unittests/test_gemma4_config.py @@ -0,0 +1,163 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for Gemma 4 (``gemma4_text``) checkpoint config parsing. + +Covers :func:`tensorrt_edgellm.config._parse_gemma4_cfg`, the +:class:`tensorrt_edgellm.config.Gemma4Config` it produces, and the +``ModelConfig.from_pretrained`` integration — including the regression that a +dense Gemma 4 checkpoint (which sets optional MoE fields to ``null``) parses +without crashing. Also checks that the registered :class:`Gemma4CausalLM` +fails with an actionable message until export support lands. +""" + +import json +import os + +import pytest + +from tensorrt_edgellm.config import ModelConfig, _parse_gemma4_cfg + +# The text sub-config of google/gemma-4-E2B-it (the LLM backbone fields the +# export frontend consumes). Kept faithful to the published checkpoint. +_GEMMA4_E2B_TEXT_CONFIG = { + "attention_bias": False, + "attention_k_eq_v": False, + "enable_moe_block": False, + "expert_intermediate_size": None, + "final_logit_softcapping": 30.0, + "global_head_dim": 512, + "head_dim": 256, + "hidden_activation": "gelu_pytorch_tanh", + "hidden_size": 1536, + "hidden_size_per_layer_input": 256, + "intermediate_size": 6144, + "layer_types": (["sliding_attention"] * 4 + ["full_attention"]) * 7, + "max_position_embeddings": 131072, + "model_type": "gemma4_text", + "num_attention_heads": 8, + "num_experts": None, + "num_global_key_value_heads": None, + "num_hidden_layers": 35, + "num_key_value_heads": 1, + "num_kv_shared_layers": 20, + "moe_intermediate_size": None, + "rms_norm_eps": 1e-06, + "rope_parameters": { + "full_attention": { + "partial_rotary_factor": 0.25, + "rope_theta": 1000000.0, + "rope_type": "proportional", + }, + "sliding_attention": { + "rope_theta": 10000.0, + "rope_type": "default", + }, + }, + "sliding_window": 512, + "tie_word_embeddings": True, + "top_k_experts": None, + "use_double_wide_mlp": True, + "vocab_size": 262144, + "vocab_size_per_layer_input": 262144, +} + + +def _write_checkpoint(tmp_path) -> str: + """Write a minimal multimodal Gemma 4 config.json and return its directory.""" + root = { + "architectures": ["Gemma4ForConditionalGeneration"], + "model_type": "gemma4", + "tie_word_embeddings": True, + "text_config": dict(_GEMMA4_E2B_TEXT_CONFIG), + } + with open(os.path.join(tmp_path, "config.json"), "w") as f: + json.dump(root, f) + return str(tmp_path) + + +def test_parse_gemma4_cfg_fields(): + cfg = _parse_gemma4_cfg(_GEMMA4_E2B_TEXT_CONFIG) + assert cfg is not None + assert cfg.local_head_dim == 256 + assert cfg.global_head_dim == 512 + assert cfg.sliding_window == 512 + assert cfg.num_kv_shared_layers == 20 + # num_global_key_value_heads is null -> falls back to num_key_value_heads. + assert cfg.num_global_key_value_heads == 1 + assert cfg.attention_k_eq_v is False + assert cfg.sliding_rope_theta == 10000.0 + assert cfg.global_rope_theta == 1000000.0 + assert cfg.global_partial_rotary_factor == 0.25 + assert cfg.global_rope_type == "proportional" + assert cfg.hidden_size_per_layer_input == 256 + assert cfg.vocab_size_per_layer_input == 262144 + assert cfg.final_logit_softcapping == 30.0 + assert cfg.use_double_wide_mlp is True + assert cfg.hidden_activation == "gelu_pytorch_tanh" + + +def test_parse_gemma4_cfg_layer_pattern(): + cfg = _parse_gemma4_cfg(_GEMMA4_E2B_TEXT_CONFIG) + assert len(cfg.is_global_layer) == 35 + assert cfg.num_global_layers == 7 + assert cfg.num_local_layers == 28 + assert cfg.uses_per_layer_embeddings is True + assert cfg.first_kv_shared_layer_idx(35) == 15 + # Every 5th layer (the trailing layer of each 4-local + 1-global block). + assert [i for i, g in enumerate(cfg.is_global_layer) + if g] == [4, 9, 14, 19, 24, 29, 34] + + +@pytest.mark.parametrize("model_type", ["llama", "qwen3_moe", "gemma2", ""]) +def test_parse_gemma4_cfg_inert_for_non_gemma4(model_type): + assert _parse_gemma4_cfg({"model_type": model_type}) is None + + +def test_model_config_from_pretrained(tmp_path): + config = ModelConfig.from_pretrained(_write_checkpoint(tmp_path)) + assert config.is_gemma4 is True + assert config.gemma4_cfg is not None + assert config.model_type == "gemma4_text" + assert config.num_hidden_layers == 35 + assert config.hidden_size == 1536 + assert config.num_attention_heads == 8 + assert config.num_key_value_heads == 1 + assert config.head_dim == 256 + assert config.intermediate_size == 6144 + assert config.vocab_size == 262144 + assert config.tie_word_embeddings is True + # The flat rope_theta falls back to the local (sliding) theta; the precise + # per-layer-type thetas live on gemma4_cfg. + assert config.rope_theta == 10000.0 + assert config.gemma4_cfg.global_rope_theta == 1000000.0 + + +def test_dense_gemma4_null_moe_fields_do_not_crash(tmp_path): + """Dense Gemma 4 sets optional MoE keys to null; parsing must not raise.""" + config = ModelConfig.from_pretrained(_write_checkpoint(tmp_path)) + assert config.num_experts == 0 + assert config.moe_intermediate_size == 0 + + +def test_registered_gemma4_class_fails_loudly(tmp_path): + """A Gemma 4 checkpoint dispatches to Gemma4CausalLM, which fails loudly.""" + pytest.importorskip("torch") + from tensorrt_edgellm import AutoModel + + with pytest.raises(NotImplementedError) as excinfo: + AutoModel.from_pretrained(_write_checkpoint(tmp_path)) + message = str(excinfo.value) + assert "gemma4_text" in message + assert "issues/72" in message