Skip to content

Contrib/ltx 2.3#63

Draft
jimburtoft wants to merge 11 commits intoaws-neuron:mainfrom
jimburtoft:contrib/ltx-2.3
Draft

Contrib/ltx 2.3#63
jimburtoft wants to merge 11 commits intoaws-neuron:mainfrom
jimburtoft:contrib/ltx-2.3

Conversation

@jimburtoft
Copy link

Description

Adds LTX-2.3, a 22B parameter DiT (Diffusion Transformer) for joint audio-video generation, running on Trainium 2 via NxD Inference. The model generates synchronized video + audio from text prompts using bidirectional audio-video cross-attention with flow matching.
Both the 22B DiT backbone and Gemma 3 12B text encoder run on Neuron at TP=4, sharing 4 NeuronCores and executing sequentially. CPU handles VAE decode and optional latent upscaling. Uses the native ltx-core framework (not Diffusers).
Key results:

  • Single forward pass cosine similarity: 0.9999 (video), 0.9999 (audio) vs CPU reference
  • Neuron Gemma3 text encoding: 6.6s (83x faster than CPU fallback at ~162s)
  • Warm denoising step: ~0.3s per step (22B DiT, 8 Euler steps)
  • E2E output: 25 frames @ 384x512 video + stereo 48kHz audio

Model Information

Model Name: LTX-2.3
Model Architecture: DiT (Diffusion Transformer) — 48 transformer blocks, 32 attention heads, 4096 video dim, 2048 audio dim. Bidirectional audio-video cross-attention, gated attention, QK-RMSNorm, split RoPE, flow matching.
Purpose: Joint text-to-video and text-to-audio generation with optional spatial/temporal latent upscaling.

Checklist

Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md for detailed guidelines.

Required Components

  • Accuracy Test (ex. test/integration/test_model.py)
    • At least one integration test that validates model accuracy
    • Uses logit validation or equivalent accuracy verification
    • Test can compile and run the model on Neuron
  • README.md with the following sections:
    • Usage Example: Clear code example showing how to use the model
    • Compatibility Matrix: Table showing tested Neuron SDK versions and instance types (Trn1/Trn2/Inf2)
    • Example Checkpoints: Links to compatible model checkpoints (e.g., HuggingFace Hub)
    • Testing Instructions: Command to run the test suite for the model
  • Source Code (src/)
    • Modeling code following NxD Inference patterns
    • Properly structured in the contrib folder hierarchy

Optional Components

  • Unit Tests (CPU or Neuron-based)
    • Tests for individual modeling components
    • Located in test/unit/ directory

Folder Structure

Confirm your contribution follows this structure:
/contrib/models/LTX-2.3/
README.md
/src
init.py
modeling_ltx23.py # Core DiT backbone (TP sharding, DistributedRMSNorm, SDPA replacement)
modeling_gemma3_encoder.py # Custom Gemma3 encoder-only model (returns all 49 hidden states)
pipeline.py # NeuronTransformerWrapper (CPU preprocessing, backbone routing)
compile_transformer.py # DiT backbone compilation script
compile_gemma3.py # Gemma3 encoder compilation script
shard_gemma3_weights.py # Pre-shard Gemma3 weights to per-rank files
load_with_weights.py # DiT backbone weight sharding and injection
generate_ltx23.py # E2E generation pipeline
/test
init.py
/unit
init.py
/integration
init.py
test_model.py

Testing

How did you test this change?
All testing performed on trn2.3xlarge (TP=4, LNC=2) with Neuron SDK 2.28, PyTorch 2.9, Deep Learning AMI Neuron (Ubuntu 24.04) 20260227.
Tests include:

  1. Integration tests (4 tests): model loading, forward pass validity, accuracy vs CPU reference, latency measurement
  2. E2E generation: Full pipeline with Neuron Gemma3 text encoding → Neuron DiT denoising (8 steps) → CPU VAE decode, producing 25 video frames @ 384x512 + stereo 48kHz audio
    Test Results:
    test/integration/test_model.py::test_model_loads PASSED
    test/integration/test_model.py::test_forward_pass_no_nan PASSED
    test/integration/test_model.py::test_accuracy_vs_cpu PASSED
    test/integration/test_model.py::test_performance_latency PASSED
    ========================= 4 passed in 456.89s =========================
    Accuracy validation (measured against unsharded BF16 CPU reference):
    | Component | Metric | Value |
    |-----------|--------|-------|
    | Single forward pass (video) | Cosine similarity | 0.999947 |
    | Single forward pass (audio) | Cosine similarity | 0.999867 |
    | 8-step denoised latent (real text) | Cosine similarity | 0.972 |

Compatibility

Tested with:

  • Neuron SDK Version(s): 2.28
  • Instance Type(s): trn2.3xlarge (TP=4, LNC=2)
  • PyTorch Version: 2.9
  • Python Version: 3.12

Additional Information

  • Two Neuron models share NeuronCores: Both the 22B DiT backbone and 12B Gemma3 text encoder are compiled for TP=4 and share the same 4 NeuronCores, executing sequentially (text encoding runs once, then denoising loop runs 8 steps).
  • Different compiler flags per model: DiT uses --auto-cast matmult for throughput; Gemma3 uses --auto-cast=none --enable-saturate-infinity --enable-mixed-precision-accumulation for text encoder precision.
  • Cold start: First two denoising steps are slow (~144s + ~177s) due to Neuron device initialization. Warm steps run at ~0.3s each.
  • CPU fallback available: Gemma3 can run on CPU (~162s) without Neuron compilation, using --gemma-path without --neuron-gemma.
  • Stage 1 only: This submission covers Stage 1 generation with optional latent upscaling. Stage 2 refinement denoising (recompiling at larger latent shapes) is planned as a follow-up.
  • Uses native ltx-core framework, not Diffusers.

Related Issues

N/A — first contribution for this model.

vLLM Integration

  • This model/feature is intended for use with vLLM
  • Documentation includes vLLM registration instructions

By submitting this PR, I confirm that:

  • I have read and followed the contributing guidelines
  • This is a community contribution and may have limited testing compared to officially-supported models
  • The code follows best practices and is well-documented
  • All required components listed above are included

Adds custom Gemma3 encoder-only model that runs on Neuron TP=4,
sharing NeuronCores with the DiT backbone for sequential execution.
Reduces text encoding from ~162s (CPU) to 6.6s (Neuron), an 83x speedup.

New files:
- modeling_gemma3_encoder.py: Custom encoder returning all 49 hidden states
- compile_gemma3.py: Compilation script using parallel_model_trace
- shard_gemma3_weights.py: Pre-shard weights to per-rank files

Updated:
- generate_ltx23.py: --neuron-gemma flag for Neuron text encoding path
- __init__.py: Export Gemma3 encoder classes
- README.md: Updated benchmarks with measured Neuron Gemma3 timing

E2E validated: text encoding + 8-step denoising + VAE decode produces
valid video (25 frames @ 384x512) and audio (stereo 48kHz).
All 4 integration tests pass.
@jimburtoft jimburtoft marked this pull request as draft March 9, 2026 06:00
- Use SpatioTemporalScaleFactors.default() (time=8, h=32, w=32) instead
  of incorrect (time=1, w=8, h=8) in both compile and generate scripts.
  This was the root cause of garbled output — wrong RoPE positions.
- Add LATENT_H/LATENT_W constants to compile_transformer.py so the
  halfres wrapper can override them (previously hardcoded height=12, w=16).
- Add compile_transformer_halfres.py for Stage 1 half-res compilation
  (192x256, 192 video tokens).
- Replace LTX2Scheduler with hardcoded DISTILLED_SIGMA_VALUES matching
  the reference distilled pipeline constants.
- Fix --neuron-gemma CLI validation (was requiring --gemma-path even
  when --neuron-gemma was specified).
…peline sigma schedule

- Add sequential Neuron model loading: Gemma3 encodes text first, then unloads
  from NeuronCores before DiT backbone loads. Eliminates HBM contention.
- Add unload_neuron_model() for explicit NRT resource cleanup
- Add shard_backbone_weights.py: pre-shards DiT weights per TP rank (~9.3GB each)
  to avoid loading full 41GB safetensors during generation
- Update load_neuron_backbone() to use pre-sharded weights when available,
  with fallback to memory-mapped safetensors loading
- Add DiT backbone warmup pass before denoising loop
- Fix pipeline.py: correct SpatioTemporalScaleFactors.default() (was time=1,
  height=8, width=8; now time=8, height=32, width=32)
- Fix pipeline.py: replace LTX2Scheduler with distilled sigma values
- Fix pipeline.py: correct latent dimension computation (height//32, not height//8//2)

E2E verified: Gemma3 text encoding + 8-step distilled denoising produces
prompt-matching video (golden retriever on meadow) at 192x256 half-res.
Warm denoising step latency: 0.3s/step on trn2.3xlarge TP=4.
…> 295ms)

Monkey-patch preprocessor _prepare_timestep to compute AdaLN MLP once
per unique sigma value instead of per-token (768 tokens in T2V mode).
CPU preprocessing reduced from 96.5ms to 48.6ms (-49.6%). Handles
both T2V (1 unique sigma) and I2V (2 unique sigmas) modes.

Correctness validated: cos_sim=1.0, max_diff=2.44e-4 (bf16 precision).
Update README with per-step performance detail table.
Cache RoPE embeddings, context projection (caption_projection Linear),
and attention masks across denoising steps since they depend only on
spatial positions and text context, not the diffusion timestep.

Combined with AdaLN dedup, CPU preprocessing drops from 96.5ms to 39.0ms
(60% reduction). Overall warm per-step latency: 330ms -> 286.5ms.
Replace --enable-saturate-infinity and --enable-mixed-precision-accumulation with
tensorizer flags (--enable-ccop-compute-overlap, --cc-pipeline-tiling-factor=1,
--vectorize-strided-dma, --enable-scalar-dge-vectorization) for the Gemma3 text
encoder. No accuracy degradation.

Update README with:
- Gemma3 warm encoding: 6.7s → 0.6s (encoder forward), ~1.3s E2E
- Per-step totals: 286.5ms → 279.3ms (15.4% vs baseline)
- Optional DiT --vectorize-strided-dma flag documented (1.9% speedup,
  0.996 cosine sim — not default)
- DiT tensorizer ablation results (tiling factor hurts, scalar-dge neutral)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant