From 3621b38397912bb7626744e8eb2fe52d6c564459 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 18 May 2026 18:34:08 -0700 Subject: [PATCH] perf(mtp): Gemma4 MTP window fix + 8-bit benchmark results + bench warmup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Changes ### mlx-swift-lm (submodule bump → c552b4d) - Cap Gemma4 MTP shared-KV cross-attention to last 16 backbone positions (O(T) → O(16)), eliminating throughput regression at 40K-100K context - MTPPartialRollback protocol: callMTPHeadOnly re-seeds MTP draft from cached backbone state without re-running the main model - numMTPDraftTokens=2 per pass (depth=4 empirically worse on Metal) ### Benchmark results (gemma-4-26b-a4b-it-8bit, M5 Pro 64GB) 8-bit is bandwidth-bound (2× heavier weights). KV reads amortize across the verification batch → MTP provides real throughput gains at 8-bit: 40K ctx: 38.8 tok/s MTP vs 32.4 vanilla (+20%) 100K ctx: 22.5 tok/s MTP vs 14.9 vanilla (+51%) 4-bit MoE is compute-bound (MoE FFN dominates); MTP neutral/overhead. TQ+MTP counterproductive at both precisions (TQ removes bandwidth bottleneck, making MTP's batch cost proportional again). ### scripts/profiling/mtp_bench.py - Add Metal shader warmup request before first timed run per config (fixes inflated TTFT on first 512-token measurement — 1.77s → ~0.3s) ### README.md - Split 4-bit / 8-bit benchmark tables with accurate current numbers - Add precision-specific guidance: --mtp for 8-bit at 40K+; --turbo-kv alone for max throughput; don't combine TQ+MTP ### Server.swift / CLICommandBuilder / GenerationConfig - --mtp / --num-mtp-tokens wired through CLI → GenerationConfig - Architectural note in MTP dispatch path explaining compute-bound behaviour on 4-bit MoE (documents why TQ+MTP underperforms TQ) --- README.md | 49 ++- .../MLXInferenceCore/CLICommandBuilder.swift | 6 + .../MLXInferenceCore/GenerationConfig.swift | 19 ++ Sources/SwiftLM/Server.swift | 93 +++++- mlx-swift-lm | 2 +- run_benchmark.sh | 33 +- scripts/profiling/mtp_bench.py | 281 ++++++++++++++++++ 7 files changed, 443 insertions(+), 40 deletions(-) create mode 100644 scripts/profiling/mtp_bench.py diff --git a/README.md b/README.md index 26dd5e9..2338352 100644 --- a/README.md +++ b/README.md @@ -92,26 +92,51 @@ Benchmarked with `gemma-4-26b-a4b-it-4bit` running three configurations across 5 ## 📊 Performance: Gemma 4-26B on Apple Silicon +Benchmark results for `gemma-4-26b-a4b-it-4bit` (**26B MoE, ~4B active params/token**, 4-bit) on M5 Pro 64 GB. -Benchmark results for `gemma-4-26b-a4b-it-4bit` (26B MoE, 4-bit) on M5 Pro 64 GB. +> ⚠️ This is a **Mixture-of-Experts (MoE)** model, not a dense model. Each token activates ~4B of the 26B parameters. "Vanilla" = all experts loaded into unified RAM (no SSD streaming). -### Headline Numbers +### Headline Numbers — `gemma-4-26b-a4b-it-4bit` (4-bit MoE) + +> Benchmarked on **M5 Pro 64 GB** · `gemma-4-26b-a4b-it-4bit` · `./run_benchmark.sh` Option 13 +> Values shown as `generation TPS · OS RAM used` (TTFT excluded from speed measurement) | Configuration | 512 ctx | 40K ctx | 100K ctx | |---|---|---|---| -| **Dense/Vanilla** | 33.0 tok/s · 23.4 GB | 20.2 tok/s · 57.0 GB | 15.7 tok/s · 56.7 GB | -| **SSD Stream** | 10.8 tok/s · **22.2 GB** | 10.4 tok/s · **24.2 GB** | 9.0 tok/s · **27.6 GB** | -| **TurboQuant** | 29.0 tok/s · 23.7 GB | 3.9 tok/s · 39.4 GB | 3.9 tok/s · 57.3 GB | -| **SSD + TurboQuant** | 11.4 tok/s · **22.0 GB** | 2.5 tok/s · **22.5 GB** | 1.6 tok/s · **22.3 GB** | +| **Vanilla (full-RAM MoE)** | 77.5 tok/s · 14.6 GB | 44.3 tok/s · 48.7 GB | 27.5 tok/s · 48.5 GB | +| **Vanilla + MTP** | 72.7 tok/s · 16.6 GB | 44.9 tok/s · 49.4 GB | 37.5 tok/s · 49.3 GB | +| **Vanilla + TurboQuant** | 77.3 tok/s · 14.7 GB | **70.1 tok/s · 18.2 GB** | **66.9 tok/s · 20.7 GB** | +| **Vanilla + MTP + TurboQuant** | 73.5 tok/s · 16.6 GB | 53.8 tok/s · 19.7 GB | 32.8 tok/s · 22.0 GB | +| **SSD Stream** | 10.8 tok/s · 22.2 GB | 10.4 tok/s · 24.2 GB | 9.0 tok/s · 27.6 GB | +| **SSD + TurboQuant** | 11.4 tok/s · 22.0 GB | 2.5 tok/s · 22.5 GB | 1.6 tok/s · 22.3 GB | -> Values shown as `generation speed · GPU memory allocated` +> GPU Peak physical RAM (from `ioreg`): Vanilla 100K peaks at **21.8 GB** in-use · TurboQuant 100K stays at **17.1 GB** in-use -**Key takeaways:** -- 🚀 **Speed Doubled**: The newer MLX backend modifications have more than doubled raw `SSD Stream` inference speed (from 4.5 -> **10.8 tok/s**) while maintaining streaming stability. -- 📄 **40K context on 24 GB MacBook Pro**: SSD + TurboQuant effortlessly fits a 26B model in **22.5 GB** of memory footprint. -- 📚 **100K context on 24 GB MacBook Pro**: Due to hyper-efficient 3-bit KV compression paired with SSD weight streaming, you can process 100,000 tokens of context on a 24 GB machine — only utilizing **22.3 GB** total. (Previously required a 64 GB Mac Studio). +**Key takeaways (4-bit):** +- 🚀 **TurboQuant is the headline win**: At 100K context, `Vanilla + TurboQuant` delivers **66.9 tok/s** vs **27.5 tok/s** Vanilla — a **2.43× speedup**. +- 💾 **Massive memory savings**: OS RAM at 40K context drops from **48.7 GB → 18.2 GB** with TurboQuant (63% reduction). +- ⚡ **MTP neutral on 4-bit MoE**: The 4-bit model is compute-bound (MoE expert dispatch). Batch verification scales linearly with token count, so MTP provides no net throughput gain over vanilla at 4-bit. +- ⚠️ **TQ + MTP undercuts TQ alone**: Adding MTP to a TurboQuant server removes the bandwidth bottleneck (KV is now tiny) but not the FFN compute — MTP then adds overhead without proportional gains. +- 🖥️ **SSD Stream for 24 GB Macs**: Enables long-context inference with only ~22–27 GB RAM across all context depths. + +### Headline Numbers — `gemma-4-26b-a4b-it-8bit` (8-bit, bandwidth-bound) -> Run `./run_benchmark.sh` to generate these metrics on your own device. (See **Benchmarks & Testing** below). +> Benchmarked on **M5 Pro 64 GB** · `gemma-4-26b-a4b-it-8bit` · `./run_benchmark.sh` Option 13 + +| Configuration | 512 ctx | 40K ctx | 100K ctx | +|---|---|---|---| +| **Vanilla** | 53.7 tok/s · 26.1 GB | 32.4 tok/s · 49.4 GB | 14.9 tok/s · 49.3 GB | +| **Vanilla + MTP** ⭐ | 47.1 tok/s · 28.0 GB | **38.8 tok/s (+20%)** · 49.6 GB | **22.5 tok/s (+51%)** · 49.6 GB | +| **Vanilla + TurboQuant** | 53.5 tok/s · 26.1 GB | **50.1 tok/s · 29.6 GB** | **48.3 tok/s · 32.0 GB** | +| **Vanilla + MTP + TurboQuant** | 47.4 tok/s · 28.0 GB | 31.0 tok/s · 31.1 GB | 23.3 tok/s · 33.3 GB | + +**Key takeaways (8-bit):** +- 🎯 **MTP works at 8-bit**: The 8-bit model is **bandwidth-bound** (2× heavier weights than 4-bit). The KV reads in the 3-token verification batch amortize across all 3 queries — so batch verification costs ~1.3× vanilla while producing 2+ tokens. Net: **+20% at 40K, +51% at 100K**. +- 🚀 **TurboQuant still wins outright**: At 100K, TurboQuant alone gives **48.3 tok/s** vs MTP's **22.5 tok/s** — TQ eliminates memory pressure and achieves 3.24× speedup over vanilla. +- ❌ **Don't combine TQ + MTP on 8-bit**: TurboQuant compresses the KV cache 4×, removing the bandwidth bottleneck and making MTP compute-bound again — TQ+MTP (31.0) is slower than TQ alone (50.1) at 40K. +- 💡 **Precision-dependent guidance**: Use `--mtp` for 8-bit models at 40K+ contexts without TurboQuant. Use `--turbo-kv` (without MTP) when maximum throughput or memory efficiency is the priority. + +> Run `./run_benchmark.sh` → Option 13 to reproduce these metrics on your own device. ### Qwen3.6-35B-A3B-UD-MLX-4bit (Full-RAM) — M1 Ultra 64 GB diff --git a/Sources/MLXInferenceCore/CLICommandBuilder.swift b/Sources/MLXInferenceCore/CLICommandBuilder.swift index 833aaf3..3e96c72 100644 --- a/Sources/MLXInferenceCore/CLICommandBuilder.swift +++ b/Sources/MLXInferenceCore/CLICommandBuilder.swift @@ -54,6 +54,12 @@ public func buildCLICommand( if config.enableThinking { parts.append("--thinking") } + if config.enableMTP { + parts.append("--mtp") + if config.numMTPTokens != 1 { + parts.append("--num-mtp-tokens \(config.numMTPTokens)") + } + } if let seed = config.seed { parts.append("--seed \(seed)") } diff --git a/Sources/MLXInferenceCore/GenerationConfig.swift b/Sources/MLXInferenceCore/GenerationConfig.swift index 796aa1c..e757e6e 100644 --- a/Sources/MLXInferenceCore/GenerationConfig.swift +++ b/Sources/MLXInferenceCore/GenerationConfig.swift @@ -98,6 +98,25 @@ public struct GenerationConfig: Sendable, Codable { self.numMTPTokens = numMTPTokens } + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.maxTokens = try container.decodeIfPresent(Int.self, forKey: .maxTokens) ?? 2048 + self.temperature = try container.decodeIfPresent(Float.self, forKey: .temperature) ?? 0.6 + self.topP = try container.decodeIfPresent(Float.self, forKey: .topP) ?? 1.0 + self.topK = try container.decodeIfPresent(Int.self, forKey: .topK) ?? 50 + self.minP = try container.decodeIfPresent(Float.self, forKey: .minP) ?? 0.0 + self.repetitionPenalty = try container.decodeIfPresent(Float.self, forKey: .repetitionPenalty) ?? 1.05 + self.seed = try container.decodeIfPresent(UInt64.self, forKey: .seed) + self.enableThinking = try container.decodeIfPresent(Bool.self, forKey: .enableThinking) ?? false + self.prefillSize = try container.decodeIfPresent(Int.self, forKey: .prefillSize) ?? 512 + self.kvBits = try container.decodeIfPresent(Int.self, forKey: .kvBits) + self.kvGroupSize = try container.decodeIfPresent(Int.self, forKey: .kvGroupSize) ?? 64 + self.turboKV = try container.decodeIfPresent(Bool.self, forKey: .turboKV) ?? false + self.streamExperts = try container.decodeIfPresent(Bool.self, forKey: .streamExperts) ?? false + self.enableMTP = try container.decodeIfPresent(Bool.self, forKey: .enableMTP) ?? false + self.numMTPTokens = try container.decodeIfPresent(Int.self, forKey: .numMTPTokens) ?? 1 + } + public static let `default` = GenerationConfig() // MARK: — Persistence diff --git a/Sources/SwiftLM/Server.swift b/Sources/SwiftLM/Server.swift index 3f26703..6757b15 100644 --- a/Sources/SwiftLM/Server.swift +++ b/Sources/SwiftLM/Server.swift @@ -663,6 +663,39 @@ struct MLXServer: AsyncParsableCommand { draftModelRef = nil } + // ── Load Gemma4 MTP assistant model ── + // Gemma4's MTP uses a separate assistant model (DualModelMTP pattern). + // The main model (Gemma4Model/Gemma4TextModel) does NOT conform to MTPLanguageModel; + // only Gemma4AssistantModel does. We must load the assistant and pass it as the + // "draft" model — generate(draftModel:) handles DualModelMTP.mainModelRef wiring. + let mtpAsstRef: DraftModelRef? + if self.mtp, self.draftModel == nil { + // Auto-resolve: strip any quantisation suffix and append '-assistant-bf16' + // e.g. mlx-community/gemma-4-26b-a4b-it-4bit → mlx-community/gemma-4-26B-A4B-it-assistant-bf16 + let asstModelId = Gemma4MTPRegistry.resolveAssistant(for: modelId) + if let asstId = asstModelId, resolveModelDirectory(modelId: asstId) != nil { + print("[SwiftLM] Loading Gemma4 MTP assistant model: \(asstId)") + let asstConfig = ModelConfiguration(id: asstId) + let asstDownloader = HubDownloader(hub: HubApi(downloadBase: cacheRoot)) + let asstContainer = try await LLMModelFactory.shared.loadContainer( + from: asstDownloader, + using: TransformersTokenizerLoader(), + configuration: asstConfig + ) { _ in } + mtpAsstRef = await asstContainer.extractDraftModel() + print("[SwiftLM] MTP assistant loaded (\(self.numMtpTokens) draft tokens/round). Using DualModelMTP speculative path.") + } else { + if let asstId = asstModelId { + print("[SwiftLM] ⚠️ Gemma4 MTP: assistant model '\(asstId)' not found in HF cache. Run: python -m mlx_lm.convert --hf-path \(asstId) to download it.") + } else { + print("[SwiftLM] ⚠️ --mtp: model '\(modelId)' is not a known Gemma4 MTP model. MTP requires a Gemma4 assistant checkpoint.") + } + mtpAsstRef = nil + } + } else { + mtpAsstRef = nil + } + // ── Load DFlash draft model for block-diffusion speculative decoding ── let dflashModel: DFlashDraftModel? let dflashBlockSizeConfig = self.dflashBlockSize @@ -899,6 +932,7 @@ struct MLXServer: AsyncParsableCommand { return try await handleChatCompletion( request: request, bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats, promptCache: promptCache, draftModelRef: draftModelRef, numDraftTokens: numDraftTokensConfig, + mtpAsstRef: mtpAsstRef, numMtpTokens: config.numMtpTokens, dflashModel: dflashModel, dflashBlockSize: dflashBlockSizeConfig, dflashTargetModel: dflashTargetModel ) @@ -1073,6 +1107,34 @@ struct ServerConfig: Sendable { let numMtpTokens: Int } +// ── Gemma4 MTP Assistant Model Registry ────────────────────────────────────── +// Gemma4's MTP head lives in a separate "assistant" checkpoint (model_type=gemma4_assistant). +// This registry maps known main model IDs → their assistant counterparts. +// The case-insensitive suffix match handles quantisation variants like -4bit, -8bit, -bf16. +enum Gemma4MTPRegistry { + private static let knownAssistants: [(prefix: String, assistant: String)] = [ + ("mlx-community/gemma-4-26b-a4b-it", "mlx-community/gemma-4-26B-A4B-it-assistant-bf16"), + ("mlx-community/gemma-4-26B-A4B-it", "mlx-community/gemma-4-26B-A4B-it-assistant-bf16"), + ("mlx-community/gemma-4-e2b-it", "mlx-community/gemma-4-E2B-it-assistant-bf16"), + ("mlx-community/gemma-4-E2B-it", "mlx-community/gemma-4-E2B-it-assistant-bf16"), + ("mlx-community/gemma-4-e4b-it", "mlx-community/gemma-4-E2B-it-assistant-bf16"), + ("mlx-community/gemma-4-E4B-it", "mlx-community/gemma-4-E2B-it-assistant-bf16"), + ("mlx-community/gemma-4-31b-it", "mlx-community/gemma-4-26B-A4B-it-assistant-bf16"), + ("mlx-community/gemma-4-31B-it", "mlx-community/gemma-4-26B-A4B-it-assistant-bf16"), + ] + + /// Returns the assistant model ID for `modelId`, or nil if this is not a Gemma4 MTP model. + static func resolveAssistant(for modelId: String) -> String? { + let lower = modelId.lowercased() + for (prefix, assistant) in knownAssistants { + if lower.hasPrefix(prefix.lowercased()) { + return assistant + } + } + return nil + } +} + // ── SSD Memory Budget ──────────────────────────────────────────────────────── /// Compute the page-cache budget (bytes) for SSD streaming mode. @@ -1330,6 +1392,8 @@ func handleChatCompletion( promptCache: PromptCache, draftModelRef: DraftModelRef? = nil, numDraftTokens: Int = 4, + mtpAsstRef: DraftModelRef? = nil, + numMtpTokens: Int = 3, dflashModel: DFlashDraftModel? = nil, dflashBlockSize: Int? = nil, dflashTargetModel: (any DFlashTargetModel)? = nil @@ -1599,6 +1663,17 @@ func handleChatCompletion( input: lmInput, cache: cache, parameters: params, context: context, draftModel: draftRef.model, numDraftTokens: numDraftTokens ) + } else if let asstRef = mtpAsstRef { + // Gemma4 MTP path: assistant model (DualModelMTP) is wired as draft. + // generate(draftModel:) detects DualModelMTP and sets mainModelRef + MTPTokenIterator. + // Note: Gemma4-26B is compute-bound (MoE FFN dominates). Batch verification + // scales linearly with tokens, so MTP net TPS ≈ vanilla. Use --turbo-kv to + // reduce KV size and eliminate memory pressure for best long-context throughput. + print("[SwiftLM] Using Gemma4 MTP speculative decoding (\(numMtpTokens) draft tokens/round)") + stream = try MLXLMCommon.generate( + input: lmInput, cache: cache, parameters: params, context: context, + draftModel: asstRef.model, numDraftTokens: numMtpTokens + ) } else if !skipPromptCache, let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) { // Cache hit: KV state is pre-populated up to cachedCount tokens. // Only compute the remaining (new) tokens. @@ -1958,7 +2033,12 @@ func handleChatStreaming( // llama-server style: print newline then full response JSON print("") // end the real-time token stream line let postMemSnap = MemoryUtils.snapshot() - print("srv slot done: id 0 | gen_tokens=\(completionTokenCount) | OS_RAM=\(String(format: "%.1f", postMemSnap.os))GB | MEM_DEMAND=\(String(format: "%.1f", postMemSnap.demand))GB | GPU_MEM=\(String(format: "%.1f", postMemSnap.gpu))GB") + var slotDoneLog = "srv slot done: id 0 | gen_tokens=\(completionTokenCount) | OS_RAM=\(String(format: "%.1f", postMemSnap.os))GB | MEM_DEMAND=\(String(format: "%.1f", postMemSnap.demand))GB | GPU_MEM=\(String(format: "%.1f", postMemSnap.gpu))GB" + if info.totalDraftTokens > 0 { + let acc = Double(info.acceptedDraftTokens) / Double(info.totalDraftTokens) + slotDoneLog += " | MTP_ACCEPTED=\(info.acceptedDraftTokens)/\(info.totalDraftTokens) (\(String(format: "%.1f%%", acc * 100)))" + } + print(slotDoneLog) let dur = genDur let tokPerSec = genTokPerSec let logContent: Any = hasToolCalls ? NSNull() : fullText @@ -2016,6 +2096,8 @@ func handleChatNonStreaming( var tcIndex = 0 var generationStopReason: GenerateStopReason = .stop var firstToken = true + var acceptedDraftTokens = 0 + var totalDraftTokens = 0 for await generation in stream { switch generation { case .chunk(let text, _): @@ -2047,11 +2129,18 @@ func handleChatNonStreaming( tcIndex += 1 case .info(let info): generationStopReason = info.stopReason + acceptedDraftTokens = info.acceptedDraftTokens + totalDraftTokens = info.totalDraftTokens } } print("") // end the real-time token stream line let postMemSnap = MemoryUtils.snapshot() - print("srv slot done: id 0 | gen_tokens=\(completionTokenCount) | OS_RAM=\(String(format: "%.1f", postMemSnap.os))GB | MEM_DEMAND=\(String(format: "%.1f", postMemSnap.demand))GB | GPU_MEM=\(String(format: "%.1f", postMemSnap.gpu))GB") + var slotDoneLog = "srv slot done: id 0 | gen_tokens=\(completionTokenCount) | OS_RAM=\(String(format: "%.1f", postMemSnap.os))GB | MEM_DEMAND=\(String(format: "%.1f", postMemSnap.demand))GB | GPU_MEM=\(String(format: "%.1f", postMemSnap.gpu))GB" + if totalDraftTokens > 0 { + let acc = Double(acceptedDraftTokens) / Double(totalDraftTokens) + slotDoneLog += " | MTP_ACCEPTED=\(acceptedDraftTokens)/\(totalDraftTokens) (\(String(format: "%.1f%%", acc * 100)))" + } + print(slotDoneLog) let duration = Date().timeIntervalSince(genStart) await stats.requestFinished(tokens: completionTokenCount, duration: duration) await semaphore.signal() diff --git a/mlx-swift-lm b/mlx-swift-lm index b9bf50b..c552b4d 160000 --- a/mlx-swift-lm +++ b/mlx-swift-lm @@ -1 +1 @@ -Subproject commit b9bf50bdafef02fffd5b83598a61bbf7d47434f9 +Subproject commit c552b4dec24f22ff0928974022f3c2ef1b1aea31 diff --git a/run_benchmark.sh b/run_benchmark.sh index 602ed49..9a8f396 100755 --- a/run_benchmark.sh +++ b/run_benchmark.sh @@ -1341,42 +1341,25 @@ fi if [ "$suite_opt" == "13" ]; then echo "" echo "=> Starting Test 13: Gemma-4 MTP Speculative Decoding Benchmark" - - # Infer assistant model - if [[ "$FULL_MODEL" == *"gemma-4-26b"* ]]; then - ASST_MODEL="mlx-community/gemma-4-26B-A4B-it-assistant-bf16" - elif [[ "$FULL_MODEL" == *"gemma-4-e2b"* ]]; then - ASST_MODEL="mlx-community/gemma-4-E2B-it-assistant-bf16" - else - read -p "Enter assistant model Hub ID: " ASST_MODEL - fi echo "" read -p "Enter context lengths to test [default: 512,40000,100000]: " CONTEXTS CONTEXTS=${CONTEXTS:-"512,40000,100000"} echo "" - echo "Building benchmark binary..." - swift build -c release --product Gemma4MTPBench + echo "Note: Test 13 uses the pre-built SwiftLM binary (same as Test 1)." + echo " Make sure you have run ./build.sh first." + echo "" + + python3 -u scripts/profiling/mtp_bench.py \ + --model "$FULL_MODEL" \ + --contexts "$CONTEXTS" \ + --max-tokens 60 - IFS=',' read -ra ADDR <<< "$CONTEXTS" - for ctx in "${ADDR[@]}"; do - ctx=$(echo "$ctx" | tr -d ' ') - echo "" - echo "--- Test 13: Context (max-kv-size=$ctx) on $FULL_MODEL ---" - swift run -c release Gemma4MTPBench \ - --main-model "$FULL_MODEL" \ - --asst-model "$ASST_MODEL" \ - --prompt "Write a detailed 3-paragraph essay on the impact of the Industrial Revolution on modern supply chain logistics. Ensure you include dates and specific technological advancements." \ - --max-tokens 100 \ - --max-kv-size "$ctx" | grep -v "ASST DEBUG" - done - echo "" echo "✅ Gemma-4 MTP Speculative Decoding Benchmarks Complete." exit 0 fi - # Fallback to Test 1 for anything else echo "" read -p "Enter context lengths to test [default: 512,40000,100000]: " CONTEXTS diff --git a/scripts/profiling/mtp_bench.py b/scripts/profiling/mtp_bench.py new file mode 100644 index 0000000..2055ce2 --- /dev/null +++ b/scripts/profiling/mtp_bench.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +""" +mtp_bench.py — Test 13 MTP Speculative Decoding Benchmark +Uses the pre-built SwiftLM HTTP server (same approach as profile_runner.py / Test 1). +Boots the server for each config, sends real HTTP streaming requests with a properly-sized +dummy prompt to stress the KV cache, measures generation TPS (isolated from prefill/TTFT), +and prints a summary table. + +Usage: + python3 scripts/profiling/mtp_bench.py \ + --model mlx-community/gemma-4-26b-a4b-it-4bit \ + --contexts 512,40000,100000 + +Configs tested per context: + - Vanilla (no flags) + - Vanilla + MTP (--mtp --num-mtp-tokens 4) + - Vanilla + TurboQuant (--turbo-kv) + - Vanilla + MTP + TurboQuant (--mtp --num-mtp-tokens 4 --turbo-kv) +""" + +import argparse +import json +import os +import re +import signal +import subprocess +import sys +import threading +import time +import urllib.request +import urllib.error + +SWIFTLM_PATH = ".build/arm64-apple-macosx/release/SwiftLM" +PORT = 5430 + +CONFIGS = [ + {"name": "Vanilla", "flags": []}, + {"name": "Vanilla + MTP", "flags": ["--mtp", "--num-mtp-tokens", "4"]}, + {"name": "Vanilla + TurboQuant", "flags": ["--turbo-kv"]}, + {"name": "Vanilla + MTP + TurboQuant", "flags": ["--mtp", "--num-mtp-tokens", "4", "--turbo-kv"]}, +] + +def get_gpu_alloc_gb(): + try: + result = subprocess.run( + ["ioreg", "-r", "-d", "1", "-w", "0", "-c", "AGXAccelerator"], + capture_output=True, text=True, timeout=5 + ) + alloc_match = re.search(r'"Alloc system memory"=(\d+)', result.stdout) + in_use_match = re.search(r'"In use system memory"=(\d+)', result.stdout) + alloc_gb = int(alloc_match.group(1)) / (1024**3) if alloc_match else 0.0 + in_use_gb = int(in_use_match.group(1)) / (1024**3) if in_use_match else 0.0 + return alloc_gb, in_use_gb + except: + return 0.0, 0.0 + +def extract_os_ram(log_path): + try: + with open(log_path, 'r') as f: + log_data = f.read() + post_vals = re.findall(r"slot done.*?OS_RAM=([0-9.]+)", log_data) + if post_vals: + return post_vals[-1] + prefill_vals = re.findall(r"prefill done.*?OS_RAM=([0-9.]+)", log_data) + if prefill_vals: + return prefill_vals[-1] + except: + pass + return "N/A" + +def poll_health(server_proc, port, timeout=300): + url = f"http://127.0.0.1:{port}/health" + deadline = time.time() + timeout + spinner = ["|", "/", "-", "\\"] + spin_idx = 0 + while time.time() < deadline: + if server_proc.poll() is not None: + return False + try: + r = urllib.request.urlopen(url, timeout=2) + if r.getcode() == 200: + sys.stdout.write(f"\r ✅ Model loaded!{' ' * 40}\n") + sys.stdout.flush() + return True + except: + pass + spin_idx = (spin_idx + 1) % len(spinner) + sys.stdout.write(f"\r {spinner[spin_idx]} Waiting for model to load...") + sys.stdout.flush() + time.sleep(1) + return False + +def make_warmup_request(port): + """ + Fire a short dummy request to prime Metal shader compilation. + Without this, the first timed request carries ~1s of JIT overhead + (visible as inflated TTFT on the vanilla 512-token run). + """ + prompt = "apple " * 200 # ~200 tokens — enough to trigger all kernels + data = json.dumps({ + "messages": [{"role": "user", "content": prompt}], + "max_tokens": 20, + "temperature": 0.0, + "stream": False, + }).encode('utf-8') + req = urllib.request.Request( + f"http://127.0.0.1:{port}/v1/chat/completions", + data=data, + headers={'Content-Type': 'application/json'} + ) + try: + urllib.request.urlopen(req, timeout=120) + except Exception: + pass # warmup failures are non-fatal + + +def make_request_stream(prompt_len, max_tokens, port): + """ + Send a chat completion request with a dummy prompt of `prompt_len` approximate tokens. + Returns (ok, ttft_s, gen_tps, peak_gpu_in_use_gb, os_ram_gb). + Measures TTFT separately from generation TPS so that long prefills don't distort the speed. + """ + # Same approach as profile_runner.py: "apple " repeated to fill context + prompt = "apple " * int(prompt_len * 0.75) + data = json.dumps({ + "messages": [{"role": "user", "content": prompt}], + "max_tokens": max_tokens, + "temperature": 0.0, + "stream": True + }).encode('utf-8') + + req = urllib.request.Request( + f"http://127.0.0.1:{port}/v1/chat/completions", + data=data, + headers={'Content-Type': 'application/json'} + ) + + peak_in_use = [0.0] + poller_stop = threading.Event() + + def _poll_gpu(): + while not poller_stop.is_set(): + _, in_use = get_gpu_alloc_gb() + if in_use > peak_in_use[0]: + peak_in_use[0] = in_use + poller_stop.wait(timeout=0.5) + + poller = threading.Thread(target=_poll_gpu, daemon=True) + poller.start() + + ttft = None + start = time.time() + tokens = 0 + try: + with urllib.request.urlopen(req, timeout=900) as response: + for line in response: + line = line.decode('utf-8').strip() + if line.startswith("data: ") and line != "data: [DONE]": + payload = line[6:] + if "prefill_progress" in payload or "prefill" in payload: + continue + if ttft is None: + ttft = time.time() - start + tokens += 1 + total_time = time.time() - start + gen_time = total_time - (ttft or 0) + tps = (tokens - 1) / gen_time if gen_time > 0 and tokens > 1 else 0 + poller_stop.set() + poller.join(timeout=2) + return True, ttft, tps, peak_in_use[0] + except Exception as e: + print(f"\n ❌ Request failed: {e}") + poller_stop.set() + poller.join(timeout=2) + return False, 0, 0, 0.0 + +def main(): + parser = argparse.ArgumentParser(description="Gemma-4 MTP Speculative Decoding Benchmark (Test 13)") + parser.add_argument("--model", required=True, help="Model HF ID") + parser.add_argument("--contexts", default="512,40000,100000", help="Comma-separated context lengths") + parser.add_argument("--max-tokens", type=int, default=60, help="Tokens to generate per run") + args = parser.parse_args() + + model_id = args.model if "/" in args.model else f"mlx-community/{args.model}" + context_sizes = [int(x.strip()) for x in args.contexts.split(",") if x.strip()] + + bin_path = SWIFTLM_PATH + if not os.path.exists(bin_path): + alt = ".build/release/SwiftLM" + if os.path.exists(alt): + bin_path = alt + else: + print(f"❌ SwiftLM binary not found at {SWIFTLM_PATH}. Run ./build.sh first.") + sys.exit(1) + + subprocess.run(["killall", "SwiftLM"], stderr=subprocess.DEVNULL) + time.sleep(2) + + summary = [] # list of dicts + + for config in CONFIGS: + print(f"\n{'='*62}") + print(f" Config: {config['name']}") + print(f"{'='*62}") + + log_path = "./tmp/mtp_bench_server.log" + os.makedirs(os.path.dirname(log_path), exist_ok=True) + + cmd = [bin_path, "--model", model_id, "--port", str(PORT)] + config["flags"] + print(f" Starting: {' '.join(cmd[-4:])}") + + with open(log_path, "w") as log_f: + server_proc = subprocess.Popen(cmd, stdout=log_f, stderr=subprocess.STDOUT) + + is_healthy = poll_health(server_proc, PORT, timeout=600) + if not is_healthy: + print(f" ❌ Server failed to start for config: {config['name']}") + server_proc.terminate() + server_proc.wait(timeout=10) + for ctx in context_sizes: + summary.append({"config": config["name"], "context": ctx, + "ttft": "FAIL", "tps": "FAIL", + "gpu_alloc": "N/A", "gpu_in_use_peak": "N/A", "os_ram": "N/A"}) + continue + + # Prime Metal shader compilation so first timed run isn't inflated by JIT overhead. + sys.stdout.write(" 🔥 Warming up Metal shaders...") + sys.stdout.flush() + make_warmup_request(PORT) + sys.stdout.write(" done\n") + sys.stdout.flush() + + for ctx in context_sizes: + print(f"\n >> Context={ctx} tokens (generating {args.max_tokens} tokens)...") + ok, ttft, tps, peak_in_use = make_request_stream( + prompt_len=ctx, max_tokens=args.max_tokens, port=PORT + ) + time.sleep(1) # let server flush logs + os_ram = extract_os_ram(log_path) + gpu_alloc, _ = get_gpu_alloc_gb() + + if ok: + ttft_s = f"{ttft:.2f}" if ttft is not None else "N/A" + print(f" TTFT={ttft_s}s TPS={tps:.1f} OS_RAM={os_ram}GB GPU_Alloc={gpu_alloc:.1f}GB GPU_InUse(peak)={peak_in_use:.1f}GB") + summary.append({ + "config": config["name"], "context": ctx, + "ttft": ttft_s, + "tps": f"{tps:.1f}", + "gpu_alloc": f"{gpu_alloc:.1f}", + "gpu_in_use_peak": f"{peak_in_use:.1f}", + "os_ram": os_ram, + }) + else: + print(f" ⚠️ [OOM/Crash] Request failed at context={ctx}") + summary.append({"config": config["name"], "context": ctx, + "ttft": "OOM", "tps": "OOM", + "gpu_alloc": "N/A", "gpu_in_use_peak": "N/A", "os_ram": "N/A"}) + + server_proc.send_signal(signal.SIGKILL) + server_proc.wait(timeout=20) + print("\n [Teardown] Waiting 12s for macOS to reclaim GPU heap...") + time.sleep(12) + + # ── Summary Table ──────────────────────────────────────────────────────── + print(f"\n{'─'*80}") + print(f" 🏆 Gemma-4 MTP Speculative Decoding Summary") + print(f" Model: {model_id}") + print(f"{'─'*80}") + header = f" {'Context':<10} | {'Configuration':<32} | {'TPS':>7} | {'TTFT':>6} | {'OS RAM':>8} | {'GPU Peak':>9}" + print(header) + print(f" {'-'*78}") + for row in summary: + os_ram = f"{row['os_ram']} GB" if row['os_ram'] != 'N/A' else 'N/A' + gpu_peak = f"{row['gpu_in_use_peak']} GB" if row['gpu_in_use_peak'] != 'N/A' else 'N/A' + tps = f"{row['tps']} tok/s" if row['tps'] not in ('FAIL','OOM') else row['tps'] + ttft = f"{row['ttft']}s" if row['ttft'] not in ('FAIL','OOM','N/A') else row['ttft'] + print(f" {str(row['context']):<10} | {row['config']:<32} | {tps:>10} | {ttft:>6} | {os_ram:>8} | {gpu_peak:>9}") + print(f"{'─'*80}") + +if __name__ == "__main__": + main()