Skip to content

Add Boltz-2 contrib model with NKI kernels for pairformer inference#64

Open
jimburtoft wants to merge 2 commits intoaws-neuron:mainfrom
jimburtoft:contrib/boltz-2
Open

Add Boltz-2 contrib model with NKI kernels for pairformer inference#64
jimburtoft wants to merge 2 commits intoaws-neuron:mainfrom
jimburtoft:contrib/boltz-2

Conversation

@jimburtoft
Copy link

Boltz-2 biomolecular structure prediction on AWS Trainium 2. This is a novel contrib that uses custom NKI (Neuron Kernel Interface) kernels for the four O(N^3) triangular operations in the pairformer trunk, compiled via torch_neuronx.trace() with the weight replacement pattern to handle all 64 pairformer layers in 7.4 minutes of setup time.
Boltz-2 predicts 3D structures of proteins, RNA, DNA, and small molecules. The pairformer trunk (64 transformer layers with triangular attention and triangular multiplicative updates) is the computational bottleneck. Two custom NKI kernels replace the expensive operations with implementations that run directly on the NeuronCore, and the weight replacement pattern (inline_weights_to_neff=False + replace_weights()) avoids recompiling each of the 64 identical-graph layers.

Model Information

Model Name: Boltz-2 (v2.2.1)
Model Architecture: 64-layer pairformer trunk (triangular attention, triangular multiplicative updates, pair bias attention) + diffusion score model for coordinate generation. 507M parameters.
Purpose: Biomolecular structure prediction (proteins, RNA, DNA, small molecules)

Checklist

Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md (../contrib/CONTRIBUTING.md) for detailed guidelines.
Required Components

  • Accuracy Test (ex. test/integration/test_model.py)
    • 4 integration tests: model loading, compilation, forward pass validation (no NaN/Inf), accuracy vs CPU reference
    • Cosine similarity validation against CPU reference (s_cos >= 0.99, z_cos >= 0.99)
    • Tests compile 2 pairformer layers at N=128 on Neuron
  • README.md with the following sections:
    • Usage Example: Step-by-step code showing monkey-patching, compilation, and inference
    • Compatibility Matrix: Tested on trn2.3xlarge (SDK 2.27, 2.28) and trn2.48xlarge (SDK 2.28)
    • Example Checkpoints: boltz==2.2.1 on PyPI (auto-downloads checkpoint)
    • Testing Instructions: pytest command with environment setup
  • Source Code (src/)
    • modeling_boltz2.py: Pipeline module (monkey-patching, weight-replaced compilation, inference)
    • nki_triangular_attention.py: NKI kernel for triangular attention with online softmax
    • nki_triangular_mul.py: NKI kernel for triangular multiplicative update (einsum contraction)

Optional Components

  • Unit Tests (CPU or Neuron-based)
    • Not included; standalone NKI kernel tests are in the development repository

Folder Structure

Confirm your contribution follows this structure:
/contrib/models/Boltz-2/
README.md
/src
init.py
modeling_boltz2.py
nki_triangular_attention.py
nki_triangular_mul.py
/test
init.py
/unit
init.py
/integration
init.py
test_model.py

Testing

How did you test this change?
Tested on trn2.3xlarge (Neuron SDK 2.27 and 2.28) and trn2.48xlarge (SDK 2.28). Full 64-layer pairformer compiled and validated at N=128, N=256, and N=512 with trained Boltz-2 weights loaded from the official checkpoint.
Test Results:

N Layers Instance s_cos z_cos Status
128 1 trn2.3xlarge 0.999796 0.998359 PASS
128 8 trn2.3xlarge 0.999713 0.995417 PASS
256 64 trn2.3xlarge 0.999220 0.943929 PASS
512 64 trn2.48xlarge 0.999460 0.979214 PASS
Standalone NKI kernel accuracy (cosine similarity vs CPU reference):
  • Triangular Attention N=128: 0.999713, N=256: 1.000029
  • Triangular Mul (Outgoing) N=128: 0.999967, N=256: 0.999903
  • Triangular Mul (Incoming) N=128: 0.999967

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.27, 2.28
  • Instance Type(s): trn2.3xlarge, trn2.48xlarge
  • PyTorch Version: 2.9
  • Python Version: 3.12

Additional Information

This is the first NxDI contrib to use NKI (Neuron Kernel Interface) custom kernels. The approach differs from typical contribs in that it uses torch_neuronx.trace() directly rather than NxDI model classes (NeuronBaseModel/NeuronBaseForCausalLM), because Boltz-2's pairformer architecture (triangular attention with full 2D bias, triangular multiplicative updates) has no equivalent in the NxDI model zoo.
Key limitations:

  • Per-layer host-device round trips (64 separate traced models) contribute significant sync overhead
  • N must be a multiple of 128 (NKI kernel tiling constraint)
  • --auto-cast matmult must NOT be used (destroys accuracy for this model)

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines (../contrib/CONTRIBUTING.md)
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

Boltz-2 biomolecular structure prediction on Trainium 2 using custom NKI
kernels for the O(N^3) triangular attention and multiplicative update
operations. Uses torch_neuronx.trace() with weight replacement pattern
to compile 64 pairformer layers in 7.4 minutes.

Validated on trn2.3xlarge at N=256 (s_cos=0.999220, z_cos=0.943929)
and trn2.48xlarge at N=512 (s_cos=0.999460, z_cos=0.979214).
Add SPMD grid=[2] NKI mega-kernel that fuses all 7 sub-operations of a
PairformerLayer into a single kernel call, eliminating host-device round
trips. Validated at N=256: s_cos=0.999995, z_cos=0.999245.

New files:
- src/fused_z_ops_spmd.py: z-only fused operations (TriMul, TriAttn, Transition_z)
- src/full_pairformer_layer_spmd.py: full layer kernel (PBA + all z-ops + Transition_s)
- test/integration/compile_full_layer_spmd.py: compilation script
- test/integration/test_full_layer_spmd.py: correctness test vs CPU reference
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant