Skip to content

Native MTP speculative decoding (Qwen3.5/3.6 reference implementation)#990

Open
AirRunner wants to merge 29 commits into
ml-explore:mainfrom
AirRunner:feat/mtp-native
Open

Native MTP speculative decoding (Qwen3.5/3.6 reference implementation)#990
AirRunner wants to merge 29 commits into
ml-explore:mainfrom
AirRunner:feat/mtp-native

Conversation

@AirRunner
Copy link
Copy Markdown

@AirRunner AirRunner commented Mar 13, 2026

Summary

Qwen3.5 checkpoints ship with a built-in Multi-Token Prediction head (mtp_num_hidden_layers: 1 in config) that predicts token t+2 from the backbone hidden state at t and the embedding of token t+1. This PR adds support for using it as a native speculative decoding mechanism. No separate draft model needed, at minimal extra compute (1 extra transformer layer).

Changes

  • mlx_lm/generate.py: MTP generation loop with draft/verify and probabilistic acceptance, --mtp CLI flag
  • mlx_lm/models/cache.py: rollback_state slot for conv/SSM snapshot on draft rejection
  • mlx_lm/sample_utils.py: p_draw parameter added to apply_xtc to share the XTC draw across draft and verify
  • mlx_lm/models/qwen3_5.py: MTP head module, self.norm moved to TextModel to expose pre-norm hidden states for MTP, n_confirmed parameter for SSM rollback, sanitize: norm +1 shift now triggered only on raw HF checkpoints (unsanitized conv1d), not on presence of MTP weights, and MoE gate weights at 8-bit group=64
  • mlx_lm/models/qwen3_5_moe.py: MTP checkpoint sanitization for MoE variants + handling both Qwen3.5 and Qwen3.6 (fused gate_up_proj)
  • mlx_lm/server.py: --mtp flag, dynamic MTP/batch switching + fix xtc_special_tokens construction
  • tests/test_mtp.py: 10 unit tests

How it works

Each backbone forward pass returns both logits and pre-norm hidden states. The MTP head fuses pre_fc_norm_hidden(h_t) and pre_fc_norm_embedding(embed(t+1)) via a linear projection, runs one full-attention transformer layer, and produces draft logits through the shared lm_head.

The generation loop verifies drafts by feeding [confirmed_tok, draft_tok] to the backbone with n_confirmed=1. This causes GatedDeltaNet to snapshot its conv/SSM state after the confirmed token. On acceptance, both tokens are emitted. On rejection, the SSM state is rolled back to the snapshot and KV caches are trimmed.

Results (Qwen3.6-27B 4-bit, M4 Pro)

Pooled tok/s, 3 runs × 8 prompts, conditions interleaved. See bench_mtp.py.

Acceptance is reported as A/V (drafts accepted / drafts proposed), the standard speculative decoding metric. The benchmark script also reports A/(V+A) (~46% for Qwen3.6-27B at temp=0.6), which equals A/V ÷ (1 + A/V) and is used in some implementations.

Condition Tok/s Speedup Accept (A/V)
Baseline 15.7 1.00x
MTP temp=0 24.6 1.57x 88.3%
MTP temp=0.6 24.0 1.53x 84.8%
MTP temp=1.0 23.5 1.50x 80.5%

Identity check: greedy MTP output == standard generate_step output.

Usage

mlx_lm.generate --model <path> --mtp
mlx_lm.server   --model <path> --mtp

Checkpoint conversion

This requires a checkpoint converted with MTP weights (the default sanitize() previously stripped them). Re-convert from HF with this branch to preserve mtp.* weights.

Note on M1/M2: M1 and M2 lack native BF16 GPU support (MTLDataType.bfloat requires Apple8+). If you choose not to quantise mtp.fc on M1/M2, you need to add the flag the flag --dtype float16 to the convert command. Without it, MTP may drastically slow down on M1 despite positive acceptance rates.

Questions for reviewers

  1. sampler is None as greedy signal: I use sampler is None to distinguish greedy from stochastic and apply exact-match vs probabilistic acceptance accordingly. Is this the right signal, or would you prefer an explicit greedy: bool parameter for instance?
  2. Dynamic MTP/batch switching: the server now auto-switches based on self.requests.empty(): MTP for solo requests and BatchGenerator for concurrent ones. Is a best-effort queue check the right approach, or is there a preferred pattern in the server architecture?

Addressed in feat/mtp-batched where GenerationBatch supports MTP natively for B > 1

Future work

DRY refactor + SamplerConfig

A follow-up PR independent of MTP would address:

  1. Code duplication across the now three generator functions:
  • _prefill logic: 3 variants across generate_step, speculative_generate_step, and mtp_generate_step
  • _process_and_sample: almost same pattern in speculative_generate_step and mtp_generate_step
  • quantize_cache_fn = functools.partial(...): same pattern in all three
  1. SamplerConfig: currently mtp_generate_step cannot accept a pre-built sampler= callable and produce correct acceptance logprobs simultaneously. A sampler today returns only a token, but for MTP the acceptance criterion also needs the log-probability distribution the token was drawn from. The fix is a richer sampler interface that returns (token, lp_distribution), allowing both generate_step and mtp_generate_step to share the same interface without passing a dozen individual parameters.

Beyond DRY, SamplerConfig unlocks a potential performance gain: sparse residual sampling.
On rejection at temp > 0, the current implementation samples from max(p_target - p_draft, 0) / Z over the full vocabulary (151K-token for Qwen3.5, 580 µs/call). With top_k > 0, the sampler already computes a top-k partition over the vocabulary, so exposing those indices lets the rejection path work on a K-token support instead.
Without a SamplerConfig, re-running argpartition specifically for the rejection path is slower or equal to the full-vocab path.

Batched MTP

This PR brings MTP for the solo request path only.
However, per-sequence selective rollback (restore SSM state + trim KV only for rejected sequences) is already implemented in AirRunner/mlx-lm · feat/mtp-batched, left out of this PR to keep the diff reviewable.

Test plan

  • Unit tests (10/10 passing) — module existence, cache creation, shapes, pre-norm hidden states, quant predicate, generation identity, end-to-end
  • Manual validation on Qwen3.5-27B, Qwen3.5-0.8B and Qwen3.5-35B-A3B (all 4-bit)

Relates to #872 — cc @janhilgard


Update - probabilistic acceptance and MoE benchmarks

