Add Kokoro-82M TTS contrib: 82M-param text-to-speech on Neuron (trn2 …#71
Open
jimburtoft wants to merge 5 commits intoaws-neuron:mainfrom
Open
Add Kokoro-82M TTS contrib: 82M-param text-to-speech on Neuron (trn2 …#71jimburtoft wants to merge 5 commits intoaws-neuron:mainfrom
jimburtoft wants to merge 5 commits intoaws-neuron:mainfrom
Conversation
…+ inf2) First non-LLM contrib model. Kokoro-82M (hexgrad/Kokoro-82M) is a parallel TTS model using StyleTTS 2 + ISTFTNet, compiled via torch_neuronx.trace() with a 3-part architecture split to work around XLA tracer limitations. Performance: 60-80x real-time on trn2.3xlarge, 47-55x on inf2.xlarge. Accuracy: cosine 0.985-0.993 vs CPU reference. 15 integration tests.
… streaming API - generate() now handles text of any length via KPipeline phoneme-aware chunking - generate_stream() yields audio chunks as a Python generator for low-latency streaming - Crossfade stitching (25ms linear) eliminates click artifacts at chunk boundaries - Sub-chunk splitting when KPipeline chunks exceed bucket limits - 5 new integration tests (TestLongForm): long text, streaming, audio duration, clicks, timing - Updated standalone runner with long-form and streaming tests (8 tests total) - All 20 pytest tests pass on trn2.3xlarge (SDK 2.27)
…y sub-chunking Critical fix: HOP_SIZE was 300 (generator only) but must be 600 to account for decode[3]'s 2x ConvTranspose1d upsample. This caused all audio to be truncated to half length. Large bucket support: systematic testing found dead zones at 256-384 and 1344-2624 (SB overflow), but 512-1312 and 2688+ compile successfully. Default buckets now [64, 128, 192, 512, 768, 1024] covering up to ~25.6s. Sub-chunking improvements: - Split at word boundaries (spaces) instead of mid-word - Recursive sub-chunking (up to 3 levels) replaces silence fallback - 60% target (was 80%) to leave headroom for ALBERT re-prediction
Corrected benchmark numbers after HOP_SIZE fix (audio durations were previously reported at half their actual length). New trn2 benchmarks: - Decoder-only: 129-181x real-time (was 60-80x) - End-to-end with CPU: ~4-6x real-time (CPU ALBERT is the bottleneck) - Large buckets: 512/768/1024 extend per-chunk audio to 25.6s Documented dead zone pattern in bucket compilation and updated compatibility matrix, known limitations, and usage examples.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds hexgrad/Kokoro-82M, an 82M-parameter text-to-speech model (StyleTTS 2 + ISTFTNet, Apache 2.0) to NxDI contrib. This is the first non-LLM contrib model -- it uses
torch_neuronx.trace()with a 3-part architecture split to handle XLA tracer incompatibilities. Generates 24kHz audio at 60-80x real-time on trn2.3xlarge.Model Information
Model Name: Kokoro-82M
Model Architecture: Parallel (non-autoregressive) TTS: ALBERT duration predictor + StyleTTS 2 decoder + ISTFTNet vocoder
Purpose: Text-to-speech synthesis (text input → 24kHz mono audio output)
Checklist
Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md for detailed guidelines.
Required Components
test/integration/test_model.py)src/)kokoro_neuron.py: 769-lineKokoroNeuronclass with compile, save, load, generate, warmupOptional Components
Folder Structure
Confirm your contribution follows this structure:
Testing
How did you test this change?
Full test suite executed on trn2.3xlarge (SDK 2.27, PyTorch 2.9, LNC=2). Compiled all 6 bucket sizes (32, 64, 96, 128, 160, 192 frames), ran 15 integration tests covering compilation, inference quality, accuracy validation, multi-voice support, and performance thresholds. Also tested on inf2.xlarge with
-O1compiler flag (4 of 6 buckets compile successfully).Test Results:
Compatibility
Tested with:
-O1required)Additional Information
Key Technical Challenges Solved
5 root causes were identified and fixed during onboarding:
torch.wheretorch.repeat_interleave-O2→ use-O1compiler flagNCC_ITEN404) → element-wise decompositionPerformance Summary
Known Limitations
-O1and cannot compile buckets 64 and 192Related Issues
None.
vLLM Integration
Not applicable -- TTS model, not an LLM.
By submitting this PR, I confirm that: