Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Release Notes

## Unreleased
- Added Gemma 4 (`gemma4_text`) checkpoint recognition and architecture config parsing in the checkpoint export frontend; ONNX export and runtime support are in progress (see [#72](https://github.com/NVIDIA/TensorRT-Edge-LLM/issues/72))

## 0.8.0
- Externalized INT4 FFN, INT4 MoE, and LM-head weights to reduce engine build memory usage
- Upgraded plugins to TensorRT Plugin V3 for TensorRT 11 readiness
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_edgellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .config import ModelConfig, QuantConfig
from .model import AutoModel, register_model
# Register model-type-specific implementations
from .models.gemma4.modeling_gemma4_text import Gemma4CausalLM
from .models.nemotron_h.modeling_nemotron_h import NemotronHCausalLM
from .models.qwen3_5.modeling_qwen3_5_text import Qwen3_5CausalLM
from .models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeCausalLM
Expand All @@ -46,6 +47,10 @@
from .onnx.export import export_onnx

register_model("nemotron_h", NemotronHCausalLM)
# Gemma 4: ``gemma4`` is the multimodal root model_type, ``gemma4_text`` the
# promoted text sub-config; register both so either is dispatched here.
register_model("gemma4", Gemma4CausalLM)
register_model("gemma4_text", Gemma4CausalLM)
register_model("qwen3_5_text", Qwen3_5CausalLM)
register_model("qwen3_5_moe_text", Qwen3_5MoeCausalLM)
register_model("qwen3_5_moe", Qwen3_5MoeCausalLM)
Expand Down
172 changes: 170 additions & 2 deletions tensorrt_edgellm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,85 @@ def conv_dim(self) -> int:
return self.key_dim + self.key_dim + self.value_dim


@dataclass
class Gemma4Config:
"""Gemma 4 (``gemma4_text``) architecture parameters.

Gemma 4 interleaves local *sliding-window* attention with global *full*
attention, and the two layer types differ in more than the attention mask:
they use different head dimensions (``head_dim`` vs ``global_head_dim``) and
different RoPE parameters (a plain ``rope_theta`` for local layers, a
proportional RoPE with its own ``rope_theta`` and ``partial_rotary_factor``
for global layers). Gemma 4 also adds Per-Layer Embeddings (PLE), key/value
sharing across the trailing decoder layers, and final-logit soft-capping.

None of these fit the flat :class:`ModelConfig` shared by Llama-style dense
models, so they are collected here and attached as
:attr:`ModelConfig.gemma4_cfg`. See ``models/gemma4`` for the consuming
modeling code.
"""

#: Per-layer attention type: ``True`` for global (full) attention, ``False``
#: for local (sliding-window) attention. Length == ``num_hidden_layers``.
is_global_layer: List[bool]
#: Head dimension for local (sliding) attention layers (``head_dim``).
local_head_dim: int
#: Head dimension for global (full) attention layers (``global_head_dim``).
global_head_dim: int
#: Sliding-window size used by the local attention layers.
sliding_window: int
#: Number of trailing decoder layers that reuse a previous layer's KV
#: projections (``num_kv_shared_layers``); 0 means no sharing.
num_kv_shared_layers: int
#: KV heads for global layers (``num_global_key_value_heads``); falls back to
#: ``num_key_value_heads`` when the checkpoint leaves it unset.
num_global_key_value_heads: int
#: When ``True``, global layers reuse the key projection as the value
#: projection (``attention_k_eq_v``).
attention_k_eq_v: bool
#: RoPE base frequency for local (sliding) layers.
sliding_rope_theta: float
#: RoPE base frequency for global (full) layers.
global_rope_theta: float
#: Fraction of ``global_head_dim`` that receives RoPE on global layers
#: (Proportional RoPE; the global rope params' ``partial_rotary_factor``).
global_partial_rotary_factor: float
#: RoPE variant for global layers (e.g. ``"proportional"``).
global_rope_type: str
#: Per-Layer-Embedding hidden width (``hidden_size_per_layer_input``); 0
#: disables PLE.
hidden_size_per_layer_input: int
#: Vocabulary size of the PLE table (``vocab_size_per_layer_input``).
vocab_size_per_layer_input: int
#: Final-logit soft-capping value (``final_logit_softcapping``); ``None`` when
#: the checkpoint does not soft-cap logits.
final_logit_softcapping: Optional[float]
#: When ``True``, the KV-shared layers use a double-width MLP intermediate
#: size (``use_double_wide_mlp``).
use_double_wide_mlp: bool
#: Feed-forward activation name (e.g. ``"gelu_pytorch_tanh"``).
hidden_activation: str

def first_kv_shared_layer_idx(self, num_hidden_layers: int) -> int:
"""Index of the first KV-sharing layer (== ``num_hidden_layers`` if none)."""
return num_hidden_layers - self.num_kv_shared_layers

@property
def num_global_layers(self) -> int:
"""Count of global (full) attention layers."""
return sum(1 for is_global in self.is_global_layer if is_global)

@property
def num_local_layers(self) -> int:
"""Count of local (sliding-window) attention layers."""
return sum(1 for is_global in self.is_global_layer if not is_global)

@property
def uses_per_layer_embeddings(self) -> bool:
"""True when the model carries a Per-Layer Embedding pathway."""
return self.hidden_size_per_layer_input > 0


@dataclass
class ModelConfig:
"""Flat model hyper-parameter config consumed by module builders."""
Expand Down Expand Up @@ -320,6 +399,11 @@ class ModelConfig:
mamba_cfg: Optional[MambaConfig] = None
# ------------------------------------------ gdn / hybrid config
gdn_cfg: Optional[GdnConfig] = None
# ------------------------------------------ gemma4 (gemma4_text) config
# Populated only for Gemma 4 checkpoints; carries the per-layer-type
# attention, Per-Layer-Embedding, KV-sharing, and soft-capping parameters
# that do not fit the flat dense-model fields above.
gemma4_cfg: Optional[Gemma4Config] = None
# ------------------------------------------ gated attention (Qwen3.5)
attn_output_gate: bool = False
# ------------------------------------------ MTP config
Expand Down Expand Up @@ -383,6 +467,10 @@ def eagle3_target_hidden_size(self) -> int:
def is_hybrid(self) -> bool:
return self.mamba_cfg is not None or self.gdn_cfg is not None

@property
def is_gemma4(self) -> bool:
return self.gemma4_cfg is not None

@property
def is_nemotron_h(self) -> bool:
return (self.model_type or "").lower().startswith("nemotron_h")
Expand Down Expand Up @@ -447,6 +535,7 @@ def from_pretrained(cls, model_dir: str) -> "ModelConfig":
layer_types,
model_dir=model_dir)
gdn_cfg = _parse_gdn_cfg(llm_dict, layer_types)
gemma4_cfg = _parse_gemma4_cfg(llm_dict)
has_qk_norm = _detect_has_qk_norm(model_dir)

# MTP config
Expand Down Expand Up @@ -474,8 +563,11 @@ def from_pretrained(cls, model_dir: str) -> "ModelConfig":
num_experts = int(
llm_dict.get("num_experts", llm_dict.get("num_local_experts", 0))
or 0)
num_experts_per_tok = int(llm_dict.get("num_experts_per_tok", 0))
moe_intermediate_size = int(llm_dict.get("moe_intermediate_size", 0))
# ``or 0`` (rather than a get-default) so configs that explicitly set
# these optional keys to ``null`` — e.g. dense Gemma 4 checkpoints with
# ``"moe_intermediate_size": null`` — parse as 0 instead of crashing.
num_experts_per_tok = int(llm_dict.get("num_experts_per_tok") or 0)
moe_intermediate_size = int(llm_dict.get("moe_intermediate_size") or 0)
moe_shared_expert_intermediate_size = int(
llm_dict.get("moe_shared_expert_intermediate_size",
llm_dict.get("shared_expert_intermediate_size", 0))
Expand Down Expand Up @@ -523,6 +615,7 @@ def from_pretrained(cls, model_dir: str) -> "ModelConfig":
quant=quant,
mamba_cfg=mamba_cfg,
gdn_cfg=gdn_cfg,
gemma4_cfg=gemma4_cfg,
attn_output_gate=bool(llm_dict.get("attn_output_gate", False)),
mtp_num_hidden_layers=mtp_num_hidden_layers,
mtp_use_dedicated_embeddings=mtp_use_dedicated_embeddings,
Expand Down Expand Up @@ -815,6 +908,81 @@ def _parse_gdn_cfg(config: dict,
)


# Default RoPE base for Gemma 4 global (full) attention layers when the
# checkpoint's ``rope_parameters`` omits it.
_GEMMA4_DEFAULT_GLOBAL_ROPE_THETA = 1_000_000.0


def _parse_gemma4_cfg(llm_dict: Dict[str, Any]) -> Optional[Gemma4Config]:
"""Return a :class:`Gemma4Config` for ``gemma4_text`` checkpoints, else None.

Gemma 4 stores ``rope_parameters`` as a mapping keyed by layer type
(``"sliding_attention"`` / ``"full_attention"``) rather than a single flat
block, and ``layer_types`` marks each decoder layer as sliding (local) or
full (global). Both are parsed here, along with the per-layer-type head
dimensions, KV-sharing, Per-Layer-Embedding, and soft-capping fields.

Returns ``None`` for non-Gemma-4 checkpoints so the parser is inert for
every other model family.
"""
model_type = str(llm_dict.get("model_type", "")).lower()
if not model_type.startswith("gemma4"):
return None

raw_layer_types = llm_dict.get("layer_types") or []
is_global_layer = [
str(layer_type).lower() == "full_attention"
for layer_type in raw_layer_types
]

rope_params = llm_dict.get("rope_parameters")
if not isinstance(rope_params, dict):
rope_params = {}
sliding_rope = rope_params.get("sliding_attention") or {}
global_rope = rope_params.get("full_attention") or {}

num_key_value_heads = int(llm_dict.get("num_key_value_heads", 1))
raw_global_kv_heads = llm_dict.get("num_global_key_value_heads")
num_global_key_value_heads = num_key_value_heads
if raw_global_kv_heads is not None:
num_global_key_value_heads = int(raw_global_kv_heads)

hidden_size = int(llm_dict["hidden_size"])
num_attention_heads = int(llm_dict["num_attention_heads"])
local_head_dim = int(
llm_dict.get("head_dim", hidden_size // num_attention_heads))
global_head_dim = int(llm_dict.get("global_head_dim") or local_head_dim)

softcap = llm_dict.get("final_logit_softcapping")
final_logit_softcapping = float(softcap) if softcap is not None else None

return Gemma4Config(
is_global_layer=is_global_layer,
local_head_dim=local_head_dim,
global_head_dim=global_head_dim,
sliding_window=int(llm_dict.get("sliding_window") or 0),
num_kv_shared_layers=int(llm_dict.get("num_kv_shared_layers") or 0),
num_global_key_value_heads=num_global_key_value_heads,
attention_k_eq_v=bool(llm_dict.get("attention_k_eq_v", False)),
sliding_rope_theta=float(
sliding_rope.get("rope_theta", _DEFAULT_ROPE_THETA)),
global_rope_theta=float(
global_rope.get("rope_theta", _GEMMA4_DEFAULT_GLOBAL_ROPE_THETA)),
global_partial_rotary_factor=float(
global_rope.get("partial_rotary_factor", 1.0)),
global_rope_type=str(global_rope.get("rope_type", "default")),
hidden_size_per_layer_input=int(
llm_dict.get("hidden_size_per_layer_input") or 0),
vocab_size_per_layer_input=int(
llm_dict.get("vocab_size_per_layer_input")
or llm_dict.get("vocab_size") or 0),
final_logit_softcapping=final_logit_softcapping,
use_double_wide_mlp=bool(llm_dict.get("use_double_wide_mlp", False)),
hidden_activation=str(
llm_dict.get("hidden_activation", "gelu_pytorch_tanh")),
)


def _get_partial_rotary_factor(llm_dict: Dict[str, Any]) -> float:
"""Extract partial_rotary_factor from config dict.

Expand Down
1 change: 1 addition & 0 deletions tensorrt_edgellm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

Model variants live in sub-packages named after the model family:
``default/`` - standard decoder transformer + Mamba hybrid
``gemma4/`` - Gemma 4 (gemma4_text) export frontend (config parsing; export WIP)
``nemotron_h/`` - Nemotron-H hybrid (Mamba2 + attention)
``nemotron_omni/`` - Nemotron-Omni: RADIO visual + Parakeet audio (LLM reuses nemotron_h)
``qwen3_vl/`` - Qwen3-VL visual encoder + LLM
Expand Down
19 changes: 19 additions & 0 deletions tensorrt_edgellm/models/gemma4/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gemma 4 (``gemma4_text``) export frontend package."""

from .modeling_gemma4_text import Gemma4CausalLM

__all__ = ["Gemma4CausalLM"]
100 changes: 100 additions & 0 deletions tensorrt_edgellm/models/gemma4/modeling_gemma4_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Gemma 4 (``gemma4_text``) causal LM — checkpoint export frontend.

This is the first phase of Gemma 4 support (issue
https://github.com/NVIDIA/TensorRT-Edge-LLM/issues/72). It teaches the frontend
to *recognize* a Gemma 4 text checkpoint and parse its architecture into
:class:`~tensorrt_edgellm.config.Gemma4Config`, but ONNX export and C++ runtime
support are not implemented yet.

Why a dedicated (currently fail-loud) class instead of the dense
:class:`~tensorrt_edgellm.models.default.modeling_default.CausalLM`? Gemma 4
departs from the Llama-style dense template in ways that produce silently wrong
results — not just errors — if exported through the default path:

* Per-layer-type attention: local *sliding* layers use ``head_dim`` (256) while
global *full* layers use ``global_head_dim`` (512).
* Dual RoPE: local layers use a plain RoPE (``theta=10000``); global layers use
a *proportional* RoPE (``theta=1e6``) over only ``partial_rotary_factor`` of
the head.
* Per-Layer Embeddings (PLE): a second embedding pathway feeds a residual signal
into every decoder layer; the C++ runtime would have to compute and thread it
into the engine alongside ``inputs_embeds``.
* KV sharing: the trailing ``num_kv_shared_layers`` layers reuse a previous
layer's keys/values.
* Soft-capped final logits, GeGLU (``gelu_pytorch_tanh``) MLP with a double-wide
variant on the KV-shared layers, and ``(query) scaling = 1.0`` rather than the
usual ``1/sqrt(head_dim)``.

Several of those live in the C++ runtime / CUDA kernels (per-layer head dim,
dual RoPE, KV sharing, attention scaling, PLE input plumbing) and land in
follow-up phases. Registering this class makes a Gemma 4 checkpoint fail with
an explicit, actionable message instead of being mis-built as a Llama-style
model with the wrong head dimensions, RoPE, and normalization.
"""

import torch.nn as nn

from ...config import ModelConfig

__all__ = ["Gemma4CausalLM"]

_TRACKING_ISSUE = "https://github.com/NVIDIA/TensorRT-Edge-LLM/issues/72"


def _detected_summary(config: ModelConfig) -> str:
"""One-line summary of the parsed Gemma 4 architecture, for the error text."""
cfg = config.gemma4_cfg
if cfg is None:
return f"model_type={config.model_type!r}, {config.num_hidden_layers} layers"
return (
f"{config.num_hidden_layers} layers "
f"({cfg.num_local_layers} local sliding + {cfg.num_global_layers} global full), "
f"head_dim local={cfg.local_head_dim}/global={cfg.global_head_dim}, "
f"RoPE theta local={cfg.sliding_rope_theta:g}/global={cfg.global_rope_theta:g}, "
f"PLE dim={cfg.hidden_size_per_layer_input}, "
f"kv_shared_layers={cfg.num_kv_shared_layers}, "
f"final_logit_softcapping={cfg.final_logit_softcapping}")


def _unsupported_message(config: ModelConfig) -> str:
"""Build the actionable NotImplementedError message for Gemma 4 export."""
return (
"Gemma 4 (gemma4_text) export is not implemented yet.\n"
f"Detected: {_detected_summary(config)}.\n"
"The frontend can parse this checkpoint, but ONNX export and the C++ "
"runtime still need: per-layer-type head dimensions, dual/proportional "
"RoPE, Per-Layer-Embedding (PLE) input plumbing, KV sharing across the "
"trailing layers, unit query scaling, and final-logit soft-capping.\n"
f"Track progress at {_TRACKING_ISSUE}.")


class Gemma4CausalLM(nn.Module):
"""Gemma 4 causal LM placeholder that fails loudly until export lands.

Constructing the model raises :class:`NotImplementedError` with a summary of
the parsed architecture and the remaining work. See the module docstring
for why Gemma 4 cannot reuse the dense export path.

Args:
config: Parsed model configuration; expected to carry a populated
:attr:`~tensorrt_edgellm.config.ModelConfig.gemma4_cfg`.
"""

def __init__(self, config: ModelConfig) -> None:
super().__init__()
raise NotImplementedError(_unsupported_message(config))
Loading