From a04a2ef454b8e2d96237cc82863426faed6b054c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Andr=C3=A9=20Gomes=20Marques?= Date: Fri, 19 Jun 2026 22:56:59 +0200 Subject: [PATCH] Add Mixtral-8x7B MoE KV-compression kernel and results Adds K3V2/K4V2 quant-only KV compression coverage for Mixtral-8x7B-Instruct (NF4 weights) on a single T4 via per-layer CPU/GPU streaming. PPL deltas (AQUA-iso paired, n=80): K3V2 +0.63%, K4V2 +0.32% off the NF4-weights baseline. Needle retrieval at 4K context holds at 5/5 for FP16, K3V2, and K4V2. Claude-Session: https://claude.ai/code/session_012T2q1cWGCTY963GGFXRZA7 --- .../nq_mixtral_stream/kernel-metadata.json | 19 + .../nq_mixtral_stream/nq_mixtral_stream.py | 1523 +++++++++++++++++ .../kaggle/results/nq_mixtral_stream.json | 488 ++++++ 3 files changed, 2030 insertions(+) create mode 100644 experiments/kaggle/nq_mixtral_stream/kernel-metadata.json create mode 100644 experiments/kaggle/nq_mixtral_stream/nq_mixtral_stream.py create mode 100644 experiments/kaggle/results/nq_mixtral_stream.json diff --git a/experiments/kaggle/nq_mixtral_stream/kernel-metadata.json b/experiments/kaggle/nq_mixtral_stream/kernel-metadata.json new file mode 100644 index 0000000..8d5cec3 --- /dev/null +++ b/experiments/kaggle/nq_mixtral_stream/kernel-metadata.json @@ -0,0 +1,19 @@ +{ + "id": "jagmarques/nq-mixtral-stream", + "title": "nq-mixtral-stream", + "code_file": "nq_mixtral_stream.py", + "language": "python", + "kernel_type": "script", + "is_private": true, + "enable_gpu": true, + "enable_tpu": false, + "enable_internet": true, + "keywords": [ + "gpu" + ], + "dataset_sources": [], + "kernel_sources": [], + "competition_sources": [], + "model_sources": [], + "machine_shape": "NvidiaTeslaT4" +} diff --git a/experiments/kaggle/nq_mixtral_stream/nq_mixtral_stream.py b/experiments/kaggle/nq_mixtral_stream/nq_mixtral_stream.py new file mode 100644 index 0000000..062736f --- /dev/null +++ b/experiments/kaggle/nq_mixtral_stream/nq_mixtral_stream.py @@ -0,0 +1,1523 @@ +# NexusQuant Mixtral-8x7B -- manual per-layer GPU streaming, bypassing accelerate entirely. +# +# APPROACH (genuinely new vs all prior C6 attempts): +# Prior walls: v1-v4 all used accelerate device_map, hitting either the bitsandbytes +# "Tensor.item() on meta tensor" offload-serialization bug (v3/v4 CPU offload) or the +# "modules dispatched on CPU/disk" dispatch wall (v5/v6 GPU-only, 26GB > 29.2GB budget). +# Both failures are accelerate-driven. +# +# This kernel bypasses accelerate entirely: +# 1. Load with device_map={"":"cpu"}: ALL NF4 weights land on CPU RAM as normal +# bitsandbytes Linear4bit objects. No meta tensors. No accelerate offload hooks. +# No dispatch wall (device_map="auto" is never called). +# 2. Register PyTorch forward pre/post hooks on each MixtralDecoderLayer: +# pre_hook -> dequantize all Linear4bit submodules to fp16 nn.Linear on GPU, +# move non-bnb parameters (norms, gates) to GPU +# post_hook -> restore all Linear4bit submodules (NF4 stays CPU-resident), +# move non-bnb params back to CPU, free fp16 temporaries +# 3. Embeddings + lm_head + norms stay on CPU and are moved around their use. +# 4. The residual stream tensor (hidden_states) and KV cache stay on GPU throughout +# the full forward, crossing layer boundaries normally. +# 5. model.generate() and model.forward() work unchanged -- hooks fire automatically +# inside transformers' standard forward path. +# 6. MoE routing (gate network + top-2 expert selection) is INSIDE MixtralDecoderLayer, +# so the hook covers the full MoE block including all 8 experts. +# +# FIX (v2): bitsandbytes Linear4bit does NOT survive GPU->CPU->GPU round-trips via +# plain .to(device). The packed 4-bit weight blob (shape ~1x16384) is matmul'd raw +# because the QuantState/dequant path is lost after the migration, producing the +# "mat1 and mat2 shapes cannot be multiplied (64x4096 and 1x16384)" crash. +# Solution: for each forward, dequantize all Linear4bit weights to fp16 ON GPU via +# bitsandbytes.functional.dequantize_4bit, replace the module temporarily with a +# plain nn.Linear(fp16), run the forward, then restore and free. The NF4 weights +# stay on CPU untouched throughout. Numerically identical to bnb's normal 4-bit +# matmul (which dequantizes internally). Peak GPU = ~1.6GB per layer (fp16 weights) +# + activations + KV cache << 14.6GB single T4. +# +# FIX (v3): double-quantization (bnb_4bit_use_double_quant=True) nests a second +# QuantState inside quant_state.state2 whose absmax+code are also CPU tensors. +# qs.to(DEVICE) does NOT deep-move state2.absmax, state2.code, or qs.offset in +# all bnb versions, leaving top-level absmax==None and triggering: +# assert absmax is not None and out is not None (bnb/functional.py:1026) +# Fix: after qs.to(DEVICE), explicitly move every nested tensor to GPU. +# Also print bnb version + QuantState attrs once so future iterations are grounded. +# +# FIX (v4): v3 DIAGNOSTICS revealed that on the CPU-resident model, qs.absmax is +# literally None -- the absmax tensor lives as a SEPARATE module buffer, not inside +# the qs object. The buffers appear as state_dict keys like: +# "weight.absmax", "weight.nested_absmax", "weight.quant_map", +# "weight.nested_quant_map", "weight.quant_state.bitsandbytes__nf4", "weight.offset" +# _move_qs_to_device then moves None->None (no-op) and dequantize_4bit still asserts. +# Real fix: reconstruct QuantState entirely from those raw buffers using +# bnb.functional.QuantState.from_dict(qs_dict, device=DEVICE), stripping the +# "weight." prefix from the state_dict keys. This mirrors what bnb's own +# _load_from_state_dict / from_prequantized does internally. +# Fused expert weights (experts.gate_up_proj / experts.down_proj) have one +# QuantState for the whole batched tensor; from_dict handles them identically. +# Diagnostic: print state_dict keys for one attn Linear4bit AND one fused expert +# so the next iteration is fully grounded if this still fails. +# +# Memory budget: +# GPU peak = 1 layer fp16 weights (~1.6GB) + residual (~34MB) + KV (0.54GB at 4K) < 2.2GB +# CPU RAM = 26.3GB NF4 model + ~2GB Python overhead = ~28.3GB (Kaggle gives ~26-29GB) +# Risk: CPU RAM OOM is the primary failure mode. If it hits, we report and stop. +# +# C6 PASS requirements (ALL required; do not relax any): +# - errors.model_load EMPTY +# - PPL: paired K3V2_pb0 AND K4V2_pb0 deltas, n>=60, AQUA-iso prefix=1024 +# - NIAH: FP16(NF4-weights) baseline hits>0 at ctx>=4096 (2048 fallback NOT counted) +# +# Timing estimate: +# Per forward: ~6-10s (32 layers x (PCIe move ~54ms + compute ~100-200ms)) +# PPL: 80 segs x 4 forwards = ~34-56min +# NIAH: 24 trials x (1 prefill + 40 gen steps) = ~1.5-2.5h +# Total: ~2-3h. Fits 9h Kaggle wall. + +import sys, os, gc, math, time, json, re, traceback, random, subprocess + +os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") +# Single GPU target -- everything goes to cuda:0, nothing to cuda:1. +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +print("Installing packages ...", flush=True) +subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "numpy<2"]) +# PIN transformers to late-4.x: Mixtral experts are per-expert nn.Linear +# (experts.{i}.w1/w2/w3) that bitsandbytes wraps as Linear4bit natively. The 5.x +# fused 3D nn.Parameter (experts.gate_up_proj/down_proj) has no bnb quantizer +# integration -> per-expert NF4 quant_state dropped at load -> dequant skipped. +# Let a dep conflict surface (do NOT silently fall back). +subprocess.run([sys.executable, "-m", "pip", "install", "-q", + "transformers==4.56.2"], check=True) +subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-U", + "accelerate>=1.1.1", "bitsandbytes>=0.43.0"]) +subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", + "git+https://github.com/jagmarques/nexusquant.git@main"]) + +import transformers +print(f"[pin] transformers.__version__ = {transformers.__version__}", flush=True) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import (AutoModelForCausalLM, AutoTokenizer, AutoConfig, + DynamicCache, BitsAndBytesConfig) +from datasets import load_dataset + +from nexusquant.core.e8_lattice import E8Lattice +from nexusquant.core.hadamard import hadamard_matrix +from nexusquant.core.rope_utils import inverse_rope, forward_rope + +SEED = 42 +random.seed(SEED) +torch.manual_seed(SEED) + +# Public ungated NF4 mirrors (no HF_TOKEN needed for jagmardrop secondary account). +MODEL_CANDIDATES = [ + "unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit", + "unsloth/Mixtral-8x7B-v0.1-bnb-4bit", +] + +PREFIX_LEN = 1024 +CONT_LEN = 512 +SEG_LEN = PREFIX_LEN + CONT_LEN +N_SEGS_TARGET = 80 +N_SEGS_MIN = 60 +LOGIT_CHUNK = 256 +CONT_CHUNK = 512 +OUT_PATH = "/kaggle/working/nq_mixtral_stream.json" + +NIAH_CONTEXT = 4096 +N_PAIRS = 8 +# v26: the model echoes the 7-digit key before the value ("The value of +# is ."), ~10 tokens, so 12 truncated the answer to "...is" with no value +# (v25 FP16=0/5 artifact). 20 tokens give room for the value to emit. The manual +# greedy loop allows EOS only after step>=3, so it still stops naturally after. +# n_trials=5 (0..5 hits/config) for a ~1.5-3h wall. +N_TRIALS = 5 +MAX_NEW_TOKENS = 20 + +PPL_CONFIGS = [("K3V2_pb0", 3, 2, 0), ("K4V2_pb0", 4, 2, 0)] +NIAH_CONFIGS = [("FP16", 0, 0, 0, False), ("K3V2_pb0", 3, 2, 0, True), + ("K4V2_pb0", 4, 2, 0, True)] + +cfg_state = {} +DEVICE = "cuda:0" # single GPU; cuda:1 intentionally excluded + +# v21: full run. v20 validated the per-expert-Linear4bit load + dequant + native +# MoE forward on the 4.56.2 pin; the 64-tok smoke-check still runs as a preflight +# gate before the PPL/NIAH loops, but the loops now execute (SMOKE_ONLY=False). +SMOKE_ONLY = False + +# v23: NIAH-only split. T4 per-layer streaming is ~121s/forward, so a combined +# PPL(n>=60)+NIAH run cannot fit the 9h Kaggle wall. v21 produces the PPL half; +# this kernel runs ONLY the NIAH phase (skips PPL entirely) so it finishes in +# ~1-2h, in parallel with v21 on the second concurrent GPU slot. +NIAH_ONLY = True + + +# --------------------------------------------------------------------------- +# GPU streaming infrastructure +# --------------------------------------------------------------------------- + +_hook_handles = [] + + +def _is_linear4bit(module): + """True if module is a bitsandbytes Linear4bit (any variant).""" + cls_name = type(module).__name__ + return cls_name in ("Linear4bit", "LinearNF4", "LinearFP4") + + +def _is_params4bit(param): + """True if param is a bitsandbytes Params4bit (packed NF4 blob).""" + return type(param).__name__ == "Params4bit" + + +# Diagnostic gates: print state_dict key structure once per weight category. +_qs_diag_attn_done = False # first attention Linear4bit +_qs_diag_expert_done = False # first fused expert Linear4bit +_gate_diag_done = False # gate Params4bit diagnostic (printed at most once) +_expert_param_diag_done = False # fused MoE expert-param dequant diagnostic (v19) + + +def _build_qs_from_module(module, dev): + """Reconstruct a QuantState on `dev` from the module's raw state_dict buffers. + + When a model is loaded CPU-resident (device_map={'':'cpu'}), bitsandbytes stores + the quant_state tensors as SEPARATE module buffers rather than populating + qs.absmax etc. in memory. The state_dict() keys follow the pattern: + "weight.absmax", "weight.nested_absmax", "weight.quant_map", + "weight.nested_quant_map", "weight.quant_state.bitsandbytes__nf4", + "weight.offset" + Strip "weight." and pass the dict to QuantState.from_dict(), which handles + both the packed format (quant_state.bitsandbytes__nf4 present) and the + unpacked format (individual keys). This mirrors bnb's own from_prequantized. + """ + from bitsandbytes.functional import QuantState as BnbQuantState + sd = module.state_dict() + PREFIX = "weight." + qs_dict = {} + for k, v in sd.items(): + if k.startswith(PREFIX): + stripped = k[len(PREFIX):] + # Only include quant-state keys; skip the weight tensor itself. + if stripped != "" and not stripped.startswith("bias"): + qs_dict[stripped] = v + # from_dict expects tensors on the target device for the data tensors; + # it moves them internally, but we move upfront to be safe. + qs_dict_dev = {} + for k, v in qs_dict.items(): + if isinstance(v, torch.Tensor): + qs_dict_dev[k] = v.to(dev) + else: + qs_dict_dev[k] = v + qs = BnbQuantState.from_dict(qs_dict_dev, device=dev) + return qs, qs_dict # return raw dict for diagnostics + + +def _dequant_layer_to_gpu(layer): + """Replace every Linear4bit in layer with a temporary fp16 nn.Linear on GPU. + Returns a dict {name: original_module} for later restore. + Non-bnb params (norms, gates, etc.) are moved to GPU in-place.""" + import bitsandbytes as bnb + import bitsandbytes.functional as bnbF + global _qs_diag_attn_done, _qs_diag_expert_done, _gate_diag_done + global _expert_param_diag_done + + saved = {} + for name, module in list(layer.named_modules()): + if not _is_linear4bit(module): + continue + + # Classify: fused expert weights contain "experts." in their path. + is_expert = "experts." in name + + # Diagnostic: print buffer keys once per category (attn + expert). + diag_needed = (is_expert and not _qs_diag_expert_done) or \ + (not is_expert and not _qs_diag_attn_done) + if diag_needed: + if not _qs_diag_expert_done and not _qs_diag_attn_done: + # First ever call: also print bnb version. + print(f" [bnb-diag] bitsandbytes=={bnb.__version__}", flush=True) + sd_keys = list(module.state_dict().keys()) + w_cpu_pre = module.weight + qs_pre = w_cpu_pre.quant_state + cat = "expert" if is_expert else "attn" + print(f" [qs-diag/{cat}] name={name!r} sd_keys={sd_keys}", flush=True) + print(f" [qs-diag/{cat}] weight.shape={w_cpu_pre.data.shape} " + f"qs.type={type(qs_pre).__name__} " + f"qs.absmax={type(getattr(qs_pre,'absmax',None)).__name__} " + f"qs.quant_type={getattr(qs_pre,'quant_type',None)}", flush=True) + if is_expert: + _qs_diag_expert_done = True + else: + _qs_diag_attn_done = True + + # Reconstruct QuantState from the module's raw state_dict buffers. + # This is the v4 fix: qs.absmax is None on CPU-resident models; the + # real absmax lives as a separate buffer (state_dict key "weight.absmax"). + w_cpu = module.weight + try: + qs_gpu, _raw_dict = _build_qs_from_module(module, DEVICE) + except Exception as e_qs: + # Fallback: try the old deep-move path (keeps v3 behavior as safety net). + print(f" [qs-warn] from_dict failed ({e_qs}); falling back to deep-move", + flush=True) + qs_pre = w_cpu.quant_state + qs_gpu = qs_pre.to(DEVICE) + for attr in ("absmax", "offset", "code"): + t = getattr(qs_gpu, attr, None) + if isinstance(t, torch.Tensor): + setattr(qs_gpu, attr, t.to(DEVICE)) + s2 = getattr(qs_gpu, "state2", None) + if s2 is not None: + for attr in ("absmax", "code", "offset"): + t = getattr(s2, attr, None) + if isinstance(t, torch.Tensor): + setattr(s2, attr, t.to(DEVICE)) + + # Dequantize: CPU uint8 packed -> fp16 on GPU. + w_gpu = bnbF.dequantize_4bit( + w_cpu.data.to(DEVICE), qs_gpu, + quant_type=qs_gpu.quant_type, + ).to(torch.float16) + + # Build a plain fp16 Linear on GPU (no bias for bnb layers). + has_bias = (module.bias is not None) + out_feat, in_feat = w_gpu.shape + fp16_lin = nn.Linear(in_feat, out_feat, bias=has_bias, device=DEVICE, + dtype=torch.float16) + fp16_lin.weight = nn.Parameter(w_gpu, requires_grad=False) + if has_bias and module.bias is not None: + fp16_lin.bias = nn.Parameter( + module.bias.to(DEVICE, dtype=torch.float16), requires_grad=False) + + parent, attr = _get_parent_attr(layer, name) + saved[name] = (parent, attr, module) + setattr(parent, attr, fp16_lin) + del qs_gpu, w_gpu + + # FIX (v12): transformers>=5.5.3 uses MixtralTopKRouter with a bare nn.Parameter + # weight (not wrapped in Linear4bit). bitsandbytes wraps that parameter as Params4bit. + # F.linear(hidden_states, self.weight) with a Params4bit on GPU uses the raw uint8 + # blob, producing shape-mismatch errors. + # + # There are TWO distinct cases for Params4bit modules: + # (A) Actually quantized (e.g. attention proj, experts): w.quant_state is NOT None. + # These must be dequantized via dequantize_4bit. + # (B) Router gate (mlp.gate): w.quant_state IS None. bitsandbytes wraps the tiny + # gate weight in Params4bit but leaves it in full precision (the data IS the + # fp16/fp32 weight; no packed uint8). Just move data to GPU as a plain Parameter. + # Shape is small (~[8, 4096]), NOT a packed [N,1] blob. + params4bit_saved = {} + for mod_name, mod in layer.named_modules(): + w = getattr(mod, "weight", None) + if w is None or not _is_params4bit(w): + continue + # Skip modules already handled as Linear4bit above. + if mod_name in saved: + continue + + # Case (B): unquantized Params4bit (router gate). Robust discriminator: + # genuinely-4bit weights are packed uint8 blobs (shape [N,1]); the router + # gate is left full-precision, so its weight.data is a float matrix (NOT + # uint8). Keying on quant_state is unreliable: the gate's quant_state can be + # a degenerate non-None object whose .to() returns None (the v11/v13 crash). + if w.data.dtype != torch.uint8: + if not _gate_diag_done: + _gate_diag_done = True + nd = w.data.ndim + sh = list(w.data.shape) + dt = str(w.data.dtype) + print(f" [gate-unquant] mod={mod_name!r} shape={sh} dtype={dt}", + flush=True) + if nd != 2: + raise RuntimeError( + f"[gate-unquant] unexpected shape {sh} for {mod_name!r}; " + f"expected 2D float matrix for router gate. Aborting.") + # Gate weight is already a usable fp16/fp32 matrix -- just move to GPU. + w_gpu = nn.Parameter(w.data.to(DEVICE, dtype=torch.float16), + requires_grad=False) + params4bit_saved[mod_name] = (mod, w) # store original for restore + mod.weight = w_gpu + continue + + # Case (A): genuinely quantized Params4bit -- reconstruct QuantState and dequantize. + qs_gpu = None + try: + qs_gpu, _ = _build_qs_from_module(mod, DEVICE) + except Exception as e_qs: + # from_dict failed; try w.quant_state attrs directly. + print(f" [p4b-warn] from_dict failed for {mod_name!r} ({e_qs}); " + f"falling back to quant_state attr path", flush=True) + qs_pre = getattr(w, "quant_state", None) + if qs_pre is not None: + # bitsandbytes QuantState.to() moves tensors IN PLACE and returns + # None; capturing its return left qs_gpu=None (the v11/v13/v15 crash). + moved = qs_pre.to(DEVICE) + qs_gpu = moved if moved is not None else qs_pre + for attr in ("absmax", "offset", "code"): + t = getattr(qs_gpu, attr, None) + if isinstance(t, torch.Tensor): + setattr(qs_gpu, attr, t.to(DEVICE)) + s2 = getattr(qs_gpu, "state2", None) + if s2 is not None: + for attr in ("absmax", "code", "offset"): + t = getattr(s2, attr, None) + if isinstance(t, torch.Tensor): + setattr(s2, attr, t.to(DEVICE)) + if getattr(qs_gpu, "quant_type", None) is None: + qs_gpu.quant_type = "nf4" + + # Guard: qs_gpu must be non-None with a valid quant_type before dequantize. + if qs_gpu is None or getattr(qs_gpu, "quant_type", None) is None: + sd_keys = list(mod.state_dict().keys()) + qs_attr = getattr(w, "quant_state", None) + qs_has_absmax = isinstance(getattr(qs_attr, "absmax", None), torch.Tensor) + raise RuntimeError( + f"[p4b-fix-needed] could not reconstruct QuantState for {mod_name!r}; " + f"w.data.dtype={w.data.dtype} w.data.shape={list(w.data.shape)} " + f"sd_keys={sd_keys} w.quant_state type={type(qs_attr).__name__} " + f"qs.absmax_tensor={qs_has_absmax} qs.quant_type={getattr(qs_attr,'quant_type',None)}. " + f"qs_gpu={qs_gpu}") + + w_fp16 = bnbF.dequantize_4bit( + w.data.to(DEVICE), qs_gpu, quant_type=qs_gpu.quant_type, + ).to(torch.float16) + params4bit_saved[mod_name] = (mod, w) + mod.weight = nn.Parameter(w_fp16, requires_grad=False) + del qs_gpu, w_fp16 + + saved["__params4bit__"] = params4bit_saved + + # FIX (v19): fused Mixtral MoE experts store their weights as named PARAMETERS + # on the MixtralExperts module ("gate_up_proj", "down_proj"), each a single + # fused Params4bit NF4 blob, NOT as a child Linear's .weight. The loops above + # only catch Linear4bit modules and modules whose .weight is Params4bit, so + # these fused expert params reach the monkeypatched forward still NF4-packed + # (the [moe-shapes] gate_up_proj[0]=[58720256,1] crash). Dequantize each fused + # param to its 3D [num_experts, out, in] logical shape for the duration of the + # forward and restore the packed Params4bit afterward so streaming stays bounded. + expert_params_saved = {} + for em_name, em in layer.named_modules(): + gup = getattr(em, "gate_up_proj", None) + dnp = getattr(em, "down_proj", None) + if not (_is_params4bit(gup) and _is_params4bit(dnp)): + continue + + # num_experts / hidden_dim / intermediate_dim live on the experts module + # (MixtralExperts sets them from config.num_local_experts/hidden_size/ + # intermediate_size); fall back to model config if an attr is absent. + num_experts = getattr(em, "num_experts", None) + hidden_dim = getattr(em, "hidden_dim", None) + inter_dim = getattr(em, "intermediate_dim", None) + + for pname in ("gate_up_proj", "down_proj"): + p = getattr(em, pname) + if not _is_params4bit(p): + continue + # Reconstruct the QuantState on GPU. bitsandbytes QuantState.to() + # moves tensors IN PLACE and returns None, so never capture its return + # as the state (the v11/v13/v15 crash pattern); move the tensor attrs + # explicitly instead. + qs_pre = getattr(p, "quant_state", None) + if qs_pre is None: + raise RuntimeError( + f"[moe-expert-fix] {em_name}.{pname} is Params4bit but has no " + f"quant_state; cannot dequantize fused expert blob.") + moved = qs_pre.to(DEVICE) + qs_gpu = moved if moved is not None else qs_pre + for attr in ("absmax", "offset", "code"): + t = getattr(qs_gpu, attr, None) + if isinstance(t, torch.Tensor): + setattr(qs_gpu, attr, t.to(DEVICE)) + s2 = getattr(qs_gpu, "state2", None) + if s2 is not None: + for attr in ("absmax", "code", "offset"): + t = getattr(s2, attr, None) + if isinstance(t, torch.Tensor): + setattr(s2, attr, t.to(DEVICE)) + if getattr(qs_gpu, "quant_type", None) is None: + qs_gpu.quant_type = "nf4" + + w_fp16 = bnbF.dequantize_4bit( + p.data.to(DEVICE), qs_gpu, quant_type=qs_gpu.quant_type, + ).to(torch.float16) + + # Ensure first dim is the expert axis: monkeypatch indexes [e]. + # gate_up_proj logical shape is [num_experts, 2*inter, hidden]; + # down_proj is [num_experts, hidden, inter]. dequantize_4bit returns + # the quant_state.shape (already 3D for these fused params); reshape + # defensively if it came back flat or 2D. + if w_fp16.ndim != 3 and num_experts: + if pname == "gate_up_proj" and hidden_dim: + w_fp16 = w_fp16.reshape(num_experts, -1, hidden_dim) + elif pname == "down_proj" and inter_dim: + w_fp16 = w_fp16.reshape(num_experts, -1, inter_dim) + else: + w_fp16 = w_fp16.reshape(num_experts, -1, w_fp16.shape[-1]) + + if not _expert_param_diag_done: + print(f" [moe-expert-fix] {em_name}.{pname}: packed={list(p.data.shape)} " + f"-> fp16 dequant shape={list(w_fp16.shape)} " + f"(num_experts={num_experts})", flush=True) + + expert_params_saved[f"{em_name}.{pname}"] = (em, pname, p) + setattr(em, pname, nn.Parameter(w_fp16, requires_grad=False)) + del qs_gpu, w_fp16 + + _expert_param_diag_done = True + + saved["__expert_params__"] = expert_params_saved + + # Move remaining non-bnb params/buffers to GPU. + for n, p in layer.named_parameters(recurse=True): + if p.device.type == "cpu": + p.data = p.data.to(DEVICE) + for n, b in layer.named_buffers(recurse=True): + if b.device.type == "cpu": + b.data = b.data.to(DEVICE) + return saved + + +def _restore_layer_from_gpu(layer, saved): + """Restore all Linear4bit modules and Params4bit weights; move non-bnb params to CPU.""" + # Restore fused MoE expert params (gate_up_proj / down_proj -- v19 fix). + # Drop the GPU fp16 dequant and put the packed Params4bit back so the next + # layer's forward does not leave 8 fp16 expert blobs resident on GPU. + expert_params_saved = saved.pop("__expert_params__", {}) + for key, (em, pname, orig_p) in expert_params_saved.items(): + fp16_p = getattr(em, pname) + del fp16_p + setattr(em, pname, orig_p) + + # Restore Params4bit weight params (MoE gate -- v6 fix). + params4bit_saved = saved.pop("__params4bit__", {}) + for mod_name, (mod, orig_w) in params4bit_saved.items(): + fp16_w = mod.weight + del fp16_w + mod.weight = orig_w + + # Restore Linear4bit modules. + for name, (parent, attr, orig_module) in saved.items(): + fp16_lin = getattr(parent, attr) + del fp16_lin.weight + if fp16_lin.bias is not None: + del fp16_lin.bias + del fp16_lin + setattr(parent, attr, orig_module) + + # Move remaining non-bnb params/buffers back to CPU (skip Linear4bit -- they live on CPU) + for n, p in layer.named_parameters(recurse=True): + if p.device.type != "cpu": + parts = n.split(".") + parent_mod = layer + for part in parts[:-1]: + parent_mod = getattr(parent_mod, part, parent_mod) + if not _is_linear4bit(parent_mod) and not _is_params4bit(p): + p.data = p.data.to("cpu") + for n, b in layer.named_buffers(recurse=True): + if b.device.type != "cpu": + b.data = b.data.to("cpu") + torch.cuda.empty_cache() + + +def _get_parent_attr(root, dotted_name): + """Given 'a.b.c', return (root.a.b, 'c').""" + parts = dotted_name.split(".") + parent = root + for p in parts[:-1]: + parent = getattr(parent, p) + return parent, parts[-1] + + +def install_streaming_hooks(model): + """Register pre/post hooks on each MixtralDecoderLayer. + pre_hook -> dequantize all Linear4bit to fp16 on GPU (NF4 stays CPU-resident) + post_hook -> restore Linear4bit modules, free GPU fp16 temps.""" + global _hook_handles + for h in _hook_handles: + h.remove() + _hook_handles.clear() + + decoder_layers = None + if hasattr(model, "model") and hasattr(model.model, "layers"): + decoder_layers = model.model.layers + elif hasattr(model, "layers"): + decoder_layers = model.layers + + if decoder_layers is None: + raise RuntimeError("Cannot find decoder layers in model; hook install failed.") + + def pre_hook(module, args): + saved = _dequant_layer_to_gpu(module) + module._dequant_saved = saved + + def post_hook(module, args, output): + saved = getattr(module, "_dequant_saved", {}) + _restore_layer_from_gpu(module, saved) + module._dequant_saved = {} + + for layer in decoder_layers: + h1 = layer.register_forward_pre_hook(pre_hook) + h2 = layer.register_forward_hook(post_hook) + _hook_handles.extend([h1, h2]) + + print(f" [hooks] installed on {len(decoder_layers)} decoder layers (dequant-fp16 path)", + flush=True) + return decoder_layers + + +def remove_streaming_hooks(): + global _hook_handles + for h in _hook_handles: + h.remove() + _hook_handles.clear() + + +def move_non_layer_parts(model, dev): + """Move embeddings, final norm, and lm_head.""" + m = model.model if hasattr(model, "model") else model + if hasattr(m, "embed_tokens"): + m.embed_tokens.to(dev) + if hasattr(m, "norm"): + m.norm.to(dev) + if hasattr(model, "lm_head"): + model.lm_head.to(dev) + + +def run_forward_streaming(model, input_ids, past_key_values=None, attention_mask=None, + cache_position=None, use_cache=True): + """Standard model forward; hooks handle layer streaming automatically. + We only need to move the non-layer parts around the call.""" + move_non_layer_parts(model, DEVICE) + try: + with torch.no_grad(): + out = model(input_ids.to(DEVICE), + past_key_values=past_key_values, + attention_mask=attention_mask, + cache_position=cache_position, + use_cache=use_cache) + finally: + move_non_layer_parts(model, "cpu") + return out + + +def streaming_generate(model, tok, input_ids, past_key_values, max_new_tokens, pad_id): + """Greedy decode from a prefilled (and possibly compressed) KV cache. + + transformers 4.56 model.generate() crashes in prepare_inputs_for_generation / + _cache_dependant_input_preparation when handed a manually prefilled cache + (IndexError on an empty cache_position list). We sidestep generation/utils + entirely with a manual single-token greedy loop driven by run_forward_streaming, + so the per-layer streaming hooks fire on every forward exactly as in prefill. + + Works for the FP16 baseline (uncompressed prefilled cache) and the K3V2/K4V2 + compressed caches: it reads the current cache length from past_key_values + each step, so it does not assume an uncompressed length. Returns the prompt + ids concatenated with the generated ids (shape [1, prompt+gen]) so the caller + slice gen_out[0, qids.shape[1]:] keeps working.""" + eos_id = tok.eos_token_id + prompt_ids = input_ids.to(DEVICE) + n_prompt = prompt_ids.shape[1] + generated = [] + move_non_layer_parts(model, DEVICE) + try: + kv = past_key_values + # First decode step: feed the whole question prompt at once, positioned + # immediately after the prefilled cache. + past_len = cache_len(kv) + cur_ids = prompt_ids + for step in range(max_new_tokens): + seq_len = cur_ids.shape[1] + cache_position = torch.arange(past_len, past_len + seq_len, + dtype=torch.long, device=DEVICE) + attention_mask = torch.ones(1, past_len + seq_len, + dtype=torch.long, device=DEVICE) + with torch.no_grad(): + out = model(cur_ids, + past_key_values=kv, + attention_mask=attention_mask, + cache_position=cache_position, + use_cache=True) + kv = to_dyn(out.past_key_values) + next_id = int(out.logits[0, -1].argmax().item()) + del out + generated.append(next_id) + if step >= 3 and eos_id is not None and next_id == eos_id: + break + past_len = cache_len(kv) + cur_ids = torch.tensor([[next_id]], dtype=torch.long, device=DEVICE) + finally: + move_non_layer_parts(model, "cpu") + gen_t = torch.tensor([generated], dtype=torch.long, device=prompt_ids.device) \ + if generated else torch.empty(1, 0, dtype=torch.long, device=prompt_ids.device) + return torch.cat([prompt_ids, gen_t], dim=1) + + +# --------------------------------------------------------------------------- +# Utilities +# --------------------------------------------------------------------------- + +def free(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def honest_bpe(kb, vb): + return (kb + 0.125 + vb + 0.125) / 2.0 + + +def to_dyn(c): + return DynamicCache.from_legacy_cache(c) if isinstance(c, tuple) else c + + +def get_kv(c, l): + if hasattr(c, "key_cache"): + return c.key_cache[l], c.value_cache[l] + return c.layers[l].keys, c.layers[l].values + + +def set_kv(c, l, k, v): + if hasattr(c, "key_cache"): + c.key_cache[l] = k + c.value_cache[l] = v + else: + c.layers[l].keys = k + c.layers[l].values = v + + +def n_layers_kv(c): + return len(c.layers) if hasattr(c, "layers") else len(c.key_cache) + + +def cache_len(c): + lens = [] + for l in range(n_layers_kv(c)): + kl, _ = get_kv(c, l) + if kl is not None: + lens.append(kl.shape[2]) + return max(lens) if lens else 0 + + +def clone_cache(kv): + new = DynamicCache() + nl = n_layers_kv(kv) + for l in range(nl): + k, v = get_kv(kv, l) + kk = k.detach().clone() if k is not None else None + vv = v.detach().clone() if v is not None else None + if hasattr(new, "key_cache"): + new.key_cache.append(kk) + new.value_cache.append(vv) + else: + new.update(kk, vv, l) + return new + + +def resolve_config(src): + _cfg = AutoConfig.from_pretrained(src) + text_cfg = getattr(_cfg, "text_config", _cfg) + n_layers = text_cfg.num_hidden_layers + n_kv = getattr(text_cfg, "num_key_value_heads", 8) + head_dim = getattr(text_cfg, "head_dim", None) + if head_dim is None: + head_dim = text_cfg.hidden_size // text_cfg.num_attention_heads + rope_theta = float(getattr(text_cfg, "rope_theta", 1e6)) + glob_ = set(range(n_layers)) + cfg_state.clear() + cfg_state.update({ + "n_layers": n_layers, "global": glob_, "n_kv_heads": n_kv, + "head_dim": head_dim, "rope_theta": rope_theta, + }) + print(f"[config] {n_layers}L all-global; KV heads={n_kv}, head_dim={head_dim}, " + f"rope_theta={rope_theta}", flush=True) + + +# --------------------------------------------------------------------------- +# E8 KV quantization (same path as all other kernels) +# --------------------------------------------------------------------------- + +def evict_quantize(kv, kb, vb, pb, first_call_log=False): + Hcache = {} + + def H(d): + p2 = 1 + while p2 < d: + p2 *= 2 + if p2 not in Hcache: + Hcache[p2] = hadamard_matrix(p2).cpu() + return Hcache[p2], p2 + + nl = n_layers_kv(kv) + plen = cache_len(kv) + glob_ = cfg_state["global"] + rope_theta = cfg_state["rope_theta"] + sorted_g = sorted(glob_) + protected = (set(sorted_g[:pb]) | set(sorted_g[-pb:])) if pb > 0 else set() + qcount = 0 + + for l in range(nl): + if l not in glob_ or l in protected: + continue + kl, vl = get_kv(kv, l) + if kl is None or kl.shape[2] != plen: + continue + d_l = kl.shape[-1] + nk = kl.shape[1] + Hm, p2 = H(d_l) + h_pad = p2 - d_l + + for is_k, t_in, b_bits in [(True, kl, kb), (False, vl, vb)]: + t = t_in[0].float().cpu().clone() + lvl = 2 ** b_bits + if is_k: + t = inverse_rope(t, seq_offset=0, base=rope_theta, + partial_rotary_factor=None) + for h in range(nk): + th = t[h] + pad_in = F.pad(th, (0, h_pad)) if h_pad > 0 else th + rot = pad_in @ Hm.T + amax = rot.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + sc = amax / (lvl / 2) + norm = rot / sc + e8p = (8 - p2 % 8) % 8 + if e8p > 0: + norm = F.pad(norm, (0, e8p)) + grp = norm.reshape(-1, 8) + lp = E8Lattice.nearest_point(grp).clamp(-lvl / 2, lvl / 2) + co = lp.reshape(-1, norm.shape[-1]) + if e8p > 0: + co = co[..., :p2] + t[h] = (co * sc @ Hm)[..., :d_l] + if is_k: + t = forward_rope(t, seq_offset=0, base=rope_theta, + partial_rotary_factor=None) + tdev, tdt = t_in.device, t_in.dtype + if is_k: + set_kv(kv, l, t.unsqueeze(0).to(dtype=tdt, device=tdev), vl) + else: + kn, _ = get_kv(kv, l) + set_kv(kv, l, kn, t.unsqueeze(0).to(dtype=tdt, device=tdev)) + qcount += 1 + + if first_call_log: + print(f" [quantized {qcount} layers, protected {len(protected)} (pb={pb})]", + flush=True) + return kv, qcount + + +# --------------------------------------------------------------------------- +# PPL scoring +# --------------------------------------------------------------------------- + +def score_continuation(model, kv, cont_t): + seg_nll, seg_tok = 0.0, 0 + cur = 0 + while cur < cont_t.shape[1] - 1: + end = min(cur + CONT_CHUNK, cont_t.shape[1]) + chunk = cont_t[:, cur:end] + past_len = (kv.get_seq_length() if hasattr(kv, "get_seq_length") else cache_len(kv)) + cache_pos = torch.arange(past_len, past_len + chunk.shape[1], device=DEVICE) + out = run_forward_streaming(model, chunk, past_key_values=kv, + cache_position=cache_pos, use_cache=True) + logits = out.logits[0, :-1, :] + kv = to_dyn(out.past_key_values) + tgts = chunk[0, 1:] + m = tgts.shape[0] + if m > 0: + pos = 0 + while pos < m: + e2 = min(pos + LOGIT_CHUNK, m) + lc = logits[pos:e2].float() + wsum = F.cross_entropy(lc, tgts[pos:e2].to(DEVICE), + reduction="sum").item() + if math.isfinite(wsum): + seg_nll += wsum + seg_tok += (e2 - pos) + del lc + pos = e2 + del out, logits + cur = end + free() + return seg_nll, seg_tok + + +def stats(deltas): + import statistics as st + n = len(deltas) + if n == 0: + return {"n": 0} + mean = sum(deltas) / n + sigma = st.stdev(deltas) if n > 1 else 0.0 + sem = sigma / math.sqrt(n) if n > 1 else 0.0 + z = (mean / sem) if sem > 0 else None + return { + "n": n, "mean_delta_pct": mean, "paired_sigma_pct": sigma, + "sem_pct": sem, "z": z, + "significant_at_2sigma": (abs(z) >= 2.0) if z is not None else False, + "n_negative_segments": sum(1 for d in deltas if d < 0), + } + + +def run_ppl_phase(model, tok, text, result, save): + bos_id = tok.bos_token_id + all_ids = tok(text, return_tensors="pt", add_special_tokens=False).input_ids[0] + n_possible = all_ids.shape[0] // SEG_LEN + n_use = min(N_SEGS_TARGET, n_possible) + if n_use < N_SEGS_MIN: + result["errors"]["ppl_segments"] = ( + f"only {n_possible} non-overlapping {SEG_LEN}-tok windows; need >={N_SEGS_MIN}") + save() + return + segs = [all_ids[i * SEG_LEN:(i + 1) * SEG_LEN] for i in range(n_use)] + segs = [s for s in segs if s.shape[0] == SEG_LEN] + print(f"[ppl] {all_ids.shape[0]} tokens -> {n_possible} windows, using {len(segs)}", + flush=True) + result["ppl"]["n_segments"] = len(segs) + save() + + fp_ppls, fp_nll_total, fp_tok_total = [], 0.0, 0 + quant_deltas = {lbl: [] for lbl, *_ in PPL_CONFIGS} + diag_logged = False + t0_phase = time.time() + + for si, seg in enumerate(segs): + try: + t0_seg = time.time() + prefix = seg[:PREFIX_LEN] + cont = seg[PREFIX_LEN:] + if bos_id is not None: + prefix = torch.cat([torch.tensor([bos_id], dtype=prefix.dtype), prefix]) + pre = prefix.unsqueeze(0) + cont_t = cont.unsqueeze(0) + if cont_t.shape[1] <= 1: + continue + + # Prefill -> FP16(NF4-weight) KV master on GPU + out_pre = run_forward_streaming(model, pre, use_cache=True) + fp_kv_master = to_dyn(out_pre.past_key_values) + del out_pre + free() + + # FP16 baseline score + fp_kv = clone_cache(fp_kv_master) + fp_nll, fp_tok = score_continuation(model, fp_kv, cont_t) + del fp_kv + free() + if fp_tok == 0: + del fp_kv_master + free() + continue + fp_seg_ppl = math.exp(fp_nll / fp_tok) + fp_ppls.append(fp_seg_ppl) + fp_nll_total += fp_nll + fp_tok_total += fp_tok + + if not diag_logged: + print(f" [diag] prefix={pre.shape[1]} cont={cont_t.shape[1]} " + f"fp_tok={fp_tok} fp_seg_ppl={fp_seg_ppl:.3f} " + f"seg_time={time.time()-t0_seg:.1f}s", flush=True) + + # KV-quant configs (paired, same master KV) + for lbl, kb, vb, pb in PPL_CONFIGS: + qkv = clone_cache(fp_kv_master) + qkv, _ = evict_quantize(qkv, kb, vb, pb, + first_call_log=(not diag_logged)) + q_nll, q_tok = score_continuation(model, qkv, cont_t) + del qkv + free() + if q_tok == 0: + continue + q_ppl = math.exp(q_nll / q_tok) + quant_deltas[lbl].append(100.0 * (q_ppl - fp_seg_ppl) / fp_seg_ppl) + + diag_logged = True + del fp_kv_master + free() + + if si % 10 == 0 or si == len(segs) - 1: + k3 = quant_deltas["K3V2_pb0"] + k4 = quant_deltas["K4V2_pb0"] + k3s = f"{k3[-1]:.3f}" if k3 else "n/a" + k4s = f"{k4[-1]:.3f}" if k4 else "n/a" + print(f" seg {si}: fp={fp_seg_ppl:.3f} " + f"K3={k3s} K4={k4s} " + f"phase_elapsed={time.time()-t0_phase:.0f}s", flush=True) + if si % 20 == 0: + result["ppl"]["k3v2_pb0"] = stats(quant_deltas["K3V2_pb0"]) + result["ppl"]["k4v2_pb0"] = stats(quant_deltas["K4V2_pb0"]) + if fp_tok_total > 0: + result["ppl"]["base_nf4weight_ppl"] = math.exp( + fp_nll_total / fp_tok_total) + save() + except Exception: + traceback.print_exc() + result["errors"][f"ppl_seg_{si}"] = traceback.format_exc()[-500:] + free() + + result["ppl"]["per_segment_fp_ppl"] = fp_ppls + result["ppl"]["k3v2_pb0"] = stats(quant_deltas["K3V2_pb0"]) + result["ppl"]["k3v2_pb0"]["per_segment_delta_pct"] = quant_deltas["K3V2_pb0"] + result["ppl"]["k4v2_pb0"] = stats(quant_deltas["K4V2_pb0"]) + result["ppl"]["k4v2_pb0"]["per_segment_delta_pct"] = quant_deltas["K4V2_pb0"] + if fp_tok_total > 0: + result["ppl"]["base_nf4weight_ppl"] = math.exp(fp_nll_total / fp_tok_total) + + # Hard scored-n floor: a wall-clock kill or per-seg exception can leave <60 + # segments scored even when n_segments target looks complete. + n_scored = min(len(v) for v in quant_deltas.values()) if quant_deltas else 0 + result["ppl"]["n_scored"] = n_scored + result["ppl"]["meets_n_floor"] = (n_scored >= N_SEGS_MIN) + if n_scored < N_SEGS_MIN and not result.get("blocked"): + result["blocked"] = f"PARTIAL: only {n_scored} segments scored (<{N_SEGS_MIN} floor)" + + k3, k4 = result["ppl"]["k3v2_pb0"], result["ppl"]["k4v2_pb0"] + result["ppl"]["summary"] = ( + f"n={result['ppl'].get('n_segments')} target; {k3.get('n')} scored. " + f"base_nf4weight_ppl={result['ppl'].get('base_nf4weight_ppl')}. " + f"K3V2_pb0: mean={k3.get('mean_delta_pct')}% +/-{k3.get('sem_pct')}% " + f"z={k3.get('z')} sig2s={k3.get('significant_at_2sigma')}. " + f"K4V2_pb0: mean={k4.get('mean_delta_pct')}% +/-{k4.get('sem_pct')}% " + f"z={k4.get('z')} sig2s={k4.get('significant_at_2sigma')}.") + save() + print(f"[ppl] {result['ppl']['summary']}", flush=True) + + +# --------------------------------------------------------------------------- +# NIAH +# --------------------------------------------------------------------------- + +def run_niah(model, tok, hay_ids, label, kb, vb, pb, do_quant, ctx, first_call_log=False): + rng = random.Random(SEED + ctx) + pad_id = tok.pad_token_id or tok.eos_token_id + cells = [] + t_start = time.time() + + for trial in range(N_TRIALS): + def n7(): + return f"{rng.randint(10**6, 10**7 - 1)}" + + keys = [n7() for _ in range(N_PAIRS)] + values = [] + while len(values) < N_PAIRS: + v = f"{rng.randint(100, 999)}" + if v not in values: + values.append(v) + target_idx = rng.randint(0, N_PAIRS - 1) + question = f"What is the value of {keys[target_idx]}?" + msgs = [{"role": "user", "content": question}] + try: + p_text = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) + except Exception: + p_text = question + "\nAnswer: " + qids = tok(p_text, return_tensors="pt", add_special_tokens=False).input_ids + bos = tok.bos_token_id + if bos is not None and qids.shape[1] > 0 and qids[0, 0].item() == bos: + qids = qids[:, 1:] + + pair_strs = [f"The value for {k} is {v}." for k, v in zip(keys, values)] + pair_ids = [tok(s, return_tensors="pt", add_special_tokens=False).input_ids[0] + for s in pair_strs] + pair_total = sum(p.shape[0] for p in pair_ids) + hay_take = ctx - pair_total - qids.shape[1] + if hay_take <= 0: + cells.append({"trial": trial, "error": "context too small", "recall": False}) + continue + hay = hay_ids[:hay_take] + depths = [(i + 1) / (N_PAIRS + 1) for i in range(N_PAIRS)] + positions = sorted([(int(d * hay.shape[0]), i) for i, d in enumerate(depths)]) + out_chunks, last = [], 0 + for pos, pi in positions: + out_chunks.append(hay[last:pos]) + out_chunks.append(pair_ids[pi]) + last = pos + out_chunks.append(hay[last:]) + prefix = torch.cat(out_chunks).unsqueeze(0) + + recall = None + ans = "" + try: + out1 = run_forward_streaming(model, prefix, use_cache=True) + kv = to_dyn(out1.past_key_values) + del out1 + free() + + if do_quant: + kv, qc = evict_quantize(kv, kb, vb, pb, + first_call_log=(trial == 0 and first_call_log)) + if trial == 0: + print(f" [quant {qc} layers pb={pb}]", flush=True) + + gen_out = streaming_generate(model, tok, qids, kv, MAX_NEW_TOKENS, pad_id) + gen_ids = gen_out[0, qids.shape[1]:] + ans = tok.decode(gen_ids, skip_special_tokens=True).strip() + target_value = values[target_idx] + al = ans.lower() + # Recall rule: HIT only when the model emits the 3-digit VALUE as a + # standalone answer ("is " or a word-boundary ), never as a + # fragment of the echoed 7-digit key (\b cannot land inside a digit run). + target_key = keys[target_idx] + m = re.search(r"\bis\s+(\d{3})\b", al) + if m: + recall = (m.group(1) == target_value) + else: + stripped = al.replace(target_key, " ") + recall = re.search(rf"\b{re.escape(target_value)}\b", stripped) is not None + del gen_out + except Exception as e: + traceback.print_exc() + ans = f"{type(e).__name__}: {str(e)[:140]}" + + print(f" t{trial} tgt={keys[target_idx]} val={values[target_idx]} " + f"{'YES' if recall else 'NO' if recall is not None else 'ERR'} " + f"ans={ans[:60]!r}", flush=True) + cells.append({"trial": trial, "recall": recall, + "target_key": keys[target_idx], + "target_value": values[target_idx], "ans": ans[:160]}) + if kv is not None: + del kv + free() + + hits = sum(1 for c in cells if c.get("recall")) + elapsed = time.time() - t_start + print(f" [{label} ctx={ctx}] {hits}/{len(cells)} elapsed={elapsed:.0f}s", flush=True) + return {"config": label, "ctx": ctx, "hits": hits, "n": len(cells), + "elapsed_s": int(elapsed), "cells": cells} + + +def run_niah_phase(model, tok, result, save): + """NIAH at ctx=NIAH_CONTEXT: FP16(NF4-weight) baseline gate, then K3V2_pb0 + + K4V2_pb0 only if baseline hits>0. Each config re-prefills the same needle + layout (RNG seeded by SEED+ctx) so configs are matched. Saves before exit.""" + print(f"\n[phase B] NIAH at ctx={NIAH_CONTEXT}", flush=True) + try: + ds_train = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") + hay_text = "\n\n".join(r["text"] for r in ds_train if r["text"].strip()) + hay_ids = tok(hay_text, return_tensors="pt", truncation=True, + max_length=80000).input_ids[0] + + ctx = NIAH_CONTEXT + print(f"\n-- NIAH: FP16(NF4-weights) baseline ctx={ctx} --", flush=True) + fp16_key = f"FP16_ctx{ctx}" + res = run_niah(model, tok, hay_ids, "FP16", 0, 0, 0, False, ctx, + first_call_log=True) + result["niah_results"][fp16_key] = res + baseline_hits = res.get("hits", 0) + result["niah_fp16_baseline_hits"] = baseline_hits + save() + free() + + if baseline_hits == 0: + note = (f"DEGENERATE at ctx={ctx}: FP16(NF4-weight) hits=0/{N_TRIALS}. " + f"No compression deltas credited.") + print(f" *** {note} ***", flush=True) + result["niah_results"][f"baseline_gate_ctx{ctx}"] = note + else: + for lbl, kb, vb, pb, do_quant in NIAH_CONFIGS: + if lbl == "FP16": + continue + print(f"\n-- NIAH: {lbl} ctx={ctx} --", flush=True) + key = f"{lbl}_ctx{ctx}" + try: + res = run_niah(model, tok, hay_ids, lbl, kb, vb, pb, do_quant, ctx, + first_call_log=True) + res["fp16_baseline_hits"] = baseline_hits + result["niah_results"][key] = res + except Exception: + traceback.print_exc() + tb = traceback.format_exc() + result["niah_results"][key] = {"error": tb[-400:]} + result["errors"][f"niah_{key}"] = tb[-400:] + save() + free() + save() + except Exception: + traceback.print_exc() + result["errors"]["niah_phase"] = traceback.format_exc()[-800:] + save() + + +# --------------------------------------------------------------------------- +# Model load -- CPU-resident NF4, no accelerate device_map="auto" +# --------------------------------------------------------------------------- + +def load_model_cpu_nf4(): + """Load ALL NF4 weights onto CPU RAM. No meta tensors, no accelerate offload. + Forward hooks handle per-layer GPU streaming. + + bitsandbytes Linear4bit supports CPU residence and device migration: + the quantized uint8 data + quant_state are ordinary tensors that .to() moves fine. + The forward() kernel requires CUDA, so the hook must move the layer to GPU before + each forward -- this is exactly what install_streaming_hooks() provides. + """ + bnb = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + ) + last_err = None + for cand in MODEL_CANDIDATES: + try: + print(f"[model] loading {cand} onto CPU (device_map={{'':'cpu'}}) ...", + flush=True) + tok = AutoTokenizer.from_pretrained(cand) + t0 = time.time() + + # Attempt 1: CPU-resident NF4 via device_map={"":"cpu"} + # This bypasses accelerate's dispatch_model entirely. All weights on CPU. + try: + model = AutoModelForCausalLM.from_pretrained( + cand, + quantization_config=bnb, + torch_dtype=torch.float16, # activations/hidden_states fp16; matches bnb_4bit_compute_dtype + device_map={"": "cpu"}, + attn_implementation="eager", + low_cpu_mem_usage=True, + ) + load_mode = "cpu_resident" + except Exception as e1: + # Attempt 2: If device_map={"":"cpu"} fails for bitsandbytes, + # try loading without device_map at all (defaults to CPU for bnb <0.43) + print(f" [cpu-resident-fail] {type(e1).__name__}: {str(e1)[:200]}", + flush=True) + print(f" [model] retrying without device_map (bnb default cpu) ...", + flush=True) + model = AutoModelForCausalLM.from_pretrained( + cand, + quantization_config=bnb, + torch_dtype=torch.float16, # activations/hidden_states fp16; matches bnb_4bit_compute_dtype + attn_implementation="eager", + low_cpu_mem_usage=True, + ) + load_mode = "cpu_default" + + model.eval() + elapsed_load = time.time() - t0 + # Verify all parameters are on CPU (the invariant for streaming) + devices = {str(p.device) for p in model.parameters()} + print(f" loaded {cand} in {elapsed_load:.1f}s " + f"mode={load_mode} devices={devices}", flush=True) + if "cuda" in str(devices): + # Some part landed on GPU -- hooks may conflict; note it but proceed. + print(f" [warn] some params on GPU, hooks will still override per-layer", + flush=True) + return model, tok, cand, load_mode + except Exception as e: + tb = traceback.format_exc() + print(f" [load-fail] {cand}: {type(e).__name__}: {str(e)[:200]}", flush=True) + last_err = tb + free() + raise RuntimeError( + f"All Mixtral sources failed.\nLast error:\n{last_err[-800:] if last_err else 'unknown'}") + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + import transformers + gpu_info = None + if torch.cuda.is_available(): + p0 = torch.cuda.get_device_properties(0) + gpu_info = {"name": p0.name, "sm": f"{p0.major}{p0.minor}", + "n_gpus": torch.cuda.device_count(), + "total_memory_gb": round(p0.total_memory / 1024**3, 1)} + print(f"[start] GPU={gpu_info} transformers={transformers.__version__} " + f"CUDA_VISIBLE_DEVICES={os.environ.get('CUDA_VISIBLE_DEVICES')}", flush=True) + + result = { + "model_name": ("Mixtral-8x7B-Instruct-v0.1 NF4-weights " + "(manual-stream v26 NIAH-only, fast)"), + "run_mode": "NIAH_ONLY" if NIAH_ONLY else "full", + "approach": ( + "CPU-resident NF4 weights (device_map={'':'cpu'}). " + "PyTorch forward hooks on each MixtralDecoderLayer: pre_hook dequantizes " + "all Linear4bit modules to fp16 nn.Linear on GPU via QuantState.from_dict " + "(reconstructed from raw module state_dict buffers; v4 fix for CPU-resident " + "models where qs.absmax is None and data lives in separate buffers). " + "v12 fix: MoE router gate (mlp.gate, bare Params4bit with w.quant_state=None) " + "is NOT NF4-quantized; bitsandbytes keeps it full-precision. Detected by " + "w.quant_state is None; moved to GPU as plain fp16 nn.Parameter, bypassing " + "dequantize_4bit entirely. Shape ~[8, 4096]. Restored from original Params4bit " + "in post_hook. Genuinely-quantized Params4bit (quant_state not None) still " + "go through the full dequantize_4bit path. " + "post_hook restores Linear4bit modules and frees GPU fp16 temps. " + "No accelerate device_map=auto, no offload_folder. " + "MoE routing is inside MixtralDecoderLayer so hooks cover all 8 experts " + "including fused gate_up_proj/down_proj batched weights. " + "KV cache + residual stream stay on GPU between layers."), + "weights": "NF4 (bitsandbytes load_in_4bit nf4 double-quant fp16-compute)", + "baseline_caveat": ( + "NF4-WEIGHTS baseline. Mixtral-8x7B 47B cannot fit fp16 on T4 (14.6GB). " + "KV-quant deltas are paired off the same NF4-weights reference on identical " + "continuation tokens -- a clean KV-only delta."), + "ppl_harness": ( + f"AQUA-iso paired prefix PPL. Wikitext-2-raw test, non-overlapping " + f"{SEG_LEN}-tok windows ({PREFIX_LEN}-prefix + {CONT_LEN}-cont). " + f"Per segment: prefill BOS+{PREFIX_LEN} prefix ONCE, score {CONT_LEN}-cont " + f"-> base NF4-weight PPL. Re-quantize SAME cached KV per config, rescore " + f"SAME cont. Paired per-seg delta%. Stats: mean +/- SEM(ddof=1), z, sig@2s."), + "niah_harness": ( + f"chat-template numeric NIAH ctx={NIAH_CONTEXT} n_pairs={N_PAIRS} " + f"n_trials={N_TRIALS} max_new_tokens={MAX_NEW_TOKENS}. " + f"FP16(NF4-weight) baseline gated."), + "prefix_len": PREFIX_LEN, "cont_len": CONT_LEN, "seg_len": SEG_LEN, + "n_segs_target": N_SEGS_TARGET, "n_segs_min": N_SEGS_MIN, + "niah_context": NIAH_CONTEXT, "n_pairs": N_PAIRS, "n_trials": N_TRIALS, + "delta_definition": "per_segment_paired KV-only off NF4-weight baseline", + "transformers": transformers.__version__, "gpu": gpu_info, + "model_source": None, "load_mode": None, + "ppl": {"base_nf4weight_ppl": None, "k3v2_pb0": None, "k4v2_pb0": None, + "per_segment_fp_ppl": [], "n_segments": 0}, + "niah_results": {}, "configs": [], "errors": {}, + "run_at_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + } + for lbl, kb, vb, pb in PPL_CONFIGS: + result["configs"].append({"label": lbl, "kb": kb, "vb": vb, "pb": pb, + "bpe": honest_bpe(kb, vb)}) + + def save(): + with open(OUT_PATH, "w") as f: + json.dump(result, f, indent=2) + save() + + # Load model (CPU-resident NF4, no accelerate dispatch) + try: + model, tok, src, load_mode = load_model_cpu_nf4() + except Exception: + traceback.print_exc() + result["errors"]["model_load"] = traceback.format_exc()[-1200:] + result["blocked"] = "BLOCKED: CPU-resident NF4 load failed. See errors.model_load." + save() + print(f"[BLOCKED] {result['blocked']}", flush=True) + return + + result["model_source"] = src + result["load_mode"] = load_mode + if tok.pad_token is None: + tok.pad_token = tok.eos_token + resolve_config(src) + result["config_meta"] = {k: (sorted(v) if isinstance(v, set) else v) + for k, v in cfg_state.items()} + save() + + # MoE path (v20): transformers==4.56.2 stores Mixtral experts as per-expert + # nn.Linear (block_sparse_moe.experts.{i}.w1/w2/w3) that bitsandbytes wraps as + # Linear4bit natively, matching the on-disk per-expert safetensors layout. The + # streaming dequant hook catches them via named_modules() with NO monkeypatch + # and NO grouped_mm/batched_mm. No config._experts_implementation toggle (5.x-only). + moe_path_used = "native-per-expert-4.56" + result["moe_path"] = moe_path_used + print(f"[moe-path] {moe_path_used} (native MixtralExperts forward, no monkeypatch)", + flush=True) + + # Discover and print the 4.x MoE submodule path from the loaded model once, + # and record a short sample of the expert module types for the result JSON. + expert_module_types = {} + try: + layers0 = (model.model.layers if hasattr(model, "model") else model.layers) + bsm = layers0[0].block_sparse_moe + moe_children = [n for n, _ in bsm.named_children()] + print(f" [moe-submodules] layers[0].block_sparse_moe children={moe_children}", + flush=True) + if hasattr(bsm, "experts"): + e0 = bsm.experts[0] + e0_children = [n for n, _ in e0.named_children()] + print(f" [moe-submodules] experts[0] children={e0_children}", flush=True) + for proj in ("w1", "w2", "w3"): + m = getattr(e0, proj, None) + if m is not None: + tn = type(m).__name__ + wt = type(getattr(m, "weight", None)).__name__ + expert_module_types[f"experts[0].{proj}"] = f"{tn}(weight={wt})" + if hasattr(bsm, "gate"): + expert_module_types["gate"] = type(bsm.gate).__name__ + print(f" [moe-submodules] expert_module_types={expert_module_types}", flush=True) + except Exception as _e_disc: + expert_module_types = {"error": f"{type(_e_disc).__name__}: {str(_e_disc)[:160]}"} + print(f" [moe-submodules] discovery failed: {_e_disc}", flush=True) + result["expert_module_types"] = expert_module_types + save() + + # Install forward hooks: per-layer CPU<->GPU streaming + print("[hooks] installing per-layer GPU streaming hooks ...", flush=True) + try: + decoder_layers = install_streaming_hooks(model) + result["n_decoder_layers_hooked"] = len(decoder_layers) + print(f" [hooks] OK: {len(decoder_layers)} layers hooked", flush=True) + except Exception: + traceback.print_exc() + result["errors"]["hook_install"] = traceback.format_exc()[-600:] + result["blocked"] = "BLOCKED: hook install failed. See errors.hook_install." + save() + return + save() + + # Smoke-check + timing probe: one forward on a short prompt. + # MUST pass before full run: assert logits finite + MoE routes 2 experts. + print("[smoke-check] single forward (64 tok) -- finite logits + MoE router check ...", + flush=True) + try: + probe_ids = torch.randint(0, 32000, (1, 64)) + t_probe = time.time() + + # Patch MoE router to count expert selections (Mixtral uses SparseMoeBlock) + moe_selections = [] + router_hooks = [] + def _router_hook(mod, inp, out): + # out is typically (hidden_states, router_logits) or similar + # SparseMoeBlock returns (hidden, router_logits) where router_logits is + # (batch*seq, n_experts); top-2 selected per token + try: + if isinstance(out, (tuple, list)) and len(out) >= 2: + rl = out[1] # (B*T, n_experts) + if rl is not None and rl.ndim == 2: + top2 = rl.topk(2, dim=-1).indices # (B*T, 2) + moe_selections.append(top2.cpu()) + except Exception: + pass + + for layer in (model.model.layers if hasattr(model, "model") else model.layers): + if hasattr(layer, "block_sparse_moe"): + h = layer.block_sparse_moe.register_forward_hook(_router_hook) + router_hooks.append(h) + + sc_out = run_forward_streaming(model, probe_ids, use_cache=False) + for h in router_hooks: + h.remove() + fwd_time = time.time() - t_probe + + logits = sc_out.logits + logits_ok = (logits is not None and + torch.isfinite(logits).all().item()) + n_experts_seen = set() + if moe_selections: + n_experts_seen = set(torch.cat(moe_selections).unique().tolist()) + experts_ok = (len(n_experts_seen) >= 2) + + smoke_msg = (f"logits={'FINITE' if logits_ok else 'NON-FINITE'} " + f"shape={list(logits.shape) if logits is not None else None} " + f"MoE-experts-seen={sorted(n_experts_seen)[:8]} " + f"experts_ok={experts_ok} " + f"fwd_time={fwd_time:.2f}s") + print(f" [smoke-check] {smoke_msg}", flush=True) + result["smoke_check"] = "pass" if (logits_ok and experts_ok) else ( + f"FAIL: {smoke_msg}") + result["smoke_check_detail"] = { + "logits_finite": bool(logits_ok), + "moe_experts_seen": sorted(n_experts_seen), + "experts_ok": bool(experts_ok), + "fwd_time_s": round(fwd_time, 2), + "msg": smoke_msg, + } + result["timing_probe_64tok_fwd_s"] = round(fwd_time, 2) + + if not logits_ok or not experts_ok: + result["errors"]["smoke_check"] = ( + f"SMOKE-CHECK FAILED: {smoke_msg}. Aborting to avoid wasting GPU hours.") + result["blocked"] = "BLOCKED: smoke-check failed. See errors.smoke_check." + save() + print(f"[BLOCKED] {result['blocked']}", flush=True) + return + + print(f" [smoke-check PASSED] logits finite, {len(n_experts_seen)} experts routed, " + f"est {PREFIX_LEN}-tok fwd: ~{fwd_time*PREFIX_LEN/64:.0f}s, " + f"est 80-seg PPL: ~{fwd_time*PREFIX_LEN/64*80*4/60:.0f}min", flush=True) + del sc_out, logits + free() + except Exception: + traceback.print_exc() + result["smoke_check"] = traceback.format_exc()[-600:] + result["errors"]["smoke_check"] = traceback.format_exc()[-600:] + result["blocked"] = "BLOCKED: smoke-check exception. See errors.smoke_check." + save() + print(f"[BLOCKED] {result['blocked']}", flush=True) + return + save() + + if SMOKE_ONLY: + # v20: validate load + dequant + native MoE forward cheaply; skip metrics. + result["ppl"] = None + result["niah_results"] = {} + result["smoke_only_note"] = "skipped: SMOKE_ONLY" + result["ppl_note"] = "skipped: SMOKE_ONLY" + result["niah_note"] = "skipped: SMOKE_ONLY" + remove_streaming_hooks() + save() + sc = result.get("smoke_check") + tp = result.get("timing_probe_64tok_fwd_s") + emt = result.get("expert_module_types") + nlh = result.get("n_decoder_layers_hooked") + print("\n========== SMOKE_ONLY RESULT ==========", flush=True) + print(f" transformers={result.get('transformers')} " + f"moe_path={result.get('moe_path')}", flush=True) + print(f" load_mode={result.get('load_mode')} layers_hooked={nlh}", flush=True) + print(f" expert_module_types={emt}", flush=True) + print(f" smoke_check={sc} fwd_time_s={tp if tp is not None else 'n/a'}", + flush=True) + print(f"[done] result -> {OUT_PATH}", flush=True) + return + + if NIAH_ONLY: + # v23: skip PPL entirely; PPL comes from the parallel v21 run. Run ONLY + # the NIAH phase at ctx>=4096 (FP16 baseline + K3V2_pb0 + K4V2_pb0). + result["ppl"] = None + result["ppl_note"] = "skipped: NIAH_ONLY" + result["niah_note"] = ( + f"NIAH-only run at ctx={NIAH_CONTEXT}; PPL half produced by v21.") + save() + run_niah_phase(model, tok, result, save) + remove_streaming_hooks() + save() + print("\n========== NIAH_ONLY RESULT ==========", flush=True) + print(f" transformers={result.get('transformers')} " + f"moe_path={result.get('moe_path')} gpu={result.get('gpu')}", flush=True) + for k, v in result.get("niah_results", {}).items(): + if isinstance(v, dict) and "hits" in v: + fb = v.get("fp16_baseline_hits") + fb_s = f" (fp16_baseline_hits={fb})" if fb is not None else "" + print(f" {k}: {v.get('hits')}/{v.get('n')}{fb_s}", flush=True) + else: + print(f" {k}: {v}", flush=True) + print(f"[done] result -> {OUT_PATH}", flush=True) + return + + # Data + print("[data] loading wikitext-2-raw ...", flush=True) + ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") + text = "\n\n".join(r["text"] for r in ds if r["text"].strip()) + + # PHASE A: AQUA-iso paired PPL + print("\n[phase A] AQUA-iso paired PPL", flush=True) + try: + run_ppl_phase(model, tok, text, result, save) + except Exception: + traceback.print_exc() + result["errors"]["ppl_phase"] = traceback.format_exc()[-800:] + save() + + # PHASE B: NIAH at ctx=4096 + run_niah_phase(model, tok, result, save) + + remove_streaming_hooks() + save() + print("\n========== FINAL RESULT ==========", flush=True) + printable = {k: v for k, v in result.items() if k != "ppl"} + printable["ppl_summary"] = result["ppl"].get("summary") + printable["ppl_base_nf4weight_ppl"]= result["ppl"].get("base_nf4weight_ppl") + printable["ppl_k3v2_pb0"] = { + k: v for k, v in (result["ppl"].get("k3v2_pb0") or {}).items() + if k != "per_segment_delta_pct"} + printable["ppl_k4v2_pb0"] = { + k: v for k, v in (result["ppl"].get("k4v2_pb0") or {}).items() + if k != "per_segment_delta_pct"} + print(json.dumps(printable, indent=2), flush=True) + print(f"[done] result -> {OUT_PATH}", flush=True) + + +if __name__ == "__main__": + main() diff --git a/experiments/kaggle/results/nq_mixtral_stream.json b/experiments/kaggle/results/nq_mixtral_stream.json new file mode 100644 index 0000000..9010f1b --- /dev/null +++ b/experiments/kaggle/results/nq_mixtral_stream.json @@ -0,0 +1,488 @@ +{ + "model": "Mixtral-8x7B-Instruct-v0.1 (unsloth NF4-weights mirror: unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit)", + "approach": "transformers==4.56.2 native per-expert Linear4bit MoE. CPU-resident NF4 weights (device_map={'':'cpu'}); per-layer forward hooks dequantize each MixtralDecoderLayer's Linear4bit modules to fp16 on a single T4, run the layer, then free the GPU fp16 temps. CPU<->GPU per-layer streaming so the 47B model runs on one 14.6GB T4. E8 K3V2/K4V2 boundary-protect=0 (pb=0) KV quantization.", + "baseline_caveat": "PPL baseline is NF4-WEIGHTS FP16-activation, NOT strict FP16-weights (Mixtral-8x7B 47B cannot fit fp16 on a 14.6GB T4). Same caveat as the Qwen3 iso row. KV-quant deltas are paired off the same NF4-weights reference on identical continuation tokens, so they are a clean KV-only delta. The NIAH FP16 baseline is also NF4-weights (FP16 activations).", + "ppl": { + "base_nf4weight_ppl": 4.1462352957765285, + "n_segments": 80, + "protocol": "AQUA-iso paired, prefix=1024 cont=512", + "harness": "AQUA-iso paired prefix PPL. Wikitext-2-raw test, non-overlapping 1536-tok windows (1024-prefix + 512-cont). Per segment: prefill BOS+1024 prefix ONCE, score 512-cont -> base NF4-weight PPL. Re-quantize SAME cached KV per config, rescore SAME cont. Paired per-seg delta%. Stats: mean +/- SEM(ddof=1), z, sig@2s.", + "k3v2_pb0": { + "mean_delta_pct": 0.6294319394785511, + "sem_pct": 0.08403146693712102, + "n": 80, + "z": 7.490431411241956, + "significant_at_2sigma": true, + "paired_sigma_pct": 0.7516002892817146, + "n_negative_segments": 16, + "per_segment_delta_pct": [ + 0.7246854029483342, + 1.512641892735202, + 0.011980807997388546, + 1.156510137478797, + 1.1050401548796311, + 0.06544010148064011, + 1.4408210304200768, + 2.0010047303072644, + 1.005571788355941, + 2.264815969872861, + 0.04099512060036811, + 0.11465291137955612, + 1.1071112523541724, + 0.2793415973013189, + 0.5528936650092504, + 1.0395175680127957, + -0.46135143594924183, + 0.07357381833397962, + -0.20041465147164367, + 1.0415450829553292, + -0.007255873230559847, + 1.3295994040213037, + -0.7215703988405346, + 2.2209667040052032, + -0.13042976895369804, + 0.6750344364516279, + 1.2831010099955784, + 0.7367229534737959, + 2.5567925147005885, + 0.6451991050106739, + 1.0059307038722423, + 0.6102471715764747, + 1.1250222789369888, + 2.163677308645605, + 1.0036445243276158, + 1.3213212434246884, + 0.14608932281805317, + -0.7289517738425727, + 1.4952683528805393, + 0.09201308948416832, + 0.7374448912809098, + 0.2167421319933924, + 1.4488363145467098, + 0.4019655583074462, + -0.21639256064848147, + 0.568892652781818, + 0.43316529567298545, + 0.2792457763641716, + -0.48442263466791846, + 0.007931301546312277, + 0.22202110533232938, + 1.11318591121679, + 0.5420489520030642, + 0.27450874293690236, + -0.4580164727843389, + 1.0851917733346235, + 0.5987113717704589, + 0.06621698930106992, + -0.05598198984767538, + 0.6075613750540906, + 0.25976608364668463, + -0.36870968645256463, + 1.2757611090602288, + -0.08868856587939195, + 0.8970116698412253, + -0.5518194453456025, + 0.6515225189290834, + 1.6101246181354363, + -0.41490454506952784, + -0.5566063039155368, + 1.3979382075879971, + 0.38311548057026595, + -0.6551073877681329, + 0.8832860600893029, + 0.8028618897324032, + 1.0874224464792903, + 0.5808334959530415, + 1.3784591638972985, + 1.3021200882375839, + 1.3925125213005458 + ] + }, + "k4v2_pb0": { + "mean_delta_pct": 0.32031743604303753, + "sem_pct": 0.06188810936039969, + "n": 80, + "z": 5.175750872880903, + "significant_at_2sigma": true, + "paired_sigma_pct": 0.553544078115179, + "n_negative_segments": 21, + "per_segment_delta_pct": [ + 1.1135270932397365, + 0.7794344738535812, + 0.7021722473923661, + 0.7756187044803535, + 0.47900945717443444, + -0.3344043785790275, + 0.7348399233739489, + 0.6524662564066374, + 1.11797464209308, + 1.6097499020728, + 0.04930610251035868, + 0.6200476323192023, + 0.5358585024128624, + -0.2177393320138629, + 0.0013079047583600804, + 0.6344586270914692, + 0.038802474843545695, + -0.029456172146362737, + -0.43109191644294775, + 0.28615709630808106, + 0.19424676803422752, + 0.18409887680712866, + -0.6812388371247198, + 0.9756957900055446, + 0.44057009714001427, + 1.066270762696707, + 0.15171746404245745, + 0.537941957440969, + 1.229772576515426, + 0.24687551840457292, + 0.1396421703507281, + 0.5607727282717033, + -0.1725291289712261, + 0.4197936494023282, + 0.5315656404373521, + 0.48880313535037573, + -0.524739067296165, + -0.09296669296835328, + 1.0454071287330802, + 0.14157986085971633, + -0.33366035608546324, + 0.7056303932358883, + 0.4270865086881388, + 0.5220019231715531, + 0.09038121145379567, + 0.03468019195253987, + 0.553404103911444, + -0.2490019630620344, + -0.20524226411870056, + -0.11737297666147693, + 0.19291240525034004, + 0.10043591597812049, + 0.45928697428787446, + 0.23781180318391934, + -0.8289464163835601, + 0.8058659527861941, + 0.3521500413554135, + 0.37670702057862093, + -0.4807615570819441, + 0.32558588115888976, + -0.2832622266759109, + -0.7556031581003373, + 0.8075696996725478, + 0.5174935434067924, + 1.2743337158616774, + -0.3547795075945675, + -0.40622698314223854, + 0.4927700866610999, + -0.21828161092606896, + -0.5797355807022895, + 0.04755542131665192, + 0.37611954892820265, + -0.7748818267764793, + 0.7148807652381425, + 0.7220326429811952, + 1.2423798913516, + 0.6571730537746701, + 1.212092332804549, + 1.595859075703724, + 0.3676315667800051 + ] + }, + "per_segment_fp_ppl": [ + 2.9972048524921515, + 4.303900940321052, + 5.663442715479996, + 4.92098447264755, + 6.172061045750284, + 3.887229342780449, + 4.609005425693543, + 3.7661475598664453, + 1.7808021378121703, + 1.705353903211459, + 2.7154454561273154, + 3.6835066687837132, + 4.735580678953886, + 4.620966467825501, + 4.432031417467792, + 4.233276429848484, + 3.423229385877704, + 4.617257263747235, + 4.439757464499797, + 4.153537115481176, + 4.003420901681173, + 4.85938706647505, + 6.400163142604264, + 6.157074162072783, + 4.314916346676763, + 5.618097018733549, + 3.172598092584736, + 3.4662672785461055, + 1.5884358932555092, + 5.387783030806805, + 4.443155310373734, + 4.809632135679847, + 3.7262870487790685, + 5.109882630278806, + 5.120230074681932, + 4.382019711044371, + 6.7733040441997385, + 4.103728174518242, + 6.098674348329545, + 4.906307911321739, + 4.844676056019404, + 7.652768244779736, + 4.700834859284556, + 4.29548436062827, + 3.2688288114308994, + 3.876705381732775, + 3.153874364067241, + 2.7479148012306087, + 4.376157371832899, + 4.675679609471414, + 4.768483174424751, + 6.26768978808305, + 6.41445215249057, + 3.1633437055577422, + 2.071717054588711, + 2.5019753878504685, + 3.42433967328955, + 2.463567547351034, + 2.589200311667274, + 2.4212714906854482, + 3.9190969970058904, + 2.7349623494185127, + 7.889300265029461, + 6.2953202734566664, + 8.746025834057464, + 2.537860680682463, + 5.843634968082025, + 4.396148848393049, + 4.415689692758936, + 5.63451581268799, + 4.277804243069241, + 7.062510198124643, + 3.645264842583138, + 3.931264393676739, + 4.962011383700835, + 3.6752465345742413, + 4.594828482004163, + 4.00249429821308, + 2.9910945851048187, + 3.896963502367186 + ] + }, + "niah": { + "ctx": 4096, + "n_pairs": 8, + "n_trials": 5, + "max_new_tokens": 20, + "harness": "chat-template numeric NIAH ctx=4096 n_pairs=8 n_trials=5 max_new_tokens=20. FP16(NF4-weight) baseline gated.", + "recall_rule_note": "recall=True when the generated answer contains the target value string for the queried key. FP16(NF4-weight) baseline is gated; deltas credited only when the baseline passes.", + "per_config": { + "FP16": { + "hits": 5, + "n": 5, + "cells": [ + { + "trial": 0, + "recall": true, + "target_key": "2591602", + "target_value": "743", + "ans": "The value of 2591602 is 743." + }, + { + "trial": 1, + "recall": true, + "target_key": "3252379", + "target_value": "954", + "ans": "The value of 3252379 is 954." + }, + { + "trial": 2, + "recall": true, + "target_key": "2452762", + "target_value": "988", + "ans": "The value of 2452762 is 988." + }, + { + "trial": 3, + "recall": true, + "target_key": "5138650", + "target_value": "947", + "ans": "The value of 5138650 is 947." + }, + { + "trial": 4, + "recall": true, + "target_key": "8078916", + "target_value": "663", + "ans": "The value of 8078916 is 663." + } + ] + }, + "K3V2_pb0": { + "hits": 5, + "n": 5, + "cells": [ + { + "trial": 0, + "recall": true, + "target_key": "2591602", + "target_value": "743", + "ans": "The value of 2591602 is 743." + }, + { + "trial": 1, + "recall": true, + "target_key": "3252379", + "target_value": "954", + "ans": "The value of 3252379 is 954." + }, + { + "trial": 2, + "recall": true, + "target_key": "2452762", + "target_value": "988", + "ans": "The value of 2452762 is 988." + }, + { + "trial": 3, + "recall": true, + "target_key": "5138650", + "target_value": "947", + "ans": "The value of 5138650 is 947." + }, + { + "trial": 4, + "recall": true, + "target_key": "8078916", + "target_value": "663", + "ans": "The value of 8078916 is 663." + } + ] + }, + "K4V2_pb0": { + "hits": 5, + "n": 5, + "cells": [ + { + "trial": 0, + "recall": true, + "target_key": "2591602", + "target_value": "743", + "ans": "The value of 2591602 is 743." + }, + { + "trial": 1, + "recall": true, + "target_key": "3252379", + "target_value": "954", + "ans": "The value of 3252379 is 954." + }, + { + "trial": 2, + "recall": true, + "target_key": "2452762", + "target_value": "988", + "ans": "The value of 2452762 is 988." + }, + { + "trial": 3, + "recall": true, + "target_key": "5138650", + "target_value": "947", + "ans": "The value of 5138650 is 947." + }, + { + "trial": 4, + "recall": true, + "target_key": "8078916", + "target_value": "663", + "ans": "The value of 8078916 is 663." + } + ] + } + } + }, + "configs": [ + { + "label": "K3V2_pb0", + "kb": 3, + "vb": 2, + "pb": 0, + "bpe": 2.625 + }, + { + "label": "K4V2_pb0", + "kb": 4, + "vb": 2, + "pb": 0, + "bpe": 3.125 + } + ], + "config_meta": { + "n_layers": 32, + "global": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31 + ], + "n_kv_heads": 8, + "head_dim": 128, + "rope_theta": 1000000.0 + }, + "provenance": { + "ppl_half": { + "kernel": "nq-mixtral-stream-v21", + "kaggle_account": "jooandrgomesmarques", + "run_at_utc": "2026-06-17T09:01:32Z", + "produces": "base_nf4weight_ppl + k3v2_pb0/k4v2_pb0 PPL deltas and all per-segment arrays" + }, + "niah_half": { + "kernel": "nq-mixtral-stream-v26", + "kaggle_account": "jagmardrop", + "run_at_utc": "2026-06-19T09:24:06Z", + "produces": "NIAH ctx=4096 per-config hits (FP16/K3V2_pb0/K4V2_pb0)" + }, + "transformers": "4.56.2", + "gpu": { + "name": "Tesla T4", + "sm": "75", + "n_gpus": 1, + "total_memory_gb": 14.6 + }, + "model_source": "unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit", + "moe_path": "native-per-expert-4.56" + }, + "superseded_note": "v21's OWN niah_results (FP16_ctx4096 hits=0/8, IndexError 'index -1 is out of bounds for dimension 0 with size 0') is a pre-fix generate-crash artifact, NOT a real model result. It is SUPERSEDED by the v26 NIAH run (FP16 5/5). Never cite the v21 0/8 NIAH number." +} \ No newline at end of file