Skip to content

C6: Mixtral-8x7B manual per-layer GPU streaming (bypass accelerate)#44

Draft
jagmarques wants to merge 15 commits into
mainfrom
company/c6-mixtral-stream
Draft

C6: Mixtral-8x7B manual per-layer GPU streaming (bypass accelerate)#44
jagmarques wants to merge 15 commits into
mainfrom
company/c6-mixtral-stream

Conversation

@jagmarques

@jagmarques jagmarques commented Jun 16, 2026

Copy link
Copy Markdown
Owner

Closes part of C6: Mixtral-8x7B MoE coverage on Kaggle T4x2 via manual per-layer GPU streaming (no accelerate).

Progress this run: the streaming loader advanced through every prior forward wall. v4 reconstructed the NF4 QuantState from raw module buffers (dequant works), v5 aligned the load dtype to fp16 (attention q_proj cleared). v5 then failed at the MoE router: the gate weight (block_sparse_moe.gate) is itself a packed NF4 Params4bit, missed by the per-layer dequant pass. The router F.linear hit (64x4096) x (1x16384 packed blob).

v6 (this branch) fixes that: the dequant pass now iterates every bitsandbytes 4-bit module AND every standalone Params4bit weight in the layer, so the gate and all 8 experts dequantize to fp16 before the forward. No name list to drift.

Status: ready to run, but not yet run. All three of our Kaggle accounts have exhausted their weekly GPU quota (a GPU SaveKernel push returns a 500; the same 500 appears on the known-exhausted primary, and non-GPU pushes succeed). The kernel will run the moment any account's weekly quota resets, or on an A100. DRAFT until a clean run produces paired K3V2/K4V2 PPL (n>=60) and NIAH at 4K.

Bypasses accelerate entirely: CPU-resident NF4 weights + PyTorch forward
hooks on each MixtralDecoderLayer. Pre-hook moves layer to cuda:0, post-hook
moves back to CPU. No dispatch wall, no meta-tensor serialization.
Replace plain module.to(DEVICE) hooks (which lose Linear4bit QuantState
on round-trip) with a dequant-to-fp16 approach: pre_hook dequantizes
every Linear4bit submodule to a temporary fp16 nn.Linear on GPU via
bitsandbytes.functional.dequantize_4bit; post_hook restores the original
NF4 modules and frees GPU temps. NF4 weights stay CPU-resident throughout.

Adds smoke-check gate: assert logits finite + MoE routes >=2 experts
before starting the 2-3h PPL+NIAH run.

Slug updated to jagmardrop/nq-mixtral-stream-v2.
… (v3)

qs.to(DEVICE) does not deep-move state2.absmax/code or qs.offset when
bnb_4bit_use_double_quant=True, leaving top-level absmax=None and hitting
dequantize_4bit assert. Add _move_qs_to_device() that explicitly moves
every nested tensor field. Also add one-time bnb version + QuantState
attribute diagnostic so next iteration has real structure, not guesses.
Bump kernel slug to jagmardrop/nq-mixtral-stream-v3.
…rs (v4)

v3 DIAGNOSTICS confirmed qs.absmax is None on CPU-resident models; the absmax
tensor lives as a separate state_dict buffer (weight.absmax, weight.nested_absmax,
weight.quant_map, weight.quant_state.bitsandbytes__nf4, etc.). _move_qs_to_device
was a no-op on None fields, leaving dequantize_4bit asserting.

v4 fix: _build_qs_from_module() strips the "weight." prefix from the module's
state_dict keys and calls QuantState.from_dict(qs_dict, device=DEVICE), which
mirrors bnb's own from_prequantized internals. Handles both attention Linear4bit
and fused expert batched weights (experts.gate_up_proj/down_proj) identically.
Diagnostic now prints state_dict keys + shapes for one attn weight AND one fused
expert so next iteration is grounded. v3 deep-move kept as fallback on from_dict
exception. Bump slug to jagmardrop/nq-mixtral-stream-v4.
Add torch_dtype=torch.float16 to both from_pretrained calls so
hidden_states/activations are fp16 throughout, matching bnb_4bit_compute_dtype.
Kernel slug: jagmardrop/nq-mixtral-stream-v5.
transformers>=5.5.3 uses MixtralTopKRouter for the MoE gate. Unlike
attention layers (Linear4bit modules), the router holds self.weight as
a bare nn.Parameter that bitsandbytes promotes to Params4bit.
F.linear(hidden_states, self.weight) then uses the raw uint8 blob,
causing: "mat1 and mat2 shapes cannot be multiplied (64x4096 and 1x16384)".

Fix: _dequant_layer_to_gpu now also scans all named_modules for any
module whose .weight is Params4bit (not already a Linear4bit), rebuilds
the QuantState from state_dict buffers, dequantizes to fp16, and swaps
the parameter in-place. _restore_layer_from_gpu restores the original
Params4bit weight after the forward. All Linear4bit expert linears
(w1/w2/w3) are already covered by the existing loop.
Root cause: for the MixtralTopKRouter gate (bare Params4bit, not Linear4bit),
QuantState.from_dict fails because the gate's state_dict only has 'weight' with
no separate quant buffer entries. The v6 fallback set qs_gpu = w.quant_state.to(DEVICE)
but quant_type was None on the CPU-resident model, causing AttributeError at
dequantize_4bit(..., quant_type=None.quant_type).

Fix: after from_dict fails, use w.quant_state attrs directly (with explicit GPU
tensor moves for absmax/offset/code/state2), then patch quant_type=None -> "nf4"
(all weights in this NF4 model are nf4). Guard: if qs_gpu or quant_type still None
after both paths, print [gate-diag] buffer layout and raise a clear error.

Slug: jagmardrop/nq-mixtral-stream-v11 (RUNNING on Kaggle).
mlp.gate is wrapped in Params4bit but w.quant_state is None: bitsandbytes
keeps the tiny router gate in full precision, not NF4. v11 fell through to
the qs_gpu=None guard and crashed at qs_gpu.quant_type="nf4" (AttributeError
on None). v12 short-circuits on quant_state is None: moves gate weight to GPU
as plain fp16 nn.Parameter, prints [gate-unquant] diagnostic, validates 2D
shape. Genuinely-quantized Params4bit (quant_state not None) still go through
dequantize_4bit unchanged.
…it vs float=unquantized gate), not the unreliable quant_state-None check (v14+)
…n place); use qs_pre directly. Enrich guard diagnostics (v16)
The grouped_mm_fallback used by transformers on sm_75 materializes a dense
~14 GiB block even for 64 tokens. v17 sets
config._experts_implementation = "batched_mm" after model load, switching
to the token-indexed batched path (gate_up_proj[expert_ids] allocates only
S*2*intermediate rows). Fallback monkeypatch covers older transformers
snapshots without _experts_implementation. Print [moe-path] marker states
which path is active. Push under jooandrgomesmarques/nq-mixtral-stream-v17.
…hed_mm bmm-shape-fails + grouped_mm OOMs on sm_75; looped path is the only correct+bounded one
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant