Open
Conversation
ZUNA (Zyphra/ZUNA) is a 382M-parameter masked diffusion autoencoder for EEG signals. Ported to AWS Neuron using torch_neuronx.trace() with SDPA replacement for flex_attention, achieving 28x speedup over GPU (A10G) and 106x over CPU. - Encoder-decoder architecture with 4D axial RoPE, rectified flow diffusion - 50-step pipeline: 102ms latency, 9.80 samples/sec on inf2.xlarge - DataParallel: 16.29 samples/sec (1.66x scaling on 2 NeuronCores) - Cosine similarity: 0.990 (auto-cast=matmult), 1.000000 (pure FP32) - Integration tests: accuracy, performance, DataParallel validation - Comprehensive executed notebook with full walkthrough
Measured min across 50 seeds is 0.937; seed 7 gives 0.965. Setting threshold at 0.930 ensures all seeds pass while remaining meaningful. All 15 tests pass on inf2.xlarge.
- README: Add per-seed auto-cast comparison table showing that outlier seeds (e.g. seed 7: 0.966 with matmult) become perfect (1.000000) without it. Confirms all error is from BF16 matmul, not compiler or SDPA replacement. - Tests: Add TestNoAutocast class (5 parametrized seeds) that compiles with --auto-cast=none and asserts cosine_sim >= 0.999. All 20 tests pass on inf2.xlarge (15 original + 5 no-autocast).
26 unit tests in test/unit/test_patching.py covering: - flex_attention patch installation (5 tests) - Model loading and config validation (4 tests) - SDPA monkey-patching of all 48 attention modules (2 tests) - Synthetic input generation (4 tests) - Encoder/decoder wrapper shapes and finiteness (6 tests) - Diffusion loop determinism and correctness (5 tests) All run on CPU in ~30s with no Neuron hardware required.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Add ZUNA (Zyphra/ZUNA), a 382M-parameter masked diffusion autoencoder for EEG signals, to the contrib directory. The model is ported to Neuron using
torch_neuronx.trace()with SDPA replacement for PyTorch's flex_attention API (unsupported on XLA). The 50-step rectified flow diffusion pipeline achieves 102ms latency on inf2.xlarge -- 28x faster than GPU (A10G) and 106x faster than CPU.Model Information
Model Name: ZUNA
Model Architecture: Encoder-decoder masked diffusion autoencoder (16-layer encoder with self-attention + SwiGLU + register interleaving + MMD bottleneck, 16-layer decoder with cross-attention + self-attention + AdaRMSNorm timestep conditioning, 4D axial RoPE, rectified flow with Euler ODE solver)
Purpose: EEG signal reconstruction / brain-computer interface foundation model
Checklist
Required Components
test/integration/test_model.py)src/)Optional Components
test/unit/directoryFolder Structure
Confirm your contribution follows this structure:
/contrib/models/ZUNA/
README.md
zuna_neuron_comprehensive.ipynb (executed, with cell outputs)
/src
init.py
modeling_zuna.py
/test
init.py
/unit
init.py
test_patching.py
/integration
init.py
test_model.py
Testing
How did you test this change?
All tests were run on an inf2.xlarge instance (1 Inferentia2 chip, 2 NeuronCores) in sa-east-1 using the Deep Learning AMI Neuron (Ubuntu 24.04) with Neuron SDK 2.27 and the pre-installed PyTorch inference venv.
The test suite includes 46 tests across 12 classes in 2 files:
Unit tests (
test/unit/test_patching.py) -- 26 CPU-only tests, ~30s:Integration tests (
test/integration/test_model.py) -- 20 Neuron tests, ~7min:--auto-cast=matmult, MSE bounds--auto-cast=none, verifies perfect cosine similarity (>=0.999) across all seeds -- confirms all accuracy loss comes from BF16 matmul conversion, not compiler or SDPA replacementTest Results:
================= 46 passed, 12 warnings in 454.16s (0:07:34) ==================
test_patching.py::TestFlexAttentionPatch::test_module_exists PASSED
test_patching.py::TestFlexAttentionPatch::test_create_block_mask_callable PASSED
test_patching.py::TestFlexAttentionPatch::test_noop_mask_callable PASSED
test_patching.py::TestFlexAttentionPatch::test_mask_mod_signature_exists PASSED
test_patching.py::TestFlexAttentionPatch::test_flex_attention_raises PASSED
test_patching.py::TestModelLoading::test_model_loads PASSED
test_patching.py::TestModelLoading::test_parameter_count PASSED
test_patching.py::TestModelLoading::test_model_in_eval_mode PASSED
test_patching.py::TestModelLoading::test_model_args_input_dim PASSED
test_patching.py::TestPatching::test_self_attention_patched PASSED
test_patching.py::TestPatching::test_cross_attention_patched PASSED
test_patching.py::TestSyntheticInput::test_shapes PASSED
test_patching.py::TestSyntheticInput::test_tok_idx_dtype PASSED
test_patching.py::TestSyntheticInput::test_deterministic PASSED
test_patching.py::TestSyntheticInput::test_different_seeds PASSED
test_patching.py::TestEncoderWrapper::test_output_shape PASSED
test_patching.py::TestEncoderWrapper::test_output_finite PASSED
test_patching.py::TestEncoderWrapper::test_tok_idx_2d_auto_unsqueeze PASSED
test_patching.py::TestDecoderWrapper::test_output_shape PASSED
test_patching.py::TestDecoderWrapper::test_output_finite PASSED
test_patching.py::TestDecoderWrapper::test_different_timesteps_different_output PASSED
test_patching.py::TestDiffusionLoop::test_output_shape PASSED
test_patching.py::TestDiffusionLoop::test_output_finite PASSED
test_patching.py::TestDiffusionLoop::test_deterministic_with_seed PASSED
test_patching.py::TestDiffusionLoop::test_different_seeds_different_output PASSED
test_patching.py::TestDiffusionLoop::test_more_steps_changes_output PASSED
test_model.py::TestModelLoads::test_encoder_loads PASSED
test_model.py::TestModelLoads::test_decoder_loads PASSED
test_model.py::TestModelLoads::test_encoder_runs PASSED
test_model.py::TestModelLoads::test_decoder_runs PASSED
test_model.py::TestModelLoads::test_full_pipeline PASSED
test_model.py::TestAccuracy::test_cosine_similarity0 seed=0: cosine_sim=0.981687 PASSED
test_model.py::TestAccuracy::test_cosine_similarity7 seed=7: cosine_sim=0.965521 PASSED
test_model.py::TestAccuracy::test_cosine_similarity13 seed=13: cosine_sim=0.999861 PASSED
test_model.py::TestAccuracy::test_cosine_similarity42 seed=42: cosine_sim=0.998712 PASSED
test_model.py::TestAccuracy::test_cosine_similarity99 seed=99: cosine_sim=0.999337 PASSED
test_model.py::TestAccuracy::test_mse_bounded MSE: 0.00003314 PASSED
test_model.py::TestNoAutocast::test_cosine_similarity_noautocast0 cosine_sim=1.000000 PASSED
test_model.py::TestNoAutocast::test_cosine_similarity_noautocast7 cosine_sim=1.000000 PASSED
test_model.py::TestNoAutocast::test_cosine_similarity_noautocast13 cosine_sim=1.000000 PASSED
test_model.py::TestNoAutocast::test_cosine_similarity_noautocast42 cosine_sim=1.000000 PASSED
test_model.py::TestNoAutocast::test_cosine_similarity_noautocast99 cosine_sim=1.000000 PASSED
test_model.py::TestDataParallel::test_data_parallel_runs PASSED
test_model.py::TestDataParallel::test_data_parallel_speedup 1.63x speedup PASSED
test_model.py::TestPerformance::test_pipeline_throughput 9.77 samples/sec PASSED
test_model.py::TestPerformance::test_pipeline_latency p50: 102.2 ms PASSED
Compatibility
Tested with:
Additional Information
Porting challenge: flex_attention. ZUNA uses PyTorch's
flex_attentionAPI which is not supported on Neuron's XLA device. The solution patches thetorch.nn.attention.flex_attentionmodule with dummy symbols before importing ZUNA, then monkey-patches all 48 attention modules (32 self-attention + 16 cross-attention) to useF.scaled_dot_product_attention. Sincesliding_window=65536far exceeds the inference sequence length (~100), full attention without masking is mathematically correct.auto-cast accuracy tradeoff. With
--auto-cast=matmult, mean cosine similarity is 0.990 (min 0.937 over 50 seeds). With--auto-cast=none, cosine similarity is 1.000000 for all seeds, at ~2x latency cost (212ms vs 102ms). The TestNoAutocast class explicitly verifies this, confirming that all accuracy loss originates from BF16 matmul conversion, not the compiler or SDPA replacement.Cross-platform benchmarks. The 50-step diffusion pipeline runs at 102ms on Neuron (inf2.xlarge), compared to 2,854ms on GPU (A10G g5.xlarge) and 10,856ms on CPU -- 28x and 106x faster respectively.
Related Issues
N/A
vLLM Integration
By submitting this PR, I confirm that: