Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@
ReferencedContainer = "container:CoreMLLLMChat.xcodeproj">
</BuildableReference>
</BuildableProductRunnable>
<EnvironmentVariables>
<EnvironmentVariable
key = "LLM_SHOW_EXPERIMENTAL"
value = "1"
isEnabled = "YES">
</EnvironmentVariable>
<EnvironmentVariable
key = "LLM_VISION_FORCE_ANE"
value = "1"
isEnabled = "YES">
</EnvironmentVariable>
</EnvironmentVariables>
</LaunchAction>
<ProfileAction
buildConfiguration = "Release"
Expand Down
234 changes: 227 additions & 7 deletions Examples/CoreMLLLMChat/CoreMLLLMChat/LLMRunner.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ final class LLMRunner {
private var gemma4StatefulEngine: Gemma4StatefulEngine?
private var gemma4StatefulTokenizer: (any Tokenizer)?

// Gemma 4 stateful + multimodal path (Stage 8): same 3-chunk merged
// Linear decode as the text-only stateful entry, plus a separate
// T=288 single-function prefill set under `prefill_T288/` and the
// vision/video/audio encoders. Selected when the bundle has both
// `chunk_{1..3}` and a `prefill_T288/` subdir alongside.
private var gemma4StatefulMultimodalEngine: Gemma4StatefulMultimodalEngine?
private var gemma4StatefulMultimodalTokenizer: (any Tokenizer)?
/// Cache the last image/audio features so a same-attachment follow-up
/// turn skips encoder cost (mirrors the legacy gemma4 multimodal path).
private var cachedGemma4MMImage: CGImage?
private var cachedGemma4MMImageFeatures: MLMultiArray?
private var cachedGemma4MMAudioSig: [Float]?
private var cachedGemma4MMAudioFeatures: MLMultiArray?
private var cachedGemma4MMAudioTokens: Int = 0

// Qwen3-VL 2B path: separate generator + tokenizer, selected when
// the downloaded folder contains `qwen3_vl_2b_decode_chunks/`.
// Plain GQA architecture (not the Qwen3.5 hybrid SSM), so it gets
Expand Down Expand Up @@ -91,6 +106,7 @@ final class LLMRunner {
|| qwen3vl2bGenerator != nil
|| qwen3vl2bStatefulGenerator != nil
|| gemma4StatefulEngine != nil
|| gemma4StatefulMultimodalEngine != nil
{
llm = nil
qwen35Generator = nil
Expand All @@ -102,6 +118,13 @@ final class LLMRunner {
qwen3vl2bVisionEncoder = nil
gemma4StatefulEngine = nil
gemma4StatefulTokenizer = nil
gemma4StatefulMultimodalEngine = nil
gemma4StatefulMultimodalTokenizer = nil
cachedGemma4MMImage = nil
cachedGemma4MMImageFeatures = nil
cachedGemma4MMAudioSig = nil
cachedGemma4MMAudioFeatures = nil
cachedGemma4MMAudioTokens = 0
cachedVisionImage = nil
cachedVisionFeatures = nil
isLoaded = false
Expand Down Expand Up @@ -187,13 +210,15 @@ final class LLMRunner {
return
}

// Gemma 4 E2B STATEFUL detection: chunk_{1..4}.mlpackage/.mlmodelc
// + embed_tokens_q8.bin under gemma4_e2b_stateful_chunks/. Both
// the Conv2d wrapper variant (folder=gemma4-e2b-stateful) and the
// Linear variant (folder=gemma4-e2b-stateful-linear, Plan 3 A/B)
// share the same internal layout — Gemma4StatefulEngine handles
// both transparently because the only difference is the MIL graph
// inside each chunk_*.mlpackage.
// Gemma 4 STATEFUL detection: chunk_{1..4}.mlpackage/.mlmodelc
// + embed_tokens_q8.bin under gemma4_e2b_stateful_chunks/. The
// subdir name is shared across all six published variants —
// E2B: gemma4-e2b-stateful{,-linear} (Conv2d / Plan 3 Linear)
// E4B: gemma4-e4b-stateful{,-linear} (Stage 2 port)
// — because Gemma4StatefulEngine reads hidden_size / num_layers /
// num_kv_heads from model_config.json, so per-model differences
// (E2B 35 layers / HKV=1 vs E4B 42 layers / HKV=2) need no
// engine code change.
// Require either:
// - chunks 1-3 (3-chunk or 4-chunk bundle — chunk_4 optional)
// - model.{mlpackage,mlmodelc} (1-chunk all-in-one)
Expand All @@ -212,7 +237,26 @@ final class LLMRunner {
let gemma4StatefulPresent = fm.fileExists(atPath:
gemma4StatefulDir.appendingPathComponent("embed_tokens_q8.bin").path)
&& (hasChunks || has1Chunk)
// Stage 8 multimodal-stateful detection: same 3-chunk decode
// bundle plus a `prefill_T288/` subdir with the three single-
// function prefill mlpackages, plus at least one of
// vision/audio mlmodelc. Route to Gemma4StatefulMultimodalEngine
// when present — falls through to the text-only stateful path
// when only the decode chunks are installed.
if gemma4StatefulPresent {
let prefillT288Dir = gemma4StatefulDir.appendingPathComponent("prefill_T288")
let hasPrefillT288 = ["chunk_1_prefill_T288",
"chunk_2_3way_prefill_T288",
"chunk_3_prefill_T288"].allSatisfy { name in
fm.fileExists(atPath:
prefillT288Dir.appendingPathComponent("\(name).mlpackage").path)
|| fm.fileExists(atPath:
prefillT288Dir.appendingPathComponent("\(name).mlmodelc").path)
}
if hasPrefillT288 {
try await loadGemma4StatefulMultimodal(folder: gemma4StatefulDir)
return
}
try await loadGemma4Stateful(folder: gemma4StatefulDir)
return
}
Expand Down Expand Up @@ -297,6 +341,10 @@ final class LLMRunner {
return try await generateQwen3VL2BStateful(
messages: messages, image: image)
}
if gemma4StatefulMultimodalEngine != nil {
return try await generateGemma4StatefulMultimodal(
messages: messages, image: image, audio: audio)
}
if gemma4StatefulEngine != nil {
return try await generateGemma4Stateful(messages: messages)
}
Expand Down Expand Up @@ -1149,6 +1197,178 @@ final class LLMRunner {
}
}

// MARK: - Gemma 4 stateful + multimodal (Stage 8)

private func loadGemma4StatefulMultimodal(folder: URL) async throws {
loadingStatus = "Loading Gemma 4 multimodal tokenizer..."
let hfDir = folder.appendingPathComponent("hf_model")
let tok = try await AutoTokenizer.from(modelFolder: hfDir)
loadingStatus = "Compiling Gemma 4 stateful multimodal chunks (first run only)..."
let engine = Gemma4StatefulMultimodalEngine()
try await engine.load(modelDirectory: folder)
gemma4StatefulMultimodalEngine = engine
gemma4StatefulMultimodalTokenizer = tok

let parent = folder.deletingLastPathComponent().lastPathComponent
let isE4B = parent.lowercased().contains("e4b")
modelName = isE4B
? "Gemma 4 E4B (stateful, multimodal)"
: "Gemma 4 E2B (stateful, multimodal)"
hasVision = engine.hasVision
hasAudio = engine.hasAudio
isLoaded = true
loadingStatus = "Ready"
print("[LLMRunner] Gemma 4 stateful multimodal loaded — \(modelName) " +
"vision=\(hasVision) video=\(engine.hasVideoVision) audio=\(hasAudio)")
}

private func generateGemma4StatefulMultimodal(messages: [ChatMessage],
image: CGImage?,
audio: [Float]?
) async throws -> AsyncStream<String> {
guard let engine = gemma4StatefulMultimodalEngine,
let tok = gemma4StatefulMultimodalTokenizer
else {
throw NSError(domain: "LLMRunner", code: 42,
userInfo: [NSLocalizedDescriptionKey:
"Gemma 4 stateful multimodal not loaded"])
}
isGenerating = true
tokensPerSecond = 0

// Encode image once per distinct attachment. Cache hit (same
// CGImage instance) skips the ~30 s vision graph + lets the
// engine's cross-turn KV reuse hit the LCP fast path.
var imageFeatures: MLMultiArray? = nil
var imageNumTokens = 0
var imageChanged = false
if let img = image {
if cachedGemma4MMImage === img, let f = cachedGemma4MMImageFeatures {
imageFeatures = f
imageNumTokens = 256
} else {
imageFeatures = try engine.processImage(img)
imageNumTokens = 256
cachedGemma4MMImage = img
cachedGemma4MMImageFeatures = imageFeatures
imageChanged = true
}
} else if cachedGemma4MMImage != nil {
cachedGemma4MMImage = nil
cachedGemma4MMImageFeatures = nil
imageChanged = true
}

var audioFeatures: MLMultiArray? = nil
var audioNumTokens = 0
var audioChanged = false
if let pcm = audio {
// Cheap fingerprint: [count, first, last]. Re-encode on
// any mismatch.
let sig: [Float] = pcm.isEmpty
? [0, 0, 0]
: [Float(pcm.count), pcm.first ?? 0, pcm.last ?? 0]
let sigMatches = (cachedGemma4MMAudioSig == sig)
if sigMatches, let f = cachedGemma4MMAudioFeatures {
audioFeatures = f
audioNumTokens = cachedGemma4MMAudioTokens
} else {
let (feat, n) = try engine.processAudio(pcm)
audioFeatures = feat
audioNumTokens = n
cachedGemma4MMAudioSig = sig
cachedGemma4MMAudioFeatures = feat
cachedGemma4MMAudioTokens = n
audioChanged = true
}
} else if cachedGemma4MMAudioFeatures != nil {
cachedGemma4MMAudioSig = nil
cachedGemma4MMAudioFeatures = nil
cachedGemma4MMAudioTokens = 0
audioChanged = true
}

// Attachment changed → drop persisted KV so the LCP match
// doesn't reuse stale image/audio rows from a prior turn.
if imageChanged || audioChanged { engine.resetPersistedState() }

// Build the Gemma 4 prompt. Image / audio blocks are pinned to
// the LAST user turn so cross-turn resume keeps the pad span at
// a fixed offset (same trick as the legacy gemma4 path).
let imageBlock = "<|image>"
+ String(repeating: "<|image|>", count: 256)
+ "<image|>"
let audioBlock = "<|audio>"
+ String(repeating: "<|audio|>", count: audioNumTokens)
+ "<audio|>"
let lastUserIdx = messages.lastIndex { $0.role == .user }
var prompt = "<bos>"
for (i, m) in messages.enumerated() {
switch m.role {
case .user:
let isLast = i == lastUserIdx
var mediaPrefix = ""
if imageFeatures != nil && isLast { mediaPrefix += imageBlock + "\n" }
if audioFeatures != nil && isLast && audioNumTokens > 0 {
mediaPrefix += audioBlock + "\n"
}
prompt += "<|turn>user\n\(mediaPrefix)\(m.content)<turn|>\n"
case .assistant:
prompt += "<|turn>model\n\(m.content)<turn|>\n"
case .system:
continue
}
}
prompt += "<|turn>model\n"
let inputIds = tok.encode(text: prompt).map { Int32($0) }

var eosSet: Set<Int32> = [1, 106]
if let eid = tok.eosTokenId { eosSet.insert(Int32(eid)) }
let skipSet: Set<Int32> = [1, 105, 106]

let genStart = Date()
return AsyncStream { continuation in
Task { [weak self] in
defer { Task { @MainActor in self?.isGenerating = false } }
var accum: [Int] = []
var emittedString = ""
var totalEmitted = 0
do {
_ = try await engine.generate(
inputIds: inputIds,
imageFeatures: imageFeatures,
imageNumTokens: imageNumTokens,
audioFeatures: audioFeatures,
audioNumTokens: audioNumTokens,
maxNewTokens: 256,
eosTokenIds: eosSet,
onToken: { tokenId in
if skipSet.contains(tokenId) { return }
accum.append(Int(tokenId))
let current = tok.decode(tokens: accum)
if current.count > emittedString.count {
let delta = String(
current.suffix(current.count - emittedString.count))
continuation.yield(delta)
emittedString = current
}
totalEmitted += 1
})
let dt = Date().timeIntervalSince(genStart)
if dt > 0 {
let tps = Double(totalEmitted) / dt
Task { @MainActor in
self?.tokensPerSecond = tps
}
}
} catch {
continuation.yield("[Error: \(error.localizedDescription)]")
}
continuation.finish()
}
}
}

/// Build the token ID sequence for a vision-augmented Qwen3-VL 2B
/// prompt. Emits the same prefix the HF processor would produce for
/// `[{role:"user", content:[{type:"image"},{type:"text", text:...}]}]`
Expand Down
12 changes: 12 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ let package = Package(
.executable(name: "determinism-oracle", targets: ["DeterminismOracle"]),
.executable(name: "verify-k8-probe", targets: ["VerifyK8Probe"]),
.executable(name: "ane-residency-gate", targets: ["AneResidencyGate"]),
.executable(name: "gemma4mm-smoke", targets: ["Gemma4MMSmoke"]),
// Standalone samples for the two Gemma-3-based models. These live in
// the same package on purpose — a LocalAIKit-style wrapper can depend
// on the `CoreMLLLM` library and use `FunctionGemma` / `EmbeddingGemma`
Expand Down Expand Up @@ -91,6 +92,17 @@ let package = Package(
path: "Sources/verify-k8-probe",
swiftSettings: [.swiftLanguageMode(.v5)]
),
// Mac smoke test for Gemma4StatefulMultimodalEngine — text-only
// generate to catch engine bugs without an iPhone trip.
.executableTarget(
name: "Gemma4MMSmoke",
dependencies: [
"CoreMLLLM",
.product(name: "Tokenizers", package: "swift-transformers"),
],
path: "Sources/gemma4mm-smoke",
swiftSettings: [.swiftLanguageMode(.v5)]
),
// FunctionGemma-270M standalone CLI. Does NOT combine with Gemma 4 —
// multi-model orchestration belongs in the LocalAIKit wrapper.
.executableTarget(
Expand Down
Loading
Loading