From 1a089007784df58a96ba98f8c725be0886042a7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Andr=C3=A9=20Gomes=20Marques?= Date: Tue, 16 Jun 2026 08:35:24 +0200 Subject: [PATCH] C6: Mixtral-8x7B T4x2 AQUA-iso paired PPL + NIAH kernel (v3, memory-fit) v1 OOMed (93GB bf16 download), v2 OOMed (15GiB/card, paired-clone on GPU). v3 fixes: max_memory={0:11GiB, 1:11GiB, cpu:26GiB} + offload_folder pushes expert blocks to CPU/disk; CPU-master pattern keeps one GPU KV copy at a time enabling AQUA-iso PAIRED deltas (n=80 >= contract floor 60). Mirror: unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit (24.5GB, Apache-2.0, gated=false, no HF_TOKEN). Local smoke: AST OK, E8 ordering OK, json OK. --- .../nq_mixtral_t4fit/kernel-metadata.json | 20 + .../nq_mixtral_t4fit/nq_mixtral_t4fit.py | 781 ++++++++++++++++++ 2 files changed, 801 insertions(+) create mode 100644 experiments/kaggle/nq_mixtral_t4fit/kernel-metadata.json create mode 100644 experiments/kaggle/nq_mixtral_t4fit/nq_mixtral_t4fit.py diff --git a/experiments/kaggle/nq_mixtral_t4fit/kernel-metadata.json b/experiments/kaggle/nq_mixtral_t4fit/kernel-metadata.json new file mode 100644 index 0000000..85b74e4 --- /dev/null +++ b/experiments/kaggle/nq_mixtral_t4fit/kernel-metadata.json @@ -0,0 +1,20 @@ +{ + "id": "jooandrgomesmarques/nq-mixtral-t4fit", + "title": "nq-mixtral-t4fit", + "code_file": "nq_mixtral_t4fit.py", + "language": "python", + "kernel_type": "script", + "is_private": true, + "enable_gpu": true, + "enable_tpu": false, + "enable_internet": true, + "keywords": [ + "gpu" + ], + "dataset_sources": [], + "kernel_sources": [], + "competition_sources": [], + "model_sources": [], + "docker_image": "gcr.io/kaggle-private-byod/python@sha256:00377cd1b3d470a605bc5b0ceca79969e369644e9b36802242a1c70e627372f9", + "machine_shape": "NvidiaTeslaT4" +} diff --git a/experiments/kaggle/nq_mixtral_t4fit/nq_mixtral_t4fit.py b/experiments/kaggle/nq_mixtral_t4fit/nq_mixtral_t4fit.py new file mode 100644 index 0000000..38d938c --- /dev/null +++ b/experiments/kaggle/nq_mixtral_t4fit/nq_mixtral_t4fit.py @@ -0,0 +1,781 @@ +# NexusQuant Mixtral-8x7B T4x2 memory-fit kernel (v3) -- AQUA-iso paired PPL + NIAH. +# +# WHY PRIOR VERSIONS OOMed: +# v1: downloaded 93GB bf16 repo (OOM mid-download). +# v2 (lean): re-prefilled each config separately (no paired clone), small n=25, +# still OOMed: max_memory=15GiB/card left no headroom for 4K NIAH prefill. +# +# v3 MEMORY STRATEGY: +# 1. max_memory={0:"11GiB", 1:"11GiB", "cpu":"26GiB"} -- pushes several expert +# blocks to CPU RAM. Slower but fits: ~24.5GB NF4 weights + activations. +# 2. offload_folder on disk as safety valve for anything that overflows CPU too. +# 3. CPU-master approach for PAIRED PPL: prefill -> move entire kv_master to CPU. +# For each config: clone master on CPU -> move clone to GPU -> quantize on GPU +# -> score continuation -> del GPU clone + free(). One GPU copy at a time. +# This enables AQUA-iso PAIRED deltas (same tokens, same prefix KV) without +# holding 3 GPU copies simultaneously. +# 4. Short per-segment continuation (512 tok) keeps activation memory bounded. +# 5. NIAH at ctx=4096: full 4K prefill on GPU. If OOM, retry at ctx=2048 (lower +# but still >=2048 for NIAH non-degeneracy). Each trial: prefill -> (quant) -> +# generate -> del kv -> free(). +# 6. SAVE-BEFORE-PRINT: json.dump before any f-string/print summary. +# 7. Guard every f-string against None with 'if x is not None else "n/a"'. +# +# Architecture: Mixtral-8x7B is dense GQA per layer (MoE routes FFN only; KV cache +# is standard). 32 layers, 8 KV heads, head_dim=128, rope_theta=1e6, no SWA, no +# partial-rotary. Same E8 K/V path as Mistral-7B; every layer is a quant target. +# +# Protocol: AQUA-iso paired. Wikitext-2-raw test, non-overlapping 1536-tok windows +# (1024 prefix + 512 continuation). n_target=80 >= contract floor 60. +# Per segment: prefill ONCE -> master on CPU; per config clone master to GPU, +# quantize, score identical cont tokens, paired delta% = (q_ppl-fp_ppl)/fp_ppl*100. +# Stats: mean, sigma(ddof=1), SEM, z=mean/SEM, n_negative, sig@2sigma. +# +# Mirror: unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit (pre-quantized NF4, ~24.5GB, +# Apache-2.0, gated=false, no HF_TOKEN needed). + +import sys, os, gc, json, math, time, re, traceback, random, subprocess + +os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + +print("Installing transformers/accelerate/bitsandbytes + nexusquant ...", flush=True) +subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "numpy<2"]) +subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-U", + "transformers>=5.5.3", "accelerate>=1.1.1", "bitsandbytes>=0.43.0"]) +subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", + "git+https://github.com/jagmarques/nexusquant.git@main"]) + +import torch +import torch.nn.functional as F +from transformers import (AutoModelForCausalLM, AutoTokenizer, AutoConfig, + DynamicCache, BitsAndBytesConfig) +from datasets import load_dataset + +from nexusquant.core.e8_lattice import E8Lattice +from nexusquant.core.hadamard import hadamard_matrix +from nexusquant.core.rope_utils import inverse_rope, forward_rope + +SEED = 42 +random.seed(SEED) +torch.manual_seed(SEED) + +# Pre-quantized NF4 mirror: ~24.5GB, Apache-2.0, gated=false, no HF_TOKEN. +MODEL_CANDIDATES = [ + "unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit", + "unsloth/Mixtral-8x7B-v0.1-bnb-4bit", +] +MODEL_NAME = "Mixtral-8x7B-Instruct-v0.1 NF4-weights (v3 paired)" + +PREFIX_LEN = 1024 +CONT_LEN = 512 +SEG_LEN = PREFIX_LEN + CONT_LEN +N_SEGS_TARGET = 80 # >= contract floor 60 +N_SEGS_MIN = 60 +LOGIT_CHUNK = 128 # small to limit GPU memory peak during scoring +CONT_CHUNK = 256 # continuation forward chunk +# Low GPU caps push expert blocks to CPU; overflow to offload_folder on disk. +MAX_MEMORY = {0: "11GiB", 1: "11GiB", "cpu": "26GiB"} +OFFLOAD_DIR = "/kaggle/working/offload" +OUT_PATH = "/kaggle/working/nq_mixtral_t4fit.json" + +HF_TOKEN = os.environ.get("HF_TOKEN") # ungated mirror; None is fine + +# NIAH params. Try 4K first; if any OOM retry at 2K. +NIAH_CTX_PREF = 4096 +NIAH_CTX_FALL = 2048 +NIAH_TRIALS = 8 +NIAH_PAIRS = 8 +MAX_NEW_TOKENS = 40 + +# PPL quant configs (paired off NF4-weight baseline). +PPL_CONFIGS = [ + ("K3V2_pb0", 3, 2, 0), + ("K4V2_pb0", 4, 2, 0), +] +# NIAH configs: FP16 baseline (gate) then quant configs. +NIAH_CONFIGS = [ + ("FP16", 0, 0, 0, False), + ("K3V2_pb0", 3, 2, 0, True), + ("K4V2_pb0", 4, 2, 0, True), +] + +cfg_state = {} + + +def free(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def honest_bpe(kb, vb): + return (kb + 0.125 + vb + 0.125) / 2.0 + + +def to_dyn(c): + return DynamicCache.from_legacy_cache(c) if isinstance(c, tuple) else c + + +def get_kv(c, l): + if hasattr(c, "key_cache"): + return c.key_cache[l], c.value_cache[l] + return c.layers[l].keys, c.layers[l].values + + +def set_kv(c, l, k, v): + if hasattr(c, "key_cache"): + c.key_cache[l] = k + c.value_cache[l] = v + else: + c.layers[l].keys = k + c.layers[l].values = v + + +def n_layers_kv(c): + return len(c.layers) if hasattr(c, "layers") else len(c.key_cache) + + +def cache_len(c): + lens = [] + for l in range(n_layers_kv(c)): + kl, _ = get_kv(c, l) + if kl is not None: + lens.append(kl.shape[2]) + return max(lens) if lens else 0 + + +def clone_cache_cpu(kv): + """Clone a KV cache to CPU tensors; returns a DynamicCache with cpu tensors.""" + new = DynamicCache() + nl = n_layers_kv(kv) + for l in range(nl): + k, v = get_kv(kv, l) + kk = k.detach().cpu().clone() if k is not None else None + vv = v.detach().cpu().clone() if v is not None else None + if hasattr(new, "key_cache"): + new.key_cache.append(kk) + new.value_cache.append(vv) + else: + new.update(kk, vv, l) + return new + + +def clone_cache_gpu(cpu_kv, device): + """Clone a CPU KV cache to GPU; returns a fresh DynamicCache on device.""" + new = DynamicCache() + nl = n_layers_kv(cpu_kv) + for l in range(nl): + k, v = get_kv(cpu_kv, l) + kk = k.to(device).detach().clone() if k is not None else None + vv = v.to(device).detach().clone() if v is not None else None + if hasattr(new, "key_cache"): + new.key_cache.append(kk) + new.value_cache.append(vv) + else: + new.update(kk, vv, l) + return new + + +def resolve_config(src): + _cfg = AutoConfig.from_pretrained(src, token=HF_TOKEN) + text_cfg = getattr(_cfg, "text_config", _cfg) + n_layers = text_cfg.num_hidden_layers + n_kv = getattr(text_cfg, "num_key_value_heads", 8) + head_dim = getattr(text_cfg, "head_dim", None) + if head_dim is None: + head_dim = text_cfg.hidden_size // text_cfg.num_attention_heads + rope_theta = float(getattr(text_cfg, "rope_theta", 1e6)) + cfg_state.clear() + cfg_state.update({ + "n_layers": n_layers, + "global": set(range(n_layers)), + "n_kv_heads": n_kv, + "head_dim": head_dim, + "rope_theta": rope_theta, + }) + print(f"[config] {src}: {n_layers}L all-global; KV heads={n_kv}, " + f"head_dim={head_dim}, rope_theta={rope_theta}", flush=True) + + +def nq_quantize_kv(kv, kb, vb, pb, first_call_log=False): + """In-place E8 quantize of a GPU-side DynamicCache. Returns modified kv.""" + Hcache = {} + + def H(d): + p2 = 1 + while p2 < d: + p2 *= 2 + if p2 not in Hcache: + Hcache[p2] = hadamard_matrix(p2).cpu() + return Hcache[p2], p2 + + nl = n_layers_kv(kv) + plen = cache_len(kv) + glob_ = cfg_state["global"] + rope_theta = cfg_state["rope_theta"] + sorted_g = sorted(glob_) + protected = (set(sorted_g[:pb]) | set(sorted_g[-pb:])) if pb > 0 else set() + qcount = 0 + + for l in range(nl): + if l not in glob_ or l in protected: + continue + kl, vl = get_kv(kv, l) + if kl is None or kl.shape[2] != plen: + continue + d_l = kl.shape[-1] + nk = kl.shape[1] + Hm, p2 = H(d_l) + h_pad = p2 - d_l + + for is_k, t_in, b_bits in [(True, kl, kb), (False, vl, vb)]: + t = t_in[0].float().cpu().clone() + lvl = 2 ** b_bits + if is_k: + t = inverse_rope(t, seq_offset=0, base=rope_theta, + partial_rotary_factor=None) + for h in range(nk): + th = t[h] + pad_in = F.pad(th, (0, h_pad)) if h_pad > 0 else th + rot = pad_in @ Hm.T + amax = rot.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + sc = amax / (lvl / 2) + norm = rot / sc + e8p = (8 - p2 % 8) % 8 + if e8p > 0: + norm = F.pad(norm, (0, e8p)) + grp = norm.reshape(-1, 8) + lp = E8Lattice.nearest_point(grp).clamp(-lvl / 2, lvl / 2) + co = lp.reshape(-1, norm.shape[-1]) + if e8p > 0: + co = co[..., :p2] + t[h] = (co * sc @ Hm)[..., :d_l] + if is_k: + t = forward_rope(t, seq_offset=0, base=rope_theta, + partial_rotary_factor=None) + tdev, tdt = t_in.device, t_in.dtype + if is_k: + set_kv(kv, l, t.unsqueeze(0).to(dtype=tdt, device=tdev), vl) + else: + kn, _ = get_kv(kv, l) + set_kv(kv, l, kn, t.unsqueeze(0).to(dtype=tdt, device=tdev)) + qcount += 1 + + if first_call_log: + print(f" [quantized {qcount} layers, protected {len(protected)} (pb={pb})]", + flush=True) + return kv, qcount + + +def score_continuation(model, kv, cont_t, device): + seg_nll, seg_tok = 0.0, 0 + cur = 0 + while cur < cont_t.shape[1] - 1: + end = min(cur + CONT_CHUNK, cont_t.shape[1]) + chunk = cont_t[:, cur:end] + past_len = (kv.get_seq_length() if hasattr(kv, "get_seq_length") + else cache_len(kv)) + cache_pos = torch.arange(past_len, past_len + chunk.shape[1], device=device) + with torch.no_grad(): + out = model(chunk, past_key_values=kv, cache_position=cache_pos, use_cache=True) + logits = out.logits[0, :-1, :] + kv = to_dyn(out.past_key_values) + tgts = chunk[0, 1:] + m = tgts.shape[0] + if m > 0: + pos = 0 + while pos < m: + e2 = min(pos + LOGIT_CHUNK, m) + lc = logits[pos:e2].float() + wsum = F.cross_entropy(lc, tgts[pos:e2].to(device), reduction="sum").item() + if math.isfinite(wsum): + seg_nll += wsum + seg_tok += (e2 - pos) + del lc + pos = e2 + del out, logits + cur = end + free() + return seg_nll, seg_tok + + +def stats(deltas): + import statistics as st + n = len(deltas) + if n == 0: + return {"n": 0} + mean = sum(deltas) / n + sigma = st.stdev(deltas) if n > 1 else 0.0 + sem = sigma / math.sqrt(n) if n > 1 else 0.0 + z = (mean / sem) if sem > 0 else None + return { + "n": n, + "mean_delta_pct": mean, + "paired_sigma_pct": sigma, + "sem_pct": sem, + "z": z, + "significant_at_2sigma": (abs(z) >= 2.0) if z is not None else False, + "n_negative_segments": sum(1 for d in deltas if d < 0), + } + + +def oom_msg(tb): + low = tb.lower() + if "out of memory" in low or "cuda oom" in low: + return "CUDA_OOM: " + tb[-500:] + return tb[-500:] + + +def load_model(): + bnb = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + ) + os.makedirs(OFFLOAD_DIR, exist_ok=True) + last_err = None + for cand in MODEL_CANDIDATES: + try: + print(f"[model] loading {cand} NF4 ...", flush=True) + tok = AutoTokenizer.from_pretrained(cand, token=HF_TOKEN) + t0 = time.time() + model = AutoModelForCausalLM.from_pretrained( + cand, + quantization_config=bnb, + device_map="auto", + max_memory=MAX_MEMORY, + offload_folder=OFFLOAD_DIR, + offload_state_dict=True, + attn_implementation="eager", + low_cpu_mem_usage=True, + token=HF_TOKEN, + ) + model.eval() + m0 = torch.cuda.memory_allocated(0) / 1e9 if torch.cuda.is_available() else 0 + m1 = torch.cuda.memory_allocated(1) / 1e9 if (torch.cuda.is_available() and + torch.cuda.device_count() > 1) else 0 + print(f" loaded {cand} in {time.time()-t0:.1f}s " + f"mem0={m0:.1f}GB mem1={m1:.1f}GB", flush=True) + return model, tok, cand + except RuntimeError as e: + tb = traceback.format_exc() + print(f" [OOM/load-fail] {cand}: {oom_msg(tb)[:300]}", flush=True) + last_err = tb + free() + except Exception as e: + tb = traceback.format_exc() + print(f" [load-fail] {cand}: {type(e).__name__}: {str(e)[:200]}", flush=True) + last_err = tb + free() + raise RuntimeError( + f"All Mixtral sources failed. Last error:\n{(last_err or '')[-800:]}") + + +def run_ppl_phase(model, tok, device, result, save): + bos_id = tok.bos_token_id + print("[data] loading wikitext-2-raw test ...", flush=True) + ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="test") + text = "\n\n".join(r["text"] for r in ds if r["text"].strip()) + all_ids = tok(text, return_tensors="pt", add_special_tokens=False).input_ids[0] + n_possible = all_ids.shape[0] // SEG_LEN + n_use = min(N_SEGS_TARGET, n_possible) + if n_use < N_SEGS_MIN: + result["errors"]["ppl_segments"] = ( + f"only {n_possible} non-overlapping {SEG_LEN}-tok windows; need >= {N_SEGS_MIN}") + save() + return + segments = [all_ids[i * SEG_LEN:(i + 1) * SEG_LEN] + for i in range(n_use) + if all_ids[i * SEG_LEN:(i + 1) * SEG_LEN].shape[0] == SEG_LEN] + print(f"[ppl] {all_ids.shape[0]} tokens -> {n_possible} windows possible, using " + f"{len(segments)}", flush=True) + result["ppl"]["n_segments"] = len(segments) + save() + + fp_ppls = [] + quant_deltas = {label: [] for label, *_ in PPL_CONFIGS} + fp_nll_total, fp_tok_total = 0.0, 0 + diag_logged = False + + for si, seg in enumerate(segments): + try: + prefix = seg[:PREFIX_LEN] + cont = seg[PREFIX_LEN:] + if bos_id is not None: + prefix = torch.cat([torch.tensor([bos_id], dtype=prefix.dtype), prefix]) + pre = prefix.unsqueeze(0).to(device) + cont_t = cont.unsqueeze(0).to(device) + if cont_t.shape[1] <= 1: + continue + + # Prefill ONCE on GPU -> move master to CPU immediately. + with torch.no_grad(): + out_pre = model(pre, use_cache=True) + gpu_kv = to_dyn(out_pre.past_key_values) + del out_pre + free() + cpu_master = clone_cache_cpu(gpu_kv) + del gpu_kv + free() + + # FP16 baseline: clone master to GPU, score, del. + fp_kv = clone_cache_gpu(cpu_master, device) + fp_nll, fp_tok = score_continuation(model, fp_kv, cont_t, device) + del fp_kv + free() + if fp_tok == 0: + del cpu_master + free() + continue + fp_seg_ppl = math.exp(fp_nll / fp_tok) + fp_ppls.append(fp_seg_ppl) + fp_nll_total += fp_nll + fp_tok_total += fp_tok + + if not diag_logged: + print(f" [diag] prefix_with_bos={pre.shape[1]} cont={cont_t.shape[1]} " + f"fp_tok={fp_tok} fp_seg_ppl={fp_seg_ppl:.3f}", flush=True) + + # Quant configs: clone master to GPU, quantize, score, del. + for label, kb, vb, pb in PPL_CONFIGS: + q_kv = clone_cache_gpu(cpu_master, device) + q_kv, _ = nq_quantize_kv(q_kv, kb, vb, pb, + first_call_log=(not diag_logged)) + q_nll, q_tok = score_continuation(model, q_kv, cont_t, device) + del q_kv + free() + if q_tok == 0: + continue + q_seg_ppl = math.exp(q_nll / q_tok) + delta = 100.0 * (q_seg_ppl - fp_seg_ppl) / fp_seg_ppl + quant_deltas[label].append(delta) + + diag_logged = True + del cpu_master + free() + + if si % 10 == 0 or si == len(segments) - 1: + k3d = quant_deltas["K3V2_pb0"] + k4d = quant_deltas["K4V2_pb0"] + print(f" seg {si}: fp_ppl={fp_seg_ppl:.3f} " + f"K3={k3d[-1] if k3d else 'n/a':.4f}% " + f"K4={k4d[-1] if k4d else 'n/a':.4f}%", flush=True) + if si % 20 == 0: + result["ppl"]["per_segment_fp_ppl"] = fp_ppls + result["ppl"]["k3v2_pb0"] = stats(quant_deltas["K3V2_pb0"]) + result["ppl"]["k4v2_pb0"] = stats(quant_deltas["K4V2_pb0"]) + if fp_tok_total > 0: + result["ppl"]["base_nf4weight_ppl"] = math.exp(fp_nll_total / fp_tok_total) + save() + + except RuntimeError: + tb = traceback.format_exc() + print(f" [OOM seg {si}] {oom_msg(tb)[:200]}", flush=True) + result["errors"][f"ppl_seg{si}"] = oom_msg(tb) + free() + except Exception: + tb = traceback.format_exc() + result["errors"][f"ppl_seg{si}"] = tb[-400:] + free() + + result["ppl"]["per_segment_fp_ppl"] = fp_ppls + k3 = stats(quant_deltas["K3V2_pb0"]) + k3["per_segment_delta_pct"] = quant_deltas["K3V2_pb0"] + result["ppl"]["k3v2_pb0"] = k3 + k4 = stats(quant_deltas["K4V2_pb0"]) + k4["per_segment_delta_pct"] = quant_deltas["K4V2_pb0"] + result["ppl"]["k4v2_pb0"] = k4 + if fp_tok_total > 0: + result["ppl"]["base_nf4weight_ppl"] = math.exp(fp_nll_total / fp_tok_total) + + k3m = k3.get("mean_delta_pct") + k3s = k3.get("sem_pct") + k4m = k4.get("mean_delta_pct") + k4s = k4.get("sem_pct") + result["ppl"]["summary"] = ( + f"n_segs={len(fp_ppls)}/{len(segments)} scored; " + f"base_nf4w_ppl={result['ppl'].get('base_nf4weight_ppl')}. " + f"K3V2_pb0 mean={k3m if k3m is not None else 'n/a'}% " + f"+/-{k3s if k3s is not None else 'n/a'}% " + f"(z={k3.get('z')}, sig2s={k3.get('significant_at_2sigma')}, " + f"neg={k3.get('n_negative_segments')}/{k3.get('n')}). " + f"K4V2_pb0 mean={k4m if k4m is not None else 'n/a'}% " + f"+/-{k4s if k4s is not None else 'n/a'}% " + f"(z={k4.get('z')}, sig2s={k4.get('significant_at_2sigma')}, " + f"neg={k4.get('n_negative_segments')}/{k4.get('n')}).") + save() + print(f"[ppl] {result['ppl']['summary']}", flush=True) + + +def run_niah_config(model, tok, hay_ids, label, kb, vb, pb, do_quant, ctx, device): + rng = random.Random(SEED + ctx) + cells = [] + + def n7(): + return f"{rng.randint(10**6, 10**7 - 1)}" + + for trial in range(NIAH_TRIALS): + keys = [n7() for _ in range(NIAH_PAIRS)] + values = [] + while len(values) < NIAH_PAIRS: + v = f"{rng.randint(100, 999)}" + if v not in values: + values.append(v) + ti = rng.randint(0, NIAH_PAIRS - 1) + q_text = f"What is the value of {keys[ti]}?" + try: + msgs = [{"role": "user", "content": q_text}] + p_text = tok.apply_chat_template(msgs, add_generation_prompt=True, tokenize=False) + except Exception: + p_text = q_text + "\nAnswer: " + qids = tok(p_text, return_tensors="pt", add_special_tokens=False).input_ids + bos = tok.bos_token_id + if bos is not None and qids.shape[1] > 0 and qids[0, 0].item() == bos: + qids = qids[:, 1:] + qids = qids.to(device) + + pair_ids = [tok(f"The value for {k} is {v}.", + return_tensors="pt", add_special_tokens=False).input_ids[0] + for k, v in zip(keys, values)] + pair_total = sum(p.shape[0] for p in pair_ids) + hay_take = ctx - pair_total - qids.shape[1] + if hay_take <= 0: + cells.append({"trial": trial, "error": "ctx too small", "recall": None}) + continue + hay = hay_ids[:hay_take] + depths = [(i + 1) / (NIAH_PAIRS + 1) for i in range(NIAH_PAIRS)] + positions = sorted([(int(d * hay.shape[0]), i) for i, d in enumerate(depths)]) + chunks, last = [], 0 + for pos, pi in positions: + chunks.append(hay[last:pos]) + chunks.append(pair_ids[pi]) + last = pos + chunks.append(hay[last:]) + prefix = torch.cat(chunks).unsqueeze(0).to(device) + + recall = None + ans = "" + kv = None + try: + with torch.no_grad(): + out1 = model(prefix, use_cache=True) + kv = to_dyn(out1.past_key_values) + del out1 + free() + if do_quant: + kv, qc = nq_quantize_kv(kv, kb, vb, pb, + first_call_log=(trial == 0)) + if trial == 0: + print(f" [quant: {qc} layers, pb={pb}]", flush=True) + with torch.no_grad(): + am = torch.ones(1, cache_len(kv) + qids.shape[1], + dtype=torch.long, device=device) + gen_out = model.generate( + qids, past_key_values=kv, attention_mask=am, + max_new_tokens=MAX_NEW_TOKENS, min_new_tokens=4, + do_sample=False, + pad_token_id=(tok.pad_token_id or tok.eos_token_id), + use_cache=True) + gen_ids = gen_out[0, qids.shape[1]:] + ans = tok.decode(gen_ids, skip_special_tokens=True).strip() + al = ans.lower() + tgt_v = values[ti] + m = re.search(r"\bis\s+(\d{3})\b", al) + recall = (m.group(1) == tgt_v) if m else (tgt_v in re.findall(r"\d{3}", al)) + del gen_out + except RuntimeError: + tb = traceback.format_exc() + ans = oom_msg(tb)[:160] + recall = None + except Exception: + tb = traceback.format_exc() + ans = tb[-160:] + recall = None + + print(f" [{label} ctx={ctx}] t{trial} " + f"{'YES' if recall else 'NO' if recall is not None else 'ERR'} " + f"ans={ans[:60]!r}", flush=True) + cells.append({"trial": trial, "recall": recall, "ans": ans[:160], + "target_value": values[ti], "target_key": keys[ti]}) + if kv is not None: + del kv + free() + + hits = sum(1 for c in cells if c.get("recall")) + print(f" [{label} ctx={ctx}] {hits}/{NIAH_TRIALS}", flush=True) + return {"config": label, "ctx": ctx, "hits": hits, "n": NIAH_TRIALS, "cells": cells} + + +def run_niah_phase(model, tok, device, result, save): + print("[niah] loading haystack ...", flush=True) + ds_wiki = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") + hay_txt = "\n\n".join(r["text"] for r in ds_wiki if r["text"].strip()) + hay_ids = tok(hay_txt, return_tensors="pt", truncation=True, + max_length=80000).input_ids[0] + + # Try preferred ctx (4096), fall back to 2048 on persistent OOM. + for ctx in [NIAH_CTX_PREF, NIAH_CTX_FALL]: + print(f"\n[niah] ctx={ctx}", flush=True) + ctx_oom = False + for label, kb, vb, pb, do_quant in NIAH_CONFIGS: + key = f"{label}_ctx{ctx}" + print(f"\n-- NIAH: {label} ctx={ctx} --", flush=True) + try: + res = run_niah_config(model, tok, hay_ids, label, kb, vb, pb, + do_quant, ctx, device) + result["niah"][key] = res + except RuntimeError: + tb = traceback.format_exc() + result["niah"][key] = {"error": oom_msg(tb), "err_class": "CUDA_OOM"} + result["errors"][f"niah_{key}"] = oom_msg(tb) + ctx_oom = True + except Exception: + tb = traceback.format_exc() + result["niah"][key] = {"error": tb[-400:]} + result["errors"][f"niah_{key}"] = tb[-400:] + save() + free() + # If FP16 baseline at this context completed and got hits>0, done. + fp_key = f"FP16_ctx{ctx}" + fp_res = result["niah"].get(fp_key, {}) + fp_hits = fp_res.get("hits", 0) + if fp_hits > 0 and not ctx_oom: + print(f"[niah] FP16 baseline hits={fp_hits}/{NIAH_TRIALS} at ctx={ctx}; done.", + flush=True) + result["niah"]["baseline_ctx_used"] = ctx + save() + break + if not ctx_oom and fp_hits == 0: + note = (f"DEGENERATE BASELINE at ctx={ctx}: FP16 hits=0/{NIAH_TRIALS}. " + f"Falling back to ctx={NIAH_CTX_FALL} if not already there.") + print(f" *** {note} ***", flush=True) + result["niah"][f"baseline_gate_ctx{ctx}"] = note + save() + if ctx == NIAH_CTX_FALL: + print(f"[niah] Exhausted fallback ctx={NIAH_CTX_FALL}.", flush=True) + break + + +def main(): + import transformers + gpu_info = None + if torch.cuda.is_available(): + p0 = torch.cuda.get_device_properties(0) + gpu_info = {"name": p0.name, "sm": f"{p0.major}{p0.minor}", + "n_gpus": torch.cuda.device_count(), + "total_memory_gb": round(p0.total_memory / 1024**3, 1)} + print(f"[start] GPU={gpu_info} transformers={transformers.__version__}", flush=True) + + result = { + "model_name": MODEL_NAME, + "weights": "NF4 (unsloth pre-quantized bnb-4bit, Apache-2.0, gated=false)", + "baseline_caveat": ( + "NF4-WEIGHTS baseline, NOT strict FP16-weights. Mixtral-8x7B (47B params) " + "cannot fit fp16 on T4x2 (32GB). KV-quant deltas are PAIRED off the same " + "NF4-weights baseline on identical continuation tokens, so the delta is a " + "clean KV-only delta."), + "protocol": "AQUA-iso paired (CPU-master)", + "protocol_note": ( + "Per segment: prefill ONCE -> GPU KV -> clone to CPU master -> del GPU KV. " + "Per config: clone CPU master to GPU -> quantize -> score identical cont " + "tokens -> del GPU clone. One GPU KV copy at a time. Paired delta = " + "(q_ppl - fp_ppl) / fp_ppl * 100. Stats: mean +/- SEM (ddof=1), z, sig@2sigma."), + "memory_config": { + "max_memory": MAX_MEMORY, + "offload_folder": OFFLOAD_DIR, + "offload_state_dict": True, + "note": ("11GiB/card pushes expert blocks to CPU; overflow to disk. " + "CPU master keeps KV off GPU between config runs."), + }, + "prefix_len": PREFIX_LEN, + "cont_len": CONT_LEN, + "seg_len": SEG_LEN, + "n_segs_target": N_SEGS_TARGET, + "n_segs_min": N_SEGS_MIN, + "niah_ctx_preferred": NIAH_CTX_PREF, + "niah_ctx_fallback": NIAH_CTX_FALL, + "n_trials": NIAH_TRIALS, + "n_pairs": NIAH_PAIRS, + "ppl": { + "base_nf4weight_ppl": None, "k3v2_pb0": None, "k4v2_pb0": None, + "per_segment_fp_ppl": [], "n_segments": 0, + }, + "niah": {}, + "errors": {}, + "model_source": None, + "configs": [{"label": l, "kb": kb, "vb": vb, "pb": pb, "bpe": honest_bpe(kb, vb)} + for l, kb, vb, pb in PPL_CONFIGS], + "fp16_bpe": 16.0, + "transformers": transformers.__version__, + "gpu": gpu_info, + "run_at_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + } + + def save(): + try: + with open(OUT_PATH, "w") as f: + json.dump(result, f, indent=2) + except Exception as e: + print(f"[SAVE FAILED] {e} -- raw result keys: {list(result.keys())}", flush=True) + + save() + + try: + model, tok, src = load_model() + except Exception: + tb = traceback.format_exc() + result["errors"]["model_load"] = oom_msg(tb) + result["blocked"] = ( + "BLOCKED: all Mixtral sources failed to load. " + "Unblock: verify unsloth/Mixtral-8x7B-Instruct-v0.1-bnb-4bit is public " + "on HF, or supply an HF_TOKEN for the gated mistralai source.") + save() + print(f"[BLOCKED] {result['blocked']}", flush=True) + return + + result["model_source"] = src + if tok.pad_token is None: + tok.pad_token = tok.eos_token + resolve_config(src) + result["config_meta"] = {k: (sorted(v) if isinstance(v, set) else v) + for k, v in cfg_state.items()} + dmap = getattr(model, "hf_device_map", None) + result["device_map_n_entries"] = len(dmap) if dmap else None + device = next(model.parameters()).device + save() + + print("\n[phase A] AQUA-iso paired PPL (CPU-master)", flush=True) + try: + run_ppl_phase(model, tok, device, result, save) + except Exception: + tb = traceback.format_exc() + result["errors"]["ppl_phase"] = oom_msg(tb) + save() + + print(f"\n[phase B] NIAH ctx={NIAH_CTX_PREF} (fallback {NIAH_CTX_FALL})", flush=True) + try: + run_niah_phase(model, tok, device, result, save) + except Exception: + tb = traceback.format_exc() + result["errors"]["niah_phase"] = oom_msg(tb) + save() + + # SAVE-BEFORE-PRINT: all stats already in result; json.dump above. + print("\n========== FINAL JSON ==========", flush=True) + k3 = result["ppl"].get("k3v2_pb0") or {} + k4 = result["ppl"].get("k4v2_pb0") or {} + printable = {k: v for k, v in result.items() if k not in ("ppl",)} + printable["ppl_summary"] = result["ppl"].get("summary") + printable["ppl_base_nf4w"] = result["ppl"].get("base_nf4weight_ppl") + printable["ppl_n_scored"] = len(result["ppl"].get("per_segment_fp_ppl") or []) + printable["ppl_k3v2_pb0"] = {kk: vv for kk, vv in k3.items() + if kk != "per_segment_delta_pct"} + printable["ppl_k4v2_pb0"] = {kk: vv for kk, vv in k4.items() + if kk != "per_segment_delta_pct"} + print(json.dumps(printable, indent=2), flush=True) + print(f"[done] result -> {OUT_PATH}", flush=True) + + +if __name__ == "__main__": + main()