perf(mtp): Gemma4 MTP sliding-window fix + 8-bit benchmark results#109
Open
solderzzc wants to merge 1 commit into
Open
perf(mtp): Gemma4 MTP sliding-window fix + 8-bit benchmark results#109solderzzc wants to merge 1 commit into
solderzzc wants to merge 1 commit into
Conversation
…rmup ## Changes ### mlx-swift-lm (submodule bump → c552b4d) - Cap Gemma4 MTP shared-KV cross-attention to last 16 backbone positions (O(T) → O(16)), eliminating throughput regression at 40K-100K context - MTPPartialRollback protocol: callMTPHeadOnly re-seeds MTP draft from cached backbone state without re-running the main model - numMTPDraftTokens=2 per pass (depth=4 empirically worse on Metal) ### Benchmark results (gemma-4-26b-a4b-it-8bit, M5 Pro 64GB) 8-bit is bandwidth-bound (2× heavier weights). KV reads amortize across the verification batch → MTP provides real throughput gains at 8-bit: 40K ctx: 38.8 tok/s MTP vs 32.4 vanilla (+20%) 100K ctx: 22.5 tok/s MTP vs 14.9 vanilla (+51%) 4-bit MoE is compute-bound (MoE FFN dominates); MTP neutral/overhead. TQ+MTP counterproductive at both precisions (TQ removes bandwidth bottleneck, making MTP's batch cost proportional again). ### scripts/profiling/mtp_bench.py - Add Metal shader warmup request before first timed run per config (fixes inflated TTFT on first 512-token measurement — 1.77s → ~0.3s) ### README.md - Split 4-bit / 8-bit benchmark tables with accurate current numbers - Add precision-specific guidance: --mtp for 8-bit at 40K+; --turbo-kv alone for max throughput; don't combine TQ+MTP ### Server.swift / CLICommandBuilder / GenerationConfig - --mtp / --num-mtp-tokens wired through CLI → GenerationConfig - Architectural note in MTP dispatch path explaining compute-bound behaviour on 4-bit MoE (documents why TQ+MTP underperforms TQ)
Contributor
There was a problem hiding this comment.
Pull request overview
This PR wires Gemma4 MTP assistant-model speculative decoding into the SwiftLM server/CLI flow and adds benchmarking/documentation for precision-dependent long-context performance.
Changes:
- Adds Gemma4 MTP assistant auto-resolution/loading and MTP acceptance logging in server responses.
- Persists/exports MTP generation settings via
GenerationConfigand CLI command builder. - Adds a new Test 13 profiling script and updates benchmark guidance in README/run script.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
Sources/SwiftLM/Server.swift |
Loads Gemma4 MTP assistant models, passes them into generation, and logs MTP acceptance. |
Sources/MLXInferenceCore/GenerationConfig.swift |
Adds explicit decoding defaults for persisted MTP-related generation settings. |
Sources/MLXInferenceCore/CLICommandBuilder.swift |
Emits --mtp and --num-mtp-tokens in generated CLI commands. |
scripts/profiling/mtp_bench.py |
Adds HTTP-based MTP/TurboQuant benchmark runner. |
run_benchmark.sh |
Routes Test 13 to the new Python benchmark script. |
README.md |
Updates Gemma 4 benchmark tables and MTP/TurboQuant guidance. |
Comments suppressed due to low confidence (1)
Sources/MLXInferenceCore/CLICommandBuilder.swift:60
- This omits
--num-mtp-tokenswhen the UI config is1, but the SwiftLM server flag defaults to3(Server.swift:286-287). A copied command forenableMTP=true, numMTPTokens=1would therefore run with 3 draft tokens instead of the configured value.
if config.numMTPTokens != 1 {
parts.append("--num-mtp-tokens \(config.numMTPTokens)")
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+678
to
+684
| let asstConfig = ModelConfiguration(id: asstId) | ||
| let asstDownloader = HubDownloader(hub: HubApi(downloadBase: cacheRoot)) | ||
| let asstContainer = try await LLMModelFactory.shared.loadContainer( | ||
| from: asstDownloader, | ||
| using: TransformersTokenizerLoader(), | ||
| configuration: asstConfig | ||
| ) { _ in } |
Comment on lines
+116
to
+117
| self.enableMTP = try container.decodeIfPresent(Bool.self, forKey: .enableMTP) ?? false | ||
| self.numMTPTokens = try container.decodeIfPresent(Int.self, forKey: .numMTPTokens) ?? 1 |
Comment on lines
+57
to
+61
| if config.enableMTP { | ||
| parts.append("--mtp") | ||
| if config.numMTPTokens != 1 { | ||
| parts.append("--num-mtp-tokens \(config.numMTPTokens)") | ||
| } |
Comment on lines
+676
to
+692
| if let asstId = asstModelId, resolveModelDirectory(modelId: asstId) != nil { | ||
| print("[SwiftLM] Loading Gemma4 MTP assistant model: \(asstId)") | ||
| let asstConfig = ModelConfiguration(id: asstId) | ||
| let asstDownloader = HubDownloader(hub: HubApi(downloadBase: cacheRoot)) | ||
| let asstContainer = try await LLMModelFactory.shared.loadContainer( | ||
| from: asstDownloader, | ||
| using: TransformersTokenizerLoader(), | ||
| configuration: asstConfig | ||
| ) { _ in } | ||
| mtpAsstRef = await asstContainer.extractDraftModel() | ||
| print("[SwiftLM] MTP assistant loaded (\(self.numMtpTokens) draft tokens/round). Using DualModelMTP speculative path.") | ||
| } else { | ||
| if let asstId = asstModelId { | ||
| print("[SwiftLM] ⚠️ Gemma4 MTP: assistant model '\(asstId)' not found in HF cache. Run: python -m mlx_lm.convert --hf-path \(asstId) to download it.") | ||
| } else { | ||
| print("[SwiftLM] ⚠️ --mtp: model '\(modelId)' is not a known Gemma4 MTP model. MTP requires a Gemma4 assistant checkpoint.") | ||
| } |
| **Key takeaways (4-bit):** | ||
| - 🚀 **TurboQuant is the headline win**: At 100K context, `Vanilla + TurboQuant` delivers **66.9 tok/s** vs **27.5 tok/s** Vanilla — a **2.43× speedup**. | ||
| - 💾 **Massive memory savings**: OS RAM at 40K context drops from **48.7 GB → 18.2 GB** with TurboQuant (63% reduction). | ||
| - ⚡ **MTP neutral on 4-bit MoE**: The 4-bit model is compute-bound (MoE expert dispatch). Batch verification scales linearly with token count, so MTP provides no net throughput gain over vanilla at 4-bit. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Resolves throughput regression in Gemma4 MTP speculative decoding at long-context lengths, and documents precision-dependent MTP behaviour from empirical benchmarking.
mlx-swift-lm submodule (→ fix/compiler-warnings-mtp-optim @ c552b4d)
Benchmark Results (M5 Pro 64 GB, gemma-4-26b-a4b-it-8bit)
8-bit is bandwidth-bound → KV reads amortize across the batch → MTP provides real gains:
4-bit MoE is compute-bound; MTP neutral. TQ+MTP counterproductive at both precisions.
Other
User Guidance