Skip to content

Contrib/granite 4.0 h small#69

Open
jimburtoft wants to merge 5 commits intoaws-neuron:mainfrom
jimburtoft:contrib/granite-4.0-h-small
Open

Contrib/granite 4.0 h small#69
jimburtoft wants to merge 5 commits intoaws-neuron:mainfrom
jimburtoft:contrib/granite-4.0-h-small

Conversation

@jimburtoft
Copy link

Description

NxDI contrib port of IBM's Granite 4.0-H-Small (GraniteMoeHybridForCausalLM) -- a hybrid Mamba2/Attention architecture with Mixture-of-Experts. This is one of the first hybrid SSM/Attention + MoE models, combining Mamba2 recurrent layers for efficient sequence modeling, sparse attention for long-range dependencies, and MoE for parameter efficiency.
Key implementation challenges solved:

  1. Mamba state persistence across XLA graph executions -- conv_state and ssm_state for 36 Mamba layers persisted using input_output_aliases (same mechanism as KV cache), following the MLlama vision_key_values pattern
  2. Mamba2 parallel scan for prefill -- O(L^2) cumulative-sum-in-log-space, mathematically equivalent to HF's chunk-based SSD. Also includes an optional NKI kernel using nisa.tensor_tensor_scan for O(L) hardware-accelerated scanning at longer context lengths
  3. Conv1d compiler bug workaround -- SDK 2.28 TEN404 crashes on seq_len=1 (decode path); worked around with manual depthwise convolution
  4. Gated RMSNorm ordering -- Granite applies gate before normalization (silu(gate) * x -> RMSNorm -> weight), validated against HF reference

Model Information

Model Name: Granite 4.0-H-Small
Model Architecture: Hybrid Mamba2/Attention with MoE -- 40 layers (36 Mamba2 + 4 Attention at indices 5, 15, 25, 35), 72 experts per layer with top-10 routing + 1 shared expert, ~4B total parameters (~800M active per token), hidden_size=4096, no positional embeddings ("nope")
Purpose: Text generation (code, general-purpose)

Checklist

Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md for detailed guidelines.

Required Components

  • Accuracy Test (ex. test/integration/test_model.py)
    • At least one integration test that validates model accuracy
    • Uses logit validation or equivalent accuracy verification
    • Test can compile and run the model on Neuron
  • README.md with the following sections:
    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types (Trn1/Trn2/Inf2)
    • Example Checkpoints: Links to compatible model checkpoints (e.g., HuggingFace Hub)
    • Testing Instructions: Command to run the test suite for the model
  • Source Code (src/)
    • Modeling code following NxD Inference patterns
    • Properly structured in the contrib folder hierarchy

Optional Components

  • Unit Tests (CPU or Neuron-based)
    • Tests for individual modeling components
    • Located in test/unit/ directory

Folder Structure

Confirm your contribution follows this structure:

/contrib/models/granite-4.0-h-small/
  README.md
  /src
    __init__.py
    modeling_granite.py
  /test
    __init__.py
    /unit
      __init__.py
      test_nki_selective_scan.py
    /integration
      __init__.py
      test_model.py
      benchmark_latency.py

Testing

How did you test this change?
Tested on trn2.3xlarge (LNC=2, 4 NeuronCores) with Neuron SDK 2.28 (DLAMI Deep Learning AMI Neuron (Ubuntu 24.04) 20260227). Configuration: TP=4, batch_size=1, max_context_length=128, seq_len=2048, bfloat16.
Accuracy validation against HuggingFace BF16 CPU reference across 10 diverse prompts (factual, code, conversational, single-token). Also validated decode quality with 30-token greedy generation.
Test Results:
Prefill accuracy (10 prompts vs HF BF16 CPU):

Metric Value
Greedy token match rate 100% (10/10)
Average Pearson correlation 0.9968
Average Cosine similarity 0.9987
Max absolute logit diff 3.00
Performance (max_context_length=128, quadratic scan):
Metric Value
-------- -------
Prefill latency 717 ms
Decode per-token latency ~50 ms
100-token throughput 17.6 tok/s
Compile time ~16 min
Token-level decode divergence from HF is expected: the full-sequence parallel scan and HF's chunk-based SSD are mathematically equivalent but accumulate BF16 rounding differently, causing early divergence that cascades through autoregressive generation. Deterministic answers (e.g., "Paris") match exactly.

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.28
  • Instance Type(s): trn2.3xlarge (TP=4, LNC=2)
  • PyTorch Version: 2.9
  • Python Version: 3.12

Additional Information

NKI Selective Scan Kernel (optional): The model includes an optional NKI kernel (USE_NKI_SCAN flag in modeling_granite.py) that replaces the O(L^2) quadratic scan with O(L) hardware-accelerated nisa.tensor_tensor_scan. At max_context_length=128, the quadratic scan is ~30% faster (compiler vectorizes the 128x128 matrix efficiently). At max_context_length=256+, the NKI kernel is required -- the quadratic scan causes compiler OOM.
Context length scaling:

max_context_length Quadratic Compile NKI Compile Runtime (trn2.3xlarge)
128 OK OK OK
256 Compiler OOM OK HBM OOM (needs larger instance)
512+ Compiler OOM Compiler OOM N/A
MoE (72 experts x 40 layers = 2,880 expert weight sets) dominates memory. Longer contexts require trn2.48xlarge to distribute experts across more cores.
Known limitations:
  1. max_context_length=128 on trn2.3xlarge -- longer contexts require larger instances
  2. Batch size 1 only -- batch_size > 1 not yet validated
  3. No on-device sampling tested -- uses raw logits (on_device_sampling_config=None)
  4. Conv1d workaround -- manual depthwise conv avoids SDK 2.28 TEN404 bug; future SDK versions may not need it
  5. NKI scan overhead at short contexts -- disabled by default, enable for max_context_length >= 256

Related Issues

None.

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

Port of IBM granite-4.0-h-small (GraniteMoeHybridForCausalLM) to NxDI.
Hybrid architecture: 36 Mamba2 + 4 Attention layers with 72-expert MoE.
Mamba state persistence via input_output_aliases (conv_state + ssm_state).

Validated on trn2.3xlarge (TP=4, SDK 2.28):
- Prefill: 100% greedy match, Pearson=0.9968, Cosine=0.9987 (10 prompts)
- Decode: coherent text generation matching HF reference
Replace O(L^2) quadratic parallel scan with O(L) hardware-accelerated
scan using nisa.tensor_tensor_scan on Trainium2. The NKI kernel is
enabled by default (USE_NKI_SCAN=True) and validated against the
quadratic baseline (Pearson=0.987, Cosine=0.978, 100% greedy match).

Changes:
- modeling_granite.py: Add nki_scan_kernel and _nki_selective_scan helper
  with USE_NKI_SCAN toggle (falls back to quadratic scan when disabled)
- test/unit/test_nki_selective_scan.py: Standalone kernel with CPU
  reference, quadratic reference, and validation tests
- README.md: Document NKI kernel, accuracy results, and requirements
tensor_tensor_scan is a NeuronCore ISA primitive available on all Neuron
hardware with NKI support, not just Trainium2. The platform override env
var is for telling the compiler which target to compile for, not a
hardware restriction.
Benchmark results on trn2.3xlarge (TP=4, max_context_length=128):
- Quadratic scan: 717ms prefill, 50.3ms/token decode, 17.6 tok/s
- NKI scan: 935ms prefill (+30%), identical decode latency
- The NKI kernel's 8,192 tensor_tensor_scan invocations have more
  overhead than the compiler-optimized 128x128 quadratic matrix at
  short contexts. NKI should win at L>=512 where O(L^2) dominates.

Changes:
- Set USE_NKI_SCAN=False as default (quadratic is faster at L=128)
- Add performance benchmarks section to README with latency data
- Add benchmark_latency.py script for reproducible measurements
- Update known limitations to explain the NKI tradeoff
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant