Add Qwen3.5-35B-A3B contrib model#60
Open
jimburtoft wants to merge 18 commits intoaws-neuron:mainfrom
Open
Conversation
Implements Qwen3.5-35B-A3B (35B total, 3B active per token) with: - Custom NKI v2 DeltaNet kernels for 30 linear attention layers - Standard GQA with output gate and partial RoPE for 10 attention layers - Sparse MoE (256 experts, top-8 + sigmoid-gated shared expert) - CTE-to-TKG state carry-over via input_output_aliases Validated on trn2.3xlarge with SDK 2.28: CTE produces 'Paris' (17.88), TKG generates coherent multi-token output across all test prompts.
…EADME - Neuron vs GPU benchmark results (trn2.3xlarge vs g6.12xlarge 4x L4) - 5.5x TKG throughput advantage (54.9 vs 10.0 tok/s at BS=1) - Long context (seq_len>=1024) config with attn_kernel_enabled=False - Document NKI flash attention head_dim>128 limitation and workaround - Document compilation time scaling (O(seq_len) due to DeltaNet unrolling) - Add token match vs CPU test result (100%)
… head_dim>128 - Add nki_flash_attn_d256.py: custom NKI kernel that tiles QK contraction in 2x128 chunks, supporting head_dim=256 (NxDI kernel asserts d<=128) - Add perform_prefill() override to NeuronQwen35Attention that auto-disables NKI flash attention for head_dim>128, eliminating the need for manual attn_kernel_enabled=False in NeuronConfig - Custom kernel is opt-in via QWEN35_USE_FLASH_ATTN_D256=1 (benchmarks show 2.4x TTFT regression due to BHSD->BHDS layout conversion overhead) - Update README: remove attn_kernel_enabled requirement, document kernel findings
Adds nkilib_kernel_patch.py that replaces NxDI's bundled flash attention kernel (_pre_prod_kernels) with our modified nkilib kernel supporting head_dim up to 256 via d-tiling. Requires installing the nki-library fork (feature/head-dim-256 branch) as a standalone override. Enable with: QWEN35_PATCH_FLASH_ATTN=1 Prerequisite: pip install git+https://github.com/jimburtoft/nki-library.git@feature/head-dim-256 The adapter translates NxDI parameter names (do_out_tp, kernel_name, use_dma_transpose) to nkilib convention (tp_out, causal_mask) and applies the same decorator stack (peel, re-jit torchxla, skip_middle_end, enable_stack_allocator). perform_prefill updated with three priority paths: 1. Patched nkilib kernel (zero layout overhead) - QWEN35_PATCH_FLASH_ATTN=1 2. External NKI kernel (has TTFT regression) - QWEN35_USE_FLASH_ATTN_D256=1 3. PyTorch softmax fallback (default for head_dim>128)
The previous approach tried to monkey-patch _flash_fwd_call_nki which passed do_out_tp=True (translated to tp_out=True). But tp_out=True is not supported for d>128 because the output write-back path would need d on the par_dim. New approach: call the nkilib kernel directly from perform_prefill with tp_out=False. The kernel returns (B*H, seqlen, d) which we reshape to BHSD (B, H, S, D) and return with FlashAttentionStrategy.NONE. No layout conversion overhead -- Q and K are prepared in tp_q=True format (B*H, seqlen, d) which matches what the nkilib kernel expects with GQA.
…ssue Calling the nkilib kernel with grid=(nc(lnc),) causes num_shard=None from nl.num_programs(). For now, call without grid (single-core mode) to test correctness first. Performance optimization with SPMD grid can follow.
The num_shard=None issue is fixed in nkilib kernel_helpers.py (commit 91e06cf in nki-library fork). Re-enable the full NxDI decorator stack which is needed for CTE kernel optimizations.
The previous approach (peel @nki.jit -> re-jit with mode='torchxla' + skip_middle_end_transformations + enable_stack_allocator) caused tracing issues: the NKI compiler's inline_function mechanism doesn't properly propagate trace context (builtin, address= support) to sub-functions. nki_jit() wraps the already-traced GenericKernel directly for XLA execution, preserving the original trace context intact. This is simpler and avoids all the sub-function tracing issues.
The nkilib flash attention kernel for head_dim>256 works in standalone mode (all tests pass) but cannot be integrated with NxDI due to an NKI compiler bug: TraceKernel.inline_function doesn't propagate trace context (builtin injection) to sub-functions in torchxla mode. - nkilib_kernel_patch.py: disabled with early return, ready to enable when the compiler bug is fixed - README: documented the compiler bug and nki-library fork work - issue_nki_inline_function_builtin.md: full bug report filed
- Add is_speculative_decoding flag to MoE forward call, aligning with upstream NxDI 2.28 Qwen3 MoE changes - Add vision encoder (modeling_qwen35_moe_vision.py): Qwen3.5 ViT with 27-layer transformer, patch merger, rotary embeddings, block-diagonal attention mask, and sequence length bucketing - Add VL orchestrator (modeling_qwen35_moe_vl.py): mRoPE position ID computation, vision+text input preparation, generate loop with rope_deltas for TKG decode phase - Add image input integration tests: tokenization, mRoPE correctness, vision encoder shapes, end-to-end VL pipeline - Update __init__.py exports and README with vision-language docs
Vision encoder (27-layer ViT, 442M params) compiles via torch_neuronx.trace() in 162.7s and runs inference in 8.3ms with cosine similarity 0.999 vs CPU. Changes: - compile_vision_encoder.py: Standalone compilation script for the ViT - Fix rotary embeddings to use rotate_half (matching HF reference) - Fix weight loading prefix (model.visual.* not visual.*) - Fix HF import chain for Qwen3VL classes in transformers 4.57+ - Add load_compiled() and load_vision_weights_from_hf() to wrapper - Add standalone mode (model_cls=None) for pre-compiled model loading - Update VL orchestrator compile/load to support pre-compiled vision encoder Tested on trn2.3xlarge: 3/3 E2E tests pass (dummy image, real image via HF processor, compiled model load through wrapper).
- Override _get_model_outputs in NeuronQwen35MoeForCausalLM to pass all 24 positional args explicitly (positions 7-21 as torch.empty(0), 22-23 as vision_embeddings/vision_mask) - Add get_required_kwargs to propagate llava_args through HF generation loop - Qwen35ModelWrapper.input_generator produces 24-arg tuples for CTE and TKG - NeuronQwen35MoeModel.forward accepts vision_embeddings/vision_mask at positions 22-23 - encode_vision_to_input uses index_put_ to scatter vision embeddings - VL generate() in modeling_qwen35_moe_vl.py computes mRoPE position IDs, runs vision encoder, and passes llava_args to text model - Add test_vl_e2e.py for full VL pipeline testing on trn2
…eness Wire 3D position IDs (T/H/W) through the full inference pipeline: - Qwen35MRoPEEmbedding with interleaved layout [11,11,10] replaces RotaryEmbedding - rotary_position_id at slot 21 carries (3,B,S) mRoPE through CTE/TKG - VL generate() passes 3D positions via llava_args[2] - _get_model_outputs extracts mRoPE or generates text-only defaults - input_generator and pad_inputs handle trace-time and bucket padding
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Add Qwen3.5-35B-A3B (hybrid DeltaNet + GQA + MoE) as a contrib model. This is a novel architecture with 30 linear recurrent DeltaNet layers and 10 standard GQA attention layers, all with 256-expert sparse MoE. Implements custom NKI v2 kernels for the DeltaNet gated delta rule recurrence, padding-aware state management, sigmoid-gated shared experts, partial RoPE, and attention output gates. Achieves 100% token match against CPU reference and 5.5x higher TKG throughput than a 4x L4 GPU setup.
Model Information
Model Name: Qwen3.5-35B-A3B
Model Architecture: Hybrid decoder-only transformer (30 DeltaNet linear recurrent layers + 10 GQA attention layers + 256-expert sparse MoE on all 40 layers)
Purpose: Text generation (language backbone of a Vision-Language model)
Checklist
Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md for detailed guidelines.
Required Components
test/integration/test_model.py)src/)Optional Components
test/unit/directoryFolder Structure
Confirm your contribution follows this structure:
/contrib/models/Qwen3.5-35B-A3B/
README.md
/src
init.py
modeling_qwen35_moe.py
nki_deltanet.py
/test
init.py
/unit
init.py
/integration
init.py
test_model.py
Testing
How did you test this change?
Tested on trn2.3xlarge with Neuron SDK 2.28 (PyTorch 2.9). Model was compiled, loaded, and run through multiple test prompts. Token generation was validated against CPU reference (HuggingFace transformers) achieving 100% top-1 token match across 3 test prompts. Benchmark measurements taken for both short context (128 tokens) and long context (2048 tokens) sequences.
Test Results:
Compatibility
Tested with:
Additional Information
Key implementation details:
nki_deltanet.py) for the gated delta rule recurrence, processing one (batch, head) pair per kernel call with sequential token iteration over a 128x128 state matrix in SBUF.nn.Parameterbuffers withinput_output_aliasesthrough a customQwen35DecoderModelInstance.gzeroed so padding tokens preserve (rather than decay) the recurrent state.SharedExpertswith a sigmoid gate, since Qwen3.5's shared expert gating differs from NxDI's default additive behavior.forward_blockwiseproduces incorrect output on trn2 SDK 2.28 (workaround: large block_size). NKI flash attention limited to head_dim<=128 (workaround:attn_kernel_enabled=False).Related Issues
N/A
vLLM Integration
By submitting this PR, I confirm that: