Open
Conversation
Port of IBM granite-4.0-h-small (GraniteMoeHybridForCausalLM) to NxDI. Hybrid architecture: 36 Mamba2 + 4 Attention layers with 72-expert MoE. Mamba state persistence via input_output_aliases (conv_state + ssm_state). Validated on trn2.3xlarge (TP=4, SDK 2.28): - Prefill: 100% greedy match, Pearson=0.9968, Cosine=0.9987 (10 prompts) - Decode: coherent text generation matching HF reference
Replace O(L^2) quadratic parallel scan with O(L) hardware-accelerated scan using nisa.tensor_tensor_scan on Trainium2. The NKI kernel is enabled by default (USE_NKI_SCAN=True) and validated against the quadratic baseline (Pearson=0.987, Cosine=0.978, 100% greedy match). Changes: - modeling_granite.py: Add nki_scan_kernel and _nki_selective_scan helper with USE_NKI_SCAN toggle (falls back to quadratic scan when disabled) - test/unit/test_nki_selective_scan.py: Standalone kernel with CPU reference, quadratic reference, and validation tests - README.md: Document NKI kernel, accuracy results, and requirements
tensor_tensor_scan is a NeuronCore ISA primitive available on all Neuron hardware with NKI support, not just Trainium2. The platform override env var is for telling the compiler which target to compile for, not a hardware restriction.
Benchmark results on trn2.3xlarge (TP=4, max_context_length=128): - Quadratic scan: 717ms prefill, 50.3ms/token decode, 17.6 tok/s - NKI scan: 935ms prefill (+30%), identical decode latency - The NKI kernel's 8,192 tensor_tensor_scan invocations have more overhead than the compiler-optimized 128x128 quadratic matrix at short contexts. NKI should win at L>=512 where O(L^2) dominates. Changes: - Set USE_NKI_SCAN=False as default (quadratic is faster at L=128) - Add performance benchmarks section to README with latency data - Add benchmark_latency.py script for reproducible measurements - Update known limitations to explain the NKI tradeoff
…M limits on trn2.3xlarge
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
NxDI contrib port of IBM's Granite 4.0-H-Small (
GraniteMoeHybridForCausalLM) -- a hybrid Mamba2/Attention architecture with Mixture-of-Experts. This is one of the first hybrid SSM/Attention + MoE models, combining Mamba2 recurrent layers for efficient sequence modeling, sparse attention for long-range dependencies, and MoE for parameter efficiency.Key implementation challenges solved:
input_output_aliases(same mechanism as KV cache), following the MLlamavision_key_valuespatternnisa.tensor_tensor_scanfor O(L) hardware-accelerated scanning at longer context lengthssilu(gate) * x -> RMSNorm -> weight), validated against HF referenceModel Information
Model Name: Granite 4.0-H-Small
Model Architecture: Hybrid Mamba2/Attention with MoE -- 40 layers (36 Mamba2 + 4 Attention at indices 5, 15, 25, 35), 72 experts per layer with top-10 routing + 1 shared expert, ~4B total parameters (~800M active per token), hidden_size=4096, no positional embeddings ("nope")
Purpose: Text generation (code, general-purpose)
Checklist
Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md for detailed guidelines.
Required Components
test/integration/test_model.py)src/)Optional Components
test/unit/directoryFolder Structure
Confirm your contribution follows this structure:
Testing
How did you test this change?
Tested on trn2.3xlarge (LNC=2, 4 NeuronCores) with Neuron SDK 2.28 (DLAMI
Deep Learning AMI Neuron (Ubuntu 24.04) 20260227). Configuration: TP=4, batch_size=1, max_context_length=128, seq_len=2048, bfloat16.Accuracy validation against HuggingFace BF16 CPU reference across 10 diverse prompts (factual, code, conversational, single-token). Also validated decode quality with 30-token greedy generation.
Test Results:
Prefill accuracy (10 prompts vs HF BF16 CPU):
Compatibility
Tested with:
Additional Information
NKI Selective Scan Kernel (optional): The model includes an optional NKI kernel (
USE_NKI_SCANflag inmodeling_granite.py) that replaces the O(L^2) quadratic scan with O(L) hardware-acceleratednisa.tensor_tensor_scan. At max_context_length=128, the quadratic scan is ~30% faster (compiler vectorizes the 128x128 matrix efficiently). At max_context_length=256+, the NKI kernel is required -- the quadratic scan causes compiler OOM.Context length scaling:
Related Issues
None.
vLLM Integration
By submitting this PR, I confirm that: