Skip to content

Add Qwen3.5-35B-A3B contrib model#60

Open
jimburtoft wants to merge 18 commits intoaws-neuron:mainfrom
jimburtoft:contrib/qwen3.5-35b-a3b
Open

Add Qwen3.5-35B-A3B contrib model#60
jimburtoft wants to merge 18 commits intoaws-neuron:mainfrom
jimburtoft:contrib/qwen3.5-35b-a3b

Conversation

@jimburtoft
Copy link

Description

Add Qwen3.5-35B-A3B (hybrid DeltaNet + GQA + MoE) as a contrib model. This is a novel architecture with 30 linear recurrent DeltaNet layers and 10 standard GQA attention layers, all with 256-expert sparse MoE. Implements custom NKI v2 kernels for the DeltaNet gated delta rule recurrence, padding-aware state management, sigmoid-gated shared experts, partial RoPE, and attention output gates. Achieves 100% token match against CPU reference and 5.5x higher TKG throughput than a 4x L4 GPU setup.

Model Information

Model Name: Qwen3.5-35B-A3B
Model Architecture: Hybrid decoder-only transformer (30 DeltaNet linear recurrent layers + 10 GQA attention layers + 256-expert sparse MoE on all 40 layers)
Purpose: Text generation (language backbone of a Vision-Language model)

Checklist

Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md for detailed guidelines.

Required Components

  • Accuracy Test (ex. test/integration/test_model.py)
    • At least one integration test that validates model accuracy
    • Uses logit validation or equivalent accuracy verification
    • Test can compile and run the model on Neuron
  • README.md with the following sections:
    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types (Trn1/Trn2/Inf2)
    • Example Checkpoints: Links to compatible model checkpoints (e.g., HuggingFace Hub)
    • Testing Instructions: Command to run the test suite for the model
  • Source Code (src/)
    • Modeling code following NxD Inference patterns
    • Properly structured in the contrib folder hierarchy

Optional Components

  • Unit Tests (CPU or Neuron-based)
    • Tests for individual modeling components
    • Located in test/unit/ directory

Folder Structure

Confirm your contribution follows this structure:
/contrib/models/Qwen3.5-35B-A3B/
README.md
/src
init.py
modeling_qwen35_moe.py
nki_deltanet.py
/test
init.py
/unit
init.py
/integration
init.py
test_model.py

Testing

How did you test this change?
Tested on trn2.3xlarge with Neuron SDK 2.28 (PyTorch 2.9). Model was compiled, loaded, and run through multiple test prompts. Token generation was validated against CPU reference (HuggingFace transformers) achieving 100% top-1 token match across 3 test prompts. Benchmark measurements taken for both short context (128 tokens) and long context (2048 tokens) sequences.
Test Results:

Test Status Result
Smoke Test PASS Model loads successfully
Generation ("Paris") PASS CTE top-1 = "Paris" (score 17.88)
Coherence PASS Multi-prompt coherent generation
Token Match vs CPU PASS 3/3 = 100% token match
TTFT (seq_len=128) 1,051 ms
TKG throughput 54.3 tok/s (18.4 ms/tok)

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.28
  • Instance Type(s): trn2.3xlarge (TP=4, LNC=2)
  • PyTorch Version: 2.9
  • Python Version: 3.10

Additional Information

Key implementation details:

  • NKI DeltaNet kernel: Custom NKI v2 kernels (nki_deltanet.py) for the gated delta rule recurrence, processing one (batch, head) pair per kernel call with sequential token iteration over a 128x128 state matrix in SBUF.
  • State carry-over: DeltaNet recurrent state and conv1d state are carried between context encoding (CTE) and token generation (TKG) via nn.Parameter buffers with input_output_aliases through a custom Qwen35DecoderModelInstance.
  • Padding-aware recurrence: Right-padded inputs have their decay factor g zeroed so padding tokens preserve (rather than decay) the recurrent state.
  • Sigmoid-gated shared expert: Wraps NxDI's SharedExperts with a sigmoid gate, since Qwen3.5's shared expert gating differs from NxDI's default additive behavior.
  • Known SDK issues: MoE forward_blockwise produces incorrect output on trn2 SDK 2.28 (workaround: large block_size). NKI flash attention limited to head_dim<=128 (workaround: attn_kernel_enabled=False).

Related Issues

N/A

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

Implements Qwen3.5-35B-A3B (35B total, 3B active per token) with:
- Custom NKI v2 DeltaNet kernels for 30 linear attention layers
- Standard GQA with output gate and partial RoPE for 10 attention layers
- Sparse MoE (256 experts, top-8 + sigmoid-gated shared expert)
- CTE-to-TKG state carry-over via input_output_aliases

Validated on trn2.3xlarge with SDK 2.28: CTE produces 'Paris' (17.88),
TKG generates coherent multi-token output across all test prompts.
…EADME

- Neuron vs GPU benchmark results (trn2.3xlarge vs g6.12xlarge 4x L4)
- 5.5x TKG throughput advantage (54.9 vs 10.0 tok/s at BS=1)
- Long context (seq_len>=1024) config with attn_kernel_enabled=False
- Document NKI flash attention head_dim>128 limitation and workaround
- Document compilation time scaling (O(seq_len) due to DeltaNet unrolling)
- Add token match vs CPU test result (100%)
… head_dim>128

- Add nki_flash_attn_d256.py: custom NKI kernel that tiles QK contraction in
  2x128 chunks, supporting head_dim=256 (NxDI kernel asserts d<=128)
- Add perform_prefill() override to NeuronQwen35Attention that auto-disables
  NKI flash attention for head_dim>128, eliminating the need for manual
  attn_kernel_enabled=False in NeuronConfig
- Custom kernel is opt-in via QWEN35_USE_FLASH_ATTN_D256=1 (benchmarks show
  2.4x TTFT regression due to BHSD->BHDS layout conversion overhead)
- Update README: remove attn_kernel_enabled requirement, document kernel findings
Adds nkilib_kernel_patch.py that replaces NxDI's bundled flash attention
kernel (_pre_prod_kernels) with our modified nkilib kernel supporting
head_dim up to 256 via d-tiling. Requires installing the nki-library
fork (feature/head-dim-256 branch) as a standalone override.

Enable with: QWEN35_PATCH_FLASH_ATTN=1
Prerequisite: pip install git+https://github.com/jimburtoft/nki-library.git@feature/head-dim-256

The adapter translates NxDI parameter names (do_out_tp, kernel_name,
use_dma_transpose) to nkilib convention (tp_out, causal_mask) and applies
the same decorator stack (peel, re-jit torchxla, skip_middle_end,
enable_stack_allocator).

perform_prefill updated with three priority paths:
1. Patched nkilib kernel (zero layout overhead) - QWEN35_PATCH_FLASH_ATTN=1
2. External NKI kernel (has TTFT regression) - QWEN35_USE_FLASH_ATTN_D256=1
3. PyTorch softmax fallback (default for head_dim>128)
The previous approach tried to monkey-patch _flash_fwd_call_nki which passed
do_out_tp=True (translated to tp_out=True). But tp_out=True is not supported
for d>128 because the output write-back path would need d on the par_dim.

New approach: call the nkilib kernel directly from perform_prefill with
tp_out=False. The kernel returns (B*H, seqlen, d) which we reshape to
BHSD (B, H, S, D) and return with FlashAttentionStrategy.NONE.

No layout conversion overhead -- Q and K are prepared in tp_q=True format
(B*H, seqlen, d) which matches what the nkilib kernel expects with GQA.
…ssue

Calling the nkilib kernel with grid=(nc(lnc),) causes num_shard=None from
nl.num_programs(). For now, call without grid (single-core mode) to test
correctness first. Performance optimization with SPMD grid can follow.
The num_shard=None issue is fixed in nkilib kernel_helpers.py (commit 91e06cf
in nki-library fork). Re-enable the full NxDI decorator stack which is needed
for CTE kernel optimizations.
The previous approach (peel @nki.jit -> re-jit with mode='torchxla' +
skip_middle_end_transformations + enable_stack_allocator) caused tracing
issues: the NKI compiler's inline_function mechanism doesn't properly
propagate trace context (builtin, address= support) to sub-functions.

nki_jit() wraps the already-traced GenericKernel directly for XLA
execution, preserving the original trace context intact. This is simpler
and avoids all the sub-function tracing issues.
The nkilib flash attention kernel for head_dim>256 works in standalone
mode (all tests pass) but cannot be integrated with NxDI due to an NKI
compiler bug: TraceKernel.inline_function doesn't propagate trace context
(builtin injection) to sub-functions in torchxla mode.

- nkilib_kernel_patch.py: disabled with early return, ready to enable
  when the compiler bug is fixed
- README: documented the compiler bug and nki-library fork work
- issue_nki_inline_function_builtin.md: full bug report filed
- Add is_speculative_decoding flag to MoE forward call, aligning with
  upstream NxDI 2.28 Qwen3 MoE changes
- Add vision encoder (modeling_qwen35_moe_vision.py): Qwen3.5 ViT with
  27-layer transformer, patch merger, rotary embeddings, block-diagonal
  attention mask, and sequence length bucketing
- Add VL orchestrator (modeling_qwen35_moe_vl.py): mRoPE position ID
  computation, vision+text input preparation, generate loop with
  rope_deltas for TKG decode phase
- Add image input integration tests: tokenization, mRoPE correctness,
  vision encoder shapes, end-to-end VL pipeline
- Update __init__.py exports and README with vision-language docs
Vision encoder (27-layer ViT, 442M params) compiles via torch_neuronx.trace()
in 162.7s and runs inference in 8.3ms with cosine similarity 0.999 vs CPU.

Changes:
- compile_vision_encoder.py: Standalone compilation script for the ViT
- Fix rotary embeddings to use rotate_half (matching HF reference)
- Fix weight loading prefix (model.visual.* not visual.*)
- Fix HF import chain for Qwen3VL classes in transformers 4.57+
- Add load_compiled() and load_vision_weights_from_hf() to wrapper
- Add standalone mode (model_cls=None) for pre-compiled model loading
- Update VL orchestrator compile/load to support pre-compiled vision encoder

Tested on trn2.3xlarge: 3/3 E2E tests pass (dummy image, real image via
HF processor, compiled model load through wrapper).
- Override _get_model_outputs in NeuronQwen35MoeForCausalLM to pass all 24
  positional args explicitly (positions 7-21 as torch.empty(0), 22-23 as
  vision_embeddings/vision_mask)
- Add get_required_kwargs to propagate llava_args through HF generation loop
- Qwen35ModelWrapper.input_generator produces 24-arg tuples for CTE and TKG
- NeuronQwen35MoeModel.forward accepts vision_embeddings/vision_mask at positions 22-23
- encode_vision_to_input uses index_put_ to scatter vision embeddings
- VL generate() in modeling_qwen35_moe_vl.py computes mRoPE position IDs,
  runs vision encoder, and passes llava_args to text model
- Add test_vl_e2e.py for full VL pipeline testing on trn2
…eness

Wire 3D position IDs (T/H/W) through the full inference pipeline:
- Qwen35MRoPEEmbedding with interleaved layout [11,11,10] replaces RotaryEmbedding
- rotary_position_id at slot 21 carries (3,B,S) mRoPE through CTE/TKG
- VL generate() passes 3D positions via llava_args[2]
- _get_model_outputs extracts mRoPE or generates text-only defaults
- input_generator and pad_inputs handle trace-time and bucket padding
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant