Skip to content

Contrib/zuna#59

Open
jimburtoft wants to merge 5 commits intoaws-neuron:mainfrom
jimburtoft:contrib/zuna
Open

Contrib/zuna#59
jimburtoft wants to merge 5 commits intoaws-neuron:mainfrom
jimburtoft:contrib/zuna

Conversation

@jimburtoft
Copy link

Description

Add ZUNA (Zyphra/ZUNA), a 382M-parameter masked diffusion autoencoder for EEG signals, to the contrib directory. The model is ported to Neuron using torch_neuronx.trace() with SDPA replacement for PyTorch's flex_attention API (unsupported on XLA). The 50-step rectified flow diffusion pipeline achieves 102ms latency on inf2.xlarge -- 28x faster than GPU (A10G) and 106x faster than CPU.

Model Information

Model Name: ZUNA
Model Architecture: Encoder-decoder masked diffusion autoencoder (16-layer encoder with self-attention + SwiGLU + register interleaving + MMD bottleneck, 16-layer decoder with cross-attention + self-attention + AdaRMSNorm timestep conditioning, 4D axial RoPE, rectified flow with Euler ODE solver)
Purpose: EEG signal reconstruction / brain-computer interface foundation model

Checklist

Required Components

  • Accuracy Test (ex. test/integration/test_model.py)
    • At least one integration test that validates model accuracy
    • Uses logit validation or equivalent accuracy verification
    • Test can compile and run the model on Neuron
  • README.md with the following sections:
    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types (Trn1/Trn2/Inf2)
    • Example Checkpoints: Links to compatible model checkpoints (e.g., HuggingFace Hub)
    • Testing Instructions: Command to run the test suite for the model
  • Source Code (src/)
    • Modeling code following NxD Inference patterns
    • Properly structured in the contrib folder hierarchy

Optional Components

  • Unit Tests (CPU or Neuron-based)
    • Tests for individual modeling components
    • Located in test/unit/ directory

Folder Structure

Confirm your contribution follows this structure:
/contrib/models/ZUNA/
README.md
zuna_neuron_comprehensive.ipynb (executed, with cell outputs)
/src
init.py
modeling_zuna.py
/test
init.py
/unit
init.py
test_patching.py
/integration
init.py
test_model.py

Testing

How did you test this change?
All tests were run on an inf2.xlarge instance (1 Inferentia2 chip, 2 NeuronCores) in sa-east-1 using the Deep Learning AMI Neuron (Ubuntu 24.04) with Neuron SDK 2.27 and the pre-installed PyTorch inference venv.
The test suite includes 46 tests across 12 classes in 2 files:
Unit tests (test/unit/test_patching.py) -- 26 CPU-only tests, ~30s:

  • TestFlexAttentionPatch (5): Dummy flex_attention symbols installed correctly
  • TestModelLoading (4): HuggingFace load, ~382M params, eval mode, config
  • TestPatching (2): All 32 Attention + 16 CrossAttention modules patched to SDPA
  • TestSyntheticInput (4): Input shapes, dtype, determinism, seed independence
  • TestEncoderWrapper (3): Output shape, finiteness, 2D tok_idx auto-unsqueeze
  • TestDecoderWrapper (3): Output shape, finiteness, timestep sensitivity
  • TestDiffusionLoop (5): Shape, finiteness, determinism, seed/step sensitivity
    Integration tests (test/integration/test_model.py) -- 20 Neuron tests, ~7min:
  • TestModelLoads (5): Encoder/decoder compile, run, output shape, full 50-step pipeline
  • TestAccuracy (6): Cosine similarity across 5 seeds with --auto-cast=matmult, MSE bounds
  • TestNoAutocast (5): Compiles with --auto-cast=none, verifies perfect cosine similarity (>=0.999) across all seeds -- confirms all accuracy loss comes from BF16 matmul conversion, not compiler or SDPA replacement
  • TestDataParallel (2): Multi-core execution and 1.63x speedup validation
  • TestPerformance (2): Throughput (>5 samples/sec) and latency (<200ms) thresholds
    Test Results:
    ================= 46 passed, 12 warnings in 454.16s (0:07:34) ==================
    test_patching.py::TestFlexAttentionPatch::test_module_exists PASSED
    test_patching.py::TestFlexAttentionPatch::test_create_block_mask_callable PASSED
    test_patching.py::TestFlexAttentionPatch::test_noop_mask_callable PASSED
    test_patching.py::TestFlexAttentionPatch::test_mask_mod_signature_exists PASSED
    test_patching.py::TestFlexAttentionPatch::test_flex_attention_raises PASSED
    test_patching.py::TestModelLoading::test_model_loads PASSED
    test_patching.py::TestModelLoading::test_parameter_count PASSED
    test_patching.py::TestModelLoading::test_model_in_eval_mode PASSED
    test_patching.py::TestModelLoading::test_model_args_input_dim PASSED
    test_patching.py::TestPatching::test_self_attention_patched PASSED
    test_patching.py::TestPatching::test_cross_attention_patched PASSED
    test_patching.py::TestSyntheticInput::test_shapes PASSED
    test_patching.py::TestSyntheticInput::test_tok_idx_dtype PASSED
    test_patching.py::TestSyntheticInput::test_deterministic PASSED
    test_patching.py::TestSyntheticInput::test_different_seeds PASSED
    test_patching.py::TestEncoderWrapper::test_output_shape PASSED
    test_patching.py::TestEncoderWrapper::test_output_finite PASSED
    test_patching.py::TestEncoderWrapper::test_tok_idx_2d_auto_unsqueeze PASSED
    test_patching.py::TestDecoderWrapper::test_output_shape PASSED
    test_patching.py::TestDecoderWrapper::test_output_finite PASSED
    test_patching.py::TestDecoderWrapper::test_different_timesteps_different_output PASSED
    test_patching.py::TestDiffusionLoop::test_output_shape PASSED
    test_patching.py::TestDiffusionLoop::test_output_finite PASSED
    test_patching.py::TestDiffusionLoop::test_deterministic_with_seed PASSED
    test_patching.py::TestDiffusionLoop::test_different_seeds_different_output PASSED
    test_patching.py::TestDiffusionLoop::test_more_steps_changes_output PASSED
    test_model.py::TestModelLoads::test_encoder_loads PASSED
    test_model.py::TestModelLoads::test_decoder_loads PASSED
    test_model.py::TestModelLoads::test_encoder_runs PASSED
    test_model.py::TestModelLoads::test_decoder_runs PASSED
    test_model.py::TestModelLoads::test_full_pipeline PASSED
    test_model.py::TestAccuracy::test_cosine_similarity0 seed=0: cosine_sim=0.981687 PASSED
    test_model.py::TestAccuracy::test_cosine_similarity7 seed=7: cosine_sim=0.965521 PASSED
    test_model.py::TestAccuracy::test_cosine_similarity13 seed=13: cosine_sim=0.999861 PASSED
    test_model.py::TestAccuracy::test_cosine_similarity42 seed=42: cosine_sim=0.998712 PASSED
    test_model.py::TestAccuracy::test_cosine_similarity99 seed=99: cosine_sim=0.999337 PASSED
    test_model.py::TestAccuracy::test_mse_bounded MSE: 0.00003314 PASSED
    test_model.py::TestNoAutocast::test_cosine_similarity_noautocast0 cosine_sim=1.000000 PASSED
    test_model.py::TestNoAutocast::test_cosine_similarity_noautocast7 cosine_sim=1.000000 PASSED
    test_model.py::TestNoAutocast::test_cosine_similarity_noautocast13 cosine_sim=1.000000 PASSED
    test_model.py::TestNoAutocast::test_cosine_similarity_noautocast42 cosine_sim=1.000000 PASSED
    test_model.py::TestNoAutocast::test_cosine_similarity_noautocast99 cosine_sim=1.000000 PASSED
    test_model.py::TestDataParallel::test_data_parallel_runs PASSED
    test_model.py::TestDataParallel::test_data_parallel_speedup 1.63x speedup PASSED
    test_model.py::TestPerformance::test_pipeline_throughput 9.77 samples/sec PASSED
    test_model.py::TestPerformance::test_pipeline_latency p50: 102.2 ms PASSED

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.27
  • Instance Type(s): inf2.xlarge
  • PyTorch Version: 2.9
  • Python Version: 3.12.3

Additional Information

Porting challenge: flex_attention. ZUNA uses PyTorch's flex_attention API which is not supported on Neuron's XLA device. The solution patches the torch.nn.attention.flex_attention module with dummy symbols before importing ZUNA, then monkey-patches all 48 attention modules (32 self-attention + 16 cross-attention) to use F.scaled_dot_product_attention. Since sliding_window=65536 far exceeds the inference sequence length (~100), full attention without masking is mathematically correct.
auto-cast accuracy tradeoff. With --auto-cast=matmult, mean cosine similarity is 0.990 (min 0.937 over 50 seeds). With --auto-cast=none, cosine similarity is 1.000000 for all seeds, at ~2x latency cost (212ms vs 102ms). The TestNoAutocast class explicitly verifies this, confirming that all accuracy loss originates from BF16 matmul conversion, not the compiler or SDPA replacement.
Cross-platform benchmarks. The 50-step diffusion pipeline runs at 102ms on Neuron (inf2.xlarge), compared to 2,854ms on GPU (A10G g5.xlarge) and 10,856ms on CPU -- 28x and 106x faster respectively.

Related Issues

N/A

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

jimburtoft and others added 5 commits March 5, 2026 23:17
ZUNA (Zyphra/ZUNA) is a 382M-parameter masked diffusion autoencoder for EEG
signals. Ported to AWS Neuron using torch_neuronx.trace() with SDPA replacement
for flex_attention, achieving 28x speedup over GPU (A10G) and 106x over CPU.

- Encoder-decoder architecture with 4D axial RoPE, rectified flow diffusion
- 50-step pipeline: 102ms latency, 9.80 samples/sec on inf2.xlarge
- DataParallel: 16.29 samples/sec (1.66x scaling on 2 NeuronCores)
- Cosine similarity: 0.990 (auto-cast=matmult), 1.000000 (pure FP32)
- Integration tests: accuracy, performance, DataParallel validation
- Comprehensive executed notebook with full walkthrough
Measured min across 50 seeds is 0.937; seed 7 gives 0.965. Setting threshold
at 0.930 ensures all seeds pass while remaining meaningful. All 15 tests pass
on inf2.xlarge.
- README: Add per-seed auto-cast comparison table showing that outlier seeds
  (e.g. seed 7: 0.966 with matmult) become perfect (1.000000) without it.
  Confirms all error is from BF16 matmul, not compiler or SDPA replacement.
- Tests: Add TestNoAutocast class (5 parametrized seeds) that compiles with
  --auto-cast=none and asserts cosine_sim >= 0.999. All 20 tests pass on
  inf2.xlarge (15 original + 5 no-autocast).
26 unit tests in test/unit/test_patching.py covering:
- flex_attention patch installation (5 tests)
- Model loading and config validation (4 tests)
- SDPA monkey-patching of all 48 attention modules (2 tests)
- Synthetic input generation (4 tests)
- Encoder/decoder wrapper shapes and finiteness (6 tests)
- Diffusion loop determinism and correctness (5 tests)

All run on CPU in ~30s with no Neuron hardware required.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant