Add Gemma 4 assistant (MTP drafter) model class#1276
Open
broomva wants to merge 11 commits into
Open
Conversation
First step toward Gemma 4 MTP drafter support. ModelArgs parses the real google/gemma-4-*-it-assistant config. Forward pass + submodules land in follow-up commits within the same PR.
2048-centroid clustering with top-K=32 token selection per draft step. ~64x cheaper than full-vocab softmax over 262144 tokens — the defining cost-saving trick of the MTP drafter.
AssistantAttention has no k_proj/v_proj weights — it reads K/V straight from target's shared_kv_states. GQA expansion and matching RoPE so target's pre-RoPE'd K and assistant's Q stay aligned.
SwiGLU MLP + Gemma-family double-norm (pre/post around each sublayer). layer_scalar is the 11th per-layer tensor in the checkpoint.
Embedding + 4 DecoderLayers + final norm. shared_kv_states dispatch by layer_type so each layer fetches the right (full vs sliding) target K/V.
Pre/post projection wiring + masked_embedding logit head + tuple return. Plus sanitize (token_ordering int32 cast), make_cache ([], the drafter has no cache of its own), and quant_predicate that excludes the centroid Linear from quantization.
Real-weight load against mlx-community/gemma-4-E4B-it-assistant-bf16 revealed that layer 3 (the lone full_attention layer) uses global_head_dim=512, not head_dim=256. Mirror gemma4_text.Attention's head_dim dispatch logic. With this fix, mlx_lm.load() succeeds and forward pass produces (B, 1, 2560) last_hidden and (B, 1, 262144) logits — matching the spec contract.
Mirrors existing model_test_runner pattern, adapted for the two-tuple return signature (last_hidden, logits) and dict-typed shared_kv_states. Covers dtype propagation, GQA expansion, centroid scatter, copy.deepcopy compatibility, make_cache contract, and quant_predicate.
Confirms the non-clustered logit path (tied embedding matmul) works when the assistant config disables masked_embedding.
Loads mlx-community/gemma-4-E4B-it-assistant-bf16 (159MB) and runs one forward pass with synthetic K/V matching the real target's shapes. Gated by MLX_LM_SKIP_NETWORK_TESTS so air-gapped CI skips.
Round 2 fixes from cross-model adversarial review (B/C strata): - B1: drop private workspace path from module docstring - B2: invert network-test polarity (MLX_LM_RUN_NETWORK_TESTS=1 opt-in, default-skip in CI); use public huggingface_hub.snapshot_download instead of private _download; rename test to describe behavior - I1: remove unused List import - I2: rename quant_predicate's _module to _ (matches sibling style) - I3: pass position_ids[-1] as mx.array offset to RoPE (no .item() sync; MTP hot path was paying a GPU->host roundtrip per forward) - I4: build MaskedEmbedder fill on-device (no .item() sync inside the per-draft-step path) - ml-explore#6: rename masked_embedding.token_ordering -> _token_ordering so the int32 gather buffer is excluded from Module.parameters(); install via Model.sanitize side-channel so load_weights (which rejects underscored keys) doesn't trip; eliminates the fragile post-tree_map int32 restore in the synthetic Tier 1 test - ml-explore#9: make_cache() now raises NotImplementedError instead of returning [] (zip-based cache iteration would silently no-op the previous return) 77 model tests pass (10 gemma4 tests + 1 opt-in real-weight test pass).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Adds the
gemma4_assistantmodel class — the lightweight drafter companion released by Google alongside Gemma 4 for multi-token-prediction (MTP) speculative decoding. Loadingmlx-community/gemma-4-*-it-assistant-bf16viamlx_lm.load(...)now succeeds and the forward pass produces correctly-shaped output.Architecture
q_proj/q_norm/o_proj. K and V come from the target model via ashared_kv_statesdict keyed bylayer_type(full_attentionorsliding_attention).concat(target_embed(last_token), target_last_hidden_state); the drafter returns its ownlast_hidden_state_2560for the next draft step.use_ordered_embeddingsconfig flag.global_head_dimdispatch. Full-attention layers useglobal_head_dim(512 in E4B), sliding-attention layers usehead_dim(256) — mirrorsgemma4_text.Attention.layer_scalar(per-layer 11th tensor) applied after the FFN residual.Scope
This PR lands ONLY the model class. It loads cleanly via
mlx_lm.load(...)and produces correctly-shaped output given syntheticshared_kv_states. Integration into the existingstream_generatespeculative-decoding loop (target must extract per-layer-type K/V and forward the target's last-hidden-state to the drafter each step) is intended as a follow-up PR; happy to take guidance from maintainers on whether to land that as a separate PR or roll into this one.Tests
test_gemma4_assistant: synthetic forward infloat32+float16, covers all submodules, GQA expansion, centroid scatter,copy.deepcopy,make_cacheraise,quant_predicateexclusion listtest_gemma4_assistant_no_ordered_embeddings: non-clustered logit path (tied embedding matmul) — verifies the alternate code pathtest_gemma4_assistant_published_checkpoint_forward_shapes: opt-in viaMLX_LM_RUN_NETWORK_TESTS=1, default-skipped in CI. Loads the real 159MBmlx-community/gemma-4-E4B-it-assistant-bf16and runs one forward pass with synthetic K/V matching the real target's shapesAll existing
test_gemma4_*tests still pass (10 default + 1 opt-in network test).Verified locally
Notes on the design
make_cacheraisesNotImplementedError— the drafter owns no KV cache of its own (it cross-attends to the target'sshared_kv_stateseach forward). Returning[]would silently no-op in zip-based cache iteration paths, leading to attention without cached context. Raising loudly is the safer contract._token_ordering(leading underscore) — the centroid lookup is an int32 buffer, not a parameter. Without the underscore prefix it shows up inModule.parameters()and gets corrupted bymodel.update(tree_map(astype, ...)). The leading underscore excludes it fromparameters(). To makeload_weights(which is strict and rejects underscored keys) happy, the buffer is installed directly on the submodule insideModel.sanitizeand then stripped from the returned weight dict..item()calls in the forward path —MaskedEmbedder'smask_valueis constructed on-device;AssistantAttention's RoPE offset is passed asmx.array. Both are in the MTP per-token hot loop where any GPU→host sync would defeat the speedup.References
transformers/models/gemma4_assistant/modeling_gemma4_assistant.py