Draft
Conversation
Adds custom Gemma3 encoder-only model that runs on Neuron TP=4, sharing NeuronCores with the DiT backbone for sequential execution. Reduces text encoding from ~162s (CPU) to 6.6s (Neuron), an 83x speedup. New files: - modeling_gemma3_encoder.py: Custom encoder returning all 49 hidden states - compile_gemma3.py: Compilation script using parallel_model_trace - shard_gemma3_weights.py: Pre-shard weights to per-rank files Updated: - generate_ltx23.py: --neuron-gemma flag for Neuron text encoding path - __init__.py: Export Gemma3 encoder classes - README.md: Updated benchmarks with measured Neuron Gemma3 timing E2E validated: text encoding + 8-step denoising + VAE decode produces valid video (25 frames @ 384x512) and audio (stereo 48kHz). All 4 integration tests pass.
- Use SpatioTemporalScaleFactors.default() (time=8, h=32, w=32) instead of incorrect (time=1, w=8, h=8) in both compile and generate scripts. This was the root cause of garbled output — wrong RoPE positions. - Add LATENT_H/LATENT_W constants to compile_transformer.py so the halfres wrapper can override them (previously hardcoded height=12, w=16). - Add compile_transformer_halfres.py for Stage 1 half-res compilation (192x256, 192 video tokens). - Replace LTX2Scheduler with hardcoded DISTILLED_SIGMA_VALUES matching the reference distilled pipeline constants. - Fix --neuron-gemma CLI validation (was requiring --gemma-path even when --neuron-gemma was specified).
…peline sigma schedule - Add sequential Neuron model loading: Gemma3 encodes text first, then unloads from NeuronCores before DiT backbone loads. Eliminates HBM contention. - Add unload_neuron_model() for explicit NRT resource cleanup - Add shard_backbone_weights.py: pre-shards DiT weights per TP rank (~9.3GB each) to avoid loading full 41GB safetensors during generation - Update load_neuron_backbone() to use pre-sharded weights when available, with fallback to memory-mapped safetensors loading - Add DiT backbone warmup pass before denoising loop - Fix pipeline.py: correct SpatioTemporalScaleFactors.default() (was time=1, height=8, width=8; now time=8, height=32, width=32) - Fix pipeline.py: replace LTX2Scheduler with distilled sigma values - Fix pipeline.py: correct latent dimension computation (height//32, not height//8//2) E2E verified: Gemma3 text encoding + 8-step distilled denoising produces prompt-matching video (golden retriever on meadow) at 192x256 half-res. Warm denoising step latency: 0.3s/step on trn2.3xlarge TP=4.
…d sequential loading
… 0, per-token timesteps
…psample x2 + full-res S2 (3 steps)
…> 295ms) Monkey-patch preprocessor _prepare_timestep to compute AdaLN MLP once per unique sigma value instead of per-token (768 tokens in T2V mode). CPU preprocessing reduced from 96.5ms to 48.6ms (-49.6%). Handles both T2V (1 unique sigma) and I2V (2 unique sigmas) modes. Correctness validated: cos_sim=1.0, max_diff=2.44e-4 (bf16 precision). Update README with per-step performance detail table.
Cache RoPE embeddings, context projection (caption_projection Linear), and attention masks across denoising steps since they depend only on spatial positions and text context, not the diffusion timestep. Combined with AdaLN dedup, CPU preprocessing drops from 96.5ms to 39.0ms (60% reduction). Overall warm per-step latency: 330ms -> 286.5ms.
Replace --enable-saturate-infinity and --enable-mixed-precision-accumulation with tensorizer flags (--enable-ccop-compute-overlap, --cc-pipeline-tiling-factor=1, --vectorize-strided-dma, --enable-scalar-dge-vectorization) for the Gemma3 text encoder. No accuracy degradation. Update README with: - Gemma3 warm encoding: 6.7s → 0.6s (encoder forward), ~1.3s E2E - Per-step totals: 286.5ms → 279.3ms (15.4% vs baseline) - Optional DiT --vectorize-strided-dma flag documented (1.9% speedup, 0.996 cosine sim — not default) - DiT tensorizer ablation results (tiling factor hurts, scalar-dge neutral)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Adds LTX-2.3, a 22B parameter DiT (Diffusion Transformer) for joint audio-video generation, running on Trainium 2 via NxD Inference. The model generates synchronized video + audio from text prompts using bidirectional audio-video cross-attention with flow matching.
Both the 22B DiT backbone and Gemma 3 12B text encoder run on Neuron at TP=4, sharing 4 NeuronCores and executing sequentially. CPU handles VAE decode and optional latent upscaling. Uses the native ltx-core framework (not Diffusers).
Key results:
Model Information
Model Name: LTX-2.3
Model Architecture: DiT (Diffusion Transformer) — 48 transformer blocks, 32 attention heads, 4096 video dim, 2048 audio dim. Bidirectional audio-video cross-attention, gated attention, QK-RMSNorm, split RoPE, flow matching.
Purpose: Joint text-to-video and text-to-audio generation with optional spatial/temporal latent upscaling.
Checklist
Please ensure your PR includes the following items. Refer to the contrib/CONTRIBUTING.md for detailed guidelines.
Required Components
test/integration/test_model.py)src/)Optional Components
test/unit/directoryFolder Structure
Confirm your contribution follows this structure:
/contrib/models/LTX-2.3/
README.md
/src
init.py
modeling_ltx23.py # Core DiT backbone (TP sharding, DistributedRMSNorm, SDPA replacement)
modeling_gemma3_encoder.py # Custom Gemma3 encoder-only model (returns all 49 hidden states)
pipeline.py # NeuronTransformerWrapper (CPU preprocessing, backbone routing)
compile_transformer.py # DiT backbone compilation script
compile_gemma3.py # Gemma3 encoder compilation script
shard_gemma3_weights.py # Pre-shard Gemma3 weights to per-rank files
load_with_weights.py # DiT backbone weight sharding and injection
generate_ltx23.py # E2E generation pipeline
/test
init.py
/unit
init.py
/integration
init.py
test_model.py
Testing
How did you test this change?
All testing performed on trn2.3xlarge (TP=4, LNC=2) with Neuron SDK 2.28, PyTorch 2.9, Deep Learning AMI Neuron (Ubuntu 24.04) 20260227.
Tests include:
Test Results:
test/integration/test_model.py::test_model_loads PASSED
test/integration/test_model.py::test_forward_pass_no_nan PASSED
test/integration/test_model.py::test_accuracy_vs_cpu PASSED
test/integration/test_model.py::test_performance_latency PASSED
========================= 4 passed in 456.89s =========================
Accuracy validation (measured against unsharded BF16 CPU reference):
| Component | Metric | Value |
|-----------|--------|-------|
| Single forward pass (video) | Cosine similarity | 0.999947 |
| Single forward pass (audio) | Cosine similarity | 0.999867 |
| 8-step denoised latent (real text) | Cosine similarity | 0.972 |
Compatibility
Tested with:
Additional Information
--auto-cast matmultfor throughput; Gemma3 uses--auto-cast=none --enable-saturate-infinity --enable-mixed-precision-accumulationfor text encoder precision.--gemma-pathwithout--neuron-gemma.Related Issues
N/A — first contribution for this model.
vLLM Integration
By submitting this PR, I confirm that: