-
Notifications
You must be signed in to change notification settings - Fork 389
sample QAD example script #933
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
199d6e2
6ca6726
f117065
8093507
189d04f
bdf2f3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| # LTX-2 QAD Example (Quantization-Aware Distillation) | ||
|
|
||
| **Note:** This is a **sample script for illustrating the QAD pipeline**. It has been verified to run on a **Linux RTX 5090** system, but runs into **OOM (Out of Memory)** on that configuration. | ||
|
|
||
| This example demonstrates **Quantization-Aware Distillation (QAD)** for [LTX-2](https://github.com/Lightricks/LTX-2) using the native LTX training loop and [NVIDIA ModelOpt](https://github.com/NVIDIA/Model-Optimizer). It combines: | ||
|
|
||
| - **LTX packages**: training loop, datasets, and strategies (masked loss, audio/video split) | ||
| - **NVIDIA ModelOpt**: PTQ calibration (`mtq.quantize`), distillation (`mtd.convert`), and NVFP4 quantization | ||
|
|
||
| Combined loss (same idea as the full distillation trainer): | ||
|
|
||
| ```text | ||
| L_total = α × L_task + (1−α) × L_distill | ||
| ``` | ||
|
|
||
| For the **full-stage QAD** implementation (LTX-2 DiT with ModelOpt quantization, full calibration options, checkpoint resume, and multi-node training), see the NVIDIA Model-Optimizer distillation example: | ||
|
|
||
| - **Full distillation trainer**: [distillation_trainer.py](https://github.com/NVIDIA/Model-Optimizer/blob/ca1f9687bd741a0c73791c093692eff0f95d2d46/examples/diffusers/distillation/distillation_trainer.py) | ||
| - **Example docs**: [examples/diffusers/distillation](https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/diffusers/distillation) (README, configs, and usage there). | ||
|
|
||
| ## Requirements | ||
|
|
||
| - Python 3.10+ | ||
| - CUDA-capable GPU(s) | ||
| - [Accelerate](https://huggingface.co/docs/accelerate) (for FSDP multi-GPU training) | ||
|
|
||
| ## Installation | ||
|
|
||
| Create a virtual environment and install dependencies: | ||
|
|
||
| ```bash | ||
| python -m venv .venv | ||
| .venv\Scripts\activate # Windows | ||
| # source .venv/bin/activate # Linux/macOS | ||
|
|
||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| The `requirements.txt` includes: | ||
|
|
||
| | Package | Source | | ||
| |--------|--------| | ||
| | **ltx-core** | `git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-core` | | ||
| | **ltx-pipelines** | `git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-pipelines` | | ||
| | **ltx-trainer** | `git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-trainer` | | ||
| | **nvidia-modelopt** | PyPI (`nvidia-modelopt`) | | ||
|
|
||
| You may also need to install PyTorch, Accelerate, safetensors, and PyYAML if not already present: | ||
|
|
||
| ```bash | ||
| pip install torch accelerate safetensors pyyaml | ||
| ``` | ||
|
|
||
| ## Project layout | ||
|
|
||
| | File | Description | | ||
| |------|-------------| | ||
| | `sample_example_qad_diffusers.py` | Main script: QAD training and inference checkpoint creation | | ||
| | `ltx2_qad.yaml` | LTX training config (model, data, optimization, QAD options) | | ||
| | `fsdp_custom.yaml` | Accelerate FSDP config for multi-GPU training | | ||
|
|
||
| ## Usage | ||
|
|
||
| ### 1. Prepare your dataset | ||
|
|
||
| Run the LTX preprocessing script to extract latents and text embeddings from your videos. Use `preprocess_dataset.py` with the following arguments (matching the LTX training pipeline): | ||
|
|
||
| ```bash | ||
| python scripts/process_dataset.py /path/to/dataset.json \ | ||
| --resolution-buckets 384x256x97 \ | ||
|
Comment on lines
+66
to
+70
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dataset preprocessing command name is inconsistent. The text says 🤖 Prompt for AI Agents |
||
| --output-dir /path/to/preprocessed \ | ||
| --model-path /path/to/ltx2/checkpoint.safetensors \ | ||
| --text-encoder-path /path/to/gemma \ | ||
| --batch-size 4 \ | ||
| --with-audio \ | ||
| --decode | ||
| ``` | ||
|
|
||
| - **Positional**: path to dataset metadata file (CSV/JSON/JSONL with captions and video paths). | ||
| - **Required**: `--resolution-buckets`, `--model-path`, `--text-encoder-path`. | ||
| - **Optional**: `--output-dir` (defaults to `.precomputed` in dataset dir), `--batch-size` (default 1), `--with-audio`, `--decode` (decode and save videos for verification). | ||
|
|
||
| Set `data.preprocessed_data_root` in your config (step 2) to the same path as `--output-dir`. | ||
|
|
||
| On a **Slurm cluster**, run the same script via `srun` and `torchrun` (set `MASTER_ADDR`, `MASTER_PORT`, `WORLD_SIZE` from Slurm and use `--nnodes=$SLURM_NNODES` and `--nproc_per_node=8`). | ||
|
|
||
| ### 2. Configure paths | ||
|
|
||
| Edit `ltx2_qad.yaml` and set: | ||
|
|
||
| - `model.model_path` – path to base LTX checkpoint (e.g. `.safetensors`) | ||
| - `model.text_encoder_path` – path to Gemma text encoder | ||
| - `data.preprocessed_data_root` – path to preprocessed LTX dataset | ||
|
|
||
| Adjust `qad` section as needed: `calib_size`, `kd_loss_weight`, `exclude_blocks`, `skip_inference_ckpt`. | ||
|
|
||
| #### Hyperparameters controllable via YAML (`ltx2_qad.yaml`) | ||
|
|
||
| All of the following can be set in `ltx2_qad.yaml`. QAD-specific options can also be overridden from the CLI (see step 3). | ||
|
|
||
| | Section | Key | Default (example) | Description | | ||
| |--------|-----|--------------------|-------------| | ||
| | **qad** | `calib_size` | `512` | Number of calibration batches for PTQ (more = better scale estimates, slower startup). | | ||
| | **qad** | `kd_loss_weight` | `0.5` | Weight for distillation loss in combined loss; `0` = task loss only, `1` = distillation only. | | ||
| | **qad** | `exclude_blocks` | `[0, 1, 46, 47]` | Transformer block indices to exclude from quantization (e.g. first/last blocks). | | ||
| | **qad** | `skip_inference_ckpt` | `false` | If `true`, do not build the inference checkpoint after training. | | ||
| | **optimization** | `learning_rate` | `1e-6` | Learning rate (low is typical for QAD/distillation). | | ||
| | **optimization** | `steps` | `300` | Total training steps. | | ||
| | **optimization** | `batch_size` | `1` | Per-device batch size. | | ||
| | **optimization** | `gradient_accumulation_steps` | `4` | Gradient accumulation steps (effective batch = batch_size × accumulation × num_gpus). | | ||
| | **optimization** | `optimizer_type` | `"adamw"` | Optimizer (`adamw`, etc.). | | ||
| | **checkpoints** | `interval` | `100` | Save a checkpoint every N steps; `null` to disable. | | ||
| | (root) | `output_dir` | `"outputs/ltx2_qad"` | Where to write checkpoints and logs. | | ||
|
|
||
| ### 3. Run QAD training | ||
|
|
||
| Using Accelerate with the provided FSDP config: | ||
|
|
||
| ```bash | ||
| accelerate launch --config_file fsdp_custom.yaml sample_example_qad_diffusers.py train \ | ||
| --config ltx2_qad.yaml \ | ||
| ``` | ||
|
|
||
| Checkpoints are saved under `output_dir` (e.g. `outputs/ltx2_qad/checkpoints/`) as safetensors plus optional amax and modelopt state files. | ||
|
|
||
| ### 4. Create inference checkpoint (ComfyUI-compatible) | ||
|
|
||
| **ComfyUI** is a node-based interface for running diffusion models (Stable Diffusion, LTX, etc.). You load your exported checkpoint in ComfyUI to generate images or videos from prompts and workflows. | ||
|
|
||
| - **ComfyUI**: [github.com/comfyanonymous/ComfyUI](https://github.com/comfyanonymous/ComfyUI) | ||
| - **ComfyUI documentation**: [comfyanonymous.github.io/ComfyUI_examples](https://comfyanonymous.github.io/ComfyUI_examples/) (examples and node docs) | ||
|
|
||
| To build a single inference checkpoint compatible with ComfyUI, use the PTQ checkpoint merger: | ||
|
ynankani marked this conversation as resolved.
|
||
|
|
||
| ```bash | ||
| python -m ltx2.tools.ptq.checkpoint_merger \ | ||
| --artefact /path/to/amax_artifact.json \ | ||
| --checkpoint /path/to/ltx2_qad_bf16.safetensors \ | ||
| --config /path/to/config.yaml \ | ||
| --output /path/to/comfyui_checkpoints/nvfp4_qad_inference.safetensors | ||
| ``` | ||
|
|
||
| - **`--artefact`** – Path to the amax artifact JSON (from calibration / QAD training). | ||
| - **`--checkpoint`** – Path to the trained QAD weights (e.g. `ltx2_qad_bf16.safetensors` from your run). | ||
| - **`--config`** – Path to the merger config YAML. | ||
| - **`--output`** – Output path for the ComfyUI-ready `.safetensors` file. | ||
|
|
||
| This produces a single `.safetensors` file you can load in ComfyUI. | ||
|
|
||
| ## How it works | ||
|
|
||
| 1. **Model load** – Base transformer is loaded via `ltx_trainer.model_loader.load_transformer`. | ||
| 2. **PTQ calibration** – ModelOpt `mtq.quantize` runs a calibration loop using the LTX dataset and training strategy; NVFP4 config excludes sensitive layers and optionally specific blocks. | ||
| 3. **Distillation** – A full-precision teacher (same checkpoint) is loaded and the quantized model is wrapped with ModelOpt `mtd.convert` (KD loss). | ||
| 4. **Training** – Standard LTX training loop with an overridden `_training_step` that adds KD loss via ModelOpt’s loss balancer. | ||
| 5. **Checkpoint save** – Checkpoints are filtered (no teacher/loss/quantizer state), optionally dtype-matched to the base model, and saved as safetensors; amax and modelopt state can be saved separately. | ||
|
|
||
| ## References | ||
|
|
||
| - LTX-2: [Lightricks/LTX-2](https://github.com/Lightricks/LTX-2) | ||
| - NVIDIA ModelOpt: [NVIDIA/Model-Optimizer](https://github.com/NVIDIA/Model-Optimizer) | ||
| - Full-stage QAD / distillation trainer: [distillation_trainer.py](https://github.com/NVIDIA/Model-Optimizer/blob/ca1f9687bd741a0c73791c093692eff0f95d2d46/examples/diffusers/distillation/distillation_trainer.py) in Model-Optimizer (`examples/diffusers/distillation`) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| compute_environment: LOCAL_MACHINE | ||
| debug: false | ||
| distributed_type: FSDP | ||
| downcast_bf16: 'yes' | ||
| enable_cpu_affinity: false | ||
| fsdp_config: | ||
| fsdp_activation_checkpointing: false | ||
| fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP | ||
| fsdp_backward_prefetch: BACKWARD_PRE | ||
| fsdp_cpu_ram_efficient_loading: true | ||
| fsdp_forward_prefetch: false | ||
| fsdp_offload_params: false | ||
| fsdp_reshard_after_forward: FULL_SHARD | ||
| fsdp_state_dict_type: SHARDED_STATE_DICT | ||
| fsdp_sync_module_states: true | ||
| fsdp_transformer_layer_cls_to_wrap: BasicAVTransformerBlock | ||
| fsdp_use_orig_params: true | ||
| fsdp_version: 1 | ||
| machine_rank: 0 | ||
| main_training_function: main | ||
| mixed_precision: "no" | ||
| num_machines: 1 | ||
| num_processes: 8 | ||
| rdzv_backend: static | ||
| same_network: true | ||
| tpu_env: [] | ||
| tpu_use_cluster: false | ||
| tpu_use_sudo: false | ||
| use_cpu: false |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,78 @@ | ||||||||||||||
| # LTX-2 QAD Training Configuration | ||||||||||||||
| model: | ||||||||||||||
| model_path: "/lustre/fsw/portfolios/adlr/projects/adlr_psx_numerics/users/ynankani/ComfyUI/models/checkpoints/ltx-av-step-1933500-split-new-vae.safetensors" | ||||||||||||||
| training_mode: "full" | ||||||||||||||
| load_checkpoint: | ||||||||||||||
| text_encoder_path: "/lustre/fsw/portfolios/adlr/users/dhutchins/models/gemma" | ||||||||||||||
|
Comment on lines
+2
to
+6
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace internal filesystem paths with placeholders. The config contains hardcoded paths to an internal Lustre filesystem that won't work for external users:
Replace with placeholder paths that users can easily identify and update. 📝 Suggested placeholder paths model:
- model_path: "/lustre/fsw/portfolios/adlr/projects/adlr_psx_numerics/users/ynankani/ComfyUI/models/checkpoints/ltx-av-step-1933500-split-new-vae.safetensors"
+ model_path: "/path/to/ltx2/checkpoint.safetensors" # TODO: Set your LTX-2 checkpoint path
training_mode: "full"
load_checkpoint:
- text_encoder_path: "/lustre/fsw/portfolios/adlr/users/dhutchins/models/gemma"
+ text_encoder_path: "/path/to/gemma" # TODO: Set your Gemma text encoder path🤖 Prompt for AI Agents |
||||||||||||||
|
|
||||||||||||||
| training_strategy: | ||||||||||||||
| name: "text_to_video" | ||||||||||||||
| first_frame_conditioning_p: 0.1 | ||||||||||||||
|
|
||||||||||||||
| optimization: | ||||||||||||||
| learning_rate: 1e-6 # Low LR for QAD (distillation) | ||||||||||||||
| steps: 300 | ||||||||||||||
| batch_size: 1 | ||||||||||||||
| gradient_accumulation_steps: 4 | ||||||||||||||
| max_grad_norm: 1.0 | ||||||||||||||
| optimizer_type: "adamw" | ||||||||||||||
| scheduler_type: "linear" | ||||||||||||||
| scheduler_params: {} | ||||||||||||||
| enable_gradient_checkpointing: true | ||||||||||||||
|
|
||||||||||||||
| acceleration: | ||||||||||||||
| mixed_precision_mode: "bf16" | ||||||||||||||
| quantization: # We use ModelOpt, not LTX quantization | ||||||||||||||
| load_text_encoder_in_8bit: true | ||||||||||||||
|
|
||||||||||||||
| data: | ||||||||||||||
| preprocessed_data_root: "/lustre/fsw/portfolios/adlr/users/scavallari/ltx-qad/qad-dataset" | ||||||||||||||
| num_dataloader_workers: 2 | ||||||||||||||
|
Comment on lines
+28
to
+30
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace internal data path with placeholder. The 📝 Suggested fix data:
- preprocessed_data_root: "/lustre/fsw/portfolios/adlr/users/scavallari/ltx-qad/qad-dataset"
+ preprocessed_data_root: "/path/to/preprocessed" # TODO: Set your preprocessed dataset path
num_dataloader_workers: 2📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||
|
|
||||||||||||||
| validation: | ||||||||||||||
| prompts: | ||||||||||||||
| - "a professional portrait video of a person with blurry bokeh background" | ||||||||||||||
| - "a video of a person wearing a nice suit" | ||||||||||||||
| negative_prompt: "worst quality, inconsistent motion, blurry, jittery, distorted" | ||||||||||||||
| images: # Set to a list of image paths to use first-frame conditioning, or null to disable | ||||||||||||||
| video_dims: [768, 448, 89] # [width, height, frames] | ||||||||||||||
| seed: 42 | ||||||||||||||
| inference_steps: 50 | ||||||||||||||
| interval: # Set to null to disable validation | ||||||||||||||
| videos_per_prompt: 1 | ||||||||||||||
| guidance_scale: 3.5 | ||||||||||||||
|
|
||||||||||||||
| checkpoints: | ||||||||||||||
| interval: 2300 # Save a checkpoint every N steps, set to null to disable | ||||||||||||||
| keep_last_n: -1 # Keep only the N most recent checkpoints, set to -1 to keep all | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| # Flow matching configuration | ||||||||||||||
| flow_matching: | ||||||||||||||
| timestep_sampling_mode: "shifted_logit_normal" # Options: "uniform", "shifted_logit_normal" | ||||||||||||||
| timestep_sampling_params: {} | ||||||||||||||
|
|
||||||||||||||
| # HuggingFace Hub configuration | ||||||||||||||
| hub: | ||||||||||||||
| push_to_hub: false # Whether to push the model weights to the Hugging Face Hub | ||||||||||||||
| hub_model_id: # Hugging Face Hub repository ID (e.g., 'username/repo-name'). Must be provided if `push_to_hub` is set to True | ||||||||||||||
|
|
||||||||||||||
| # W&B configuration | ||||||||||||||
| wandb: | ||||||||||||||
| enabled: false # Set to true to enable W&B logging | ||||||||||||||
| project: "ltxv-trainer" | ||||||||||||||
| entity: # Your W&B username or team | ||||||||||||||
| tags: [] | ||||||||||||||
| log_validation_videos: true | ||||||||||||||
|
|
||||||||||||||
| # QAD-specific configuration (not part of LtxvTrainerConfig) | ||||||||||||||
| # These can also be overridden via CLI: --calib-size, --kd-loss-weight, --exclude-blocks | ||||||||||||||
| qad: | ||||||||||||||
| calib_size: 10 | ||||||||||||||
| kd_loss_weight: 1.0 | ||||||||||||||
| exclude_blocks: [0, 1, 46, 47] | ||||||||||||||
| skip_inference_ckpt: false | ||||||||||||||
|
Comment on lines
+70
to
+74
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Documentation mismatch: YAML defaults differ from README. The YAML default values don't match the README documentation (lines 101-102):
Either update the YAML to match the README or update the README to reflect the actual defaults. 🤖 Prompt for AI Agents |
||||||||||||||
|
|
||||||||||||||
| # General configuration | ||||||||||||||
| seed: 42 | ||||||||||||||
| output_dir: "outputs/ltx2_qad" | ||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| accelerate | ||
| # LTX-2 packages (from Lightricks/LTX-2 monorepo) | ||
| ltx-core @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-core | ||
| ltx-pipelines @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-pipelines | ||
| ltx-trainer @ git+https://github.com/Lightricks/LTX-2.git#subdirectory=packages/ltx-trainer | ||
|
|
||
| # NVIDIA ModelOpt (quantization & distillation) | ||
| nvidia-modelopt | ||
| pyyaml | ||
| safetensors | ||
|
|
||
| torch>=2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify the single-RTX support statement.
“Verified to run on Linux RTX 5090” conflicts with “runs into OOM” in the same sentence. Please state explicitly whether this setup is unsupported for full training (and note the Windows text-encoder load failure as a known limitation).
🤖 Prompt for AI Agents