Skip to content

Add Kokoro-82M TTS contrib: 82M-param text-to-speech on Neuron (trn2 …#71

Open
jimburtoft wants to merge 5 commits intoaws-neuron:mainfrom
jimburtoft:contrib/kokoro-82m
Open

Add Kokoro-82M TTS contrib: 82M-param text-to-speech on Neuron (trn2 …#71
jimburtoft wants to merge 5 commits intoaws-neuron:mainfrom
jimburtoft:contrib/kokoro-82m

Conversation

@jimburtoft
Copy link

Description

Adds hexgrad/Kokoro-82M, an 82M-parameter text-to-speech model (StyleTTS 2 + ISTFTNet, Apache 2.0) to NxDI contrib. This is the first non-LLM contrib model -- it uses torch_neuronx.trace() with a 3-part architecture split to handle XLA tracer incompatibilities. Generates 24kHz audio at 60-80x real-time on trn2.3xlarge.

Model Information

Model Name: Kokoro-82M
Model Architecture: Parallel (non-autoregressive) TTS: ALBERT duration predictor + StyleTTS 2 decoder + ISTFTNet vocoder
Purpose: Text-to-speech synthesis (text input → 24kHz mono audio output)

Checklist

Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md for detailed guidelines.

Required Components

  • Accuracy Test (ex. test/integration/test_model.py)
    • 15 integration tests validating compilation, inference, accuracy (cosine > 0.95, SNR > 10dB), voice styles, and performance
    • Tests compile and run the model on Neuron (trn2.3xlarge validated, inf2.xlarge validated)
    • Accuracy verified via cosine similarity against CPU reference: 0.985 on trn2, 0.993 on inf2
  • README.md with the following sections:
    • Usage Example: Clear code example showing compile, save, load, generate, and audio export
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types (trn2.3xlarge, inf2.xlarge)
    • Example Checkpoints: Link to hexgrad/Kokoro-82M
    • Testing Instructions: Command to run pytest and standalone test
  • Source Code (src/)
    • kokoro_neuron.py: 769-line KokoroNeuron class with compile, save, load, generate, warmup
    • All monkey-patches applied at import time (5 XLA workarounds)
    • Bucket-based compilation for variable-length inputs (6 bucket sizes)

Optional Components

  • Unit Tests (CPU or Neuron-based)
    • Not included (all tests require Neuron hardware due to compilation)

Folder Structure

Confirm your contribution follows this structure:

/contrib/models/Kokoro-82M/
  README.md
  /src
    __init__.py
    kokoro_neuron.py
  /test
    __init__.py
    /unit
      __init__.py
    /integration
      __init__.py
      test_model.py

Testing

How did you test this change?
Full test suite executed on trn2.3xlarge (SDK 2.27, PyTorch 2.9, LNC=2). Compiled all 6 bucket sizes (32, 64, 96, 128, 160, 192 frames), ran 15 integration tests covering compilation, inference quality, accuracy validation, multi-voice support, and performance thresholds. Also tested on inf2.xlarge with -O1 compiler flag (4 of 6 buckets compile successfully).
Test Results:

======================== 15 passed in 26.54s ========================
TestCompilation: 4/4 PASS (model loads, sample rate, compiled files exist, NEFF sizes)
TestInference:   4/4 PASS (audio produced, range [-1,1], duration within 20%, timing info)
TestAccuracy:    2/2 PASS (cosine > 0.95, SNR > 10dB vs CPU reference)
TestVoices:      3/3 PASS (af_heart, af_sky, am_adam all produce valid audio)
TestPerformance: 2/2 PASS (>10x real-time decoder, P50 < 50ms)

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.27
  • Instance Type(s): trn2.3xlarge (primary), inf2.xlarge (secondary, -O1 required)
  • PyTorch Version: 2.9
  • Python Version: 3.12

Additional Information

Key Technical Challenges Solved

5 root causes were identified and fixed during onboarding:

  1. CustomSTFT boolean mask indexing crashes XLA tracer → torch.where
  2. UpSample1d F.interpolate causes XLA to silently drop input tensors → torch.repeat_interleave
  3. Generator State Buffer overflow on inf2 at -O2 → use -O1 compiler flag
  4. NKI depthwise ConvTranspose1d kernel bug (NCC_ITEN404) → element-wise decomposition
  5. Generator F.interpolate(scale_factor=300) not traceable → precompute harmonics on CPU

Performance Summary

Instance Latency Range Real-Time Factor Buckets
trn2.3xlarge 6.7-30.7ms 60-80x All 6 (32-192)
inf2.xlarge 7.3-41.0ms 47-55x 4 of 6 (32,96,128,160)

Known Limitations

  • Max ~2.4s audio per utterance (bucket 192 on trn2, 160 on inf2)
  • CPU preprocessing adds ~5-15ms (ALBERT duration prediction + harmonic precompute)
  • inf2 requires -O1 and cannot compile buckets 64 and 192
  • Single-utterance batch size (use DataParallel for throughput)

Related Issues

None.

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions
    Not applicable -- TTS model, not an LLM.

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

…+ inf2)

First non-LLM contrib model. Kokoro-82M (hexgrad/Kokoro-82M) is a parallel
TTS model using StyleTTS 2 + ISTFTNet, compiled via torch_neuronx.trace()
with a 3-part architecture split to work around XLA tracer limitations.

Performance: 60-80x real-time on trn2.3xlarge, 47-55x on inf2.xlarge.
Accuracy: cosine 0.985-0.993 vs CPU reference. 15 integration tests.
… streaming API

- generate() now handles text of any length via KPipeline phoneme-aware chunking
- generate_stream() yields audio chunks as a Python generator for low-latency streaming
- Crossfade stitching (25ms linear) eliminates click artifacts at chunk boundaries
- Sub-chunk splitting when KPipeline chunks exceed bucket limits
- 5 new integration tests (TestLongForm): long text, streaming, audio duration, clicks, timing
- Updated standalone runner with long-form and streaming tests (8 tests total)
- All 20 pytest tests pass on trn2.3xlarge (SDK 2.27)
…y sub-chunking

Critical fix: HOP_SIZE was 300 (generator only) but must be 600 to account
for decode[3]'s 2x ConvTranspose1d upsample. This caused all audio to be
truncated to half length.

Large bucket support: systematic testing found dead zones at 256-384 and
1344-2624 (SB overflow), but 512-1312 and 2688+ compile successfully.
Default buckets now [64, 128, 192, 512, 768, 1024] covering up to ~25.6s.

Sub-chunking improvements:
- Split at word boundaries (spaces) instead of mid-word
- Recursive sub-chunking (up to 3 levels) replaces silence fallback
- 60% target (was 80%) to leave headroom for ALBERT re-prediction
Corrected benchmark numbers after HOP_SIZE fix (audio durations were
previously reported at half their actual length). New trn2 benchmarks:
- Decoder-only: 129-181x real-time (was 60-80x)
- End-to-end with CPU: ~4-6x real-time (CPU ALBERT is the bottleneck)
- Large buckets: 512/768/1024 extend per-chunk audio to 25.6s

Documented dead zone pattern in bucket compilation and updated
compatibility matrix, known limitations, and usage examples.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant