From fcb6be780c61d4f76727217c25610d6f1fd917c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Andr=C3=A9=20Gomes=20Marques?= Date: Mon, 15 Jun 2026 18:07:39 +0200 Subject: [PATCH 1/2] Add Mistral-7B diverse-text NIAH rescue kernel (n=40, ctx=4K, no chat template) Discriminative operating-point harness: completion prompt on wikitext haystack creates partial FP16 failure band. Power-tests K2V2/K3V2 pb=0 rescue at n=40. --- .../kernel-metadata.json | 20 + .../nq_mistral_diverse_rescue.py | 466 ++++++++++++++++++ 2 files changed, 486 insertions(+) create mode 100644 experiments/kaggle/nq_mistral_diverse_rescue/kernel-metadata.json create mode 100644 experiments/kaggle/nq_mistral_diverse_rescue/nq_mistral_diverse_rescue.py diff --git a/experiments/kaggle/nq_mistral_diverse_rescue/kernel-metadata.json b/experiments/kaggle/nq_mistral_diverse_rescue/kernel-metadata.json new file mode 100644 index 0000000..9931737 --- /dev/null +++ b/experiments/kaggle/nq_mistral_diverse_rescue/kernel-metadata.json @@ -0,0 +1,20 @@ +{ + "id": "jagmarques/nq-mistral-diverse-rescue", + "title": "nq-mistral-diverse-rescue", + "code_file": "nq_mistral_diverse_rescue.py", + "language": "python", + "kernel_type": "script", + "is_private": true, + "enable_gpu": true, + "enable_tpu": false, + "enable_internet": true, + "keywords": [ + "gpu" + ], + "dataset_sources": [], + "kernel_sources": [], + "competition_sources": [], + "model_sources": [], + "docker_image": "gcr.io/kaggle-private-byod/python@sha256:00377cd1b3d470a605bc5b0ceca79969e369644e9b36802242a1c70e627372f9", + "machine_shape": "NvidiaTeslaT4" +} diff --git a/experiments/kaggle/nq_mistral_diverse_rescue/nq_mistral_diverse_rescue.py b/experiments/kaggle/nq_mistral_diverse_rescue/nq_mistral_diverse_rescue.py new file mode 100644 index 0000000..56fcf8a --- /dev/null +++ b/experiments/kaggle/nq_mistral_diverse_rescue/nq_mistral_diverse_rescue.py @@ -0,0 +1,466 @@ +# NexusQuant NIAH rescue test: Mistral-7B-Instruct-v0.3, Kaggle T4x2 (sm_75, fp16). +# Configs: FP16, K2V2_pb0, K3V2_pb0 at ctx=4096, n=40 trials. +# Mistral-7B-Instruct-v0.3: 8 KV heads, head_dim=128, rope_theta=1e6, full rotary. +# DIVERSE-TEXT COMPLETION harness (no chat template): creates partial FP16 band at 4K. +# Paper: FP16 1/5, K2V2 pb=0 3/5 at 4K diverse-text (documented beneficial rescue). +# Primary model: NousResearch/Mistral-7B-Instruct-v0.3 (ungated mirror; no token needed). +# T4x2: GPU0 capped at 11GiB (embed+head+vocab overhead), GPU1 at 14GiB. + +import sys, os, gc, json, time, re, random, traceback, subprocess + +os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") + +print("Installing transformers/accelerate + nexusquant ...", flush=True) +subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "numpy<2"]) +subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "-U", + "transformers>=5.5.3", "accelerate>=1.1.1"]) +subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", + "git+https://github.com/jagmarques/nexusquant.git@main"]) + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache +from datasets import load_dataset + +from nexusquant.core.e8_lattice import E8Lattice +from nexusquant.core.hadamard import hadamard_matrix +from nexusquant.core.rope_utils import inverse_rope, forward_rope + +SEED = 42 +random.seed(SEED) +torch.manual_seed(SEED) + +# Primary: NousResearch ungated mirror. Fallback: gated official (needs HF_TOKEN). +MODEL_PRIMARY = "NousResearch/Mistral-7B-Instruct-v0.3" +MODEL_FALLBACK = "mistralai/Mistral-7B-Instruct-v0.3" + +CONTEXT = 4096 # 4K: partial FP16 NIAH band on diverse-text (not saturated) +N_NEEDLES = 40 # 40 trials for McNemar statistical power +MAX_NEW_TOKENS = 16 # 3-digit value fits in <16 tokens +MAX_MEMORY = {0: "11GiB", 1: "14GiB"} +OUT_PATH = "/kaggle/working/nq_mistral_diverse_rescue.json" + +# Configs: (label, k_bits, v_bits, protect_boundary_layers, do_quant) +CONFIGS = [ + ("FP16", 0, 0, 0, False), + ("K2V2_pb0", 2, 2, 0, True), + ("K3V2_pb0", 3, 2, 0, True), +] + +# Rope config for Mistral-7B-Instruct-v0.3: rope_theta=1e6, full rotary. +ROPE_THETA = 1_000_000.0 +ROPE_PRF = None # full rotary (not partial) + +result = { + "model": MODEL_PRIMARY, + "context": CONTEXT, + "n_needles": N_NEEDLES, + "rope_theta": ROPE_THETA, + "rope_partial_rotary_factor": ROPE_PRF, + "harness": "diverse-text completion (NO chat template); partial FP16 band target", + "run_at_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + "niah_results": {}, + "errors": {}, +} + + +def save(): + with open(OUT_PATH, "w") as f: + json.dump(result, f, indent=2) + + +def free(): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def to_dyn(c): + return DynamicCache.from_legacy_cache(c) if isinstance(c, tuple) else c + + +def get_kv(c, l): + if hasattr(c, "key_cache"): + return c.key_cache[l], c.value_cache[l] + return c.layers[l].keys, c.layers[l].values + + +def set_kv(c, l, k, v): + if hasattr(c, "key_cache"): + c.key_cache[l] = k + c.value_cache[l] = v + else: + c.layers[l].keys = k + c.layers[l].values = v + + +def n_layers_kv(c): + return len(c.layers) if hasattr(c, "layers") else len(c.key_cache) + + +def cache_seq_len(c): + # Returns max sequence length across cache layers. + lens = [] + for l in range(n_layers_kv(c)): + kl, _ = get_kv(c, l) + if kl is not None: + lens.append(kl.shape[2]) + return max(lens) if lens else 0 + + +def quantize_kv(kv_cache, k_bits, v_bits): + """Apply Hadamard + E8 VQ to all layers; no boundary protection (pb=0).""" + Hcache = {} + + def get_H(d): + p2 = 1 + while p2 < d: + p2 *= 2 + if p2 not in Hcache: + Hcache[p2] = hadamard_matrix(p2).cpu() + return Hcache[p2], p2 + + nl = n_layers_kv(kv_cache) + quant_count = 0 + for l in range(nl): + kl, vl = get_kv(kv_cache, l) + if kl is None: + continue + d_l = kl.shape[-1] # head_dim + nk = kl.shape[1] # n_kv_heads + Hm, p2 = get_H(d_l) + h_pad = p2 - d_l + + for is_k, t_in, b_bits in [(True, kl, k_bits), (False, vl, v_bits)]: + t = t_in[0].float().cpu().clone() # (n_kv_heads, seq, head_dim) + if is_k: + # Remove RoPE (standard split-half, full rotary, theta=1e6). + t = inverse_rope(t, seq_offset=0, base=ROPE_THETA, + partial_rotary_factor=ROPE_PRF) + lvl = 2 ** b_bits + for h in range(nk): + th = t[h] # (seq, head_dim) + pad_in = F.pad(th, (0, h_pad)) if h_pad > 0 else th + rot = pad_in @ Hm.T # Hadamard rotation + amax = rot.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8) + sc = amax / (lvl / 2) + norm = rot / sc + # Pad to nearest multiple of 8 for E8 grouping. + e8p = (8 - p2 % 8) % 8 + if e8p > 0: + norm = F.pad(norm, (0, e8p)) + grp = norm.reshape(-1, 8) + qnt = E8Lattice.nearest_point(grp).clamp(-lvl / 2, lvl / 2) + co = qnt.reshape(-1, norm.shape[-1]) + if e8p > 0: + co = co[..., :p2] + t[h] = (co * sc @ Hm)[..., :d_l] + if is_k: + t = forward_rope(t, seq_offset=0, base=ROPE_THETA, + partial_rotary_factor=ROPE_PRF) + tdev = t_in.device + tdt = t_in.dtype + new_t = t.unsqueeze(0).to(dtype=tdt, device=tdev) + if is_k: + set_kv(kv_cache, l, new_t, vl) + else: + kn, _ = get_kv(kv_cache, l) + set_kv(kv_cache, l, kn, new_t) + quant_count += 1 + + return kv_cache, quant_count + + +def run_niah(model, tok, label, k_bits, v_bits, do_quant, hay_ids, device): + # label-independent rng: all configs see identical needles (true pairing). + rng = random.Random(SEED) + cells = [] + recall_arr = [] + t0 = time.time() + + for trial in range(N_NEEDLES): + kv = None + try: + # 7-digit keys, 3-digit values -- both rare in wikitext prose. + keys = [str(rng.randint(1_000_000, 9_999_999)) for _ in range(N_NEEDLES)] + seen = set() + values = [] + while len(values) < N_NEEDLES: + v = str(rng.randint(100, 999)) + if v not in seen: + seen.add(v) + values.append(v) + + target_idx = rng.randint(0, N_NEEDLES - 1) + target_key = keys[target_idx] + target_val = values[target_idx] + + # COMPLETION CUE -- no chat template; forces model into completion mode. + # "\n\nThe special pass key for {key} is" continues the haystack text. + cue_str = f"\n\nThe special pass key for {target_key} is" + cue_ids = tok(cue_str, return_tensors="pt", + add_special_tokens=False).input_ids[0] + + # NEEDLE sentences embedded in the haystack. + # Format: "The special pass key for {key} is {value}." + needle_strs = [ + f"The special pass key for {k} is {v}." for k, v in zip(keys, values) + ] + needle_ids_list = [ + tok(s, return_tensors="pt", + add_special_tokens=False).input_ids[0] + for s in needle_strs + ] + + needle_total = sum(p.shape[0] for p in needle_ids_list) + hay_budget = CONTEXT - needle_total - cue_ids.shape[0] - 1 # -1 for BOS + if hay_budget <= 0: + cells.append({"trial": trial, + "error": "context_too_small", "recall": False}) + recall_arr.append(None) + continue + + hay = hay_ids[:hay_budget] + # Spread ALL N_NEEDLES needles at evenly-spaced fractional depths. + depths = [(i + 1) / (N_NEEDLES + 1) for i in range(N_NEEDLES)] + positions = sorted( + [(int(d * hay.shape[0]), i) for i, d in enumerate(depths)]) + chunks = [] + last = 0 + for pos, pi in positions: + chunks.append(hay[last:pos]) + chunks.append(needle_ids_list[pi]) + last = pos + chunks.append(hay[last:]) + + # Prepend BOS; append completion cue (no [INST] wrapping). + bos_t = torch.tensor([tok.bos_token_id], dtype=torch.long) + prefix = torch.cat([bos_t] + chunks + [cue_ids]).unsqueeze(0).to(device) + + with torch.no_grad(): + out1 = model(prefix, use_cache=True) + kv = to_dyn(out1.past_key_values) + del out1 + free() + + if do_quant: + kv, qc = quantize_kv(kv, k_bits, v_bits) + if trial == 0: + print(f" [{label}] quantized {qc} layers " + f"k_bits={k_bits} v_bits={v_bits}", flush=True) + + past_len = cache_seq_len(kv) + # Generate from an empty continuation (model already saw the cue in prefix). + gen_input = torch.zeros(1, 0, dtype=torch.long, device=device) + am = torch.ones(1, past_len, dtype=torch.long, device=device) + with torch.no_grad(): + gen = model.generate( + gen_input, past_key_values=kv, attention_mask=am, + max_new_tokens=MAX_NEW_TOKENS, min_new_tokens=1, + do_sample=False, + pad_token_id=(tok.pad_token_id or tok.eos_token_id), + use_cache=True, + ) + gen_ids = gen[0] + ans = tok.decode(gen_ids, skip_special_tokens=True).strip() + al = ans.lower() + # Match 3-digit value as whole word; guard against substring match inside 7-digit key. + m = re.search(r"\b" + re.escape(target_val) + r"\b", al) + recall = m is not None + del gen + except Exception as e: + traceback.print_exc() + recall = None + ans = f"{type(e).__name__}: {str(e)[:120]}" + + tag = "YES" if recall is True else ("NO" if recall is False else "ERR") + print(f" [{label}] t{trial} key={target_key} val={target_val} " + f"{tag} ans={ans[:60]!r}", flush=True) + cells.append({ + "trial": trial, + "target_key": target_key, + "target_value": target_val, + "answer": ans[:160], + "recall": recall, + }) + recall_arr.append(recall) + if kv is not None: + del kv + free() + + hits = sum(1 for c in cells if c.get("recall") is True) + elapsed = time.time() - t0 + print(f" [{label}] {hits}/{N_NEEDLES} elapsed={elapsed:.0f}s", flush=True) + return { + "config": label, + "ctx": CONTEXT, + "hits": hits, + "n": N_NEEDLES, + "elapsed_s": int(elapsed), + "recall_array": recall_arr, + "cells": cells, + } + + +def compute_paired(fp16_arr, quant_arr): + """McNemar b/c counts: b=fp16-hit/quant-miss, c=fp16-miss/quant-RESCUE.""" + b, c = 0, 0 + for f, q in zip(fp16_arr, quant_arr): + if f is True and q is False: + b += 1 + elif f is False and q is True: + c += 1 + return b, c + + +def main(): + import transformers + gpu_info = None + if torch.cuda.is_available(): + p = torch.cuda.get_device_properties(0) + gpu_info = { + "name": p.name, + "sm": f"{p.major}{p.minor}", + "n_gpus": torch.cuda.device_count(), + "total_gb": round(p.total_memory / 1024**3, 1), + } + print(f"[start] GPU={gpu_info} transformers={transformers.__version__} " + f"alloc_conf={os.environ.get('PYTORCH_CUDA_ALLOC_CONF')}", flush=True) + + result["gpu"] = gpu_info + result["transformers"] = transformers.__version__ + + # Try NousResearch ungated mirror first; fall back to gated official. + model_id = MODEL_PRIMARY + hf_token = None + try: + tok = AutoTokenizer.from_pretrained(model_id) + print(f"[model] using primary (ungated): {model_id}", flush=True) + except Exception as e_primary: + print(f"[model] primary failed ({e_primary}); trying fallback ...", flush=True) + model_id = MODEL_FALLBACK + try: + from kaggle_secrets import UserSecretsClient + hf_token = UserSecretsClient().get_secret("HF_TOKEN") + except Exception: + hf_token = os.environ.get("HF_TOKEN") + if not hf_token: + raise SystemExit("primary and fallback both unreachable; set HF_TOKEN secret") + os.environ["HF_TOKEN"] = hf_token + tok = AutoTokenizer.from_pretrained(model_id, token=hf_token) + print(f"[model] using fallback (gated): {model_id}", flush=True) + + result["model"] = model_id + if tok.pad_token is None: + tok.pad_token = tok.eos_token + + t0 = time.time() + model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.float16, + device_map="auto", + max_memory=MAX_MEMORY, + attn_implementation="eager", + low_cpu_mem_usage=True, + **({"token": hf_token} if hf_token else {}), + ) + model.eval() + print(f" loaded in {time.time()-t0:.1f}s " + f"mem0={torch.cuda.memory_allocated(0)/1e9:.2f}GB " + f"mem1={torch.cuda.memory_allocated(1)/1e9:.2f}GB", flush=True) + dmap = getattr(model, "hf_device_map", None) + result["device_map"] = ( + {str(k): str(v) for k, v in dmap.items()} if dmap else None) + device = next(model.parameters()).device + + # Verify config: expect 8 KV heads, head_dim=128. + cfg = getattr(model.config, "text_config", model.config) + n_kv = getattr(cfg, "num_key_value_heads", None) + hd = getattr(cfg, "head_dim", None) + theta = getattr(cfg, "rope_theta", None) + n_lay = getattr(cfg, "num_hidden_layers", None) + print(f"[config] n_kv_heads={n_kv} head_dim={hd} " + f"rope_theta={theta} n_layers={n_lay}", flush=True) + result["model_config"] = { + "n_kv_heads": n_kv, "head_dim": hd, + "rope_theta": theta, "n_layers": n_lay, + } + + print("[data] loading wikitext-2 haystack ...", flush=True) + ds = load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train") + hay_text = "\n\n".join(r["text"] for r in ds if r["text"].strip()) + hay_ids = tok(hay_text, return_tensors="pt", + truncation=True, max_length=80_000).input_ids[0] + print(f" haystack tokenized: {hay_ids.shape[0]} tokens", flush=True) + save() + + print("\n[phase] NIAH sweep: FP16 -> K2V2_pb0 -> K3V2_pb0", flush=True) + fp16_recall = None + for label, k_bits, v_bits, _pb, do_quant in CONFIGS: + print(f"\n--- {label} ctx={CONTEXT} ---", flush=True) + try: + res = run_niah(model, tok, label, k_bits, v_bits, do_quant, + hay_ids, device) + result["niah_results"][label] = res + if label == "FP16": + fp16_recall = res["recall_array"] + # Flag non-discriminative operating points immediately. + if res["hits"] == 0: + result["errors"]["non_discriminative"] = ( + f"FP16={res['hits']}/40: DEGENERATE baseline; quant rescue credits invalid") + print(" *** DEGENERATE: FP16=0/40. Completion cue may need adjustment. ***", + flush=True) + elif res["hits"] == N_NEEDLES: + result["errors"]["non_discriminative"] = ( + f"FP16={res['hits']}/40: SATURATED baseline; no rescue possible") + print(" *** SATURATED: FP16=40/40. Chat-template leak or wrong harness. ***", + flush=True) + else: + print(f" *** DISCRIMINATIVE: FP16={res['hits']}/40 -- " + f"partial-failure band CONFIRMED ***", flush=True) + except Exception: + traceback.print_exc() + tb = traceback.format_exc() + result["niah_results"][label] = {"error": tb[-400:]} + result["errors"][f"niah_{label}"] = tb[-400:] + save() + free() + + # Paired McNemar b/c counts for each quant config vs FP16. + if fp16_recall is not None: + paired = {} + for label, k_bits, v_bits, _pb, do_quant in CONFIGS: + if not do_quant: + continue + quant_res = result["niah_results"].get(label, {}) + q_arr = quant_res.get("recall_array", []) + if q_arr: + b, c = compute_paired(fp16_recall, q_arr) + paired[label] = { + "b_fp16hit_quantmiss": b, + "c_fp16miss_quantrescue": c, + } + print(f" paired {label}: b(FP16-hit/quant-miss)={b} " + f"c(FP16-miss/quant-RESCUE)={c}", flush=True) + result["paired_mcnemar"] = paired + + # Summary. + print("\n========== NIAH SUMMARY ==========", flush=True) + for label, *_ in CONFIGS: + r = result["niah_results"].get(label, {}) + hits = r.get("hits", "?") + n = r.get("n", N_NEEDLES) + err = r.get("error", "") + if err: + print(f" {label}: ERROR {err[:80]}", flush=True) + else: + print(f" {label}: {hits}/{n}", flush=True) + + result["run_complete_utc"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) + save() + print(f"\n[done] -> {OUT_PATH}", flush=True) + print(json.dumps(result, indent=2), flush=True) + + +if __name__ == "__main__": + main() From 9894e98555eed3159b54f159cb0849e605102fd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Andr=C3=A9=20Gomes=20Marques?= Date: Mon, 15 Jun 2026 20:30:58 +0200 Subject: [PATCH 2/2] Fix model download: switch to ungated unsloth/mistral-7b-instruct-v0.3 primary NousResearch/Mistral-7B-Instruct-v0.3 does not exist; primary is now unsloth/mistral-7b-instruct-v0.3 (ungated, files confirmed via HF API). Fallback is mistralai/Mistral-7B-Instruct-v0.3 (also ungated). HF_TOKEN is now optional only (used if present; neither model requires it). --- .../nq_mistral_diverse_rescue.py | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/experiments/kaggle/nq_mistral_diverse_rescue/nq_mistral_diverse_rescue.py b/experiments/kaggle/nq_mistral_diverse_rescue/nq_mistral_diverse_rescue.py index 56fcf8a..2252f6a 100644 --- a/experiments/kaggle/nq_mistral_diverse_rescue/nq_mistral_diverse_rescue.py +++ b/experiments/kaggle/nq_mistral_diverse_rescue/nq_mistral_diverse_rescue.py @@ -3,7 +3,8 @@ # Mistral-7B-Instruct-v0.3: 8 KV heads, head_dim=128, rope_theta=1e6, full rotary. # DIVERSE-TEXT COMPLETION harness (no chat template): creates partial FP16 band at 4K. # Paper: FP16 1/5, K2V2 pb=0 3/5 at 4K diverse-text (documented beneficial rescue). -# Primary model: NousResearch/Mistral-7B-Instruct-v0.3 (ungated mirror; no token needed). +# Primary model: unsloth/mistral-7b-instruct-v0.3 (ungated, no HF_TOKEN needed). +# Fallback: mistralai/Mistral-7B-Instruct-v0.3 (also ungated; no token needed). # T4x2: GPU0 capped at 11GiB (embed+head+vocab overhead), GPU1 at 14GiB. import sys, os, gc, json, time, re, random, traceback, subprocess @@ -30,8 +31,9 @@ random.seed(SEED) torch.manual_seed(SEED) -# Primary: NousResearch ungated mirror. Fallback: gated official (needs HF_TOKEN). -MODEL_PRIMARY = "NousResearch/Mistral-7B-Instruct-v0.3" +# Primary: unsloth ungated mirror (gated=False, config.json+safetensors confirmed via HF API). +# Fallback: official repo (also ungated, gated=False confirmed via HF API). No HF_TOKEN needed. +MODEL_PRIMARY = "unsloth/mistral-7b-instruct-v0.3" MODEL_FALLBACK = "mistralai/Mistral-7B-Instruct-v0.3" CONTEXT = 4096 # 4K: partial FP16 NIAH band on diverse-text (not saturated) @@ -330,25 +332,25 @@ def main(): result["gpu"] = gpu_info result["transformers"] = transformers.__version__ - # Try NousResearch ungated mirror first; fall back to gated official. - model_id = MODEL_PRIMARY + # Both PRIMARY and FALLBACK are ungated; no HF_TOKEN needed for either. + # Use optional HF_TOKEN from env/secrets only if present (speeds up rate-limited downloads). hf_token = None try: - tok = AutoTokenizer.from_pretrained(model_id) - print(f"[model] using primary (ungated): {model_id}", flush=True) + from kaggle_secrets import UserSecretsClient + hf_token = UserSecretsClient().get_secret("HF_TOKEN") + except Exception: + hf_token = os.environ.get("HF_TOKEN") + token_kwargs = {"token": hf_token} if hf_token else {} + + model_id = MODEL_PRIMARY + try: + tok = AutoTokenizer.from_pretrained(model_id, **token_kwargs) + print(f"[model] using primary: {model_id}", flush=True) except Exception as e_primary: print(f"[model] primary failed ({e_primary}); trying fallback ...", flush=True) model_id = MODEL_FALLBACK - try: - from kaggle_secrets import UserSecretsClient - hf_token = UserSecretsClient().get_secret("HF_TOKEN") - except Exception: - hf_token = os.environ.get("HF_TOKEN") - if not hf_token: - raise SystemExit("primary and fallback both unreachable; set HF_TOKEN secret") - os.environ["HF_TOKEN"] = hf_token - tok = AutoTokenizer.from_pretrained(model_id, token=hf_token) - print(f"[model] using fallback (gated): {model_id}", flush=True) + tok = AutoTokenizer.from_pretrained(model_id, **token_kwargs) + print(f"[model] using fallback: {model_id}", flush=True) result["model"] = model_id if tok.pad_token is None: @@ -362,7 +364,7 @@ def main(): max_memory=MAX_MEMORY, attn_implementation="eager", low_cpu_mem_usage=True, - **({"token": hf_token} if hf_token else {}), + **token_kwargs, ) model.eval() print(f" loaded in {time.time()-t0:.1f}s "