Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 37 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,26 +92,51 @@ Benchmarked with `gemma-4-26b-a4b-it-4bit` running three configurations across 5

## 📊 Performance: Gemma 4-26B on Apple Silicon

Benchmark results for `gemma-4-26b-a4b-it-4bit` (**26B MoE, ~4B active params/token**, 4-bit) on M5 Pro 64 GB.

Benchmark results for `gemma-4-26b-a4b-it-4bit` (26B MoE, 4-bit) on M5 Pro 64 GB.
> ⚠️ This is a **Mixture-of-Experts (MoE)** model, not a dense model. Each token activates ~4B of the 26B parameters. "Vanilla" = all experts loaded into unified RAM (no SSD streaming).

### Headline Numbers
### Headline Numbers — `gemma-4-26b-a4b-it-4bit` (4-bit MoE)

> Benchmarked on **M5 Pro 64 GB** · `gemma-4-26b-a4b-it-4bit` · `./run_benchmark.sh` Option 13
> Values shown as `generation TPS · OS RAM used` (TTFT excluded from speed measurement)

| Configuration | 512 ctx | 40K ctx | 100K ctx |
|---|---|---|---|
| **Dense/Vanilla** | 33.0 tok/s · 23.4 GB | 20.2 tok/s · 57.0 GB | 15.7 tok/s · 56.7 GB |
| **SSD Stream** | 10.8 tok/s · **22.2 GB** | 10.4 tok/s · **24.2 GB** | 9.0 tok/s · **27.6 GB** |
| **TurboQuant** | 29.0 tok/s · 23.7 GB | 3.9 tok/s · 39.4 GB | 3.9 tok/s · 57.3 GB |
| **SSD + TurboQuant** | 11.4 tok/s · **22.0 GB** | 2.5 tok/s · **22.5 GB** | 1.6 tok/s · **22.3 GB** |
| **Vanilla (full-RAM MoE)** | 77.5 tok/s · 14.6 GB | 44.3 tok/s · 48.7 GB | 27.5 tok/s · 48.5 GB |
| **Vanilla + MTP** | 72.7 tok/s · 16.6 GB | 44.9 tok/s · 49.4 GB | 37.5 tok/s · 49.3 GB |
| **Vanilla + TurboQuant** | 77.3 tok/s · 14.7 GB | **70.1 tok/s · 18.2 GB** | **66.9 tok/s · 20.7 GB** |
| **Vanilla + MTP + TurboQuant** | 73.5 tok/s · 16.6 GB | 53.8 tok/s · 19.7 GB | 32.8 tok/s · 22.0 GB |
| **SSD Stream** | 10.8 tok/s · 22.2 GB | 10.4 tok/s · 24.2 GB | 9.0 tok/s · 27.6 GB |
| **SSD + TurboQuant** | 11.4 tok/s · 22.0 GB | 2.5 tok/s · 22.5 GB | 1.6 tok/s · 22.3 GB |

> Values shown as `generation speed · GPU memory allocated`
> GPU Peak physical RAM (from `ioreg`): Vanilla 100K peaks at **21.8 GB** in-use · TurboQuant 100K stays at **17.1 GB** in-use

**Key takeaways:**
- 🚀 **Speed Doubled**: The newer MLX backend modifications have more than doubled raw `SSD Stream` inference speed (from 4.5 -> **10.8 tok/s**) while maintaining streaming stability.
- 📄 **40K context on 24 GB MacBook Pro**: SSD + TurboQuant effortlessly fits a 26B model in **22.5 GB** of memory footprint.
- 📚 **100K context on 24 GB MacBook Pro**: Due to hyper-efficient 3-bit KV compression paired with SSD weight streaming, you can process 100,000 tokens of context on a 24 GB machine — only utilizing **22.3 GB** total. (Previously required a 64 GB Mac Studio).
**Key takeaways (4-bit):**
- 🚀 **TurboQuant is the headline win**: At 100K context, `Vanilla + TurboQuant` delivers **66.9 tok/s** vs **27.5 tok/s** Vanilla — a **2.43× speedup**.
- 💾 **Massive memory savings**: OS RAM at 40K context drops from **48.7 GB → 18.2 GB** with TurboQuant (63% reduction).
- ⚡ **MTP neutral on 4-bit MoE**: The 4-bit model is compute-bound (MoE expert dispatch). Batch verification scales linearly with token count, so MTP provides no net throughput gain over vanilla at 4-bit.
- ⚠️ **TQ + MTP undercuts TQ alone**: Adding MTP to a TurboQuant server removes the bandwidth bottleneck (KV is now tiny) but not the FFN compute — MTP then adds overhead without proportional gains.
- 🖥️ **SSD Stream for 24 GB Macs**: Enables long-context inference with only ~22–27 GB RAM across all context depths.

### Headline Numbers — `gemma-4-26b-a4b-it-8bit` (8-bit, bandwidth-bound)

> Run `./run_benchmark.sh` to generate these metrics on your own device. (See **Benchmarks & Testing** below).
> Benchmarked on **M5 Pro 64 GB** · `gemma-4-26b-a4b-it-8bit` · `./run_benchmark.sh` Option 13

| Configuration | 512 ctx | 40K ctx | 100K ctx |
|---|---|---|---|
| **Vanilla** | 53.7 tok/s · 26.1 GB | 32.4 tok/s · 49.4 GB | 14.9 tok/s · 49.3 GB |
| **Vanilla + MTP** ⭐ | 47.1 tok/s · 28.0 GB | **38.8 tok/s (+20%)** · 49.6 GB | **22.5 tok/s (+51%)** · 49.6 GB |
| **Vanilla + TurboQuant** | 53.5 tok/s · 26.1 GB | **50.1 tok/s · 29.6 GB** | **48.3 tok/s · 32.0 GB** |
| **Vanilla + MTP + TurboQuant** | 47.4 tok/s · 28.0 GB | 31.0 tok/s · 31.1 GB | 23.3 tok/s · 33.3 GB |

**Key takeaways (8-bit):**
- 🎯 **MTP works at 8-bit**: The 8-bit model is **bandwidth-bound** (2× heavier weights than 4-bit). The KV reads in the 3-token verification batch amortize across all 3 queries — so batch verification costs ~1.3× vanilla while producing 2+ tokens. Net: **+20% at 40K, +51% at 100K**.
- 🚀 **TurboQuant still wins outright**: At 100K, TurboQuant alone gives **48.3 tok/s** vs MTP's **22.5 tok/s** — TQ eliminates memory pressure and achieves 3.24× speedup over vanilla.
- ❌ **Don't combine TQ + MTP on 8-bit**: TurboQuant compresses the KV cache 4×, removing the bandwidth bottleneck and making MTP compute-bound again — TQ+MTP (31.0) is slower than TQ alone (50.1) at 40K.
- 💡 **Precision-dependent guidance**: Use `--mtp` for 8-bit models at 40K+ contexts without TurboQuant. Use `--turbo-kv` (without MTP) when maximum throughput or memory efficiency is the priority.

> Run `./run_benchmark.sh` → Option 13 to reproduce these metrics on your own device.

### Qwen3.6-35B-A3B-UD-MLX-4bit (Full-RAM) — M1 Ultra 64 GB

Expand Down
6 changes: 6 additions & 0 deletions Sources/MLXInferenceCore/CLICommandBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ public func buildCLICommand(
if config.enableThinking {
parts.append("--thinking")
}
if config.enableMTP {
parts.append("--mtp")
if config.numMTPTokens != 1 {
parts.append("--num-mtp-tokens \(config.numMTPTokens)")
}
Comment on lines +57 to +61
}
if let seed = config.seed {
parts.append("--seed \(seed)")
}
Expand Down
19 changes: 19 additions & 0 deletions Sources/MLXInferenceCore/GenerationConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,25 @@ public struct GenerationConfig: Sendable, Codable {
self.numMTPTokens = numMTPTokens
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.maxTokens = try container.decodeIfPresent(Int.self, forKey: .maxTokens) ?? 2048
self.temperature = try container.decodeIfPresent(Float.self, forKey: .temperature) ?? 0.6
self.topP = try container.decodeIfPresent(Float.self, forKey: .topP) ?? 1.0
self.topK = try container.decodeIfPresent(Int.self, forKey: .topK) ?? 50
self.minP = try container.decodeIfPresent(Float.self, forKey: .minP) ?? 0.0
self.repetitionPenalty = try container.decodeIfPresent(Float.self, forKey: .repetitionPenalty) ?? 1.05
self.seed = try container.decodeIfPresent(UInt64.self, forKey: .seed)
self.enableThinking = try container.decodeIfPresent(Bool.self, forKey: .enableThinking) ?? false
self.prefillSize = try container.decodeIfPresent(Int.self, forKey: .prefillSize) ?? 512
self.kvBits = try container.decodeIfPresent(Int.self, forKey: .kvBits)
self.kvGroupSize = try container.decodeIfPresent(Int.self, forKey: .kvGroupSize) ?? 64
self.turboKV = try container.decodeIfPresent(Bool.self, forKey: .turboKV) ?? false
self.streamExperts = try container.decodeIfPresent(Bool.self, forKey: .streamExperts) ?? false
self.enableMTP = try container.decodeIfPresent(Bool.self, forKey: .enableMTP) ?? false
self.numMTPTokens = try container.decodeIfPresent(Int.self, forKey: .numMTPTokens) ?? 1
Comment on lines +116 to +117
}

public static let `default` = GenerationConfig()

// MARK: — Persistence
Expand Down
93 changes: 91 additions & 2 deletions Sources/SwiftLM/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,39 @@ struct MLXServer: AsyncParsableCommand {
draftModelRef = nil
}

// ── Load Gemma4 MTP assistant model ──
// Gemma4's MTP uses a separate assistant model (DualModelMTP pattern).
// The main model (Gemma4Model/Gemma4TextModel) does NOT conform to MTPLanguageModel;
// only Gemma4AssistantModel does. We must load the assistant and pass it as the
// "draft" model — generate(draftModel:) handles DualModelMTP.mainModelRef wiring.
let mtpAsstRef: DraftModelRef?
if self.mtp, self.draftModel == nil {
// Auto-resolve: strip any quantisation suffix and append '-assistant-bf16'
// e.g. mlx-community/gemma-4-26b-a4b-it-4bit → mlx-community/gemma-4-26B-A4B-it-assistant-bf16
let asstModelId = Gemma4MTPRegistry.resolveAssistant(for: modelId)
if let asstId = asstModelId, resolveModelDirectory(modelId: asstId) != nil {
print("[SwiftLM] Loading Gemma4 MTP assistant model: \(asstId)")
let asstConfig = ModelConfiguration(id: asstId)
let asstDownloader = HubDownloader(hub: HubApi(downloadBase: cacheRoot))
let asstContainer = try await LLMModelFactory.shared.loadContainer(
from: asstDownloader,
using: TransformersTokenizerLoader(),
configuration: asstConfig
) { _ in }
Comment on lines +678 to +684
mtpAsstRef = await asstContainer.extractDraftModel()
print("[SwiftLM] MTP assistant loaded (\(self.numMtpTokens) draft tokens/round). Using DualModelMTP speculative path.")
} else {
if let asstId = asstModelId {
print("[SwiftLM] ⚠️ Gemma4 MTP: assistant model '\(asstId)' not found in HF cache. Run: python -m mlx_lm.convert --hf-path \(asstId) to download it.")
} else {
print("[SwiftLM] ⚠️ --mtp: model '\(modelId)' is not a known Gemma4 MTP model. MTP requires a Gemma4 assistant checkpoint.")
}
Comment on lines +676 to +692
mtpAsstRef = nil
}
} else {
mtpAsstRef = nil
}

// ── Load DFlash draft model for block-diffusion speculative decoding ──
let dflashModel: DFlashDraftModel?
let dflashBlockSizeConfig = self.dflashBlockSize
Expand Down Expand Up @@ -899,6 +932,7 @@ struct MLXServer: AsyncParsableCommand {
return try await handleChatCompletion(
request: request, bodyData: bodyData, config: config, container: container, semaphore: semaphore, stats: stats, promptCache: promptCache,
draftModelRef: draftModelRef, numDraftTokens: numDraftTokensConfig,
mtpAsstRef: mtpAsstRef, numMtpTokens: config.numMtpTokens,
dflashModel: dflashModel, dflashBlockSize: dflashBlockSizeConfig,
dflashTargetModel: dflashTargetModel
)
Expand Down Expand Up @@ -1073,6 +1107,34 @@ struct ServerConfig: Sendable {
let numMtpTokens: Int
}

// ── Gemma4 MTP Assistant Model Registry ──────────────────────────────────────
// Gemma4's MTP head lives in a separate "assistant" checkpoint (model_type=gemma4_assistant).
// This registry maps known main model IDs → their assistant counterparts.
// The case-insensitive suffix match handles quantisation variants like -4bit, -8bit, -bf16.
enum Gemma4MTPRegistry {
private static let knownAssistants: [(prefix: String, assistant: String)] = [
("mlx-community/gemma-4-26b-a4b-it", "mlx-community/gemma-4-26B-A4B-it-assistant-bf16"),
("mlx-community/gemma-4-26B-A4B-it", "mlx-community/gemma-4-26B-A4B-it-assistant-bf16"),
("mlx-community/gemma-4-e2b-it", "mlx-community/gemma-4-E2B-it-assistant-bf16"),
("mlx-community/gemma-4-E2B-it", "mlx-community/gemma-4-E2B-it-assistant-bf16"),
("mlx-community/gemma-4-e4b-it", "mlx-community/gemma-4-E2B-it-assistant-bf16"),
("mlx-community/gemma-4-E4B-it", "mlx-community/gemma-4-E2B-it-assistant-bf16"),
("mlx-community/gemma-4-31b-it", "mlx-community/gemma-4-26B-A4B-it-assistant-bf16"),
("mlx-community/gemma-4-31B-it", "mlx-community/gemma-4-26B-A4B-it-assistant-bf16"),
]

/// Returns the assistant model ID for `modelId`, or nil if this is not a Gemma4 MTP model.
static func resolveAssistant(for modelId: String) -> String? {
let lower = modelId.lowercased()
for (prefix, assistant) in knownAssistants {
if lower.hasPrefix(prefix.lowercased()) {
return assistant
}
}
return nil
}
}

// ── SSD Memory Budget ────────────────────────────────────────────────────────

/// Compute the page-cache budget (bytes) for SSD streaming mode.
Expand Down Expand Up @@ -1330,6 +1392,8 @@ func handleChatCompletion(
promptCache: PromptCache,
draftModelRef: DraftModelRef? = nil,
numDraftTokens: Int = 4,
mtpAsstRef: DraftModelRef? = nil,
numMtpTokens: Int = 3,
dflashModel: DFlashDraftModel? = nil,
dflashBlockSize: Int? = nil,
dflashTargetModel: (any DFlashTargetModel)? = nil
Expand Down Expand Up @@ -1599,6 +1663,17 @@ func handleChatCompletion(
input: lmInput, cache: cache, parameters: params, context: context,
draftModel: draftRef.model, numDraftTokens: numDraftTokens
)
} else if let asstRef = mtpAsstRef {
// Gemma4 MTP path: assistant model (DualModelMTP) is wired as draft.
// generate(draftModel:) detects DualModelMTP and sets mainModelRef + MTPTokenIterator.
// Note: Gemma4-26B is compute-bound (MoE FFN dominates). Batch verification
// scales linearly with tokens, so MTP net TPS ≈ vanilla. Use --turbo-kv to
// reduce KV size and eliminate memory pressure for best long-context throughput.
print("[SwiftLM] Using Gemma4 MTP speculative decoding (\(numMtpTokens) draft tokens/round)")
stream = try MLXLMCommon.generate(
input: lmInput, cache: cache, parameters: params, context: context,
draftModel: asstRef.model, numDraftTokens: numMtpTokens
)
} else if !skipPromptCache, let cachedCount = await promptCache.restore(newTokens: promptTokens, into: cache) {
// Cache hit: KV state is pre-populated up to cachedCount tokens.
// Only compute the remaining (new) tokens.
Expand Down Expand Up @@ -1958,7 +2033,12 @@ func handleChatStreaming(
// llama-server style: print newline then full response JSON
print("") // end the real-time token stream line
let postMemSnap = MemoryUtils.snapshot()
print("srv slot done: id 0 | gen_tokens=\(completionTokenCount) | OS_RAM=\(String(format: "%.1f", postMemSnap.os))GB | MEM_DEMAND=\(String(format: "%.1f", postMemSnap.demand))GB | GPU_MEM=\(String(format: "%.1f", postMemSnap.gpu))GB")
var slotDoneLog = "srv slot done: id 0 | gen_tokens=\(completionTokenCount) | OS_RAM=\(String(format: "%.1f", postMemSnap.os))GB | MEM_DEMAND=\(String(format: "%.1f", postMemSnap.demand))GB | GPU_MEM=\(String(format: "%.1f", postMemSnap.gpu))GB"
if info.totalDraftTokens > 0 {
let acc = Double(info.acceptedDraftTokens) / Double(info.totalDraftTokens)
slotDoneLog += " | MTP_ACCEPTED=\(info.acceptedDraftTokens)/\(info.totalDraftTokens) (\(String(format: "%.1f%%", acc * 100)))"
}
print(slotDoneLog)
let dur = genDur
let tokPerSec = genTokPerSec
let logContent: Any = hasToolCalls ? NSNull() : fullText
Expand Down Expand Up @@ -2016,6 +2096,8 @@ func handleChatNonStreaming(
var tcIndex = 0
var generationStopReason: GenerateStopReason = .stop
var firstToken = true
var acceptedDraftTokens = 0
var totalDraftTokens = 0
for await generation in stream {
switch generation {
case .chunk(let text, _):
Expand Down Expand Up @@ -2047,11 +2129,18 @@ func handleChatNonStreaming(
tcIndex += 1
case .info(let info):
generationStopReason = info.stopReason
acceptedDraftTokens = info.acceptedDraftTokens
totalDraftTokens = info.totalDraftTokens
}
}
print("") // end the real-time token stream line
let postMemSnap = MemoryUtils.snapshot()
print("srv slot done: id 0 | gen_tokens=\(completionTokenCount) | OS_RAM=\(String(format: "%.1f", postMemSnap.os))GB | MEM_DEMAND=\(String(format: "%.1f", postMemSnap.demand))GB | GPU_MEM=\(String(format: "%.1f", postMemSnap.gpu))GB")
var slotDoneLog = "srv slot done: id 0 | gen_tokens=\(completionTokenCount) | OS_RAM=\(String(format: "%.1f", postMemSnap.os))GB | MEM_DEMAND=\(String(format: "%.1f", postMemSnap.demand))GB | GPU_MEM=\(String(format: "%.1f", postMemSnap.gpu))GB"
if totalDraftTokens > 0 {
let acc = Double(acceptedDraftTokens) / Double(totalDraftTokens)
slotDoneLog += " | MTP_ACCEPTED=\(acceptedDraftTokens)/\(totalDraftTokens) (\(String(format: "%.1f%%", acc * 100)))"
}
print(slotDoneLog)
let duration = Date().timeIntervalSince(genStart)
await stats.requestFinished(tokens: completionTokenCount, duration: duration)
await semaphore.signal()
Expand Down
33 changes: 8 additions & 25 deletions run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1341,42 +1341,25 @@ fi
if [ "$suite_opt" == "13" ]; then
echo ""
echo "=> Starting Test 13: Gemma-4 MTP Speculative Decoding Benchmark"

# Infer assistant model
if [[ "$FULL_MODEL" == *"gemma-4-26b"* ]]; then
ASST_MODEL="mlx-community/gemma-4-26B-A4B-it-assistant-bf16"
elif [[ "$FULL_MODEL" == *"gemma-4-e2b"* ]]; then
ASST_MODEL="mlx-community/gemma-4-E2B-it-assistant-bf16"
else
read -p "Enter assistant model Hub ID: " ASST_MODEL
fi

echo ""
read -p "Enter context lengths to test [default: 512,40000,100000]: " CONTEXTS
CONTEXTS=${CONTEXTS:-"512,40000,100000"}

echo ""
echo "Building benchmark binary..."
swift build -c release --product Gemma4MTPBench
echo "Note: Test 13 uses the pre-built SwiftLM binary (same as Test 1)."
echo " Make sure you have run ./build.sh first."
echo ""

python3 -u scripts/profiling/mtp_bench.py \
--model "$FULL_MODEL" \
--contexts "$CONTEXTS" \
--max-tokens 60

IFS=',' read -ra ADDR <<< "$CONTEXTS"
for ctx in "${ADDR[@]}"; do
ctx=$(echo "$ctx" | tr -d ' ')
echo ""
echo "--- Test 13: Context (max-kv-size=$ctx) on $FULL_MODEL ---"
swift run -c release Gemma4MTPBench \
--main-model "$FULL_MODEL" \
--asst-model "$ASST_MODEL" \
--prompt "Write a detailed 3-paragraph essay on the impact of the Industrial Revolution on modern supply chain logistics. Ensure you include dates and specific technological advancements." \
--max-tokens 100 \
--max-kv-size "$ctx" | grep -v "ASST DEBUG"
done

echo ""
echo "✅ Gemma-4 MTP Speculative Decoding Benchmarks Complete."
exit 0
fi

# Fallback to Test 1 for anything else
echo ""
read -p "Enter context lengths to test [default: 512,40000,100000]: " CONTEXTS
Expand Down
Loading
Loading