Add Boltz-2 contrib model with NKI kernels for pairformer inference#64
Open
jimburtoft wants to merge 2 commits intoaws-neuron:mainfrom
Open
Add Boltz-2 contrib model with NKI kernels for pairformer inference#64jimburtoft wants to merge 2 commits intoaws-neuron:mainfrom
jimburtoft wants to merge 2 commits intoaws-neuron:mainfrom
Conversation
Boltz-2 biomolecular structure prediction on Trainium 2 using custom NKI kernels for the O(N^3) triangular attention and multiplicative update operations. Uses torch_neuronx.trace() with weight replacement pattern to compile 64 pairformer layers in 7.4 minutes. Validated on trn2.3xlarge at N=256 (s_cos=0.999220, z_cos=0.943929) and trn2.48xlarge at N=512 (s_cos=0.999460, z_cos=0.979214).
Add SPMD grid=[2] NKI mega-kernel that fuses all 7 sub-operations of a PairformerLayer into a single kernel call, eliminating host-device round trips. Validated at N=256: s_cos=0.999995, z_cos=0.999245. New files: - src/fused_z_ops_spmd.py: z-only fused operations (TriMul, TriAttn, Transition_z) - src/full_pairformer_layer_spmd.py: full layer kernel (PBA + all z-ops + Transition_s) - test/integration/compile_full_layer_spmd.py: compilation script - test/integration/test_full_layer_spmd.py: correctness test vs CPU reference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Boltz-2 biomolecular structure prediction on AWS Trainium 2. This is a novel contrib that uses custom NKI (Neuron Kernel Interface) kernels for the four O(N^3) triangular operations in the pairformer trunk, compiled via torch_neuronx.trace() with the weight replacement pattern to handle all 64 pairformer layers in 7.4 minutes of setup time.
Boltz-2 predicts 3D structures of proteins, RNA, DNA, and small molecules. The pairformer trunk (64 transformer layers with triangular attention and triangular multiplicative updates) is the computational bottleneck. Two custom NKI kernels replace the expensive operations with implementations that run directly on the NeuronCore, and the weight replacement pattern (inline_weights_to_neff=False + replace_weights()) avoids recompiling each of the 64 identical-graph layers.
Model Information
Model Name: Boltz-2 (v2.2.1)
Model Architecture: 64-layer pairformer trunk (triangular attention, triangular multiplicative updates, pair bias attention) + diffusion score model for coordinate generation. 507M parameters.
Purpose: Biomolecular structure prediction (proteins, RNA, DNA, small molecules)
Checklist
Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md (../contrib/CONTRIBUTING.md) for detailed guidelines.
Required Components
Optional Components
Folder Structure
Confirm your contribution follows this structure:
/contrib/models/Boltz-2/
README.md
/src
init.py
modeling_boltz2.py
nki_triangular_attention.py
nki_triangular_mul.py
/test
init.py
/unit
init.py
/integration
init.py
test_model.py
Testing
How did you test this change?
Tested on trn2.3xlarge (Neuron SDK 2.27 and 2.28) and trn2.48xlarge (SDK 2.28). Full 64-layer pairformer compiled and validated at N=128, N=256, and N=512 with trained Boltz-2 weights loaded from the official checkpoint.
Test Results:
Compatibility
Tested with:
Additional Information
This is the first NxDI contrib to use NKI (Neuron Kernel Interface) custom kernels. The approach differs from typical contribs in that it uses torch_neuronx.trace() directly rather than NxDI model classes (NeuronBaseModel/NeuronBaseForCausalLM), because Boltz-2's pairformer architecture (triangular attention with full 2D bias, triangular multiplicative updates) has no equivalent in the NxDI model zoo.
Key limitations:
vLLM Integration
By submitting this PR, I confirm that: