Skip to content

perf(mtp): Gemma4 MTP sliding-window fix + 8-bit benchmark results#109

Open
solderzzc wants to merge 1 commit into
mainfrom
feat/mtp-window-fix-8bit-bench
Open

perf(mtp): Gemma4 MTP sliding-window fix + 8-bit benchmark results#109
solderzzc wants to merge 1 commit into
mainfrom
feat/mtp-window-fix-8bit-bench

Conversation

@solderzzc
Copy link
Copy Markdown
Member

Summary

Resolves throughput regression in Gemma4 MTP speculative decoding at long-context lengths, and documents precision-dependent MTP behaviour from empirical benchmarking.

mlx-swift-lm submodule (→ fix/compiler-warnings-mtp-optim @ c552b4d)

  • Sliding-window KV cap: runMTPHead caps shared-KV cross-attention to last 16 backbone positions (O(T) → O(16)). Eliminates 2–4× throughput regression at 40K–100K context.
  • MTPPartialRollback protocol: stores lastBackboneHiddenStateAll for partial-rejection rollback without re-running the main model.
  • callMTPHeadOnly: re-seeds MTP head from cached backbone state at near-zero cost.
  • numMTPDraftTokens=2 (depth=4 empirically slower on Metal).

Benchmark Results (M5 Pro 64 GB, gemma-4-26b-a4b-it-8bit)

8-bit is bandwidth-bound → KV reads amortize across the batch → MTP provides real gains:

Config 512 ctx 40K ctx 100K ctx
Vanilla 53.7 32.4 14.9
Vanilla + MTP 47.1 38.8 (+20%) 22.5 (+51%)
Vanilla + TurboQuant 53.5 50.1 48.3
TQ + MTP 47.4 31.0 23.3

4-bit MoE is compute-bound; MTP neutral. TQ+MTP counterproductive at both precisions.

Other

  • mtp_bench.py: Metal warmup request before first timed run (fixes inflated 1.77s TTFT)
  • README: split 4-bit/8-bit tables, precision-specific guidance
  • Server/CLI: --mtp / --num-mtp-tokens wired through GenerationConfig

User Guidance

Precision Best long-context config
8-bit --mtp alone (+20–51%)
8-bit, memory-critical --turbo-kv alone
4-bit MoE --turbo-kv alone
4-bit + both Skip MTP

…rmup

## Changes

### mlx-swift-lm (submodule bump → c552b4d)
- Cap Gemma4 MTP shared-KV cross-attention to last 16 backbone positions
  (O(T) → O(16)), eliminating throughput regression at 40K-100K context
- MTPPartialRollback protocol: callMTPHeadOnly re-seeds MTP draft from
  cached backbone state without re-running the main model
- numMTPDraftTokens=2 per pass (depth=4 empirically worse on Metal)

### Benchmark results (gemma-4-26b-a4b-it-8bit, M5 Pro 64GB)
8-bit is bandwidth-bound (2× heavier weights). KV reads amortize across
the verification batch → MTP provides real throughput gains at 8-bit:
  40K ctx:  38.8 tok/s MTP vs 32.4 vanilla  (+20%)
  100K ctx: 22.5 tok/s MTP vs 14.9 vanilla  (+51%)
4-bit MoE is compute-bound (MoE FFN dominates); MTP neutral/overhead.
TQ+MTP counterproductive at both precisions (TQ removes bandwidth
bottleneck, making MTP's batch cost proportional again).

### scripts/profiling/mtp_bench.py
- Add Metal shader warmup request before first timed run per config
  (fixes inflated TTFT on first 512-token measurement — 1.77s → ~0.3s)

### README.md
- Split 4-bit / 8-bit benchmark tables with accurate current numbers
- Add precision-specific guidance: --mtp for 8-bit at 40K+; --turbo-kv
  alone for max throughput; don't combine TQ+MTP

### Server.swift / CLICommandBuilder / GenerationConfig
- --mtp / --num-mtp-tokens wired through CLI → GenerationConfig
- Architectural note in MTP dispatch path explaining compute-bound
  behaviour on 4-bit MoE (documents why TQ+MTP underperforms TQ)
Copilot AI review requested due to automatic review settings May 19, 2026 01:35
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR wires Gemma4 MTP assistant-model speculative decoding into the SwiftLM server/CLI flow and adds benchmarking/documentation for precision-dependent long-context performance.

Changes:

  • Adds Gemma4 MTP assistant auto-resolution/loading and MTP acceptance logging in server responses.
  • Persists/exports MTP generation settings via GenerationConfig and CLI command builder.
  • Adds a new Test 13 profiling script and updates benchmark guidance in README/run script.

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
Sources/SwiftLM/Server.swift Loads Gemma4 MTP assistant models, passes them into generation, and logs MTP acceptance.
Sources/MLXInferenceCore/GenerationConfig.swift Adds explicit decoding defaults for persisted MTP-related generation settings.
Sources/MLXInferenceCore/CLICommandBuilder.swift Emits --mtp and --num-mtp-tokens in generated CLI commands.
scripts/profiling/mtp_bench.py Adds HTTP-based MTP/TurboQuant benchmark runner.
run_benchmark.sh Routes Test 13 to the new Python benchmark script.
README.md Updates Gemma 4 benchmark tables and MTP/TurboQuant guidance.
Comments suppressed due to low confidence (1)

Sources/MLXInferenceCore/CLICommandBuilder.swift:60

  • This omits --num-mtp-tokens when the UI config is 1, but the SwiftLM server flag defaults to 3 (Server.swift:286-287). A copied command for enableMTP=true, numMTPTokens=1 would therefore run with 3 draft tokens instead of the configured value.
        if config.numMTPTokens != 1 {
            parts.append("--num-mtp-tokens \(config.numMTPTokens)")

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +678 to +684
let asstConfig = ModelConfiguration(id: asstId)
let asstDownloader = HubDownloader(hub: HubApi(downloadBase: cacheRoot))
let asstContainer = try await LLMModelFactory.shared.loadContainer(
from: asstDownloader,
using: TransformersTokenizerLoader(),
configuration: asstConfig
) { _ in }
Comment on lines +116 to +117
self.enableMTP = try container.decodeIfPresent(Bool.self, forKey: .enableMTP) ?? false
self.numMTPTokens = try container.decodeIfPresent(Int.self, forKey: .numMTPTokens) ?? 1
Comment on lines +57 to +61
if config.enableMTP {
parts.append("--mtp")
if config.numMTPTokens != 1 {
parts.append("--num-mtp-tokens \(config.numMTPTokens)")
}
Comment on lines +676 to +692
if let asstId = asstModelId, resolveModelDirectory(modelId: asstId) != nil {
print("[SwiftLM] Loading Gemma4 MTP assistant model: \(asstId)")
let asstConfig = ModelConfiguration(id: asstId)
let asstDownloader = HubDownloader(hub: HubApi(downloadBase: cacheRoot))
let asstContainer = try await LLMModelFactory.shared.loadContainer(
from: asstDownloader,
using: TransformersTokenizerLoader(),
configuration: asstConfig
) { _ in }
mtpAsstRef = await asstContainer.extractDraftModel()
print("[SwiftLM] MTP assistant loaded (\(self.numMtpTokens) draft tokens/round). Using DualModelMTP speculative path.")
} else {
if let asstId = asstModelId {
print("[SwiftLM] ⚠️ Gemma4 MTP: assistant model '\(asstId)' not found in HF cache. Run: python -m mlx_lm.convert --hf-path \(asstId) to download it.")
} else {
print("[SwiftLM] ⚠️ --mtp: model '\(modelId)' is not a known Gemma4 MTP model. MTP requires a Gemma4 assistant checkpoint.")
}
Comment thread README.md
**Key takeaways (4-bit):**
- 🚀 **TurboQuant is the headline win**: At 100K context, `Vanilla + TurboQuant` delivers **66.9 tok/s** vs **27.5 tok/s** Vanilla — a **2.43× speedup**.
- 💾 **Massive memory savings**: OS RAM at 40K context drops from **48.7 GB → 18.2 GB** with TurboQuant (63% reduction).
- ⚡ **MTP neutral on 4-bit MoE**: The 4-bit model is compute-bound (MoE expert dispatch). Batch verification scales linearly with token count, so MTP provides no net throughput gain over vanilla at 4-bit.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants