From 3c1f91590e78273e644566f14aeaa0d7eb998c88 Mon Sep 17 00:00:00 2001 From: John Rocky Date: Thu, 30 Apr 2026 08:53:22 +0900 Subject: [PATCH] docs(bonsai): post-mortem + ANE decode-state lessons MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the post-mortem for the ternary Bonsai investigation and the ANE decode-path lessons it produced. The Qwen3 architecture support that came out of the same investigation lands separately. Why Bonsai didn't ship: - prism-ml/Ternary-Bonsai-1.7B's compression depends on per-(row, block) independent scales (g=64). ANEC rejects that LUT granularity with error -14, and the stock-API per-block approximation factorizes scales into a rank-1 outer product, defeating the model's design. - For Apple Silicon, the GPU path (mlx-lm with the official Ternary-Bonsai-1.7B-mlx-2bit) is the only honest option. What lands as reusable infrastructure: - docs/TERNARY_BONSAI.md: full post-mortem (what was tried, why each failed, what the right path is) - docs/DECODE_STATE_LAYOUTS.md: ANE decode-state catalog — mask-based rotating buffer pattern, ctx > weight bandwidth result, palettization traps, ternary-on-ANE checklist - docs/GEMMA4_ROTATING_BUFFER_PORT.md: design note for porting our mask-based rotating buffer to Gemma 4's full-attention layers - docs/NEXT_MODELS.md: shortlist (Qwen3-1.7B, Gemma 3 4B QAT, Llama-3.2, SmolLM3) for the next port - docs/ADDING_MODELS.md: §4.5 KV state-layout checklist - docs/ANE_OPTIMIZATION_SURVEY.md: cross-reference to the ctx>weights finding - conversion/experiments/bonsai/: research scripts (oracle, ternary surgery, SWA comparisons, decode-chunks builder) retained as breadcrumbs in case anyone retraces the path - conversion/config.py: NOTE comment in MODEL_REGISTRY explaining why Bonsai is intentionally absent and pointing readers to the doc Extracted from feat/qwen3-bonsai-investigation (commit 56ee545). --- conversion/config.py | 8 + conversion/experiments/bonsai/README.md | 56 ++ .../bonsai/bonsai_reference_oracle.py | 168 +++++ .../bonsai/build_bonsai_17b_decode_chunks.py | 639 ++++++++++++++++++ .../bonsai/compare_swa_long_range.py | 187 +++++ .../experiments/bonsai/compare_swa_vs_full.py | 163 +++++ .../experiments/bonsai/ternary_surgery.py | 302 +++++++++ .../bonsai/test_bonsai_chunks_inference.py | 144 ++++ .../bonsai/test_bonsai_inference.py | 103 +++ .../bonsai/verify_bonsai_ternary.py | 123 ++++ docs/ADDING_MODELS.md | 21 + docs/ANE_OPTIMIZATION_SURVEY.md | 7 + docs/DECODE_STATE_LAYOUTS.md | 252 +++++++ docs/GEMMA4_ROTATING_BUFFER_PORT.md | 195 ++++++ docs/NEXT_MODELS.md | 116 ++++ docs/TERNARY_BONSAI.md | 107 +++ 16 files changed, 2591 insertions(+) create mode 100644 conversion/experiments/bonsai/README.md create mode 100644 conversion/experiments/bonsai/bonsai_reference_oracle.py create mode 100644 conversion/experiments/bonsai/build_bonsai_17b_decode_chunks.py create mode 100644 conversion/experiments/bonsai/compare_swa_long_range.py create mode 100644 conversion/experiments/bonsai/compare_swa_vs_full.py create mode 100644 conversion/experiments/bonsai/ternary_surgery.py create mode 100644 conversion/experiments/bonsai/test_bonsai_chunks_inference.py create mode 100644 conversion/experiments/bonsai/test_bonsai_inference.py create mode 100644 conversion/experiments/bonsai/verify_bonsai_ternary.py create mode 100644 docs/DECODE_STATE_LAYOUTS.md create mode 100644 docs/GEMMA4_ROTATING_BUFFER_PORT.md create mode 100644 docs/NEXT_MODELS.md create mode 100644 docs/TERNARY_BONSAI.md diff --git a/conversion/config.py b/conversion/config.py index d29f461..17beb27 100644 --- a/conversion/config.py +++ b/conversion/config.py @@ -82,6 +82,14 @@ class ConversionConfig: max_context_length=32_768, description="Liquid AI LFM2 350M - first-gen LFM2 350M (architecturally identical to 2.5)", ), + # NOTE: prism-ml/{,Ternary-}Bonsai-1.7B were investigated and intentionally + # not registered here. Their per-(row, block) ternary structure cannot be + # faithfully represented on ANE — Apple's ANEC rejects per-block LUT + # palettization with error -14, and any stock-API approximation collapses + # the per-block scales into a rank-1 outer product, defeating the model's + # core compression. See `docs/TERNARY_BONSAI.md` for the post-mortem. + # To run Bonsai on Apple Silicon, use mlx-lm with + # `prism-ml/Ternary-Bonsai-1.7B-mlx-2bit` (GPU, native ternary matmul). } diff --git a/conversion/experiments/bonsai/README.md b/conversion/experiments/bonsai/README.md new file mode 100644 index 0000000..e1e3233 --- /dev/null +++ b/conversion/experiments/bonsai/README.md @@ -0,0 +1,56 @@ +# Bonsai (1.58-bit ternary) — Investigation Artifacts + +These scripts attempted to bring `prism-ml/Ternary-Bonsai-1.7B` to Apple +Neural Engine via Core ML. **The investigation concluded that ANE cannot +faithfully run Bonsai's per-(row, block) ternary structure.** Apple's ANE +compiler rejects per-block LUT palettization (`error code: -14`), and +working around it (per-tensor / per-channel kmeans) collapses Bonsai's +core design — the per-block independent scales — into a rank-1 outer +product. At that point we'd be shipping "Qwen3-1.7B with palette quant", +not Bonsai. So we don't ship. + +The full post-mortem and the path forward (MLX Swift for Bonsai-class +models) is in `docs/TERNARY_BONSAI.md`. + +## What's here, briefly + +| File | Purpose | Result | +|---|---|---| +| `bonsai_reference_oracle.py` | HF vs our `Qwen3Model` parity, 5-token greedy match | **Pass** — confirmed `models/qwen3.py` correctness | +| `build_bonsai_17b_decode_chunks.py` | 2-chunk INT4/INT8 + optional SWA decode build | **Pass** — produced ANE-running INT4 at 24 tok/s, but quality is approximate Qwen3, not faithful Bonsai | +| `verify_bonsai_ternary.py` | Validates per-128-block ternary structure of unpacked FP16 | **Pass** — 100% of sampled 128-groups have exactly 3 unique values | +| `ternary_surgery.py` | Custom MIL pass: per-(row, block) `constexpr_lut_to_dense` palettization | **Pass to save, fail at load** — ANE compiler -14 | +| `test_bonsai_inference.py`, `test_bonsai_chunks_inference.py` | Smoke + benchmark | Used during investigation | +| `compare_swa_vs_full.py`, `compare_swa_long_range.py` | SWA-vs-full divergence measurements | Found long-range recall regression with sinks=0 SWA | + +## Reusable bits that escaped to `conversion/` + +These are the parts of the work that landed in the main codebase: + +- `models/qwen3.py` — Qwen3 architecture support (QK-norm, tied embed, + no attention bias). Useful for Qwen3-1.7B / 4B / 8B and any QK-normed + Qwen-family model. +- `base_model.py` — `ModelConfig.has_qk_norm` flag and conditional + `q_norm` / `k_norm` modules in `ANEAttention`. Backward-compatible + default (`has_qk_norm=False`) so Qwen2 / Gemma builds are unchanged. +- `exporter.py` — `MonolithicWrapper` applies QK-norm when the layer + has `has_qk_norm=True`. +- `convert.py` — `qwen3` architecture routes to `Qwen3Model`. +- `docs/DECODE_STATE_LAYOUTS.md` — captured ANE decode-path lessons + including the per-block palette finding. + +## If you want to actually run Bonsai on Apple Silicon + +Use MLX, not Core ML / ANE: + +```bash +pip install mlx-lm +mlx_lm.generate \ + --model prism-ml/Ternary-Bonsai-1.7B-mlx-2bit \ + --prompt "..." +``` + +`mlx-lm` natively supports the 2-bit packed ternary format with `mx.quantized_matmul`, +preserving the per-block scale structure. Runs on Apple Silicon GPU at full fidelity. + +For Swift integration, see [`mlx-swift-examples`](https://github.com/ml-explore/mlx-swift-examples). diff --git a/conversion/experiments/bonsai/bonsai_reference_oracle.py b/conversion/experiments/bonsai/bonsai_reference_oracle.py new file mode 100644 index 0000000..817238c --- /dev/null +++ b/conversion/experiments/bonsai/bonsai_reference_oracle.py @@ -0,0 +1,168 @@ +"""Parity test: HF Qwen3ForCausalLM vs our Qwen3Model (ANE-optimized Conv2d). + +Validates that weight loading + QK-norm are wired correctly by comparing: +- last-token logits cosine similarity +- top-1 next token prediction +- first N greedy tokens + +Usage: + python bonsai_reference_oracle.py --model-path /path/to/ternary-bonsai-1.7b + python bonsai_reference_oracle.py --model-path ... --max-new-tokens 5 +""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path + +import torch + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from ane_ops import MODEL_DTYPE, apply_rotary_pos_emb, repeat_kv, stable_attention +from models.qwen3 import Qwen3Model + + +DEFAULT_PROMPTS = [ + "The capital of France is", + "Hello, my name is", + "def fibonacci(n):\n if n <= 1:", +] + + +def cos_sim(a: torch.Tensor, b: torch.Tensor, eps: float = 1e-8) -> float: + a = a.flatten().to(torch.float32) + b = b.flatten().to(torch.float32) + return float((a @ b) / (a.norm() * b.norm() + eps)) + + +@torch.no_grad() +def hf_next_tokens(model, tokenizer, prompt: str, n: int) -> tuple[list[int], torch.Tensor]: + """Greedy decode n tokens with HF; return token ids + last-token logits at step 0.""" + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + first_out = model(input_ids=input_ids, use_cache=False) + last_logits = first_out.logits[0, -1, :].float().cpu() + + generated = input_ids.clone() + for _ in range(n): + out = model(input_ids=generated, use_cache=False) + next_id = out.logits[0, -1, :].argmax().item() + generated = torch.cat([generated, torch.tensor([[next_id]])], dim=1) + + return generated[0, input_ids.shape[1]:].tolist(), last_logits + + +@torch.no_grad() +def ours_prefill_last_logits( + our_model: Qwen3Model, input_ids: torch.Tensor +) -> torch.Tensor: + """Run one prefill through our Qwen3Model; return (vocab,) logits of final position. + + NOTE: bypasses ANETransformerModel.forward_transformer_prefill because that path + reads from the (empty) KV cache for attention instead of using the freshly computed + K/V for the current seq. For a standalone parity test we want a cache-free prefill: + attend current tokens to themselves with a seq x seq causal mask. + """ + seq_len = input_ids.shape[1] + positions = torch.arange(seq_len) + + # Cache-free prefill: seq x seq causal mask + mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), dtype=torch.float32), diagonal=1 + ).view(1, 1, seq_len, seq_len) + + hidden = our_model.forward_embeddings(input_ids) + + for layer in our_model.layers: + residual = hidden + hidden = layer.input_layernorm(hidden) + + attn = layer.self_attn + q, k, v = attn._project_qkv(hidden) # q_norm/k_norm applied inside + cos, sin = attn.rotary_emb.forward_range(positions) + cos = cos.permute(0, 2, 1, 3) + sin = sin.permute(0, 2, 1, 3) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + k_exp = repeat_kv(k, attn.n_rep) + v_exp = repeat_kv(v, attn.n_rep) + attn_out = stable_attention(q, k_exp, v_exp, attn.scale, mask) + attn_out = attn._output_proj(attn_out) + + hidden = residual + attn_out + + residual = hidden + hidden = layer.post_attention_layernorm(hidden) + hidden = layer.mlp(hidden) + hidden = residual + hidden + + hidden = our_model.norm(hidden) + + last = hidden[:, -1:, :] + x = last.permute(0, 2, 1).unsqueeze(2).to(hidden.dtype) + logits = our_model.lm_head(x).squeeze(2).permute(0, 2, 1) # (1, 1, vocab) + return logits[0, 0, :].float().cpu() + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--model-path", required=True, help="HF model dir with config.json + safetensors") + ap.add_argument("--context-length", type=int, default=1024) + ap.add_argument("--max-new-tokens", type=int, default=5, help="Greedy decode tokens for HF side") + ap.add_argument("--prompts", nargs="*", default=DEFAULT_PROMPTS) + args = ap.parse_args() + + # Lazy imports so the script prints a clean error if transformers is missing. + from transformers import AutoModelForCausalLM, AutoTokenizer + + model_path = Path(args.model_path).expanduser() + print(f"Loading HF model from {model_path}") + hf_model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float32, low_cpu_mem_usage=True + ) + hf_model.eval() + tokenizer = AutoTokenizer.from_pretrained(model_path) + + print(f"Loading our Qwen3Model (context_length={args.context_length})") + our_model = Qwen3Model.from_pretrained(str(model_path), context_length=args.context_length) + our_model.eval() + + total = 0 + passed = 0 + + for prompt in args.prompts: + print(f"\nprompt: {prompt!r}") + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + seq_len = input_ids.shape[1] + print(f" tokens ({seq_len}): {input_ids[0].tolist()}") + + hf_next, hf_last_logits = hf_next_tokens(hf_model, tokenizer, prompt, args.max_new_tokens) + + our_last_logits = ours_prefill_last_logits(our_model, input_ids) + our_top1 = int(our_last_logits.argmax().item()) + + cs = cos_sim(hf_last_logits, our_last_logits) + hf_top1 = hf_next[0] + match = our_top1 == hf_top1 + + hf_text = tokenizer.decode(hf_next) + print(f" HF next token: {hf_top1} ({tokenizer.decode([hf_top1])!r})") + print(f" our top-1: {our_top1} ({tokenizer.decode([our_top1])!r})") + print(f" cos(last_logits): {cs:.6f} match: {match}") + print(f" HF {args.max_new_tokens}-token continuation: {hf_text!r}") + + total += 1 + if match and cs >= 0.95: + passed += 1 + + print(f"\nparity summary: {passed}/{total} prompts passed (top1 match + cos>=0.95)") + if passed < total: + print("FAIL — investigate QK-norm, weight map, or attention scale") + sys.exit(1) + print("PASS") + + +if __name__ == "__main__": + main() diff --git a/conversion/experiments/bonsai/build_bonsai_17b_decode_chunks.py b/conversion/experiments/bonsai/build_bonsai_17b_decode_chunks.py new file mode 100644 index 0000000..1fd20da --- /dev/null +++ b/conversion/experiments/bonsai/build_bonsai_17b_decode_chunks.py @@ -0,0 +1,639 @@ +"""Ternary-Bonsai-1.7B decode: 2-chunk INT8 palettized build for iPhone ANE. + +Why chunks: monolithic INT8 is ~1.9 GB — same class as Qwen3.5-2B, which jetsam-killed +on iPhone (see `docs/QWEN35_2B_CHUNKED_HANDOFF.md`). Per-mlpackage ANE compile budget +is ~1.4 GB; splitting gets us under it and also stops the silent GPU fallback that +kills throughput even on Mac. + +Split (28 layers total): + chunk_a: input_ids → embed → layers [0..14) → hidden_out (fp16) + chunk_b: hidden_in → layers [14..28) + norm + lm_head → token_id, token_logit + +Each chunk ships its own `kv_cache` StateType for its 14 layers. + +Uses the same attention / KV-write pattern as `exporter.py::MonolithicWrapper` +(mask-based cache update, per-channel `index_select` for RoPE, Conv2d Q/K/V/O, +QK-norm before RoPE). That wrapper is parity-verified against HF. + +Output layout (Swift-friendly): + /bonsai_1_7b_decode_chunks/ + chunk_a.mlpackage + chunk_b.mlpackage + model_config.json +""" + +from __future__ import annotations + +import argparse +import json +import os +import shutil +import sys +import time +from collections import Counter +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn + +import coremltools as ct +from coremltools.optimize.coreml import ( + OpPalettizerConfig, + OptimizationConfig, + palettize_weights, +) + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from ane_ops import MODEL_DTYPE, apply_rotary_pos_emb +from models.qwen3 import Qwen3Model + + +# ---- Shared forward helpers (match MonolithicWrapper) --------------------- + + +def _decode_layer_step( + layer, + hidden_states: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + causal_mask: torch.Tensor, + update_mask: torch.Tensor, + K_cache: torch.Tensor, + V_cache: torch.Tensor, + num_heads: int, + num_kv_heads: int, + head_dim: int, + n_rep: int, + scale: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Single decoder-layer decode step with mask-based KV write. + + Returns: (hidden_out, K_new, V_new) where K_new/V_new are the (1, kv, ctx, d) + caches with the current position written in. + """ + residual = hidden_states + hidden_states = layer.input_layernorm(hidden_states) + + x = hidden_states.permute(0, 2, 1).unsqueeze(2).to(MODEL_DTYPE) + q = ( + layer.self_attn.q_proj(x) + .view(1, num_heads, head_dim, 1) + .permute(0, 1, 3, 2) + .to(MODEL_DTYPE) + ) + k = ( + layer.self_attn.k_proj(x) + .view(1, num_kv_heads, head_dim, 1) + .permute(0, 1, 3, 2) + .to(MODEL_DTYPE) + ) + v = ( + layer.self_attn.v_proj(x) + .view(1, num_kv_heads, head_dim, 1) + .permute(0, 1, 3, 2) + .to(MODEL_DTYPE) + ) + + # Qwen3 QK-norm (per-head RMSNorm) before RoPE + if getattr(layer.self_attn, "has_qk_norm", False): + q = layer.self_attn.q_norm(q) + k = layer.self_attn.k_norm(k) + + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + # Mask-based KV write — broadcast current (1,1,head_dim) to (1,ctx,head_dim) + # and blend with cache at `update_mask`-selected position. + k_broadcast = k.expand_as(K_cache) + v_broadcast = v.expand_as(V_cache) + K_new = K_cache * (1 - update_mask) + k_broadcast * update_mask + V_new = V_cache * (1 - update_mask) + v_broadcast * update_mask + + K_expanded = K_new.repeat_interleave(n_rep, dim=1) + V_expanded = V_new.repeat_interleave(n_rep, dim=1) + + q_f = q.to(torch.float32) + k_f = K_expanded.to(torch.float32) + attn_weights = torch.matmul(q_f, k_f.transpose(-1, -2)) * scale + attn_weights = attn_weights + causal_mask.to(torch.float32) + attn_weights = torch.softmax(attn_weights, dim=-1).to(MODEL_DTYPE) + attn_output = torch.matmul( + attn_weights.to(torch.float32), V_expanded.to(torch.float32) + ).to(MODEL_DTYPE) + + attn_output = attn_output.permute(0, 2, 1, 3).contiguous().view(1, 1, -1) + attn_output = ( + layer.self_attn.o_proj(attn_output.permute(0, 2, 1).unsqueeze(2)) + .squeeze(2) + .permute(0, 2, 1) + ) + hidden_states = residual + attn_output + + residual = hidden_states + hidden_states = layer.post_attention_layernorm(hidden_states) + hidden_states = layer.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, K_new, V_new + + +# SWA uses the exact same `_decode_layer_step` as full attention. The only +# differences are external: state buffer is sized W (not ctx), and the host +# passes an `update_mask` with the 1.0 at `pos % W` (circular slot), plus a +# causal_mask sized (1,1,1,W). This keeps the ops identical to the ANE-proven +# non-SWA path (mask-based blend + standard matmul attention) and avoids the +# ANEC -14 compile rejection that `cat([K[:,:,1:,:], k])` hits. +# +# Attention order invariance: softmax+weighted sum is permutation-invariant +# over the keys, so the scrambled slot order in the circular buffer is fine. +# RoPE is baked into K at write time, so positional information is preserved. + + +class ChunkBase(nn.Module): + """Base class holding shared precomputed RoPE + KV cache state.""" + + def __init__(self, config, layer_indices: list[int], + sliding_window: int | None = None) -> None: + super().__init__() + self.config = config + self.num_chunk_layers = len(layer_indices) + self.sliding_window = sliding_window + + # KV cache size: + # full attention → (..., ctx, head_dim) + # SWA (shift-based rotating buffer) → (..., W, head_dim) + state_len = sliding_window if sliding_window is not None else config.context_length + cache_shape = ( + 2 * self.num_chunk_layers, + config.num_key_value_heads, + state_len, + config.head_dim, + ) + self.register_buffer("kv_cache", torch.zeros(cache_shape, dtype=MODEL_DTYPE)) + + # RoPE cos/sin (shared identical buffer across chunks — small, ~4 MB total) + head_dim = config.head_dim + base = config.rope_theta + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + max_len = config.context_length * 2 + t = torch.arange(max_len).float() + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(MODEL_DTYPE)) + self.register_buffer("sin_cached", emb.sin().to(MODEL_DTYPE)) + + +class ChunkA(ChunkBase): + """Head chunk: input_ids → embed → layers [0, split) → hidden_out. + + Inputs are the same whether full-attention or SWA — the only difference is + the state buffer size (ctx vs W) and the semantics of update_mask (absolute + position vs `pos % W` slot). + """ + + def __init__(self, full_model: Qwen3Model, split_at: int, + sliding_window: int | None = None) -> None: + super().__init__(full_model.config, list(range(split_at)), + sliding_window=sliding_window) + self.embed_tokens = full_model.embed_tokens + self.layers = nn.ModuleList([full_model.layers[i] for i in range(split_at)]) + + def forward(self, input_ids, position_ids, causal_mask, update_mask): + cfg = self.config + num_heads = cfg.num_attention_heads + num_kv_heads = cfg.num_key_value_heads + head_dim = cfg.head_dim + n_rep = num_heads // num_kv_heads + scale = 1.0 / (head_dim ** 0.5) + num_layers = self.num_chunk_layers + + hidden_states = self.embed_tokens(input_ids).to(MODEL_DTYPE) + + cos = torch.index_select(self.cos_cached, 0, position_ids).view(1, 1, 1, head_dim) + sin = torch.index_select(self.sin_cached, 0, position_ids).view(1, 1, 1, head_dim) + + for i in range(num_layers): + layer = self.layers[i] + k_idx = i + v_idx = num_layers + i + K_cache = self.kv_cache[k_idx].unsqueeze(0) + V_cache = self.kv_cache[v_idx].unsqueeze(0) + + hidden_states, K_new, V_new = _decode_layer_step( + layer, hidden_states, cos, sin, causal_mask, update_mask, + K_cache, V_cache, num_heads, num_kv_heads, head_dim, n_rep, scale, + ) + + self.kv_cache[k_idx] = K_new.squeeze(0) + self.kv_cache[v_idx] = V_new.squeeze(0) + + return hidden_states + + +class ChunkB(ChunkBase): + """Tail chunk: hidden_in → layers [split, end) → norm → lm_head → (token, logit).""" + + def __init__(self, full_model: Qwen3Model, split_at: int, + sliding_window: int | None = None) -> None: + cfg = full_model.config + tail_indices = list(range(split_at, cfg.num_hidden_layers)) + super().__init__(cfg, tail_indices, sliding_window=sliding_window) + self.layers = nn.ModuleList([full_model.layers[i] for i in tail_indices]) + self.norm = full_model.norm + self.lm_head = full_model.lm_head + self.argmax = full_model.argmax + + def forward(self, hidden_in, position_ids, causal_mask, update_mask): + cfg = self.config + num_heads = cfg.num_attention_heads + num_kv_heads = cfg.num_key_value_heads + head_dim = cfg.head_dim + n_rep = num_heads // num_kv_heads + scale = 1.0 / (head_dim ** 0.5) + num_layers = self.num_chunk_layers + + hidden_states = hidden_in.to(MODEL_DTYPE) + + cos = torch.index_select(self.cos_cached, 0, position_ids).view(1, 1, 1, head_dim) + sin = torch.index_select(self.sin_cached, 0, position_ids).view(1, 1, 1, head_dim) + + for i in range(num_layers): + layer = self.layers[i] + k_idx = i + v_idx = num_layers + i + K_cache = self.kv_cache[k_idx].unsqueeze(0) + V_cache = self.kv_cache[v_idx].unsqueeze(0) + + hidden_states, K_new, V_new = _decode_layer_step( + layer, hidden_states, cos, sin, causal_mask, update_mask, + K_cache, V_cache, num_heads, num_kv_heads, head_dim, n_rep, scale, + ) + + self.kv_cache[k_idx] = K_new.squeeze(0) + self.kv_cache[v_idx] = V_new.squeeze(0) + + hidden_states = self.norm(hidden_states) + x = hidden_states.permute(0, 2, 1).unsqueeze(2).to(MODEL_DTYPE) + logits = self.lm_head(x).squeeze(2).permute(0, 2, 1) + return self.argmax(logits.squeeze(0)) + + +# ---- ANE placement audit -------------------------------------------------- + + +def audit_ane(pkg_path: Path) -> float: + """Print per-op device placement and return ANE percentage. + + Returns -1.0 if audit fails (e.g. the save-time ANEC warning left the model + without a cached compiled path). This is a diagnostic; never fatal. + """ + try: + reloaded = ct.models.MLModel(str(pkg_path), compute_units=ct.ComputeUnit.CPU_AND_NE) + compiled = reloaded.get_compiled_model_path() + plan = ct.models.compute_plan.MLComputePlan.load_from_path( + path=str(compiled), compute_units=ct.ComputeUnit.CPU_AND_NE, + ) + except Exception as e: + print(f" ANE audit skipped: {e}") + return -1.0 + dev = Counter() + for fn in plan.model_structure.program.functions.values(): + for op in fn.block.operations: + a = plan.get_compute_device_usage_for_mlprogram_operation(op) + d = ("const" if (a is None and op.operator_name == "const") + else (a.preferred_compute_device.__class__.__name__ if a else "unknown")) + dev[d] += 1 + total = sum(dev.values()) + compute = total - dev.get("const", 0) + ane = dev.get("MLNeuralEngineComputeDevice", 0) + pct = 100.0 * ane / compute if compute else 0.0 + print(f" ANE placement: {ane}/{compute} ({pct:.1f}%) " + f"dev breakdown={dict(dev)}") + return pct + + +# ---- Conversion ----------------------------------------------------------- + + +def convert_chunk( + chunk: nn.Module, + ctx: int, + hidden_size: int, + cache_shape: tuple, + out_path: Path, + *, + is_head: bool, + sliding_window: int | None = None, +) -> ct.models.MLModel: + """Trace + convert one chunk to fp16 mlpackage. + + If sliding_window is set, the chunk expects an extra (W,) int32 `gather_idx` + input and the causal_mask shape is (1,1,1,W) instead of (1,1,1,ctx). + """ + label = "chunk_a (head)" if is_head else "chunk_b (tail)" + w = sliding_window if sliding_window is not None else ctx + print(f"\n--- {label} → {out_path.name} " + f"(ctx={ctx}, window={w}{' SWA' if sliding_window else ''}) ---") + + if is_head: + sample_input = torch.zeros((1, 1), dtype=torch.int32) + input_spec = ct.TensorType(name="input_ids", shape=(1, 1), dtype=np.int32) + else: + sample_input = torch.zeros((1, 1, hidden_size), dtype=torch.float16) + input_spec = ct.TensorType( + name="hidden_in", shape=(1, 1, hidden_size), dtype=np.float16 + ) + + # State buffer / mask length: + # full attention → buffer = ctx, update_mask over ctx, causal over ctx + # SWA → buffer = W, update_mask over W, causal over W + # This keeps the op pattern identical to the proven non-SWA path. SWA just + # uses a smaller rotating buffer with `pos % W` slot selection on the host. + sample_position = torch.zeros((1,), dtype=torch.int32) + sample_causal = torch.zeros((1, 1, 1, w), dtype=torch.float16) + sample_update = torch.zeros((1, 1, w, 1), dtype=torch.float16) + sample_update[0, 0, 0, 0] = 1.0 + + sample_args = [sample_input, sample_position, sample_causal, sample_update] + + inputs = [ + input_spec, + ct.TensorType(name="position_ids", shape=(1,), dtype=np.int32), + ct.TensorType(name="causal_mask", shape=(1, 1, 1, w), dtype=np.float16), + ct.TensorType(name="update_mask", shape=(1, 1, w, 1), dtype=np.float16), + ] + + with torch.no_grad(): + chunk.kv_cache.zero_() + + print(" tracing...") + t0 = time.time() + with torch.no_grad(): + # strict=False: the module mutates `kv_cache` buffer, so JIT's trace- + # validation re-run sees different state and complains. The graph itself + # is correct (mask-based or shift-based write is functional); this matches + # the pattern in `build_qwen35_2b_decode_chunks.py`. + traced = torch.jit.trace(chunk, tuple(sample_args), strict=False) + print(f" traced in {time.time()-t0:.1f}s") + + if is_head: + outputs = [ct.TensorType(name="hidden", dtype=np.float16)] + else: + outputs = [ + ct.TensorType(name="token_id", dtype=np.int32), + ct.TensorType(name="token_logit", dtype=np.float16), + ] + + states = [ + ct.StateType( + wrapped_type=ct.TensorType(shape=cache_shape, dtype=np.float16), + name="kv_cache", + ), + ] + + print(" converting to CoreML...") + t0 = time.time() + mlmodel = ct.convert( + traced, + convert_to="mlprogram", + inputs=inputs, + outputs=outputs, + states=states, + compute_units=ct.ComputeUnit.CPU_AND_NE, + minimum_deployment_target=ct.target.iOS18, + ) + print(f" converted in {time.time()-t0:.1f}s") + + if out_path.exists(): + shutil.rmtree(out_path) + mlmodel.save(str(out_path)) + size_mb = sum(f.stat().st_size for f in out_path.rglob("*") if f.is_file()) / 1e6 + print(f" saved fp16 {out_path.name} ({size_mb:.0f} MB)") + return mlmodel + + +def palettize( + src: Path, + dst: Path, + nbits: int, + mode: str = "kmeans", + granularity: str = "per_tensor", + group_size: int | None = None, +) -> None: + """Palettize weights via Core ML OpPalettizerConfig. + + Useful combos: + • kmeans / per_tensor / nbits=4 — default lossy approximation + • unique / per_grouped_channel / nbits=2 / group_size=128 + bit-exact for Bonsai-style ternary: each 128-group has only 3 + distinct values so nbits=2 palette is lossless. No quality drop. + """ + label = f"{mode}-{granularity}" + if group_size is not None: + label += f"-g{group_size}" + print(f"\n--- palettize INT{nbits} ({label}): {src.name} → {dst.name} ---") + m = ct.models.MLModel(str(src)) + kwargs: dict = dict(mode=mode, granularity=granularity) + # `unique` mode derives nbits itself from the unique-value count; passing + # nbits is explicitly rejected by OpPalettizerConfig. + if mode != "unique": + kwargs["nbits"] = nbits + if granularity == "per_grouped_channel" and group_size is not None: + kwargs["group_size"] = group_size + op_cfg = OpPalettizerConfig(**kwargs) + opt = OptimizationConfig(global_config=op_cfg) + t0 = time.time() + m = palettize_weights(m, opt) + print(f" palettize in {time.time()-t0:.1f}s") + if dst.exists(): + shutil.rmtree(dst) + m.save(str(dst)) + src_mb = sum(f.stat().st_size for f in src.rglob("*") if f.is_file()) / 1e6 + dst_mb = sum(f.stat().st_size for f in dst.rglob("*") if f.is_file()) / 1e6 + print(f" {src_mb:.0f} MB (fp16) → {dst_mb:.0f} MB (nbits={nbits} {label}) " + f"[{100*dst_mb/src_mb:.1f}%]") + audit_ane(dst) + + +# Back-compat alias for earlier scripts that imported palettize_kmeans. +def palettize_kmeans(src: Path, dst: Path, nbits: int) -> None: + palettize(src, dst, nbits, mode="kmeans", granularity="per_tensor") + + +# ---- Main ----------------------------------------------------------------- + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--model-path", required=True, + help="HF Bonsai model dir (config.json + safetensors)") + ap.add_argument("--output", required=True, help="Output root dir") + ap.add_argument("--context-length", type=int, default=2048) + ap.add_argument("--split-at", type=int, default=14, + help="Split layers before this index (14 → layers 0-13 / 14-27)") + ap.add_argument("--quantize", + choices=["fp16", "int8", "int4", "ternary"], + default="int8", + help="Quantization preset. 'ternary' = nbits=2 + mode=unique + " + "per_grouped_channel + group_size=128, which is bit-exact for " + "Bonsai's {-s, 0, +s} per-128-group weights (see " + "`verify_bonsai_ternary.py`). Lossless vs INT8/INT4 kmeans " + "approximations.") + ap.add_argument("--nbits", type=int, default=None, + help="Override quantize nbits (1,2,3,4,6,8). Takes precedence over --quantize") + ap.add_argument("--palette-mode", default=None, + choices=[None, "kmeans", "uniform", "unique"], + help="Override palettization mode. Default depends on --quantize.") + ap.add_argument("--palette-granularity", default=None, + choices=[None, "per_tensor", "per_grouped_channel"], + help="Override granularity. Default depends on --quantize.") + ap.add_argument("--palette-group-size", type=int, default=None, + help="Group size for per_grouped_channel. Default 128 for ternary.") + ap.add_argument("--keep-fp16", action="store_true", + help="Keep the fp16 intermediates under _fp16_intermediate/ for re-palettizing") + ap.add_argument("--sliding-window", type=int, default=None, + help="Enable SWA decode: state buffer = context-length but per-step " + "attention is over the last W slots selected via gather_idx input. " + "Expected to preserve ~ctx=W speed while allowing ctx=context-length " + "prefill. Host computes gather_idx + windowed causal mask per step. " + "Set to e.g. 1024 while --context-length 4096.") + args = ap.parse_args() + + out_root = Path(args.output).resolve() + out_dir = out_root / "bonsai_1_7b_decode_chunks" + tmp_dir = out_root / "_fp16_intermediate" + out_dir.mkdir(parents=True, exist_ok=True) + tmp_dir.mkdir(parents=True, exist_ok=True) + + print(f"Loading Qwen3Model from {args.model_path}") + t0 = time.time() + model = Qwen3Model.from_pretrained(args.model_path, context_length=args.context_length) + model.eval() + cfg = model.config + print(f" loaded in {time.time()-t0:.1f}s: {cfg.num_hidden_layers} layers, " + f"hidden={cfg.hidden_size}, heads={cfg.num_attention_heads}/" + f"{cfg.num_key_value_heads}, head_dim={cfg.head_dim}, " + f"vocab={cfg.vocab_size}, tie_embed={cfg.tie_word_embeddings}") + + assert 0 < args.split_at < cfg.num_hidden_layers, \ + f"split_at must be in (0, {cfg.num_hidden_layers}), got {args.split_at}" + + if args.sliding_window is not None: + assert 0 < args.sliding_window <= args.context_length, \ + f"--sliding-window must be in (0, {args.context_length}], got {args.sliding_window}" + print(f" SWA: state_buffer=W={args.sliding_window} (circular, pos % W slot). " + f"context_length={args.context_length} only bounds RoPE max position.") + chunk_a = ChunkA(model, args.split_at, sliding_window=args.sliding_window).eval() + chunk_b = ChunkB(model, args.split_at, sliding_window=args.sliding_window).eval() + chunk_a_shape = tuple(chunk_a.kv_cache.shape) + chunk_b_shape = tuple(chunk_b.kv_cache.shape) + print(f" chunk_a: embed + layers [0..{args.split_at}) state {chunk_a_shape}") + print(f" chunk_b: layers [{args.split_at}..{cfg.num_hidden_layers}) + head " + f"state {chunk_b_shape}") + + # Free full-model param refs we don't need for tracing + # (chunks hold direct references to the relevant submodules already). + del model + + fp16_a = tmp_dir / "chunk_a.mlpackage" + fp16_b = tmp_dir / "chunk_b.mlpackage" + final_a = out_dir / "chunk_a.mlpackage" + final_b = out_dir / "chunk_b.mlpackage" + + convert_chunk( + chunk_a, args.context_length, cfg.hidden_size, chunk_a_shape, + fp16_a, is_head=True, sliding_window=args.sliding_window, + ) + audit_ane(fp16_a) + del chunk_a + + convert_chunk( + chunk_b, args.context_length, cfg.hidden_size, chunk_b_shape, + fp16_b, is_head=False, sliding_window=args.sliding_window, + ) + audit_ane(fp16_b) + del chunk_b + + if args.quantize == "fp16" and args.nbits is None: + if final_a.exists(): + shutil.rmtree(final_a) + if final_b.exists(): + shutil.rmtree(final_b) + shutil.copytree(fp16_a, final_a) + shutil.copytree(fp16_b, final_b) + else: + # Defaults per --quantize preset + preset_nbits = {"int8": 8, "int4": 4, "ternary": 2} + preset_mode = {"int8": "kmeans", "int4": "kmeans", "ternary": "unique"} + preset_granularity = { + "int8": "per_tensor", + "int4": "per_tensor", + "ternary": "per_grouped_channel", + } + preset_group_size = {"ternary": 128} + + nbits = args.nbits if args.nbits is not None else preset_nbits[args.quantize] + mode = args.palette_mode if args.palette_mode else preset_mode[args.quantize] + granularity = (args.palette_granularity + if args.palette_granularity + else preset_granularity[args.quantize]) + group_size = (args.palette_group_size + if args.palette_group_size is not None + else preset_group_size.get(args.quantize)) + + palettize(fp16_a, final_a, nbits, mode=mode, + granularity=granularity, group_size=group_size) + palettize(fp16_b, final_b, nbits, mode=mode, + granularity=granularity, group_size=group_size) + + # Manifest for Swift + manifest = { + "architecture": "qwen3", + "model": "ternary-bonsai-1.7b", + "split_at": args.split_at, + "context_length": args.context_length, + "num_hidden_layers": cfg.num_hidden_layers, + "num_attention_heads": cfg.num_attention_heads, + "num_key_value_heads": cfg.num_key_value_heads, + "head_dim": cfg.head_dim, + "hidden_size": cfg.hidden_size, + "vocab_size": cfg.vocab_size, + "rms_norm_eps": cfg.rms_norm_eps, + "rope_theta": cfg.rope_theta, + "tie_word_embeddings": cfg.tie_word_embeddings, + "bos_token_id": cfg.bos_token_id, + "eos_token_id": cfg.eos_token_id, + "quantization": args.quantize if args.nbits is None else f"int{args.nbits}", + "palette_mode": args.palette_mode or ( + {"int8": "kmeans", "int4": "kmeans", "ternary": "unique"}.get(args.quantize) + ), + "palette_granularity": args.palette_granularity or ( + {"int8": "per_tensor", "int4": "per_tensor", + "ternary": "per_grouped_channel"}.get(args.quantize) + ), + "palette_group_size": (args.palette_group_size if args.palette_group_size + else (128 if args.quantize == "ternary" else None)), + "sliding_window": args.sliding_window, + "parts": { + "chunk_a": "chunk_a.mlpackage", + "chunk_b": "chunk_b.mlpackage", + }, + } + with open(out_dir / "model_config.json", "w") as f: + json.dump(manifest, f, indent=2) + + if not args.keep_fp16: + shutil.rmtree(tmp_dir, ignore_errors=True) + + print(f"\n✓ shipping artifacts under {out_dir}") + for p in sorted(out_dir.iterdir()): + size = ( + sum(f.stat().st_size for f in p.rglob("*") if f.is_file()) / 1e6 + if p.is_dir() else p.stat().st_size / 1e6 + ) + print(f" {p.name}: {size:.0f} MB") + + +if __name__ == "__main__": + main() diff --git a/conversion/experiments/bonsai/compare_swa_long_range.py b/conversion/experiments/bonsai/compare_swa_long_range.py new file mode 100644 index 0000000..425110f --- /dev/null +++ b/conversion/experiments/bonsai/compare_swa_long_range.py @@ -0,0 +1,187 @@ +"""Long-range dependency test: does SWA forget setup info >W tokens back? + +Prompt structure: + : a short memorable fact + : N tokens of neutral text padding the context to push SETUP out of window + : re-mentions SETUP, asking the model to complete it + +If Full retains SETUP (within its ctx-sized attention), its next token after TRIGGER +should match the setup. SWA with W=1024 may have dropped SETUP if (prompt_len + +generated_so_far) - setup_pos > W, and would then continue with something else. + +We measure: + - top-1 agreement between FULL and SWA for the first N generated tokens + - first divergence + - whether Full's continuation matches the setup substring (qualitative) +""" +from __future__ import annotations + +import argparse +import json +import time +from pathlib import Path + +import numpy as np +import coremltools as ct +from transformers import AutoTokenizer + + +def load_chunks(bundle: Path, cu: ct.ComputeUnit): + cfg = json.load(open(bundle / "bonsai_1_7b_decode_chunks/model_config.json")) + chunk_a = ct.models.MLModel(str(bundle / "bonsai_1_7b_decode_chunks/chunk_a.mlpackage"), + compute_units=cu) + chunk_b = ct.models.MLModel(str(bundle / "bonsai_1_7b_decode_chunks/chunk_b.mlpackage"), + compute_units=cu) + return chunk_a, chunk_b, cfg + + +def build_feeds(cfg, tok_id: int, pos: int) -> dict: + swa = cfg.get("sliding_window") + ctx = cfg["context_length"] + pos_arr = np.array([pos], dtype=np.int32) + if swa is None: + L = ctx + write_slot = pos + valid_range = range(pos + 1) + else: + L = swa + write_slot = pos % swa + valid_count = min(pos + 1, swa) + valid_range = [((pos - i) % swa) for i in range(valid_count)] + causal = np.full((1, 1, 1, L), -1e4, dtype=np.float16) + for s in valid_range: + causal[0, 0, 0, s] = 0.0 + update = np.zeros((1, 1, L, 1), dtype=np.float16) + update[0, 0, write_slot, 0] = 1.0 + return { + "input_ids": np.array([[tok_id]], dtype=np.int32), + "position_ids": pos_arr, + "causal_mask": causal, + "update_mask": update, + } + + +def step(chunk_a, chunk_b, cfg, state_a, state_b, tok_id: int, pos: int) -> int: + feed = build_feeds(cfg, tok_id, pos) + out_a = chunk_a.predict(feed, state=state_a) + hidden = out_a["hidden"].astype(np.float16) + feed_b = {**feed, "hidden_in": hidden} + del feed_b["input_ids"] + out_b = chunk_b.predict(feed_b, state=state_b) + return int(out_b["token_id"].item()) + + +def run_bundle(bundle: Path, tok, prompt_ids: list[int], max_new: int, cu): + chunk_a, chunk_b, cfg = load_chunks(bundle, cu) + sa = chunk_a.make_state() + sb = chunk_b.make_state() + + nxt = 0 + for i, tid in enumerate(prompt_ids): + nxt = step(chunk_a, chunk_b, cfg, sa, sb, tid, i) + gen: list[int] = [] + cur = len(prompt_ids) - 1 + for _ in range(max_new): + cur += 1 + tok_in = gen[-1] if gen else nxt + nxt = step(chunk_a, chunk_b, cfg, sa, sb, tok_in, cur) + gen.append(nxt) + return gen, cfg + + +# Filler text that doesn't reference the setup. Pure lorem-ipsum-like continuation +# style. Generated from a typical base-model-friendly topic. +FILLER_PARA = ( + "The river flows through the valley, carving paths into the rocks over time. " + "Birds sing in the trees, their melodies carried by the wind. Farmers tend to " + "their fields, growing crops that feed the village. In the evening, lanterns " + "are lit, casting warm light on the cobblestone streets. Children play by the " + "fountain, laughing and running. The baker opens early every morning, the smell " + "of fresh bread drifting down the lane. Merchants arrive from distant lands, " + "bringing goods and stories from faraway places. The old clock tower rings at " + "noon, marking the time of day. Travelers stop at the inn to rest. " +) + + +def build_test_prompt(tok, setup: str, trigger: str, min_filler_tokens: int = 1100): + """Build: , return input_ids.""" + filler = "" + while len(tok(filler, return_tensors="np").input_ids[0]) < min_filler_tokens: + filler += FILLER_PARA + prompt = setup + " " + filler + " " + trigger + return tok(prompt, return_tensors="np").input_ids[0].tolist() + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--full-bundle", required=True) + ap.add_argument("--swa-bundle", required=True) + ap.add_argument("--tokenizer", required=True) + ap.add_argument("--compute-units", default="CPU_AND_NE") + ap.add_argument("--filler-tokens", type=int, default=1100, + help="Push setup past W=1024 boundary") + ap.add_argument("--max-new-tokens", type=int, default=30) + args = ap.parse_args() + + cu = getattr(ct.ComputeUnit, args.compute_units) + tok = AutoTokenizer.from_pretrained(args.tokenizer) + + # Stable, memorable setup that's easy to verify in continuation. + setup = "My favorite color is chartreuse, which is a vibrant yellow-green." + trigger = "To remind you, my favorite color is" + + ids = build_test_prompt(tok, setup, trigger, args.filler_tokens) + print(f"prompt shape: {len(ids)} tokens " + f"(~{args.filler_tokens} filler + setup + trigger)") + setup_ids = tok(setup, return_tensors="np").input_ids[0].tolist() + trigger_ids = tok(" " + trigger, return_tensors="np").input_ids[0].tolist() + trigger_starts_at = len(ids) - len(trigger_ids) + dist_setup_to_end = len(ids) - len(setup_ids) + print(f" setup ends around token {len(setup_ids)}, trigger begins around " + f"token {trigger_starts_at}") + print(f" distance from setup to end of prompt ≈ {dist_setup_to_end} tokens " + f"(> W=1024 means SWA should have forgotten setup)") + + print("\n=== FULL (ctx=4096 full-attn) ===") + t0 = time.time() + full_gen, full_cfg = run_bundle(Path(args.full_bundle), tok, ids, + args.max_new_tokens, cu) + print(f" done in {time.time()-t0:.1f}s") + full_text = tok.decode(full_gen) + print(f" continuation: {full_text!r}") + + print("\n=== SWA (W=1024) ===") + t0 = time.time() + swa_gen, swa_cfg = run_bundle(Path(args.swa_bundle), tok, ids, + args.max_new_tokens, cu) + print(f" done in {time.time()-t0:.1f}s") + swa_text = tok.decode(swa_gen) + print(f" continuation: {swa_text!r}") + + # Scoring: does continuation contain "chartreuse"? + full_recall = "chartreuse" in full_text.lower() + swa_recall = "chartreuse" in swa_text.lower() + + agree = sum(1 for a, b in zip(full_gen, swa_gen) if a == b) + first_div = next((i for i, (a, b) in enumerate(zip(full_gen, swa_gen)) if a != b), + None) + + print(f"\n=== comparison ===") + print(f" top-1 agreement: {agree}/{len(full_gen)} " + f"({100*agree/len(full_gen):.1f}%)") + print(f" first divergence at gen index: " + f"{first_div if first_div is not None else 'NONE'}") + print(f" Full recalls 'chartreuse': {full_recall}") + print(f" SWA recalls 'chartreuse': {swa_recall}") + if full_recall and not swa_recall: + print(f" → long-range recall regression confirmed for this prompt") + elif full_recall and swa_recall: + print(f" → both recall: either setup is in window (filler too short) " + f"or model is robust via other cues") + elif not full_recall: + print(f" → full model didn't recall either; prompt may be too hard " + f"even for full attention") + + +if __name__ == "__main__": + main() diff --git a/conversion/experiments/bonsai/compare_swa_vs_full.py b/conversion/experiments/bonsai/compare_swa_vs_full.py new file mode 100644 index 0000000..689d97f --- /dev/null +++ b/conversion/experiments/bonsai/compare_swa_vs_full.py @@ -0,0 +1,163 @@ +"""Side-by-side divergence test: SWA W=1024 vs full-attention ctx=4096. + +Runs greedy decode on both bundles with the same prompt and prints: +- first divergence position (where top-1 tokens differ) +- per-step top-1 agreement rate +- final decoded outputs for eyeball comparison + +Expected: +- positions 0..W-1: 100% agreement (no wraparound yet) +- position >= W: SWA forgets tokens < (pos - W + 1); divergence accumulates +""" +from __future__ import annotations + +import argparse +import json +import time +from pathlib import Path + +import numpy as np +import coremltools as ct +from transformers import AutoTokenizer + + +def load_chunks(bundle: Path, cu: ct.ComputeUnit): + cfg = json.load(open(bundle / "bonsai_1_7b_decode_chunks/model_config.json")) + chunk_a = ct.models.MLModel(str(bundle / "bonsai_1_7b_decode_chunks/chunk_a.mlpackage"), + compute_units=cu) + chunk_b = ct.models.MLModel(str(bundle / "bonsai_1_7b_decode_chunks/chunk_b.mlpackage"), + compute_units=cu) + return chunk_a, chunk_b, cfg + + +def build_feeds(cfg: dict, tok_id: int, pos: int) -> dict: + swa = cfg.get("sliding_window") + ctx = cfg["context_length"] + pos_arr = np.array([pos], dtype=np.int32) + if swa is None: + L = ctx + write_slot = pos + valid_range = range(pos + 1) + else: + L = swa + write_slot = pos % swa + valid_count = min(pos + 1, swa) + valid_range = [((pos - i) % swa) for i in range(valid_count)] + causal = np.full((1, 1, 1, L), -1e4, dtype=np.float16) + for s in valid_range: + causal[0, 0, 0, s] = 0.0 + update = np.zeros((1, 1, L, 1), dtype=np.float16) + update[0, 0, write_slot, 0] = 1.0 + return { + "input_ids": np.array([[tok_id]], dtype=np.int32), + "position_ids": pos_arr, + "causal_mask": causal, + "update_mask": update, + } + + +def step(chunk_a, chunk_b, cfg, state_a, state_b, tok_id: int, pos: int) -> int: + feed = build_feeds(cfg, tok_id, pos) + out_a = chunk_a.predict(feed, state=state_a) + hidden = out_a["hidden"].astype(np.float16) + feed_b = {**feed, "hidden_in": hidden} + del feed_b["input_ids"] + out_b = chunk_b.predict(feed_b, state=state_b) + return int(out_b["token_id"].item()) + + +def run_bundle(bundle: Path, tok, prompt: str, max_new: int, cu): + chunk_a, chunk_b, cfg = load_chunks(bundle, cu) + sa = chunk_a.make_state() + sb = chunk_b.make_state() + ids = tok(prompt, return_tensors="np").input_ids[0].tolist() + + nxt = 0 + for i, tid in enumerate(ids): + nxt = step(chunk_a, chunk_b, cfg, sa, sb, tid, i) + # After prefill, next model position = len(ids); feed the next token predicted from + # the last prompt step (nxt) as input for that position, and so on greedily. + gen: list[int] = [] + cur = len(ids) - 1 + for _ in range(max_new): + cur += 1 + tok_in = gen[-1] if gen else nxt + nxt = step(chunk_a, chunk_b, cfg, sa, sb, tok_in, cur) + gen.append(nxt) + return ids, gen, cfg + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--full-bundle", required=True, help="non-SWA full-attention bundle dir") + ap.add_argument("--swa-bundle", required=True, help="SWA bundle dir") + ap.add_argument("--tokenizer", required=True) + ap.add_argument("--prompt", default="Once upon a time there was a small village in the mountains where") + ap.add_argument("--max-new-tokens", type=int, default=1100) + ap.add_argument("--compute-units", default="CPU_AND_NE") + args = ap.parse_args() + + cu = getattr(ct.ComputeUnit, args.compute_units) + tok = AutoTokenizer.from_pretrained(args.tokenizer) + + print(f"prompt: {args.prompt!r}") + + print("\n=== running FULL (non-SWA) ===") + t0 = time.time() + full_prompt, full_gen, full_cfg = run_bundle(Path(args.full_bundle), tok, + args.prompt, args.max_new_tokens, cu) + print(f" done in {time.time()-t0:.1f}s, ctx={full_cfg['context_length']}, " + f"swa={full_cfg.get('sliding_window')}") + + print("\n=== running SWA ===") + t0 = time.time() + swa_prompt, swa_gen, swa_cfg = run_bundle(Path(args.swa_bundle), tok, + args.prompt, args.max_new_tokens, cu) + print(f" done in {time.time()-t0:.1f}s, ctx={swa_cfg['context_length']}, " + f"swa={swa_cfg.get('sliding_window')}") + + W = swa_cfg.get("sliding_window", 0) + + # Find first divergence + first_div = None + agree = 0 + for i, (a, b) in enumerate(zip(full_gen, swa_gen)): + if a == b: + agree += 1 + elif first_div is None: + first_div = i + + print(f"\n=== comparison ===") + print(f" W={W}, prompt_tokens={len(full_prompt)}, gen_tokens={len(full_gen)}") + print(f" first divergence at gen index: " + f"{first_div if first_div is not None else 'NONE (bit-identical)'}") + if first_div is not None: + gen_pos_at_div = len(full_prompt) + first_div + print(f" → that's model position {gen_pos_at_div} (W={W}; " + f"tokens seen in non-SWA but not in SWA at this step: " + f"{max(0, gen_pos_at_div - W + 1)})") + print(f" total agreement: {agree}/{len(full_gen)} " + f"({100*agree/len(full_gen):.1f}%)") + + # Agreement buckets + buckets = [(0, W // 2), (W // 2, W), (W, W + W // 2), (W + W // 2, 2 * W)] + print(f"\n top-1 agreement by position bucket:") + for lo, hi in buckets: + # bucket is model position; gen index = pos - len(prompt) + lo_i = max(0, lo - len(full_prompt)) + hi_i = min(len(full_gen), hi - len(full_prompt)) + if lo_i >= hi_i: + continue + b_agree = sum(1 for a, b in zip(full_gen[lo_i:hi_i], swa_gen[lo_i:hi_i]) if a == b) + b_total = hi_i - lo_i + print(f" pos [{lo:>5}, {hi:>5}): {b_agree}/{b_total} " + f"({100*b_agree/b_total:.1f}%)") + + print(f"\n=== FULL decoded (first 300 chars) ===") + print(tok.decode(full_gen)[:300]) + print(f"\n=== SWA decoded (first 300 chars) ===") + print(tok.decode(swa_gen)[:300]) + + +if __name__ == "__main__": + main() diff --git a/conversion/experiments/bonsai/ternary_surgery.py b/conversion/experiments/bonsai/ternary_surgery.py new file mode 100644 index 0000000..5c49dd5 --- /dev/null +++ b/conversion/experiments/bonsai/ternary_surgery.py @@ -0,0 +1,302 @@ +"""Bit-exact ternary weight surgery for Bonsai Core ML mlpackages. + +Replaces each FP16 weight const in a Core ML mlpackage with a two-op compressed +chain that exactly reproduces the {-s_{r,b}, 0, +s_{r,b}} per-128-block structure +of Bonsai's native 1.58-bit encoding: + + weight[r, b*128+k] + = scale[r, b] * sign_codebook[ indices[r, b*128+k] ] + = constexpr_blockwise_shift_scale( + data = constexpr_lut_to_dense(indices=uint2(...), lut=[0,+1,-1,0]), + scale = fp16 per-(row,block) scale + ) + +This keeps the sign LUT small and shared (1,1,..,4,1) — ANE-friendly — and +factors per-row scale into a separate blockwise op. `reorder_lut_per_channel_scale` +(the coremltools pass that moves scale post-matmul) will apply at compile time if +the downstream is a linear/matmul/conv, letting the ANE run the quantized matmul. + +Usage: + python ternary_surgery.py --src --dst + python ternary_surgery.py --src /chunk_a.mlpackage \\ + --dst /chunk_a.mlpackage \\ + --block-size 128 +""" + +from __future__ import annotations + +import argparse +import shutil +import time +from pathlib import Path + +import numpy as np + +import coremltools as ct +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil import types +from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass +from coremltools.converters.mil.mil.passes.helper import block_context_manager +from coremltools.models.utils import _apply_graph_pass + + +SIGN_LUT_VALUES = np.array([0.0, 1.0, -1.0, 0.0], dtype=np.float16) + + +def encode_ternary(w: np.ndarray, block_size: int = 128): + """Encode a 2D or 4D-with-trailing-1s weight as (indices uint8, scale fp16). + + Returns: + indices: uint8 array same shape as `w`, values in {0,1,2,3}: + 0 → zero, 1 → +s, 2 → -s, 3 → unused. + scale: fp16 array shape (out, in/block_size) for 2D weights, or + (out, in/block_size, 1, 1) for (out, in, 1, 1) 4D weights. + Both in same numeric type as the data path uses. + """ + orig_shape = w.shape + if w.ndim == 4: + assert orig_shape[-1] == 1 and orig_shape[-2] == 1, ( + f"only trailing-1 4D weights supported, got {orig_shape}" + ) + w2 = w.reshape(orig_shape[0], orig_shape[1]) + elif w.ndim == 2: + w2 = w + else: + raise ValueError(f"unsupported weight rank {w.ndim}: shape {orig_shape}") + + out_dim, in_dim = w2.shape + if in_dim % block_size != 0: + raise ValueError( + f"in_dim {in_dim} not divisible by block_size {block_size}" + ) + num_blocks = in_dim // block_size + + w_blocks = w2.reshape(out_dim, num_blocks, block_size).astype(np.float32) + absval = np.abs(w_blocks) + scale_2d = absval.max(axis=-1) # (out, num_blocks) + + # Indices via sign + magnitude > 0.5 * scale + safe_scale = np.where(scale_2d == 0.0, 1.0, scale_2d).reshape(out_dim, num_blocks, 1) + normalized = w_blocks / safe_scale # ∈ [-1, 1] approximately, with values {-1, 0, +1} + # Threshold loosely to absorb any fp-noise. + indices_b = np.where( + normalized > 0.5, 1, + np.where(normalized < -0.5, 2, 0), + ).astype(np.uint8) # (out, num_blocks, block_size) + indices = indices_b.reshape(out_dim, in_dim) + + if w.ndim == 4: + indices = indices.reshape(orig_shape) + scale = scale_2d.reshape(orig_shape[0], num_blocks, 1, 1).astype(np.float16) + else: + scale = scale_2d.astype(np.float16) + + # Sanity: reconstruction round-trip + recon = SIGN_LUT_VALUES[indices_b].astype(np.float32) # (out, nb, bs) + recon = recon * safe_scale + recon = recon.reshape(out_dim, in_dim) + diff = np.abs(recon - w2.astype(np.float32)) + max_diff = float(diff.max()) + + return indices, scale, max_diff + + +def _is_target_weight(op, block_size: int = 128, min_numel: int = 1024) -> bool: + """Is this `const` op a Bonsai weight tensor we can ternarize?""" + if op.op_type != "const": + return False + # Access the materialized value via the op's output Var. + arr = op.outputs[0].val + if arr is None or not isinstance(arr, np.ndarray): + return False + if arr.dtype not in (np.float16, np.float32): + return False + if arr.ndim not in (2, 4): + return False + if arr.ndim == 4 and not (arr.shape[-1] == 1 and arr.shape[-2] == 1): + return False + # For a 2D or (out, in, 1, 1) 4D weight, axis=1 is the "in" axis along which + # blocks run — same convention as Bonsai's per-128-block scales. + in_dim = arr.shape[1] if arr.ndim >= 2 else 0 + if in_dim == 0 or in_dim % block_size != 0: + return False + if arr.size < min_numel: + return False + # Ternary probe: for ≥3 sampled rows, each 128-group must have + # exactly {-s, 0, +s} structure (3 unique values with +s == -(-s)). + flat = arr if arr.ndim == 2 else arr.reshape(arr.shape[0], arr.shape[1]) + out_dim = flat.shape[0] + sample_rows = [0, out_dim // 2, out_dim - 1] + for r in sample_rows: + probe = flat[r, :block_size] + u = np.unique(probe) + # Allow all-zero group (padding) only if we have a positive example elsewhere + nz = u[u != 0] + if len(u) > 3: + return False + if len(nz) == 2: + # Must be opposite-signed pair (ternary structure) + if not np.isclose(nz[0], -nz[1], rtol=1e-3): + return False + elif len(nz) == 1: + # Single non-zero means the whole block is {0, v} — still ternary-compatible + pass + elif len(nz) == 0: + # All zeros — could be padding row of a real ternary weight or a non-weight table + continue + else: + return False + # At least one of the sampled rows must have a non-zero structure; otherwise + # this is likely a trivial tensor (embed padding etc.) or a trig table. + any_nonzero = False + for r in sample_rows: + if np.any(flat[r] != 0): + # And a stronger check: the full row should be buildable from at most + # 16 unique values (matches Bonsai per-row ≤ 16 block scales × 3). + row_u = np.unique(flat[r]) + if len(row_u) <= 64: # generous; ternary rows typically ~33-49 unique + any_nonzero = True + break + return any_nonzero + + +def _make_uint2_indices(indices_uint8: np.ndarray): + """Convert uint8 {0,1,2,3} indices to the coremltools uint2 numpy dtype.""" + uint2_dt = types.nptype_from_builtin(types.string_to_builtin("uint2")) + return indices_uint8.astype(uint2_dt) + + +class TernaryPalettizePass(AbstractGraphPass): + """Pass that replaces Bonsai weight consts with bit-exact ternary constexpr chains.""" + + def __init__(self, block_size: int = 128, verbose: bool = True): + self.block_size = block_size + self.verbose = verbose + self.replaced = 0 + self.skipped = 0 + self.max_max_diff = 0.0 + self.bytes_before = 0 + self.bytes_after = 0 + + def apply(self, prog): + # `mb.()` requires a live block context; the decorator pushes each block + # onto the Builder's stack so our in-place ops insert correctly. + pass_self = self + + @block_context_manager + def _visit_block(block): + for op in list(block.operations): + for nested in op.blocks: + _visit_block(nested) + try: + if not _is_target_weight(op, pass_self.block_size): + continue + pass_self._replace_one(op) + except Exception as e: + pass_self.skipped += 1 + if pass_self.verbose: + print(f" skip {op.name}: {type(e).__name__}: {e}") + + for func in prog.functions.values(): + _visit_block(func) + + if self.verbose: + size_mb_before = self.bytes_before / 1e6 + size_mb_after = self.bytes_after / 1e6 + print(f"\nternary pass summary: replaced {self.replaced}, skipped {self.skipped}") + print(f" max reconstruction |diff|: {self.max_max_diff:.6f}") + print(f" weight bytes {size_mb_before:.0f} MB → {size_mb_after:.0f} MB " + f"({100 * size_mb_after / max(size_mb_before, 1):.1f}%)") + + def _replace_one(self, op): + w = op.outputs[0].val + self.bytes_before += w.nbytes + + indices, scale, max_diff = encode_ternary(w, self.block_size) + if max_diff > 1e-3 and self.verbose: + print(f" {op.name}: shape={w.shape} max_diff={max_diff:.6f} " + f"w.max={np.abs(w).max():.6f}") + self.max_max_diff = max(self.max_max_diff, max_diff) + + indices_u2 = _make_uint2_indices(indices) + + # Build a per-row-per-block LUT with the scale baked in, so each + # (row, block) has its own 4-entry codebook = [0, +s, -s, 0]. This is + # a single-op replacement (no constexpr_blockwise_shift_scale), which + # avoids ANE compile rejection we hit with the 2-op chain. + # + # scale shape: (out, num_blocks) for 2D, (out, num_blocks, 1, 1) for 4D. + # LUT rank = indices_rank + 2, with per-(row,block) palette: + # 2D indices (out, in) → lut (out, num_blocks, 4, 1) + # 4D indices (out, in, 1, 1) → lut (out, num_blocks, 1, 1, 4, 1) + if w.ndim == 2: + s = scale.astype(np.float16) # (out, num_blocks) + out_dim, num_blocks = s.shape + lut = np.zeros((out_dim, num_blocks, 4, 1), dtype=np.float16) + lut[..., 1, 0] = s # +s + lut[..., 2, 0] = -s # -s + # entries 0 and 3 stay 0.0 + else: # 4D (out, in, 1, 1) + s2d = scale.reshape(scale.shape[0], scale.shape[1]).astype(np.float16) + out_dim, num_blocks = s2d.shape + lut = np.zeros((out_dim, num_blocks, 1, 1, 4, 1), dtype=np.float16) + lut[..., 1, 0] = s2d.reshape(out_dim, num_blocks, 1, 1) + lut[..., 2, 0] = -s2d.reshape(out_dim, num_blocks, 1, 1) + + new_var = mb.constexpr_lut_to_dense( + indices=indices_u2, lut=lut, + before_op=op, name=op.name + "_tern", + ) + + op.enclosing_block.replace_uses_of_var_after_op( + anchor_op=op, old_var=op.outputs[0], new_var=new_var, + no_check_var_types=True, force_replace=True, + ) + op.enclosing_block.remove_ops([op]) + + self.replaced += 1 + # bytes_after: uint2 indices + per-(row,block) fp16 LUT (4 entries of which + # 2 are meaningful; still compact compared to fp16 weights). + self.bytes_after += (indices.size * 2 + 7) // 8 # uint2 + self.bytes_after += lut.nbytes + + +def run(src: Path, dst: Path, block_size: int = 128) -> None: + print(f"Loading {src}") + m = ct.models.MLModel(str(src)) + + pass_inst = TernaryPalettizePass(block_size=block_size, verbose=True) + print("Running ternary MIL surgery...") + t0 = time.time() + out_model = _apply_graph_pass( + m, pass_inst, + skip_model_load=True, # avoid forcing compile before save + ) + print(f" pass applied in {time.time() - t0:.1f}s") + + if dst.exists() and dst != src: + shutil.rmtree(dst) + elif dst == src: + tmp = dst.with_suffix(".mlpackage.new") + if tmp.exists(): + shutil.rmtree(tmp) + out_model.save(str(tmp)) + shutil.rmtree(dst) + shutil.move(str(tmp), str(dst)) + return + out_model.save(str(dst)) + size_mb = sum(f.stat().st_size for f in dst.rglob("*") if f.is_file()) / 1e6 + print(f"Saved {dst.name} ({size_mb:.0f} MB)") + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--src", required=True) + ap.add_argument("--dst", required=True) + ap.add_argument("--block-size", type=int, default=128) + args = ap.parse_args() + run(Path(args.src), Path(args.dst), args.block_size) + + +if __name__ == "__main__": + main() diff --git a/conversion/experiments/bonsai/test_bonsai_chunks_inference.py b/conversion/experiments/bonsai/test_bonsai_chunks_inference.py new file mode 100644 index 0000000..1799e1b --- /dev/null +++ b/conversion/experiments/bonsai/test_bonsai_chunks_inference.py @@ -0,0 +1,144 @@ +"""Chained decode test for the 2-chunk Bonsai build. + +Measures tok/s end-to-end on Mac ANE (and by proxy the iPhone ceiling). +Runs per-token: chunk_a(input_id, pos, ...) → hidden → chunk_b(hidden, pos, ...) → token. +Prefills the prompt then greedily decodes N tokens. +""" + +from __future__ import annotations + +import argparse +import json +import time +from pathlib import Path + +import numpy as np +import coremltools as ct +from transformers import AutoTokenizer + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--bundle", required=True, + help="Dir containing bonsai_1_7b_decode_chunks/{chunk_a,chunk_b}.mlpackage") + ap.add_argument("--tokenizer", required=True, help="HF model dir for tokenizer") + ap.add_argument("--prompt", default="The capital of France is") + ap.add_argument("--max-new-tokens", type=int, default=30) + ap.add_argument("--compute-units", default="CPU_AND_NE", + choices=["CPU_ONLY", "CPU_AND_NE", "CPU_AND_GPU", "ALL"]) + args = ap.parse_args() + + bundle = Path(args.bundle).expanduser() / "bonsai_1_7b_decode_chunks" + cfg = json.load(open(bundle / "model_config.json")) + ctx = cfg["context_length"] + swa = cfg.get("sliding_window") # None = full attention; int = SWA window size + mask_len = swa if swa is not None else ctx + print(f" ctx={ctx}, sliding_window={swa}, attn_len={mask_len}") + + cu = getattr(ct.ComputeUnit, args.compute_units) + print(f"Loading chunks from {bundle}, compute_units={args.compute_units}") + t0 = time.time() + chunk_a = ct.models.MLModel(str(bundle / "chunk_a.mlpackage"), compute_units=cu) + t_a = time.time() - t0 + t0 = time.time() + chunk_b = ct.models.MLModel(str(bundle / "chunk_b.mlpackage"), compute_units=cu) + t_b = time.time() - t0 + print(f" chunk_a loaded in {t_a:.1f}s, chunk_b in {t_b:.1f}s") + + tok = AutoTokenizer.from_pretrained(args.tokenizer) + prompt_ids = tok(args.prompt, return_tensors="np").input_ids[0].tolist() + print(f"prompt: {args.prompt!r}") + print(f" tokens ({len(prompt_ids)}): {prompt_ids}") + + state_a = chunk_a.make_state() + state_b = chunk_b.make_state() + + def step(tok_id: int, pos: int) -> tuple[int, float]: + pos_arr = np.array([pos], dtype=np.int32) + + if swa is None: + # Full-attention: state buffer = ctx. Write at absolute `pos`. + # Valid positions = [0, pos]. + L = ctx + write_slot = pos + valid_range = range(pos + 1) + else: + # SWA: state buffer = W. Write at `pos % W` (circular slot). After + # the first W steps, every slot holds a valid (position-encoded) K/V + # so attention attends to the whole buffer — ordering doesn't + # matter because softmax is permutation-invariant and RoPE is + # already baked into cached K at write time. + L = swa + write_slot = pos % swa + # Early prefill: only slots that have been written are valid. + # After pos >= W-1, all W slots are valid. + valid_count = min(pos + 1, swa) + # Valid slots are the `valid_count` most-recently-written positions, + # i.e. all slots whose last write was at position > pos - valid_count. + # For simplicity, during warm-up we mark all W slots valid (including + # the freshly-overwritten zeros) once we have written >= 1 token; + # this is equivalent to StreamingLLM without sinks and matches the + # speed characteristics. For strict correctness during warm-up, use + # only `valid_count` slots: + valid_range = [((pos - i) % swa) for i in range(valid_count)] + + causal = np.full((1, 1, 1, L), -1e4, dtype=np.float16) + for s in valid_range: + causal[0, 0, 0, s] = 0.0 + update = np.zeros((1, 1, L, 1), dtype=np.float16) + update[0, 0, write_slot, 0] = 1.0 + + feed_a = { + "input_ids": np.array([[tok_id]], dtype=np.int32), + "position_ids": pos_arr, + "causal_mask": causal, + "update_mask": update, + } + out_a = chunk_a.predict(feed_a, state=state_a) + hidden = out_a["hidden"] + + feed_b = { + "hidden_in": hidden.astype(np.float16), + "position_ids": pos_arr, + "causal_mask": causal, + "update_mask": update, + } + out_b = chunk_b.predict(feed_b, state=state_b) + return int(out_b["token_id"].item()), float(out_b["token_logit"].item()) + + # Prefill: step through prompt, recording per-step time + prefill_times: list[float] = [] + for pos, tid in enumerate(prompt_ids): + t1 = time.time() + next_id, next_logit = step(tid, pos) + prefill_times.append(time.time() - t1) + + print(f" prefill {len(prompt_ids)} steps, avg {np.mean(prefill_times)*1000:.1f} ms") + print(f" first gen token: {next_id} ({tok.decode([next_id])!r}), " + f"logit={next_logit:.3f}") + + generated = [next_id] + decode_times: list[float] = [] + cur = len(prompt_ids) - 1 + for _ in range(args.max_new_tokens - 1): + cur += 1 + t1 = time.time() + next_id, next_logit = step(generated[-1], cur) + decode_times.append(time.time() - t1) + generated.append(next_id) + + print(f"\ngenerated {len(generated)} tokens: {generated[:20]}...") + print(f"decoded: {tok.decode(generated)!r}") + print(f"\noverall:") + print(f" prefill: {np.mean(prefill_times)*1000:.1f} ms/tok " + f"({1/np.mean(prefill_times):.1f} tok/s)") + if decode_times: + print(f" decode: {np.mean(decode_times)*1000:.1f} ms/tok " + f"({1/np.mean(decode_times):.1f} tok/s)") + p50 = float(np.median(decode_times)) + p95 = float(np.percentile(decode_times, 95)) + print(f" decode p50/p95: {p50*1000:.1f} / {p95*1000:.1f} ms/tok") + + +if __name__ == "__main__": + main() diff --git a/conversion/experiments/bonsai/test_bonsai_inference.py b/conversion/experiments/bonsai/test_bonsai_inference.py new file mode 100644 index 0000000..ff1a0ee --- /dev/null +++ b/conversion/experiments/bonsai/test_bonsai_inference.py @@ -0,0 +1,103 @@ +"""Quick smoke test of the converted Bonsai model. + +Loads the saved .mlpackage, runs a single decode step, and prints the predicted token. +Validates: model loads, predict() works, output shape / dtype correct, first-token +prediction is sane. +""" + +from __future__ import annotations + +import argparse +import json +import time +from pathlib import Path + +import numpy as np +import coremltools as ct +from transformers import AutoTokenizer + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--bundle", required=True, help="Dir with model.mlpackage + model_config.json") + ap.add_argument("--tokenizer", required=True, help="HF model dir for tokenizer") + ap.add_argument("--prompt", default="The capital of France is") + ap.add_argument("--max-new-tokens", type=int, default=10) + ap.add_argument("--compute-units", default="CPU_AND_NE", + choices=["CPU_ONLY", "CPU_AND_NE", "CPU_AND_GPU", "ALL"]) + args = ap.parse_args() + + bundle = Path(args.bundle).expanduser() + pkg_path = bundle / "model.mlpackage" + cfg_path = bundle / "model_config.json" + with open(cfg_path) as f: + cfg = json.load(f) + + ctx = cfg["context_length"] + print(f"Loading {pkg_path}") + cu = getattr(ct.ComputeUnit, args.compute_units) + t0 = time.time() + model = ct.models.MLModel(str(pkg_path), compute_units=cu) + print(f" loaded in {time.time()-t0:.1f}s, compute_units={args.compute_units}") + + tok = AutoTokenizer.from_pretrained(args.tokenizer) + input_ids = tok(args.prompt, return_tensors="np").input_ids[0].tolist() + print(f"prompt: {args.prompt!r}") + print(f" tokens ({len(input_ids)}): {input_ids}") + + # Per-step decode (no batched prefill for this smoke test). + state = model.make_state() + generated: list[int] = [] + cur_pos = 0 + step_times = [] + + for pos, tok_id in enumerate(input_ids): + causal_mask = np.full((1, 1, 1, ctx), -1e4, dtype=np.float16) + causal_mask[0, 0, 0, : pos + 1] = 0.0 + update_mask = np.zeros((1, 1, ctx, 1), dtype=np.float16) + update_mask[0, 0, pos, 0] = 1.0 + feed = { + "input_ids": np.array([[tok_id]], dtype=np.int32), + "position_ids": np.array([pos], dtype=np.int32), + "causal_mask": causal_mask, + "update_mask": update_mask, + } + t1 = time.time() + out = model.predict(feed, state=state) + step_times.append(time.time() - t1) + cur_pos = pos + next_id = int(out["token_id"].item()) + next_logit = float(out["token_logit"].item()) + print(f" prefilled {len(input_ids)} tokens, avg {np.mean(step_times)*1000:.1f} ms/step") + print(f" first gen token: {next_id} ({tok.decode([next_id])!r}), logit={next_logit:.3f}") + + generated.append(next_id) + # Continue greedy decode + for i in range(args.max_new_tokens - 1): + cur_pos += 1 + causal_mask = np.full((1, 1, 1, ctx), -1e4, dtype=np.float16) + causal_mask[0, 0, 0, : cur_pos + 1] = 0.0 + update_mask = np.zeros((1, 1, ctx, 1), dtype=np.float16) + update_mask[0, 0, cur_pos, 0] = 1.0 + feed = { + "input_ids": np.array([[generated[-1]]], dtype=np.int32), + "position_ids": np.array([cur_pos], dtype=np.int32), + "causal_mask": causal_mask, + "update_mask": update_mask, + } + t1 = time.time() + out = model.predict(feed, state=state) + step_times.append(time.time() - t1) + next_id = int(out["token_id"].item()) + generated.append(next_id) + + cont = tok.decode(generated) + print(f"\ngenerated {len(generated)} tokens: {generated}") + print(f"decoded: {cont!r}") + total_toks = len(input_ids) + len(generated) + print(f"\noverall: {total_toks} steps, avg {np.mean(step_times)*1000:.1f} ms/step " + f"({1.0 / np.mean(step_times):.1f} tok/s decode throughput)") + + +if __name__ == "__main__": + main() diff --git a/conversion/experiments/bonsai/verify_bonsai_ternary.py b/conversion/experiments/bonsai/verify_bonsai_ternary.py new file mode 100644 index 0000000..b5f863a --- /dev/null +++ b/conversion/experiments/bonsai/verify_bonsai_ternary.py @@ -0,0 +1,123 @@ +"""Verify that Bonsai's FP16 unpacked weights really are ternary per 128-group. + +If yes → nbits=2 + mode="unique" + per_grouped_channel + group_size=128 gives +a bit-exact Core ML palettization. If the groups have float noise around the +3 ternary centroids, we need a custom LUT builder. + +For each sampled weight tensor we report the distribution of unique-values-per-group +and the actual values in a few example groups. +""" +from __future__ import annotations + +import argparse +from pathlib import Path + +import numpy as np +import safetensors.torch + + +SAMPLE_TENSORS = [ + "model.embed_tokens.weight", + "model.layers.0.self_attn.q_proj.weight", + "model.layers.0.self_attn.k_proj.weight", + "model.layers.0.self_attn.v_proj.weight", + "model.layers.0.self_attn.o_proj.weight", + "model.layers.5.mlp.gate_proj.weight", + "model.layers.5.mlp.up_proj.weight", + "model.layers.5.mlp.down_proj.weight", + "model.layers.13.self_attn.q_proj.weight", + "model.layers.27.self_attn.o_proj.weight", + "model.layers.27.mlp.down_proj.weight", +] + + +def analyze_tensor(name: str, w: np.ndarray, group_size: int, sample_rows: int, + axis: int) -> None: + print(f"\n=== {name} shape={w.shape} dtype={w.dtype} " + f"group along axis={axis} ===") + if w.ndim != 2: + print(f" skipping ({w.ndim}d, only 2d supported)") + return + + # Always group along axis; bring it to last for simplicity. + if axis == 0: + w2 = w.T + else: + w2 = w + rows, cols = w2.shape + if cols % group_size != 0: + print(f" skipping ({cols} % {group_size} != 0)") + return + n_groups_per_row = cols // group_size + rows_to_sample = min(sample_rows, rows) + + uniq_counts: list[int] = [] + example_group_values: list[np.ndarray] = [] + + for ri in np.linspace(0, rows - 1, rows_to_sample).astype(int): + row = w2[ri] + for gi in range(n_groups_per_row): + group = row[gi * group_size : (gi + 1) * group_size] + u = np.unique(group) + uniq_counts.append(len(u)) + if len(example_group_values) < 3: + example_group_values.append(u) + + cnt = np.array(uniq_counts) + total = len(cnt) + print(f" sampled {total} groups across {rows_to_sample} rows × {n_groups_per_row} groups") + # Distribution + for k in [1, 2, 3, 4, 5, 6, 8, 16, 32, 64, 128]: + n = int((cnt == k).sum()) + if n > 0: + print(f" exactly {k:>3} unique: {n:>6} ({100 * n / total:5.1f}%)") + # Buckets for "weird" groups + le3 = int((cnt <= 3).sum()) + le4 = int((cnt <= 4).sum()) + le8 = int((cnt <= 8).sum()) + more = int((cnt > 8).sum()) + print(f" cumulative: ≤3 unique = {100 * le3 / total:.2f}%, " + f"≤4 = {100 * le4 / total:.2f}%, " + f"≤8 = {100 * le8 / total:.2f}%, " + f">8 = {100 * more / total:.2f}%") + print(f" max unique in any group: {int(cnt.max())}, " + f"min: {int(cnt.min())}, mean: {cnt.mean():.2f}") + for i, u in enumerate(example_group_values): + vals_str = ", ".join(f"{v:+.6f}" for v in u[: min(10, len(u))]) + if len(u) > 10: + vals_str += ", ..." + print(f" example group {i}: {len(u)} unique — [{vals_str}]") + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--model-path", default="../output/bonsai/hf_model") + ap.add_argument("--group-size", type=int, default=128) + ap.add_argument("--sample-rows", type=int, default=20, + help="rows sampled per tensor (cost vs coverage)") + ap.add_argument("--axis", type=int, default=1, choices=[0, 1], + help="axis to group along; 1=last dim (in-channel), 0=first dim") + args = ap.parse_args() + + weights_path = Path(args.model_path) / "model.safetensors" + print(f"Loading {weights_path}") + state = safetensors.torch.load_file(str(weights_path)) + print(f" {len(state)} tensors loaded") + + for name in SAMPLE_TENSORS: + if name not in state: + print(f"\n=== {name}: NOT FOUND ===") + continue + t = state[name].float().cpu().numpy() + analyze_tensor(name, t, args.group_size, args.sample_rows, args.axis) + + print("\n=== interpretation ===") + print(" If most tensors show ≤3 unique values per group (>99%) → bit-exact") + print(" palettization via nbits=2 + mode='unique' + per_grouped_channel +") + print(f" group_size={args.group_size} should work losslessly.") + print(" If 4+ unique values appear often, the unpacked FP16 has numerical") + print(" noise around the ternary centroids → needs custom LUT construction.") + + +if __name__ == "__main__": + main() diff --git a/docs/ADDING_MODELS.md b/docs/ADDING_MODELS.md index 7447411..00fefe0 100644 --- a/docs/ADDING_MODELS.md +++ b/docs/ADDING_MODELS.md @@ -89,6 +89,27 @@ print(model.model.layers[0].self_attn.scaling) **Getting this wrong produces coherent but completely wrong text.** +### 4.5 Decode-Time KV State Layout — Critical for ANE + +Before writing the wrapper, decide the decode-path state layout. + +- **Monolithic INT4/INT8 > ~1.4 GB silently falls back to GPU on ANE.** + Plan chunking upfront if your param count × nbits/8 exceeds this. +- **Default to mask-based state writes, not shift-based `cat`.** ANEC + rejects shift-based `cat([K[:,:,1:,:], k], dim=2)` for Qwen3 + Stateful + + tied-embedding combinations (error code -14). Mask-based rotating + buffer works on the same arch patterns that ship today plus the ones + that fail. See `docs/DECODE_STATE_LAYOUTS.md` §3 for the pattern. +- **Per-step cost is `O(state_length)` on ANE.** Start with + `context_length=1024`, measure, then decide. For longer effective ctx, + use sliding-window attention (mask-based rotating, W=1024 default). +- **Palettize with `mode="kmeans"` first.** Linear INT8 div-by-zero on + sparse tensors; kmeans is the safer default. See `docs/DECODE_STATE_LAYOUTS.md` §4. +- **Parity test before CoreML conversion** (PyTorch HF vs your ANE model). + Pattern: `conversion/experiments/bonsai/bonsai_reference_oracle.py`. + +Full checklist: `docs/DECODE_STATE_LAYOUTS.md` §7. + ### 5. Create Wrapper (if needed) If the model uses the same structure as Qwen2 (standard GQA, 2 norms, SiLU, no special features), the default `MonolithicWrapper` in `exporter.py` works. diff --git a/docs/ANE_OPTIMIZATION_SURVEY.md b/docs/ANE_OPTIMIZATION_SURVEY.md index a7b875f..64056cd 100644 --- a/docs/ANE_OPTIMIZATION_SURVEY.md +++ b/docs/ANE_OPTIMIZATION_SURVEY.md @@ -10,6 +10,13 @@ ReDrafter (arXiv 2403.09919), SwiftKV ## Executive Summary — Top 5 New Findings +0. **[2026-04-24 update] Per-step decode cost on ANE is `O(state_length)`, + not weight-bandwidth** — Ternary-Bonsai-1.7B measurements: halving ctx + (2048 → 1024) gave 2.56× decode speedup; halving weights (INT8 → INT4 + at the same ctx) gave only +12%. For any new 1-2B model, start at + ctx=1024 and switch to mask-based rotating SWA if you need longer + effective context. See `docs/DECODE_STATE_LAYOUTS.md`. + 1. **Prefill bypass (TTFT -40%)** — Apple's AFM paper reveals: L15-34 never produce KV during prefill. Skip chunk3+4 for all prompt tokens except the last one. Zero model changes, zero quality loss. (AFM / SwiftKV) diff --git a/docs/DECODE_STATE_LAYOUTS.md b/docs/DECODE_STATE_LAYOUTS.md new file mode 100644 index 0000000..8a789ea --- /dev/null +++ b/docs/DECODE_STATE_LAYOUTS.md @@ -0,0 +1,252 @@ +# Decode-Time KV State Layouts — Lessons + +**Last updated:** 2026-04-24 (from Ternary-Bonsai-1.7B port) + +Actionable knowledge about how Core ML / ANE compiles KV cache update patterns. +Read this before designing the decode path of a new model or reworking an +existing one. + +## TL;DR + +1. On ANE, per-step cost is dominated by `O(state_length)` in attention (not + weight bandwidth). Shrinking state length is the biggest single lever. +2. For Stateful models, **mask-based circular rotating buffer** is a strictly + safer default than shift-based `cat`. Same semantics, strictly wider op + compatibility with ANEC. Use it unless you have a measured reason not to. +3. Palettize with `mode="kmeans"` by default. Linear quantization is a + div-by-zero hazard on sparse tensors. +4. If you're tracing a module that mutates a registered buffer (KV cache), + pass `strict=False` to `torch.jit.trace`. The warning is benign; the graph + is correct. + +## 1. `O(state_length)` is the decode-time bottleneck on ANE + +Measured on Ternary-Bonsai-1.7B INT4 kmeans, 2-chunk, Mac ANE (M-series), +"The capital of France is" decode: + +| ctx / state | decode tok/s | +|---|---| +| 2048 | 9.4 | +| 1024 | 24.1 (2.56× vs 2048) | +| 4096 | 4.9 | +| 4096 + SWA W=1024 | 25.6 | + +Halving state gave **2.56×**, which is much larger than the weight-bandwidth +explanation (INT8 → INT4 only gave +12% for the same ctx). Attention +softmax-over-state and KV-cache state read/write dominate per-step latency +for 1–2B class models on ANE. + +**Corollary**: for any new model, **start with ctx=1024** and measure before +pushing higher. If you need larger effective context, switch to SWA — see §3. + +Qwen3.5-2B handoff (`docs/QWEN35_2B_CHUNKED_HANDOFF.md` §2) made the same +observation for the 2B class; Bonsai-1.7B confirms it for the 1.7B class. + +## 2. Monolithic bundles > ~1.4 GB silently fall back to GPU + +Confirmed a third time (after Gemma 4 E4B and Qwen3.5-2B): Ternary-Bonsai-1.7B +monolithic INT8 at 1.94 GB runs at 8.3 tok/s on Mac ANE, but audit shows 0% +ANE placement — Core ML routes to GPU without error because +`MLComputeUnits.cpuAndNeuralEngine` is a preference, not a requirement. + +**Rule of thumb**: if `du -sh model.mlpackage` > ~1.0–1.4 GB, +you will hit silent GPU fallback on iPhone ANE and likely jetsam-kill on +load. Split into chunks. See `docs/QWEN35_2B_CHUNKED_HANDOFF.md` §3 for the +chunking pattern. + +## 3. Mask-based circular rotating buffer vs shift-based `cat` + +**The problem** (discovered on Bonsai / Qwen3 port): + +```python +# gemma4_swa_chunks.py-style shift-based SWA update: +K_new = torch.cat([K_cache[:, :, 1:, :], k], dim=2) +``` + +For Qwen3ForCausalLM + `ct.StateType` + `tie_word_embeddings=True`, Apple's +ANEC compiler **rejects this pattern** with `error code: -14`. The produced +mlpackage loads but: +- `ct.models.MLModel(path).get_compiled_model_path()` throws. +- `MLModel.make_state()` throws with "This model was not loaded with the Core + ML Framework." +- Opening with `CPU_ONLY` compute units succeeds, so the graph itself is + valid; the failure is ANE-specific op-lowering. + +The exact conditions that trigger -14 are architecture-sensitive. Gemma 4's +sliding layers use the same `cat`-shift pattern and **do** ship on ANE. Our +best current guess: interaction between the shift op, the tied-embedding +lm_head, and Stateful buffer aliasing. + +**The fix**: mask-based circular rotating buffer. + +```python +# SWA, state buffer sized = W (not ctx): +# Host passes update_mask with 1.0 at `pos % W`, 0 elsewhere. +k_broadcast = k.expand_as(K_cache) # (1, kv, W, d) +K_new = K_cache * (1 - update_mask) + k_broadcast * update_mask +``` + +This is the exact op pattern used for non-SWA mask-based absolute-position +writes — so it inherits the ANE-proven lowering path. With this pattern the +Bonsai-1.7B build hits 92% ANE placement at INT4, 25.6 tok/s at ctx=4096 +effective / W=1024. + +**Correctness**: +- RoPE is applied to K *before* the blend, so cached K holds position-encoded + values. The slot index in the buffer is independent of position — just a + physical location. +- After wraparound, slot order is scrambled (e.g. slot 0 holds pos W, slot 1 + holds pos 1, …). Attention softmax is permutation-invariant over the key + axis, so this is fine. +- During warm-up (pos < W-1), unfilled slots are masked to `-1e4` in + `causal_mask`. Mask value -1e4 (not `-inf`) is the ANE FP16 convention. + +**Host-side per step**: +```python +W = sliding_window +write_slot = pos % W +update_mask = np.zeros((1, 1, W, 1), dtype=np.float16) +update_mask[0, 0, write_slot, 0] = 1.0 +causal_mask = np.full((1, 1, 1, W), -1e4, dtype=np.float16) +valid_count = min(pos + 1, W) +for i in range(valid_count): + causal_mask[0, 0, 0, (pos - i) % W] = 0.0 # most-recent valid slots +``` + +Reference build: `conversion/build_bonsai_17b_decode_chunks.py` +Reference host-side: `conversion/test_bonsai_chunks_inference.py` + +### When to prefer shift-based `cat` anyway + +- Your model is Gemma 4 (already ships with shift, production-proven). +- You measured a speedup from shift over mask-based on your specific arch. +- You are not using `ct.StateType` (non-stateful, KV passed as I/O tensors). + +Otherwise, default to mask-based rotating. + +## 4. Palettization: kmeans INT4 first, not linear INT8 + +Bonsai-1.7B INT8 linear quantization logged multiple +`RuntimeWarning: invalid value encountered in divide / cast` during +`linear_quantize_weights`, caused by zero-valued tensors in some layers +(scale = max_abs/127 → 0 → div-by-zero NaN → cast to int8 produces garbage). +The model still loaded but first-token logits were less stable. + +k-means palettization (`mode="kmeans"`) is centroid-based, doesn't hit +div-by-zero on zero tensors, and compresses to the same disk size. INT4 +kmeans is usually the sweet spot for ANE — further quantization (INT3, INT2) +has dramatically less compiler / kernel support. + +Reference: `ct.optimize.coreml.OpPalettizerConfig(mode="kmeans", nbits=4, granularity="per_tensor")`. + +Observed first-token logit stability for "Paris" prompt (INT8 linear → +INT4 kmeans): 17.938 → 18.234 (~1.6% delta, top-1 preserved, top-3 preserved). + +## 5. Tracing gotchas for Stateful models + +**Problem**: `torch.jit.trace` re-runs the traced function to validate outputs. +If the module mutates a registered buffer (kv_cache), the second run sees +different state and the tracer logs: + +``` +TracerWarning: Output nr 1. of the traced function does not match... +``` + +The graph is still correct — the mutation is captured, the validation just +runs a different branch. + +**Fix**: pass `strict=False`. + +```python +traced = torch.jit.trace(module, sample_inputs, strict=False) +``` + +Match the pattern in `conversion/build_qwen35_2b_decode_chunks.py`. + +## 6. `audit_ane` can throw on -14 models; wrap it + +If conversion logs `error code: -14` (ANE compiler rejection), the saved +mlpackage is a CPU-stub. `get_compiled_model_path()` will throw with "This +model was not loaded or compiled with the Core ML Framework." Don't let that +kill your build pipeline — audit is diagnostic only. + +```python +def audit_ane(pkg_path: Path) -> float: + try: + m = ct.models.MLModel(str(pkg_path), compute_units=ct.ComputeUnit.CPU_AND_NE) + compiled = m.get_compiled_model_path() + plan = ct.models.compute_plan.MLComputePlan.load_from_path( + path=str(compiled), compute_units=ct.ComputeUnit.CPU_AND_NE, + ) + except Exception as e: + print(f" ANE audit skipped: {e}") + return -1.0 + # ... placement counting +``` + +The real signal (the model is bad, usually from shift-cat or unsupported +op pattern) is already in the saved-time warning; the audit crash is just +noise. + +## 7. Checklist for a new model's decode path + +- [ ] Reference HF arch. If it has `tie_word_embeddings=True` and you plan + to use `ct.StateType`, **default to mask-based rotating, not shift**. +- [ ] Measure monolithic INT4 bundle size. If > ~1.4 GB, chunk (see + `docs/QWEN35_2B_CHUNKED_HANDOFF.md` §3). +- [ ] Start with `context_length=1024`. Decide on larger ctx only after + measuring per-step latency at 1024. +- [ ] For ctx > 1024: use SWA (mask-based rotating) with W=1024 by default. + State buffer size = W, not ctx. +- [ ] Palettize with `mode="kmeans"` first. Only try linear quant if kmeans + produces measured quality regression. +- [ ] `torch.jit.trace(..., strict=False)` for stateful graphs. +- [ ] Wrap `audit_ane` in try/except; -14 warnings kill the check. +- [ ] Parity test (PyTorch vs our-model) before CoreML conversion. + See `conversion/bonsai_reference_oracle.py` for the pattern. + +## 8. Per-(row, block) palette is rejected by ANEC + +For weight quantization schemes where each (row, block) of a matrix needs its +own scale — e.g., 1.58-bit ternary BitNet-style models like Bonsai — Core ML +has the right MIL ops (`constexpr_lut_to_dense` with multi-axis LUT, or +`constexpr_lut_to_dense` + `constexpr_blockwise_shift_scale` two-op chain), +and the resulting mlpackage saves and validates fine. + +**But Apple's ANE compiler rejects both forms** with `error code: -14`. The +saved model loads as a CPU-stub: `MLModel(...)` returns successfully, but +`get_compiled_model_path()` and `make_state()` both throw "This model was not +loaded with the Core ML Framework." Tested with iOS 18 / coremltools 9.0 on +Qwen3 + Stateful + tied embed. + +ANE in iOS 18 supports `per_tensor` and `per_grouped_channel` palette +granularities (one LUT shared across the tensor, or one LUT per N +output-channels). It does **not** support a separate LUT per `(row, block)` +pair, which is what BitNet/Bonsai need to preserve their per-block +independent scales. + +The available approximation — `nbits=2 per_grouped_channel + enable_per_channel_scale` +— compiles for ANE but factorizes the scale matrix as `s_{r,b} ≈ c_b · d_r` +(rank-1 outer product), losing the per-(row, block) independence. For models +whose value is precisely that independence, this defeats the purpose. + +**Practical guidance**: +- Don't try to bit-exact ternary / 1.58-bit on ANE today. Either ship an + approximation (per-tensor / per-channel kmeans) and accept that you've + effectively quantized a Qwen3/Llama equivalent — or ship via MLX (Apple + Silicon GPU), which natively executes packed ternary via `mx.quantized_matmul`. +- This applies to: Bonsai, BitNet b1.58, Era of 1-bit LLMs and any + derivative with per-block scales. + +Investigation details: [`TERNARY_BONSAI.md`](TERNARY_BONSAI.md). + +## 9. Related docs + +- `docs/TERNARY_BONSAI.md` — the Bonsai-specific landing page. +- `docs/QWEN35_2B_CHUNKED_HANDOFF.md` — original chunking + jetsam findings. +- `docs/ANE_OPTIMIZATION_SURVEY.md` — broader ANE tricks (prefill bypass, + ping-pong buffers, etc.). +- `docs/ADDING_MODELS.md` — the end-to-end "I want to add a new model" + walkthrough; this doc is its decode-path companion. +- `docs/GEMMA4_ROTATING_BUFFER_PORT.md` — applies this decode-path knowledge + to Gemma 4. diff --git a/docs/GEMMA4_ROTATING_BUFFER_PORT.md b/docs/GEMMA4_ROTATING_BUFFER_PORT.md new file mode 100644 index 0000000..fa900a2 --- /dev/null +++ b/docs/GEMMA4_ROTATING_BUFFER_PORT.md @@ -0,0 +1,195 @@ +# Porting the Bonsai Mask-Based Rotating Buffer to Gemma 4 + +**Last updated:** 2026-04-24 + +This document analyzes whether the circular mask-based rotating KV buffer +technique proven on Ternary-Bonsai-1.7B (see `DECODE_STATE_LAYOUTS.md` §3) +is worth porting to Gemma 4 E2B / E4B. Short answer: **yes, for the 7 +full-attention layers** — that's where Gemma 4's decode-time bottleneck sits +and where the technique maps cleanly. + +## Where Gemma 4 is today + +From `conversion/models/gemma4.py:70-77` and `gemma4_swa_chunks.py`: + +- **E2B**: 35 layers = 28 sliding + 7 full. `layer_types[i] == "full_attention"` + for i ∈ {4, 9, 14, 19, 24, 29, 34} (every 5th layer). +- **E4B**: 42 layers = 35 sliding + 7 full. Same 1-in-5 cadence. +- **Sliding layers** (`W=512`, `head_dim=256`): use shift-based update + `K_new = cat([K_cache[:, :, 1:, :], k], dim=2)` + (`gemma4_swa_chunks.py:105-108`). Ships on iOS ANE without ANEC -14. +- **Full-attention layers** (`W=ctx`, `head_dim=512`): use mask-based update + `K_new = K_cache * (1 - update_mask) + k * update_mask` + with state buffer sized to full context. +- **KV sharing**: layers 15-34 read from L13 (sliding producer) and L14 + (full producer) via explicit I/O tensors, not `ct.StateType` + (`gemma4_swa_chunks.py:112-129`). + +## Where Gemma 4 hurts + +From `docs/SPEED_8K.md` and `docs/HANDOFF.md:145`: + +- **ctx=2048 → 31.4 tok/s** decode, iPhone 17 Pro ANE +- **ctx=8192 → 14.9 tok/s** — 2.1× regression from 2K to 8K +- Chunk2 (the chunk containing the 7 full-attention layers) measures + **~2.96 ms / full layer** at ctx=8192 vs **1.5–1.7 ms / sliding layer**. + Full-attn state read/write is ~2× the cost of a sliding-W=512 state read + per step. +- chunk2 is where the budget goes. Reducing full-attn per-step cost has + the largest leverage on ctx=8K tok/s. + +## Application vectors + +### A. Convert 7 full-attention layers to mask-based rotating SWA (recommended trial) + +**What**: Replace the `ctx`-sized state buffer on each full layer with a +W-sized rotating buffer (W configurable, default 1024). Write slot = +`pos % W`, attention over W slots. Same mask-based blend op already used +for the write — just smaller state. + +**Expected speedup** (back-of-envelope): full-layer per-step cost scales +with state_length. 8192 → 1024 = 8× reduction in state reads/softmax. At +2.96 ms/layer × 7 layers = 20.7 ms → ~2.6 ms total. chunk2 budget drops by +~18 ms/step. At the 14.9 tok/s / 67 ms baseline, 18 ms saved → ~25 tok/s +(+68%) at ctx=8192. That's the single biggest lever available. + +**Quality risk**: full-attention layers retain long-range context. Converting +them to SWA discards tokens older than W. Precedent: `gemma4_swa_wfa.py` +attempted this exact semantic change (with shift-based update, W=2048) and +was **shelved for quality regression** on prompts that need attention beyond +the window (see `docs/EXPERIMENTS.md` "WFA section"). + +**Two things make this worth trying again**: + +1. **Attention sinks** (StreamingLLM-style). Reserve the first 4 slots of + the W-sized buffer permanently for the first 4 prompt positions. The + full-attention layers regain a global anchor at the cost of 4 slots + (~0.4% of W=1024). Mask-based rotating trivially supports this: just + fix `update_mask` to never write to slots 0-3 once positions 0-3 are + captured. WFA didn't have this — so its quality regression is not + directly predictive of our ceiling. +2. **Mask-based vs shift-based is strictly better for ANEC**: even if we + end up wanting W=2048 (matching WFA) or ctx=4096, mask-based pattern + works on Qwen3 + Stateful + tied (proven). Shift-based doesn't (ANEC -14). + So this gives future flexibility Gemma 4 doesn't currently have. + +**Code changes** (small, localized): + +- `conversion/models/gemma4_swa_chunks.py:99-101` — the existing mask-based + write for full-attention. Change the state buffer shape from + `(1, num_kv_heads, ctx, head_dim)` to `(1, num_kv_heads, W, head_dim)`. +- Host-side (Swift `Gemma4Chunk2.swift` or equivalent): update_mask becomes + `(1, 1, W, 1)` sized, write index = `pos % W`; causal_mask becomes + `(1, 1, 1, W)` and the valid-slot fill logic handles wraparound. +- (Optional) Sink retention: first 4 slots are reserved. Host sets + update_mask[0, 0, s, 0] = 0 for s ∈ {0,1,2,3} when `pos >= 4`. +- Reuse `conversion/build_bonsai_17b_decode_chunks.py` helpers + (`_decode_layer_step`) — the op pattern is identical, just different + state size. Or copy the pattern into Gemma 4's wrapper directly. + +**Risk-mitigation plan**: + +1. Gate the behavior behind a `--full-layer-window W` CLI flag (default None = + keep current ctx-sized full attention). +2. Quality sweep on Gemma 4 E2B at ctx=2K, 4K, 8K with W ∈ {512, 1024, 2048, + 4096}. Use the Qwen3.5-2B acceptance prompts (factual, multilingual, + code-switch) + long-context retrieval probes (needle in haystack). +3. Ship only if quality regression < measurable threshold AND speedup + materializes. + +### B. Convert 28 sliding layers from shift-based to mask-based + +**What**: Replace `gemma4_swa_chunks.py:105-108` shift-based `cat` with +mask-based rotating write. Semantics identical, op pattern different. + +**Expected speedup**: zero to marginal. Sliding layers already run at +~1.5-1.7 ms/layer on ctx=8K, which is mostly the GQA + MLP, not the state +write. The shift pattern already ships, so ANEC accepts it for Gemma 4's +specific config. + +**Worth doing anyway?** Three reasons it might: + +1. **Future-proofing**: when Apple changes ANEC lowering in an iOS update, + the shift-based path on Gemma 4 could become another ANEC -14 surprise. + Mask-based has strictly more coverage (Qwen3 + Gemma 4 both). +2. **Consistency**: `gemma4_lite_chunks.py` already uses mask-based for all + layers (sliding + full) — converging the shipping variant removes a + divergence. +3. **Measurement first**: convert one sliding layer, trace, measure. If + mask-based is faster by any margin, switch. If not, leave it. + +**Verdict**: low priority. Do it after (A) if (A) ships. Or leave indefinitely +if Gemma 4 ships don't break. + +### C. Extend sliding window from 512 to 1024+ + +**What**: Bump `config.sliding_window` from 512 to 1024 or 2048, keep shift +pattern. Gains quality (sliding layers see more context per step), costs +per-step time (scales linearly with W). + +**Expected impact**: at W=1024 the sliding layers' state read doubles, +costing ~1 ms/layer × 28 layers = ~28 ms/step extra. That drops tok/s from +14.9 → ~11 at ctx=8K. Not worth it unless quality measurements show a +meaningful retrieval gain. + +**Verdict**: not a rotating-buffer port, orthogonal. Don't conflate. + +### D. Apply our full Bonsai pipeline to Qwen3-4B / Qwen3-8B + +Out of scope for this doc but worth flagging: the entire mask-based rotating +pipeline works directly on larger Qwen3 variants (4B, 8B) with no +architectural changes. Register the model in `conversion/config.py`, adjust +`--split-at` based on parameter count, and rebuild. Gemma 4 is the harder +port; Qwen3-4B is the easier extension. + +## Minimum viable experiment + +Only (A) is a meaningful decode-time change. Scope: + +1. **New CLI flag on `gemma4_swa_chunks.py`** (or new file + `build_gemma4_full_rotating.py` if surgery in-place is too invasive): + `--full-layer-window W`. Default None → keep ctx-sized full attention. +2. **Runtime state buffer for full layers** becomes `(1, nkv_full, W, 512)` + instead of `(1, nkv_full, ctx, 512)`. (`gemma4_swa_chunks.py` around + the `SWAChunk2` constructor / state-tensor declarations.) +3. **Optional sink retention**: reserve slots 0-3, never overwrite after + pos 3. +4. **Swift side**: chunk2's caller builds `update_mask` and `causal_mask` + sized to W, with `write_slot = pos % W` and wrap-around causal. +5. **Parity harness**: adapt `conversion/bonsai_reference_oracle.py` pattern + to Gemma 4 — compare first-5-token greedy from HF vs our model. Not for + quality eval (long-context retrieval is a separate eval), but for + "did I wire up the cache index right." + +Estimated effort: 1–2 days for a prototype, 2–3 days including an iPhone +tok/s measurement and a minimal quality sweep (factual + multilingual + +one needle-in-haystack prompt). If the sweep shows quality cliff, stop and +revisit with sinks or hybrid (e.g. keep 2 of 7 full layers unchanged, apply +rotating to the other 5). + +## What NOT to do + +- **Do not swap shift→mask on sliding layers without measurement** (vector B). + Low upside, risk of ANEC surprise. +- **Do not enlarge the sliding window** (vector C) hoping for speed. That + makes things slower; it's a quality knob, not a speed one. +- **Do not remove the full-attention layers entirely** (i.e. all-sliding + Gemma 4). Precedent: WFA. Quality shelved. Don't re-run that experiment + without sinks + a different W. + +## References + +- `docs/DECODE_STATE_LAYOUTS.md` — the decode-path knowledge base this port + is based on. +- `docs/TERNARY_BONSAI.md` §SWA — the proven pattern for Qwen3-class. +- `conversion/build_bonsai_17b_decode_chunks.py` — the reference build + with `--sliding-window W`. The `_decode_layer_step` function there is + copy-pasteable into Gemma 4's per-layer decode. +- `conversion/models/gemma4_swa_chunks.py:99-108` — current full-attention + and sliding write patterns, side-by-side. +- `conversion/models/gemma4_swa_wfa.py` — prior full→SWA attempt (shift-based, + no sinks) that was shelved. Read its header comments first. +- `docs/SPEED_8K.md` — the measurement baseline for chunk2's + 2.96 ms/full-layer figure. +- `docs/EXPERIMENTS.md` WFA section — quality regression narrative. +- `docs/HANDOFF.md:145` — ctx=8K / 14.9 tok/s shipping number. diff --git a/docs/NEXT_MODELS.md b/docs/NEXT_MODELS.md new file mode 100644 index 0000000..b5e9bda --- /dev/null +++ b/docs/NEXT_MODELS.md @@ -0,0 +1,116 @@ +# Next ANE-Friendly Models to Port + +**Last updated:** 2026-04-25 + +Shortlist of small (1–4B) decoder-only LLMs that map cleanly onto our existing +Core ML / ANE conversion infrastructure. Ranked by ease-of-port × ecosystem value. +Excluded: models we already support (Qwen2.5, Gemma 3 / 4, FunctionGemma, +EmbeddingGemma) and models that fundamentally don't fit ANE (per-block +quantized BitNet/Bonsai-class — see `TERNARY_BONSAI.md`). + +## Top picks + +### 1. Qwen3-1.7B-Instruct / Qwen3-4B-Instruct (top pick) + +- HF: `Qwen/Qwen3-1.7B`, `Qwen/Qwen3-4B-Instruct-2507` +- 1.7B (28 layers, hidden=2048, GQA 16/8, head_dim=128, tied embed, 32K ctx) +- 4B (36 layers, GQA 32/8, similar tied / RoPE / QK-norm story) +- Apache 2.0; bf16 official; community AWQ/GGUF abundant +- Why ANE-friendly: dense decoder, GQA, RoPE, **QK-norm** — exactly what + `conversion/models/qwen3.py` already supports. Drop-in via existing + `convert.py --model qwen3-1.7b` once registered. +- Speed estimate: 1.7B ≥ 25 tok/s INT4 chunked, 4B ~12-14 tok/s (Gemma 4 E4B class). +- ANEMLL has a working CoreML port that validates feasibility. + +**Effort to port:** add 1-2 entries to `MODEL_REGISTRY` in `config.py`. Done. + +### 2. Gemma 3 1B / 4B (best architectural match) + +- HF: `google/gemma-3-1b-it`, `google/gemma-3-4b-it`, `google/gemma-3-270m` +- 1B is text-only (32K ctx); 4B is multimodal (128K ctx) +- Same **5-local-SWA : 1-global** pattern as Gemma 4 — our `gemma4_swa_*.py` + chunking infra reuses ~80% +- Google ships **QAT INT4** weights officially: + `google/gemma-3-4b-it-qat-q4_0-unquantized` — palette-friendly, no per-block + scales (unlike Bonsai), so they actually run on ANE +- Gemma terms; commercial OK with usage policy; released Mar 25, 2025 + +**Effort to port:** Gemma 3 has FunctionGemma support already +(`conversion/models/gemma3.py`); 1B / 4B differ mostly in size + dual SWA/full +window. Reuse Gemma 4 SWA chunking; ~1-2 days. + +### 3. Llama-3.2-1B-Instruct / 3B-Instruct (lowest risk) + +- HF: `meta-llama/Llama-3.2-1B-Instruct`, `meta-llama/Llama-3.2-3B-Instruct` +- 1B: 16 layers, hidden=2048, GQA 32/8; 3B: 28 layers, hidden=3072, GQA 24/8 +- 128K ctx via Llama-3 RoPE scaling +- Vanilla decoder-only — no QK-norm, no SWA — simplest possible ANE port +- Llama 3.2 Community License (commercial OK with MAU caps; not EU) +- Reported: ~47–62 tok/s on iPhone 17 Pro for Llama-3.2-1B via ANEMLL → our + build should hit similar +- Existing reference: `smpanaro/Llama-3.2-1B-Instruct-CoreML` on HF +- Released Sep 2024 — older but mature & widely requested + +**Effort to port:** new `conversion/models/llama.py` (mirror `qwen2.py`), +verify Llama-3 RoPE scaling. Half a day. + +### 4. SmolLM3-3B (best quality/size in 3B class) + +- HF: `HuggingFaceTB/SmolLM3-3B` +- 3B, decoder-only, GQA, **NoPE 3:1** (no positional embed every 4th layer) +- 64K native ctx (128K via YaRN) +- Apache 2.0 +- Outperforms Llama-3.2-3B and Qwen2.5-3B at 3B scale (HF benchmarks Jul 2025) +- Risk: 3:1 NoPE pattern needs a small wrapper change (skip RoPE on every 4th layer) + +**Effort to port:** ~1 day, mostly the NoPE wrapper. + +## Honorable mentions + +| model | HF | why interesting | why not first | +|---|---|---|---| +| Phi-4-mini-instruct | `microsoft/Phi-4-mini-instruct` | 3.8B, MIT, popular; INT4 ONNX exists | fractional RoPE (25% NoPE per head) → custom RoPE op. Qwen3-4B fills same niche | +| Ministral-3-3B-Instruct-2512 | `mistralai/Ministral-3-3B-Instruct-2512` | 3B, Apache 2.0, multimodal, 256K ctx | released Dec 2025, too new to validate | +| DeepSeek-R1-Distill-Qwen-1.5B | `deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B` | reasoning-tuned | comes free via existing `qwen2.py` path; just register | +| Apple OpenELM-3B | `apple/OpenELM-3B-Instruct` | Apple-built | layer-wise variable-width FFN doesn't fit our chunking; weak instruct quality | + +## Skip / deprioritize + +- **Phi-3.5-mini** — superseded by Phi-4-mini +- **TinyLlama** — too old, dwarfed by SmolLM3 +- **Liquid LFM2 / IBM Granite 4.0 / Hymba / Falcon-Mamba** — Mamba/SSM hybrid; + needs new SSM kernel beyond our Gated-DeltaNet code (different op set) +- **Apple Foundation Models 3B (Apple Intelligence)** — weights not released +- **Cohere R7B / Mistral 7B / Qwen3-8B** — exceed iPhone ANE compile budget; + Mac-only would still need 4-chunk split + +## Recommended porting order + +1. **Qwen3-1.7B-Instruct** — almost zero code; validates the new `qwen3.py` + path on a real instruct model; ships fastest. +2. **Gemma 3 4B (QAT)** — reuses Gemma 4 SWA chunking; native INT4 weights; + Google brand recognition. +3. **Llama-3.2-3B-Instruct** — most-requested by community; ANEMLL parity check. +4. **SmolLM3-3B** — best quality/size at 3B class; Apache 2.0; differentiator. +5. (later) Qwen3-4B-Instruct-2507, Phi-4-mini, Ministral-3-3B. + +## What infrastructure we have ready + +- `conversion/models/qwen3.py` — Qwen3 architecture (QK-norm, tied embed) + ready for Qwen3-1.7B / 4B / 8B +- `conversion/models/qwen2.py` — Qwen2.5 architecture, also fits + DeepSeek-R1-Distill-Qwen-* finetunes +- `conversion/models/gemma3.py` + Gemma 4 SWA chunks — Gemma family backbone +- `docs/DECODE_STATE_LAYOUTS.md` — decode-time state pattern checklist for + any new model +- `docs/ADDING_MODELS.md` — end-to-end walkthrough for adding a new arch + +## Sources for the shortlist + +- [Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B), [Qwen3 tech report](https://arxiv.org/pdf/2505.09388) +- [Gemma 3 1B](https://huggingface.co/google/gemma-3-1b-it), [4B](https://huggingface.co/google/gemma-3-4b-it), [tech report](https://arxiv.org/html/2503.19786v1) +- [Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct), [3B](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct) +- [SmolLM3-3B](https://huggingface.co/HuggingFaceTB/SmolLM3-3B), [SmolLM3 blog](https://huggingface.co/blog/smollm3) +- [Phi-4-mini-instruct](https://huggingface.co/microsoft/Phi-4-mini-instruct) +- [Ministral-3-3B](https://huggingface.co/mistralai/Ministral-3-3B-Instruct-2512) +- [smpanaro/Llama-3.2-1B-Instruct-CoreML](https://huggingface.co/smpanaro/Llama-3.2-1B-Instruct-CoreML), [ANEMLL](https://www.anemll.com/) diff --git a/docs/TERNARY_BONSAI.md b/docs/TERNARY_BONSAI.md new file mode 100644 index 0000000..fac7b6b --- /dev/null +++ b/docs/TERNARY_BONSAI.md @@ -0,0 +1,107 @@ +# Bonsai 1.58-bit on Core ML / ANE — Investigation Post-Mortem + +**Status:** investigated and **not shipped**. Use MLX for Bonsai instead. + +## Goal + +Bring [`prism-ml/Ternary-Bonsai-1.7B`](https://huggingface.co/prism-ml/Ternary-Bonsai-1.7B) +to Apple Neural Engine (iPhone 17 Pro), preserving the author's 1.58-bit ternary +weight encoding so the model's structural compression advantage carries over +into Core ML. + +## What we built + +A complete Qwen3 conversion path that didn't exist before: + +- `conversion/models/qwen3.py` — `Qwen3Model` (QK-norm, tied embed, no attention bias) +- `conversion/base_model.py` — optional QK-norm in `ANEAttention`, off by default +- `conversion/exporter.py` — `MonolithicWrapper` honors QK-norm +- `conversion/convert.py` — `qwen3` architecture routing +- `docs/DECODE_STATE_LAYOUTS.md` — generalized decode-path lessons that came out + of this work (chunking, SWA, palette traps) + +These all stay in the codebase; they're useful for any Qwen3 derivative. + +## Bonsai-specific work, summarized + +Verified the per-128-block structure of the unpacked FP16 weights — 100% of +sampled groups have exactly 3 unique values `{−s, 0, +s}`, scale `s` varies +per (row, block). Built three Core ML variants and measured them on Mac ANE: + +| variant | size | speed | top-1 vs Bonsai/MLX | ANE OK | +|---|---|---|---|---| +| INT4 k-means per_tensor (chunked) | 1.0 GB | 24 tok/s | matches "Paris" but logits are coarse | yes | +| nbits=6 unique per-row palette | 1.7 GB | 11 tok/s | drifts (different output) | yes | +| **bit-exact per-(row,block) LUT** (custom MIL surgery) | **0.4 GB / 1.0 GB** | **N/A** | **would be exact** | **NO — ANEC error -14** | +| SWA mask-based rotating buffer at ctx=4096/W=1024 | 1.0 GB | 25 tok/s | matches at short ctx, forgets long-range | yes | + +Reference build: `conversion/experiments/bonsai/build_bonsai_17b_decode_chunks.py`. +Bit-exact MIL surgery: `conversion/experiments/bonsai/ternary_surgery.py`. + +## The blocking finding + +Bonsai's compression depends on **per-(row, block) independent scales** — for +a (2048, 2048) layer, that's 32,768 distinct scales arranged as a (2048, 16) +matrix. Two ways to express this in Core ML: + +1. **Single-op `constexpr_lut_to_dense`** with LUT shape `(2048, 16, 4, 1)`, + where each (row, block) has its own 4-entry codebook `[0, +s, -s, 0]`. +2. **Two-op chain** — `constexpr_lut_to_dense` with shared sign codebook + `[0, +1, -1, 0]` followed by `constexpr_blockwise_shift_scale` carrying + the (2048, 16) scale matrix. + +Both **load as MLModel and serialize fine**. Both **fail Apple's ANE compiler +(`error code: -14`)** when iOS tries to build the execution plan. The model +loads as a CPU-stub and `make_state()` throws "This model was not loaded +with the Core ML Framework." + +The granularity ANE accepts in iOS18 is **per-tensor or per-grouped-channel +along a single axis** — there is no current ANE kernel that handles a +per-(row, block) palette layout. Until Apple adds support, Bonsai's structure +cannot be faithfully run on ANE. + +## Why "approximate but ANE-running" doesn't work + +The next-best approximation in stock coremltools is `nbits=2 per_grouped_channel ++ enable_per_channel_scale`: per-block LUT (16 LUTs, 4 codes each) plus a +per-row scale factor. This is rank-1: `s_{r,b} ≈ c_b · d_r`. It compiles for +ANE, but **discards the per-(row, block) scale independence** — the very thing +that justifies Bonsai's training procedure. + +If you ship that, you're shipping "Qwen3-1.7B with structured palette quant", +not Bonsai. There is no point in pulling Bonsai's weights specifically; any +Qwen3-1.7B variant gives equivalent results through the same path. So we don't +ship it. + +## What to do if you want Bonsai + +Use [MLX](https://github.com/ml-explore/mlx) and `mlx-lm`. The published +`prism-ml/Ternary-Bonsai-1.7B-mlx-2bit` weights run on Apple Silicon GPU +via `mx.quantized_matmul` with native 2-bit packed ternary, preserving the +per-block scales. Reported speed: ~27 tok/s on iPhone 17 Pro Max for the 8B +class; the 1.7B should be substantially faster. + +For Swift integration, [`mlx-swift-examples`](https://github.com/ml-explore/mlx-swift-examples) +provides drop-in patterns. ANE is not used; this is a GPU path. + +## Knowledge harvested for the rest of the codebase + +The Bonsai investigation produced reusable lessons recorded in +[`docs/DECODE_STATE_LAYOUTS.md`](DECODE_STATE_LAYOUTS.md): + +- Mask-based circular rotating KV buffer (replaces shift-based `cat([K[:,1:], k])` + which ANEC rejects on Qwen3 + Stateful + tied embed) +- ANE per-step decode cost is `O(state_length)`, not weight bandwidth, in this + model class — context reduction beats more aggressive quant +- `mode="kmeans"` palettization is the safe default; `mode="unique"` falls back + silently when global tensor uniqueness exceeds nbits range +- Trace-time `TracerWarning` on stateful modules is suppressed by `strict=False` +- `audit_ane` must wrap `get_compiled_model_path()` in try/except to survive + ANEC -14 saves + +## Files + +- `conversion/experiments/bonsai/README.md` — manifest of the experiment scripts +- `conversion/experiments/bonsai/*` — full set of build, parity, and surgery + scripts that we walked through during the investigation +- This doc — the human-readable summary