Skip to content

Add Gemma 4 assistant (MTP drafter) model class#1276

Open
broomva wants to merge 11 commits into
ml-explore:mainfrom
broomva:feat/gemma4-assistant
Open

Add Gemma 4 assistant (MTP drafter) model class#1276
broomva wants to merge 11 commits into
ml-explore:mainfrom
broomva:feat/gemma4-assistant

Conversation

@broomva
Copy link
Copy Markdown

@broomva broomva commented May 14, 2026

What

Adds the gemma4_assistant model class — the lightweight drafter companion released by Google alongside Gemma 4 for multi-token-prediction (MTP) speculative decoding. Loading mlx-community/gemma-4-*-it-assistant-bf16 via mlx_lm.load(...) now succeeds and the forward pass produces correctly-shaped output.

Architecture

  • Q-only cross-attention. Each of the 4 layers has only q_proj/q_norm/o_proj. K and V come from the target model via a shared_kv_states dict keyed by layer_type (full_attention or sliding_attention).
  • Pre/post projection (5120 → 256 → 2560). Caller feeds concat(target_embed(last_token), target_last_hidden_state); the drafter returns its own last_hidden_state_2560 for the next draft step.
  • Centroid logit head. 2048-cluster top-K=32 gating reduces logit computation ~64× over the 262144-token vocab. Optional via use_ordered_embeddings config flag.
  • Per-layer global_head_dim dispatch. Full-attention layers use global_head_dim (512 in E4B), sliding-attention layers use head_dim (256) — mirrors gemma4_text.Attention.
  • layer_scalar (per-layer 11th tensor) applied after the FFN residual.

Scope

This PR lands ONLY the model class. It loads cleanly via mlx_lm.load(...) and produces correctly-shaped output given synthetic shared_kv_states. Integration into the existing stream_generate speculative-decoding loop (target must extract per-layer-type K/V and forward the target's last-hidden-state to the drafter each step) is intended as a follow-up PR; happy to take guidance from maintainers on whether to land that as a separate PR or roll into this one.

Tests

  • test_gemma4_assistant: synthetic forward in float32 + float16, covers all submodules, GQA expansion, centroid scatter, copy.deepcopy, make_cache raise, quant_predicate exclusion list
  • test_gemma4_assistant_no_ordered_embeddings: non-clustered logit path (tied embedding matmul) — verifies the alternate code path
  • test_gemma4_assistant_published_checkpoint_forward_shapes: opt-in via MLX_LM_RUN_NETWORK_TESTS=1, default-skipped in CI. Loads the real 159MB mlx-community/gemma-4-E4B-it-assistant-bf16 and runs one forward pass with synthetic K/V matching the real target's shapes

All existing test_gemma4_* tests still pass (10 default + 1 opt-in network test).

Verified locally

$ python -c "from mlx_lm import load; m, t = load('mlx-community/gemma-4-E4B-it-assistant-bf16'); print(type(m).__name__, len(m.layers))"
Model 4

$ pytest tests/test_models.py -k gemma4 -q
10 passed, 1 skipped

$ MLX_LM_RUN_NETWORK_TESTS=1 pytest tests/test_models.py::TestModels::test_gemma4_assistant_published_checkpoint_forward_shapes
1 passed

Notes on the design

  • Why make_cache raises NotImplementedError — the drafter owns no KV cache of its own (it cross-attends to the target's shared_kv_states each forward). Returning [] would silently no-op in zip-based cache iteration paths, leading to attention without cached context. Raising loudly is the safer contract.
  • Why _token_ordering (leading underscore) — the centroid lookup is an int32 buffer, not a parameter. Without the underscore prefix it shows up in Module.parameters() and gets corrupted by model.update(tree_map(astype, ...)). The leading underscore excludes it from parameters(). To make load_weights (which is strict and rejects underscored keys) happy, the buffer is installed directly on the submodule inside Model.sanitize and then stripped from the returned weight dict.
  • No .item() calls in the forward pathMaskedEmbedder's mask_value is constructed on-device; AssistantAttention's RoPE offset is passed as mx.array. Both are in the MTP per-token hot loop where any GPU→host sync would defeat the speedup.

References

broomva added 11 commits May 14, 2026 07:29
First step toward Gemma 4 MTP drafter support. ModelArgs parses the
real google/gemma-4-*-it-assistant config. Forward pass + submodules
land in follow-up commits within the same PR.
2048-centroid clustering with top-K=32 token selection per draft step.
~64x cheaper than full-vocab softmax over 262144 tokens — the
defining cost-saving trick of the MTP drafter.
AssistantAttention has no k_proj/v_proj weights — it reads K/V
straight from target's shared_kv_states. GQA expansion and matching
RoPE so target's pre-RoPE'd K and assistant's Q stay aligned.
SwiGLU MLP + Gemma-family double-norm (pre/post around each
sublayer). layer_scalar is the 11th per-layer tensor in the
checkpoint.
Embedding + 4 DecoderLayers + final norm. shared_kv_states dispatch
by layer_type so each layer fetches the right (full vs sliding)
target K/V.
Pre/post projection wiring + masked_embedding logit head + tuple
return. Plus sanitize (token_ordering int32 cast), make_cache ([],
the drafter has no cache of its own), and quant_predicate that
excludes the centroid Linear from quantization.
Real-weight load against mlx-community/gemma-4-E4B-it-assistant-bf16
revealed that layer 3 (the lone full_attention layer) uses
global_head_dim=512, not head_dim=256. Mirror gemma4_text.Attention's
head_dim dispatch logic. With this fix, mlx_lm.load() succeeds and
forward pass produces (B, 1, 2560) last_hidden and (B, 1, 262144)
logits — matching the spec contract.
Mirrors existing model_test_runner pattern, adapted for the
two-tuple return signature (last_hidden, logits) and dict-typed
shared_kv_states. Covers dtype propagation, GQA expansion,
centroid scatter, copy.deepcopy compatibility, make_cache contract,
and quant_predicate.
Confirms the non-clustered logit path (tied embedding matmul) works
when the assistant config disables masked_embedding.
Loads mlx-community/gemma-4-E4B-it-assistant-bf16 (159MB) and runs
one forward pass with synthetic K/V matching the real target's
shapes. Gated by MLX_LM_SKIP_NETWORK_TESTS so air-gapped CI skips.
Round 2 fixes from cross-model adversarial review (B/C strata):

- B1: drop private workspace path from module docstring
- B2: invert network-test polarity (MLX_LM_RUN_NETWORK_TESTS=1 opt-in,
  default-skip in CI); use public huggingface_hub.snapshot_download
  instead of private _download; rename test to describe behavior
- I1: remove unused List import
- I2: rename quant_predicate's _module to _ (matches sibling style)
- I3: pass position_ids[-1] as mx.array offset to RoPE (no .item() sync;
  MTP hot path was paying a GPU->host roundtrip per forward)
- I4: build MaskedEmbedder fill on-device (no .item() sync inside the
  per-draft-step path)
- ml-explore#6: rename masked_embedding.token_ordering -> _token_ordering so the
  int32 gather buffer is excluded from Module.parameters(); install via
  Model.sanitize side-channel so load_weights (which rejects underscored
  keys) doesn't trip; eliminates the fragile post-tree_map int32 restore
  in the synthetic Tier 1 test
- ml-explore#9: make_cache() now raises NotImplementedError instead of returning []
  (zip-based cache iteration would silently no-op the previous return)

77 model tests pass (10 gemma4 tests + 1 opt-in real-weight test pass).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant