C6: Mixtral-8x7B manual per-layer GPU streaming (bypass accelerate)#44
Draft
jagmarques wants to merge 15 commits into
Draft
C6: Mixtral-8x7B manual per-layer GPU streaming (bypass accelerate)#44jagmarques wants to merge 15 commits into
jagmarques wants to merge 15 commits into
Conversation
Bypasses accelerate entirely: CPU-resident NF4 weights + PyTorch forward hooks on each MixtralDecoderLayer. Pre-hook moves layer to cuda:0, post-hook moves back to CPU. No dispatch wall, no meta-tensor serialization.
Replace plain module.to(DEVICE) hooks (which lose Linear4bit QuantState on round-trip) with a dequant-to-fp16 approach: pre_hook dequantizes every Linear4bit submodule to a temporary fp16 nn.Linear on GPU via bitsandbytes.functional.dequantize_4bit; post_hook restores the original NF4 modules and frees GPU temps. NF4 weights stay CPU-resident throughout. Adds smoke-check gate: assert logits finite + MoE routes >=2 experts before starting the 2-3h PPL+NIAH run. Slug updated to jagmardrop/nq-mixtral-stream-v2.
… (v3) qs.to(DEVICE) does not deep-move state2.absmax/code or qs.offset when bnb_4bit_use_double_quant=True, leaving top-level absmax=None and hitting dequantize_4bit assert. Add _move_qs_to_device() that explicitly moves every nested tensor field. Also add one-time bnb version + QuantState attribute diagnostic so next iteration has real structure, not guesses. Bump kernel slug to jagmardrop/nq-mixtral-stream-v3.
…rs (v4) v3 DIAGNOSTICS confirmed qs.absmax is None on CPU-resident models; the absmax tensor lives as a separate state_dict buffer (weight.absmax, weight.nested_absmax, weight.quant_map, weight.quant_state.bitsandbytes__nf4, etc.). _move_qs_to_device was a no-op on None fields, leaving dequantize_4bit asserting. v4 fix: _build_qs_from_module() strips the "weight." prefix from the module's state_dict keys and calls QuantState.from_dict(qs_dict, device=DEVICE), which mirrors bnb's own from_prequantized internals. Handles both attention Linear4bit and fused expert batched weights (experts.gate_up_proj/down_proj) identically. Diagnostic now prints state_dict keys + shapes for one attn weight AND one fused expert so next iteration is grounded. v3 deep-move kept as fallback on from_dict exception. Bump slug to jagmardrop/nq-mixtral-stream-v4.
Add torch_dtype=torch.float16 to both from_pretrained calls so hidden_states/activations are fp16 throughout, matching bnb_4bit_compute_dtype. Kernel slug: jagmardrop/nq-mixtral-stream-v5.
transformers>=5.5.3 uses MixtralTopKRouter for the MoE gate. Unlike attention layers (Linear4bit modules), the router holds self.weight as a bare nn.Parameter that bitsandbytes promotes to Params4bit. F.linear(hidden_states, self.weight) then uses the raw uint8 blob, causing: "mat1 and mat2 shapes cannot be multiplied (64x4096 and 1x16384)". Fix: _dequant_layer_to_gpu now also scans all named_modules for any module whose .weight is Params4bit (not already a Linear4bit), rebuilds the QuantState from state_dict buffers, dequantizes to fp16, and swaps the parameter in-place. _restore_layer_from_gpu restores the original Params4bit weight after the forward. All Linear4bit expert linears (w1/w2/w3) are already covered by the existing loop.
… GPU-quota-exhausted window)
…, keep T4x2 machine_shape
Root cause: for the MixtralTopKRouter gate (bare Params4bit, not Linear4bit), QuantState.from_dict fails because the gate's state_dict only has 'weight' with no separate quant buffer entries. The v6 fallback set qs_gpu = w.quant_state.to(DEVICE) but quant_type was None on the CPU-resident model, causing AttributeError at dequantize_4bit(..., quant_type=None.quant_type). Fix: after from_dict fails, use w.quant_state attrs directly (with explicit GPU tensor moves for absmax/offset/code/state2), then patch quant_type=None -> "nf4" (all weights in this NF4 model are nf4). Guard: if qs_gpu or quant_type still None after both paths, print [gate-diag] buffer layout and raise a clear error. Slug: jagmardrop/nq-mixtral-stream-v11 (RUNNING on Kaggle).
mlp.gate is wrapped in Params4bit but w.quant_state is None: bitsandbytes keeps the tiny router gate in full precision, not NF4. v11 fell through to the qs_gpu=None guard and crashed at qs_gpu.quant_type="nf4" (AttributeError on None). v12 short-circuits on quant_state is None: moves gate weight to GPU as plain fp16 nn.Parameter, prints [gate-unquant] diagnostic, validates 2D shape. Genuinely-quantized Params4bit (quant_state not None) still go through dequantize_4bit unchanged.
…it vs float=unquantized gate), not the unreliable quant_state-None check (v14+)
…n place); use qs_pre directly. Enrich guard diagnostics (v16)
The grouped_mm_fallback used by transformers on sm_75 materializes a dense ~14 GiB block even for 64 tokens. v17 sets config._experts_implementation = "batched_mm" after model load, switching to the token-indexed batched path (gate_up_proj[expert_ids] allocates only S*2*intermediate rows). Fallback monkeypatch covers older transformers snapshots without _experts_implementation. Print [moe-path] marker states which path is active. Push under jooandrgomesmarques/nq-mixtral-stream-v17.
…hed_mm bmm-shape-fails + grouped_mm OOMs on sm_75; looped path is the only correct+bounded one
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes part of C6: Mixtral-8x7B MoE coverage on Kaggle T4x2 via manual per-layer GPU streaming (no accelerate).
Progress this run: the streaming loader advanced through every prior forward wall. v4 reconstructed the NF4 QuantState from raw module buffers (dequant works), v5 aligned the load dtype to fp16 (attention q_proj cleared). v5 then failed at the MoE router: the gate weight (block_sparse_moe.gate) is itself a packed NF4 Params4bit, missed by the per-layer dequant pass. The router F.linear hit (64x4096) x (1x16384 packed blob).
v6 (this branch) fixes that: the dequant pass now iterates every bitsandbytes 4-bit module AND every standalone Params4bit weight in the layer, so the gate and all 8 experts dequantize to fp16 before the forward. No name list to drift.
Status: ready to run, but not yet run. All three of our Kaggle accounts have exhausted their weekly GPU quota (a GPU SaveKernel push returns a 500; the same 500 appears on the known-exhausted primary, and non-GPU pushes succeed). The kernel will run the moment any account's weekly quota resets, or on an A100. DRAFT until a clean run produces paired K3V2/K4V2 PPL (n>=60) and NIAH at 4K.