Integrated probabilistic draft acceptance with two cases:

  1. Greedy (sampler=None): exact-match acceptance, mathematically correct for deterministic argmax sampling
  2. Stochastic (temp > 0): min(1, p_target / p_draft): recovers greedy acceptance level at any temperature

Benchmarks on M4 Pro, with 8 diverse prompts:

A reproducible benchmark script is available: bench_mtp.py

Qwen3.5-27B 4-bit

Tok/s Speedup Acceptance (A/V)
No MTP 15.3 1.00x
MTP, temp=0 24.0 1.57x 85.2%
MTP, temp=0.6, exact match 22.7 1.49x 75.4%
MTP, temp=0.6, probabilistic 22.9 1.51x 85.2%

Qwen3.5-35B-A3B 4-bit

Tok/s Speedup Acceptance (A/V)
No MTP 85.3 1.00x
MTP, temp=0 87.9 1.04x 85.2%
MTP, temp=0.6, exact match 84.5 0.98x 78.6%
MTP, temp=0.6, probabilistic 86.5 1.03x 85.2%

On M4 Pro MoE speedup is marginal regardless of acceptance rate. MTP benefit scales with baseline decode time, so at 85 tok/s (3B active params) the MTP overhead is proportionally too large to yield meaningful speedup. With probabilistic acceptance, acceptance rates are consistent with the dense model (~85%).

Bandwidth model

The cross-hardware speedup variation is explained by speedup = (1+p) / (β+δ), where p is the per-round acceptance probability, β = T_verify_backbone / T_baseline, and δ = T_mtp_head / T_baseline. Full derivation and per-component bandwidth estimates in this comment.

For reference:

  • @Thump604's MoE results (M2 Ultra, 8-bit, temp=0, exact match): 35B-A3B 1.11x, 122B-A10B 1.09x.
  • @sammcj results (M5 Max, 4-bit, temp=0): 9B +11.3%, 27B +35.5%, 122B +12.4%.
  • @Anionex results (M5 Pro, 4-bit, temp=0): 27B +31.4%, 79.5% acceptance (A/V).

@vlbosch
Copy link
Copy Markdown

vlbosch commented Mar 15, 2026

Great work! Would this also be possible for models like GLM5? As in, does each model require its own implementation of MTP, or can we reuse your mtp_generate_step-funtion for other models? Thanks for your work so far!

@AirRunner
Copy link
Copy Markdown
Author

Great work! Would this also be possible for models like GLM5? As in, does each model require its own implementation of MTP, or can we reuse your mtp_generate_step-funtion for other models? Thanks for your work so far!

Thanks!

Yes mtp_generate_step() is fully reusable, but each model still needs its own model-side interface.

The Qwen3.5-specific part is MTPDecoderLayer, mtp_forward (produce draft logits), make_mtp_cache and the backbone's __call__ (with n_confirmed for SSM state rollback on hybrid models).

So the speculative-decoding logic lives in one place, and adding a new model is just a matter of exposing the right interface.

For GLM5 specifically, it would certainly be feasible yeah. But I don't think there is even a glm5.py currently.

@Thump604
Copy link
Copy Markdown

Great work on this! We've been using it on M2 Ultra (128GB) with all three Qwen3.5 sizes and it works well.

MoE fix needed

The PR works out of the box for the dense 27B, but MoE models (35B-A3B, 122B-A10B) fail conversion with "768 parameters not in model". The MTP layer's expert weights use unfused per-expert format (mtp.layers.{l}.mlp.experts.{i}.gate_proj.weight) unlike the backbone which uses pre-fused gate_up_proj. The existing sanitize() in qwen3_5_moe.py only handles backbone expert stacking.

Fix (add to qwen3_5_moe.py sanitize(), after the backbone expert stacking loop):

# Stack per-expert MTP weights into switch_mlp format.
mtp_num = getattr(self.language_model.args, "mtp_num_hidden_layers", 0)
num_experts = self.language_model.args.num_experts
for l in range(mtp_num):
    prefix = f"language_model.mtp.layers.{l}.mlp"
    test_key = f"{prefix}.experts.0.gate_proj.weight"
    if test_key in new_weights:
        for n in ["gate_proj", "up_proj", "down_proj"]:
            to_join = [
                new_weights.pop(f"{prefix}.experts.{e}.{n}.weight")
                for e in range(num_experts)
            ]
            new_weights[f"{prefix}.switch_mlp.{n}.weight"] = mx.stack(to_join)

Also needs import mlx.core as mx at the top of the file.

Full fix on our fork: Thump604/mlx-lm@04a4383

Benchmark results (M2 Ultra, greedy)

Model Baseline MTP Speedup
27B-8bit (dense) 20.6 tok/s 27.1 tok/s 1.32x
35B-A3B-8bit (MoE) 74.4 tok/s 82.3 tok/s 1.11x
122B-A10B-5bit (MoE) 43.0 tok/s 46.7 tok/s 1.09x

Pre-converted models with MTP weights: Thump604/Qwen3.5-27B-MLX-8bit, 35B, 122B

@AirRunner
Copy link
Copy Markdown
Author

@Thump604 Thanks for the report and the fix! I've integrated it in AirRunner/mlx-lm@8d06796 with a credit.

Also, what acceptance rates did you get with MoE? I'm curious if it's somehow correlated to the speedup.

@Thump604
Copy link
Copy Markdown

Thanks for the quick integration!

Here are the acceptance rates derived from our benchmarks (M2 Ultra 128GB, greedy/temp=0):

Model Baseline tok/s MTP tok/s Speedup Implied Accept Rate
27B dense 8-bit 20.6 27.1 1.32x ~32%
35B-A3B MoE 8-bit 74.4 82.3 1.11x ~11%
122B-A10B MoE 5-bit 43.0 46.7 1.09x ~9%

At temp=0.6 (production sampling), 122B drops to 1.05x (~5% acceptance).

So yes — it does correlate with architecture. MoE acceptance rates are significantly lower than dense. My hypothesis: the MTP layer contains a full 256-expert MoE routing step (same expert count as the backbone), but with only a single layer of context depth it struggles to predict the correct expert routing. The dense 27B's MTP layer is a standard transformer layer — much simpler prediction task, much higher acceptance.

The fp16 27B was actually 0.61x (slower) — bandwidth-saturated, the MTP overhead exceeds the savings. 8-bit quantization is the sweet spot where MTP helps most.

@Thump604
Copy link
Copy Markdown

Hey @AirRunner — thanks for integrating the MoE sanitize fix! The PR has merge conflicts with main now though. Would you be able to rebase? Happy to help if needed.

Also, any thoughts on tagging a maintainer for review? This has been open since March 13 with zero maintainer engagement. The implementation is solid (8 tests, code review feedback addressed, MoE fix integrated), just needs someone to look at it.

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented Mar 21, 2026

Hey @Goekdeniz-Guelmez, would you be able to take a look when you get a chance?

Quick summary: 8 unit tests, code review feedback from @janhilgard and @Thump604, rebased on main.
Results: 1.52x token generation on Qwen3.5-27B dense on M4 Pro, validated independently on M2 Ultra across three Qwen3.5 sizes (MoE and dense).

@layer4down
Copy link
Copy Markdown

Subject: Successfully running Qwen3.5-27B locally with workaround

Transparency Note: This comment was drafted with the assistance of an AI assistant to help document the troubleshooting process. All technical details and findings are from actual testing.


Thanks for this PR! I was able to get Qwen3.5-27B working locally with MLX, but encountered an issue that might help others.

The Bug I Was Addressing

When trying to use the model with a client that passes short model IDs, I encountered:

401 Client Error. (Request ID: Root=1-69bfb0a8...)
Repository Not Found for url: https://huggingface.co/api/models/qwen3_5-27b_4bit/revision/main.
Please make sure you specified the correct `repo_id` and `repo_type`.
User Access Token "Claude-flow-ro" is expired

The error message was misleading - it suggested an expired token, but the real issue was a config/weight mismatch described below.

Issue Encountered

The model failed to load with:

ValueError: Missing 15 parameters: 
language_model.mtp.fc.weight,
language_model.mtp.layers.0.input_layernorm.weight,
...

Root Cause

The model's config.json (from mlx-community/Qwen3.5-27B-4bit on HuggingFace) has:

{
  "text_config": {
    "mtp_num_hidden_layers": 1
  }
}

However, the actual .safetensors weights do not contain any MTP parameters. The PR code correctly expects MTP weights when mtp_num_hidden_layers > 0, but this particular model's config claims MTP support that isn't present in the weights.

Workaround

Set mtp_num_hidden_layers to 0 in the model's config:

cat config.json | jq '.text_config.mtp_num_hidden_layers = 0' > config_fixed.json
mv config_fixed.json config.json

Other Configuration Notes

For anyone trying this setup:

  • Context length: Model supports 98K+ context; works with --max-tokens 98304
  • KV cache quantization: Works with MLX_KV_CACHE_QUANT=true environment variable
  • Model path as ID: The server uses the full local path as the model ID in API calls. For example:
    // Request to /v1/chat/completions
    {
      "model": "/path/to/local/models/mlx-community/Qwen3.5-27B-4bit",
      "messages": [...]
    }
    Short names like "Qwen3.5-27B" will trigger a HuggingFace lookup (and fail if the repo doesn't exist or auth is expired).

Suggestion

It might be helpful to add a check/warning when:

  1. mtp_num_hidden_layers > 0 in config
  2. But MTP weights are missing from the loaded model

This would help users identify config/weight mismatches more quickly and avoid confusing auth error messages.

@AirRunner
Copy link
Copy Markdown
Author

@layer4down thanks for the write-up!

You're right, mlx-community/Qwen3.5-27B-4bit was quantized without the MTP head weights, the mtp_num_hidden_layers: 1 in the config is inherited from the original Qwen3.5 config but the MTP parameters were not included when quantizing.

To actually use MTP acceleration, the model needs to be re-quantized including the MTP layers using this branch.

As you suggested I just pushed a fix that raises a clear ValueError instead of the cryptic "Missing N parameters" crash.

@Thump604
Copy link
Copy Markdown

@angeloskath -- this PR has been open 11 days with no maintainer review. AirRunner rebased on 2026-03-21, all conflicts resolved, 8 unit tests passing.

We've been running this in production on M2 Ultra 128GB since day one. Qwen3.5-122B-A10B-VLM-MTP-5bit, 24/7 inference serving coding agents. MTP acceptance rates:

  • 27B dense 8-bit: 1.32x (32% acceptance, best fit)
  • 35B MoE 8-bit: 1.11x (11% acceptance)
  • 122B MoE 5-bit: 1.09x (9% acceptance)

MoE acceptance rates are lower because a single MTP layer can't predict expert routing well. Still a net win for the latency-sensitive use case.

The MoE sanitize fix (commit 8d06796) is essential for Qwen3.5 MoE models -- without it, 768 MTP parameters are silently missing. We've also published pre-converted VLM+MTP models on HuggingFace that depend on this code path.

Would be great to get this reviewed and merged so the community models work out of the box.

@cresseelia
Copy link
Copy Markdown

cresseelia commented Mar 29, 2026

Can we at the reviewer again? it's an important update for qwen3.5

@Thump604
Copy link
Copy Markdown

@angeloskath @awni — this PR has been open 17 days with no maintainer review or feedback. Multiple community members have asked for review (AirRunner, ourselves, cresseelia).

Is there a concern with the approach, scope, or implementation that's blocking review? We're happy to help address any issues — split the PR, rework the API surface, add tests, whatever is needed.

We're running this in production on 122B and have validated it across three Qwen3.5 model sizes. The community is actively hitting the config/weight mismatch that AirRunner already fixed in this branch (layer4down's report above). Without this merged, users have to manually patch config.json to use MTP on Qwen3.5 models.

If the PR needs changes or a different direction, we'd rather know than wait. Let us know how we can help move this forward.

@Goekdeniz-Guelmez
Copy link
Copy Markdown
Contributor

as you can see 6 files have been changes/added alongside 700 lines of added code. This is a PR that has big changes int he codebase itself. Reviewing and (correctly) implementing it will take time. 17 days not long enough. My full weight fine-tuning PR took multiple weeks to be merged. Just keep it open, update it and please be patient. Adding completely new features will take long.

@janhilgard
Copy link
Copy Markdown

@Goekdeniz-Guelmez — fair point, thanks for the perspective. We appreciate you taking the time to look at it.

To make the review easier, we can split this into two smaller PRs:

PR 1 — Model architecture (~260 lines): MTPModule, MTPDecoderLayer, SSM rollback support in GatedDeltaNet, MoE weight stacking, cache rollback field. Pure model-side changes, reviewable independently.

PR 2 — Generation + tests (~420 lines): mtp_generate_step() function, --mtp CLI flag, 8 unit tests. Depends on PR 1 but much easier to review once the model interface is established.

Would splitting it this way help with the review process? Happy to do the work if so.

@AirRunner — would you be open to splitting the PR this way?

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented Apr 1, 2026

@Goekdeniz-Guelmez — fair point, thanks for the perspective. We appreciate you taking the time to look at it.

To make the review easier, we can split this into two smaller PRs:

PR 1 — Model architecture (~260 lines): MTPModule, MTPDecoderLayer, SSM rollback support in GatedDeltaNet, MoE weight stacking, cache rollback field. Pure model-side changes, reviewable independently.

PR 2 — Generation + tests (~420 lines): mtp_generate_step() function, --mtp CLI flag, 8 unit tests. Depends on PR 1 but much easier to review once the model interface is established.

Would splitting it this way help with the review process? Happy to do the work if so.

@AirRunner — would you be open to splitting the PR this way?

@janhilgard I'm not sure splitting would actually help the review here actually?

The PRs you suggest wouldn't be reviewable in isolation, because the architecture changes only make sense in the context of how mtp_generate_step uses them. Also the changes in generate would be dead code until the other PR lands, so one would need to review both PRs together anyways.

(Also, 183 of the 683 added lines are just unit tests).

That said, I'm open to whatever helps, happy to reorganize if it does :).

@Thump604
Copy link
Copy Markdown

Thump604 commented Apr 1, 2026

@angeloskath @awni — this PR has been open 20+ days with no maintainer review. It is the foundation for MTP speculative decoding on Qwen3.5 models, which several of us are using in production. My PR #1085 (probabilistic acceptance, 2.3x throughput on 122B) builds directly on top of it.

AirRunner's implementation is solid: 8 tests, 80.6% acceptance on M4 Pro. Is there a concern about scope or approach blocking review?

@gyzerok
Copy link
Copy Markdown

gyzerok commented Apr 1, 2026

@Thump604 can you stop pinging people? The more annoying you are the less likely anyone is going to respond.

@janhilgard
Copy link
Copy Markdown

Great work — I've been running MTP on Qwen3.5 MoE models in production (M3 Ultra, 256 GB) and wanted to share findings that might explain the low MoE acceptance rates.

BF16 MTP weights are critical for MoE acceptance

Your quant_predicate excludes only mtp.fc:

if path.endswith("mtp.fc"):
    return False

But the MTP transformer layer (attention, MLP, norms) still gets quantized. We found that quantized MTP weights give near-0% acceptance on MoE models — the quantization error compounds through the expert routing prediction.

Fix: exclude ALL MTP weights from quantization:

if "mtp." in path:
    return False

Our MoE results with BF16 MTP weights

Model Quantization MTP weights Acceptance Speedup
35B-A3B 4-bit BF16 79-85% 1.18x
122B-A10B 4-bit BF16 77-78% 1.12x
35B-A3B 4-bit dequantized 4→BF16 ~0%

vs your MoE benchmarks (quantized MTP weights):

Model MTP weights Implied acceptance Speedup
35B-A3B 8-bit quantized ~11% 1.11x
122B-A10B 5-bit quantized ~5% 1.09x

The difference is stark: BF16 MTP weights → 79-85% acceptance, quantized → 5-11%.

Batch auto-skip

Your PR sets is_batchable = False when MTP is active. In our vllm-mlx integration (#245 on waybarrios/vllm-mlx) we auto-skip MTP when batch_size > 1:

if len(active_batch) > 1:
    # Skip MTP, fall back to standard generation
    return _orig_step(input_tokens, cache)

This gives the best of both worlds:

  • 1 request: MTP active → 86 tok/s (1.18x)
  • 8 requests: MTP skipped → 307 tok/s (full batching throughput)

Instead of disabling batching entirely, you could dynamically switch.

Weight extraction

We extract BF16 MTP weights from the original HF model (not the quantized MLX model) with a dedicated script. See vllm-mlx PR #245 for the add_mtp_weights_qwen35.py script that:

  • Downloads only MTP-containing shards (not entire model)
  • Stacks per-expert weights into SwitchLinear format
  • Applies RMSNorm +1.0 shift
  • Outputs native BF16

Happy to collaborate on getting BF16 MTP weights into the standard conversion pipeline.

@Thump604
Copy link
Copy Markdown

Thump604 commented Apr 2, 2026

I tested your BF16 MTP finding on our models. Sharing the data since it tells a different story on 5-bit and 8-bit backbones.

I extracted fresh BF16 MTP weights from the original HF models (not dequantized from quantized), applied the RMSNorm +1.0 shift, stacked MoE experts into SwitchLinear format, and re-quantized to match the backbone (5-bit gs=64 for 122B, 4-bit gs=64 for 4B). This matches the process you describe in your extraction script.

Results (probabilistic acceptance, temp=0.6):

Model Backbone Original quantized MTP BF16-source re-quantized MTP
4B dense 4-bit gs=64 44.9%, 91.8 tok/s 43.8%, 86.5 tok/s
122B MoE 5-bit gs=64 47.3%, 21.5 tok/s 47.3%, 21.0 tok/s

No measurable difference. Re-quantizing MTP from the BF16 source produces the same acceptance as the original quantized weights on these models.

I also tested with fully unquantized BF16 MTP (no re-quantization, just raw BF16 + norm shift). This gave 0% acceptance across all models. The BF16 MTP forward pass produces a different logit distribution than the quantized backbone expects. Once I re-quantize to match the backbone, the acceptance rate converges to the same ~46%.

Your 79-85% acceptance at 4-bit is significantly higher than what I see. A few questions:

  • Are you running the MTP layer entirely in BF16 (unquantized), or does your script quantize it to match the backbone?
  • Which mlx-lm generate path are you using? Our probabilistic acceptance is from PR feat: probabilistic MTP acceptance (speculative sampling) #1085 (min(1, p_target/p_draft)). Exact match at temp=0.6 gives ~5%.
  • Are your numbers from greedy (temp=0) or sampled (temp=0.6)?

Our acceptance ceiling appears to be ~47% with probabilistic sampling regardless of how the MTP weights are prepared, as long as they match the backbone's quantization. If you are getting 79-85%, there may be a difference in the generation loop or sampling strategy that accounts for the gap.

Previously _prefill only populated the backbone cache, leaving the MTP KVCache cold at the start of decode. The MTP head was trained with full prefix context, so starting from an empty cache is misaligned with training.

Now each prefill chunk passes return_hidden=True and immediately calls mtp_forward(hidden, y[1:n+1], mtp_cache). The hidden tensor is transient: consumed within the same iteration before mx.clear_cache().
@AirRunner
Copy link
Copy Markdown
Author

@JJJYmmm Thanks, you're right! To simplify _prefill I constrained the last backbone forward pass to exactly 1 token so that the hidden state [1, N, H] kept alive by return_hidden=True is minimal.
From there it was a short step to not prefilling the MTP cache at all during prompt processing. The MTP cache warms up naturally over the first few decode steps, but I hadn't measured the difference starting from a cold cache. So thanks for your numbers!

It's now fixed, model.mtp_forward() is called immediately (discarding the logits).

It's worth noting though that in multi-turn usage the MTP cache from the previous turn is already populated, so the cold-start effect would only apply to the very first turn. Also, at temp>0 with probabilistic acceptance, the criterion would partially compensates for cold-cache predictions, so the real-world delta will likely be small.
That said, prefilling is still the correct behavior.

VRAM considerations

This adds a bit of overhead, but the [1, N, H] hidden tensor is transient, as it's consumed by mtp_forward in the same iteration and freed before the next chunk.

For Qwen3.5 it just adds about 4 KB/token permanent VRAM cost for the MTP KV cache (4 KV heads × 256 head_dim × BF16), so about 40 MB for a 10K-token prompt.

generate_step calls mx.clear_cache() every 256 tokens to bound the Metal allocator's free list.

Introduce _CACHE_CLEAR_INTERVAL = 256 shared by both generate_step and mtp_generate_step to add the equivalent cache-clearing logic to the MTP decode loop.  The block-based counter (ntoks // _CACHE_CLEAR_INTERVAL) handles MTP iterations that could emit multiple tokens at once, where a '% interval == 0' check could skip a boundary.
@atelepov
Copy link
Copy Markdown

atelepov commented May 8, 2026

@AirRunner
I tested it on M1 MAX 32Gb, but there's no acceleration. It's actually slowing down.
Could you tell me if this is a limitation of M1 MAX 32Gb specifically? Or are some additional parameters incorrectly specified?

Configuration
M1 MAX 32Gb

Convert

uv run python -m mlx_lm convert --hf-path Qwen/Qwen3.6-27B \
  --mlx-path Qwen3.6-27B-mtp \
  -q --q-mode affine --q-bits 4 --q-group-size 64

no mtp

uv run python -m mlx_lm generate --model Qwen3.6-27B-mtp \
  --prompt "Write a quicksort in Python." --max-tokens 100

result

Prompt: 17 tokens, 5.002 tokens-per-sec
Generation: 100 tokens, 19.447 tokens-per-sec
Peak memory: 15.715 GB

mtp

uv run python -m mlx_lm generate --model Qwen3.6-27B-mtp \  
  --prompt "Write a quicksort in Python." --max-tokens 100 --mtp

result

Prompt: 17 tokens, 5.552 tokens-per-sec
Generation: 100 tokens, 17.553 tokens-per-sec
Peak memory: 15.814 GB

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented May 8, 2026

@atelepov Your setup looks correct. Are the MTP weights present in the quantized version?
M1 Max shouldn't be a limitation, as we tested it M2 Ultra, M4 Pro, M5 Pro/Max, etc.

It's known that on MoE there is no much gain, especially on the 35B-A3B where is could even drop 1 or 2%.
But the 27B should definitely benefit from MTP.

It might simply be the benchmark length (might be too much variance at 100 tokens).
Though I just tested on M4 Pro with the same model and prompt with --max-tokens 100, I still get +51.4%.

Could you retry with --max-tokens 256, or even more (like 2048)?

@s-n-t
Copy link
Copy Markdown

s-n-t commented May 8, 2026

For M1/M2 you always want "--dtype float16" when quantizing as bfloat16 goes through a software path I think? - I still don't see any improvement at 4-bit though:, but 8-bit shows a decent boost.

python3 -m mlx_lm convert --hf-path Qwen/Qwen3.6-27B \
  --mlx-path Qwen3.6-27B-mtp-8-bit \
  -q --q-mode affine --q-bits 8 --q-group-size 64 --dtype float16
python3 -m mlx_lm convert --hf-path Qwen/Qwen3.6-27B \
  --mlx-path Qwen3.6-27B-mtp-4-bit \
  -q --q-mode affine --q-bits 4 --q-group-size 64 --dtype float16
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-4-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024 --mtp
...
  ==========
Prompt: 17 tokens, 32.623 tokens-per-sec
Generation: 1024 tokens, 20.213 tokens-per-sec
Peak memory: 15.900 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-4-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024
...
 ==========
Prompt: 17 tokens, 30.828 tokens-per-sec
Generation: 1024 tokens, 20.601 tokens-per-sec
Peak memory: 15.737 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-8-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024 --mtp
...
==========
Prompt: 17 tokens, 23.865 tokens-per-sec
Generation: 1024 tokens, 14.406 tokens-per-sec
Peak memory: 29.513 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-8-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024
...  
==========
Prompt: 17 tokens, 22.529 tokens-per-sec
Generation: 1024 tokens, 11.346 tokens-per-sec
Peak memory: 29.351 GB

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented May 8, 2026

@s-n-t Interesting findings. So collecting different data points across the comments of this PR, we have:

Hardware Model dtype Baseline MTP Delta
M1 Max 27B 4-bit fp16 20.6 tok/s 20.2 tok/s -2%
M1 Max 27B 8-bit fp16 11.3 tok/s 14.4 tok/s +27%
M2 Ultra 27B 8-bit bf16 20.6 tok/s 27.1 tok/s +32%
M4 Pro 27B 4-bit bf16 15.6 tok/s 23.7 tok/s +51%
M5 Max 27B 4-bit bf16 32 tok/s 44 tok/s +35%

Important note: --dtype float16 explicit override seems to be important for M1, otherwise the default bf16 goes through a software path.
mtp.fc is kept in full precision by our quant_predicate, so if BF16 ops go through a software path on M1, that layer adds disproportionate overhead.


The M1 Max result is the outlier. The speedup differences across chips could be better explained by the β+δ framework from my reply to Anionex: speedup = (1+p) / (β+δ), where p is the per-round acceptance probability, β is the 2-token backbone overhead, and δ is the MTP head cost, both relative to a single baseline step.

Also from what I understand, there is no dual-datapath execution (FP32/BF16 and int4 ops serialized) or dedicated matrix accelerators on M1. M3 introduced both. So the M1 GPU generation likely adds disproportionate compute overhead on the MTP head forward pass (short sequence, more compute-bound than the memory-bound backbone).

With p≈0.85 (from α=0.46 via α=p/(1+p)), breakeven is at β+δ = 1+p ≈ 1.85. On M4 Pro β+δ = 1.190, well below that. On M1 Max, the MTP head forward pass likely pushes β+δ closer to 1.85, which would explain the near-zero result. Running this bench script would give a measured β+δ for M1 Max.

@s-n-t
Copy link
Copy Markdown

s-n-t commented May 8, 2026

Without the "--dtype float16" on the same hardware (confirms @atelepov numbers):

python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-4-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024 --mtp
...
==========
Prompt: 17 tokens, 24.655 tokens-per-sec
Generation: 1024 tokens, 15.337 tokens-per-sec
Peak memory: 15.895 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-4-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024
...
==========
Prompt: 17 tokens, 23.971 tokens-per-sec
Generation: 1024 tokens, 19.700 tokens-per-sec
Peak memory: 15.737 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-8-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024 --mtp
...
==========
Prompt: 17 tokens, 19.169 tokens-per-sec
Generation: 1024 tokens, 14.032 tokens-per-sec
Peak memory: 29.513 GB
python3 -m mlx_lm generate --model Qwen3.6-27B-mtp-8-bit \
  --prompt "Write a quicksort in Python." --max-tokens 1024
...
==========
Prompt: 17 tokens, 18.781 tokens-per-sec
Generation: 1024 tokens, 11.231 tokens-per-sec
Peak memory: 29.351 GB

@atelepov
Copy link
Copy Markdown

atelepov commented May 8, 2026

@AirRunner
Thank you very much.
Specifying the --dtype float16 parameter during conversion increases response generation.

--dtype float16
MTP

--max-tokens 256 --temp 0
Prompt: 17 tokens, 5.502 tokens-per-sec
Generation: 256 tokens, 21.729 tokens-per-sec
Peak memory: 15.833 GB

--max-tokens 2048 --temp 0
Prompt: 17 tokens, 5.758 tokens-per-sec
Generation: 1628 tokens, 22.117 tokens-per-sec
Peak memory: 15.945 GB

NO MTP

--max-tokens 256 --temp 0
Prompt: 17 tokens, 5.501 tokens-per-sec
Generation: 256 tokens, 20.285 tokens-per-sec
Peak memory: 15.715 GB

--max-tokens 2048 --temp 0
Prompt: 17 tokens, 5.690 tokens-per-sec
Generation: 1628 tokens, 19.700 tokens-per-sec
Peak memory: 15.777 GB

--dtype DEFAULT

MTP

--max-tokens 256
Prompt: 17 tokens, 4.490 tokens-per-sec
Generation: 256 tokens, 18.103 tokens-per-sec
Peak memory: 15.834 GB

--max-tokens 256 --temp 0
Prompt: 17 tokens, 5.511 tokens-per-sec
Generation: 256 tokens, 17.529 tokens-per-sec
Peak memory: 15.833 GB

--max-tokens 256 --temp 0.6
Prompt: 17 tokens, 5.454 tokens-per-sec
Generation: 256 tokens, 17.555 tokens-per-sec
Peak memory: 15.833 GB

--max-tokens 2048
Prompt: 17 tokens, 6.118 tokens-per-sec
Generation: 1806 tokens, 15.917 tokens-per-sec
Peak memory: 15.957 GB

--max-tokens 2048 --temp 0
Prompt: 17 tokens, 5.797 tokens-per-sec
Generation: 1806 tokens, 15.092 tokens-per-sec
Peak memory: 15.961 GB

NO MTP

--max-tokens 256
Prompt: 17 tokens, 5.449 tokens-per-sec
Generation: 256 tokens, 11.092 tokens-per-sec
Peak memory: 15.715 GB

--max-tokens 256 --temp 0
Prompt: 17 tokens, 4.499 tokens-per-sec
Generation: 256 tokens, 19.318 tokens-per-sec
Peak memory: 15.715 GB

--max-tokens 256 --temp 0.6
Prompt: 17 tokens, 5.136 tokens-per-sec
Generation: 256 tokens, 19.694 tokens-per-sec
Peak memory: 15.715 GB

--max-tokens 2048
Prompt: 17 tokens, 5.194 tokens-per-sec
Generation: 1830 tokens, 12.277 tokens-per-sec
Peak memory: 15.796 GB

--max-tokens 2048 --temp 0
Prompt: 17 tokens, 5.472 tokens-per-sec
Generation: 1830 tokens, 17.354 tokens-per-sec
Peak memory: 15.797 GB

@AirRunner
Copy link
Copy Markdown
Author

AirRunner commented May 8, 2026

@s-n-t Ok then your -22% (bf16) vs -2% (fp16) gap is almost entirely explained by BF16 emulation on M1.

M1 has no native BF16 GPU support, MTLDataType.bfloat requires Apple8+ (M1 is Apple7), confirmed by hw.optional.arm.FEAT_BF16: 0 on M1.
The quant_predicate keeps mtp.fc in full precision, so that layer runs through a software path on M1, adding disproportionate overhead to every MTP step.

So for M1 users: --dtype float16 is strongly recommended.

The 8-bit case is much less affected by dtype (+25% vs +27%) because the baseline is slower. Same absolute BF16 overhead on mtp.fc, smaller fraction of a slower step.

With fp16, the residual -2% on 4-bit likely reflects the β+δ compute overhead on M1's GPU architecture (see here).


@atelepov Nice, around +10% on M1 Max seems to be the expected range on this architecture.

@heykb
Copy link
Copy Markdown

heykb commented May 9, 2026

M4 pro 48G. Qwen3.6-27 4bit
============================================================================
SUMMARY  (decode tok/s, first token excluded)
============================================================================
Condition                     Pooled        Mean+-SD    Accept     Prefill
----------------------------------------------------------------------------
  baseline  temp=0             15.26    15.3+-0.2        --        30.0
  MTP       temp=0             22.02    22.1+-1.3       46.9%      31.2
----------------------------------------------------------------------------

Speedup (pooled decode tok/s, MTP vs baseline at matching temperature):
  MTP       temp=0        1.444x  (alpha=0.469)  beta+delta=1.3038

Model  : /Users/a1-6/models/Qwen3.6-27B-mtp
Config : max_tokens=256, runs=3, warmup=1, prompts=8

@TomLucidor
Copy link
Copy Markdown

TomLucidor commented May 9, 2026

@heykb did MTP manage to accelerate prefill by accident?

Declare u = mx.random.uniform() immediately before its first use (mx.eval) rather than before the unrelated _step_backbone call.
@crazyi
Copy link
Copy Markdown

crazyi commented May 10, 2026

@s-n-t Ok then your -22% (bf16) vs -2% (fp16) gap is almost entirely explained by BF16 emulation on M1.

M1 has no native BF16 GPU support, MTLDataType.bfloat requires Apple8+ (M1 is Apple7), confirmed by hw.optional.arm.FEAT_BF16: 0 on M1. The quant_predicate keeps mtp.fc in full precision, so that layer runs through a software path on M1, adding disproportionate overhead to every MTP step.

So for M1 users: --dtype float16 is strongly recommended.

The 8-bit case is much less affected by dtype (+25% vs +27%) because the baseline is slower. Same absolute BF16 overhead on mtp.fc, smaller fraction of a slower step.

With fp16, the residual -2% on 4-bit likely reflects the β+δ compute overhead on M1's GPU architecture (see here).

@atelepov Nice, around +10% on M1 Max seems to be the expected range on this architecture.

I found that hw.optional.arm.FEAT_BF16: 1 is supported on the M2 chip. If the M2 chip supports BF16, why does converting to FP16 still result in performance improvements?
jundot/omlx#604
Why the M2 is more advanced that it seemed

@deepsweet
Copy link
Copy Markdown

If the M2 chip supports BF16, why does converting to FP16 still result in performance improvements?

See my detailed benchmark and conclusions.

Empirical benchmarks (Qwen3.6-27B 4-bit, M4 Pro, temp=0/0.6/1.0) show no measurable impact on MTP acceptance rate when mtp.fc is quantized to 4-bit: acceptance delta is within noise (−0.2 to +0.3 pp), speedup delta within noise (−0.003 to +0.026x).

Additionally, keeping mtp.fc in BF16 penalizes M1 users where BF16 has no native GPU support.
@AirRunner
Copy link
Copy Markdown
Author

Following up on these: comment1, comment2, comment3.

The original rationale for excluding mtp.fc from quantization was that it projects directly onto the vocab logits, so quantizing it could shift the draft distribution before the acceptance test.

However I just tested this empirically on Qwen3.6-27B 4-bit. Two models: mtp.fc in BF16, 105 MB vs. mtp.fc quantized to 4-bit group=64 (26 MB). M4 Pro, 8 prompts, temp=0/0.6/1.0.
See bench_mtp.py with benchmark 'standard'.

Results

Condition delta tok/s delta accept delta speedup
temp=0 −0.03 −0.2 pp +0.003x
temp=0.6 +0.35 +0.3 pp +0.026x
temp=1.0 −0.07 −0.1 pp 0.000x

No measurable degradation at any temperature. Acceptance delta is ±0.3 pp across 16 samples/condition, within noise.
So I just removed mtp.fc exclusion from quant_predicate.

Also, as a side effect, keeping mtp.fc in BF16 penalized M1 and M2 chips and required --dtype float16 as a workaround (which means a dedicated model checkpoint).
Quantizing mtp.fc fixes it at the source.


Raw logs here.
(run 3 was discarded due to unrelated throttling)

The test was checking mtp.fc exclusion, which was removed in c47c1cb after empirical benchmarks.
@deepsweet
Copy link
Copy Markdown

M2 Max @ 30 GPU cores here.

Qwen3.6-27B + Q4 + FP16:

================================================================================
SUMMARY  (decode tok/s, first token excluded)
================================================================================
Condition                     Pooled        Mean+-SD    Accept     Prefill
--------------------------------------------------------------------------------
  baseline_temp0_6             13.48    13.6+-1.2        --        25.6
  mtp_temp0_6                  15.71    15.8+-1.5       46.2%      22.6
--------------------------------------------------------------------------------

Speedup (pooled decode tok/s, MTP vs baseline at matching temperature):
  mtp_temp0_6             1.165x  (alpha=0.462, beta+delta=1.5941)

GoodOlClint added a commit to GoodOlClint/mlx-lm that referenced this pull request May 15, 2026
… the ledger gap)

Brings the MTP machinery fully current with the PR ml-explore#990 tip, closing
the temp>0 sampling-correctness gap that the corrected ledger flagged.
Adopted commits not previously in our tree:

- 13f157b fix(mtp): use residual sampling on rejection at temp>0
- 6594348 fix(mtp): reduce residual sampling to 1 sync, correct z=0 fallback
- 87f1b09 feat(mtp): native sampling params, XTC draw sharing, correct lp_accept
- a2f1374 quality: from functools import partial
- b1dad14 fix(mtp): prefill MTP cache during prompt prefill
- 8a52379 fix(mtp): input_embeddings + logits_processors dimensionality
- 32fdaa3 fix(mtp): remove spurious mtp_cache trim on draft rejection
- a5a82a9 style(mtp): move u after _step_backbone
- c47c1cb qwen3_5: remove mtp.fc exclusion from quant_predicate
- 6222938 test(mtp): remove stale quant_predicate test
(48e1fca / fae9fa1 net content was already present; cache.py and
qwen3_5_moe.py were already byte-identical to tip.)

Method: checked out generate.py / qwen3_5.py / qwen3_5_moe.py /
cache.py / test_mtp.py at pr/990 tip, then re-applied Patch 3's three
disjoint additive generate.py hunks (import os, --json-schema arg,
the json_schema logits_processors block in main()). The MTP sampling
refactor and Patch 3's CLI plumbing touch non-overlapping regions, so
the re-application is exact: `git diff pr/990 -- generate.py` is now
precisely Patch 3's 29 insertions; the other MTP files are
byte-identical to tip.

Subsumes the standalone 68b2cd4 (ffac433 cache-clear, cherry-picked
out of order with a 2-var adaptation because 87f1b09 was then absent)
and the original Patch 1 squash; the final tree state equals tip
regardless of that earlier adaptation.

Also reconciled the two ml-explore#990 commits that touch files outside the
core MTP set:
- mlx_lm/sample_utils.py taken wholesale at tip (unpatched by us;
  adds make_sampler_chain / native sampling params — the missing
  import that broke collection until reconciled).
- mlx_lm/server.py: applied only the two ml-explore#990 hunks we lacked
  (_xtc_special_tokens helper + native sampling kwargs threaded into
  the _serve_single MTP generate call). Patches 2/4/6/7/8 regions
  untouched; the ml-explore#990 server.py delta is now fully reconciled.

Test status: tests/test_structured.py + tests/test_mtp.py (now the
full tip suite) + tests/test_server.py green (88); full suite 263
passed, only the 3 pre-existing test_tokenizers BPE-whitespace
environment flakes (unrelated, untouched module).

VALIDATION GATE (per CLAUDE.md "Validation requirements before
tagging"): this changes bench-validated Patch 1 behavior at temp>0.
Unit suite is green here, but the case-project benchmark harness +
MTP-preserving converted weights are NOT on this box, so the
mandated re-bench has NOT been run. Do not treat any tag built on
this commit as bench-valid until that run completes.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The accepted draft token was never processed by the MTP head, causing the cache to drift behind the backbone cache by one entry per accept. After k accepts the MTP head operates on k tokens of missing context. Empirically the impact was negligible though (backbone hidden dominates MTP head conditioning).

Fix: extend _step_mtp with an optional cache_commit=(hidden, tok) parameter. When set, the alignment position and the draft position are processed in a single 2-token batched mtp_forward, committing the accepted token to mtp_cache at no extra forward-pass cost.
@AirRunner
Copy link
Copy Markdown
Author

M2 Max @ 30 GPU cores here.

Qwen3.6-27B + Q4 + FP16:

================================================================================
SUMMARY  (decode tok/s, first token excluded)
================================================================================
Condition                     Pooled        Mean+-SD    Accept     Prefill
--------------------------------------------------------------------------------
  baseline_temp0_6             13.48    13.6+-1.2        --        25.6
  mtp_temp0_6                  15.71    15.8+-1.5       46.2%      22.6
--------------------------------------------------------------------------------

Speedup (pooled decode tok/s, MTP vs baseline at matching temperature):
  mtp_temp0_6             1.165x  (alpha=0.462, beta+delta=1.5941)

@deepsweet Interesting to see quite high β+δ on M2 Max given that you have FP16 weights. It's actually consistent with atelepov's M1 Max results (+7-10% with FP16 weights).

I believe even with FP16 weights MLX intermediate activations likely remain in BF16, which still goes through the FP32 path in Metal's matrix multiplication on M1/M2.

To confirm, could you run bench_mtp_timing.py on your M2 Max? It measures β and δ directly and also includes a --synth mode that benchmarks quantized_matmul at L=1/2/4 to check whether the Metal kernel dispatch differs on your chip (on M4 Pro L=1 and L=2 cost the same at 3.3ms, but that may not hold on M2).

python bench_mtp_timing.py --model <path>  # full β/δ measurement
python bench_mtp_timing.py --model <path> --synth  # lm_head kernel dispatch

Also one thing that also stands out: your baseline SD is ±1.2 tok/s, whereas I get ±0.1. That seems unusual for a baseline, how many runs did you have? Maybe worth a rerun to confirm the β+δ is stable.


(Note: the "46.2% acceptance" shown actually means 85.9% in the A/V metric. The old script only reported A/(V+A)).

@deepsweet
Copy link
Copy Markdown

deepsweet commented May 16, 2026

@AirRunner same Qwen3.6-27B + Q4 + FP16 as above:

=== Baseline (generate_step, L=1 backbone) ===
  T_baseline : mean=57.8ms  median=62.2ms  sd=8.9ms  n=256

=== MTP verify pass ===
  T_mtp_head    : mean=5.8ms  median=6.3ms  sd=1.1ms  n=139
  T_round       : mean=105.6ms  median=105.6ms  sd=2.1ms  n=138
  T_backbone_L2 : mean=99.7ms  (T_round - T_mtp)
  Acceptance    : p=0.803  (44.5% as A/(V+A))

=======================================================
  beta       = T_backbone_L2 / T_baseline = 1.725
  delta      = T_mtp_head    / T_baseline = 0.101
  beta+delta = 1.826
  predicted speedup = (1+p)/(beta+delta)  = 0.987x

  Reference (M4 Pro): beta=1.123  delta=0.072  beta+delta=1.195
=======================================================

and

=== Synth: quantized_matmul (V=248320, H=5120, group=64, bits=4) ===
  L=1: avg=2.37ms  min=2.35ms
  L=2: avg=3.20ms  min=3.17ms
  L=4: avg=5.90ms  min=5.87ms

Also one thing that also stands out: your baseline SD is ±1.2 tok/s, whereas I get ±0.1. That seems unusual for a baseline, how many runs did you have? Maybe worth a rerun to confirm the β+δ is stable.

I've ran the default "quick" preset without any extra arguments.

@AirRunner
Copy link
Copy Markdown
Author

@deepsweet Hum ok so your β=1.725 means the verify backbone (L=2) costs 72% more than baseline (L=1), which is pretty high. I also see that your baseline SD is unusually high (±8.9ms when I get ±0.4ms).
(I use L where Metal notation is M, that is to differentiate from chips' names).

The synth goes in this direction with 1.35x, but doesn't really explain β=1.725. The full backbone ratio is worse than lm_head alone, meaning MLP and attention layers have an even larger L=2/L=1 penalty...

Actually this is the same family of issue I documented in ml-explore/mlx#3553, but there the non-linear step shifted at L=3 on my M4 Pro. On M2 it seems to be at L=2.

Looking at the dispatch code, qmv uses the same kernel for all L values below the vector_limit, which is 6 for applegpu_g13/14 and 10 for applegpu_g15/16.
So neither the L=2 nor the L=3 bump is in the C++ dispatch logic. It has to be in GPU scheduling (wave occupancy or similar), and the threshold differs by chip generation.


So to sum up, on M1/M2 the MTP verify pass (L=2) seems to hit a non-linear cost increase that isn't present on M3+ (where L=2L=1). This caps the MTP benefit to ±1.1x or ±1.2x on M1/M2 (on the 27B at least), regardless of acceptance ratio. FP16 weights help from being in a degraded state to 1.1x, but then the bottleneck might be GPU scheduling at L=2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.