diff --git a/.github/workflows/build-wheel.yml b/.github/workflows/build-wheel.yml
new file mode 100644
index 000000000..f6c2d1eb2
--- /dev/null
+++ b/.github/workflows/build-wheel.yml
@@ -0,0 +1,67 @@
+name: Build Wheel
+
+on:
+ push:
+ branches:
+ - main
+ paths-ignore:
+ - 'docs/**'
+ - '.assets/**'
+ - '**.md'
+ - '.gitignore'
+ - '.gitattributes'
+ - 'LICENSE'
+ pull_request:
+ branches:
+ - main
+ paths-ignore:
+ - 'docs/**'
+ - '.assets/**'
+ - '**.md'
+ - '.gitignore'
+ - '.gitattributes'
+ - 'LICENSE'
+ workflow_dispatch:
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.12"]
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ fetch-tags: true
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+
+ - name: Install build
+ run: |
+ pip install -U pip && pip install build
+
+ - name: Check build configuration
+ run: |
+ if [ ! -f "pyproject.toml" ]; then
+ echo "Error: pyproject.toml not found! Cannot build wheel."
+ exit 1
+ fi
+ echo "✅ Build configuration file (pyproject.toml) found."
+
+ - name: Build Wheel
+ run: |
+ python -m build --wheel --outdir dist
+ ls -la dist/
+
+ - name: Upload Wheel Artifact
+ uses: actions/upload-artifact@v4
+ with:
+ name: cache-dist-wheel-py${{ matrix.python-version }}
+ path: dist/*.whl
+ retention-days: 15
diff --git a/.github/workflows/check-mkdocs.yml b/.github/workflows/check-mkdocs.yml
new file mode 100644
index 000000000..210aa7b33
--- /dev/null
+++ b/.github/workflows/check-mkdocs.yml
@@ -0,0 +1,70 @@
+name: Check MkDocs Build
+
+on:
+ push:
+ branches:
+ - main
+ paths-ignore:
+ - '.github/**'
+ - '.gitignore'
+ - '.gitattributes'
+ - 'LICENSE'
+ pull_request:
+ branches:
+ - main
+ paths-ignore:
+ - '.github/**'
+ - '.gitignore'
+ - '.gitattributes'
+ - 'LICENSE'
+ workflow_dispatch:
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.12"]
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+ fetch-tags: true
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+
+ - name: Check build configuration
+ run: |
+ if [ ! -f "pyproject.toml" ]; then
+ echo "Error: pyproject.toml not found! Cannot build docs."
+ exit 1
+ fi
+ echo "✅ Build configuration file (pyproject.toml) found."
+
+ - name: Install Only MkDocs Dependencies
+ run: |
+ pip install mkdocs>=1.5.0 mkdocs-api-autonav mkdocs-material \
+ mkdocstrings-python mkdocs-gen-files mkdocs-awesome-nav \
+ mkdocs-glightbox mkdocs-git-revision-date-localized-plugin \
+ mkdocs-minify-plugin regex ruff pydantic
+
+ # Build MkDocs documentation with strict mode and fail it if the
+ # logs contains 'WARNING', 'aborted', 'ERROR', or not contains
+ # 'Documentation built'.
+ - name: Build MkDocs Documentation Strictly
+ run: |
+ mkdocs build --strict 2>&1 | tee build.log
+ if grep -E 'WARNING|aborted|abort|ERROR' build.log; then
+ echo "MkDocs build failed due to warnings or errors."
+ exit 1
+ elif ! grep -q 'Documentation built' build.log; then
+ echo "MkDocs build did not complete successfully."
+ exit 1
+ else
+ echo "MkDocs build completed successfully!"
+ fi
diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml
new file mode 100644
index 000000000..edb008794
--- /dev/null
+++ b/.github/workflows/cpu-tests.yml
@@ -0,0 +1,105 @@
+name: Run CPU Tests
+
+on:
+ push:
+ branches:
+ - main
+ paths-ignore:
+ - 'docs/**'
+ - '.assets/**'
+ - '**.md'
+ - '.gitignore'
+ - '.gitattributes'
+ - 'LICENSE'
+ pull_request:
+ branches:
+ - main
+ paths-ignore:
+ - 'docs/**'
+ - '.assets/**'
+ - '**.md'
+ - '.gitignore'
+ - '.gitattributes'
+ - 'LICENSE'
+ workflow_run:
+ workflows: ["Build Wheel"]
+ branches: [main]
+ types: [completed]
+ workflow_dispatch:
+
+jobs:
+ Basic_CPU_Tests:
+ if: >-
+ (github.event_name == 'workflow_run' && github.event.workflow_run.conclusion == 'success') ||
+ github.event_name == 'pull_request' ||
+ github.event_name == 'push' ||
+ github.event_name == 'workflow_dispatch'
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: ["3.12"]
+
+ steps:
+ - name: Checkout code (for test files)
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: 'pip'
+
+ - name: Download built wheel artifact (only for workflow_run)
+ if: github.event_name == 'workflow_run'
+ uses: dawidd6/action-download-artifact@v6
+ with:
+ workflow: build_wheel.yml
+ run_id: ${{ github.event.workflow_run.id }}
+ name: cache-dist-wheel-py${{ matrix.python-version }}
+ path: dist/
+
+ - name: Install dependencies
+ run: |
+ pip install -U pip && pip install torch==2.9.1 torchvision --index-url https://download.pytorch.org/whl/cpu
+ if [ -f "dist/*.whl" ]; then
+ pip install dist/*.whl
+ else
+ pip install -e .
+ fi
+ pip install pytest
+
+ - name: Run Forward Pattern CPU Tests (pytest)
+ run: |
+ cd tests
+ pytest api/test_forward_pattern.py -v -s -x
+
+ - name: Run Refresh Context CPU Tests (pytest)
+ run: |
+ cd tests
+ pytest api/test_refresh_context.py -v -s -x
+
+ - name: Run TaylorSeers CPU Tests (pytest)
+ run: |
+ cd tests
+ pytest api/test_taylorseers.py -v -s -x
+
+ - name: Run Load Configs CPU Tests (pytest)
+ run: |
+ cd tests
+ pytest api/test_load_configs.py -v -s -x
+
+ - name: Verify cache_dit installation
+ run: |
+ python -c "import cache_dit; print(f'✅ cache_dit imported successfully, version: {cache_dit.__version__}')"
+ continue-on-error: false
+
+ - name: Upload test logs (optional)
+ uses: actions/upload-artifact@v4
+ with:
+ name: test-logs-py${{ matrix.python-version }}
+ path: tests/basic-cpu-tests-results.txt
+ retention-days: 7
+ if: always()
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
new file mode 100644
index 000000000..d887378cd
--- /dev/null
+++ b/.github/workflows/pre-commit.yml
@@ -0,0 +1,26 @@
+name: pre-commit
+
+on:
+ pull_request:
+ branches: [main]
+ push:
+ branches: [main]
+ workflow_dispatch: # for manual trigger
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: ${{ github.event_name == 'pull_request' }}
+
+permissions:
+ contents: read
+
+jobs:
+ pre-commit:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.12"
+ cache: 'pip'
+ - uses: pre-commit/action@v3.0.1
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index bf3cb9c2a..736b90b47 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -5,7 +5,7 @@ repos:
- id: check-docstring-first
- id: check-toml
- id: check-yaml
- exclude: packaging/.*
+ exclude: ^(packaging/.*|mkdocs\.yml)$
args:
- --allow-multiple-documents
- id: mixed-line-ending
diff --git a/.readthedocs.yml b/.readthedocs.yml
new file mode 100644
index 000000000..5a06c6634
--- /dev/null
+++ b/.readthedocs.yml
@@ -0,0 +1,24 @@
+# Read the Docs configuration file
+# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
+
+version: 2
+
+build:
+ os: ubuntu-22.04
+ tools:
+ python: "3.12"
+ jobs:
+ post_checkout:
+ - git fetch --unshallow || true
+
+mkdocs:
+ configuration: mkdocs.yml
+ fail_on_warning: true
+
+# Optionally declare the Python requirements required to build your docs
+python:
+ install:
+ - method: pip
+ path: .
+ extra_requirements:
+ - docs
diff --git a/LICENSE b/LICENSE
index 6871e49a5..e27f99d89 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,4 +1,4 @@
-Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved
+Copyright (c) 2025 Cache-DiT Authors. All Rights Reserved
Apache License
Version 2.0, January 2004
diff --git a/README.md b/README.md
index b4831506b..c9e15b111 100644
--- a/README.md
+++ b/README.md
@@ -1,304 +1,90 @@
-
+
A PyTorch-native and Flexible Inference Engine with Hybrid Cache Acceleration and Parallelism for 🤗DiTs
-
+
-
+
-|Baseline|SCM S S*|SCM F D*|SCM U D*|+TS|+compile|+FP8*|
+|Baseline|SCM Slow|SCM Fast|SCM Ultra|+compile|+FP8*|+CP2|
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
-|24.85s|15.4s|11.4s|8.2s|8.2s|**🎉7.1s**|**🎉4.5s**|
-|
|
|
|
|
|
|
|
+|24.85s|15.4s|11.4s|8.2s|**🎉7.1s**|**🎉4.5s**|**🎉2.9s**|
+|
|
|
|
|
|
|
|
-
- Scheme: DBCache + SCM(steps_computation_mask) + TS(TaylorSeer) + FP8* , L20x1, S*: static cache, D*: dynamic cache , S : Slow, F : Fast, U : Ultra Fast, TS : TaylorSeer, FP8* : FP8 DQ + Sage, FLUX.1 -Dev
-
+
-
+**🤗Why Cache-DiT❓❓**Cache-DiT is built on top of the Diffusers library and now supports nearly **[🔥ALL](https://cache-dit.readthedocs.io/en/latest/)** DiTs from Diffusers, including over **[🤗70+](https://github.com/vipshop/cache-dit)** DiTs. Please refer to our online documentation at [readthedocs.io](https://cache-dit.readthedocs.io/en/latest/) for more details. The optimizations made by Cache-DiT include: (**UAA**: [Ulysses Anything Attention](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention))
-
- U*: Ulysses Attention, UAA: Ulysses Anything Attenton , UAA*: UAA + Gloo, Device: NVIDIA L20
- FLUX.1-Dev w/o CPU Offload, 28 steps; Qwen-Image w/ CPU Offload, 50 steps; Gloo: Extra All Gather w/ Gloo
-
+- 🎉**Hybrid Cache Acceleration** ([**DBCache**](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/#dbcache-dual-block-cache), DBPrune, [**TaylorSeer**](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/#hybrid-taylorseer-calibrator), [**SCM**](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/#scm-steps-computation-masking) and more)
+- 🎉**Context Parallelism** (w/ Extended Diffusers' CP APIs, [**UAA**](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention), Async Ulysses, FP8 comm)
+- 🎉**Tensor Parallelism** (w/ PyTorch native DTensor and Tensor Parallelism APIs)
+- 🎉**Text Encoder Parallelism** (w/ PyTorch native DTensor and Tensor Parallelism APIs)
+- 🎉**Auto Encoder (VAE) Parallelism** (w/ Data or Tile Parallelism, avoid OOM)
+- 🎉**ControlNet Parallelism** (w/ Context Parallelism for ControlNet module)
+- 🎉Built-in **HTTP serving** deployment support with simple REST APIs
+- 🎉**Natively** compatible with **Compile**, **Offloading**, **Quantization**, ...
+- 🎉Integration into **vLLM-Omni**, **SGLang Diffusion**, SD.Next, ...
+- 🎉**Natively** supports **NVIDIA GPUs**, [**Ascend NPUs**](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/) (>= 1.2.0), ...
-|CP2 U* |CP2 UAA* | L20x1 | CP2 UAA* | CP2 U* | L20x1 | CP2 UAA* |
-|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
-|FLUX, 13.87s|**🎉13.88s**|23.25s| **🎉13.75s**|Qwen, 132s|181s|**🎉133s**|
-| | | | | | | |
-|1024x1024|1024x1024|1008x1008|1008x1008|1312x1312|1328x1328|1328x1328|
-|✔️U* ✔️UAA|✔️U* ✔️UAA| NO CP|❌U* ✔️UAA|✔️U* ✔️UAA| NO CP|❌U* ✔️UAA|
+## 🔥Latest News
-
+- [2026/01] **[🎉v1.2.0 Major Release](https://github.com/vipshop/cache-dit)** is ready: New Models Support(Z-Image, FLUX.2, LTX-2, etc), Request level Cache Context, HTTP Serving, [Ulysses Anything Attention](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention), TE-P, VAE-P, CN-P and [Ascend NPUs](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/) Support.
-## 🔥Hightlight
-
-We are excited to announce that the 🎉[**v1.1.0**](https://github.com/vipshop/cache-dit/releases/tag/v1.1.0) version of cache-dit has finally been released! It brings **[🔥Context Parallelism](./docs/User_Guide.md/#️hybrid-context-parallelism)** and **[🔥Tensor Parallelism](./docs/User_Guide.md#️hybrid-tensor-parallelism)** to cache-dit, thus making it a **[PyTorch-native](./)** and **[Flexible](./)** Inference Engine for 🤗DiTs. Key features: **Unified Cache APIs**, **Forward Pattern Matching**, **Block Adapter**, **DBCache**, **DBPrune**, **Cache CFG**, **TaylorSeer**, **[SCM](./docs/User_Guide.md#scm-steps-computation-masking)**, **Context Parallelism (w/ [UAA](./docs/User_Guide.md#uaa-ulysses-anything-attention))**, **Tensor Parallelism** and **🎉SOTA** performance.
+## 🚀Quick Start
+You can install the cache-dit from PyPI or from source:
```bash
-pip3 install -U cache-dit # Also, pip3 install git+https://github.com/huggingface/diffusers.git (latest)
+pip3 install -U cache-dit # or, pip3 install git+https://github.com/vipshop/cache-dit.git
```
-You can install the stable release of cache-dit from PyPI, or the latest development version from GitHub. Then try ♥️ Cache Acceleration with just **one line** of code ~ ♥️
+Then try ♥️ Cache Acceleration with just **one line** of code ~ ♥️
```python
>>> import cache_dit
>>> from diffusers import DiffusionPipeline
->>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
->>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
+>>> # The pipe can be any diffusion pipeline.
+>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
+>>> # Cache Acceleration with One-line code.
+>>> cache_dit.enable_cache(pipe)
+>>> # Or, Hybrid Cache Acceleration + Parallelism.
+>>> from cache_dit import DBCacheConfig, ParallelismConfig
+>>> cache_dit.enable_cache(
+... pipe, cache_config=DBCacheConfig(),
+... parallelism_config=ParallelismConfig(ulysses_size=2)
+... )
+>>> from cache_dit import load_configs
+>>> # Or, Load Acceleration config from a custom yaml file.
+>>> cache_dit.enable_cache(pipe, **load_configs("config.yaml"))
>>> output = pipe(...) # Just call the pipe as normal.
->>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
->>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
```
+Please refer to our online documentation at [readthedocs.io](https://cache-dit.readthedocs.io/en/latest/) for more details.
-### 📚Core Features
-
-- **[🎉Full 🤗Diffusers Support](./docs/User_Guide.md#supported-pipelines)**: Notably, **[cache-dit](https://github.com/vipshop/cache-dit)** now supports nearly **all** of Diffusers' **DiT-based** pipelines, include **[30+](./examples/pipeline/)** series, nearly **[100+](./examples/pipeline/)** pipelines, such as FLUX.1, Qwen-Image, Qwen-Image-Lightning, Wan 2.1/2.2, HunyuanImage-2.1, HunyuanVideo, HiDream, AuraFlow, CogView3Plus, CogView4, CogVideoX, LTXVideo, ConsisID, SkyReelsV2, VisualCloze, PixArt, Chroma, Mochi, SD 3.5, DiT-XL, etc.
-- **[🎉Extremely Easy to Use](./docs/User_Guide.md#unified-cache-apis)**: In most cases, you only need **one line** of code: `cache_dit.enable_cache(...)`. After calling this API, just use the pipeline as normal.
-- **[🎉Easy New Model Integration](./docs/User_Guide.md#automatic-block-adapter)**: Features like **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **Hybrid Forward Pattern**, and **Patch Functor** make it highly functional and flexible. For example, we achieved 🎉 Day 1 support for [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) with 1.7x speedup w/o precision loss—even before it was available in the Diffusers library.
-- **[🎉State-of-the-Art Performance](./bench/)**: Compared with algorithms including Δ-DiT, Chipmunk, FORA, DuCa, TaylorSeer and FoCa, cache-dit achieved the **SOTA** performance w/ **7.4x↑🎉** speedup on ClipScore!
-- **[🎉Support for 4/8-Steps Distilled Models](./bench/)**: Surprisingly, cache-dit's **DBCache** works for extremely few-step distilled models—something many other methods fail to do.
-- **[🎉Compatibility with Other Optimizations](./docs/User_Guide.md#️torch-compile)**: Designed to work seamlessly with torch.compile, Quantization ([torchao](./examples/quantize/), [🔥nunchaku](./examples/quantize/)), CPU or Sequential Offloading, **[🔥Context Parallelism](./docs/User_Guide.md/#️hybrid-context-parallelism)**, **[🔥Tensor Parallelism](./docs/User_Guide.md#️hybrid-tensor-parallelism)**, etc.
-- **[🎉Hybrid Cache Acceleration](./docs/User_Guide.md#taylorseer-calibrator)**: Now supports hybrid **Block-wise Cache + Calibrator** schemes (e.g., DBCache or DBPrune + TaylorSeerCalibrator). DBCache or DBPrune acts as the **Indicator** to decide *when* to cache, while the Calibrator decides *how* to cache. More mainstream cache acceleration algorithms (e.g., FoCa) will be supported in the future, along with additional benchmarks—stay tuned for updates!
-- **[🤗Diffusers Ecosystem Integration](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit)**: 🔥**cache-dit** has joined the Diffusers community ecosystem as the **first** DiT-specific cache acceleration framework! Check out the documentation here:
-
-
-
-The comparison between **cache-dit** and other algorithms shows that within a speedup ratio (TFLOPs) less than 🎉**4x**, cache-dit achieved the **SOTA** performance. Please refer to [📚Benchmarks](https://github.com/vipshop/cache-dit/tree/main/bench/) for more details.
-
-
-
-| Method | TFLOPs(↓) | SpeedUp(↑) | ImageReward(↑) | Clip Score(↑) |
-| --- | --- | --- | --- | --- |
-| [**FLUX.1**-dev]: 50 steps | 3726.87 | 1.00× | 0.9898 | 32.404 |
-| Chipmunk | 1505.87 | 2.47× | 0.9936 | 32.776 |
-| FORA(N=3) | 1320.07 | 2.82× | 0.9776 | 32.266 |
-| **[DBCache(S)](https://github.com/vipshop/cache-dit)** | 1400.08 | **2.66×** | **1.0065** | 32.838 |
-| DuCa(N=5) | 978.76 | 3.80× | 0.9955 | 32.241 |
-| TeaCache(l=0.8) | 892.35 | 4.17× | 0.8683 | 31.704 |
-| TaylorSeer(N=4,O=2) | 1042.27 | 3.57× | 0.9857 | 32.413 |
-| **[DBCache(S)+TS](https://github.com/vipshop/cache-dit)** | 1153.05 | **3.23×** | **1.0221** | 32.819 |
-| **[DBCache(M)+TS](https://github.com/vipshop/cache-dit)** | 944.75 | **3.94×** | **1.0107** | 32.865 |
-| FoCa(N=5) | 893.54 | **4.16×** | 1.0029 | **32.948** |
-| [**FLUX.1**-dev]: 22% steps | 818.29 | 4.55× | 0.8183 | 31.772 |
-| TaylorSeer(N=7,O=2) | 670.44 | 5.54× | 0.9128 | 32.128 |
-| FoCa(N=8) | 596.07 | 6.24× | 0.9502 | **32.706** |
-| **[DBCache(F)+TS](https://github.com/vipshop/cache-dit)** | 651.90 | **5.72x** | **0.9526** | 32.568 |
-| **[DBCache(U)+TS](https://github.com/vipshop/cache-dit)** | 505.47 | **7.37x** | 0.8645 | **32.719** |
-
-
-
-🎉Surprisingly, **cache-dit** still works in the **extremely few-step** distill model, such as **Qwen-Image-Lightning**, with the F16B16 config, the PSNR is 34.8 and the ImageReward is 1.26. It maintained a relatively high precision.
-
-
-| Config | PSNR(↑) | Clip Score(↑) | ImageReward(↑) | TFLOPs(↓) | SpeedUp(↑) |
-|----------------------------|-----------|------------|--------------|----------|------------|
-| **[Full 4 steps]** | INF | 35.5797 | 1.2630 | 274.33 | 1.00x |
-| F24B24 | 36.3242 | 35.6224 | 1.2630 | 264.74 | 1.04x |
-| F16B16 | 34.8163 | 35.6109 | 1.2614 | 244.25 | 1.12x |
-| F12B12 | 33.8953 | 35.6535 | 1.2549 | 234.63 | 1.17x |
-| F8B8 | 33.1374 | 35.7284 | 1.2517 | 224.29 | 1.22x |
-| F1B0 | 31.8317 | 35.6651 | 1.2397 | 206.90 | 1.33x |
-
-
-
-## 🔥Supported DiTs
+## 🚀Quick Links
-> [!Tip]
-> One **Model Series** may contain **many** pipelines. cache-dit applies optimizations at the **Transformer** level; thus, any pipelines that include the supported transformer are already supported by cache-dit. ✅: known work and official supported now; ✖️: unofficial supported now, but maybe support in the future; **[`Q`](https://github.com/nunchaku-tech/nunchaku)**: **4-bits** models w/ [nunchaku](https://github.com/nunchaku-tech/nunchaku) + SVDQ **W4A4**.
+- [📊Examples](https://github.com/vipshop/cache-dit/tree/main/examples/) - The **easiest** way to enable **hybrid cache acceleration** and **parallelism** for DiTs with cache-dit is to start with our examples for popular models: FLUX, Z-Image, Qwen-Image, Wan, etc.
+- [🌐HTTP Serving](https://cache-dit.readthedocs.io/en/latest) - Deploy cache-dit models with HTTP API for **text-to-image**, **image editing**, **multi-image editing**, and **text/image-to-video** generation.
+- [🎉User Guide](https://cache-dit.readthedocs.io/en/latest/) - For more advanced features, please refer to the [🎉User Guide](https://cache-dit.readthedocs.io/en/latest/) for details.
+- [❓FAQ](https://cache-dit.readthedocs.io/en/latest) - Frequently asked questions including attention backend configuration, troubleshooting, and optimization tips.
-
-
-| 📚Model | Cache | CP | TP | 📚Model | Cache | CP | TP |
-|:---|:---|:---|:---|:---|:---|:---|:---|
-| **🎉[FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[FLUX.1 `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[FLUX.1-Fill](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[FLUX.1-Fill `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[Qwen-Image](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen-Image `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[Qwen...Edit](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen...Edit `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[Qwen...Lightning](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen...Light `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[Qwen...Control..](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen...E...Light `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[Wan 2.1 I2V/T2V](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Mochi](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✅ |
-| **🎉[Wan 2.1 VACE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[HiDream](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[Wan 2.2 I2V/T2V](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[HunyunDiT](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✅ |
-| **🎉[HunyuanVideo](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Sana](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[ChronoEdit](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Bria](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[CogVideoX](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[SkyReelsV2](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ |
-| **🎉[CogVideoX 1.5](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Lumina 1/2](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[CogView4](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[DiT-XL](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[CogView3Plus](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Allegro](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[PixArt Sigma](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Cosmos](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[PixArt Alpha](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[OmniGen](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[Chroma-HD](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ️✅ | **🎉[EasyAnimate](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[VisualCloze](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[StableDiffusion3](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[HunyuanImage](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[PRX T2I](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[Kandinsky5](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅️ | ✅️ | **🎉[Amused](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[LTXVideo](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[AuraFlow](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[ConsisID](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[LongCatVideo](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-
-
-
-
-🔥Click here to show many Image/Video cases🔥
-
-
- 🎉Now, cache-dit covers almost All Diffusers' DiT Pipelines🎉
- 🔥Qwen-Image | Qwen-Image-Edit | Qwen-Image-Edit-Plus 🔥
- 🔥FLUX.1 | Qwen-Image-Lightning 4/8 Steps | Wan 2.1 | Wan 2.2 🔥
- 🔥HunyuanImage-2.1 | HunyuanVideo | HunyuanDiT | HiDream | AuraFlow 🔥
- 🔥CogView3Plus | CogView4 | LTXVideo | CogVideoX | CogVideoX 1.5 | ConsisID 🔥
- 🔥Cosmos | SkyReelsV2 | VisualCloze | OmniGen 1/2 | Lumina 1/2 | PixArt 🔥
- 🔥Chroma | Sana | Allegro | Mochi | SD 3/3.5 | Amused | ... | DiT-XL 🔥
-
-
-
-
-
-
-
-## 📖Table of Contents
+## 🌐Community Integration
-
+- 🔥[Ascend NPU x Cache-DiT](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/)
+- 🎉[Diffusers x Cache-DiT](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit)
+- 🎉[SGLang Diffusion x Cache-DiT](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/docs/cache_dit.md)
+- 🎉[vLLM-Omni x Cache-DiT](https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/diffusion/cache_dit_acceleration/)
+- 🎉[Nunchaku x Cache-DiT](https://nunchaku.tech/docs/nunchaku/usage/cache.html#cache-dit)
+- 🎉[SD.Next x Cache-DiT](https://github.com/vladmandic/sdnext/blob/master/modules/cachedit.py)
+- 🎉[stable-diffusion.cpp x Cache-DiT](https://github.com/leejet/stable-diffusion.cpp/blob/master/cache_dit.hpp)
+- 🎉[jetson-containers x Cache-DiT](https://github.com/dusty-nv/jetson-containers/tree/master/packages/diffusion/cache_edit)
-For more advanced features such as **Unified Cache APIs**, **Forward Pattern Matching**, **Automatic Block Adapter**, **Hybrid Forward Pattern**, **Patch Functor**, **DBCache**, **DBPrune**, **TaylorSeer Calibrator**, **SCM**, **Hybrid Cache CFG**, **Context Parallelism (w/ UAA)** and **Tensor Parallelism**, please refer to the [🎉User_Guide.md](./docs/User_Guide.md) for details.
-
-- [⚙️Installation](./docs/User_Guide.md#️installation)
-- [🔥Supported DiTs](./docs/User_Guide.md#supported)
-- [🔥Benchmarks](./docs/User_Guide.md#benchmarks)
-- [🎉Unified Cache APIs](./docs/User_Guide.md#unified-cache-apis)
- - [📚Forward Pattern Matching](./docs/User_Guide.md#forward-pattern-matching)
- - [📚Cache with One-line Code](./docs/User_Guide.md#%EF%B8%8Fcache-acceleration-with-one-line-code)
- - [🔥Automatic Block Adapter](./docs/User_Guide.md#automatic-block-adapter)
- - [📚Hybrid Forward Pattern](./docs/User_Guide.md#hybrid-forward-pattern)
- - [📚Implement Patch Functor](./docs/User_Guide.md#implement-patch-functor)
- - [📚Transformer-Only Interface](./docs/User_Guide.md#transformer-only-interface)
- - [📚How to use ParamsModifier](./docs/User_Guide.md#how-to-use-paramsmodifier)
- - [🤖Cache Acceleration Stats](./docs/User_Guide.md#cache-acceleration-stats-summary)
-- [⚡️DBCache: Dual Block Cache](./docs/User_Guide.md#️dbcache-dual-block-cache)
-- [⚡️DBPrune: Dynamic Block Prune](./docs/User_Guide.md#️dbprune-dynamic-block-prune)
-- [⚡️Hybrid Cache CFG](./docs/User_Guide.md#️hybrid-cache-cfg)
-- [🔥Hybrid TaylorSeer Calibrator](./docs/User_Guide.md#taylorseer-calibrator)
-- [🤖SCM: Steps Computation Masking](./docs/User_Guide.md#steps-mask)
-- [⚡️Hybrid Context Parallelism](./docs/User_Guide.md#context-parallelism)
-- [🤖UAA: Ulysses Anything Attention](./docs/User_Guide.md#ulysses-anything-attention)
-- [⚡️Hybrid Tensor Parallelism](./docs/User_Guide.md#tensor-parallelism)
-- [🤖Low-bits Quantization](./docs/User_Guide.md#quantization)
-- [🤖How to use FP8 Attention](./docs/User_Guide.md#fp8-attention)
-- [🛠Metrics Command Line](./docs/User_Guide.md#metrics-cli)
-- [⚙️Torch Compile](./docs/User_Guide.md#️torch-compile)
-- [📚API Documents](./docs/User_Guide.md#api-documentation)
-
-## 👋Contribute
-
-
-How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](https://github.com/vipshop/cache-dit/raw/main/docs/CONTRIBUTE.md).
-
-
-
-## 🎉Projects Using CacheDiT
-
-Here is a curated list of open-source projects integrating **CacheDiT**, including popular repositories like [jetson-containers](https://github.com/dusty-nv/jetson-containers/blob/master/packages/diffusion/cache_edit/build.sh), [flux-fast](https://github.com/huggingface/flux-fast), and [sdnext](https://github.com/vladmandic/sdnext/discussions/4269). 🎉**CacheDiT** has been **recommended** by: [Wan 2.2](https://github.com/Wan-Video/Wan2.2), [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning), [Qwen-Image](https://github.com/QwenLM/Qwen-Image), [LongCat-Video](https://github.com/meituan-longcat/LongCat-Video), [Kandinsky-5](https://github.com/ai-forever/Kandinsky-5), [LeMiCa](https://github.com/UnicomAI/LeMiCa), [🤗diffusers](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit) and [HelloGitHub](https://hellogithub.com/repository/vipshop/cache-dit), among others.
## ©️Acknowledgements
-Special thanks to vipshop's Computer Vision AI Team for supporting document, testing and production-level deployment of this project. We learned the design and reused code from the following projects: [🤗diffusers](https://huggingface.co/docs/diffusers), [ParaAttention](https://github.com/chengzeyi/ParaAttention), [xDiT](https://github.com/xdit-project/xDiT), [TaylorSeer](https://github.com/Shenyi-Z/TaylorSeer) and [LeMiCa](https://github.com/UnicomAI/LeMiCa).
+Special thanks to vipshop's Computer Vision AI Team for supporting document, testing and deployment of this project. We learned the design and reused code from the following projects: [Diffusers](https://huggingface.co/docs/diffusers), [SGLang](https://github.com/sgl-project/sglang), [vLLM-Omni](https://github.com/vllm-project/vllm-omni), [ParaAttention](https://github.com/chengzeyi/ParaAttention), [xDiT](https://github.com/xdit-project/xDiT), [TaylorSeer](https://github.com/Shenyi-Z/TaylorSeer) and [LeMiCa](https://github.com/UnicomAI/LeMiCa).
+
## ©️Citations
diff --git a/assets/cache-dit-logo-v2.png b/assets/cache-dit-logo-v2.png
new file mode 100644
index 000000000..68513634d
Binary files /dev/null and b/assets/cache-dit-logo-v2.png differ
diff --git a/assets/cache-dit-logo-v2.svg b/assets/cache-dit-logo-v2.svg
new file mode 100644
index 000000000..807d4d785
--- /dev/null
+++ b/assets/cache-dit-logo-v2.svg
@@ -0,0 +1,4 @@
+
+
+
+
diff --git a/assets/npu_sample/flux.1024x1024.C0_Q0_NONE.png b/assets/npu_sample/flux.1024x1024.C0_Q0_NONE.png
new file mode 100644
index 000000000..97a008359
Binary files /dev/null and b/assets/npu_sample/flux.1024x1024.C0_Q0_NONE.png differ
diff --git a/assets/npu_sample/flux.1024x1024.C0_Q0_NONE_Ulysses2.png b/assets/npu_sample/flux.1024x1024.C0_Q0_NONE_Ulysses2.png
new file mode 100644
index 000000000..74e33862d
Binary files /dev/null and b/assets/npu_sample/flux.1024x1024.C0_Q0_NONE_Ulysses2.png differ
diff --git a/assets/npu_sample/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async.png b/assets/npu_sample/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async.png
new file mode 100644
index 000000000..74e33862d
Binary files /dev/null and b/assets/npu_sample/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async.png differ
diff --git a/assets/npu_sample/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async_native_npu.png b/assets/npu_sample/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async_native_npu.png
new file mode 100644
index 000000000..74e33862d
Binary files /dev/null and b/assets/npu_sample/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async_native_npu.png differ
diff --git a/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses2_TEP_ulysses_anything.png b/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses2_TEP_ulysses_anything.png
new file mode 100644
index 000000000..20046acc5
Binary files /dev/null and b/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses2_TEP_ulysses_anything.png differ
diff --git a/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses4_TEP_ulysses_anything.png b/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses4_TEP_ulysses_anything.png
new file mode 100644
index 000000000..c57e59e60
Binary files /dev/null and b/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses4_TEP_ulysses_anything.png differ
diff --git a/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses4_TEP_ulysses_anything_ulysses_async.png b/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses4_TEP_ulysses_anything_ulysses_async.png
new file mode 100644
index 000000000..f4df539f4
Binary files /dev/null and b/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses4_TEP_ulysses_anything_ulysses_async.png differ
diff --git a/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses4_TEP_ulysses_anything_ulysses_async_native_npu.png b/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses4_TEP_ulysses_anything_ulysses_async_native_npu.png
new file mode 100644
index 000000000..b704e83f6
Binary files /dev/null and b/assets/npu_sample/qwen_image_edit.1024x1024.C0_Q0_NONE_Ulysses4_TEP_ulysses_anything_ulysses_async_native_npu.png differ
diff --git a/assets/npu_sample/zimage.1024x1024.C0_Q0_NONE_Ulysses2_native_npu.png b/assets/npu_sample/zimage.1024x1024.C0_Q0_NONE_Ulysses2_native_npu.png
new file mode 100644
index 000000000..597c197e8
Binary files /dev/null and b/assets/npu_sample/zimage.1024x1024.C0_Q0_NONE_Ulysses2_native_npu.png differ
diff --git a/assets/npu_sample/zimage.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async_native_npu.png b/assets/npu_sample/zimage.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async_native_npu.png
new file mode 100644
index 000000000..597c197e8
Binary files /dev/null and b/assets/npu_sample/zimage.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async_native_npu.png differ
diff --git a/assets/npu_sample/zimage.1024x1024.C0_Q0_NONE_native_npu.png b/assets/npu_sample/zimage.1024x1024.C0_Q0_NONE_native_npu.png
new file mode 100644
index 000000000..7391a61a3
Binary files /dev/null and b/assets/npu_sample/zimage.1024x1024.C0_Q0_NONE_native_npu.png differ
diff --git a/assets/parallelism/async_ulysses.png b/assets/parallelism/async_ulysses.png
new file mode 100644
index 000000000..932a6ee59
Binary files /dev/null and b/assets/parallelism/async_ulysses.png differ
diff --git a/assets/parallelism/async_ulysses_fp8.png b/assets/parallelism/async_ulysses_fp8.png
new file mode 100644
index 000000000..96cb91e43
Binary files /dev/null and b/assets/parallelism/async_ulysses_fp8.png differ
diff --git a/assets/parallelism/flux.1024x1024.C0_Q0_NONE_Ulysses2.png b/assets/parallelism/flux.1024x1024.C0_Q0_NONE_Ulysses2.png
new file mode 100644
index 000000000..9b7a3dd23
Binary files /dev/null and b/assets/parallelism/flux.1024x1024.C0_Q0_NONE_Ulysses2.png differ
diff --git a/assets/parallelism/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async_qkv_proj.png b/assets/parallelism/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async_qkv_proj.png
new file mode 100644
index 000000000..9b7a3dd23
Binary files /dev/null and b/assets/parallelism/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_async_qkv_proj.png differ
diff --git a/assets/parallelism/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_float8.png b/assets/parallelism/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_float8.png
new file mode 100644
index 000000000..b3a7eef5f
Binary files /dev/null and b/assets/parallelism/flux.1024x1024.C0_Q0_NONE_Ulysses2_ulysses_float8.png differ
diff --git a/assets/parallelism/flux.1024x1024.C1_Q0_NONE_Ulysses2.png b/assets/parallelism/flux.1024x1024.C1_Q0_NONE_Ulysses2.png
new file mode 100644
index 000000000..79ade9718
Binary files /dev/null and b/assets/parallelism/flux.1024x1024.C1_Q0_NONE_Ulysses2.png differ
diff --git a/assets/parallelism/flux.1024x1024.C1_Q0_NONE_Ulysses2_ulysses_async_qkv_proj.png b/assets/parallelism/flux.1024x1024.C1_Q0_NONE_Ulysses2_ulysses_async_qkv_proj.png
new file mode 100644
index 000000000..79ade9718
Binary files /dev/null and b/assets/parallelism/flux.1024x1024.C1_Q0_NONE_Ulysses2_ulysses_async_qkv_proj.png differ
diff --git a/assets/parallelism/flux.1024x1024.C1_Q0_NONE_Ulysses2_ulysses_float8.png b/assets/parallelism/flux.1024x1024.C1_Q0_NONE_Ulysses2_ulysses_float8.png
new file mode 100644
index 000000000..dfc81e6d9
Binary files /dev/null and b/assets/parallelism/flux.1024x1024.C1_Q0_NONE_Ulysses2_ulysses_float8.png differ
diff --git a/assets/steps_mask/flux.1024x1024.C1_Q1_float8_DBCache_F1B0_W8I1M0MC0_R0.35_SCM1111001000001000000100000001_dynamic_CFG0_T1O1_Ulysses2_S19_ulysses_float8_sage.png b/assets/steps_mask/flux.1024x1024.C1_Q1_float8_DBCache_F1B0_W8I1M0MC0_R0.35_SCM1111001000001000000100000001_dynamic_CFG0_T1O1_Ulysses2_S19_ulysses_float8_sage.png
new file mode 100644
index 000000000..7ab43ed10
Binary files /dev/null and b/assets/steps_mask/flux.1024x1024.C1_Q1_float8_DBCache_F1B0_W8I1M0MC0_R0.35_SCM1111001000001000000100000001_dynamic_CFG0_T1O1_Ulysses2_S19_ulysses_float8_sage.png differ
diff --git a/collect_env.py b/collect_env.py
new file mode 100644
index 000000000..45a1ca97e
--- /dev/null
+++ b/collect_env.py
@@ -0,0 +1,709 @@
+# mypy: allow-untyped-defs
+
+# Unlike the rest of the PyTorch this file must be python2 compliant.
+# This script outputs relevant system environment info
+# Run it with `python collect_env.py` or `python -m torch.utils.collect_env`
+import datetime
+import json
+import locale
+import os
+import re
+import subprocess
+import sys
+from collections import namedtuple
+
+
+try:
+ import torch
+
+ TORCH_AVAILABLE = True
+except (ImportError, NameError, AttributeError, OSError):
+ TORCH_AVAILABLE = False
+
+# System Environment Information
+SystemEnv = namedtuple(
+ "SystemEnv",
+ [
+ "torch_version",
+ "is_debug_build",
+ "cuda_compiled_version",
+ "gcc_version",
+ "clang_version",
+ "cmake_version",
+ "os",
+ "libc_version",
+ "python_version",
+ "python_platform",
+ "is_cuda_available",
+ "cuda_runtime_version",
+ "cuda_module_loading",
+ "nvidia_driver_version",
+ "nvidia_gpu_models",
+ "cudnn_version",
+ "pip_version", # 'pip' or 'pip3'
+ "pip_packages",
+ "conda_packages",
+ "hip_compiled_version",
+ "hip_runtime_version",
+ "miopen_runtime_version",
+ "caching_allocator_config",
+ "is_xnnpack_available",
+ "cpu_info",
+ ],
+)
+
+COMMON_PATTERNS = [
+ "torch",
+ "numpy",
+ "triton",
+ "optree",
+]
+
+NVIDIA_PATTERNS = [
+ "cuda-cudart",
+ "cuda-cupti",
+ "cuda-libraries",
+ "cuda-opencl",
+ "cuda-nvrtc",
+ "cuda-runtime",
+ "cublas",
+ "cudnn",
+ "cufft",
+ "curand",
+ "cusolver",
+ "cusparse",
+ "nccl",
+ "nvjitlink",
+ "nvtx",
+]
+
+CONDA_PATTERNS = [
+ "cudatoolkit",
+ "soumith",
+ "mkl",
+ "magma",
+]
+
+PIP_PATTERNS = [
+ "mypy",
+ "flake8",
+ "onnx",
+]
+
+
+def run(command):
+ """Return (return-code, stdout, stderr)."""
+ shell = True if type(command) is str else False
+ p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=shell)
+ raw_output, raw_err = p.communicate()
+ rc = p.returncode
+ if get_platform() == "win32":
+ enc = "oem"
+ else:
+ enc = locale.getpreferredencoding()
+ output = raw_output.decode(enc)
+ err = raw_err.decode(enc)
+ return rc, output.strip(), err.strip()
+
+
+def run_and_read_all(run_lambda, command):
+ """Run command using run_lambda; reads and returns entire output if rc is 0."""
+ rc, out, _ = run_lambda(command)
+ if rc != 0:
+ return None
+ return out
+
+
+def run_and_parse_first_match(run_lambda, command, regex):
+ """Run command using run_lambda, returns the first regex match if it exists."""
+ rc, out, _ = run_lambda(command)
+ if rc != 0:
+ return None
+ match = re.search(regex, out)
+ if match is None:
+ return None
+ return match.group(1)
+
+
+def run_and_return_first_line(run_lambda, command):
+ """Run command using run_lambda and returns first line if output is not empty."""
+ rc, out, _ = run_lambda(command)
+ if rc != 0:
+ return None
+ return out.split("\n")[0]
+
+
+def get_conda_packages(run_lambda, patterns=None):
+ if patterns is None:
+ patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS
+ conda = os.environ.get("CONDA_EXE", "conda")
+ out = run_and_read_all(run_lambda, "{} list".format(conda))
+ if out is None:
+ return out
+
+ return "\n".join(
+ line
+ for line in out.splitlines()
+ if not line.startswith("#") and any(name in line for name in patterns)
+ )
+
+
+def get_gcc_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, "gcc --version", r"gcc (.*)")
+
+
+def get_clang_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, "clang --version", r"clang version (.*)")
+
+
+def get_cmake_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, "cmake --version", r"cmake (.*)")
+
+
+def get_nvidia_driver_version(run_lambda):
+ if get_platform() == "darwin":
+ cmd = "kextstat | grep -i cuda"
+ return run_and_parse_first_match(run_lambda, cmd, r"com[.]nvidia[.]CUDA [(](.*?)[)]")
+ smi = get_nvidia_smi()
+ return run_and_parse_first_match(run_lambda, smi, r"Driver Version: (.*?) ")
+
+
+def get_gpu_info(run_lambda):
+ if get_platform() == "darwin" or (
+ TORCH_AVAILABLE and hasattr(torch.version, "hip") and torch.version.hip is not None
+ ):
+ if TORCH_AVAILABLE and torch.cuda.is_available():
+ if torch.version.hip is not None:
+ prop = torch.cuda.get_device_properties(0)
+ if hasattr(prop, "gcnArchName"):
+ gcnArch = " ({})".format(prop.gcnArchName)
+ else:
+ gcnArch = "NoGCNArchNameOnOldPyTorch"
+ else:
+ gcnArch = ""
+ return torch.cuda.get_device_name(None) + gcnArch
+ return None
+ smi = get_nvidia_smi()
+ uuid_regex = re.compile(r" \(UUID: .+?\)")
+ rc, out, _ = run_lambda(smi + " -L")
+ if rc != 0:
+ return None
+ # Anonymize GPUs by removing their UUID
+ return re.sub(uuid_regex, "", out)
+
+
+def get_running_cuda_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, "nvcc --version", r"release .+ V(.*)")
+
+
+def get_cudnn_version(run_lambda):
+ """Return a list of libcudnn.so; it's hard to tell which one is being used."""
+ if get_platform() == "win32":
+ system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
+ cuda_path = os.environ.get("CUDA_PATH", "%CUDA_PATH%")
+ where_cmd = os.path.join(system_root, "System32", "where")
+ cudnn_cmd = '{} /R "{}\\bin" cudnn*.dll'.format(where_cmd, cuda_path)
+ elif get_platform() == "darwin":
+ # CUDA libraries and drivers can be found in /usr/local/cuda/. See
+ # https://docs.nvidia.com/cuda/archive/9.0/cuda-installation-guide-mac-os-x/index.html#installation
+ # https://docs.nvidia.com/deeplearning/cudnn/installation/latest/
+ # Use CUDNN_LIBRARY when cudnn library is installed elsewhere.
+ cudnn_cmd = "ls /usr/local/cuda/lib/libcudnn*"
+ else:
+ cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev'
+ rc, out, _ = run_lambda(cudnn_cmd)
+ # find will return 1 if there are permission errors or if not found
+ if len(out) == 0 or (rc != 1 and rc != 0):
+ lib = os.environ.get("CUDNN_LIBRARY")
+ if lib is not None and os.path.isfile(lib):
+ return os.path.realpath(lib)
+ return None
+ files_set = set()
+ for fn in out.split("\n"):
+ fn = os.path.realpath(fn) # eliminate symbolic links
+ if os.path.isfile(fn):
+ files_set.add(fn)
+ if not files_set:
+ return None
+ # Alphabetize the result because the order is non-deterministic otherwise
+ files = sorted(files_set)
+ if len(files) == 1:
+ return files[0]
+ result = "\n".join(files)
+ return "Probably one of the following:\n{}".format(result)
+
+
+def get_nvidia_smi():
+ # Note: nvidia-smi is currently available only on Windows and Linux
+ smi = "nvidia-smi"
+ if get_platform() == "win32":
+ system_root = os.environ.get("SYSTEMROOT", "C:\\Windows")
+ program_files_root = os.environ.get("PROGRAMFILES", "C:\\Program Files")
+ legacy_path = os.path.join(program_files_root, "NVIDIA Corporation", "NVSMI", smi)
+ new_path = os.path.join(system_root, "System32", smi)
+ smis = [new_path, legacy_path]
+ for candidate_smi in smis:
+ if os.path.exists(candidate_smi):
+ smi = '"{}"'.format(candidate_smi)
+ break
+ return smi
+
+
+# example outputs of CPU infos
+# * linux
+# Architecture: x86_64
+# CPU op-mode(s): 32-bit, 64-bit
+# Address sizes: 46 bits physical, 48 bits virtual
+# Byte Order: Little Endian
+# CPU(s): 128
+# On-line CPU(s) list: 0-127
+# Vendor ID: GenuineIntel
+# Model name: Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
+# CPU family: 6
+# Model: 106
+# Thread(s) per core: 2
+# Core(s) per socket: 32
+# Socket(s): 2
+# Stepping: 6
+# BogoMIPS: 5799.78
+# Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr
+# sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl
+# xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16
+# pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand
+# hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced
+# fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap
+# avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1
+# xsaves wbnoinvd ida arat avx512vbmi pku ospke avx512_vbmi2 gfni vaes vpclmulqdq
+# avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear flush_l1d arch_capabilities
+# Virtualization features:
+# Hypervisor vendor: KVM
+# Virtualization type: full
+# Caches (sum of all):
+# L1d: 3 MiB (64 instances)
+# L1i: 2 MiB (64 instances)
+# L2: 80 MiB (64 instances)
+# L3: 108 MiB (2 instances)
+# NUMA:
+# NUMA node(s): 2
+# NUMA node0 CPU(s): 0-31,64-95
+# NUMA node1 CPU(s): 32-63,96-127
+# Vulnerabilities:
+# Itlb multihit: Not affected
+# L1tf: Not affected
+# Mds: Not affected
+# Meltdown: Not affected
+# Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
+# Retbleed: Not affected
+# Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
+# Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
+# Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
+# Srbds: Not affected
+# Tsx async abort: Not affected
+# * win32
+# Architecture=9
+# CurrentClockSpeed=2900
+# DeviceID=CPU0
+# Family=179
+# L2CacheSize=40960
+# L2CacheSpeed=
+# Manufacturer=GenuineIntel
+# MaxClockSpeed=2900
+# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
+# ProcessorType=3
+# Revision=27142
+#
+# Architecture=9
+# CurrentClockSpeed=2900
+# DeviceID=CPU1
+# Family=179
+# L2CacheSize=40960
+# L2CacheSpeed=
+# Manufacturer=GenuineIntel
+# MaxClockSpeed=2900
+# Name=Intel(R) Xeon(R) Platinum 8375C CPU @ 2.90GHz
+# ProcessorType=3
+# Revision=27142
+
+
+def get_cpu_info(run_lambda):
+ rc, out, err = 0, "", ""
+ if get_platform() == "linux":
+ rc, out, err = run_lambda("lscpu")
+ elif get_platform() == "win32":
+ rc, out, err = run_lambda(
+ 'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\
+ Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\
+ | ConvertTo-Json"'
+ )
+ if rc == 0:
+ lst = []
+ try:
+ obj = json.loads(out)
+ if type(obj) is list:
+ for o in obj:
+ lst.append("----------------------")
+ lst.extend([f"{k}: {v}" for (k, v) in o.items()])
+ else:
+ lst.extend([f"{k}: {v}" for (k, v) in obj.items()])
+ except ValueError as e:
+ lst.append(out)
+ lst.append(str(e))
+ out = "\n".join(lst)
+ elif get_platform() == "darwin":
+ rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
+ cpu_info = "None"
+ if rc == 0:
+ cpu_info = out
+ else:
+ cpu_info = err
+ return cpu_info
+
+
+def get_platform():
+ if sys.platform.startswith("linux"):
+ return "linux"
+ elif sys.platform.startswith("win32"):
+ return "win32"
+ elif sys.platform.startswith("cygwin"):
+ return "cygwin"
+ elif sys.platform.startswith("darwin"):
+ return "darwin"
+ else:
+ return sys.platform
+
+
+def get_mac_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, "sw_vers -productVersion", r"(.*)")
+
+
+def get_windows_version(run_lambda):
+ ret = run_and_read_all(
+ run_lambda,
+ 'powershell.exe "gwmi -Class Win32_OperatingSystem | Select-Object -Property Caption,\
+ OSArchitecture,Version | ConvertTo-Json"',
+ )
+ try:
+ obj = json.loads(ret)
+ ret = f'{obj["Caption"]} ({obj["Version"]} {obj["OSArchitecture"]})'
+ except ValueError as e:
+ ret += f"\n{str(e)}"
+ return ret
+
+
+def get_lsb_version(run_lambda):
+ return run_and_parse_first_match(run_lambda, "lsb_release -a", r"Description:\t(.*)")
+
+
+def check_release_file(run_lambda):
+ return run_and_parse_first_match(run_lambda, "cat /etc/*-release", r'PRETTY_NAME="(.*)"')
+
+
+def get_os(run_lambda):
+ from platform import machine
+
+ platform = get_platform()
+
+ if platform == "win32" or platform == "cygwin":
+ return get_windows_version(run_lambda)
+
+ if platform == "darwin":
+ version = get_mac_version(run_lambda)
+ if version is None:
+ return None
+ return "macOS {} ({})".format(version, machine())
+
+ if platform == "linux":
+ # Ubuntu/Debian based
+ desc = get_lsb_version(run_lambda)
+ if desc is not None:
+ return "{} ({})".format(desc, machine())
+
+ # Try reading /etc/*-release
+ desc = check_release_file(run_lambda)
+ if desc is not None:
+ return "{} ({})".format(desc, machine())
+
+ return "{} ({})".format(platform, machine())
+
+ # Unknown platform
+ return platform
+
+
+def get_python_platform():
+ import platform
+
+ return platform.platform()
+
+
+def get_libc_version():
+ import platform
+
+ if get_platform() != "linux":
+ return "N/A"
+ return "-".join(platform.libc_ver())
+
+
+def get_pip_packages(run_lambda, patterns=None):
+ """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages."""
+ if patterns is None:
+ patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS
+
+ pip_version = "pip3" if sys.version_info.major == 3 else "pip"
+
+ os.environ["PIP_DISABLE_PIP_VERSION_CHECK"] = "1"
+ # People generally have pip as `pip` or `pip3`
+ # But here it is invoked as `python -mpip`
+ out = run_and_read_all(run_lambda, [sys.executable, "-mpip", "list", "--format=freeze"])
+ if out is None:
+ return pip_version, out
+
+ filtered_out = "\n".join(
+ line for line in out.splitlines() if any(name in line for name in patterns)
+ )
+
+ return pip_version, filtered_out
+
+
+def get_cachingallocator_config():
+ ca_config = os.environ.get("PYTORCH_CUDA_ALLOC_CONF", "")
+ if not ca_config:
+ ca_config = os.environ.get("PYTORCH_HIP_ALLOC_CONF", "")
+ return ca_config
+
+
+def get_cuda_module_loading_config():
+ if TORCH_AVAILABLE and torch.cuda.is_available():
+ torch.cuda.init()
+ config = os.environ.get("CUDA_MODULE_LOADING", "")
+ return config
+ else:
+ return "N/A"
+
+
+def is_xnnpack_available():
+ if TORCH_AVAILABLE:
+ import torch.backends.xnnpack
+
+ return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
+ else:
+ return "N/A"
+
+
+def get_env_info():
+ """
+ Collects environment information to aid in debugging.
+
+ The returned environment information contains details on torch version, is debug build
+ or not, cuda compiled version, gcc version, clang version, cmake version, operating
+ system, libc version, python version, python platform, CUDA availability, CUDA
+ runtime version, CUDA module loading config, GPU model and configuration, Nvidia
+ driver version, cuDNN version, pip version and versions of relevant pip and
+ conda packages, HIP runtime version, MIOpen runtime version,
+ Caching allocator config, XNNPACK availability and CPU information.
+
+ Returns:
+ SystemEnv (namedtuple): A tuple containining various environment details
+ and system information.
+ """
+ run_lambda = run
+ pip_version, pip_list_output = get_pip_packages(run_lambda)
+
+ if TORCH_AVAILABLE:
+ version_str = torch.__version__
+ debug_mode_str = str(torch.version.debug)
+ cuda_available_str = str(torch.cuda.is_available())
+ cuda_version_str = torch.version.cuda
+ if not hasattr(torch.version, "hip") or torch.version.hip is None: # cuda version
+ hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
+ else: # HIP version
+
+ def get_version_or_na(cfg, prefix):
+ _lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
+ return _lst[0] if _lst else "N/A"
+
+ cfg = torch._C._show_config().split("\n")
+ hip_runtime_version = get_version_or_na(cfg, "HIP Runtime")
+ miopen_runtime_version = get_version_or_na(cfg, "MIOpen")
+ cuda_version_str = "N/A"
+ hip_compiled_version = torch.version.hip
+ else:
+ version_str = debug_mode_str = cuda_available_str = cuda_version_str = "N/A"
+ hip_compiled_version = hip_runtime_version = miopen_runtime_version = "N/A"
+
+ sys_version = sys.version.replace("\n", " ")
+
+ conda_packages = get_conda_packages(run_lambda)
+
+ return SystemEnv(
+ torch_version=version_str,
+ is_debug_build=debug_mode_str,
+ python_version="{} ({}-bit runtime)".format(sys_version, sys.maxsize.bit_length() + 1),
+ python_platform=get_python_platform(),
+ is_cuda_available=cuda_available_str,
+ cuda_compiled_version=cuda_version_str,
+ cuda_runtime_version=get_running_cuda_version(run_lambda),
+ cuda_module_loading=get_cuda_module_loading_config(),
+ nvidia_gpu_models=get_gpu_info(run_lambda),
+ nvidia_driver_version=get_nvidia_driver_version(run_lambda),
+ cudnn_version=get_cudnn_version(run_lambda),
+ hip_compiled_version=hip_compiled_version,
+ hip_runtime_version=hip_runtime_version,
+ miopen_runtime_version=miopen_runtime_version,
+ pip_version=pip_version,
+ pip_packages=pip_list_output,
+ conda_packages=conda_packages,
+ os=get_os(run_lambda),
+ libc_version=get_libc_version(),
+ gcc_version=get_gcc_version(run_lambda),
+ clang_version=get_clang_version(run_lambda),
+ cmake_version=get_cmake_version(run_lambda),
+ caching_allocator_config=get_cachingallocator_config(),
+ is_xnnpack_available=is_xnnpack_available(),
+ cpu_info=get_cpu_info(run_lambda),
+ )
+
+
+env_info_fmt = """
+PyTorch version: {torch_version}
+Is debug build: {is_debug_build}
+CUDA used to build PyTorch: {cuda_compiled_version}
+ROCM used to build PyTorch: {hip_compiled_version}
+
+OS: {os}
+GCC version: {gcc_version}
+Clang version: {clang_version}
+CMake version: {cmake_version}
+Libc version: {libc_version}
+
+Python version: {python_version}
+Python platform: {python_platform}
+Is CUDA available: {is_cuda_available}
+CUDA runtime version: {cuda_runtime_version}
+CUDA_MODULE_LOADING set to: {cuda_module_loading}
+GPU models and configuration: {nvidia_gpu_models}
+Nvidia driver version: {nvidia_driver_version}
+cuDNN version: {cudnn_version}
+HIP runtime version: {hip_runtime_version}
+MIOpen runtime version: {miopen_runtime_version}
+Is XNNPACK available: {is_xnnpack_available}
+
+CPU:
+{cpu_info}
+
+Versions of relevant libraries:
+{pip_packages}
+{conda_packages}
+""".strip()
+
+
+def pretty_str(envinfo):
+ def replace_nones(dct, replacement="Could not collect"):
+ for key in dct.keys():
+ if dct[key] is not None:
+ continue
+ dct[key] = replacement
+ return dct
+
+ def replace_bools(dct, true="Yes", false="No"):
+ for key in dct.keys():
+ if dct[key] is True:
+ dct[key] = true
+ elif dct[key] is False:
+ dct[key] = false
+ return dct
+
+ def prepend(text, tag="[prepend]"):
+ lines = text.split("\n")
+ updated_lines = [tag + line for line in lines]
+ return "\n".join(updated_lines)
+
+ def replace_if_empty(text, replacement="No relevant packages"):
+ if text is not None and len(text) == 0:
+ return replacement
+ return text
+
+ def maybe_start_on_next_line(string):
+ # If `string` is multiline, prepend a \n to it.
+ if string is not None and len(string.split("\n")) > 1:
+ return "\n{}\n".format(string)
+ return string
+
+ mutable_dict = envinfo._asdict()
+
+ # If nvidia_gpu_models is multiline, start on the next line
+ mutable_dict["nvidia_gpu_models"] = maybe_start_on_next_line(envinfo.nvidia_gpu_models)
+
+ # If the machine doesn't have CUDA, report some fields as 'No CUDA'
+ dynamic_cuda_fields = [
+ "cuda_runtime_version",
+ "nvidia_gpu_models",
+ "nvidia_driver_version",
+ ]
+ all_cuda_fields = dynamic_cuda_fields + ["cudnn_version"]
+ all_dynamic_cuda_fields_missing = all(
+ mutable_dict[field] is None for field in dynamic_cuda_fields
+ )
+ if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing:
+ for field in all_cuda_fields:
+ mutable_dict[field] = "No CUDA"
+ if envinfo.cuda_compiled_version is None:
+ mutable_dict["cuda_compiled_version"] = "None"
+
+ # Replace True with Yes, False with No
+ mutable_dict = replace_bools(mutable_dict)
+
+ # Replace all None objects with 'Could not collect'
+ mutable_dict = replace_nones(mutable_dict)
+
+ # If either of these are '', replace with 'No relevant packages'
+ mutable_dict["pip_packages"] = replace_if_empty(mutable_dict["pip_packages"])
+ mutable_dict["conda_packages"] = replace_if_empty(mutable_dict["conda_packages"])
+
+ # Tag conda and pip packages with a prefix
+ # If they were previously None, they'll show up as ie '[conda] Could not collect'
+ if mutable_dict["pip_packages"]:
+ mutable_dict["pip_packages"] = prepend(
+ mutable_dict["pip_packages"], "[{}] ".format(envinfo.pip_version)
+ )
+ if mutable_dict["conda_packages"]:
+ mutable_dict["conda_packages"] = prepend(mutable_dict["conda_packages"], "[conda] ")
+ mutable_dict["cpu_info"] = envinfo.cpu_info
+ return env_info_fmt.format(**mutable_dict)
+
+
+def get_pretty_env_info():
+ """
+ Returns a pretty string of environment information.
+
+ This function retrieves environment information by calling the `get_env_info` function
+ and then formats the information into a human-readable string. The retrieved environment
+ information is listed in the document of `get_env_info`.
+ This function is used in `python collect_env.py` that should be executed when reporting a bug.
+
+ Returns:
+ str: A pretty string of the environment information.
+ """
+ return pretty_str(get_env_info())
+
+
+def main():
+ print("Collecting environment information...")
+ output = get_pretty_env_info()
+ print(output)
+
+ if TORCH_AVAILABLE and hasattr(torch, "utils") and hasattr(torch.utils, "_crash_handler"):
+ minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
+ if sys.platform == "linux" and os.path.exists(minidump_dir):
+ dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)]
+ latest = max(dumps, key=os.path.getctime)
+ ctime = os.path.getctime(latest)
+ creation_time = datetime.datetime.fromtimestamp(ctime).strftime("%Y-%m-%d %H:%M:%S")
+ msg = (
+ "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time)
+ + "if this is related to your bug please include it when you file a report ***"
+ )
+ print(msg, file=sys.stderr)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/.gitignore b/docs/.gitignore
index 792db7346..3e41d33f6 100644
--- a/docs/.gitignore
+++ b/docs/.gitignore
@@ -164,3 +164,4 @@ _version.py
report*.html
.DS_Store
+tmp
diff --git a/docs/COMMUNITY.md b/docs/COMMUNITY.md
new file mode 100644
index 000000000..211dd3ccc
--- /dev/null
+++ b/docs/COMMUNITY.md
@@ -0,0 +1,10 @@
+# Community Integration
+
+- 🔥[Ascend NPU x Cache-DiT](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/)
+- 🎉[Diffusers x Cache-DiT](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit)
+- 🎉[SGLang Diffusion x Cache-DiT](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/docs/cache_dit.md)
+- 🎉[vLLM-Omni x Cache-DiT](https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/diffusion/cache_dit_acceleration/)
+- 🎉[Nunchaku x Cache-DiT](https://nunchaku.tech/docs/nunchaku/usage/cache.html#cache-dit)
+- 🎉[SD.Next x Cache-DiT](https://github.com/vladmandic/sdnext/blob/master/modules/cachedit.py)
+- 🎉[stable-diffusion.cpp x Cache-DiT](https://github.com/leejet/stable-diffusion.cpp/blob/master/cache_dit.hpp)
+- 🎉[jetson-containers x Cache-DiT](https://github.com/dusty-nv/jetson-containers/tree/master/packages/diffusion/cache_edit)
diff --git a/docs/EXAMPLES.md b/docs/EXAMPLES.md
new file mode 100644
index 000000000..5ec1d2e73
--- /dev/null
+++ b/docs/EXAMPLES.md
@@ -0,0 +1,381 @@
+# Examples for Cache-DiT
+
+|Z-Image-ControlNet| Context Parallel: Ulysses 2 | Context Parallel: Ulysses 4 | + ControlNet Parallel |
+|:---:|:---:|:---:|:---:|
+|Base L20x1: 22s|15.7s|12.7s|**🚀7.71s**|
+| | | | |
+| **+ Hybrid Cache** | **+ Torch Compile** | **+ Async Ulyess CP** | **+ FP8 All2All + CUDNN ATTN** |
+|**🚀6.85s**|6.45s|6.38s|**🚀6.19s, 5.47s**|
+| | | |
+
+
+## Installation
+
+```bash
+pip3 install torch==2.9.1 transformers accelerate torchao==0.14.1 bitsandbytes torchvision
+pip3 install opencv-python-headless einops imageio-ffmpeg ftfy
+pip3 install git+https://github.com/huggingface/diffusers.git # latest or >= 0.36.0
+pip3 install git+https://github.com/vipshop/cache-dit.git # latest
+
+git clone https://github.com/vipshop/cache-dit.git && cd cache-dit/examples
+```
+
+## Available Examples
+
+```bash
+python3 generate.py list # list all available examples
+
+[generate.py:47] Available examples:
+[generate.py:53] - ✅ flux_nunchaku - Defalut: nunchaku-tech/nunchaku-flux.1-dev
+[generate.py:53] - ✅ flux - Defalut: black-forest-labs/FLUX.1-dev
+[generate.py:53] - ✅ flux_fill - Defalut: black-forest-labs/FLUX.1-Fill-dev
+[generate.py:53] - ✅ flux2 - Defalut: black-forest-labs/FLUX.2-dev
+[generate.py:53] - ✅ flux2_klein_base_9b - Defalut: black-forest-labs/FLUX.2-klein-base-9B
+[generate.py:53] - ✅ flux2_klein_base_4b - Defalut: black-forest-labs/FLUX.2-klein-base-4B
+[generate.py:53] - ✅ flux2_klein_9b - Defalut: black-forest-labs/FLUX.2-klein-9B
+[generate.py:53] - ✅ flux2_klein_4b - Defalut: black-forest-labs/FLUX.2-klein-4B
+[generate.py:53] - ✅ qwen_image_lightning - Defalut: lightx2v/Qwen-Image-Lightning
+[generate.py:53] - ✅ qwen_image_2512 - Defalut: Qwen/Qwen-Image-2512
+[generate.py:53] - ✅ qwen_image - Defalut: Qwen/Qwen-Image
+[generate.py:53] - ✅ qwen_image_edit_2511_lightning - Defalut: lightx2v/Qwen-Image-Edit-2511-Lightning
+[generate.py:53] - ✅ qwen_image_edit_2511 - Defalut: Qwen/Qwen-Image-Edit-2511
+[generate.py:53] - ✅ qwen_image_edit_lightning - Defalut: lightx2v/Qwen-Image-Lightning
+[generate.py:53] - ✅ qwen_image_edit - Defalut: Qwen/Qwen-Image-Edit-2509
+[generate.py:53] - ✅ qwen_image_controlnet - Defalut: InstantX/Qwen-Image-ControlNet-Inpainting
+[generate.py:53] - ✅ qwen_image_layered - Defalut: Qwen/Qwen-Image-Layered
+[generate.py:53] - ✅ ltx2_t2v - Defalut: Lightricks/LTX-2
+[generate.py:53] - ✅ ltx2_i2v - Defalut: Lightricks/LTX-2
+[generate.py:53] - ✅ skyreels_v2 - Defalut: Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
+[generate.py:53] - ✅ wan2.2_t2v - Defalut: Wan-AI/Wan2.2-T2V-A14B-Diffusers
+[generate.py:53] - ✅ wan2.1_t2v - Defalut: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
+[generate.py:53] - ✅ wan2.2_i2v - Defalut: Wan-AI/Wan2.2-I2V-A14B-Diffusers
+[generate.py:53] - ✅ wan2.1_i2v - Defalut: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
+[generate.py:53] - ✅ wan2.2_vace - Defalut: linoyts/Wan2.2-VACE-Fun-14B-diffusers
+[generate.py:53] - ✅ wan2.1_vace - Defalut: Wan-AI/Wan2.1-VACE-1.3B-diffusers
+[generate.py:53] - ✅ ovis_image - Defalut: AIDC-AI/Ovis-Image-7B
+[generate.py:53] - ✅ zimage_nunchaku - Defalut: nunchaku/nunchaku-z-image-turbo
+[generate.py:53] - ✅ zimage - Defalut: Tongyi-MAI/Z-Image-Turbo
+[generate.py:53] - ✅ zimage_controlnet_2.0 - Defalut: alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0
+[generate.py:53] - ✅ zimage_controlnet_2.1 - Defalut: alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1
+[generate.py:53] - ✅ longcat_image - Defalut: meituan-longcat/LongCat-Image
+[generate.py:53] - ✅ longcat_image_edit - Defalut: meituan-longcat/LongCat-Image-Edit
+```
+
+## Single GPU Inference
+
+The easiest way to enable hybrid cache acceleration for DiTs with cache-dit is to start with single GPU inference. For examples:
+
+```bash
+# baseline
+# use default model path, e.g, "black-forest-labs/FLUX.1-dev"
+python3 generate.py flux
+python3 generate.py flux_nunchaku # need nunchaku library
+python3 generate.py flux2
+python3 generate.py ovis_image
+python3 generate.py qwen_image_edit_lightning
+python3 generate.py qwen_image
+python3 generate.py ltx2_t2v --cache --cpu-offload
+python3 generate.py ltx2_i2v --cache --cpu-offload
+python3 generate.py skyreels_v2
+python3 generate.py wan2.2
+python3 generate.py zimage
+python3 generate.py zimage_nunchaku
+python3 generate.py zimage_controlnet_2.1
+python3 generate.py generate longcat_image
+python3 generate.py generate longcat_image_edit
+# w/ cache acceleration
+python3 generate.py flux --cache
+python3 generate.py flux --cache --taylorseer
+python3 generate.py flux_nunchaku --cache
+python3 generate.py qwen_image --cache
+python3 generate.py zimage --cache --rdt 0.6 --scm fast
+python3 generate.py zimage_controlnet_2.1 --cache --rdt 0.6 --scm fast
+# enable cpu offload or vae tiling if your encounter an OOM error
+python3 generate.py qwen_image --cache --cpu-offload
+python3 generate.py qwen_image --cache --cpu-offload --vae-tiling
+python3 generate.py qwen_image_edit_lightning --cpu-offload --steps 4
+python3 generate.py qwen_image_edit_lightning --cpu-offload --steps 8
+# or, enable sequential cpu offload for extremly low VRAM device
+python3 generate.py flux2 --sequential-cpu-offload # FLUX2 56B total
+# use `--summary` option to show the cache acceleration stats
+python3 generate.py zimage --cache --rdt 0.6 --scm fast --summary
+```
+
+## Custom Model Path
+
+The default model path are the official model names on HuggingFace Hub. Users can set custom local model path by settig `--model-path`. For examples:
+
+```bash
+python3 generate.py flux --model-path /PATH/TO/FLUX.1-dev
+python3 generate.py zimage --model-path /PATH/TO/Z-Image-Turbo
+python3 generate.py qwem_image --model-path /PATH/TO/Qwen-Image
+```
+
+## Multi-GPU Inference
+
+cache-dit is designed to work seamlessly with CPU or Sequential Offloading, 🔥Context Parallelism, 🔥Tensor Parallelism. For examples:
+
+```bash
+# context parallelism or tensor parallelism
+torchrun --nproc_per_node=4 generate.py flux --parallel ulysses
+torchrun --nproc_per_node=4 generate.py flux --parallel ring
+torchrun --nproc_per_node=4 generate.py flux --parallel tp
+torchrun --nproc_per_node=4 generate.py zimage --parallel ulysses
+torchrun --nproc_per_node=4 generate.py zimage_controlnet_2.1 --parallel ulysses
+# ulysses anything attention
+torchrun --nproc_per_node=4 generate.py zimage --parallel ulysses --ulysses-anything
+torchrun --nproc_per_node=4 generate.py qwen_image_edit_lightning --parallel ulysses --ulysses-anything
+# text encoder parallelism, enable it by add: `--parallel-text-encoder`
+torchrun --nproc_per_node=4 generate.py flux --parallel tp --parallel-text-encoder
+torchrun --nproc_per_node=4 generate.py qwen_image_edit_lightning --parallel ulysses --ulysses-anything --parallel-text-encoder
+# Hint: set `--local-ranks-filter=0` to torchrun -> only show logs on rank 0
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py flux --parallel ulysses
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py ltx2_t2v --parallel ulysses --parallel-vae --parallel-text-encoder --cache
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py ltx2_t2v --parallel tp --parallel-vae --parallel-text-encoder --cache
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py ltx2_i2v --parallel ulysses --parallel-vae --parallel-text-encoder --cache
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py ltx2_i2v --parallel tp --parallel-vae --parallel-text-encoder --cache
+```
+
+## Low-bits Quantization
+
+cache-dit is designed to work seamlessly with torch.compile, Quantization (🔥torchao, 🔥nunchaku), For examples:
+
+```bash
+# please also enable torch.compile if the quantation is using.
+python3 generate.py flux --cache --quantize-type float8 --compile
+python3 generate.py flux --cache --quantize-type int8 --compile
+python3 generate.py flux --cache --quantize-type float8_weight_only --compile
+python3 generate.py flux --cache --quantize-type int8_weight_only --compile
+python3 generate.py flux --cache --quantize-type bnb_4bit --compile # w4a16
+python3 generate.py flux_nunchaku --cache --compile # w4a16 SVDQ
+```
+
+## Hybrid Acceleration
+
+Here are some examples for `hybrid cache acceleration + parallelism` for popular DiTs with cache-dit.
+
+```bash
+# DBCache + SCM + Taylorseer
+python3 generate.py flux --cache --scm fast --taylorsees --taylorseer-order 1
+# DBCache + SCM + Taylorseer + Context Parallelism + Text Encoder Parallelism + Compile
+# + FP8 quantization + FP8 All2All comm + CUDNN Attention (--attn _sdpa_cudnn)
+torchrun --nproc_per_node=4 generate.py flux --parallel ulysses --ulysses-float8 \
+ --attn _sdpa_cudnn --parallel-text-encoder --cache --scm fast --taylorseer \
+ --taylorseer-order 1 --quantize-type float8 --warmup 2 --repeat 5 --compile
+# DBCache + SCM + Taylorseer + Context Parallelism + Text Encoder Parallelism + Compile
+# + FP8 quantization + FP8 All2All comm + FP8 SageAttention (--attn sage)
+torchrun --nproc_per_node=4 generate.py flux --parallel ulysses --ulysses-float8 \
+ --attn sage --parallel-text-encoder --cache --scm fast --taylorseer \
+ --taylorseer-order 1 --quantize-type float8 --warmup 2 --repeat 5 --compile
+# Case: Hybrid Acceleration for Qwen-Image-Edit-Lightning, tracking memory usage.
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py qwen_image_edit_lightning \
+ --parallel ulysses --ulysses-anything --parallel-text-encoder \
+ --quantize-type float8_weight_only --steps 4 --track-memory --compile
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py qwen_image_edit_lightning \
+ --parallel tp --parallel-text-encoder --quantize-type float8_weight_only \
+ --steps 4 --track-memory --compile
+# Case: Hybrid Acceleration + Context Parallelism + ControlNet Parallelism, e.g, Z-Image-ControlNet
+torchrun --nproc_per_node=4 generate.py zimage_controlnet_2.1 --parallel ulysses \
+ --parallel-controlnet --cache --rdt 0.6 --scm fast
+torchrun --nproc_per_node=4 generate.py zimage_controlnet_2.1 --parallel ulysses \
+ --parallel-controlnet --cache --scm fast --rdt 0.6 --compile \
+ --compile-controlnet --ulysses-float8 --attn _sdpa_cudnn \
+ --warmup 2 --repeat 4
+```
+
+## End2End Examples
+
+```bash
+# NO Cache Acceleration: 8.27s
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py flux --parallel ulysses
+
+INFO 12-17 09:02:31 [base.py:151] Example Input Summary:
+INFO 12-17 09:02:31 [base.py:151] - prompt: A cat holding a sign that says hello world
+INFO 12-17 09:02:31 [base.py:151] - height: 1024
+INFO 12-17 09:02:31 [base.py:151] - width: 1024
+INFO 12-17 09:02:31 [base.py:151] - num_inference_steps: 28
+INFO 12-17 09:02:31 [base.py:214] Example Output Summary:
+INFO 12-17 09:02:31 [base.py:225] - Model: flux
+INFO 12-17 09:02:31 [base.py:225] - Optimization: C0_Q0_NONE_Ulysses4
+INFO 12-17 09:02:31 [base.py:225] - Load Time: 0.79s
+INFO 12-17 09:02:31 [base.py:225] - Warmup Time: 21.09s
+INFO 12-17 09:02:31 [base.py:225] - Inference Time: 8.27s
+INFO 12-17 09:02:32 [base.py:182] Image saved to flux.1024x1024.C0_Q0_NONE_Ulysses4.png
+
+# Enabled Cache Acceleration: 4.23s
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py flux --parallel ulysses --cache --scm fast
+
+INFO 12-17 09:10:09 [base.py:151] Example Input Summary:
+INFO 12-17 09:10:09 [base.py:151] - prompt: A cat holding a sign that says hello world
+INFO 12-17 09:10:09 [base.py:151] - height: 1024
+INFO 12-17 09:10:09 [base.py:151] - width: 1024
+INFO 12-17 09:10:09 [base.py:151] - num_inference_steps: 28
+INFO 12-17 09:10:09 [base.py:214] Example Output Summary:
+INFO 12-17 09:10:09 [base.py:225] - Model: flux
+INFO 12-17 09:10:09 [base.py:225] - Optimization: C0_Q0_DBCache_F1B0_W8I1M0MC3_R0.24_CFG0_T0O0_Ulysses4_S15
+INFO 12-17 09:10:09 [base.py:225] - Load Time: 0.78s
+INFO 12-17 09:10:09 [base.py:225] - Warmup Time: 18.49s
+INFO 12-17 09:10:09 [base.py:225] - Inference Time: 4.23s
+INFO 12-17 09:10:09 [base.py:182] Image saved to flux.1024x1024.C0_Q0_DBCache_F1B0_W8I1M0MC3_R0.24_CFG0_T0O0_Ulysses4_S15.png
+```
+
+|NO Cache Acceleration: 8.27s| w/ Cache Acceleration: 4.23s|
+|:---:|:---:|
+|||
+
+## How to Add New Example
+
+It is very easy to add a new example. Please refer to the specific implementation in [registers.py](https://github.com/vipshop/cache-dit/raw/main/examples/registers.py). For example:
+
+```python
+@ExampleRegister.register("flux")
+def flux_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import FluxPipeline
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=_path("black-forest-labs/FLUX.1-dev"),
+ pipeline_class=FluxPipeline,
+ # `text_encoder_2` will be quantized when `--quantize-type`
+ # is set to `bnb_4bit`.
+ bnb_4bit_components=["text_encoder_2"],
+ ),
+ input_data=ExampleInputData(
+ prompt="A cat holding a sign that says hello world",
+ height=1024,
+ width=1024,
+ num_inference_steps=28,
+ ),
+ )
+
+# NOTE: DON'T forget to add `flux_example` into helpers.py
+```
+
+## More Usages about Examples
+
+```bash
+python3 generate.py --help
+
+usage: generate.py [-h] [--model-path MODEL_PATH] [--controlnet-path CONTROLNET_PATH] [--lora-path LORA_PATH] [--transformer-path TRANSFORMER_PATH] [--image-path IMAGE_PATH] [--mask-image-path MASK_IMAGE_PATH] [--prompt PROMPT]
+ [--negative-prompt NEGATIVE_PROMPT] [--num_inference_steps NUM_INFERENCE_STEPS] [--warmup WARMUP] [--repeat REPEAT] [--height HEIGHT] [--width WIDTH] [--seed SEED] [--num-frames NUM_FRAMES] [--save-path SAVE_PATH] [--cache]
+ [--cache-summary] [--Fn-compute-blocks FN_COMPUTE_BLOCKS] [--Bn-compute-blocks BN_COMPUTE_BLOCKS] [--residual-diff-threshold RESIDUAL_DIFF_THRESHOLD] [--max-warmup-steps MAX_WARMUP_STEPS] [--warmup-interval WARMUP_INTERVAL]
+ [--max-cached-steps MAX_CACHED_STEPS] [--max-continuous-cached-steps MAX_CONTINUOUS_CACHED_STEPS] [--taylorseer] [--taylorseer-order TAYLORSEER_ORDER] [--steps-mask] [--mask-policy {None,slow,s,medium,m,fast,f,ultra,u}]
+ [--quantize] [--quantize-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}] [--quantize-text-encoder]
+ [--quantize-text-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}] [--quantize-controlnet]
+ [--quantize-controlnet-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}] [--parallel-type {None,tp,ulysses,ring}] [--parallel-vae]
+ [--parallel-text-encoder] [--parallel-controlnet] [--attn {None,flash,_flash_3,native,_native_cudnn,_sdpa_cudnn,sage}] [--ulysses-anything] [--ulysses-float8] [--ulysses-async] [--cpu-offload]
+ [--sequential-cpu-offload] [--device-map-balance] [--vae-tiling] [--vae-slicing] [--compile] [--compile-repeated-blocks] [--compile-vae] [--compile-text-encoder] [--compile-controlnet] [--max-autotune] [--track-memory]
+ [--profile] [--profile-name PROFILE_NAME] [--profile-dir PROFILE_DIR] [--profile-activities {CPU,GPU,MEM} [{CPU,GPU,MEM} ...]] [--profile-with-stack] [--profile-record-shapes] [--disable-fuse-lora DISABLE_FUSE_LORA]
+ [{generate,list,flux_nunchaku,flux,flux2,qwen_image_lightning,qwen_image,qwen_image_edit_lightning,qwen_image_edit,qwen_image_controlnet,skyreels_v2,wan2.2_t2v,wan2.1_t2v,wan2.2_i2v,wan2.1_i2v,wan2.2_vace,wan2.1_vace,ovis_image,zimage,zimage_controlnet,longcat_image,longcat_image_edit}]
+ [{None,flux_nunchaku,flux,flux2,qwen_image_lightning,qwen_image,qwen_image_edit_lightning,qwen_image_edit,qwen_image_controlnet,skyreels_v2,wan2.2_t2v,wan2.1_t2v,wan2.2_i2v,wan2.1_i2v,wan2.2_vace,wan2.1_vace,ovis_image,zimage,zimage_controlnet,longcat_image,longcat_image_edit}]
+
+positional arguments:
+ {generate,list,flux_nunchaku,flux,flux2,qwen_image_lightning,qwen_image,qwen_image_edit_lightning,qwen_image_edit,qwen_image_controlnet,skyreels_v2,wan2.2_t2v,wan2.1_t2v,wan2.2_i2v,wan2.1_i2v,wan2.2_vace,wan2.1_vace,ovis_image,zimage,zimage_controlnet,longcat_image,longcat_image_edit}
+ The task to perform or example name to run. Use 'list' to list all available examples, or specify an example name directly (defaults to 'generate' task).
+ {None,flux_nunchaku,flux,flux2,qwen_image_lightning,qwen_image,qwen_image_edit_lightning,qwen_image_edit,qwen_image_controlnet,skyreels_v2,wan2.2_t2v,wan2.1_t2v,wan2.2_i2v,wan2.1_i2v,wan2.2_vace,wan2.1_vace,ovis_image,zimage,zimage_controlnet,longcat_image,longcat_image_edit}
+ Names of the examples to run. If not specified, skip running example.
+
+options:
+ -h, --help show this help message and exit
+ --model-path MODEL_PATH
+ Override model path if provided
+ --controlnet-path CONTROLNET_PATH
+ Override controlnet model path if provided
+ --lora-path LORA_PATH
+ Override lora model path if provided
+ --transformer-path TRANSFORMER_PATH
+ Override transformer model path if provided
+ --image-path IMAGE_PATH
+ Override image path if provided
+ --mask-image-path MASK_IMAGE_PATH
+ Override mask image path if provided
+ --prompt PROMPT Override default prompt if provided
+ --negative-prompt NEGATIVE_PROMPT
+ Override default negative prompt if provided
+ --num_inference_steps NUM_INFERENCE_STEPS, --steps NUM_INFERENCE_STEPS
+ Number of inference steps
+ --warmup WARMUP Number of warmup steps before measuring performance
+ --repeat REPEAT Number of times to repeat the inference for performance measurement
+ --height HEIGHT Height of the generated image
+ --width WIDTH Width of the generated image
+ --seed SEED Random seed for reproducibility
+ --num-frames NUM_FRAMES, --frames NUM_FRAMES
+ Number of frames to generate for video
+ --save-path SAVE_PATH
+ Path to save the generated output, e.g., output.png or output.mp4
+ --cache Enable Cache Acceleration
+ --cache-summary, --summary
+ Enable Cache Summary logging
+ --Fn-compute-blocks FN_COMPUTE_BLOCKS, --Fn FN_COMPUTE_BLOCKS
+ CacheDiT Fn_compute_blocks parameter
+ --Bn-compute-blocks BN_COMPUTE_BLOCKS, --Bn BN_COMPUTE_BLOCKS
+ CacheDiT Bn_compute_blocks parameter
+ --residual-diff-threshold RESIDUAL_DIFF_THRESHOLD, --rdt RESIDUAL_DIFF_THRESHOLD
+ CacheDiT residual diff threshold
+ --max-warmup-steps MAX_WARMUP_STEPS, --ws MAX_WARMUP_STEPS
+ Maximum warmup steps for CacheDiT
+ --warmup-interval WARMUP_INTERVAL, --wi WARMUP_INTERVAL
+ Warmup interval for CacheDiT
+ --max-cached-steps MAX_CACHED_STEPS, --mc MAX_CACHED_STEPS
+ Maximum cached steps for CacheDiT
+ --max-continuous-cached-steps MAX_CONTINUOUS_CACHED_STEPS, --mcc MAX_CONTINUOUS_CACHED_STEPS
+ Maximum continuous cached steps for CacheDiT
+ --taylorseer Enable TaylorSeer for CacheDiT
+ --taylorseer-order TAYLORSEER_ORDER, -order TAYLORSEER_ORDER
+ TaylorSeer order
+ --steps-mask Enable steps mask for CacheDiT
+ --mask-policy {None,slow,s,medium,m,fast,f,ultra,u}, --scm {None,slow,s,medium,m,fast,f,ultra,u}
+ Pre-defined steps computation mask policy
+ --quantize, --q Enable quantization for transformer
+ --quantize-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}, --q-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}
+ --quantize-text-encoder, --q-text
+ Enable quantization for text encoder
+ --quantize-text-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}, --q-text-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}
+ --quantize-controlnet, --q-controlnet
+ Enable quantization for text encoder
+ --quantize-controlnet-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}, --q-controlnet-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}
+ --parallel-type {None,tp,ulysses,ring}, --parallel {None,tp,ulysses,ring}
+ --parallel-vae Enable VAE parallelism if applicable.
+ --parallel-text-encoder, --parallel-text
+ Enable text encoder parallelism if applicable.
+ --parallel-controlnet
+ Enable ControlNet parallelism if applicable.
+ --attn {None,flash,_flash_3,native,_native_cudnn,_sdpa_cudnn,sage}
+ --ulysses-anything, --uaa
+ Enable Ulysses Anything Attention for context parallelism
+ --ulysses-float8, --ufp8
+ Enable Ulysses Attention/UAA Float8 for context parallelism
+ --ulysses-async, --uaqkv
+ Enabled experimental Async QKV Projection with Ulysses for context parallelism
+ --cpu-offload, --cpu-offload-model
+ Enable CPU offload for model if applicable.
+ --sequential-cpu-offload
+ Enable sequential GPU offload for model if applicable.
+ --device-map-balance, --device-map
+ Enable automatic device map balancing model if multiple GPUs are available.
+ --vae-tiling Enable VAE tiling for low memory device.
+ --vae-slicing Enable VAE slicing for low memory device.
+ --compile Enable compile for transformer
+ --compile-repeated-blocks
+ Enable compile for repeated blocks in transformer
+ --compile-vae Enable compile for VAE
+ --compile-text-encoder, --compile-text
+ Enable compile for text encoder
+ --compile-controlnet Enable compile for ControlNet
+ --max-autotune Enable max-autotune mode for torch.compile
+ --track-memory Track and report peak GPU memory usage
+ --profile Enable profiling with torch.profiler
+ --profile-name PROFILE_NAME
+ Name for the profiling session
+ --profile-dir PROFILE_DIR
+ Directory to save profiling results
+ --profile-activities {CPU,GPU,MEM} [{CPU,GPU,MEM} ...]
+ Activities to profile (CPU, GPU, MEM)
+ --profile-with-stack profile with stack for better traceability
+ --profile-record-shapes
+ profile record shapes for better analysis
+ --disable-fuse-lora DISABLE_FUSE_LORA
+ Disable fuse_lora even if lora weights are provided.
+```
diff --git a/docs/FAQ.md b/docs/FAQ.md
new file mode 100644
index 000000000..189ee1c7b
--- /dev/null
+++ b/docs/FAQ.md
@@ -0,0 +1,88 @@
+# Frequently Asked Questions (FAQ)
+
+## Installation & Dependencies
+
+### How to install Flash Attention 3 (FA3)?
+
+Flash Attention 3 provides optimized attention kernels for better performance. To install:
+
+```bash
+git clone git@github.com:Dao-AILab/flash-attention.git
+cd flash-attention/hopper
+python setup.py install
+```
+
+After installation, you need to modify the attention dispatch file:
+
+```bash
+vi /usr/local/lib/python3.12/dist-packages/diffusers/models/attention_dispatch.py
+```
+
+Find `_diffusers_flash_attn_3::_flash_attn_forward` and add `return_attn_probs=True`:
+
+```python
+return_attn_probs=True
+```
+
+### How to install Sage Attention?
+
+Sage Attention is an efficient attention implementation. To install:
+
+```bash
+git clone https://github.com/thu-ml/SageAttention.git
+cd SageAttention
+export EXT_PARALLEL=4 NVCC_APPEND_FLAGS="--threads 8" MAX_JOBS=32 # Optional
+export TORCH_CUDA_ARCH_LIST=9.0
+python setup.py install
+```
+
+## Common Issues
+
+### torch.compile errors when running examples
+
+If you encounter errors with `torch.compile` when running cache-dit examples, try the following solutions:
+
+1. **Clear the torch inductor cache:**
+ ```bash
+ rm -rf /tmp/torchinductor_root/
+ ```
+ Then retry running your example.
+
+2. **Upgrade PyTorch to the latest version:**
+ ```bash
+ pip install --upgrade torch torchvision
+ ```
+
+3. **If the issue persists:**
+ Please [open an issue](https://github.com/vipshop/cache-dit/issues) with:
+ - Your PyTorch version (`python -c "import torch; print(torch.__version__)"`)
+ - The complete error traceback
+ - Your system configuration (GPU model, CUDA version, etc.)
+
+## Performance Optimization
+
+### Which attention backend should I use?
+
+Cache-DiT supports multiple attention backends for different use cases. For a complete overview of attention backends in diffusers, see the [official documentation](https://github.com/huggingface/diffusers/blob/main/docs/source/en/optimization/attention_backends.md).
+
+Currently supported backends in cache-dit (see [`examples/utils.py#L126`](https://github.com/vipshop/cache-dit/blob/main/examples/utils.py#L126)):
+
+- **`flash`**: Flash Attention 2 - Good performance on Ampere/Ada GPUs
+- **`_flash_3`**: Flash Attention 3 - Best for Hopper architecture GPUs (H100, H200)
+- **`native`**: Native PyTorch SDPA - Default, works on all devices
+- **`_native_cudnn`**: cuDNN-based native attention
+- **`_sdpa_cudnn`**: SDPA with cuDNN (cache-dit specific, supports context parallelism with attention masks)
+- **`sage`**: Sage Attention - Good balance between performance and compatibility
+
+**Recommendation:**
+- **H100/H200**: Use `_flash_3` for best performance
+- **A100/A6000**: Use `flash` or `sage`
+- **Other GPUs**: Use `native` or `sage`
+
+
+## Other Questions
+
+For other questions or issues not covered here, please:
+1. Check the [documentation](https://cache-dit.readthedocs.io/en/latest/)
+2. Search [existing issues](https://github.com/vipshop/cache-dit/issues)
+3. [Open a new issue](https://github.com/vipshop/cache-dit/issues/new) if needed
diff --git a/docs/README.md b/docs/README.md
new file mode 100644
index 000000000..da467758f
--- /dev/null
+++ b/docs/README.md
@@ -0,0 +1,105 @@
+
+
+
+
+Star
+Watch
+Fork
+
+
+|Baseline|SCM Slow|SCM Fast|SCM Ultra|+compile|+FP8*|+CP2|
+|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
+|24.85s|15.4s|11.4s|8.2s|**🎉7.1s**|**🎉4.5s**|**🎉2.9s**|
+| | | | | | | |
+
+**🤗Why Cache-DiT❓❓**Cache-DiT is built on top of the Diffusers library and now supports nearly **[🔥ALL](https://cache-dit.readthedocs.io/en/latest/)** DiTs from Diffusers, including over **[🤗70+](https://github.com/vipshop/cache-dit)** DiTs. Please refer to our online documentation at [readthedocs.io](https://cache-dit.readthedocs.io/en/latest/) for more details. The optimizations made by Cache-DiT include: (**UAA**: [Ulysses Anything Attention](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention))
+
+- 🎉**Hybrid Cache Acceleration** ([**DBCache**](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/#dbcache-dual-block-cache), DBPrune, [**TaylorSeer**](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/#hybrid-taylorseer-calibrator), [**SCM**](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/#scm-steps-computation-masking) and more)
+- 🎉**Context Parallelism** (w/ Extended Diffusers' CP APIs, [**UAA**](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention), Async Ulysses, FP8 comm)
+- 🎉**Tensor Parallelism** (w/ PyTorch native DTensor and Tensor Parallelism APIs)
+- 🎉**Text Encoder Parallelism** (w/ PyTorch native DTensor and Tensor Parallelism APIs)
+- 🎉**Auto Encoder (VAE) Parallelism** (w/ Data or Tile Parallelism, avoid OOM)
+- 🎉**ControlNet Parallelism** (w/ Context Parallelism for ControlNet module)
+- 🎉Built-in **HTTP serving** deployment support with simple REST APIs
+- 🎉**Natively** compatible with **Compile**, **Offloading**, **Quantization**, ...
+- 🎉Integration into **vLLM-Omni**, **SGLang Diffusion**, SD.Next, ...
+- 🎉**Natively** supports **NVIDIA GPUs**, [**Ascend NPUs**](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/) (>= 1.2.0), ...
+
+## 🔥Latest News
+
+- [2026/01] **[🎉v1.2.0 Major Release](https://github.com/vipshop/cache-dit)** is ready: New Models Support(Z-Image, FLUX.2, LTX-2, etc), Request level Cache Context, HTTP Serving, [Ulysses Anything Attention](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention), TE-P, VAE-P, CN-P and [Ascend NPUs](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/) Support.
+
+## 🚀Quick Start
+
+You can install the cache-dit from PyPI or from source:
+```bash
+pip3 install -U cache-dit # or, pip3 install git+https://github.com/vipshop/cache-dit.git
+```
+Then try ♥️ Cache Acceleration with just **one line** of code ~ ♥️
+```python
+>>> import cache_dit
+>>> from diffusers import DiffusionPipeline
+>>> # The pipe can be any diffusion pipeline.
+>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
+>>> # Cache Acceleration with One-line code.
+>>> cache_dit.enable_cache(pipe)
+>>> # Or, Hybrid Cache Acceleration + Parallelism.
+>>> from cache_dit import DBCacheConfig, ParallelismConfig
+>>> cache_dit.enable_cache(
+... pipe, cache_config=DBCacheConfig(),
+... parallelism_config=ParallelismConfig(ulysses_size=2)
+... )
+>>> from cache_dit import load_configs
+>>> # Or, Load Acceleration config from a custom yaml file.
+>>> cache_dit.enable_cache(pipe, **load_configs("config.yaml"))
+>>> output = pipe(...) # Just call the pipe as normal.
+```
+Please refer to our online documentation at [readthedocs.io](https://cache-dit.readthedocs.io/en/latest/) for more details.
+
+## 🚀Quick Links
+
+- [📊Examples](https://github.com/vipshop/cache-dit/tree/main/examples/) - The **easiest** way to enable **hybrid cache acceleration** and **parallelism** for DiTs with cache-dit is to start with our examples for popular models: FLUX, Z-Image, Qwen-Image, Wan, etc.
+- [🌐HTTP Serving](https://cache-dit.readthedocs.io/en/latest) - Deploy cache-dit models with HTTP API for **text-to-image**, **image editing**, **multi-image editing**, and **text/image-to-video** generation.
+- [🎉User Guide](https://cache-dit.readthedocs.io/en/latest/) - For more advanced features, please refer to the [🎉User Guide](https://cache-dit.readthedocs.io/en/latest/) for details.
+- [❓FAQ](https://cache-dit.readthedocs.io/en/latest) - Frequently asked questions including attention backend configuration, troubleshooting, and optimization tips.
+
+## 🌐Community Integration
+
+- 🔥[Ascend NPU x Cache-DiT](https://cache-dit.readthedocs.io/en/latest/user_guide/ASCEND_NPU/)
+- 🎉[Diffusers x Cache-DiT](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit)
+- 🎉[SGLang Diffusion x Cache-DiT](https://github.com/sgl-project/sglang/blob/main/python/sglang/multimodal_gen/docs/cache_dit.md)
+- 🎉[vLLM-Omni x Cache-DiT](https://docs.vllm.ai/projects/vllm-omni/en/latest/user_guide/diffusion/cache_dit_acceleration/)
+- 🎉[Nunchaku x Cache-DiT](https://nunchaku.tech/docs/nunchaku/usage/cache.html#cache-dit)
+- 🎉[SD.Next x Cache-DiT](https://github.com/vladmandic/sdnext/blob/master/modules/cachedit.py)
+- 🎉[stable-diffusion.cpp x Cache-DiT](https://github.com/leejet/stable-diffusion.cpp/blob/master/cache_dit.hpp)
+- 🎉[jetson-containers x Cache-DiT](https://github.com/dusty-nv/jetson-containers/tree/master/packages/diffusion/cache_edit)
+
+## ©️Acknowledgements
+
+Special thanks to vipshop's Computer Vision AI Team for supporting document, testing and deployment of this project. We learned the design and reused code from the following projects: [Diffusers](https://huggingface.co/docs/diffusers), [SGLang](https://github.com/sgl-project/sglang), [vLLM-Omni](https://github.com/vllm-project/vllm-omni), [ParaAttention](https://github.com/chengzeyi/ParaAttention), [xDiT](https://github.com/xdit-project/xDiT), [TaylorSeer](https://github.com/Shenyi-Z/TaylorSeer) and [LeMiCa](https://github.com/UnicomAI/LeMiCa).
+
+## ©️Citations
+
+
+
+```BibTeX
+@misc{cache-dit@2025,
+ title={cache-dit: A PyTorch-native and Flexible Inference Engine with Hybrid Cache Acceleration and Parallelism for DiTs.},
+ url={https://github.com/vipshop/cache-dit.git},
+ note={Open-source software available at https://github.com/vipshop/cache-dit.git},
+ author={DefTruth, vipshop.com},
+ year={2025}
+}
+```
diff --git a/docs/README_CN.md b/docs/README_CN.md
deleted file mode 100644
index 374d4c7aa..000000000
--- a/docs/README_CN.md
+++ /dev/null
@@ -1,213 +0,0 @@
-📚English | 📚中文阅读
-
-
-
-## 🔥重点
-
-我们非常兴奋地宣布,cache-dit 的**首个 API 稳定版本 (v1.0.0)**终于正式发布!
-
-**[cache-dit](https://github.com/vipshop/cache-dit)** 是一款为 🤗 Diffusers 打造的**统一化(Unified)、高灵活(Flexible)、无需训练(Training-free)** 的缓存加速框架,仅需**一行代码**即可实现缓存加速。核心特性包括**统一缓存接口(Unified Cache APIs)**、**前向模式匹配(Forward Pattern Matching)**、**自动块适配(Automatic Block Adapter)**、**混合前向模式(Hybrid Forward Pattern)**、**DBCache 机制**、**TaylorSeer 校准器(TaylorSeer Calibrator)** 及**Cache CFG**。
-
-```bash
-pip3 install -U cache-dit # pip3 install git+https://github.com/vipshop/cache-dit.git
-```
-
-您可以从 PyPI 安装 cache-dit 的稳定版本,或从 GitHub 安装最新的开发版本。然后,只需一行代码即可体验 ♥️ 缓存加速~♥️
-
-```python
->>> import cache_dit
->>> from diffusers import DiffusionPipeline
->>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
->>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
->>> output = pipe(...) # Just call the pipe as normal.
->>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
->>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
-```
-
-
-
-点击这里查看更多Image/Video加速示例
-
-
-
-
-### 📚核心特性
-
-- **全面支持 🤗 Diffusers**:值得注意的是,**[cache-dit](https://github.com/vipshop/cache-dit)** 目前已支持 Diffusers 中几乎**所有**基于 DiT(Transformer 扩散模型)的流水线,例如 Qwen-Image、FLUX.1、Qwen-Image-Lightning、Wan 2.1/2.2、HunyuanImage-2.1、HunyuanVideo、HunyuanDiT、HiDream、AuraFlow、CogView3Plus、CogView4、LTXVideo、CogVideoX/X 1.5、ConsisID、Cosmos、SkyReelsV2、VisualCloze、OmniGen 1/2、Lumina 1/2、PixArt、Chroma、Sana、Allegro、Mochi、SD 3/3.5、Amused 以及 DiT-XL 等。
-- **极致易用**:在大多数场景下,仅需**♥️ 一行 ♥️** 代码即可启用:`cache_dit.enable_cache(...)`。调用该接口后,正常使用流水线即可享受加速。
-- **轻松集成新模型**:统一缓存接口、前向模式匹配、自动块适配、混合前向模式及 Patch Functor 等特性,使其具备极强的功能性与灵活性。例如,我们实现了对 [HunyuanImage-2.1](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) 的 🎉 首日支持(Day 1 Support)——即便该模型当时尚未在 Diffusers 库中正式发布。
-- **业界领先性能**:与 Δ-DiT、Chipmunk、FORA、DuCa、TaylorSeer、FoCa 等算法相比,在加速比低于 4 倍的场景下,cache-dit 的 DBCache 机制实现了最优精度。
-- **支持 4/8 步蒸馏模型**:令人惊喜的是,cache-dit 的 DBCache 机制可适配极少量步数的蒸馏模型,而这是许多其他方法无法实现的。
-- **兼容多种优化方案**:设计上可与 torch.compile、模型 CPU 卸载、顺序 CPU 卸载、分组卸载等优化方案无缝协同。
-- **混合缓存加速**:目前已支持 **DBCache + 校准器** 混合方案(例如 DBCache + TaylorSeerCalibrator)。其中 DBCache 作为**指示器(Indicator)** 决定*何时(when)* 缓存,校准器则负责决定*如何(how)* 缓存。未来将支持更多主流缓存加速算法(如 FoCa 等)及更多基准测试,敬请期待更新!
-- **🤗 Diffusers 生态集成**:🔥 **cache-dit** 已正式加入 🤗 Diffusers 社区生态,成为**首个**针对 DiT 的缓存加速框架!查看文档:**[Diffusers 官方文档](https://huggingface.co/docs/diffusers/main/en/optimization/cache_dit)**。
-
-
-
-
-## 🎉用户指引
-
-
-
-对于更高级的功能,如**Unified Cache APIs**、**Forward Pattern Matching**、**Automatic Block Adapter**、**Hybrid Forward Pattern**、**DBCache**、**TaylorSeer Calibrator**和**Hybrid Cache CFG**,详情请参考[🎉User_Guide.md](./docs/User_Guide.md)。
-
-- [⚙️Installation](./docs/User_Guide.md#️installation)
-- [🔥Benchmarks](./docs/User_Guide.md#benchmarks)
-- [🔥Supported Pipelines](./docs/User_Guide.md#supported-pipelines)
-- [🎉Unified Cache APIs](./docs/User_Guide.md#unified-cache-apis)
- - [📚Forward Pattern Matching](./docs/User_Guide.md#forward-pattern-matching)
- - [📚Cache with One-line Code](./docs/User_Guide.md#%EF%B8%8Fcache-acceleration-with-one-line-code)
- - [🔥Automatic Block Adapter](./docs/User_Guide.md#automatic-block-adapter)
- - [📚Hybrid Forward Pattern](./docs/User_Guide.md#hybrid-forward-pattern)
- - [📚Implement Patch Functor](./docs/User_Guide.md#implement-patch-functor)
- - [🤖Cache Acceleration Stats](./docs/User_Guide.md#cache-acceleration-stats-summary)
-- [⚡️DBCache: Dual Block Cache](./docs/User_Guide.md#️dbcache-dual-block-cache)
-- [⚡️DBPrune: Dynamic Block Prune](./docs/User_Guide.md#️dbprune-dynamic-block-prune)
-- [⚡️Hybrid Cache CFG](./docs/User_Guide.md#️hybrid-cache-cfg)
-- [🔥Hybrid TaylorSeer Calibrator](./docs/User_Guide.md#taylorseer-calibrator)
-- [⚡️Hybrid Context Parallelism](./docs/User_Guide.md#context-paralleism)
-- [🛠Metrics Command Line](./docs/User_Guide.md#metrics-cli)
-- [⚙️Torch Compile](./docs/User_Guide.md#️torch-compile)
-- [📚API Documents](./docs/User_Guide.md#api-documentation)
-
-## 👋参与贡献
-
-
-
-如何贡献?点亮星标 ⭐️ 支持我们,或查看 [CONTRIBUTE.md](https://github.com/vipshop/cache-dit/blob/main/CONTRIBUTE.md)。
-
-
-
-## ©️特别声明
-
-本项目的顺利推进与落地,离不开 唯品会-计算机视觉算法团队 的鼎力支持。特别鸣谢该团队在文档建设、功能测试及生产级应用落地等关键环节提供的专业指导与全面协助。
-
-## ©️引用我们
-
-
-
-```BibTeX
-@misc{cache-dit@2025,
- title={cache-dit: A Unified, Flexible and Training-free Cache Acceleration Framework for Diffusers.},
- url={https://github.com/vipshop/cache-dit.git},
- note={Open-source software available at https://github.com/vipshop/cache-dit.git},
- author={vipshop.com},
- year={2025}
-}
-```
diff --git a/docs/User_Guide.md b/docs/User_Guide.md
deleted file mode 100644
index ac39f6470..000000000
--- a/docs/User_Guide.md
+++ /dev/null
@@ -1,1051 +0,0 @@
-
-
-
- CacheDiT: A PyTorch-native and Flexible Inference Engine with 🤗🎉 Hybrid Cache Acceleration and Parallelism for DiTs
-
-
-
-
-
-## 📖Table of Contents
-
-
-
-- [⚙️Installation](#️installation)
-- [🔥Supported DiTs](#supported)
-- [🔥Benchmarks](#benchmarks)
-- [🎉Unified Cache APIs](#unified)
- - [📚Forward Pattern Matching](#forward-pattern-matching)
- - [📚Cache with One-line Code](#%EF%B8%8Fcache-acceleration-with-one-line-code)
- - [🔥Automatic Block Adapter](#automatic-block-adapter)
- - [📚Hybrid Forward Pattern](#automatic-block-adapter)
- - [📚Implement Patch Functor](#implement-patch-functor)
- - [📚Transformer-Only Interface](#transformer-only-interface)
- - [📚How to use ParamsModifier](#how-to-use-paramsmodifier)
- - [🤖Cache Acceleration Stats](#cache-acceleration-stats-summary)
-- [⚡️DBCache: Dual Block Cache](#dbcache)
-- [⚡️DBPrune: Dynamic Block Prune](#dbprune)
-- [⚡️Hybrid Cache CFG](#cfg)
-- [🔥Hybrid TaylorSeer Calibrator](#taylorseer)
-- [🤖SCM: Steps Computation Masking](#steps-mask)
-- [⚡️Hybrid Context Parallelism](#context-parallelism)
-- [🤖UAA: Ulysses Anything Attention](#ulysses-anything-attention)
-- [⚡️Hybrid Tensor Parallelism](#tensor-parallelism)
-- [🤖Low-bits Quantization](#quantization)
-- [🤖How to use FP8 Attention](#fp8-attention)
-- [🛠Metrics Command Line](#metrics)
-- [⚙️Torch Compile](#compile)
-- [📚API Documents](#api-docs)
-
-## ⚙️Installation
-
-
-
-You can install the stable release of `cache-dit` from PyPI:
-
-```bash
-pip3 install -U cache-dit # or, pip3 install -U "cache-dit[all]" for all features
-```
-Or you can install the latest develop version from GitHub:
-
-```bash
-pip3 install git+https://github.com/vipshop/cache-dit.git
-```
-Please also install the latest main branch of diffusers for context parallelism:
-```bash
-pip3 install git+https://github.com/huggingface/diffusers.git
-```
-
-## 🔥Supported DiTs
-
-
-
-Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Examples](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline) for more details. Here are just some of the tested models listed.
-
-```python
->>> import cache_dit
->>> cache_dit.supported_pipelines()
-(32, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',
-'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',
-'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',
-'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*',
-'Kandinsky5*', 'PRX*'])
-```
-
-> [!Tip]
-> One **Model Series** may contain **many** pipelines. cache-dit applies optimizations at the **Transformer** level; thus, any pipelines that include the supported transformer are already supported by cache-dit. ✅: known work and official supported now; ✖️: unofficial supported now, but maybe support in the future; **[`Q`](https://github.com/nunchaku-tech/nunchaku)**: **4-bits** models w/ [nunchaku](https://github.com/nunchaku-tech/nunchaku) + SVDQ **W4A4**.
-
-
-
-| 📚Model | Cache | CP | TP | 📚Model | Cache | CP | TP |
-|:---|:---|:---|:---|:---|:---|:---|:---|
-| **🎉[FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[FLUX.1 `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[FLUX.1-Fill](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[FLUX.1-Fill `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[Qwen-Image](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen-Image `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[Qwen...Edit](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen...Edit `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[Qwen...Lightning](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen...Light `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[Qwen...Control..](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Qwen...E...Light `Q`](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[Wan 2.1 I2V/T2V](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Mochi](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✅ |
-| **🎉[Wan 2.1 VACE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[HiDream](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[Wan 2.2 I2V/T2V](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[HunyunDiT](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✅ |
-| **🎉[HunyuanVideo](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Sana](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[ChronoEdit](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Bria](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[CogVideoX](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[SkyReelsV2](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ |
-| **🎉[CogVideoX 1.5](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Lumina 1/2](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[CogView4](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[DiT-XL](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✖️ |
-| **🎉[CogView3Plus](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Allegro](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[PixArt Sigma](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[Cosmos](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[PixArt Alpha](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[OmniGen](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[Chroma-HD](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ️✅ | **🎉[EasyAnimate](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[VisualCloze](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[StableDiffusion3](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[HunyuanImage](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[PRX T2I](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[Kandinsky5](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅️ | ✅️ | **🎉[Amused](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[LTXVideo](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[AuraFlow](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-| **🎉[ConsisID](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✅ | ✅ | **🎉[LongCatVideo](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline)** | ✅ | ✖️ | ✖️ |
-
-
-
-## 🔥Benchmarks
-
-
-
-cache-dit will support more mainstream Cache acceleration algorithms in the future. More benchmarks will be released, please stay tuned for update. Here, only the results of some precision and performance benchmarks are presented. The test dataset is **DrawBench**. For a complete benchmark, please refer to [📚Benchmarks](https://github.com/vipshop/cache-dit/tree/main/bench/).
-
-### 📚Text2Image DrawBench: FLUX.1-dev
-
-Comparisons between different FnBn compute block configurations show that **more compute blocks result in higher precision**. For example, the F8B0_W8MC0 configuration achieves the best Clip Score (33.007) and ImageReward (1.0333). **Device**: NVIDIA L20. **F**: Fn_compute_blocks, **B**: Bn_compute_blocks, 50 steps.
-
-
-
-
-| Config | Clip Score(↑) | ImageReward(↑) | PSNR(↑) | TFLOPs(↓) | SpeedUp(↑) |
-| --- | --- | --- | --- | --- | --- |
-| [**FLUX.1**-dev]: 50 steps | 32.9217 | 1.0412 | INF | 3726.87 | 1.00x |
-| F8B0_W4MC0_R0.08 | 32.9871 | 1.0370 | 33.8317 | 2064.81 | 1.80x |
-| F8B0_W4MC2_R0.12 | 32.9535 | 1.0185 | 32.7346 | 1935.73 | 1.93x |
-| F8B0_W4MC3_R0.12 | 32.9234 | 1.0085 | 32.5385 | 1816.58 | 2.05x |
-| F4B0_W4MC3_R0.12 | 32.8981 | 1.0130 | 31.8031 | 1507.83 | 2.47x |
-| F4B0_W4MC4_R0.12 | 32.8384 | 1.0065 | 31.5292 | 1400.08 | 2.66x |
-
-
-
-### 📚Compare with Other Methods: Δ-DiT, Chipmunk, FORA, DuCa, TaylorSeer and FoCa
-
-
-
-
-
-
-
-
-
-The comparison between **cache-dit: DBCache** and algorithms such as Δ-DiT, Chipmunk, FORA, DuCa, TaylorSeer and FoCa is as follows. Now, in the comparison with a speedup ratio less than **4x**, cache-dit achieved the best accuracy. Surprisingly, cache-dit: DBCache still works in the extremely few-step distill model. For a complete benchmark, please refer to [📚Benchmarks](https://github.com/vipshop/cache-dit/raw/main/bench/). NOTE: Except for DBCache, other performance data are referenced from the paper [FoCa, arxiv.2508.16211](https://arxiv.org/pdf/2508.16211).
-
-
-
-| Method | TFLOPs(↓) | SpeedUp(↑) | ImageReward(↑) | Clip Score(↑) |
-| --- | --- | --- | --- | --- |
-| [**FLUX.1**-dev]: 50 steps | 3726.87 | 1.00× | 0.9898 | 32.404 |
-| [**FLUX.1**-dev]: 60% steps | 2231.70 | 1.67× | 0.9663 | 32.312 |
-| Δ-DiT(N=2) | 2480.01 | 1.50× | 0.9444 | 32.273 |
-| Δ-DiT(N=3) | 1686.76 | 2.21× | 0.8721 | 32.102 |
-| [**FLUX.1**-dev]: 34% steps | 1264.63 | 3.13× | 0.9453 | 32.114 |
-| Chipmunk | 1505.87 | 2.47× | 0.9936 | 32.776 |
-| FORA(N=3) | 1320.07 | 2.82× | 0.9776 | 32.266 |
-| **[DBCache(S)](https://github.com/vipshop/cache-dit)** | 1400.08 | **2.66×** | **1.0065** | 32.838 |
-| DuCa(N=5) | 978.76 | 3.80× | 0.9955 | 32.241 |
-| TaylorSeer(N=4,O=2) | 1042.27 | 3.57× | 0.9857 | 32.413 |
-| **[DBCache(S)+TS](https://github.com/vipshop/cache-dit)** | 1153.05 | **3.23×** | **1.0221** | 32.819 |
-| **[DBCache(M)](https://github.com/vipshop/cache-dit)** | 944.75 | **3.94×** | 0.9997 | 32.849 |
-| **[DBCache(M)+TS](https://github.com/vipshop/cache-dit)** | 944.75 | **3.94×** | **1.0107** | 32.865 |
-| **[FoCa(N=5): arxiv.2508.16211](https://arxiv.org/pdf/2508.16211)** | 893.54 | **4.16×** | 1.0029 | **32.948** |
-| [**FLUX.1**-dev]: 22% steps | 818.29 | 4.55× | 0.8183 | 31.772 |
-| FORA(N=7) | 670.14 | 5.55× | 0.7418 | 31.519 |
-| ToCa(N=12) | 644.70 | 5.77× | 0.7155 | 31.808 |
-| DuCa(N=10) | 606.91 | 6.13× | 0.8382 | 31.759 |
-| TeaCache(l=1.2) | 669.27 | 5.56× | 0.7394 | 31.704 |
-| TaylorSeer(N=7,O=2) | 670.44 | 5.54× | 0.9128 | 32.128 |
-| **[DBCache(F)](https://github.com/vipshop/cache-dit)** | 651.90 | **5.72x** | 0.9271 | 32.552 |
-| **[FoCa(N=8): arxiv.2508.16211](https://arxiv.org/pdf/2508.16211)** | 596.07 | 6.24× | 0.9502 | 32.706 |
-| **[DBCache(F)+TS](https://github.com/vipshop/cache-dit)** | 651.90 | **5.72x** | **0.9526** | 32.568 |
-| **[DBCache(U)+TS](https://github.com/vipshop/cache-dit)** | 505.47 | **7.37x** | 0.8645 | **32.719** |
-
-
-
-### 📚Text2Image Distillation DrawBench: Qwen-Image-Lightning
-
-Surprisingly, cache-dit: DBCache still works in the extremely few-step distill model. For example, **Qwen-Image-Lightning w/ 4 steps**, with the F16B16 configuration, the PSNR is 34.8163, the Clip Score is 35.6109, and the ImageReward is 1.2614. It maintained a relatively high precision.
-
-
-
-| Config | PSNR(↑) | Clip Score(↑) | ImageReward(↑) | TFLOPs(↓) | SpeedUp(↑) |
-|----------------------------|-----------|------------|--------------|----------|------------|
-| [**Lightning**]: 4 steps | INF | 35.5797 | 1.2630 | 274.33 | 1.00x |
-| F24B24_W2MC1_R0.8 | 36.3242 | 35.6224 | 1.2630 | 264.74 | 1.04x |
-| F16B16_W2MC1_R0.8 | 34.8163 | 35.6109 | 1.2614 | 244.25 | 1.12x |
-| F12B12_W2MC1_R0.8 | 33.8953 | 35.6535 | 1.2549 | 234.63 | 1.17x |
-| F8B8_W2MC1_R0.8 | 33.1374 | 35.7284 | 1.2517 | 224.29 | 1.22x |
-| F1B0_W2MC1_R0.8 | 31.8317 | 35.6651 | 1.2397 | 206.90 | 1.33x |
-
-
-
-## 🎉Unified Cache APIs
-
-
-
-### 📚Forward Pattern Matching
-
-Currently, for any **Diffusion** models with **Transformer Blocks** that match the specific **Input/Output patterns**, we can use the **Unified Cache APIs** from **cache-dit**, namely, the `cache_dit.enable_cache(...)` API. The **Unified Cache APIs** are currently in the experimental phase; please stay tuned for updates. The supported patterns are listed as follows:
-
-
-
-### ♥️Cache Acceleration with One-line Code
-
-In most cases, you only need to call **one-line** of code, that is `cache_dit.enable_cache(...)`. After this API is called, you just need to call the pipe as normal. The `pipe` param can be **any** Diffusion Pipeline. Please refer to [Qwen-Image](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_qwen_image.py) as an example.
-
-```python
-import cache_dit
-from diffusers import DiffusionPipeline
-
-# Can be any diffusion pipeline
-pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
-# One-line code with default cache options.
-cache_dit.enable_cache(pipe)
-# Just call the pipe as normal.
-output = pipe(...)
-# Disable cache and run original pipe.
-cache_dit.disable_cache(pipe)
-```
-
-### 🔥Automatic Block Adapter
-
-But in some cases, you may have a **modified** Diffusion Pipeline or Transformer that is not located in the diffusers library or not officially supported by **cache-dit** at this time. The **BlockAdapter** can help you solve this problems. Please refer to [🔥Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
-
-```python
-from cache_dit import ForwardPattern, BlockAdapter
-
-# Use 🔥BlockAdapter with `auto` mode.
-cache_dit.enable_cache(
- BlockAdapter(
- # Any DiffusionPipeline, Qwen-Image, etc.
- pipe=pipe, auto=True,
- # Check `📚Forward Pattern Matching` documentation and hack the code of
- # of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
- forward_pattern=ForwardPattern.Pattern_1,
- ),
-)
-
-# Or, manually setup transformer configurations.
-cache_dit.enable_cache(
- BlockAdapter(
- pipe=pipe, # Qwen-Image, etc.
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_1,
- ),
-)
-```
-For such situations, **BlockAdapter** can help you quickly apply various cache acceleration features to your own Diffusion Pipelines and Transformers.
-
-### 📚Hybrid Forward Pattern
-
-Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, **FLUX.1** (HiDream, Chroma, etc) contains transformer_blocks and single_transformer_blocks (with different forward patterns). The **BlockAdapter** can also help you solve this problem. Please refer to [📚FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.
-
-```python
-# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
-# single_transformer_blocks have different forward patterns.
-cache_dit.enable_cache(
- BlockAdapter(
- pipe=pipe, # FLUX.1, etc.
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_1,
- ForwardPattern.Pattern_3,
- ],
- ),
-)
-```
-
-Even sometimes you have more complex cases, such as **Wan 2.2 MoE**, which has more than one Transformer (namely `transformer` and `transformer_2`) in its structure. Fortunately, **cache-dit** can also handle this situation very well. Please refer to [📚Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
-
-```python
-from cache_dit import ForwardPattern, BlockAdapter, ParamsModifier, DBCacheConfig
-
-cache_dit.enable_cache(
- BlockAdapter(
- pipe=pipe,
- transformer=[
- pipe.transformer,
- pipe.transformer_2,
- ],
- blocks=[
- pipe.transformer.blocks,
- pipe.transformer_2.blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_2,
- ForwardPattern.Pattern_2,
- ],
- # Setup different cache params for each 'blocks'. You can
- # pass any specific cache params to ParamModifier, the old
- # value will be overwrite by the new one.
- params_modifiers=[
- ParamsModifier(
- cache_config=DBCacheConfig().reset(
- max_warmup_steps=4,
- max_cached_steps=8,
- ),
- ),
- ParamsModifier(
- cache_config=DBCacheConfig().reset(
- max_warmup_steps=2,
- max_cached_steps=20,
- ),
- ),
- ],
- has_separate_cfg=True,
- ),
-)
-```
-
-### 📚Implement Patch Functor
-
-For any PATTERN not in {0...5}, we introduced the simple abstract concept of **Patch Functor**. Users can implement a subclass of Patch Functor to convert an unknown Pattern into a known PATTERN, and for some models, users may also need to fuse the operations within the blocks for loop into block forward.
-
-
-
-Some Patch functors have already been provided in cache-dit: [📚HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/patch_functors/functor_hidream.py), [📚ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/patch_functors/functor_chroma.py), etc. After implementing Patch Functor, users need to set the `patch_functor` property of **BlockAdapter**.
-
-```python
-@BlockAdapterRegister.register("HiDream")
-def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import HiDreamImageTransformer2DModel
- from cache_dit.caching.patch_functors import HiDreamPatchFunctor
-
- assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.double_stream_blocks,
- pipe.transformer.single_stream_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_0,
- ForwardPattern.Pattern_3,
- ],
- # NOTE: Setup your custom patch functor here.
- patch_functor=HiDreamPatchFunctor(),
- **kwargs,
- )
-```
-
-### 📚Transformer-Only Interface
-
-In some cases, users may **not use Diffusers or DiffusionPipeline** at all, and may not even have the concept of a "pipeline"—for instance, **ComfyUI** (which breaks down the pipeline into individual components while still retaining transformer components). cache-dit also supports such scenarios; it only needs to be configured via **BlockAdapter**. The pipeline is not mandatory, and you can simply keep it at the default value of None. In this case, the `num_inference_steps` parameter in cache_config **must be set**, as cache-dit relies on this parameter to refresh the cache context at the appropriate time. Please refer to [📚run_transformer_only.py](https://github.com/vipshop/cache-dit/blob/main/examples/api/run_transformer_only.py) as an example.
-
-```python
-cache_dit.enable_cache(
- BlockAdapter(
- # NO `pipe` required
- transformer=transformer,
- blocks=transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_1,
- ),
- cache_config=DBCacheConfig(
- num_inference_steps=50 # required
- ),
-)
-```
-
-### 📚How to use ParamsModifier
-
-Sometimes you may encounter more complex cases, such as **Wan 2.2 MoE**, which has more than one Transformer (namely `transformer` and `transformer_2`), or FLUX.1, which has multiple transformer blocks (namely `single_transformer_blocks` and `transformer_blocks`). cache-dit will assign separate cache contexts for different `blocks` instances but share the same `cache_config` by default. Users who want to achieve fine-grained control over different cache contexts can consider using `ParamsModifier`. Just pass the `ParamsModifier` per `blocks` to the `BlockAdapter` or `enable_cache(...)` API. Then, the shared `cache_config` will be overwritten by the new configurations from the `ParamsModifier`. For example:
-
-```python
-from cache_dit import ParamsModifier
-
-cache_dit.enable_cache(
- BlockAdapter(
- pipe=pipe, # FLUX.1, etc.
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_1,
- ForwardPattern.Pattern_3,
- ],
- ),
- # Basic shared cache config
- cache_config=DBCacheConfig(...),
- params_modifiers=[
- ParamsModifier(
- # Modified config only for transformer_blocks
- # Must call the `reset` method of DBCacheConfig.
- cache_config=DBCacheConfig().reset(
- Fn_compute_blocks=8,
- residual_diff_threshold=0.08,
- ),
- ),
- ParamsModifier(
- # Modified config only for single_transformer_blocks
- # NOTE: FLUX.1, single_transformer_blocks should have `higher`
- # residual_diff_threshold because of the precision error
- # accumulation from previous transformer_blocks
- cache_config=DBCacheConfig().reset(
- Fn_compute_blocks=1,
- residual_diff_threshold=0.16,
- ),
- ),
- ],
-)
-```
-
-### 🤖Cache Acceleration Stats Summary
-
-After finishing each inference of `pipe(...)`, you can call the `cache_dit.summary()` API on pipe to get the details of the **Cache Acceleration Stats** for the current inference.
-```python
-stats = cache_dit.summary(pipe)
-```
-
-You can set `details` param as `True` to show more details of cache stats. (markdown table format) Sometimes, this may help you analyze what values of the residual diff threshold would be better.
-
-```python
-⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
-
-| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
-|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
-| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
-```
-
-## ⚡️DBCache: Dual Block Cache
-
-
-
-
-
-**DBCache**: **Dual Block Caching** for Diffusion Transformers. Different configurations of compute blocks (**F8B12**, etc.) can be customized in DBCache, enabling a balanced trade-off between performance and precision. Moreover, it can be entirely **training**-**free**. Please check [DBCache.md](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) docs for more design details.
-
-- **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
-- **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
-
-
-
-```python
-import cache_dit
-from diffusers import FluxPipeline
-
-pipe_or_adapter = FluxPipeline.from_pretrained(
- "black-forest-labs/FLUX.1-dev",
- torch_dtype=torch.bfloat16,
-).to("cuda")
-
-# Default options, F8B0, 8 warmup steps, and unlimited cached
-# steps for good balance between performance and precision
-cache_dit.enable_cache(pipe_or_adapter)
-
-# Custom options, F8B8, higher precision
-from cache_dit import DBCacheConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBCacheConfig(
- max_warmup_steps=8, # steps do not cache
- max_cached_steps=-1, # -1 means no limit
- Fn_compute_blocks=8, # Fn, F8, etc.
- Bn_compute_blocks=8, # Bn, B8, etc.
- residual_diff_threshold=0.12,
- ),
-)
-```
-
-
-
- DBCache, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"
-
-
-
-
-
-|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
-|:---:|:---:|:---:|:---:|:---:|:---:|
-|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
-|
|
|
|
|
|
|
-|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
-|27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
-|
|
|
|
|
|
|
-
-
-
-
-
- DBCache, L20x4 , Steps: 20, case to show the texture recovery ability of DBCache
-
-
-
-These case studies demonstrate that even with relatively high thresholds (such as 0.12, 0.15, 0.2, etc.) under the DBCache **F12B12** or **F8B16** configuration, the detailed texture of the kitten's fur, colored cloth, and the clarity of text can still be preserved. This suggests that users can leverage DBCache to effectively balance performance and precision in their workflows!
-
-## ⚡️DBPrune: Dynamic Block Prune
-
-
-
-
-
-
-We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, which is referred to as **DBPrune**. DBPrune caches each block's hidden states and residuals, then dynamically prunes blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals. DBPrune is currently in the experimental phase, and we kindly invite you to stay tuned for upcoming updates.
-
-```python
-from cache_dit import DBPruneConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBPruneConfig(
- max_warmup_steps=8, # steps do not apply prune
- residual_diff_threshold=0.12,
- enable_dynamic_prune_threshold=True,
- ),
-)
-```
-We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_block_ids** to avoid aggressive pruning. For example:
-
-```python
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBPruneConfig(
- max_warmup_steps=8, # steps do not apply prune
- Fn_compute_blocks=8, # Fn, F8, etc.
- Bn_compute_blocks=8, # Bn, B8, etc
- residual_diff_threshold=0.12,
- enable_dynamic_prune_threshold=True,
- non_prune_block_ids=list(range(16,24)),
- ),
-)
-```
-
-
- DBPrune, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"
-
-
-
-
-
-|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(38%)|Pruned(45%)|Pruned(60%)|
-|:---:|:---:|:---:|:---:|:---:|:---:|
-|24.85s|19.43s|16.82s|15.95s|14.24s|10.66s|
-|
|
|
|
|
|
|
-
-
-
-## ⚡️Hybrid Cache CFG
-
-
-
-cache-dit supports caching for **CFG (classifier-free guidance)**. For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `enable_separate_cfg` param to **False (default, None)**. Otherwise, set it to True. For examples:
-
-```python
-from cache_dit import DBCacheConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBCacheConfig(
- ...,
- # CFG: classifier free guidance or not
- # For model that fused CFG and non-CFG into single forward step,
- # should set enable_separate_cfg as False. For example, set it as True
- # for Wan 2.1/Qwen-Image and set it as False for FLUX.1, HunyuanVideo,
- # CogVideoX, Mochi, LTXVideo, Allegro, CogView3Plus, EasyAnimate, SD3, etc.
- enable_separate_cfg=True, # Wan 2.1, Qwen-Image, CogView4, Cosmos, SkyReelsV2, etc.
- # Compute cfg forward first or not, default False, namely,
- # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
- cfg_compute_first=False,
- # Compute separate diff values for CFG and non-CFG step,
- # default True. If False, we will use the computed diff from
- # current non-CFG transformer step for current CFG step.
- cfg_diff_compute_separate=True,
- ),
-)
-```
-
-## 🔥Hybrid TaylorSeer Calibrator
-
-
-
-We have supported the [TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) algorithm to further improve the precision of DBCache in cases where the cached steps are large, namely, **Hybrid TaylorSeer + DBCache**. At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
-
-$$
-\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)=\mathcal{F}\left(x_t^l\right)+\sum_{i=1}^m \frac{\Delta^i \mathcal{F}\left(x_t^l\right)}{i!\cdot N^i}(-k)^i
-$$
-
-**TaylorSeer** employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in cache-dit supports both hidden states and residual cache types. That is $\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)$ can be a residual cache or a hidden-state cache.
-
-```python
-from cache_dit import DBCacheConfig, TaylorSeerCalibratorConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- # Basic DBCache w/ FnBn configurations
- cache_config=DBCacheConfig(
- max_warmup_steps=8, # steps do not cache
- max_cached_steps=-1, # -1 means no limit
- Fn_compute_blocks=8, # Fn, F8, etc.
- Bn_compute_blocks=8, # Bn, B8, etc.
- residual_diff_threshold=0.12,
- ),
- # Then, you can use the TaylorSeer Calibrator to approximate
- # the values in cached steps, taylorseer_order default is 1.
- calibrator_config=TaylorSeerCalibratorConfig(
- taylorseer_order=1,
- ),
-)
-```
-
-> [!Important]
-> Please note that if you have used TaylorSeer as the calibrator for approximate hidden states, the **Bn** param of DBCache can be set to **0**. In essence, DBCache's Bn is also act as a calibrator, so you can choose either Bn > 0 or TaylorSeer. We recommend using the configuration scheme of **TaylorSeer** + **DBCache FnB0**.
-
-
-
- DBCache F1B0 + TaylorSeer , L20x1, Steps: 28, "A cat holding a sign that says hello world with complex background"
-
-
-
-
-
-|Baseline(L20x1)|F1B0 (0.12)|+TaylorSeer|F1B0 (0.15)|+TaylorSeer|+compile|
-|:---:|:---:|:---:|:---:|:---:|:---:|
-|24.85s|12.85s|12.86s|10.27s|10.28s|8.48s|
-|
|
|
|
|
|
|
-
-
-
-## 🤖SCM: Steps Computation Masking
-
-
-
-
-The `steps_computation_mask` parameter adopts a step-wise computation masking approach inspired by [LeMiCa](https://github.com/UnicomAI/LeMiCa) and [EasyCache](https://github.com/H-EmbodVis/EasyCache). Its key insight is that **early caching induces amplified downstream errors, whereas later caching is less disruptive**, resulting in a **non-uniform** distribution of cached steps.
-
-
-
-|LeMiCa: Non-Uniform Cache Steps|LeMiCa: Cache Errors|EasyCache: Transformation rate Analysis|
-|:---:|:---:|:---:|
-|
|
|
|
-
-
-
-It is a list of length num_inference_steps indicating whether to compute each step or not. 1 means must compute, 0 means use dynamic/static cache. If provided, will override other settings to decide whether to compute each step. Please check the [📚examples/steps_mask](../examples/api/run_steps_mask.py) for more details.
-
-
-
-```python
-from cache_dit import DBCacheConfig, TaylorSeerCalibratorConfig
-
-# Scheme: Hybrid DBCache + LeMiCa/EasyCache + TaylorSeer
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBCacheConfig(
- # Basic DBCache configs
- Fn_compute_blocks=8,
- Bn_compute_blocks=0,
- # keep is the same as first compute bin
- max_warmup_steps=6,
- residual_diff_threshold=0.12,
- # LeMiCa or EasyCache style Mask for 28 steps, e.g,
- # SCM=111111010010000010000100001, 1: compute, 0: cache.
- steps_computation_mask=cache_dit.steps_mask(
- compute_bins=[6, 1, 1, 1, 1], # 10
- cache_bins=[1, 2, 5, 5, 5], # 18
- ),
- # The policy for cache steps can be 'dynamic' or 'static'
- steps_computation_policy="dynamic",
- ),
- calibrator_config=TaylorSeerCalibratorConfig(
- taylorseer_order=1,
- ),
-)
-
-```
-
-As we can observe, in the case of **static cache**, the image of `SCM Slow S*` (please click to enlarge) has shown **obvious blurriness**. However, the **Ultra** version under **dynamic cache** (`SCM Ultra D*`) still maintains excellent clarity. Therefore, we prioritize recommending the use of dynamic cache while using `SCM: steps_computation_mask`.
-
-
-
-|Baseline|SCM S S*|SCM S D*|SCM F D*|SCM U D*|+TS|+compile|+FP8 +Sage|
-|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
-|24.85s|15.4s|17.1s|11.4s|8.2s|8.2s|7.1s|4.5s|
-|
|
|
|
|
|
|
|
|
-
-
- Scheme: DBCache + SCM(steps_computation_mask) + TaylorSeer , L20x1, S*: static cache, D*: dynamic cache , S : Slow, F : Fast, U : Ultra Fast, TS : TaylorSeer, FP8: FP8 DQ, Sage: SageAttention, FLUX.1-Dev , Steps: 28, HxW=1024x1024, Prompt: "A cat holding a sign that says hello world"
-
-
-|DBCache + SCM Slow S*|DBCache + SCM Ultra D* + TaylorSeer + compile|
-|:---:|:---:|
-|15.4s|7.1s|
-|
|
|
-
-
-Dynamic Caching is all you need! The Ultra fast version under dynamic cache (SCM Ultra D* ) maintains better clarity than the slower static cache one (SCM Slow S* ).
-
-
-
-
-
-## ⚡️Hybrid Context Parallelism
-
-
-
-cache-dit is compatible with context parallelism. Currently, we support the use of `Hybrid Cache` + `Context Parallelism` scheme (via NATIVE_DIFFUSER parallelism backend) in cache-dit. Users can use Context Parallelism to further accelerate the speed of inference! For more details, please refer to [📚examples/parallelism](https://github.com/vipshop/cache-dit/tree/main/examples/parallelism). Currently, cache-dit supported context parallelism for [FLUX.1](https://huggingface.co/black-forest-labs/FLUX.1-dev), [Qwen-Image](https://github.com/QwenLM/Qwen-Image), [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning), [LTXVideo](https://huggingface.co/Lightricks/LTX-Video), [Wan 2.1](https://github.com/Wan-Video/Wan2.1), [Wan 2.2](https://github.com/Wan-Video/Wan2.2), [HunyuanImage-2.1](https://huggingface.co/tencent/HunyuanImage-2.1), [HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo), [CogVideoX 1.0](https://github.com/zai-org/CogVideo), [CogVideoX 1.5](https://github.com/zai-org/CogVideo), [CogView 3/4](https://github.com/zai-org/CogView4) and [VisualCloze](https://github.com/lzyhha/VisualCloze), etc. cache-dit will support more models in the future.
-
-```python
-# pip3 install "cache-dit[parallelism]"
-from cache_dit import ParallelismConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBCacheConfig(...),
- # Set ulysses_size > 1 to enable ulysses style context parallelism.
- parallelism_config=ParallelismConfig(ulysses_size=2),
-)
-# torchrun --nproc_per_node=2 parallel_cache.py
-```
-
-## 🤖UAA: Ulysses Anything Attention
-
-
-
-We have implemented the **[📚UAA: Ulysses Anything Attention](#uaa-ulysses-anything-attention)**: An Ulysses Attention that supports **arbitrary sequence length** with ✅**zero padding** and **nearly ✅zero theoretical communication overhead**. The default Ulysses Attention requires that the sequence len of hidden states **must be divisible by the number of devices**. This imposes **significant limitations** on the practical application of Ulysses.
-
-
-```python
-# pip3 install "cache-dit[parallelism]"
-from cache_dit import ParallelismConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBCacheConfig(...),
- # Set `experimental_ulysses_anything` as True to enable UAA
- parallelism_config=ParallelismConfig(
- ulysses_size=2,
- parallel_kwargs={
- "experimental_ulysses_anything": True
- },
- ),
-)
-# torchrun --nproc_per_node=2 parallel_cache_ulysses_anything.py
-```
-
-For example, in the T2I and I2V tasks, the length of prompts input by users is often variable, and it is difficult to ensure that this length is divisible by the number of devices. To address this issue, we have developed a **✅padding-free** Ulysses Attention (UAA) for **arbitrary sequence length**, which enhances the versatility of Ulysses.
-
-```python
-dist.init_process_group(backend="cpu:gloo,cuda:nccl")
-```
-Compared to Ulysses Attention, in **UAA**, we have only added an **extra all-gather** op for scalar types to gather the seq_len value of each rank. To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **✅gloo** backend in `init_process_group`. This will significantly reduce commucation latency.
-
-
-
-
- U*: Ulysses Attention, UAA: Ulysses Anything Attenton , UAA*: UAA + Gloo, Device: NVIDIA L20
- FLUX.1-Dev w/o CPU Offload, 28 steps; Qwen-Image w/ CPU Offload, 50 steps; Gloo: Extra All Gather w/ Gloo
-
-
-|CP2 w/ U* |CP2 w/ UAA* | CP2 w/ UAA | L20x1 | CP2 w/ UAA* | CP2 w/ U* | L20x1 | CP2 w/ UAA* |
-|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
-|FLUX, 13.87s|**🎉13.88s**|14.75s|23.25s| **🎉13.75s**|Qwen, 132s|181s|**🎉133s**|
-|
|
|
|
|
|
|
|
|
-|1024x1024|1024x1024|1024x1024|1008x1008|1008x1008|1312x1312|1328x1328|1328x1328|
-|✔️U* ✔️UAA|✔️U* ✔️UAA|✔️U* ✔️UAA| NO CP|❌U* ✔️UAA|✔️U* ✔️UAA|NO CP|❌U* ✔️UAA|
-
-
-
-> [!Important]
-> Please note that **Ulysses Anything Attention (UAA)** is currently an **experimental** feature. It has not undergone large-scale testing, and may introduce a slight performance degradation while the `cpu:gloo` commucation backend is not available.
-
-
-## ⚡️Hybrid Tensor Parallelism
-
-
-
-cache-dit is also compatible with tensor parallelism. Currently, we support the use of `Hybrid Cache` + `Tensor Parallelism` scheme (via NATIVE_PYTORCH parallelism backend) in cache-dit. Users can use Tensor Parallelism to further accelerate the speed of inference and **reduce the VRAM usage per GPU**! For more details, please refer to [📚examples/parallelism](https://github.com/vipshop/cache-dit/tree/main/examples/parallelism). Now, cache-dit supported tensor parallelism for [FLUX.1](https://huggingface.co/black-forest-labs/FLUX.1-dev), [Qwen-Image](https://github.com/QwenLM/Qwen-Image), [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning), [Wan2.1](https://github.com/Wan-Video/Wan2.1), [Wan2.2](https://github.com/Wan-Video/Wan2.2), [HunyuanImage-2.1](https://huggingface.co/tencent/HunyuanImage-2.1), [HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) and [VisualCloze](https://github.com/lzyhha/VisualCloze), etc. cache-dit will support more models in the future.
-
-```python
-# pip3 install "cache-dit[parallelism]"
-from cache_dit import ParallelismConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBCacheConfig(...),
- # Set tp_size > 1 to enable tensor parallelism.
- parallelism_config=ParallelismConfig(tp_size=2),
-)
-# torchrun --nproc_per_node=2 parallel_cache.py
-```
-
-> [!Important]
-> Please note that in the short term, we have no plans to support Hybrid Parallelism. Please choose to use either Context Parallelism or Tensor Parallelism based on your actual scenario.
-
-## 🤖Low-bits Quantization
-
-
-
-Currently, torchao has been integrated into cache-dit as the backend for **online** model quantization (with more backends to be supported in the future). You can implement model quantization by calling `cache_dit.quantize(...)`. At present, cache-dit supports the `Hybrid Cache + Low-bits Quantization` scheme. For GPUs with low memory capacity, we recommend using `float8_weight_only` or `int8_weight_only`, as these two schemes cause almost no loss in precision. For more details, please refer to [📚examples/quantize](https://github.com/vipshop/cache-dit/tree/main/examples/quantize).
-
-```python
-# pip3 install "cache-dit[quantization]"
-import cache_dit
-
-cache_dit.enable_cache(pipe_or_adapter)
-
-# float8, float8_weight_only, int8, int8_weight_only, int4, int4_weight_only
-# int4_weight_only requires fbgemm-gpu-genai>=1.2.0, which only supports
-# Compute Architectures >= Hopper (and does not support Ada, ..., etc.)
-pipe.transformer = cache_dit.quantize(
- pipe.transformer, quant_type="float8_weight_only"
-)
-pipe.text_encoder = cache_dit.quantize(
- pipe.text_encoder, quant_type="float8_weight_only"
-)
-```
-
-For **4-bits W4A16 (weight only)** quantization, we recommend `nf4` from **bitsandbytes** due to its better compatibility for many devices. Users can directly use it via the `quantization_config` of diffusers. For example:
-
-```python
-from diffusers import QwenImagePipeline
-from diffusers.quantizers import PipelineQuantizationConfig
-
-pipe = QwenImagePipeline.from_pretrained(
- "Qwen/Qwen-Image",
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder", "transformer"],
- )
- ),
-).to("cuda")
-
-# Then, apply cache acceleration using cache-dit
-cache_dit.enable_cache(pipe, cache_config=...)
-```
-
-cache-dit natively supports the `Hybrid Cache + 🔥Nunchaku SVDQ INT4/FP4 + Context Parallelism` scheme. Users can leverage caching and context parallelism to speed up Nunchaku **4-bit** models. For more details, please refer to [📚parallelism+nunchaku](https://github.com/vipshop/cache-dit/tree/main/examples/parallelism/run_qwen_image_nunchaku_cp.py).
-
-```python
-transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
- f"path-to/svdq-int4_r32-qwen-image.safetensors"
-)
-pipe = QwenImagePipeline.from_pretrained(
- "Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16,
-).to("cuda")
-
-cache_dit.enable_cache(pipe, cache_config=..., parallelism_config=...)
-```
-
-## 🤖How to use FP8 Attention
-
-
-
-For FP8 Attention, users must install `sage-attention`. Then, pass the `sage` attention backend to the context parallelism configuration as an extra parameter. Please note that `attention mask` is not currently supported for FP8 sage attention.
-
-```python
-# pip3 install "cache-dit[parallelism]"
-# pip3 install git+https://github.com/thu-ml/SageAttention.git
-from cache_dit import ParallelismConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBCacheConfig(...),
- parallelism_config=ParallelismConfig(
- ulysses_size=2,
- parallel_kwargs={
- # flash, native(sdpa), _native_cudnn, sage
- "attention_backend": "sage",
- },
- ),
-)
-# torchrun --nproc_per_node=2 parallel_fp8_cache.py
-```
-
-## 🛠Metrics Command Line
-
-
-
-You can utilize the APIs provided by cache-dit to quickly evaluate the accuracy losses caused by different cache configurations. For example:
-
-```python
-# pip3 install "cache-dit[metrics]"
-from cache_dit.metrics import compute_psnr
-from cache_dit.metrics import compute_ssim
-from cache_dit.metrics import compute_fid
-from cache_dit.metrics import compute_lpips
-from cache_dit.metrics import compute_clip_score
-from cache_dit.metrics import compute_image_reward
-
-psnr, n = compute_psnr("true.png", "test.png") # Num: n
-psnr, n = compute_psnr("true_dir", "test_dir")
-ssim, n = compute_ssim("true_dir", "test_dir")
-fid, n = compute_fid("true_dir", "test_dir")
-lpips, n = compute_lpips("true_dir", "test_dir")
-clip, n = compute_clip_score("DrawBench200.txt", "test_dir")
-reward, n = compute_image_reward("DrawBench200.txt", "test_dir")
-```
-
-Or, you can use `cache-dit-metrics-cli` tool. For examples:
-
-```bash
-cache-dit-metrics-cli -h # show usage
-# all: PSNR, FID, SSIM, MSE, ..., etc.
-cache-dit-metrics-cli all -i1 true.png -i2 test.png # image
-cache-dit-metrics-cli all -i1 true_dir -i2 test_dir # image dir
-```
-
-## ⚙️Torch Compile
-
-
-
-By the way, **cache-dit** is designed to work compatibly with **torch.compile.** You can easily use cache-dit with torch.compile to further achieve a better performance. For example:
-
-```python
-cache_dit.enable_cache(pipe)
-
-# Compile the Transformer module
-pipe.transformer = torch.compile(pipe.transformer)
-```
-However, users intending to use **cache-dit** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
-```python
-torch._dynamo.config.recompile_limit = 96 # default is 8
-torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
-```
-
-Please check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.
-
----
-
-## 📚API Documentation
-
-
-
-Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks that match the specific Input and Output patterns). For a good balance between performance and precision, DBCache is configured by default with F8B0, 8 warmup steps, and unlimited cached steps. All the configurable params are listed beflows.
-
-### 👏API: enable_cache
-
-```python
-def enable_cache(...) -> Union[DiffusionPipeline, BlockAdapter, Transformer]
-```
-
-### 🌟Function Description
-
-The `enable_cache` function serves as a unified caching interface designed to optimize the performance of diffusion transformer models by implementing an intelligent caching mechanism known as `DBCache`. This API is engineered to be compatible with nearly `all` diffusion transformer architectures that feature transformer blocks adhering to standard input-output patterns, eliminating the need for architecture-specific modifications.
-
-By strategically caching intermediate outputs of transformer blocks during the diffusion process, `DBCache` significantly reduces redundant computations without compromising generation quality. The caching mechanism works by tracking residual differences between consecutive steps, allowing the model to reuse previously computed features when these differences fall below a configurable threshold. This approach maintains a balance between computational efficiency and output precision.
-
-The default configuration (`F8B0, 8 warmup steps, unlimited cached steps`) is carefully tuned to provide an optimal tradeoff for most common use cases. The "F8B0" configuration indicates that the first 8 transformer blocks are used to compute stable feature differences, while no final blocks are employed for additional fusion. The warmup phase ensures the model establishes sufficient feature representation before caching begins, preventing potential degradation of output quality.
-
-This function seamlessly integrates with both standard diffusion pipelines and custom block adapters, making it versatile for various deployment scenarios—from research prototyping to production environments where inference speed is critical. By abstracting the complexity of caching logic behind a simple interface, it enables developers to enhance model performance with minimal code changes.
-
-### 👇Quick Start
-
-```python
->>> import cache_dit
->>> from diffusers import DiffusionPipeline
->>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
->>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
->>> output = pipe(...) # Just call the pipe as normal.
->>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
->>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
-```
-
-### 👇Parameter Description
-
-- **pipe_or_adapter**(`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
- The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
- For example: `cache_dit.enable_cache(FluxPipeline(...))`.
-
-- **cache_config**(`DBCacheConfig`, *required*, defaults to DBCacheConfig()):
- Basic DBCache config for cache context, defaults to DBCacheConfig(). The configurable parameters are listed below:
- - `Fn_compute_blocks`: (`int`, *required*, defaults to 8):
- Specifies that `DBCache` uses the**first n**Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 difference and delivering more accurate information to subsequent blocks.
- Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md for more details of DBCache.
- - `Bn_compute_blocks`: (`int`, *required*, defaults to 0):
- Further fuses approximate information in the**last n**Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
- - `residual_diff_threshold`: (`float`, *required*, defaults to 0.08):
- The value of residual difference threshold, a higher value leads to faster performance at the cost of lower precision.
- - `max_accumulated_residual_diff_threshold`: (`float`, *optional*, defaults to None):
- The maximum accumulated relative l1 diff threshold for Cache. If set, when the
- accumulated relative l1 diff exceeds this threshold, the caching strategy will be
- disabled for current step. This is useful for some cases where the input condition
- changes significantly in a single step. Default None means this feature is disabled.
- - `max_warmup_steps`: (`int`, *required*, defaults to 8):
- DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
- - `warmup_interval`: (`int`, *required*, defaults to 1):
- Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
- in warmup steps will be computed, others will use dynamic cache.
- - `max_cached_steps`: (`int`, *required*, defaults to -1):
- DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
- - `max_continuous_cached_steps`: (`int`, *required*, defaults to -1):
- DBCache disables the caching strategy when the previous continuous cached steps exceed this value to prevent precision degradation.
- - `enable_separate_cfg`: (`bool`, *required*, defaults to None):
- Whether to use separate cfg or not, such as in Wan 2.1, Qwen-Image. For models that fuse CFG and non-CFG into a single forward step, set enable_separate_cfg as False. Examples include: CogVideoX, HunyuanVideo, Mochi, etc.
- - `cfg_compute_first`: (`bool`, *required*, defaults to False):
- Whether to compute cfg forward first, default is False, meaning:
- 0, 2, 4, ... -> non-CFG step; 1, 3, 5, ... -> CFG step.
- - `cfg_diff_compute_separate`: (`bool`, *required*, defaults to True):
- Whether to compute separate difference values for CFG and non-CFG steps, default is True. If False, we will use the computed difference from the current non-CFG transformer step for the current CFG step.
- - `num_inference_steps` (`int`, *optional*, defaults to None):
- num_inference_steps for DiffusionPipeline, used to adjust some internal settings
- for better caching performance. For example, we will refresh the cache once the
- executed steps exceed num_inference_steps if num_inference_steps is provided.
- - `steps_computation_mask`: (`List[int]`, *optional*, defaults to None):
- This param introduce LeMiCa/EasyCache style compute mask for steps. It is a list
- of length num_inference_steps indicating whether to compute each step or not.
- 1 means must compute, 0 means use dynamic/static cache. If provided, will override
- other settings to decide whether to compute each step.
- - `steps_computation_policy`: (`str`, *optional*, defaults to "dynamic"):
- The computation policy for steps when using steps_computation_mask. It can be
- "dynamic" or "static". "dynamic" means using dynamic cache for steps marked as 0
- in steps_computation_mask, while "static" means using static cache for those steps.
-
-- **calibrator_config** (`CalibratorConfig`, *optional*, defaults to None):
- Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache with a specific calibrator, such as taylorseer, foca, and so on.
-
-- **params_modifiers** ('ParamsModifier', *optional*, defaults to None):
- Modify cache context parameters for specific blocks. The configurable parameters are listed below:
- - `cache_config`: (`DBCacheConfig`, *required*, defaults to DBCacheConfig()):
- The same as the 'cache_config' parameter in the cache_dit.enable_cache() interface.
- - `calibrator_config`: (`CalibratorConfig`, *optional*, defaults to None):
- The same as the 'calibrator_config' parameter in the cache_dit.enable_cache() interface.
- - `**kwargs`: (`dict`, *optional*, defaults to {}):
- The same as the 'kwargs' parameter in the cache_dit.enable_cache() interface.
-
-- **parallelism_config** (`ParallelismConfig`, *optional*, defaults to None):
- Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
- parallelism for cache-dit.
- - `backend`: (`ParallelismBackend`, *required*, defaults to "ParallelismBackend.NATIVE_DIFFUSER"):
- Parallelism backend, currently only NATIVE_DIFFUSER and NVTIVE_PYTORCH are supported.
- For context parallelism, only NATIVE_DIFFUSER backend is supported, for tensor parallelism,
- only NATIVE_PYTORCH backend is supported.
- - `ulysses_size`: (`int`, *optional*, defaults to None):
- The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
- This setting is only valid when backend is NATIVE_DIFFUSER.
- - `ring_size`: (`int`, *optional*, defaults to None):
- The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
- This setting is only valid when backend is NATIVE_DIFFUSER.
- - `tp_size`: (`int`, *optional*, defaults to None):
- The size of tensor parallelism. If tp_size is not None, enable tensor parallelism.
- This setting is only valid when backend is NATIVE_PYTORCH.
- - `parallel_kwargs`: (`dict`, *optional*, defaults to {}):
- Additional kwargs for parallelism backends. For example, for NATIVE_DIFFUSER backend,
- it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
-
-- **kwargs** (`dict`, *optional*, defaults to {}):
- Other cache context keyword arguments. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/cache_contexts/cache_context.py for more details.
diff --git a/docs/assets/dbprune.png b/docs/assets/dbprune.png
new file mode 100644
index 000000000..5d3cfa098
Binary files /dev/null and b/docs/assets/dbprune.png differ
diff --git a/docs/assets/flux.1024x1024.C0_Q0_NONE.png b/docs/assets/flux.1024x1024.C0_Q0_NONE.png
new file mode 100644
index 000000000..c58a4c9a1
Binary files /dev/null and b/docs/assets/flux.1024x1024.C0_Q0_NONE.png differ
diff --git a/docs/assets/flux.1024x1024.C0_Q0_NONE_TP2.png b/docs/assets/flux.1024x1024.C0_Q0_NONE_TP2.png
new file mode 100644
index 000000000..4ea4dc361
Binary files /dev/null and b/docs/assets/flux.1024x1024.C0_Q0_NONE_TP2.png differ
diff --git a/docs/assets/flux.1024x1024.C0_Q0_NONE_TP4.png b/docs/assets/flux.1024x1024.C0_Q0_NONE_TP4.png
new file mode 100644
index 000000000..146e2266f
Binary files /dev/null and b/docs/assets/flux.1024x1024.C0_Q0_NONE_TP4.png differ
diff --git a/docs/assets/flux.1024x1024.C0_Q0_NONE_Ulysses2.png b/docs/assets/flux.1024x1024.C0_Q0_NONE_Ulysses2.png
new file mode 100644
index 000000000..9b7a3dd23
Binary files /dev/null and b/docs/assets/flux.1024x1024.C0_Q0_NONE_Ulysses2.png differ
diff --git a/docs/assets/flux.1024x1024.C0_Q0_NONE_Ulysses4.png b/docs/assets/flux.1024x1024.C0_Q0_NONE_Ulysses4.png
new file mode 100644
index 000000000..e7334e55a
Binary files /dev/null and b/docs/assets/flux.1024x1024.C0_Q0_NONE_Ulysses4.png differ
diff --git a/docs/assets/flux.1024x1024.C1_Q0_NONE_TP4.png b/docs/assets/flux.1024x1024.C1_Q0_NONE_TP4.png
new file mode 100644
index 000000000..6f85b1b07
Binary files /dev/null and b/docs/assets/flux.1024x1024.C1_Q0_NONE_TP4.png differ
diff --git a/docs/assets/flux.1024x1024.C1_Q0_NONE_Ulysses2.png b/docs/assets/flux.1024x1024.C1_Q0_NONE_Ulysses2.png
new file mode 100644
index 000000000..79ade9718
Binary files /dev/null and b/docs/assets/flux.1024x1024.C1_Q0_NONE_Ulysses2.png differ
diff --git a/docs/assets/flux.1024x1024.C1_Q0_NONE_Ulysses4.png b/docs/assets/flux.1024x1024.C1_Q0_NONE_Ulysses4.png
new file mode 100644
index 000000000..7b819ca21
Binary files /dev/null and b/docs/assets/flux.1024x1024.C1_Q0_NONE_Ulysses4.png differ
diff --git a/docs/assets/lemica.png b/docs/assets/lemica.png
new file mode 100644
index 000000000..9087e7831
Binary files /dev/null and b/docs/assets/lemica.png differ
diff --git a/docs/assets/profile_0.png b/docs/assets/profile_0.png
new file mode 100644
index 000000000..2f40945ac
Binary files /dev/null and b/docs/assets/profile_0.png differ
diff --git a/docs/assets/profile_1.png b/docs/assets/profile_1.png
new file mode 100644
index 000000000..1ba598e68
Binary files /dev/null and b/docs/assets/profile_1.png differ
diff --git a/docs/assets/profile_2.png b/docs/assets/profile_2.png
new file mode 100644
index 000000000..439474889
Binary files /dev/null and b/docs/assets/profile_2.png differ
diff --git a/docs/assets/taylorseer_0.png b/docs/assets/taylorseer_0.png
new file mode 100644
index 000000000..f7c67678b
Binary files /dev/null and b/docs/assets/taylorseer_0.png differ
diff --git a/docs/assets/taylorseer_1.png b/docs/assets/taylorseer_1.png
new file mode 100644
index 000000000..4d3675dbd
Binary files /dev/null and b/docs/assets/taylorseer_1.png differ
diff --git a/docs/assets/zimage.C1_Q0_DBCache_F1B0_W4I1M0MC0_R0.6_SCM111110101_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_sdpa_cudnn.png b/docs/assets/zimage.C1_Q0_DBCache_F1B0_W4I1M0MC0_R0.6_SCM111110101_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_sdpa_cudnn.png
new file mode 100644
index 000000000..ff26fe0cb
Binary files /dev/null and b/docs/assets/zimage.C1_Q0_DBCache_F1B0_W4I1M0MC0_R0.6_SCM111110101_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_sdpa_cudnn.png differ
diff --git a/docs/assets/zimage.C1_Q0_NONE_Ulysses2_sdpa_cudnn.png b/docs/assets/zimage.C1_Q0_NONE_Ulysses2_sdpa_cudnn.png
new file mode 100644
index 000000000..b05f6f6a6
Binary files /dev/null and b/docs/assets/zimage.C1_Q0_NONE_Ulysses2_sdpa_cudnn.png differ
diff --git a/docs/assets/zimage.C1_Q0_NONE_Ulysses4_sdpa_cudnn.png b/docs/assets/zimage.C1_Q0_NONE_Ulysses4_sdpa_cudnn.png
new file mode 100644
index 000000000..b4adb73db
Binary files /dev/null and b/docs/assets/zimage.C1_Q0_NONE_Ulysses4_sdpa_cudnn.png differ
diff --git a/docs/assets/zimage.C1_Q0_NONE_Ulysses4_ulysses_float8_sdpa_cudnn.png b/docs/assets/zimage.C1_Q0_NONE_Ulysses4_ulysses_float8_sdpa_cudnn.png
new file mode 100644
index 000000000..31313dfc1
Binary files /dev/null and b/docs/assets/zimage.C1_Q0_NONE_Ulysses4_ulysses_float8_sdpa_cudnn.png differ
diff --git a/docs/assets/zimage.C1_Q1_float8_DBCache_F1B0_W4I1M0MC0_R0.6_SCM111110101_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_sdpa_cudnn.png b/docs/assets/zimage.C1_Q1_float8_DBCache_F1B0_W4I1M0MC0_R0.6_SCM111110101_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_sdpa_cudnn.png
new file mode 100644
index 000000000..4ff1f7086
Binary files /dev/null and b/docs/assets/zimage.C1_Q1_float8_DBCache_F1B0_W4I1M0MC0_R0.6_SCM111110101_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_sdpa_cudnn.png differ
diff --git a/docs/benchmark/ASCEND_NPU.md b/docs/benchmark/ASCEND_NPU.md
new file mode 100644
index 000000000..10524a849
--- /dev/null
+++ b/docs/benchmark/ASCEND_NPU.md
@@ -0,0 +1,30 @@
+# Ascend NPU Benchmark
+
+## Atlas 800I A2
+
+### FLUX.1-dev
+
+
+ Ulysses: Standard Ulysses Attention, Async Ulysses : Ulysses Attenton with Async QKV Projection, NPU Attn Backend : NPU Attention Backend, UAA: Ulysses Anything Attenton
+
+
+|800I A2x1| 800I A2x2 w/ Ulysses|w/ Async Ulysses| w/ Async Ulysses + NPU Attn Backend|
+|:---:|:---:|:---:|:---:|
+|FLUX.1, 16.13s|**🎉11.45s**|10.47s|**🎉10.34s**|
+| | | |
+
+
+### Qwen-Image-Edit
+
+|800I A2x2 w/ UAA + TEP| 800I A2x4 w/ UAA + TEP|w/ UAA + TEP + Async Ulysses| w/ UAA + TEP + Async Ulysses + NPU Attn Backend|
+|:---:|:---:|:---:|:---:|
+|Qwen-Image-Edit, 134.76s|**🎉67.44s**|64.82s|**🎉64.43s**|
+| | | | |
+
+
+### Z-Image-Turbo w/ NPU Attn Backend
+
+|800I A2x1| 800I A2x2 w/ Ulysses|w/ Async Ulysses|
+|:---:|:---:|:---:|
+|Z-Image-Turbo, 3.39s|**🎉2.49s**|**🎉2.38s**|
+| | |
diff --git a/docs/benchmark/HYBRID_CACHE.md b/docs/benchmark/HYBRID_CACHE.md
new file mode 100644
index 000000000..c46751516
--- /dev/null
+++ b/docs/benchmark/HYBRID_CACHE.md
@@ -0,0 +1,76 @@
+# Hybrid Cache Acceleration Benchmark
+
+|Baseline|SCM S S*|SCM F D*|SCM U D*|+TS|+compile|+FP8*|
+|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
+|24.85s|15.4s|11.4s|8.2s|8.2s|**🎉7.1s**|**🎉4.5s**|
+| | | | | | | |
+
+
+
+ Scheme: DBCache + SCM(steps_computation_mask) + TS(TaylorSeer) + FP8* , L20x1, S*: static cache, D*: dynamic cache , S : Slow, F : Fast, U : Ultra Fast, TS : TaylorSeer, FP8* : FP8 DQ + Sage, FLUX.1 -Dev
+
+
+
+
+
+
+
+
+cache-dit will support more mainstream Cache acceleration algorithms in the future. More benchmarks will be released, please stay tuned for update. Here, only the results of some precision and performance benchmarks are presented. The test dataset is **DrawBench**. For a complete benchmark, please refer to [📚Benchmarks](https://github.com/vipshop/cache-dit/tree/main/bench/).
+
+## Text2Image DrawBench
+
+Comparisons between different FnBn compute block configurations show that **more compute blocks result in higher precision**. For example, the F8B0_W8MC0 configuration achieves the best Clip Score (33.007) and ImageReward (1.0333). **Device**: NVIDIA L20. **F**: Fn_compute_blocks, **B**: Bn_compute_blocks, 50 steps.
+
+| Config | Clip Score(↑) | ImageReward(↑) | PSNR(↑) | TFLOPs(↓) | SpeedUp(↑) |
+| --- | --- | --- | --- | --- | --- |
+| [**FLUX.1**-dev]: 50 steps | 32.9217 | 1.0412 | INF | 3726.87 | 1.00x |
+| F8B0_W4MC0_R0.08 | 32.9871 | 1.0370 | 33.8317 | 2064.81 | 1.80x |
+| F8B0_W4MC2_R0.12 | 32.9535 | 1.0185 | 32.7346 | 1935.73 | 1.93x |
+| F8B0_W4MC3_R0.12 | 32.9234 | 1.0085 | 32.5385 | 1816.58 | 2.05x |
+| F4B0_W4MC3_R0.12 | 32.8981 | 1.0130 | 31.8031 | 1507.83 | 2.47x |
+| F4B0_W4MC4_R0.12 | 32.8384 | 1.0065 | 31.5292 | 1400.08 | 2.66x |
+
+## SOTA Performance
+
+The comparison between **cache-dit: DBCache** and algorithms such as Δ-DiT, Chipmunk, FORA, DuCa, TaylorSeer and FoCa is as follows. Now, in the comparison with a speedup ratio less than **4x**, cache-dit achieved the best accuracy. Surprisingly, cache-dit: DBCache still works in the extremely few-step distill model. For a complete benchmark, please refer to [📚Benchmarks](https://github.com/vipshop/cache-dit/raw/main/bench/). NOTE: Except for DBCache, other performance data are referenced from the paper [FoCa, arxiv.2508.16211](https://arxiv.org/pdf/2508.16211).
+
+| Method | TFLOPs(↓) | SpeedUp(↑) | ImageReward(↑) | Clip Score(↑) |
+| --- | --- | --- | --- | --- |
+| [**FLUX.1**-dev]: 50 steps | 3726.87 | 1.00× | 0.9898 | 32.404 |
+| [**FLUX.1**-dev]: 60% steps | 2231.70 | 1.67× | 0.9663 | 32.312 |
+| Δ-DiT(N=2) | 2480.01 | 1.50× | 0.9444 | 32.273 |
+| Δ-DiT(N=3) | 1686.76 | 2.21× | 0.8721 | 32.102 |
+| [**FLUX.1**-dev]: 34% steps | 1264.63 | 3.13× | 0.9453 | 32.114 |
+| Chipmunk | 1505.87 | 2.47× | 0.9936 | 32.776 |
+| FORA(N=3) | 1320.07 | 2.82× | 0.9776 | 32.266 |
+| **[DBCache(S)](https://github.com/vipshop/cache-dit)** | 1400.08 | **2.66×** | **1.0065** | 32.838 |
+| DuCa(N=5) | 978.76 | 3.80× | 0.9955 | 32.241 |
+| TaylorSeer(N=4,O=2) | 1042.27 | 3.57× | 0.9857 | 32.413 |
+| **[DBCache(S)+TS](https://github.com/vipshop/cache-dit)** | 1153.05 | **3.23×** | **1.0221** | 32.819 |
+| **[DBCache(M)](https://github.com/vipshop/cache-dit)** | 944.75 | **3.94×** | 0.9997 | 32.849 |
+| **[DBCache(M)+TS](https://github.com/vipshop/cache-dit)** | 944.75 | **3.94×** | **1.0107** | 32.865 |
+| **[FoCa(N=5): arxiv.2508.16211](https://arxiv.org/pdf/2508.16211)** | 893.54 | **4.16×** | 1.0029 | **32.948** |
+| [**FLUX.1**-dev]: 22% steps | 818.29 | 4.55× | 0.8183 | 31.772 |
+| FORA(N=7) | 670.14 | 5.55× | 0.7418 | 31.519 |
+| ToCa(N=12) | 644.70 | 5.77× | 0.7155 | 31.808 |
+| DuCa(N=10) | 606.91 | 6.13× | 0.8382 | 31.759 |
+| TeaCache(l=1.2) | 669.27 | 5.56× | 0.7394 | 31.704 |
+| TaylorSeer(N=7,O=2) | 670.44 | 5.54× | 0.9128 | 32.128 |
+| **[DBCache(F)](https://github.com/vipshop/cache-dit)** | 651.90 | **5.72x** | 0.9271 | 32.552 |
+| **[FoCa(N=8): arxiv.2508.16211](https://arxiv.org/pdf/2508.16211)** | 596.07 | 6.24× | 0.9502 | 32.706 |
+| **[DBCache(F)+TS](https://github.com/vipshop/cache-dit)** | 651.90 | **5.72x** | **0.9526** | 32.568 |
+| **[DBCache(U)+TS](https://github.com/vipshop/cache-dit)** | 505.47 | **7.37x** | 0.8645 | **32.719** |
+
+## Text2Image Distillation DrawBench
+
+Surprisingly, cache-dit: DBCache still works in the extremely few-step distill model. For example, **Qwen-Image-Lightning w/ 4 steps**, with the F16B16 configuration, the PSNR is 34.8163, the Clip Score is 35.6109, and the ImageReward is 1.2614. It maintained a relatively high precision.
+
+| Config | PSNR(↑) | Clip Score(↑) | ImageReward(↑) | TFLOPs(↓) | SpeedUp(↑) |
+|----------------------------|-----------|------------|--------------|----------|------------|
+| [**Lightning**]: 4 steps | INF | 35.5797 | 1.2630 | 274.33 | 1.00x |
+| F24B24_W2MC1_R0.8 | 36.3242 | 35.6224 | 1.2630 | 264.74 | 1.04x |
+| F16B16_W2MC1_R0.8 | 34.8163 | 35.6109 | 1.2614 | 244.25 | 1.12x |
+| F12B12_W2MC1_R0.8 | 33.8953 | 35.6535 | 1.2549 | 234.63 | 1.17x |
+| F8B8_W2MC1_R0.8 | 33.1374 | 35.7284 | 1.2517 | 224.29 | 1.22x |
+| F1B0_W2MC1_R0.8 | 31.8317 | 35.6651 | 1.2397 | 206.90 | 1.33x |
diff --git a/docs/benchmark/NVIDIA_GPU.md b/docs/benchmark/NVIDIA_GPU.md
new file mode 100644
index 000000000..a0a2b4843
--- /dev/null
+++ b/docs/benchmark/NVIDIA_GPU.md
@@ -0,0 +1,101 @@
+# NVIDIA GPU Benchmark
+
+## NVIDIA L20
+
+### Z-Image-ControlNet: Hybrid Cache + Parallelism
+
+|Z-Image-ControlNet| Context Parallel: Ulysses 2 | Context Parallel: Ulysses 4 | + ControlNet Parallel |
+|:---:|:---:|:---:|:---:|
+|Base L20x1: 22s|15.7s|12.7s|**🚀7.71s**|
+| | | | |
+| **+ Hybrid Cache** | **+ Torch Compile** | **+ Async Ulyess CP** | **+ FP8 All2All + CUDNN ATTN** |
+|**🚀6.85s**|6.45s|6.38s|**🚀6.19s, 5.47s**|
+| | | |
+
+### FLUX.1-dev: Hybrid Cache + Parallelism
+
+|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|
+|:---:|:---:|:---:|:---:|:---:|
+|24.85s|15.59s|8.58s|15.41s|15.11s|
+| | | | | |
+|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|
+|27.85s|6.04s|5.88s|5.77s|6.01s|
+| | | | | |
+
+### UAA: Ulysses Anything Attention
+
+#### Qwen-Image & FLUX.1-dev
+
+
+ ✅Any Sequence Length
+ U*: Ulysses Attention, UAA: Ulysses Anything Attenton , UAA*: UAA + Gloo, Device: NVIDIA L20
+ FLUX.1-Dev w/o CPU Offload, 28 steps; Qwen-Image w/ CPU Offload, 50 steps; Gloo: Extra All Gather w/ Gloo
+
+
+|CP2 w/ U* |CP2 w/ UAA* | CP2 w/ UAA | L20x1 | CP2 w/ UAA* | CP2 w/ U* | L20x1 | CP2 w/ UAA* |
+|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
+|FLUX, 13.87s|**🎉13.88s**|14.75s|23.25s| **🎉13.75s**|Qwen, 132s|181s|**🎉133s**|
+| | | | | | | | |
+|1024x1024|1024x1024|1024x1024|1008x1008|1008x1008|1312x1312|1328x1328|1328x1328|
+|✔️U* ✔️UAA|✔️U* ✔️UAA|✔️U* ✔️UAA| NO CP|❌U* ✔️UAA|✔️U* ✔️UAA|NO CP|❌U* ✔️UAA|
+
+#### Z-Image-Turbo
+
+
+ ✅Any Head Num
+ Ulysses: Ulysses Attention, FP8 Ulysses: Ulysses w/ FP8 All2All , Device: NVIDIA L20
+ 🔥Z-Image (Head=30, ❌CAN NOT divisible by 4), 1024x1024, 9 steps.
+
+
+|Ulysses 2, L20|Ulysses 4|FP8 Ulysses 4| + Cache | + FP8 DQ |
+|:---:|:---:|:---:|:---:|:---:|
+|1024x1024, 3.19s|1024x1024, 1.98s|1024x1024, 1.89s|1024x1024, 1.63s|1024x1024, 1.23s|
+| | | | | |
+
+### Async Ulysses QKV Projection
+
+#### FLUX.1-dev
+
+
+ Ulysses: Standard Ulysses Attention, Async Ulysses : Ulysses Attenton with Async QKV Projection
+
+
+|L20x2 w/ Ulysses| w/ Async Ulysses|w/ Ulysses + compile| w/ Async Ulysses + compile|
+|:---:|:---:|:---:|:---:|
+|FLUX.1, 13.87s|**🎉13.20s**|12.21s|**🎉11.97s**|
+| | | |
+
+
+### Async FP8 Ulysses Attention
+
+#### FLUX.1-dev
+
+|L20x2 w/ Ulysses| w/ Ulysses FP8|w/ Ulysses + compile|w/ Ulysses FP8 + compile|
+|:---:|:---:|:---:|:---:|
+|FLUX.1, 13.87s|**🎉13.36s**|12.21s|**🎉11.54s**|
+| | | | |
+
+
+## NVIDIA H100
+
+|Model|Baseline H100x1|Ulysses 2| + FA3| + cache| + compile|
+|:---:|:---:|:---:|:---:|:---:|:---:|
+|FLUX.1-dev: 50 steps | 9.30s | 6.04s | 5.99s | 2.60s | 1.92s |
+|Qwen-Image: 50 steps | 18.49s | 12.81s | 12.75s | 5.67s | 4.20s |
+
+Reproduce command:
+
+```shell
+# FLUX.1-dev: 50 steps
+python3 generate.py flux --steps 50
+torchrun --nproc_per_node=2 generate.py flux --steps 50 --parallel ulysses
+torchrun --nproc_per_node=2 generate.py flux --steps 50 --parallel ulysses --attn _flash_3
+torchrun --nproc_per_node=2 generate.py flux --steps 50 --parallel ulysses --attn _flash_3 --cache
+torchrun --nproc_per_node=2 generate.py flux --steps 50 --parallel ulysses --attn _flash_3 --cache --compile
+# Qwen-Image: 50 steps
+python3 generate.py qwen_image --steps 50
+torchrun --nproc_per_node=2 generate.py qwen_image --steps 50 --parallel ulysses
+torchrun --nproc_per_node=2 generate.py qwen_image --steps 50 --parallel ulysses --attn _flash_3
+torchrun --nproc_per_node=2 generate.py qwen_image --steps 50 --parallel ulysses --attn _flash_3 --cache
+torchrun --nproc_per_node=2 generate.py qwen_image --steps 50 --parallel ulysses --attn _flash_3 --cache --compile
+```
diff --git a/docs/community_optimization.md b/docs/community_optimization.md
deleted file mode 100644
index 0aabc392b..000000000
--- a/docs/community_optimization.md
+++ /dev/null
@@ -1,270 +0,0 @@
-## CacheDiT
-
-CacheDiT is a unified, flexible, and training-free cache acceleration framework designed to support nearly all Diffusers' DiT-based pipelines. It provides a unified cache API that supports automatic block adapter, DBCache, and more.
-
-To learn more, refer to the [CacheDiT](https://github.com/vipshop/cache-dit) repository.
-
-Install a stable release of CacheDiT from PyPI or you can install the latest version from GitHub.
-
-
-
-
-```bash
-pip3 install -U cache-dit
-```
-
-
-
-
-```bash
-pip3 install git+https://github.com/vipshop/cache-dit.git
-```
-
-
-
-
-Run the command below to view supported DiT pipelines.
-
-```python
->>> import cache_dit
->>> cache_dit.supported_pipelines()
-(30, ['Flux*', 'Mochi*', 'CogVideoX*', 'Wan*', 'HunyuanVideo*', 'QwenImage*', 'LTX*', 'Allegro*',
-'CogView3Plus*', 'CogView4*', 'Cosmos*', 'EasyAnimate*', 'SkyReelsV2*', 'StableDiffusion3*',
-'ConsisID*', 'DiT*', 'Amused*', 'Bria*', 'Lumina*', 'OmniGen*', 'PixArt*', 'Sana*', 'StableAudio*',
-'VisualCloze*', 'AuraFlow*', 'Chroma*', 'ShapE*', 'HiDream*', 'HunyuanDiT*', 'HunyuanDiTPAG*'])
-```
-
-For a complete benchmark, please refer to [Benchmarks](https://github.com/vipshop/cache-dit/blob/main/bench/).
-
-
-## Unified Cache API
-
-CacheDiT works by matching specific input/output patterns as shown below.
-
-
-
-Call the `enable_cache()` function on a pipeline to enable cache acceleration. This function is the entry point to many of CacheDiT's features.
-
-```python
-import cache_dit
-from diffusers import DiffusionPipeline
-
-# Can be any diffusion pipeline
-pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
-
-# One-line code with default cache options.
-cache_dit.enable_cache(pipe)
-
-# Just call the pipe as normal.
-output = pipe(...)
-
-# Disable cache and run original pipe.
-cache_dit.disable_cache(pipe)
-```
-
-## Automatic Block Adapter
-
-For custom or modified pipelines or transformers not included in Diffusers, use the `BlockAdapter` in `auto` mode or via manual configuration. Please check the [BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#automatic-block-adapter) docs for more details. Refer to [Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
-
-
-```python
-from cache_dit import ForwardPattern, BlockAdapter
-
-# Use 🔥BlockAdapter with `auto` mode.
-cache_dit.enable_cache(
- BlockAdapter(
- # Any DiffusionPipeline, Qwen-Image, etc.
- pipe=pipe, auto=True,
- # Check `📚Forward Pattern Matching` documentation and hack the code of
- # of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
- forward_pattern=ForwardPattern.Pattern_1,
- ),
-)
-
-# Or, manually setup transformer configurations.
-cache_dit.enable_cache(
- BlockAdapter(
- pipe=pipe, # Qwen-Image, etc.
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_1,
- ),
-)
-```
-
-Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, FLUX.1 (HiDream, Chroma, etc) contains `transformer_blocks` and `single_transformer_blocks` (with different forward patterns). The BlockAdapter is able to detect this hybrid pattern type as well.
-Refer to [FLUX.1](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_flux_adapter.py) as an example.
-
-```python
-# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
-# single_transformer_blocks have different forward patterns.
-cache_dit.enable_cache(
- BlockAdapter(
- pipe=pipe, # FLUX.1, etc.
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_1,
- ForwardPattern.Pattern_3,
- ],
- ),
-)
-```
-
-This also works if there is more than one transformer (namely `transformer` and `transformer_2`) in its structure. Refer to [Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples/pipeline/run_wan_2.2.py) as an example.
-
-## Patch Functor
-
-For any pattern not included in CacheDiT, use the Patch Functor to convert the pattern into a known pattern. You need to subclass the Patch Functor and may also need to fuse the operations within the blocks for loop into block `forward`. After implementing a Patch Functor, set the `patch_functor` property in `BlockAdapter`.
-
-
-
-Some Patch Functors are already provided in CacheDiT, [HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/patch_functors/functor_hidream.py), [ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/patch_functors/functor_chroma.py), etc.
-
-```python
-@BlockAdapterRegister.register("HiDream")
-def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import HiDreamImageTransformer2DModel
- from cache_dit.caching.patch_functors import HiDreamPatchFunctor
-
- assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.double_stream_blocks,
- pipe.transformer.single_stream_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_0,
- ForwardPattern.Pattern_3,
- ],
- # NOTE: Setup your custom patch functor here.
- patch_functor=HiDreamPatchFunctor(),
- **kwargs,
- )
-```
-
-Finally, you can call the `cache_dit.summary()` function on a pipeline after its completed inference to get the cache acceleration details.
-
-```python
-stats = cache_dit.summary(pipe)
-```
-
-```python
-⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
-
-| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
-|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
-| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
-```
-
-## DBCache: Dual Block Cache
-
-
-
-DBCache (Dual Block Caching) supports different configurations of compute blocks (F8B12, etc.) to enable a balanced trade-off between performance and precision.
-- Fn_compute_blocks: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
-- Bn_compute_blocks: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
-
-
-```python
-import cache_dit
-from diffusers import FluxPipeline
-
-pipe_or_adapter = FluxPipeline.from_pretrained(
- "black-forest-labs/FLUX.1-dev",
- torch_dtype=torch.bfloat16,
-).to("cuda")
-
-# Default options, F8B0, 8 warmup steps, and unlimited cached
-# steps for good balance between performance and precision
-cache_dit.enable_cache(pipe_or_adapter)
-
-# Custom options, F8B8, higher precision
-from cache_dit import DBCacheConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBCacheConfig(
- max_warmup_steps=8, # steps do not cache
- max_cached_steps=-1, # -1 means no limit
- Fn_compute_blocks=8, # Fn, F8, etc.
- Bn_compute_blocks=8, # Bn, B8, etc.
- residual_diff_threshold=0.12,
- ),
-)
-```
-Check the [DBCache](https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md) and [User Guide](https://github.com/vipshop/cache-dit/blob/main/docs/User_Guide.md#dbcache) docs for more design details.
-
-## TaylorSeer Calibrator
-
-The [TaylorSeers](https://huggingface.co/papers/2503.06923) algorithm further improves the precision of DBCache in cases where the cached steps are large (Hybrid TaylorSeer + DBCache). At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
-
-TaylorSeer employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. F_pred can be a residual cache or a hidden-state cache.
-
-```python
-from cache_dit import DBCacheConfig, TaylorSeerCalibratorConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- # Basic DBCache w/ FnBn configurations
- cache_config=DBCacheConfig(
- max_warmup_steps=8, # steps do not cache
- max_cached_steps=-1, # -1 means no limit
- Fn_compute_blocks=8, # Fn, F8, etc.
- Bn_compute_blocks=8, # Bn, B8, etc.
- residual_diff_threshold=0.12,
- ),
- # Then, you can use the TaylorSeer Calibrator to approximate
- # the values in cached steps, taylorseer_order default is 1.
- calibrator_config=TaylorSeerCalibratorConfig(
- taylorseer_order=1,
- ),
-)
-```
-
-> [!TIP]
-> The `Bn_compute_blocks` parameter of DBCache can be set to `0` if you use TaylorSeer as the calibrator for approximate hidden states. DBCache's `Bn_compute_blocks` also acts as a calibrator, so you can choose either `Bn_compute_blocks` > 0 or TaylorSeer. We recommend using the configuration scheme of TaylorSeer + DBCache FnB0.
-
-## Hybrid Cache CFG
-
-CacheDiT supports caching for CFG (classifier-free guidance). For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG in the forward step, please set `enable_separate_cfg` parameter to `False (default, None)`. Otherwise, set it to `True`.
-
-```python
-from cache_dit import DBCacheConfig
-
-cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=DBCacheConfig(
- ...,
- # For example, set it as True for Wan 2.1, Qwen-Image
- # and set it as False for FLUX.1, HunyuanVideo, etc.
- enable_separate_cfg=True,
- ),
-)
-```
-
-## torch.compile
-
-CacheDiT is designed to work with torch.compile for even better performance. Call `torch.compile` after enabling the cache.
-
-
-```python
-cache_dit.enable_cache(pipe)
-
-# Compile the Transformer module
-pipe.transformer = torch.compile(pipe.transformer)
-```
-
-If you're using CacheDiT with dynamic input shapes, consider increasing the `recompile_limit` of `torch._dynamo`. Otherwise, the `recompile_limit` error may be triggered, causing the module to fall back to eager mode.
-
-```python
-torch._dynamo.config.recompile_limit = 96 # default is 8
-torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
-```
-
-Please check [perf.py](https://github.com/vipshop/cache-dit/blob/main/bench/perf.py) for more details.
diff --git a/docs/CONTRIBUTE.md b/docs/developer_guide/PRE_COMMIT.md
similarity index 52%
rename from docs/CONTRIBUTE.md
rename to docs/developer_guide/PRE_COMMIT.md
index 68e56be40..e37df4408 100644
--- a/docs/CONTRIBUTE.md
+++ b/docs/developer_guide/PRE_COMMIT.md
@@ -1,6 +1,6 @@
-# Developer Guide
+# Prepare before commit
-## 👨💻Pre-commit
+## 👨💻 Run Pre-commit
Before submitting code, configure pre-commit, for example:
@@ -16,7 +16,7 @@ pre-commit install
pre-commit run --all-files
```
-## 👨💻Add a new feature
+## 👨💻 Add a new feature
```bash
# feat: support xxx-cache method
@@ -26,3 +26,20 @@ git commit -m "support xxx-cache method"
git push
# then, open a PR from your personal branch to cache-dit:main
```
+
+## 👨💻 Check MKDocs
+
+Please also check the mkdocs build status on your local branch.
+```bash
+pip3 install -e ".[docs]"
+mkdocs build --strict
+mkdocs serve # Then check the docs
+```
+
+Ensure that your new commits do not break the mkdocs build process.
+
+```bash
+INFO - Cleaning site directory
+INFO - Building documentation to directory: /workspace/dev/vipshop/cache-dit/site
+INFO - Documentation built in 0.97 seconds
+```
diff --git a/docs/developer_guide/SUPPORT_NEW_MODEL.md b/docs/developer_guide/SUPPORT_NEW_MODEL.md
new file mode 100644
index 000000000..66a57bea3
--- /dev/null
+++ b/docs/developer_guide/SUPPORT_NEW_MODEL.md
@@ -0,0 +1,75 @@
+# Support New Model
+
+Please make sure you have install and initialize pre-commit before adding any new commit. Refer [PRE_COMMIT](PRE_COMMIT.md) for more details.
+
+## Cache Acceleration
+
+In order to support cache acceleration for new model, we have to register it's BlockAdapter at [caching/block_adapters/adapter.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/block_adapters/adapters.py) and use `_safe_import` func to import it at [caching/block_adapters/\_\_init\_\_.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/block_adapters/__init__.py). For example:
+
+- step 1: Implement the `qwenimage_adapter` at [caching/block_adapters/adapter.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/block_adapters/adapters.py)
+
+```python
+@BlockAdapterRegister.register("QwenImage")
+def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
+ try:
+ from diffusers import QwenImageTransformer2DModel
+ except ImportError:
+ QwenImageTransformer2DModel = None # requires diffusers>=0.35.2
+
+ _relaxed_assert(pipe.transformer, QwenImageTransformer2DModel)
+
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_1,
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+```
+
+- step 2: use `_safe_import` to import it at [caching/block_adapters/\_\_init\_\_.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/block_adapters/__init__.py).
+
+```python
+qwenimage_adapter = _safe_import(".adapters", "qwenimage_adapter")
+```
+
+
+## Context Parallelism
+
+In order to support context parallelism for new model, we have to register it's ContextParallelismPlanner at [context_parallelism](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/transformers/context_parallelism) and use `_safe_import` func to import it at [cp_planners.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/transformers/context_parallelism/cp_planners.py). For example:
+
+- step 1: Implement the `FluxContextParallelismPlanner`
+ at FLUX.1 CP planner at [cp_plan_flux.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_flux.py)
+- step 2: use `_safe_import` func to import it at [cp_planners.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/transformers/context_parallelism/cp_planners.py).
+
+## Tensor Parallelism
+
+In order to support tensor parallelism for new model, we have to register it's TensorParallelismPlanner at [tensor_parallelism](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/transformers/tensor_parallelism) and use `_safe_import` func to import it at [tp_planners.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_planners.py). For example:
+
+- step 1: Implement the `FluxTensorParallelismPlanner`
+ at FLUX.1 TP planner at [tp_plan_flux.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_flux.py)
+- step 2: use `_safe_import` func to import it at [tp_planners.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_planners.py).
+
+## Text Encoder Parallelism
+
+In order to support text encoder tensor parallelism for new model, we have to register it's TextEncoderTensorParallelismPlanner at [text_encoders/tensor_parallelism](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/text_encoders/tensor_parallelism) and use `_safe_import` func to import it at [text_encoders/tensor_parallelism/tp_planners.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_planners.py). For example:
+
+- step 1: Implement the `T5EncoderTensorParallelismPlanner`
+ at T5 TP planner at [tp_plan_t5_encoder.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_t5_encoder.py)
+- step 2: use `_safe_import` func to import it at [text_encoders/tensor_parallelism/tp_planners.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_planners.py).
+
+
+## Auto Encoder (VAE) Parallelism
+
+In order to support auto encoder (VAE) data parallelism for new model, we have to register it's AutoEncoderDateParallelismPlanner at [autoencoders/data_parallelism](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/autoencoders/data_parallelism) and use `_safe_import` func to import it at [autoencoders/data_parallelism/dp_planners.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/autoencoders/dp_parallelism/dp_planners.py). For example:
+
+- step 1: Implement the `AutoencoderKLDataParallelismPlanner`
+ at AutoencoderKL DP planner at [dp_plan_autoencoder_kl.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl.py)
+- step 2: use `_safe_import` func to import it at [autoencoders/data_parallelism/dp_planners.py](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_planners.py).
+
+
+## Examples and Tests
+
+Once the acceleration support for the new model is completed, we should add the new models to the [Examples](https://github.com/vipshop/cache-dit/blob/main/examples) and perform the necessary tests.
diff --git a/docs/supported_matrix/ASCEND_NPU.md b/docs/supported_matrix/ASCEND_NPU.md
new file mode 100644
index 000000000..44640049f
--- /dev/null
+++ b/docs/supported_matrix/ASCEND_NPU.md
@@ -0,0 +1,72 @@
+# Ascend NPU Supported Matrix
+
+
+
+Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Examples](https://github.com/vipshop/cache-dit/blob/main/examples) for more details. Here are just some of the tested models listed.
+
+Theoretically, almost all models supported by Cache-DiT can run on Ascend NPU. Here, only some of the models we have tested are listed. We will continue testing more models for Ascend NPU, so stay tuned for updates!
+
+## Transformer Optimization
+
+|📚Models|Hybrid Cache|Context Parallel|Tensor Parallel|
+|:---|:---:|:---:|:---:|
+|FLUX.2-Klein-4B|✅|✅|✅|
+|FLUX.2-Klein-base-4B|✅|✅|✅|
+|FLUX.2-Klein-9B|✅|✅|✅|
+|FLUX.2-Klein-base-9B|✅|✅|✅|
+|FLUX.2-dev|✅|✅|✅|
+|FLUX.1-dev|✅|✅|✅|
+|FLUX.1-Fill-dev|✅|✅|✅|
+|FLUX.1-Kontext-dev|✅|✅|✅|
+|Z-Image-Turbo*|✅|✅|✅|
+|Qwen-Image|✅|✅|✅|
+|Qwen-Image-Layered|✅|✅|✅|
+|Qwen-Image-2512|✅|✅|✅|
+|Qwen-Image-Edit|✅|✅|✅|
+|Qwen-Image-Edit-2509|✅|✅|✅|
+|Qwen-Image-Edit-2511|✅|✅|✅|
+|Qwen-Image-Lightning|✅|✅|✅|
+|Qwen-Image-Edit-Lightning|✅|✅|✅|
+|Qwen-Image-Edit-2509-Lightning|✅|✅|✅|
+|Qwen-Image-Edit-2511-Lightning|✅|✅|✅|
+|Wan-2.2-T2V|✅|✅|✅|
+|Wan-2.2-I2V|✅|✅|✅|
+|Wan-2.1-T2V|✅|✅|✅|
+|Wan-2.1-I2V|✅|✅|✅|
+|LongCat-Image|✅|✅|✅|
+|LongCat-Image-Edit|✅|✅|✅|
+|Ovis-Image|✅|✅|✅|
+
+
+## Text Encoder & VAE Optimization
+
+|📚Models|Text Encoder Parallel|AutoEncoder(VAE) Parallel|
+|:---|:---:|:---:|
+|FLUX.2-Klein-4B|✅|✅|
+|FLUX.2-Klein-base-4B|✅|✅|
+|FLUX.2-Klein-9B|✅|✅|
+|FLUX.2-Klein-base-9B|✅|✅|
+|FLUX.2-dev|✅|✅|
+|FLUX.1-dev|✅|✅|
+|FLUX.1-Fill-dev|✅|✅|
+|FLUX.1-Kontext-dev|✅|✅|
+|Z-Image-Turbo*|✅|✅|
+|Qwen-Image|✅|✅|✅|
+|Qwen-Image-Layered|✅|✅|✅|
+|Qwen-Image-2512|✅|✅|✅|
+|Qwen-Image-Edit|✅|✅|
+|Qwen-Image-Edit-2509|✅|✅|
+|Qwen-Image-Edit-2511|✅|✅|✅|
+|Qwen-Image-Lightning|✅|✅|
+|Qwen-Image-Edit-Lightning|✅|✅|
+|Qwen-Image-Edit-2509-Lightning|✅|✅|
+|Qwen-Image-Edit-2511-Lightning|✅|✅|
+|Wan-2.2-T2V|✅|✅|
+|Wan-2.2-I2V|✅|✅|
+|Wan-2.1-T2V|✅|✅|
+|Wan-2.1-I2V|✅|✅|
+|LongCat-Image|✅|✅|
+|LongCat-Image-Edit|✅|✅|
+|Ovis-Image|✅|✅|
+
+Z-Image-Turbo*: Since diffusers does not support this model by NPU, you need to merge this PR into your local diffusers repo: https://github.com/huggingface/diffusers/pull/12979
diff --git a/docs/supported_matrix/NVIDIA_GPU.md b/docs/supported_matrix/NVIDIA_GPU.md
new file mode 100644
index 000000000..e7a8179b9
--- /dev/null
+++ b/docs/supported_matrix/NVIDIA_GPU.md
@@ -0,0 +1,174 @@
+# Supported Matrix
+
+
+
+Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Examples](https://github.com/vipshop/cache-dit/blob/main/examples) for more details. Here are just some of the tested models listed.
+
+## Transformers Optimization
+One Model Series may contain many pipelines. cache-dit applies optimizations at the Transformer level; thus,any pipelines that include the supported transformer are already supported by cache-dit. ✅: supported now; ✖️: not supported now; **[🤖Q](https://github.com/nunchaku-tech/nunchaku)**: **[nunchaku](https://github.com/nunchaku-tech/nunchaku)** w/ SVDQ W4A4;
+
+|📚Models: `🤗70+`|Hybrid Cache|Context Parallel|Tensor Parallel|
+|:---|:---:|:---:|:---:|
+|FLUX.2-Klein-4B|✅|✅|✅|
+|FLUX.2-Klein-base-4B|✅|✅|✅|
+|FLUX.2-Klein-9B|✅|✅|✅|
+|FLUX.2-Klein-base-9B|✅|✅|✅|
+|LTX-2-I2V|✅|✅|✅|
+|LTX-2-T2V|✅|✅|✅|
+|Qwen-Image-2512|✅|✅|✅|
+|Z-Image-Turbo `🤖Q`|✅|✅|✖️|
+|Qwen-Image-Layered|✅|✅|✅|
+|Qwen-Image-Edit-2511-Lightning|✅|✅|✅|
+|Qwen-Image-Edit-2511|✅|✅|✅|
+|LongCat-Image|✅|✅|✅|
+|LongCat-Image-Edit|✅|✅|✅|
+|Z-Image-Turbo|✅|✅|✅|
+|Z-Image-Turbo-Fun-ControlNet-2.0|✅|✅|✅|
+|Z-Image-Turbo-Fun-ControlNet-2.1|✅|✅|✅|
+|Ovis-Image|✅|✅|✅|
+|FLUX.2-dev|✅|✅|✅|
+|FLUX.1-dev|✅|✅|✅|
+|FLUX.1-Fill-dev|✅|✅|✅|
+|FLUX.1-Kontext-dev|✅|✅|✅|
+|Qwen-Image|✅|✅|✅|
+|Qwen-Image-Edit|✅|✅|✅|
+|Qwen-Image-Edit-2509|✅|✅|✅|
+|Qwen-Image-ControlNet|✅|✅|✅|
+|Qwen-Image-ControlNet-Inpainting|✅|✅|✅|
+|Qwen-Image-Lightning|✅|✅|✅|
+|Qwen-Image-Edit-Lightning|✅|✅|✅|
+|Qwen-Image-Edit-2509-Lightning|✅|✅|✅|
+|Wan-2.2-T2V|✅|✅|✅|
+|Wan-2.2-I2V|✅|✅|✅|
+|Wan-2.2-VACE-Fun|✅|✅|✅|
+|Wan-2.1-T2V|✅|✅|✅|
+|Wan-2.1-I2V|✅|✅|✅|
+|Wan-2.1-FLF2V|✅|✅|✅|
+|Wan-2.1-VACE|✅|✅|✅|
+|HunyuanImage-2.1|✅|✅|✅|
+|HunyuanVideo-1.5|✅|✖️|✖️|
+|HunyuanVideo|✅|✅|✅|
+|FLUX.1-dev `🤖Q`|✅|✅|✖️|
+|FLUX.1-Fill-dev `🤖Q`|✅|✅|✖️|
+|FLUX.1-Kontext-dev `🤖Q`|✅|✅|✖️|
+|Qwen-Image `🤖Q`|✅|✅|✖️|
+|Qwen-Image-Edit `🤖Q`|✅|✅|✖️|
+|Qwen-Image-Edit-2509 `🤖Q`|✅|✅|✖️|
+|Qwen-Image-Lightning `🤖Q`|✅|✅|✖️|
+|Qwen-Image-Edit-Lightning `🤖Q`|✅|✅|✖️|
+|Qwen-Image-Edit-2509-Lightning `🤖Q`|✅|✅|✖️|
+|SkyReels-V2-T2V|✅|✅|✅|
+|LongCat-Video|✅|✖️|✖️|
+|ChronoEdit-14B|✅|✅|✅|
+|Kandinsky-5.0-T2V-Lite|✅|✅️|✅️|
+|PRX-512-t2i-sft|✅|✖️|✖️|
+|LTX-Video-v0.9.8|✅|✅|✅|
+|LTX-Video-v0.9.7|✅|✅|✅|
+|CogVideoX|✅|✅|✅|
+|CogVideoX-1.5|✅|✅|✅|
+|CogView-4|✅|✅|✅|
+|CogView-3-Plus|✅|✅|✅|
+|Chroma1-HD|✅|✅|✅|
+|PixArt-Sigma-XL-2-1024-MS|✅|✅|✅|
+|PixArt-XL-2-1024-MS|✅|✅|✅|
+|VisualCloze-512|✅|✅|✅|
+|ConsisID-preview|✅|✅|✅|
+|mochi-1-preview|✅|✖️|✅|
+|Lumina-Image-2.0|✅|✖️|✅|
+|HiDream-I1-Full|✅|✖️|✖️|
+|HunyuanDiT|✅|✖️|✅|
+|Sana-1600M-1024px|✅|✖️|✖️|
+|DiT-XL-2-256|✅|✅|✖️|
+|Allegro-T2V|✅|✖️|✖️|
+|OmniGen-2|✅|✖️|✖️|
+|stable-diffusion-3.5-large|✅|✖️|✅|
+|Amused-512|✅|✖️|✖️|
+|AuraFlow|✅|✖️|✖️|
+
+## Text Encoder & VAE Optimization
+
+|📚Models: `🤗70+`|Text Encoder Parallel|AutoEncoder(VAE) Parallel|
+|:---|:---:|:---:|
+|FLUX.2-Klein-4B|✅|✅|✅|
+|FLUX.2-Klein-base-4B|✅|✅|✅|
+|FLUX.2-Klein-9B|✅|✅|✅|
+|FLUX.2-Klein-base-9B|✅|✅|✅|
+|LTX-2-I2V|✅|✅|
+|LTX-2-T2V|✅|✅|
+|Qwen-Image-2512|✅|✅|
+|Z-Image-Turbo `🤖Q`|✅|✅|
+|Qwen-Image-Layered|✅|✅|
+|Qwen-Image-Edit-2511-Lightning|✅|✅|
+|Qwen-Image-Edit-2511|✅|✅|
+|LongCat-Image|✅|✅|
+|LongCat-Image-Edit|✅|✅|
+|Z-Image-Turbo|✅|✅|
+|Z-Image-Turbo-Fun-ControlNet-2.0|✅|✅|
+|Z-Image-Turbo-Fun-ControlNet-2.1|✅|✅|
+|Ovis-Image|✅|✅|
+|FLUX.2-dev|✅|✅|
+|FLUX.1-dev|✅|✅|
+|FLUX.1-Fill-dev|✅|✅|
+|FLUX.1-Kontext-dev|✅|✅|
+|Qwen-Image|✅|✅|
+|Qwen-Image-Edit|✅|✅|
+|Qwen-Image-Edit-2509|✅|✅|
+|Qwen-Image-ControlNet|✅|✅|
+|Qwen-Image-ControlNet-Inpainting|✅|✅|
+|Qwen-Image-Lightning|✅|✅|
+|Qwen-Image-Edit-Lightning|✅|✅|
+|Qwen-Image-Edit-2509-Lightning|✅|✅|
+|Wan-2.2-T2V|✅|✅|
+|Wan-2.2-I2V|✅|✅|
+|Wan-2.2-VACE-Fun|✅|✅|
+|Wan-2.1-T2V|✅|✅|
+|Wan-2.1-I2V|✅|✅|
+|Wan-2.1-FLF2V|✅|✅|
+|Wan-2.1-VACE|✅|✅|
+|HunyuanImage-2.1|✅|✖️|
+|HunyuanVideo-1.5|✅|✖️|
+|HunyuanVideo|✅|✅|
+|FLUX.1-dev `🤖Q`|✅|✅|
+|FLUX.1-Fill-dev `🤖Q`|✅|✅|
+|FLUX.1-Kontext-dev `🤖Q`|✅|✅|
+|Qwen-Image `🤖Q`|✅|✅|
+|Qwen-Image-Edit `🤖Q`|✅|✅|
+|Qwen-Image-Edit-2509 `🤖Q`|✅|✅|
+|Qwen-Image-Lightning `🤖Q`|✅|✅|
+|Qwen-Image-Edit-Lightning `🤖Q`|✅|✅|
+|Qwen-Image-Edit-2509-Lightning `🤖Q`|✅|✅|
+|SkyReels-V2-T2V|✅|✅|
+|ChronoEdit-14B|✅|✅|
+|Kandinsky-5.0-T2V-Lite|✅|✅|
+|PRX-512-t2i-sft|✅|✖️|
+|LTX-Video-v0.9.8|✅|✖️|
+|LTX-Video-v0.9.7|✅|✖️|
+|CogVideoX|✅|✖️|
+|CogVideoX-1.5|✅|✖️|
+|CogView-4|✅|✅|
+|CogView-3-Plus|✅|✅|
+|Chroma1-HD|✅|✅|
+|PixArt-Sigma-XL-2-1024-MS|✅|✅|
+|PixArt-XL-2-1024-MS|✅|✅|
+|VisualCloze-512|✅|✅|
+|ConsisID-preview|✅|✖️|
+|mochi-1-preview|✅|✖️|
+|Lumina-Image-2.0|✅|✅|
+|HiDream-I1-Full|✅|✅|
+|HunyuanDiT|✅|✅|
+|Sana-1600M-1024px|✅|✖️|
+|DiT-XL-2-256|✅|✅|
+|Allegro-T2V|✅|✖️|
+|OmniGen-2|✅|✅|
+|stable-diffusion-3.5-large|✖️|✅|
+|Amused-512|✅|✖️|
+|AuraFlow|✅|✅|
+
+## ControlNet Optimization
+
+|Models|ControlNet Parallel|
+|:---|:---:|
+|Z-Image-Turbo-Fun-ControlNet-2.0|✅|
+|Z-Image-Turbo-Fun-ControlNet-2.1|✅|
+|Qwen-Image-ControlNet|TODO|
+|Qwen-Image-ControlNet-Inpainting|TODO|
diff --git a/docs/user_guide/API_DOCS.md b/docs/user_guide/API_DOCS.md
new file mode 100644
index 000000000..03a3272ea
--- /dev/null
+++ b/docs/user_guide/API_DOCS.md
@@ -0,0 +1,118 @@
+# API Documentation
+
+
+
+Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks that match the specific Input and Output patterns). For a good balance between performance and precision, DBCache is configured by default with F8B0, 8 warmup steps, and unlimited cached steps. All the configurable params are listed beflows.
+
+## API: enable_cache
+
+```python
+def enable_cache(...) -> Union[DiffusionPipeline, BlockAdapter, Transformer]
+```
+
+## Function Description
+
+The `enable_cache` function serves as a unified caching interface designed to optimize the performance of diffusion transformer models by implementing an intelligent caching mechanism known as `DBCache`. This API is engineered to be compatible with nearly `all` diffusion transformer architectures that feature transformer blocks adhering to standard input-output patterns, eliminating the need for architecture-specific modifications.
+
+By strategically caching intermediate outputs of transformer blocks during the diffusion process, `DBCache` significantly reduces redundant computations without compromising generation quality. The caching mechanism works by tracking residual differences between consecutive steps, allowing the model to reuse previously computed features when these differences fall below a configurable threshold. This approach maintains a balance between computational efficiency and output precision.
+
+The default configuration (`F8B0, 8 warmup steps, unlimited cached steps`) is carefully tuned to provide an optimal tradeoff for most common use cases. The "F8B0" configuration indicates that the first 8 transformer blocks are used to compute stable feature differences, while no final blocks are employed for additional fusion. The warmup phase ensures the model establishes sufficient feature representation before caching begins, preventing potential degradation of output quality.
+
+This function seamlessly integrates with both standard diffusion pipelines and custom block adapters, making it versatile for various deployment scenarios—from research prototyping to production environments where inference speed is critical. By abstracting the complexity of caching logic behind a simple interface, it enables developers to enhance model performance with minimal code changes.
+
+## Quick Start
+
+```python
+>>> import cache_dit
+>>> from diffusers import DiffusionPipeline
+>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image") # Can be any diffusion pipeline
+>>> cache_dit.enable_cache(pipe) # One-line code with default cache options.
+>>> output = pipe(...) # Just call the pipe as normal.
+>>> stats = cache_dit.summary(pipe) # Then, get the summary of cache acceleration stats.
+>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
+```
+
+## Parameter Description
+
+- **pipe_or_adapter**(`DiffusionPipeline`, `BlockAdapter` or `Transformer`, *required*):
+ The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
+ For example: `cache_dit.enable_cache(FluxPipeline(...))`.
+
+- **cache_config**(`DBCacheConfig`, *required*, defaults to DBCacheConfig()):
+ Basic DBCache config for cache context, defaults to DBCacheConfig(). The configurable parameters are listed below:
+ - `Fn_compute_blocks`: (`int`, *required*, defaults to 8):
+ Specifies that `DBCache` uses the**first n**Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 difference and delivering more accurate information to subsequent blocks.
+ Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md for more details of DBCache.
+ - `Bn_compute_blocks`: (`int`, *required*, defaults to 0):
+ Further fuses approximate information in the**last n**Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
+ - `residual_diff_threshold`: (`float`, *required*, defaults to 0.08):
+ The value of residual difference threshold, a higher value leads to faster performance at the cost of lower precision.
+ - `max_accumulated_residual_diff_threshold`: (`float`, *optional*, defaults to None):
+ The maximum accumulated relative l1 diff threshold for Cache. If set, when the
+ accumulated relative l1 diff exceeds this threshold, the caching strategy will be
+ disabled for current step. This is useful for some cases where the input condition
+ changes significantly in a single step. Default None means this feature is disabled.
+ - `max_warmup_steps`: (`int`, *required*, defaults to 8):
+ DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
+ - `warmup_interval`: (`int`, *required*, defaults to 1):
+ Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
+ in warmup steps will be computed, others will use dynamic cache.
+ - `max_cached_steps`: (`int`, *required*, defaults to -1):
+ DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
+ - `max_continuous_cached_steps`: (`int`, *required*, defaults to -1):
+ DBCache disables the caching strategy when the previous continuous cached steps exceed this value to prevent precision degradation.
+ - `enable_separate_cfg`: (`bool`, *required*, defaults to None):
+ Whether to use separate cfg or not, such as in Wan 2.1, Qwen-Image. For models that fuse CFG and non-CFG into a single forward step, set enable_separate_cfg as False. Examples include: CogVideoX, HunyuanVideo, Mochi, etc.
+ - `cfg_compute_first`: (`bool`, *required*, defaults to False):
+ Whether to compute cfg forward first, default is False, meaning:
+ 0, 2, 4, ... -> non-CFG step; 1, 3, 5, ... -> CFG step.
+ - `cfg_diff_compute_separate`: (`bool`, *required*, defaults to True):
+ Whether to compute separate difference values for CFG and non-CFG steps, default is True. If False, we will use the computed difference from the current non-CFG transformer step for the current CFG step.
+ - `num_inference_steps` (`int`, *optional*, defaults to None):
+ num_inference_steps for DiffusionPipeline, used to adjust some internal settings
+ for better caching performance. For example, we will refresh the cache once the
+ executed steps exceed num_inference_steps if num_inference_steps is provided.
+ - `steps_computation_mask`: (`List[int]`, *optional*, defaults to None):
+ This param introduce LeMiCa/EasyCache style compute mask for steps. It is a list
+ of length num_inference_steps indicating whether to compute each step or not.
+ 1 means must compute, 0 means use dynamic/static cache. If provided, will override
+ other settings to decide whether to compute each step.
+ - `steps_computation_policy`: (`str`, *optional*, defaults to "dynamic"):
+ The computation policy for steps when using steps_computation_mask. It can be
+ "dynamic" or "static". "dynamic" means using dynamic cache for steps marked as 0
+ in steps_computation_mask, while "static" means using static cache for those steps.
+
+- **calibrator_config** (`CalibratorConfig`, *optional*, defaults to None):
+ Config for calibrator. If calibrator_config is not None, it means the user wants to use DBCache with a specific calibrator, such as taylorseer, foca, and so on.
+
+- **params_modifiers** ('ParamsModifier', *optional*, defaults to None):
+ Modify cache context parameters for specific blocks. The configurable parameters are listed below:
+ - `cache_config`: (`DBCacheConfig`, *required*, defaults to DBCacheConfig()):
+ The same as the 'cache_config' parameter in the cache_dit.enable_cache() interface.
+ - `calibrator_config`: (`CalibratorConfig`, *optional*, defaults to None):
+ The same as the 'calibrator_config' parameter in the cache_dit.enable_cache() interface.
+ - `**kwargs`: (`dict`, *optional*, defaults to {}):
+ The same as the 'kwargs' parameter in the cache_dit.enable_cache() interface.
+
+- **parallelism_config** (`ParallelismConfig`, *optional*, defaults to None):
+ Config for Parallelism. If parallelism_config is not None, it means the user wants to enable
+ parallelism for cache-dit.
+ - `backend`: (`ParallelismBackend`, *required*, defaults to "ParallelismBackend.NATIVE_DIFFUSER"):
+ Parallelism backend, currently only NATIVE_DIFFUSER and NVTIVE_PYTORCH are supported.
+ For context parallelism, only NATIVE_DIFFUSER backend is supported, for tensor parallelism,
+ only NATIVE_PYTORCH backend is supported.
+ - `ulysses_size`: (`int`, *optional*, defaults to None):
+ The size of Ulysses cluster. If ulysses_size is not None, enable Ulysses style parallelism.
+ This setting is only valid when backend is NATIVE_DIFFUSER.
+ - `ring_size`: (`int`, *optional*, defaults to None):
+ The size of ring for ring parallelism. If ring_size is not None, enable ring attention.
+ This setting is only valid when backend is NATIVE_DIFFUSER.
+ - `tp_size`: (`int`, *optional*, defaults to None):
+ The size of tensor parallelism. If tp_size is not None, enable tensor parallelism.
+ This setting is only valid when backend is NATIVE_PYTORCH.
+ - `parallel_kwargs`: (`dict`, *optional*, defaults to {}):
+ Additional kwargs for parallelism backends. For example, for NATIVE_DIFFUSER backend,
+ it can include `cp_plan` and `attention_backend` arguments for `Context Parallelism`.
+
+- **kwargs** (`dict`, *optional*, defaults to {}):
+ Other cache context keyword arguments. Please check https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/cache_contexts/cache_context.py for more details.
diff --git a/docs/user_guide/ASCEND_NPU.md b/docs/user_guide/ASCEND_NPU.md
new file mode 100644
index 000000000..9b5eb6c7d
--- /dev/null
+++ b/docs/user_guide/ASCEND_NPU.md
@@ -0,0 +1,204 @@
+# Ascend NPU Support
+
+🔥We are excited to announce that Cache-DiT now provides **native** support for **Ascend NPU**. Theoretically, **nearly all** models supported by Cache-DiT can run on Ascend NPU with most of Cache-DiT’s optimization technologies, including:
+
+- **Hybrid Cache Acceleration** ([**DBCache**](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/#dbcache-dual-block-cache), DBPrune, [**TaylorSeer**](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/#hybrid-taylorseer-calibrator), [**SCM**](https://cache-dit.readthedocs.io/en/latest/user_guide/CACHE_API/#scm-steps-computation-masking) and more)
+- **Context Parallelism** (w/ Extended Diffusers' CP APIs, [**UAA**](https://cache-dit.readthedocs.io/en/latest/user_guide/CONTEXT_PARALLEL/#uaa-ulysses-anything-attention), Async Ulysses, ...)
+- **Tensor Parallelism** (w/ PyTorch native DTensor and Tensor Parallelism APIs)
+- **Text Encoder Parallelism** (w/ PyTorch native DTensor and Tensor Parallelism APIs)
+- **Auto Encoder (VAE) Parallelism** (w/ Data or Tile Parallelism, avoid OOM)
+- **ControlNet Parallelism** (w/ Context Parallelism for ControlNet module)
+- Built-in **HTTP serving** deployment support with simple REST APIs
+
+Please refer to **[Ascend NPU Supported Matrix](../supported_matrix/ASCEND_NPU.md)** for more details.
+
+## Features Support
+
+|Device|Hybrid Cache|Context Parallel|Tensor Parallel|Text Encoder Parallel|Auto Encoder(VAE) Parallel|
+|:---|:---:|:---:|:---:|:---:|:---:|
+|Atlas 800T A2|✅|✅|✅|✅|✅|
+|Atlas 800I A2|✅|✅|✅|✅|✅|
+
+## Attention backend
+
+Cache-DiT supports multiple Attention backends for better performance. The supported attention backends for Ascend NPU list is as follows:
+
+|backend|details|parallelism|attn_mask|
+|:---|:---|:---|:---|
+|native| Native SDPA Attention in PyTorch|✅|✅|
+|_native_npu| Optimized Ascend NPU Attention|✅|✅|
+
+We strongly recommend using the `_native_npu` backend to achieve better performance.
+
+## Environment Requirements
+
+There are two installation methods:
+
+- **Using pip**: first prepare env manually or via CANN image, then install `cache-dit` using pip.
+- **Using docker**: use the [Ascend NPU community: vllm-ascend](https://quay.io/repository/ascend/vllm-ascend?tab=tags) pre-built docker image as the base image for **cache-dit** directly. (**Recommended**, no need for installing torch and torch_npu manually)
+
+## Install NPU SDKs Manually
+
+This section describes how to install NPU environment manually.
+
+### Requirements
+
+OS: Linux; Python: >= 3.10, < 3.12; A hardware with Ascend NPU. It's usually the Atlas 800 A2 series; Softwares:
+
+| Software | Supported version | Note |
+|---------------|----------------------------------|-------------------------------------------|
+| Ascend HDK | Refer to [here](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/releasenote/releasenote_0000.html) | Required for CANN |
+| CANN | == 8.3.RC2 | Required for cache-dit and torch-npu |
+| torch-npu | == 2.8.0 | Required for cache-dit|
+| torch | == 2.8.0 | Required for torch-npu and cache-dit |
+| NNAL | == 8.3.RC2 | Required for libatb.so, enables advanced tensor operations |
+
+
+### Configure CANN environment.
+
+Before installation, you need to make sure firmware/driver and CANN are installed correctly, refer to [Ascend Environment Setup Guide](https://ascend.github.io/docs/sources/ascend/quick_install.html) for more details. To verify that the Ascend NPU firmware and driver were correctly installed, run:
+
+```bash
+npu-smi info
+```
+
+Please refer to [Ascend Environment Setup Guide](https://ascend.github.io/docs/sources/ascend/quick_install.html) for more details.
+
+### Configure software environment.
+
+The easiest way to prepare your software environment is using CANN image directly. We recommend using the [Ascend NPU community: vllm-ascend](https://quay.io/repository/ascend/vllm-ascend?tab=tags) pre-built docker image as the base image of Ascend NPU for **cache-dit**. CANN image can be found in Ascend official community website: [here](https://www.hiascend.com/developer/ascendhub/detail/17da20d1c2b6493cb38765adeba85884). The CANN prebuilt image includes NNAL (Ascend Neural Network Acceleration Library) which provides libatb.so for advanced tensor operations. No additional installation is required when using the prebuilt image.
+
+```bash
+# Update DEVICE according to your device (/dev/davinci[0-7])
+export DEVICE=/dev/davinci7
+# Update the pre-built image
+export IMAGE=quay.io/ascend/cann:|cann_image_tag|
+docker run --rm \
+ --name cache-dit-ascend \
+ --shm-size=1g \
+ --device $DEVICE \
+ --device /dev/davinci_manager \
+ --device /dev/devmm_svm \
+ --device /dev/hisi_hdc \
+ -v /usr/local/dcmi:/usr/local/dcmi \
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
+ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
+ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
+ -v /root/.cache:/root/.cache \
+ -it $IMAGE bash
+```
+
+### Install PyTorch
+
+If install failed by using pip command, you can get `torch-2.8.0*.whl` file by [Link](https://download.pytorch.org/whl/torch/) and install manually.
+
+```bash
+# torch: aarch64
+pip3 install torch==2.8.0
+# torch: x86
+pip3 install torch==2.8.0+cpu --index-url https://download.pytorch.org/whl/cpu
+```
+
+### Install torch_npu
+
+Strongly recommend install torch_npu by acquire `torch_npu-2.8.0*.whl` file by [Link](https://gitcode.com/Ascend/pytorch/releases) and install manually. For more detail about Ascend Pytorch Adapter installation, please refer [https://gitcode.com/Ascend/pytorch](https://gitcode.com/Ascend/pytorch)
+
+### Install Extra Dependences
+
+```bash
+pip install --no-deps torchvision==0.16.0
+pip install einops sentencepiece accelerate
+```
+
+## Use prebuilt Docker Image
+
+We recommend using the prebuilt image from the [Ascend NPU community: vllm-ascend](https://quay.io/repository/ascend/vllm-ascend?tab=tags) as the base image of Ascend NPU for **cache-dit**. You can just pull the **prebuilt image** from the image [repository](https://quay.io/repository/ascend/vllm-ascend?tab=tags) and run it with bash. For example:
+
+```bash
+# Download pre-built image for Ascend NPU
+docker pull quay.io/ascend/vllm-ascend:v0.13.0rc1
+
+# Use the pre-built image for cache-dit
+docker run \
+ --name cache-dit-ascend \
+ --device /dev/davinci_manager \
+ --device /dev/devmm_svm \
+ --device /dev/hisi_hdc \
+ --net=host \
+ --shm-size=80g \
+ --privileged=true \
+ -v /usr/local/dcmi:/usr/local/dcmi \
+ -v /usr/local/bin/npu-smi:/usr/local/bin/npu-smi \
+ -v /usr/local/Ascend/driver/lib64/:/usr/local/Ascend/driver/lib64/ \
+ -v /usr/local/Ascend/driver/version.info:/usr/local/Ascend/driver/version.info \
+ -v /etc/ascend_install.info:/etc/ascend_install.info \
+ -v /data:/data \
+ -itd quay.io/ascend/vllm-ascend:v0.13.0rc1 bash
+```
+
+## Ascend Environment variables
+```bash
+# Make sure CANN_path is set to your CANN installation path
+# e.g., export CANN_path=/usr/local/Ascend
+source $CANN_path/ascend-toolkit/set_env.sh
+export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
+# Set NPU devices by ASCEND_RT_VISIBLE_DEVICES env
+export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
+```
+
+Once it is done, you can start to set up `cache-dit`.
+
+## Install Cache-DiT Library
+
+You can install the stable release of `cache-dit` from PyPI:
+
+```bash
+pip3 install -U cache-dit
+```
+Or you can install the latest develop version from GitHub:
+
+```bash
+pip3 install git+https://github.com/vipshop/cache-dit.git
+```
+Please also install the latest main branch of diffusers for context parallelism:
+```bash
+pip3 install git+https://github.com/huggingface/diffusers.git # or >= 0.36.0
+```
+
+## Exmaples and Benchmark
+
+After the environment configuration is complete, users can refer to the **[Quick Examples](../EXAMPLES.md)**, **[Ascend NPU Benchmark](../benchmark/ASCEND_NPU.md)** and **[Ascend NPU Supported Matrix](../supported_matrix/ASCEND_NPU.md)** for more details.
+
+```bash
+pip3 install opencv-python-headless einops imageio-ffmpeg ftfy
+pip3 install git+https://github.com/huggingface/diffusers.git # latest or >= 0.36.0
+pip3 install git+https://github.com/vipshop/cache-dit.git # latest
+
+git clone https://github.com/vipshop/cache-dit.git && cd cache-dit/examples
+```
+
+### Single NPU Inference
+
+The easiest way to enable hybrid cache acceleration for DiTs with cache-dit is to start with single NPU inference. For examples:
+
+```bash
+# use default model path, e.g, "black-forest-labs/FLUX.1-dev"
+python3 generate.py flux --attn _native_npu
+python3 generate.py qwen_image --attn _native_npu
+python3 generate.py flux --cache --attn _native_npu
+python3 generate.py qwen_image --cache --attn _native_npu
+```
+
+### Distributed Inference
+
+cache-dit is designed to work 🔥Context Parallelism, 🔥Tensor Parallelism. For examples:
+
+```bash
+torchrun --nproc_per_node=4 generate.py flux --parallel ulysses --attn _native_npu
+torchrun --nproc_per_node=4 generate.py zimage --parallel ulysses --attn _native_npu
+torchrun --nproc_per_node=4 generate.py qwen_image --parallel ulysses --attn _native_npu
+torchrun --nproc_per_node=4 generate.py flux --parallel ulysses --cache --attn _native_npu
+torchrun --nproc_per_node=4 generate.py zimage --parallel ulysses --cache --attn _native_npu
+torchrun --nproc_per_node=4 generate.py qwen_image --parallel ulysses --cache --attn _native_npu
+```
diff --git a/docs/user_guide/ATTENTION.md b/docs/user_guide/ATTENTION.md
new file mode 100644
index 000000000..4df12164f
--- /dev/null
+++ b/docs/user_guide/ATTENTION.md
@@ -0,0 +1,58 @@
+# Attention Backend
+
+## Available backend
+
+Cache-DiT supports multiple Attention backends for better performance. The supported list is as follows:
+
+|backend|details|parallelism|attn_mask|
+|:---|:---|:---|:---|
+|native| Native SDPA Attention, w/ cache-dit optimized|✅|✅|
+|_sdpa_cudnn| CUDNN Attention via SDPA API, w/ cache-dit optimized|✅|✅|
+|_native_cudnn| CUDNN Attention via SDPA API, w/o cache-dit optimized|✅|✖️|
+|flash| official FlashAttention-2|✅|✖️|
+|_flash_3| official FlashAttention-3|✅|✖️|
+|sage| FP8 SageAttention|✅|✖️|
+|_native_npu| Optimized Ascend NPU Attention|✅|✅|
+
+Users can specify Attention backend by setting the `attention_backend` parameter of `parallel_kwargs`:
+
+```python
+from cache_dit import ParallelismConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBCacheConfig(...),
+ parallelism_config=ParallelismConfig(
+ ulysses_size=2, # or, tp_size=2
+ parallel_kwargs={
+ # flash, native(sdpa), _native_cudnn, _sdpa_cudnn, sage
+ "attention_backend": "_sdpa_cudnn",
+ },
+ ),
+)
+```
+
+## FP8 Attention
+
+
+
+For FP8 Attention, users must install `sage-attention`. Then, pass the `sage` attention backend to the parallelism configuration as an extra parameter. Please note that `attention mask` is not currently supported for FP8 sage attention.
+
+```python
+# pip3 install "cache-dit[parallelism]"
+# pip3 install git+https://github.com/thu-ml/SageAttention.git
+from cache_dit import ParallelismConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBCacheConfig(...),
+ parallelism_config=ParallelismConfig(
+ ulysses_size=2, # or, tp_size=2
+ parallel_kwargs={
+ # flash, native(sdpa), _native_cudnn, _sdpa_cudnn, sage
+ "attention_backend": "sage",
+ },
+ ),
+)
+# torchrun --nproc_per_node=2 parallel_fp8_cache.py
+```
diff --git a/docs/user_guide/CACHE_API.md b/docs/user_guide/CACHE_API.md
new file mode 100644
index 000000000..9b92c093f
--- /dev/null
+++ b/docs/user_guide/CACHE_API.md
@@ -0,0 +1,536 @@
+# Unified Cache APIs
+
+
+
+## Forward Pattern Matching
+
+Currently, for any **Diffusion** models with **Transformer Blocks** that match the specific **Input/Output patterns**, we can use the **Unified Cache APIs** from **cache-dit**, namely, the `cache_dit.enable_cache(...)` API. The **Unified Cache APIs** are currently in the experimental phase; please stay tuned for updates. The supported patterns are listed as follows:
+
+
+
+## Cache Acceleration with One-line Code
+
+In most cases, you only need to call **one-line** of code, that is `cache_dit.enable_cache(...)`. After this API is called, you just need to call the pipe as normal. The `pipe` param can be **any** Diffusion Pipeline. Please refer to [Qwen-Image](https://github.com/vipshop/cache-dit/blob/main/examples/run_qwen_image.py) as an example.
+
+```python
+import cache_dit
+from diffusers import DiffusionPipeline
+
+# Can be any diffusion pipeline
+pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
+# One-line code with default cache options.
+cache_dit.enable_cache(pipe)
+# Just call the pipe as normal.
+output = pipe(...)
+# Disable cache and run original pipe.
+cache_dit.disable_cache(pipe)
+```
+
+## Automatic Block Adapter
+
+But in some cases, you may have a **modified** Diffusion Pipeline or Transformer that is not located in the diffusers library or not officially supported by **cache-dit** at this time. The **BlockAdapter** can help you solve this problems. Please refer to [🔥Qwen-Image w/ BlockAdapter](https://github.com/vipshop/cache-dit/blob/main/examples/adapter/run_qwen_image_adapter.py) as an example.
+
+```python
+from cache_dit import ForwardPattern, BlockAdapter
+
+# Use 🔥BlockAdapter with `auto` mode.
+cache_dit.enable_cache(
+ BlockAdapter(
+ # Any DiffusionPipeline, Qwen-Image, etc.
+ pipe=pipe, auto=True,
+ # Check `📚Forward Pattern Matching` documentation and hack the code of
+ # of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
+ forward_pattern=ForwardPattern.Pattern_1,
+ ),
+)
+
+# Or, manually setup transformer configurations.
+cache_dit.enable_cache(
+ BlockAdapter(
+ pipe=pipe, # Qwen-Image, etc.
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_1,
+ ),
+)
+```
+For such situations, **BlockAdapter** can help you quickly apply various cache acceleration features to your own Diffusion Pipelines and Transformers.
+
+## Hybrid Forward Pattern
+
+Sometimes, a Transformer class will contain more than one transformer `blocks`. For example, **FLUX.1** (HiDream, Chroma, etc) contains transformer_blocks and single_transformer_blocks (with different forward patterns). The **BlockAdapter** can also help you solve this problem.
+```python
+# For diffusers <= 0.34.0, FLUX.1 transformer_blocks and
+# single_transformer_blocks have different forward patterns.
+cache_dit.enable_cache(
+ BlockAdapter(
+ pipe=pipe, # FLUX.1, etc.
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_3,
+ ],
+ ),
+)
+```
+
+Even sometimes you have more complex cases, such as **Wan 2.2 MoE**, which has more than one Transformer (namely `transformer` and `transformer_2`) in its structure. Fortunately, **cache-dit** can also handle this situation very well. Please refer to [📚Wan 2.2 MoE](https://github.com/vipshop/cache-dit/blob/main/examples) as an example.
+
+```python
+from cache_dit import ForwardPattern, BlockAdapter, ParamsModifier, DBCacheConfig
+
+cache_dit.enable_cache(
+ BlockAdapter(
+ pipe=pipe,
+ transformer=[
+ pipe.transformer,
+ pipe.transformer_2,
+ ],
+ blocks=[
+ pipe.transformer.blocks,
+ pipe.transformer_2.blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_2,
+ ForwardPattern.Pattern_2,
+ ],
+ # Setup different cache params for each 'blocks'. You can
+ # pass any specific cache params to ParamModifier, the old
+ # value will be overwrite by the new one.
+ params_modifiers=[
+ ParamsModifier(
+ cache_config=DBCacheConfig().reset(
+ max_warmup_steps=4,
+ max_cached_steps=8,
+ ),
+ ),
+ ParamsModifier(
+ cache_config=DBCacheConfig().reset(
+ max_warmup_steps=2,
+ max_cached_steps=20,
+ ),
+ ),
+ ],
+ has_separate_cfg=True,
+ ),
+)
+```
+
+## Implement Patch Functor
+
+For any PATTERN not in {0...5}, we introduced the simple abstract concept of **Patch Functor**. Users can implement a subclass of Patch Functor to convert an unknown Pattern into a known PATTERN, and for some models, users may also need to fuse the operations within the blocks for loop into block forward.
+
+
+
+Some Patch functors have already been provided in cache-dit: [📚HiDreamPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/patch_functors/functor_hidream.py), [📚ChromaPatchFunctor](https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/caching/patch_functors/functor_chroma.py), etc. After implementing Patch Functor, users need to set the `patch_functor` property of **BlockAdapter**.
+
+```python
+@BlockAdapterRegister.register("HiDream")
+def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import HiDreamImageTransformer2DModel
+ from cache_dit.caching.patch_functors import HiDreamPatchFunctor
+
+ assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.double_stream_blocks,
+ pipe.transformer.single_stream_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_3,
+ ],
+ # NOTE: Setup your custom patch functor here.
+ patch_functor=HiDreamPatchFunctor(),
+ **kwargs,
+ )
+```
+
+## Transformer-Only Interface
+
+In some cases, users may **not use Diffusers or DiffusionPipeline** at all, and may not even have the concept of a "pipeline"—for instance, **ComfyUI** (which breaks down the pipeline into individual components while still retaining transformer components). cache-dit also supports such scenarios; it only needs to be configured via **BlockAdapter**. The pipeline is not mandatory, and you can simply keep it at the default value of None. In this case, the `num_inference_steps` parameter in cache_config **must be set**, as cache-dit relies on this parameter to refresh the cache context at the appropriate time. Please refer to [📚run_transformer_only.py](https://github.com/vipshop/cache-dit/blob/main/examples/api/run_transformer_only.py) as an example.
+
+```python
+cache_dit.enable_cache(
+ BlockAdapter(
+ # NO `pipe` required
+ transformer=transformer,
+ blocks=transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_1,
+ ),
+ cache_config=DBCacheConfig(
+ num_inference_steps=50 # required
+ ),
+)
+```
+
+If you need to use a **different** num_inference_steps for each user request instead of a fixed value, you should use it in conjunction with `refresh_context` API. Before performing inference for each user request, update the cache context based on the actual number of steps. Please refer to [📚run_cache_refresh](https://github.com/vipshop/cache-dit/blob/main/examples/api) as an example.
+
+```python
+import cache_dit
+from cache_dit import DBCacheConfig
+from diffusers import DiffusionPipeline
+
+# Init cache context with num_inference_steps=None (default)
+pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
+pipe = cache_dit.enable_cache(pipe.transformer, cache_config=DBCacheConfig(num_inference_steps=None))
+
+# Assume num_inference_steps is 28, and we want to refresh the context
+cache_dit.refresh_context(pipe.transformer, num_inference_steps=28, verbose=True)
+output = pipe(...) # Just call the pipe as normal.
+stats = cache_dit.summary(pipe.transformer) # Then, get the summary
+
+# Update the cache context with new num_inference_steps=50.
+cache_dit.refresh_context(pipe.transformer, num_inference_steps=50, verbose=True)
+output = pipe(...) # Just call the pipe as normal.
+stats = cache_dit.summary(pipe.transformer) # Then, get the summary
+
+# Update the cache context with new cache_config.
+cache_dit.refresh_context(
+ pipe.transformer,
+ cache_config=DBCacheConfig(
+ residual_diff_threshold=0.1,
+ max_warmup_steps=10,
+ max_cached_steps=20,
+ max_continuous_cached_steps=4,
+ # The cache settings should all be located in the cache config
+ # if cache config is provided. Otherwise, we will skip it.
+ num_inference_steps=50,
+ ),
+ verbose=True,
+)
+output = pipe(...) # Just call the pipe as normal.
+stats = cache_dit.summary(pipe.transformer) # Then, get the summary
+```
+
+## How to use ParamsModifier
+
+Sometimes you may encounter more complex cases, such as **Wan 2.2 MoE**, which has more than one Transformer (namely `transformer` and `transformer_2`), or FLUX.1, which has multiple transformer blocks (namely `single_transformer_blocks` and `transformer_blocks`). cache-dit will assign separate cache contexts for different `blocks` instances but share the same `cache_config` by default. Users who want to achieve fine-grained control over different cache contexts can consider using `ParamsModifier`. Just pass the `ParamsModifier` per `blocks` to the `BlockAdapter` or `enable_cache(...)` API. Then, the shared `cache_config` will be overwritten by the new configurations from the `ParamsModifier`. For example:
+
+```python
+from cache_dit import ParamsModifier
+
+cache_dit.enable_cache(
+ BlockAdapter(
+ pipe=pipe, # FLUX.1, etc.
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_3,
+ ],
+ ),
+ # Basic shared cache config
+ cache_config=DBCacheConfig(...),
+ params_modifiers=[
+ ParamsModifier(
+ # Modified config only for transformer_blocks
+ # Must call the `reset` method of DBCacheConfig.
+ cache_config=DBCacheConfig().reset(
+ Fn_compute_blocks=8,
+ residual_diff_threshold=0.08,
+ ),
+ ),
+ ParamsModifier(
+ # Modified config only for single_transformer_blocks
+ # NOTE: FLUX.1, single_transformer_blocks should have `higher`
+ # residual_diff_threshold because of the precision error
+ # accumulation from previous transformer_blocks
+ cache_config=DBCacheConfig().reset(
+ Fn_compute_blocks=1,
+ residual_diff_threshold=0.16,
+ ),
+ ),
+ ],
+)
+```
+
+## Cache Stats Summary
+
+After finishing each inference of `pipe(...)`, you can call the `cache_dit.summary()` API on pipe to get the details of the **Cache Acceleration Stats** for the current inference.
+```python
+stats = cache_dit.summary(pipe)
+```
+
+You can set `details` param as `True` to show more details of cache stats. (markdown table format) Sometimes, this may help you analyze what values of the residual diff threshold would be better.
+
+```python
+⚡️Cache Steps and Residual Diffs Statistics: QwenImagePipeline
+
+| Cache Steps | Diffs Min | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Max |
+|-------------|-----------|-----------|-----------|-----------|-----------|-----------|
+| 23 | 0.045 | 0.084 | 0.114 | 0.147 | 0.241 | 0.297 |
+```
+
+## Disable Cache Acceleration
+
+Users can call `cache_dit.disable_cache` API to disable and delete the all acceleration hooks from the optimized pipeline or block adapter.
+
+```python
+import cache_dit
+# Disable all acceleration and run the original pipe.
+cache_dit.disable_cache(pipe_or_adapter)
+```
+
+## DBCache: Dual Block Cache
+
+
+
+
+
+**DBCache**: **Dual Block Caching** for Diffusion Transformers. Different configurations of compute blocks (**F8B12**, etc.) can be customized in DBCache, enabling a balanced trade-off between performance and precision. Moreover, it can be entirely **training**-**free**. Please check [DBCache Design](./DBCACHE_DESIGN.md) docs for more design details.
+
+- **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
+- **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
+
+
+
+```python
+import cache_dit
+from diffusers import FluxPipeline
+
+pipe_or_adapter = FluxPipeline.from_pretrained(
+ "black-forest-labs/FLUX.1-dev",
+ torch_dtype=torch.bfloat16,
+).to("cuda")
+
+# Default options, F8B0, 8 warmup steps, and unlimited cached
+# steps for good balance between performance and precision
+cache_dit.enable_cache(pipe_or_adapter)
+
+# Custom options, F8B8, higher precision
+from cache_dit import DBCacheConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBCacheConfig(
+ max_warmup_steps=8, # steps do not cache
+ max_cached_steps=-1, # -1 means no limit
+ Fn_compute_blocks=8, # Fn, F8, etc.
+ Bn_compute_blocks=8, # Bn, B8, etc.
+ residual_diff_threshold=0.12,
+ ),
+)
+```
+
+
+
+ DBCache, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"
+
+
+
+
+|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|
+|:---:|:---:|:---:|:---:|:---:|
+|24.85s|15.59s|8.58s|15.41s|15.11s|
+| | | | | |
+|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|
+|27.85s|6.04s|5.88s|5.77s|6.01s|
+| | | | | |
+
+
+
+
+ DBCache, L20x4 , Steps: 20, case to show the texture recovery ability of DBCache
+
+
+
+These case studies demonstrate that even with relatively high thresholds (such as 0.12, 0.15, 0.2, etc.) under the DBCache **F12B12** or **F8B16** configuration, the detailed texture of the kitten's fur, colored cloth, and the clarity of text can still be preserved. This suggests that users can leverage DBCache to effectively balance performance and precision in their workflows!
+
+## DBPrune: Dynamic Block Prune
+
+
+
+
+
+
+We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, which is referred to as **DBPrune**. DBPrune caches each block's hidden states and residuals, then dynamically prunes blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals. DBPrune is currently in the experimental phase, and we kindly invite you to stay tuned for upcoming updates.
+
+```python
+from cache_dit import DBPruneConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBPruneConfig(
+ max_warmup_steps=8, # steps do not apply prune
+ residual_diff_threshold=0.12,
+ enable_dynamic_prune_threshold=True,
+ ),
+)
+```
+We have also brought the designs from DBCache to DBPrune to make it a more general and customizable block prune algorithm. You can specify the values of **Fn** and **Bn** for higher precision, or set up the non-prune blocks list **non_prune_block_ids** to avoid aggressive pruning. For example:
+
+```python
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBPruneConfig(
+ max_warmup_steps=8, # steps do not apply prune
+ Fn_compute_blocks=8, # Fn, F8, etc.
+ Bn_compute_blocks=8, # Bn, B8, etc
+ residual_diff_threshold=0.12,
+ enable_dynamic_prune_threshold=True,
+ non_prune_block_ids=list(range(16,24)),
+ ),
+)
+```
+
+
+ DBPrune, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"
+
+
+
+|Baseline(L20x1)|Pruned(24%)|Pruned(35%)|Pruned(45%)|Pruned(60%)|
+|:---:|:---:|:---:|:---:|:---:|
+|24.85s|19.43s|16.82s|14.24s|10.66s|
+| | | | | |
+
+## Hybrid Cache CFG
+
+
+
+cache-dit supports caching for **CFG (classifier-free guidance)**. For models that fuse CFG and non-CFG into a single forward step, or models that do not include CFG (classifier-free guidance) in the forward step, please set `enable_separate_cfg` param to **False (default, None)**. Otherwise, set it to True. For examples:
+
+```python
+from cache_dit import DBCacheConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBCacheConfig(
+ ...,
+ # CFG: classifier free guidance or not
+ # For model that fused CFG and non-CFG into single forward step,
+ # should set enable_separate_cfg as False. For example, set it as True
+ # for Wan 2.1/Qwen-Image and set it as False for FLUX.1, HunyuanVideo,
+ # CogVideoX, Mochi, LTXVideo, Allegro, CogView3Plus, EasyAnimate, SD3, etc.
+ enable_separate_cfg=True, # Wan 2.1, Qwen-Image, CogView4, Cosmos, SkyReelsV2, etc.
+ # Compute cfg forward first or not, default False, namely,
+ # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
+ cfg_compute_first=False,
+ # Compute separate diff values for CFG and non-CFG step,
+ # default True. If False, we will use the computed diff from
+ # current non-CFG transformer step for current CFG step.
+ cfg_diff_compute_separate=True,
+ ),
+)
+```
+
+## Hybrid TaylorSeer Calibrator
+
+
+
+
+
+We have supported the [TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) algorithm to further improve the precision of DBCache in cases where the cached steps are large, namely, **Hybrid TaylorSeer + DBCache**. At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
+
+
+
+**TaylorSeer** employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in cache-dit supports both hidden states and residual cache types. That F_pred can be a residual cache or a hidden-state cache.
+
+```python
+from cache_dit import DBCacheConfig, TaylorSeerCalibratorConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ # Basic DBCache w/ FnBn configurations
+ cache_config=DBCacheConfig(
+ max_warmup_steps=8, # steps do not cache
+ max_cached_steps=-1, # -1 means no limit
+ Fn_compute_blocks=8, # Fn, F8, etc.
+ Bn_compute_blocks=8, # Bn, B8, etc.
+ residual_diff_threshold=0.12,
+ ),
+ # Then, you can use the TaylorSeer Calibrator to approximate
+ # the values in cached steps, taylorseer_order default is 1.
+ calibrator_config=TaylorSeerCalibratorConfig(
+ taylorseer_order=1,
+ ),
+)
+```
+
+Please note that if you have used TaylorSeer as the calibrator for approximate hidden states, the **Bn** param of DBCache can be set to **0**. In essence, DBCache's Bn is also act as a calibrator, so you can choose either Bn > 0 or TaylorSeer. We recommend using the configuration scheme of **TaylorSeer** + **DBCache FnB0**.
+
+
+
+ DBCache F1B0 + TaylorSeer , L20x1, Steps: 28, "A cat holding a sign that says hello world with complex background"
+
+
+
+|Baseline(L20x1)|F1B0 (0.12)|+TaylorSeer|F1B0 (0.15)|+TaylorSeer +compile|
+|:---:|:---:|:---:|:---:|:---:|
+|24.85s|12.85s|12.86s|10.27s|8.48s|
+| | | | | |
+
+## SCM: Steps Computation Masking
+
+
+
+
+The `steps_computation_mask` parameter adopts a step-wise computation masking approach inspired by [LeMiCa](https://github.com/UnicomAI/LeMiCa) and [EasyCache](https://github.com/H-EmbodVis/EasyCache). Its key insight is that **early caching induces amplified downstream errors, whereas later caching is less disruptive**, resulting in a **non-uniform** distribution of cached steps.
+
+|LeMiCa: Non-Uniform Cache Steps|LeMiCa: Cache Errors|EasyCache: Transformation rate Analysis|
+|:---:|:---:|:---:|
+| | | |
+
+It is a list of length num_inference_steps indicating whether to compute each step or not. 1 means must compute, 0 means use dynamic/static cache. If provided, will override other settings to decide whether to compute each step. Please check the [📚examples/steps_mask](https://github.com/vipshop/cache-dit/blob/main/examples/api/run_steps_mask.py) for more details.
+
+
+```python
+from cache_dit import DBCacheConfig, TaylorSeerCalibratorConfig
+
+# Scheme: Hybrid DBCache + SCM + TaylorSeer
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBCacheConfig(
+ # Basic DBCache configs
+ Fn_compute_blocks=8,
+ Bn_compute_blocks=0,
+ # NOTE: warmup steps is not required now!
+ residual_diff_threshold=0.12,
+ # LeMiCa or EasyCache style Mask for 28 steps, e.g,
+ # SCM=111111010010000010000100001, 1: compute, 0: cache.
+ steps_computation_mask=cache_dit.steps_mask(
+ # e.g: slow, medium, fast, ultra.
+ mask_policy="fast", total_steps=28,
+ # Or, you can use bins setting to get custom mask.
+ # compute_bins=[6, 1, 1, 1, 1], # 10
+ # cache_bins=[1, 2, 5, 5, 5], # 18
+ ),
+ # The policy for cache steps can be 'dynamic' or 'static'
+ steps_computation_policy="dynamic",
+ ),
+ calibrator_config=TaylorSeerCalibratorConfig(
+ taylorseer_order=1,
+ ),
+)
+
+```
+
+As we can observe, in the case of **static cache**, the image of `SCM Slow S*` (please click to enlarge) has shown **obvious blurriness**. However, the **Ultra** version under **dynamic cache** (`SCM Ultra D*`) still maintains excellent clarity. Therefore, we prioritize recommending the use of dynamic cache while using `SCM: steps_computation_mask`.
+
+
+|Baseline|SCM S S*|SCM S D*|SCM F D*|SCM U D*|+TS|+compile|+FP8 +Sage|
+|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
+|24.85s|15.4s|17.1s|11.4s|8.2s|8.2s|7.1s|4.5s|
+| | | | | | | | |
+
+
+ Scheme: DBCache + SCM(steps_computation_mask) + TaylorSeer , L20x1, S*: static cache, D*: dynamic cache , S : Slow, F : Fast, U : Ultra Fast, TS : TaylorSeer, FP8: FP8 DQ, Sage: SageAttention, FLUX.1-Dev , Steps: 28, HxW=1024x1024, Prompt: "A cat holding a sign that says hello world"
+
+
+|DBCache + SCM Slow S*|DBCache + SCM Ultra D* + TaylorSeer + compile|
+|:---:|:---:|
+|15.4s|7.1s|
+| | |
+
+
+Dynamic Caching is all you need! The Ultra fast version under dynamic cache (SCM Ultra D* ) maintains better clarity than the slower static cache one (SCM Slow S* ).
+
diff --git a/docs/user_guide/COMPILE.md b/docs/user_guide/COMPILE.md
new file mode 100644
index 000000000..ad0c827ba
--- /dev/null
+++ b/docs/user_guide/COMPILE.md
@@ -0,0 +1,23 @@
+# Compile
+
+## Torch Compile
+
+
+
+By the way, **cache-dit** is designed to work compatibly with **torch.compile.** You can easily use cache-dit with torch.compile to further achieve a better performance. For example:
+
+```python
+cache_dit.enable_cache(pipe)
+
+# Compile the Transformer module
+pipe.transformer = torch.compile(pipe.transformer)
+```
+However, users intending to use **cache-dit** for DiT with **dynamic input shapes** should consider increasing the **recompile** **limit** of `torch._dynamo`. Otherwise, the recompile_limit error may be triggered, causing the module to fall back to eager mode.
+```python
+torch._dynamo.config.recompile_limit = 96 # default is 8
+torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
+```
+Or, you can use the `set_compile_configs` util func in cache-dit:
+```python
+cache_dit.set_compile_configs()
+```
diff --git a/docs/user_guide/CONTEXT_PARALLEL.md b/docs/user_guide/CONTEXT_PARALLEL.md
new file mode 100644
index 000000000..711ca6928
--- /dev/null
+++ b/docs/user_guide/CONTEXT_PARALLEL.md
@@ -0,0 +1,162 @@
+# Context Parallelism
+
+## Hybrid Context Parallelism
+
+
+
+cache-dit is compatible with context parallelism. Currently, we support the use of `Hybrid Cache` + `Context Parallelism` scheme (via NATIVE_DIFFUSER parallelism backend) in cache-dit. Users can use Context Parallelism to further accelerate the speed of inference! For more details, please refer to [📚examples](https://github.com/vipshop/cache-dit/tree/main/examples). Currently, cache-dit supported context parallelism for [FLUX.1](https://huggingface.co/black-forest-labs/FLUX.1-dev), 🔥[FLUX.2](https://huggingface.co/black-forest-labs/FLUX.2-dev), [Qwen-Image](https://github.com/QwenLM/Qwen-Image), [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning), [LTXVideo](https://huggingface.co/Lightricks/LTX-Video), [Wan 2.1](https://github.com/Wan-Video/Wan2.1), [Wan 2.2](https://github.com/Wan-Video/Wan2.2), [HunyuanImage-2.1](https://huggingface.co/tencent/HunyuanImage-2.1), [HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo), [CogVideoX 1.0](https://github.com/zai-org/CogVideo), [CogVideoX 1.5](https://github.com/zai-org/CogVideo), [CogView 3/4](https://github.com/zai-org/CogView4) and [VisualCloze](https://github.com/lzyhha/VisualCloze), etc. cache-dit will support more models in the future.
+
+```python
+# pip3 install "cache-dit[parallelism]"
+from cache_dit import ParallelismConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBCacheConfig(...),
+ # Set ulysses_size > 1 to enable ulysses style context parallelism.
+ parallelism_config=ParallelismConfig(ulysses_size=2),
+)
+# torchrun --nproc_per_node=2 parallel_cache.py
+```
+
+|L20x1| Ulysses-2 | Ulysses-4 | + compile |
+|:---:|:---:|:---:|:---:|
+|FLUX, 23.56s| 13.80s | 8.28s | 7.27s |
+| | | | |
+
+## UAA: Ulysses Anything Attention
+
+
+
+✅**Any Sequence Length**: We have implemented the **[📚UAA: Ulysses Anything Attention](#uaa-ulysses-anything-attention)**: An Ulysses Attention that supports **arbitrary sequence length** with ✅**zero padding** and **nearly ✅zero theoretical communication overhead**. The default Ulysses Attention requires that the sequence len of hidden states **must be divisible by the number of devices**. This imposes **significant limitations** on the practical application of Ulysses.
+
+
+```python
+# pip3 install "cache-dit[parallelism]"
+from cache_dit import ParallelismConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBCacheConfig(...),
+ # Set `experimental_ulysses_anything` as True to enable UAA
+ parallelism_config=ParallelismConfig(
+ ulysses_size=2,
+ parallel_kwargs={
+ "experimental_ulysses_anything": True
+ },
+ ),
+)
+# torchrun --nproc_per_node=2 parallel_cache_ulysses_anything.py
+```
+
+For example, in the T2I and I2V tasks, the length of prompts input by users is often variable, and it is difficult to ensure that this length is divisible by the number of devices. To address this issue, we have developed a **✅padding-free** Ulysses Attention (UAA) for **arbitrary sequence length**, which enhances the versatility of Ulysses.
+
+```python
+dist.init_process_group(backend="cpu:gloo,cuda:nccl")
+```
+Compared to Ulysses Attention, in **UAA**, we have only added an **extra all-gather** op for scalar types to gather the seq_len value of each rank. To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **✅gloo** backend in `init_process_group`. This will significantly reduce communication latency.
+
+
+ ✅Any Sequence Length
+ U*: Ulysses Attention, UAA: Ulysses Anything Attenton , UAA*: UAA + Gloo, Device: NVIDIA L20
+ FLUX.1-Dev w/o CPU Offload, 28 steps; Qwen-Image w/ CPU Offload, 50 steps; Gloo: Extra All Gather w/ Gloo
+
+
+|CP2 w/ U* |CP2 w/ UAA* | CP2 w/ UAA | L20x1 | CP2 w/ UAA* | CP2 w/ U* | L20x1 | CP2 w/ UAA* |
+|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
+|FLUX, 13.87s|**🎉13.88s**|14.75s|23.25s| **🎉13.75s**|Qwen, 132s|181s|**🎉133s**|
+| | | | | | | | |
+|1024x1024|1024x1024|1024x1024|1008x1008|1008x1008|1312x1312|1328x1328|1328x1328|
+|✔️U* ✔️UAA|✔️U* ✔️UAA|✔️U* ✔️UAA| NO CP|❌U* ✔️UAA|✔️U* ✔️UAA|NO CP|❌U* ✔️UAA|
+
+
+✅**Any Head Num**: By the way, Ulysses Attention and UAA in cache-dit now **support arbitrary numbers of heads** via additional padding and unpadding operations implemented before and after all-to-all. The overhead incurred by these extra padding and unpadding steps can be **partially hidden** through asynchronous communication. This support for arbitrary head counts is **automatically activated** whenever the number of heads is not divisible by the world size. For Example:
+
+
+
+ ✅Any Head Num
+ Ulysses: Ulysses Attention, FP8 Ulysses: Ulysses w/ FP8 All2All , Device: NVIDIA L20
+ 🔥Z-Image (Head=30, ❌CAN NOT divisible by 4), 1024x1024, 9 steps.
+
+
+|Ulysses 2, L20|Ulysses 4|FP8 Ulysses 4| + Cache | + FP8 DQ |
+|:---:|:---:|:---:|:---:|:---:|
+|1024x1024, 3.19s|1024x1024, 1.98s|1024x1024, 1.89s|1024x1024, 1.63s|1024x1024, 1.23s|
+| | | | | |
+
+
+We have also implemented a ✅**padding-free** version that support any head num. Please be informed that this solution cannot be used when seq len is not divisible by world size. Users can enable this feature through environment variables:
+
+```bash
+export CACHE_DIT_UNEVEN_HEADS_COMM_NO_PAD=1 # NOT WORK if seq len is also not divisible by world size
+```
+
+Important: Please note that **Ulysses Anything Attention (UAA)** is currently an **experimental** feature. It has not undergone large-scale testing, and may introduce a slight performance degradation while the `cpu:gloo` commucation backend is not available.
+
+## Async Ulysses QKV Projection
+
+
+
+
+
+
+Inspired by [ByteDance-Seed/VeOmni: Async Ulysses CP](https://github.com/ByteDance-Seed/VeOmni/blob/main/veomni/distributed/sequence_parallel/async_ulysses.py), we have also added support for **Async Ulysses QKV Projection** for certain models in cache-dit. This enables partial overlap of communication and computation, which can further enhance the performance of Ulysses style Context Parallelism. Currently, only the 🔥[FLUX.1](https://huggingface.co/black-forest-labs/FLUX.1-dev), 🔥[Qwen-Image](https://github.com/QwenLM/Qwen-Image), 🔥[Z-Image](https://github.com/Tongyi-MAI/Z-Image) and 🔥[Ovis-Image](https://github.com/AIDC-AI/Ovis-Image) models are supported, and more models will be added in the future—stay tuned!
+
+```python
+# pip3 install "cache-dit[parallelism]"
+from cache_dit import ParallelismConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBCacheConfig(...),
+ # Set `experimental_ulysses_async` as True to enable Async Ulysses QKV Projection.
+ parallelism_config=ParallelismConfig(
+ ulysses_size=2,
+ parallel_kwargs={
+ "experimental_ulysses_async": True
+ },
+ ),
+)
+# torchrun --nproc_per_node=2 parallel_cache_ulysses_async.py
+```
+
+
+
+ Ulysses: Standard Ulysses Attention, Async Ulysses : Ulysses Attenton with Async QKV Projection
+
+
+|L20x2 w/ Ulysses| w/ Async Ulysses|w/ Ulysses + compile| w/ Async Ulysses + compile|
+|:---:|:---:|:---:|:---:|
+|FLUX.1, 13.87s|**🎉13.20s**|12.21s|**🎉11.97s**|
+| | | |
+
+## Async FP8 Ulysses Attention
+
+
+
+
+
+cache-dit has implemented **Async FP8 Ulysses Attention** for **🔥all** supported DiTs. This optimization reduces communication latency while preserving high precision. Users can enable this feature by setting `experimental_ulysses_float8=True`. To maintain higher precision during softmax computation—where `Softmax(Q@K^T)` is sensitive to numerical instability—we currently retain `K in FP16/BF16` format. Float8-optimized all_to_all communication is therefore only applied to Q, V, and O.
+
+```python
+# pip3 install "cache-dit[parallelism]"
+from cache_dit import ParallelismConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBCacheConfig(...),
+ # Set `experimental_ulysses_float8` as True to enable Async FP8 Ulysses Attention
+ parallelism_config=ParallelismConfig(
+ ulysses_size=2,
+ parallel_kwargs={
+ "experimental_ulysses_float8": True
+ },
+ ),
+)
+# torchrun --nproc_per_node=2 parallel_cache_ulysses_float8.py
+```
+
+|L20x2 w/ Ulysses| w/ Ulysses FP8|w/ Ulysses + compile|w/ Ulysses FP8 + compile|
+|:---:|:---:|:---:|:---:|
+|FLUX.1, 13.87s|**🎉13.36s**|12.21s|**🎉11.54s**|
+| | | | |
diff --git a/docs/DBCache.md b/docs/user_guide/DBCACHE_DESIGN.md
similarity index 73%
rename from docs/DBCache.md
rename to docs/user_guide/DBCACHE_DESIGN.md
index 0854a72b7..c625dac07 100644
--- a/docs/DBCache.md
+++ b/docs/user_guide/DBCACHE_DESIGN.md
@@ -13,13 +13,13 @@
-|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
-|:---:|:---:|:---:|:---:|:---:|:---:|
-|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
-| | | | | | |
-|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
-|27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
-| | | | | | |
+|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|
+|:---:|:---:|:---:|:---:|:---:|
+|24.85s|15.59s|8.58s|15.41s|15.11s|
+| | | | | |
+|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|
+|27.85s|6.04s|5.88s|5.77s|6.01s|
+| | | | | |
@@ -70,16 +70,6 @@ cache_dit.enable_cache(
),
)
```
-
-
- DBCache, L20x1 , Steps: 28, "A cat holding a sign that says hello world with complex background"
-
-
-
-|Baseline(L20x1)|F1B0 (0.08)|F1B0 (0.20)|F8B8 (0.15)|F12B12 (0.20)|F16B16 (0.20)|
-|:---:|:---:|:---:|:---:|:---:|:---:|
-|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
-|
|
|
|
|
|
|
## ⚡️Hybrid Cache CFG
diff --git a/docs/user_guide/EXTRA_PARALLEL.md b/docs/user_guide/EXTRA_PARALLEL.md
new file mode 100644
index 000000000..8ac53add3
--- /dev/null
+++ b/docs/user_guide/EXTRA_PARALLEL.md
@@ -0,0 +1,89 @@
+# Extra Modules Parallelism
+
+## Parallelize Text Encoder
+
+
+
+Users can set the `extra_parallel_modules` parameter in parallelism_config (when using Tensor Parallelism or Context Parallelism) to specify additional modules that need to be parallelized beyond the main transformer — e.g, `text_encoder` in `Flux2Pipeline`. It can further reduce the per-GPU memory requirement and slightly improve the inference performance of the text encoder.
+
+Currently, cache-dit supported text encoder parallelism for **T5Encoder, UMT5Encoder, Llama, Gemma 1/2/3, Mistral, Mistral-3, Qwen-3, Qwen-2.5 VL, Glm and Glm-4** model series, namely, supported almost **[🔥ALL](../supported_matrix/NVIDIA_GPU.md)** pipelines in diffusers.
+
+```python
+# pip3 install "cache-dit[parallelism]"
+from cache_dit import ParallelismConfig
+
+# Transformer Tensor Parallelism + Text Encoder Tensor Parallelism
+cache_dit.enable_cache(
+ pipe,
+ cache_config=DBCacheConfig(...),
+ parallelism_config=ParallelismConfig(
+ tp_size=2,
+ parallel_kwargs={
+ "extra_parallel_modules": [pipe.text_encoder], # FLUX.2
+ },
+ ),
+)
+
+# Transformer Context Parallelism + Text Encoder Tensor Parallelism
+cache_dit.enable_cache(
+ pipe,
+ cache_config=DBCacheConfig(...),
+ parallelism_config=ParallelismConfig(
+ ulysses_size=2,
+ parallel_kwargs={
+ "extra_parallel_modules": [pipe.text_encoder], # FLUX.2
+ },
+ ),
+)
+# torchrun --nproc_per_node=2 parallel_cache.py
+```
+
+## Parallelize Auto Encoder (VAE)
+
+
+
+Currently, cache-dit supported auto encoder (vae) parallelism for **AutoencoderKL, AutoencoderKLQwenImage, AutoencoderKLWan, and AutoencoderKLHunyuanVideo** series, namely, supported almost **[🔥ALL](../supported_matrix/NVIDIA_GPU.md)** pipelines in diffusers. It can further reduce the per-GPU memory requirement and slightly improve the inference performance of the auto encoder. Users can set it by `extra_parallel_modules` parameter in parallelism_config, for example:
+
+```python
+# pip3 install "cache-dit[parallelism]"
+from cache_dit import ParallelismConfig
+
+# Transformer Context Parallelism + Text Encoder Tensor Parallelism + VAE Data Parallelism
+cache_dit.enable_cache(
+ pipe,
+ cache_config=DBCacheConfig(...),
+ parallelism_config=ParallelismConfig(
+ ulysses_size=2,
+ parallel_kwargs={
+ "extra_parallel_modules": [pipe.text_encoder, pipe.vae], # FLUX.1
+ },
+ ),
+)
+# torchrun --nproc_per_node=2 parallel_cache.py
+```
+
+## Parallelize ControlNet
+
+
+
+Further, cache-dit even supported controlnet parallelism for specific models, such as Z-Image-Turbo with ControlNet. Users can set it by `extra_parallel_modules` parameter in parallelism_config, for example:
+
+```python
+# pip3 install "cache-dit[parallelism]"
+from cache_dit import ParallelismConfig
+
+# Transformer Context Parallelism + Text Encoder Tensor Parallelism
+# + VAE Data Parallelism + ControlNet Context Parallelism
+cache_dit.enable_cache(
+ pipe,
+ cache_config=DBCacheConfig(...),
+ parallelism_config=ParallelismConfig(
+ ulysses_size=2,
+ # case: Z-Image-Turbo-Fun-ControlNet-2.1
+ parallel_kwargs={
+ "extra_parallel_modules": [pipe.text_encoder, pipe.vae, pipe.controlnet],
+ },
+ ),
+)
+# torchrun --nproc_per_node=2 parallel_cache.py
+```
diff --git a/docs/user_guide/INSTALL.md b/docs/user_guide/INSTALL.md
new file mode 100644
index 000000000..a46f47239
--- /dev/null
+++ b/docs/user_guide/INSTALL.md
@@ -0,0 +1,24 @@
+# Installation
+
+## Installation with Nvidia GPU
+
+
+
+You can install the stable release of `cache-dit` from PyPI:
+
+```bash
+pip3 install -U cache-dit # or, pip3 install -U "cache-dit[all]" for all features
+```
+Or you can install the latest develop version from GitHub:
+
+```bash
+pip3 install git+https://github.com/vipshop/cache-dit.git
+```
+Please also install the latest main branch of diffusers for context parallelism:
+```bash
+pip3 install git+https://github.com/huggingface/diffusers.git # or >= 0.36.0
+```
+
+## Installation with Ascend NPU
+
+Please refer to [Ascend NPU Support](./ASCEND_NPU.md) documentation for more details.
diff --git a/docs/user_guide/LOAD_CONFIGS.md b/docs/user_guide/LOAD_CONFIGS.md
new file mode 100644
index 000000000..9098d05da
--- /dev/null
+++ b/docs/user_guide/LOAD_CONFIGS.md
@@ -0,0 +1,67 @@
+# Use Yaml Config File
+
+Cache-DiT now supported load the acceleration configs from a custom yaml file. Here are some examples.
+
+## Single GPU inference
+
+Define a `config.yaml` file that contains:
+
+```yaml
+cache_config:
+ max_warmup_steps: 8
+ warmup_interval: 2
+ max_cached_steps: -1
+ max_continuous_cached_steps: 2
+ Fn_compute_blocks: 1
+ Bn_compute_blocks: 0
+ residual_diff_threshold: 0.12
+ enable_taylorseer: true
+ taylorseer_order: 1
+```
+Then, apply the acceleration config from yaml.
+
+```python
+>>> import cache_dit
+>>> cache_dit.enable_cache(pipe, **cache_dit.load_configs("config.yaml"))
+```
+
+## Distributed inference
+
+Define a `parallel_config.yaml` file that contains:
+
+```yaml
+cache_config:
+ max_warmup_steps: 8
+ warmup_interval: 2
+ max_cached_steps: -1
+ max_continuous_cached_steps: 2
+ Fn_compute_blocks: 1
+ Bn_compute_blocks: 0
+ residual_diff_threshold: 0.12
+ enable_taylorseer: true
+ taylorseer_order: 1
+parallelism_config:
+ ulysses_size: auto
+ parallel_kwargs:
+ attention_backend: native
+ extra_parallel_modules: ["text_encoder", "vae"]
+```
+Then, apply the distributed inference acceleration config from yaml. `ulysses_size: auto` means that cache-dit will auto detect the `world_size` as the ulysses_size. Otherwise, you should mannually set it as specific int number, e.g, 4.
+```python
+>>> import cache_dit
+>>> cache_dit.enable_cache(pipe, **cache_dit.load_configs("parallel_config.yaml"))
+```
+
+## Quick Examples
+
+```bash
+pip3 install torch==2.9.1 transformers accelerate torchao bitsandbytes torchvision
+pip3 install opencv-python-headless einops imageio-ffmpeg ftfy
+pip3 install git+https://github.com/huggingface/diffusers.git # latest or >= 0.36.0
+pip3 install git+https://github.com/vipshop/cache-dit.git # latest
+
+git clone https://github.com/vipshop/cache-dit.git && cd examples
+
+python3 generate.py flux --config config.yaml
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py flux --config parallel_config.yaml
+```
diff --git a/docs/user_guide/METRICS.md b/docs/user_guide/METRICS.md
new file mode 100644
index 000000000..4ab559ab3
--- /dev/null
+++ b/docs/user_guide/METRICS.md
@@ -0,0 +1,38 @@
+# Metrics
+
+You can utilize the APIs provided by cache-dit to quickly evaluate the accuracy losses caused by different cache configurations.
+
+## Metrics Functions
+
+Use the metrics functions evaluate the accuracy losses:
+
+
+
+```python
+# pip3 install "cache-dit[metrics]"
+from cache_dit.metrics import compute_psnr
+from cache_dit.metrics import compute_ssim
+from cache_dit.metrics import compute_fid
+from cache_dit.metrics import compute_lpips
+from cache_dit.metrics import compute_clip_score
+from cache_dit.metrics import compute_image_reward
+
+psnr, n = compute_psnr("true.png", "test.png") # Num: n
+psnr, n = compute_psnr("true_dir", "test_dir")
+ssim, n = compute_ssim("true_dir", "test_dir")
+fid, n = compute_fid("true_dir", "test_dir")
+lpips, n = compute_lpips("true_dir", "test_dir")
+clip, n = compute_clip_score("DrawBench200.txt", "test_dir")
+reward, n = compute_image_reward("DrawBench200.txt", "test_dir")
+```
+
+## Metrics Command Line
+
+Or, you can use `cache-dit-metrics-cli` tool. For examples:
+
+```bash
+cache-dit-metrics-cli -h # show usage
+# all: PSNR, FID, SSIM, MSE, ..., etc.
+cache-dit-metrics-cli all -i1 true.png -i2 test.png # image
+cache-dit-metrics-cli all -i1 true_dir -i2 test_dir # image dir
+```
diff --git a/docs/user_guide/OVERVIEWS.md b/docs/user_guide/OVERVIEWS.md
new file mode 100644
index 000000000..bed9cbc2c
--- /dev/null
+++ b/docs/user_guide/OVERVIEWS.md
@@ -0,0 +1,16 @@
+
+
+
+ CacheDiT: A PyTorch-native and Flexible Inference Engine with 🤗🎉 Hybrid Cache Acceleration and Parallelism for DiTs
+
+
+
+
+
+# Overviews
+
+Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Supported Matrix](../supported_matrix/NVIDIA_GPU.md) for more details.
+
+- [📊Examples](https://github.com/vipshop/cache-dit/tree/main/examples) - The **easiest** way to enable **hybrid cache acceleration** and **parallelism** for DiTs with cache-dit is to start with our examples for popular models: FLUX, Z-Image, Qwen-Image, Wan, etc.
+- [🌐HTTP Serving](./SERVING.md) - Deploy cache-dit models with HTTP API for **text-to-image**, **image editing**, **multi-image editing**, and **text-to-video** generation
+- [❓FAQ](../FAQ.md) - Frequently asked questions including attention backend configuration, troubleshooting, and optimization tips
diff --git a/docs/user_guide/PROFILER.md b/docs/user_guide/PROFILER.md
new file mode 100644
index 000000000..03b3085d8
--- /dev/null
+++ b/docs/user_guide/PROFILER.md
@@ -0,0 +1,202 @@
+# Torch Profiler Usage
+
+## Quick Start
+
+### Basic Usage
+
+`cache-dit` examples have Torch Profiler built in: pass `--profile` to `examples/generate.py` to generate a trace file.
+
+Before running examples, make sure `cache_dit` is importable by Python.
+
+Recommended: run from the `examples/` directory (consistent with `examples/README.md`):
+
+```bash
+cd examples
+
+# List all available examples
+python3 generate.py list
+
+# Basic profiling (recommended: reduce steps to keep the trace small)
+python3 generate.py flux --profile --steps 3
+```
+
+If you want to write traces to a specific directory (or customize the filename prefix):
+
+```bash
+cd examples
+python3 generate.py flux --profile --steps 3 --profile-dir /tmp/cache_dit_profiles --profile-name flux_test
+```
+
+> Note: for multi-GPU runs (`torchrun`), each rank produces its own trace file, e.g. `flux_test-rank0.trace.json.gz`.
+
+### Example: `examples/base.py` Integration
+
+`generate.py` eventually calls `ExampleBase.run()`, which already integrates `--profile/--profile-dir/--profile-activities`; you only need to pass these flags on the command line.
+
+## Command-Line Arguments
+
+```bash
+# Basic profiling
+cd examples
+python3 generate.py flux --profile --steps 3
+
+# With custom profile name and output directory
+cd examples
+python3 generate.py flux --profile --steps 3 --profile-name flux_test --profile-dir /tmp/profiles
+
+# Profile with memory tracking
+cd examples
+python3 generate.py flux --profile --steps 3 --profile-activities CPU GPU MEM
+```
+
+## Parameters
+
+### `create_profiler_from_args(args, profile_name=None)`
+
+Creates a ProfilerContext from command-line arguments.
+
+**Arguments:**
+
+- `args` : Parsed command-line arguments containing profiler settings
+- `profile_name` (str, optional): Override the profile name
+
+**Command-Line Arguments:**
+
+- `--profile` : Enable profiler (default: False)
+- `--profile-name` (str) : Profile name prefix (default: auto-generated timestamp)
+- `--profile-dir` (str) : Output directory (default: $CACHE_DIT_TORCH_PROFILER_DIR or `/tmp/cache_dit_profiles`)
+- `--profile-activities` (list[str]) : Activities to profile - CPU, GPU, MEM (default: ["CPU", "GPU"])
+- `--profile-with-stack` : Record stack traces (default: True, enable for detailed debugging)
+- `--profile-record-shapes` : Record tensor shapes (default: True)
+
+**Returns:**
+
+- `ProfilerContext`: Context manager for profiling
+
+**Environment Variables:**
+
+- `CACHE_DIT_TORCH_PROFILER_DIR`: Default output directory
+
+### Controlling Trace File Size
+
+Torch Profiler trace files can be large. Recommendations:
+- Reduce `--steps` (e.g., 3–5)
+- Reduce `--repeat`
+- Optionally disable `--profile-with-stack` / `--profile-record-shapes` (if you add a way to disable them in your workflow)
+
+```bash
+# Profile with 3 steps (small trace file, recommended)
+cd examples
+python3 generate.py flux --profile --steps 3 --warmup 0 --repeat 1
+
+# Profile with full 28 steps (larger trace file)
+cd examples
+python3 generate.py flux --profile --steps 28 --warmup 0 --repeat 1
+```
+
+## View Results
+
+### Perfetto UI (Recommended)
+Visit https://ui.perfetto.dev/ and drag-drop the generated `.trace.json.gz` file. Perfetto provides a more powerful and feature-rich interface compared to Chrome Tracing.
+
+The screenshots below show an example profiling result from `generate.py flux` (model: FLUX.1-dev).
+
+
+
+
+
+
+
+
+
+### Chrome Tracing
+Open `chrome://tracing` in Chrome browser and load the generated `.trace.json.gz` file.
+
+### TensorBoard
+```bash
+pip install tensorboard
+tensorboard --logdir=/path/to/profiles
+```
+
+## Multi-GPU Usage
+
+The profiler automatically handles distributed environments. Each rank will generate its own trace file.
+
+### Example: Tensor Parallelism
+
+```bash
+# 2 GPUs with tensor parallelism
+cd examples
+torchrun --nproc_per_node=2 generate.py flux \
+ --parallel tp \
+ --profile --profile-name flux_tp --steps 3 --warmup 0 --repeat 1
+
+# Output files:
+# - flux_tp-rank0.trace.json.gz
+# - flux_tp-rank1.trace.json.gz
+```
+
+### Example: Context Parallelism
+
+```bash
+# 4 GPUs with context parallelism
+cd examples
+torchrun --nproc_per_node=4 generate.py flux \
+ --parallel ulysses \
+ --profile --profile-name flux_cp --profile-activities CPU GPU MEM \
+ --steps 3 --warmup 0 --repeat 1
+
+# Output files:
+# - flux_cp-rank0.trace.json.gz
+# - flux_cp-rank1.trace.json.gz
+# - flux_cp-rank2.trace.json.gz
+# - flux_cp-rank3.trace.json.gz
+# - flux_cp-rank0-memory-*.pickle (if MEM profiling enabled)
+# - flux_cp-rank1-memory-*.pickle
+# - ...
+```
+
+You can view each rank's trace separately in Perfetto UI or Chrome Tracing to analyze per-GPU performance.
+
+
+# Nsight Systems (nsys) Usage
+
+If you need a lower-level CUDA view (kernel timeline, CUDA API, CPU/GPU concurrency, etc.), use Nsight Systems.
+
+## Installation
+
+Follow NVIDIA Nsight Systems installation instructions (the CLI is usually `nsys`), or your internal environment setup.
+
+## Basic Profiling
+
+The example below profiles a single inference (recommended: set `--warmup 0` so warmup is not included):
+
+```bash
+cd examples
+nsys profile \
+ --trace=cuda,nvtx,osrt \
+ --force-overwrite=true \
+ -o cache_dit_flux \
+ python3 generate.py flux --steps 28 --warmup 0 --repeat 1
+```
+
+## Targeted Capture (reduce file size)
+
+Use `--delay/--duration` to skip model loading/initialization and capture only the main inference window:
+
+```bash
+cd examples
+nsys profile \
+ --trace=cuda,nvtx,osrt \
+ --force-overwrite=true \
+ --delay 10 \
+ --duration 30 \
+ -o cache_dit_flux_infer \
+ python3 generate.py flux --steps 28 --warmup 0 --repeat 1
+```
+
+**Parameter notes:**
+
+- `--delay N` : wait N seconds before capture (commonly used to skip initialization)
+- `--duration N` : stop capture after N seconds (commonly used to limit file size)
+- `-o
` : output file prefix
diff --git a/docs/user_guide/QUANTIZATION.md b/docs/user_guide/QUANTIZATION.md
new file mode 100644
index 000000000..49efdf85b
--- /dev/null
+++ b/docs/user_guide/QUANTIZATION.md
@@ -0,0 +1,67 @@
+# Low-bits Quantization
+
+
+
+## TorchAo
+
+Currently, torchao has been integrated into cache-dit as the backend for **online** model quantization (with more backends to be supported in the future). You can implement model quantization by calling `cache_dit.quantize(...)`. At present, cache-dit supports the `Hybrid Cache + Low-bits Quantization` scheme. For GPUs with low memory capacity, we recommend using `float8_weight_only` or `int8_weight_only`, as these two schemes cause almost no loss in precision.
+
+```python
+# pip3 install "cache-dit[quantization]"
+import cache_dit
+
+cache_dit.enable_cache(pipe_or_adapter)
+
+# float8, float8_weight_only, int8, int8_weight_only, int4, int4_weight_only
+# int4_weight_only requires fbgemm-gpu-genai>=1.2.0, which only supports
+# Compute Architectures >= Hopper (and does not support Ada, ..., etc.)
+pipe.transformer = cache_dit.quantize(
+ pipe.transformer, quant_type="float8_weight_only"
+)
+pipe.text_encoder = cache_dit.quantize(
+ pipe.text_encoder, quant_type="float8_weight_only"
+)
+```
+
+## bitsandbytes
+
+For **4-bits W4A16 (weight only)** quantization, we recommend `nf4` from **bitsandbytes** due to its better compatibility for many devices. Users can directly use it via the `quantization_config` of diffusers. For example:
+
+```python
+from diffusers import QwenImagePipeline
+from diffusers.quantizers import PipelineQuantizationConfig
+
+pipe = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image",
+ torch_dtype=torch.bfloat16,
+ quantization_config=(
+ PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={
+ "load_in_4bit": True,
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_compute_dtype": torch.bfloat16,
+ },
+ components_to_quantize=["text_encoder", "transformer"],
+ )
+ ),
+).to("cuda")
+
+# Then, apply cache acceleration using cache-dit
+cache_dit.enable_cache(pipe, cache_config=...)
+```
+
+## Nunchaku
+
+cache-dit natively supports the `Hybrid Cache + 🔥Nunchaku SVDQ INT4/FP4 + Context Parallelism` scheme. Users can leverage caching and context parallelism to speed up Nunchaku **4-bit** models.
+
+```python
+transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
+ f"path-to/svdq-int4_r32-qwen-image.safetensors"
+)
+pipe = QwenImagePipeline.from_pretrained(
+ "Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16,
+).to("cuda")
+
+cache_dit.enable_cache(pipe, cache_config=..., parallelism_config=...)
+```
diff --git a/docs/user_guide/SERVING.md b/docs/user_guide/SERVING.md
new file mode 100644
index 000000000..ab4940c05
--- /dev/null
+++ b/docs/user_guide/SERVING.md
@@ -0,0 +1,51 @@
+# Cache-DiT Serving
+
+HTTP serving for diffusion models with cache-dit acceleration. Supports **text-to-image**, **image editing**, **multi-image editing**, **text-to-video**, and **image-to-video** generation.
+
+Adapted from [SGLang](https://github.com/sgl-project/sglang).
+
+## Supported Tasks
+
+- **Text-to-Image (t2i)**
+- **Image Editing (edit)**
+- **Text-to-Video (t2v)**
+- **Image-to-Video (i2v)**
+
+Serving setups for LoRA, multi-image editing, distributed parallelism, etc. are available as runnable recipes.
+
+## Start Server
+
+```bash
+pip install -e ".[serving]"
+
+torchrun --nproc_per_node=1 -m cache_dit.serve.serve \
+ --model-path black-forest-labs/FLUX.1-dev \
+ --cache
+
+curl http://localhost:8000/health
+open http://localhost:8000/docs
+```
+
+## Example 1: Text-to-Image
+
+```bash
+curl -X POST \
+ -H "Content-Type: application/json" \
+ -d '{"prompt":"A beautiful sunset over the ocean","width":1024,"height":1024,"num_inference_steps":50}' \
+ http://localhost:8000/generate
+```
+
+## Example 2: Text-to-Video
+
+```bash
+curl -X POST \
+ -H "Content-Type: application/json" \
+ -d '{"prompt":"A cat walks on the grass, realistic","width":832,"height":480,"num_frames":49,"fps":16,"num_inference_steps":30}' \
+ http://localhost:8000/generate
+```
+
+## More Recipes
+
+For t2i / edit / t2v / i2v, LoRA, and multi-GPU launch examples, see:
+
+https://github.com/vipshop/cache-dit/tree/main/tests/serving
diff --git a/docs/user_guide/TENSOR_PARALLEL.md b/docs/user_guide/TENSOR_PARALLEL.md
new file mode 100644
index 000000000..030595e16
--- /dev/null
+++ b/docs/user_guide/TENSOR_PARALLEL.md
@@ -0,0 +1,28 @@
+# Tensor Parallelism
+
+## Hybrid Tensor Parallelism
+
+
+
+cache-dit is also compatible with tensor parallelism. Currently, we support the use of `Hybrid Cache` + `Tensor Parallelism` scheme (via NATIVE_PYTORCH parallelism backend) in cache-dit. Users can use Tensor Parallelism to further accelerate the speed of inference and **reduce the VRAM usage per GPU**! For more details, please refer to [📚examples/parallelism](https://github.com/vipshop/cache-dit/tree/main/examples). Now, cache-dit supported tensor parallelism for [FLUX.1](https://huggingface.co/black-forest-labs/FLUX.1-dev), 🔥[FLUX.2](https://huggingface.co/black-forest-labs/FLUX.2-dev), [Qwen-Image](https://github.com/QwenLM/Qwen-Image), [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning), [Wan2.1](https://github.com/Wan-Video/Wan2.1), [Wan2.2](https://github.com/Wan-Video/Wan2.2), [HunyuanImage-2.1](https://huggingface.co/tencent/HunyuanImage-2.1), [HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) and [VisualCloze](https://github.com/lzyhha/VisualCloze), etc. cache-dit will support more models in the future.
+
+```python
+# pip3 install "cache-dit[parallelism]"
+from cache_dit import ParallelismConfig
+
+cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=DBCacheConfig(...),
+ # Set tp_size > 1 to enable tensor parallelism.
+ parallelism_config=ParallelismConfig(tp_size=2),
+)
+# torchrun --nproc_per_node=2 parallel_cache.py
+```
+
+|L20x1| TP-2 | TP-4 | + compile |
+|:---:|:---:|:---:|:---:|
+|FLUX, 23.56s| 14.61s | 10.69s | 9.84s |
+| | | | |
+
+
+Please note that in the short term, we have no plans to support Hybrid Parallelism. Please choose to use either Context Parallelism or Tensor Parallelism based on your actual scenario.
diff --git a/examples/README.md b/examples/README.md
index 373aa3069..db2056526 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -1,90 +1,396 @@
-# Examples for CacheDiT
+
+
🚀 Examples for Cache-DiT
-## ⚙️Install Requirements
+|Z-Image-ControlNet| Context Parallel: Ulysses 2 | Context Parallel: Ulysses 4 | + ControlNet Parallel |
+|:---:|:---:|:---:|:---:|
+|Base L20x1: 22s|15.7s|12.7s|**🚀7.71s**|
+|
|
|
|
|
+| **+ Hybrid Cache** | **+ Torch Compile** | **+ Async Ulyess CP** | **+ FP8 All2All + CUDNN ATTN** |
+|**🚀6.85s**|6.45s|6.38s|**🚀6.19s, 5.47s**|
+|
|
|
|
-```bash
-pip3 install -r requirements.txt
-```
+
+
+## 📚 Table of Contents
-## 🚀Run Examples
+- [📚 Installation](#-installation)
+- [📚 Available Examples](#-available-examples)
+- [📚 Single GPU Inference](#-single-gpu-inference)
+- [📚 Custom Model Path](#-custom-model-path)
+- [📚 Multi-GPU Inference](#-multi-gpu-inference)
+- [📚 Low-bits Quantization](#-low-bits-quantization)
+- [📚 Hybrid Acceleration](#-hybrid-acceleration)
+- [📚 End2End Examples](#-end2end-examples)
+- [📚 How to Add New Example](#-how-to-add-new-example)
+- [📚 More Usages about Examples](#-more-usages-about-examples)
-- Qwen-Image-Edit
+## 📚 Installation
```bash
-python3 run_qwen_image_edit.py # baseline
-python3 run_qwen_image_edit.py --cache
+pip3 install torch==2.9.1 transformers accelerate torchao==0.14.1 bitsandbytes torchvision
+pip3 install opencv-python-headless einops imageio-ffmpeg ftfy
+pip3 install git+https://github.com/huggingface/diffusers.git # latest or >= 0.36.0
+pip3 install git+https://github.com/vipshop/cache-dit.git # latest
+
+git clone https://github.com/vipshop/cache-dit.git && cd cache-dit/examples
```
-- Qwen-Image
+## 📚 Available Examples
```bash
-python3 run_qwen_image.py # baseline
-python3 run_qwen_image.py --cache
-python3 run_qwen_image.py --cache --compile
-python3 run_qwen_image.py --cache --compile --quantize
+python3 generate.py list # list all available examples
+
+[generate.py:47] Available examples:
+[generate.py:53] - ✅ flux_nunchaku - Defalut: nunchaku-tech/nunchaku-flux.1-dev
+[generate.py:53] - ✅ flux - Defalut: black-forest-labs/FLUX.1-dev
+[generate.py:53] - ✅ flux_fill - Defalut: black-forest-labs/FLUX.1-Fill-dev
+[generate.py:53] - ✅ flux2 - Defalut: black-forest-labs/FLUX.2-dev
+[generate.py:53] - ✅ flux2_klein_base_9b - Defalut: black-forest-labs/FLUX.2-klein-base-9B
+[generate.py:53] - ✅ flux2_klein_base_4b - Defalut: black-forest-labs/FLUX.2-klein-base-4B
+[generate.py:53] - ✅ flux2_klein_9b - Defalut: black-forest-labs/FLUX.2-klein-9B
+[generate.py:53] - ✅ flux2_klein_4b - Defalut: black-forest-labs/FLUX.2-klein-4B
+[generate.py:53] - ✅ qwen_image_lightning - Defalut: lightx2v/Qwen-Image-Lightning
+[generate.py:53] - ✅ qwen_image_2512 - Defalut: Qwen/Qwen-Image-2512
+[generate.py:53] - ✅ qwen_image - Defalut: Qwen/Qwen-Image
+[generate.py:53] - ✅ qwen_image_edit_2511_lightning - Defalut: lightx2v/Qwen-Image-Edit-2511-Lightning
+[generate.py:53] - ✅ qwen_image_edit_2511 - Defalut: Qwen/Qwen-Image-Edit-2511
+[generate.py:53] - ✅ qwen_image_edit_lightning - Defalut: lightx2v/Qwen-Image-Lightning
+[generate.py:53] - ✅ qwen_image_edit - Defalut: Qwen/Qwen-Image-Edit-2509
+[generate.py:53] - ✅ qwen_image_controlnet - Defalut: InstantX/Qwen-Image-ControlNet-Inpainting
+[generate.py:53] - ✅ qwen_image_layered - Defalut: Qwen/Qwen-Image-Layered
+[generate.py:53] - ✅ ltx2_t2v - Defalut: Lightricks/LTX-2
+[generate.py:53] - ✅ ltx2_i2v - Defalut: Lightricks/LTX-2
+[generate.py:53] - ✅ skyreels_v2 - Defalut: Skywork/SkyReels-V2-T2V-14B-720P-Diffusers
+[generate.py:53] - ✅ wan2.2_t2v - Defalut: Wan-AI/Wan2.2-T2V-A14B-Diffusers
+[generate.py:53] - ✅ wan2.1_t2v - Defalut: Wan-AI/Wan2.1-T2V-1.3B-Diffusers
+[generate.py:53] - ✅ wan2.2_i2v - Defalut: Wan-AI/Wan2.2-I2V-A14B-Diffusers
+[generate.py:53] - ✅ wan2.1_i2v - Defalut: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
+[generate.py:53] - ✅ wan2.2_vace - Defalut: linoyts/Wan2.2-VACE-Fun-14B-diffusers
+[generate.py:53] - ✅ wan2.1_vace - Defalut: Wan-AI/Wan2.1-VACE-1.3B-diffusers
+[generate.py:53] - ✅ ovis_image - Defalut: AIDC-AI/Ovis-Image-7B
+[generate.py:53] - ✅ zimage_nunchaku - Defalut: nunchaku/nunchaku-z-image-turbo
+[generate.py:53] - ✅ zimage - Defalut: Tongyi-MAI/Z-Image-Turbo
+[generate.py:53] - ✅ zimage_controlnet_2.0 - Defalut: alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0
+[generate.py:53] - ✅ zimage_controlnet_2.1 - Defalut: alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1
+[generate.py:53] - ✅ longcat_image - Defalut: meituan-longcat/LongCat-Image
+[generate.py:53] - ✅ longcat_image_edit - Defalut: meituan-longcat/LongCat-Image-Edit
```
-- FLUX.1-dev
+## 📚 Single GPU Inference
+
+The easiest way to enable hybrid cache acceleration for DiTs with cache-dit is to start with single GPU inference. For examples:
```bash
-python3 run_flux.py # baseline
-python3 run_flux.py --cache
+# baseline
+# use default model path, e.g, "black-forest-labs/FLUX.1-dev"
+python3 generate.py flux
+python3 generate.py flux_nunchaku # need nunchaku library
+python3 generate.py flux2
+python3 generate.py ovis_image
+python3 generate.py qwen_image_edit_lightning
+python3 generate.py qwen_image
+python3 generate.py ltx2_t2v --cache --cpu-offload
+python3 generate.py ltx2_i2v --cache --cpu-offload
+python3 generate.py skyreels_v2
+python3 generate.py wan2.2
+python3 generate.py zimage
+python3 generate.py zimage_nunchaku
+python3 generate.py zimage_controlnet_2.1
+python3 generate.py generate longcat_image
+python3 generate.py generate longcat_image_edit
+# w/ cache acceleration
+python3 generate.py flux --cache
+python3 generate.py flux --cache --taylorseer
+python3 generate.py flux_nunchaku --cache
+python3 generate.py qwen_image --cache
+python3 generate.py zimage --cache --rdt 0.6 --scm fast
+python3 generate.py zimage_controlnet_2.1 --cache --rdt 0.6 --scm fast
+# enable cpu offload or vae tiling if your encounter an OOM error
+python3 generate.py qwen_image --cache --cpu-offload
+python3 generate.py qwen_image --cache --cpu-offload --vae-tiling
+python3 generate.py qwen_image_edit_lightning --cpu-offload --steps 4
+python3 generate.py qwen_image_edit_lightning --cpu-offload --steps 8
+# or, enable sequential cpu offload for extremly low VRAM device
+python3 generate.py flux2 --sequential-cpu-offload # FLUX2 56B total
+# use `--summary` option to show the cache acceleration stats
+python3 generate.py zimage --cache --rdt 0.6 --scm fast --summary
```
-- FLUX.1-Fill-dev
+## 📚 Custom Model Path
+
+The default model path are the official model names on HuggingFace Hub. Users can set custom local model path by settig `--model-path`. For examples:
```bash
-python3 run_flux_fill.py # baseline
-python3 run_flux_fill.py --cache
+python3 generate.py flux --model-path /PATH/TO/FLUX.1-dev
+python3 generate.py zimage --model-path /PATH/TO/Z-Image-Turbo
+python3 generate.py qwem_image --model-path /PATH/TO/Qwen-Image
```
-- FLUX.1-Kontext-dev
+## 📚 Multi-GPU Inference
+
+cache-dit is designed to work seamlessly with CPU or Sequential Offloading, 🔥Context Parallelism, 🔥Tensor Parallelism. For examples:
```bash
-python3 run_flux_kontext.py # baseline
-python3 run_flux_kontext.py --cache
+# context parallelism or tensor parallelism
+torchrun --nproc_per_node=4 generate.py flux --parallel ulysses
+torchrun --nproc_per_node=4 generate.py flux --parallel ring
+torchrun --nproc_per_node=4 generate.py flux --parallel tp
+torchrun --nproc_per_node=4 generate.py zimage --parallel ulysses
+torchrun --nproc_per_node=4 generate.py zimage_controlnet_2.1 --parallel ulysses
+# ulysses anything attention
+torchrun --nproc_per_node=4 generate.py zimage --parallel ulysses --ulysses-anything
+torchrun --nproc_per_node=4 generate.py qwen_image_edit_lightning --parallel ulysses --ulysses-anything
+# text encoder parallelism, enable it by add: `--parallel-text-encoder`
+torchrun --nproc_per_node=4 generate.py flux --parallel tp --parallel-text-encoder
+torchrun --nproc_per_node=4 generate.py qwen_image_edit_lightning --parallel ulysses --ulysses-anything --parallel-text-encoder
+# Hint: set `--local-ranks-filter=0` to torchrun -> only show logs on rank 0
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py flux --parallel ulysses
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py ltx2_t2v --parallel ulysses --parallel-vae --parallel-text-encoder --cache --ulysses-anything
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py ltx2_t2v --parallel tp --parallel-vae --parallel-text-encoder --cache
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py ltx2_i2v --parallel ulysses --parallel-vae --parallel-text-encoder --cache --ulysses-anything
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py ltx2_i2v --parallel tp --parallel-vae --parallel-text-encoder --cache
```
-- CogVideoX
+## 📚 Low-bits Quantization
+
+cache-dit is designed to work seamlessly with torch.compile, Quantization (🔥torchao, 🔥nunchaku), For examples:
```bash
-python3 run_cogvideox.py # baseline
-python3 run_cogvideox.py --cache
+# please also enable torch.compile if the quantation is using.
+python3 generate.py flux --cache --quantize-type float8 --compile
+python3 generate.py flux --cache --quantize-type int8 --compile
+python3 generate.py flux --cache --quantize-type float8_weight_only --compile
+python3 generate.py flux --cache --quantize-type int8_weight_only --compile
+python3 generate.py flux --cache --quantize-type bnb_4bit --compile # w4a16
+python3 generate.py flux_nunchaku --cache --compile # w4a16 SVDQ
```
-- Wan2.2 T2V
+## 📚 Hybrid Acceleration
+
+Here are some examples for `hybrid cache acceleration + parallelism` for popular DiTs with cache-dit.
```bash
-python3 run_wan_2.2.py # baseline
-python3 run_wan_2.2.py --cache
-python3 run_wan_2.2.py --cache --compile
-python3 run_wan_2.2.py --cache --compile --quantize
+# DBCache + SCM + Taylorseer
+python3 generate.py flux --cache --scm fast --taylorsees --taylorseer-order 1
+# DBCache + SCM + Taylorseer + Context Parallelism + Text Encoder Parallelism + Compile
+# + FP8 quantization + FP8 All2All comm + CUDNN Attention (--attn _sdpa_cudnn)
+torchrun --nproc_per_node=4 generate.py flux --parallel ulysses --ulysses-float8 \
+ --attn _sdpa_cudnn --parallel-text-encoder --cache --scm fast --taylorseer \
+ --taylorseer-order 1 --quantize-type float8 --warmup 2 --repeat 5 --compile
+# DBCache + SCM + Taylorseer + Context Parallelism + Text Encoder Parallelism + Compile
+# + FP8 quantization + FP8 All2All comm + FP8 SageAttention (--attn sage)
+torchrun --nproc_per_node=4 generate.py flux --parallel ulysses --ulysses-float8 \
+ --attn sage --parallel-text-encoder --cache --scm fast --taylorseer \
+ --taylorseer-order 1 --quantize-type float8 --warmup 2 --repeat 5 --compile
+# Case: Hybrid Acceleration for Qwen-Image-Edit-Lightning, tracking memory usage.
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py qwen_image_edit_lightning \
+ --parallel ulysses --ulysses-anything --parallel-text-encoder \
+ --quantize-type float8_weight_only --steps 4 --track-memory --compile
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py qwen_image_edit_lightning \
+ --parallel tp --parallel-text-encoder --quantize-type float8_weight_only \
+ --steps 4 --track-memory --compile
+# Case: Hybrid Acceleration + Context Parallelism + ControlNet Parallelism, e.g, Z-Image-ControlNet
+torchrun --nproc_per_node=4 generate.py zimage_controlnet_2.1 --parallel ulysses \
+ --parallel-controlnet --cache --rdt 0.6 --scm fast
+torchrun --nproc_per_node=4 generate.py zimage_controlnet_2.1 --parallel ulysses \
+ --parallel-controlnet --cache --scm fast --rdt 0.6 --compile \
+ --compile-controlnet --ulysses-float8 --attn _sdpa_cudnn \
+ --warmup 2 --repeat 4
```
-- Wan2.1 T2V
+## 📚 End2End Examples
```bash
-python3 run_wan.py # baseline
-python3 run_wan.py --cache
-```
+# NO Cache Acceleration: 8.27s
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py flux --parallel ulysses
-- Wan2.1 FLF2V
+INFO 12-17 09:02:31 [base.py:151] Example Input Summary:
+INFO 12-17 09:02:31 [base.py:151] - prompt: A cat holding a sign that says hello world
+INFO 12-17 09:02:31 [base.py:151] - height: 1024
+INFO 12-17 09:02:31 [base.py:151] - width: 1024
+INFO 12-17 09:02:31 [base.py:151] - num_inference_steps: 28
+INFO 12-17 09:02:31 [base.py:214] Example Output Summary:
+INFO 12-17 09:02:31 [base.py:225] - Model: flux
+INFO 12-17 09:02:31 [base.py:225] - Optimization: C0_Q0_NONE_Ulysses4
+INFO 12-17 09:02:31 [base.py:225] - Load Time: 0.79s
+INFO 12-17 09:02:31 [base.py:225] - Warmup Time: 21.09s
+INFO 12-17 09:02:31 [base.py:225] - Inference Time: 8.27s
+INFO 12-17 09:02:32 [base.py:182] Image saved to flux.1024x1024.C0_Q0_NONE_Ulysses4.png
-```bash
-python3 run_wan_flf2v.py # baseline
-python3 run_wan_flf2v.py --cache
+# Enabled Cache Acceleration: 4.23s
+torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py flux --parallel ulysses --cache --scm fast
+
+INFO 12-17 09:10:09 [base.py:151] Example Input Summary:
+INFO 12-17 09:10:09 [base.py:151] - prompt: A cat holding a sign that says hello world
+INFO 12-17 09:10:09 [base.py:151] - height: 1024
+INFO 12-17 09:10:09 [base.py:151] - width: 1024
+INFO 12-17 09:10:09 [base.py:151] - num_inference_steps: 28
+INFO 12-17 09:10:09 [base.py:214] Example Output Summary:
+INFO 12-17 09:10:09 [base.py:225] - Model: flux
+INFO 12-17 09:10:09 [base.py:225] - Optimization: C0_Q0_DBCache_F1B0_W8I1M0MC3_R0.24_CFG0_T0O0_Ulysses4_S15
+INFO 12-17 09:10:09 [base.py:225] - Load Time: 0.78s
+INFO 12-17 09:10:09 [base.py:225] - Warmup Time: 18.49s
+INFO 12-17 09:10:09 [base.py:225] - Inference Time: 4.23s
+INFO 12-17 09:10:09 [base.py:182] Image saved to flux.1024x1024.C0_Q0_DBCache_F1B0_W8I1M0MC3_R0.24_CFG0_T0O0_Ulysses4_S15.png
```
-- mochi-1-preview
+|NO Cache Acceleration: 8.27s| w/ Cache Acceleration: 4.23s|
+|:---:|:---:|
+|||
-```bash
-python3 run_mochi.py # baseline
-python3 run_mochi.py --cache
+## 📚 How to Add New Example
+
+It is very easy to add a new example. Please refer to the specific implementation in [registers.py](./registers.py). For example:
+
+```python
+@ExampleRegister.register("flux")
+def flux_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import FluxPipeline
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=_path("black-forest-labs/FLUX.1-dev"),
+ pipeline_class=FluxPipeline,
+ # `text_encoder_2` will be quantized when `--quantize-type`
+ # is set to `bnb_4bit`.
+ bnb_4bit_components=["text_encoder_2"],
+ ),
+ input_data=ExampleInputData(
+ prompt="A cat holding a sign that says hello world",
+ height=1024,
+ width=1024,
+ num_inference_steps=28,
+ ),
+ )
+
+# NOTE: DON'T forget to add `flux_example` into helpers.py
```
-- HunyuanVideo
+## 📚 More Usages about Examples
```bash
-python3 run_hunyuan_video.py # baseline
-python3 run_hunyuan_video.py --cache
+python3 generate.py --help
+
+usage: generate.py [-h] [--model-path MODEL_PATH] [--controlnet-path CONTROLNET_PATH] [--lora-path LORA_PATH] [--transformer-path TRANSFORMER_PATH] [--image-path IMAGE_PATH] [--mask-image-path MASK_IMAGE_PATH] [--prompt PROMPT]
+ [--negative-prompt NEGATIVE_PROMPT] [--num_inference_steps NUM_INFERENCE_STEPS] [--warmup WARMUP] [--repeat REPEAT] [--height HEIGHT] [--width WIDTH] [--seed SEED] [--num-frames NUM_FRAMES] [--save-path SAVE_PATH] [--cache]
+ [--cache-summary] [--Fn-compute-blocks FN_COMPUTE_BLOCKS] [--Bn-compute-blocks BN_COMPUTE_BLOCKS] [--residual-diff-threshold RESIDUAL_DIFF_THRESHOLD] [--max-warmup-steps MAX_WARMUP_STEPS] [--warmup-interval WARMUP_INTERVAL]
+ [--max-cached-steps MAX_CACHED_STEPS] [--max-continuous-cached-steps MAX_CONTINUOUS_CACHED_STEPS] [--taylorseer] [--taylorseer-order TAYLORSEER_ORDER] [--steps-mask] [--mask-policy {None,slow,s,medium,m,fast,f,ultra,u}]
+ [--quantize] [--quantize-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}] [--quantize-text-encoder]
+ [--quantize-text-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}] [--quantize-controlnet]
+ [--quantize-controlnet-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}] [--parallel-type {None,tp,ulysses,ring}] [--parallel-vae]
+ [--parallel-text-encoder] [--parallel-controlnet] [--attn {None,flash,_flash_3,native,_native_cudnn,_sdpa_cudnn,sage}] [--ulysses-anything] [--ulysses-float8] [--ulysses-async] [--cpu-offload]
+ [--sequential-cpu-offload] [--device-map-balance] [--vae-tiling] [--vae-slicing] [--compile] [--compile-repeated-blocks] [--compile-vae] [--compile-text-encoder] [--compile-controlnet] [--max-autotune] [--track-memory]
+ [--profile] [--profile-name PROFILE_NAME] [--profile-dir PROFILE_DIR] [--profile-activities {CPU,GPU,MEM} [{CPU,GPU,MEM} ...]] [--profile-with-stack] [--profile-record-shapes] [--disable-fuse-lora DISABLE_FUSE_LORA]
+ [{generate,list,flux_nunchaku,flux,flux2,qwen_image_lightning,qwen_image,qwen_image_edit_lightning,qwen_image_edit,qwen_image_controlnet,skyreels_v2,wan2.2_t2v,wan2.1_t2v,wan2.2_i2v,wan2.1_i2v,wan2.2_vace,wan2.1_vace,ovis_image,zimage,zimage_controlnet,longcat_image,longcat_image_edit}]
+ [{None,flux_nunchaku,flux,flux2,qwen_image_lightning,qwen_image,qwen_image_edit_lightning,qwen_image_edit,qwen_image_controlnet,skyreels_v2,wan2.2_t2v,wan2.1_t2v,wan2.2_i2v,wan2.1_i2v,wan2.2_vace,wan2.1_vace,ovis_image,zimage,zimage_controlnet,longcat_image,longcat_image_edit}]
+
+positional arguments:
+ {generate,list,flux_nunchaku,flux,flux2,qwen_image_lightning,qwen_image,qwen_image_edit_lightning,qwen_image_edit,qwen_image_controlnet,skyreels_v2,wan2.2_t2v,wan2.1_t2v,wan2.2_i2v,wan2.1_i2v,wan2.2_vace,wan2.1_vace,ovis_image,zimage,zimage_controlnet,longcat_image,longcat_image_edit}
+ The task to perform or example name to run. Use 'list' to list all available examples, or specify an example name directly (defaults to 'generate' task).
+ {None,flux_nunchaku,flux,flux2,qwen_image_lightning,qwen_image,qwen_image_edit_lightning,qwen_image_edit,qwen_image_controlnet,skyreels_v2,wan2.2_t2v,wan2.1_t2v,wan2.2_i2v,wan2.1_i2v,wan2.2_vace,wan2.1_vace,ovis_image,zimage,zimage_controlnet,longcat_image,longcat_image_edit}
+ Names of the examples to run. If not specified, skip running example.
+
+options:
+ -h, --help show this help message and exit
+ --model-path MODEL_PATH
+ Override model path if provided
+ --controlnet-path CONTROLNET_PATH
+ Override controlnet model path if provided
+ --lora-path LORA_PATH
+ Override lora model path if provided
+ --transformer-path TRANSFORMER_PATH
+ Override transformer model path if provided
+ --image-path IMAGE_PATH
+ Override image path if provided
+ --mask-image-path MASK_IMAGE_PATH
+ Override mask image path if provided
+ --prompt PROMPT Override default prompt if provided
+ --negative-prompt NEGATIVE_PROMPT
+ Override default negative prompt if provided
+ --num_inference_steps NUM_INFERENCE_STEPS, --steps NUM_INFERENCE_STEPS
+ Number of inference steps
+ --warmup WARMUP Number of warmup steps before measuring performance
+ --repeat REPEAT Number of times to repeat the inference for performance measurement
+ --height HEIGHT Height of the generated image
+ --width WIDTH Width of the generated image
+ --seed SEED Random seed for reproducibility
+ --num-frames NUM_FRAMES, --frames NUM_FRAMES
+ Number of frames to generate for video
+ --save-path SAVE_PATH
+ Path to save the generated output, e.g., output.png or output.mp4
+ --cache Enable Cache Acceleration
+ --cache-summary, --summary
+ Enable Cache Summary logging
+ --Fn-compute-blocks FN_COMPUTE_BLOCKS, --Fn FN_COMPUTE_BLOCKS
+ CacheDiT Fn_compute_blocks parameter
+ --Bn-compute-blocks BN_COMPUTE_BLOCKS, --Bn BN_COMPUTE_BLOCKS
+ CacheDiT Bn_compute_blocks parameter
+ --residual-diff-threshold RESIDUAL_DIFF_THRESHOLD, --rdt RESIDUAL_DIFF_THRESHOLD
+ CacheDiT residual diff threshold
+ --max-warmup-steps MAX_WARMUP_STEPS, --ws MAX_WARMUP_STEPS
+ Maximum warmup steps for CacheDiT
+ --warmup-interval WARMUP_INTERVAL, --wi WARMUP_INTERVAL
+ Warmup interval for CacheDiT
+ --max-cached-steps MAX_CACHED_STEPS, --mc MAX_CACHED_STEPS
+ Maximum cached steps for CacheDiT
+ --max-continuous-cached-steps MAX_CONTINUOUS_CACHED_STEPS, --mcc MAX_CONTINUOUS_CACHED_STEPS
+ Maximum continuous cached steps for CacheDiT
+ --taylorseer Enable TaylorSeer for CacheDiT
+ --taylorseer-order TAYLORSEER_ORDER, -order TAYLORSEER_ORDER
+ TaylorSeer order
+ --steps-mask Enable steps mask for CacheDiT
+ --mask-policy {None,slow,s,medium,m,fast,f,ultra,u}, --scm {None,slow,s,medium,m,fast,f,ultra,u}
+ Pre-defined steps computation mask policy
+ --quantize, --q Enable quantization for transformer
+ --quantize-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}, --q-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}
+ --quantize-text-encoder, --q-text
+ Enable quantization for text encoder
+ --quantize-text-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}, --q-text-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}
+ --quantize-controlnet, --q-controlnet
+ Enable quantization for text encoder
+ --quantize-controlnet-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}, --q-controlnet-type {None,float8,float8_weight_only,float8_wo,int8,int8_weight_only,int8_wo,int4,int4_weight_only,int4_wo,bitsandbytes_4bit,bnb_4bit}
+ --parallel-type {None,tp,ulysses,ring}, --parallel {None,tp,ulysses,ring}
+ --parallel-vae Enable VAE parallelism if applicable.
+ --parallel-text-encoder, --parallel-text
+ Enable text encoder parallelism if applicable.
+ --parallel-controlnet
+ Enable ControlNet parallelism if applicable.
+ --attn {None,flash,_flash_3,native,_native_cudnn,_sdpa_cudnn,sage}
+ --ulysses-anything, --uaa
+ Enable Ulysses Anything Attention for context parallelism
+ --ulysses-float8, --ufp8
+ Enable Ulysses Attention/UAA Float8 for context parallelism
+ --ulysses-async, --uaqkv
+ Enabled experimental Async QKV Projection with Ulysses for context parallelism
+ --cpu-offload, --cpu-offload-model
+ Enable CPU offload for model if applicable.
+ --sequential-cpu-offload
+ Enable sequential GPU offload for model if applicable.
+ --device-map-balance, --device-map
+ Enable automatic device map balancing model if multiple GPUs are available.
+ --vae-tiling Enable VAE tiling for low memory device.
+ --vae-slicing Enable VAE slicing for low memory device.
+ --compile Enable compile for transformer
+ --compile-repeated-blocks
+ Enable compile for repeated blocks in transformer
+ --compile-vae Enable compile for VAE
+ --compile-text-encoder, --compile-text
+ Enable compile for text encoder
+ --compile-controlnet Enable compile for ControlNet
+ --max-autotune Enable max-autotune mode for torch.compile
+ --track-memory Track and report peak GPU memory usage
+ --profile Enable profiling with torch.profiler
+ --profile-name PROFILE_NAME
+ Name for the profiling session
+ --profile-dir PROFILE_DIR
+ Directory to save profiling results
+ --profile-activities {CPU,GPU,MEM} [{CPU,GPU,MEM} ...]
+ Activities to profile (CPU, GPU, MEM)
+ --profile-with-stack profile with stack for better traceability
+ --profile-record-shapes
+ profile record shapes for better analysis
+ --disable-fuse-lora DISABLE_FUSE_LORA
+ Disable fuse_lora even if lora weights are provided.
```
diff --git a/examples/adapter/run_flux_adapter.py b/examples/adapter/run_flux_adapter.py
deleted file mode 100644
index 6c252d0ea..000000000
--- a/examples/adapter/run_flux_adapter.py
+++ /dev/null
@@ -1,133 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import FluxPipeline
-from utils import get_args, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = FluxPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "FLUX_DIR",
- "black-forest-labs/FLUX.1-dev",
- )
- ),
- torch_dtype=torch.bfloat16,
-).to("cuda")
-
-
-if args.cache:
-
- from cache_dit import (
- ForwardPattern,
- BlockAdapter,
- ParamsModifier,
- DBCacheConfig,
- )
- from cache_dit.utils import is_diffusers_at_least_0_3_5
- from diffusers import FluxTransformer2DModel
-
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
-
- if is_diffusers_at_least_0_3_5():
- # For diffusers >= 0.35.0
- cache_dit.enable_cache(
- BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_1,
- ForwardPattern.Pattern_1,
- ],
- params_modifiers=[
- ParamsModifier(
- cache_config=DBCacheConfig(
- residual_diff_threshold=0.12,
- ),
- ),
- ParamsModifier(
- cache_config=DBCacheConfig(
- Fn_compute_blocks=1,
- residual_diff_threshold=0.25,
- ),
- ),
- ],
- ),
- )
-
- else:
-
- # For diffusers <= 0.34.0
- cache_dit.enable_cache(
- BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_1,
- ForwardPattern.Pattern_3,
- ],
- params_modifiers=[
- ParamsModifier(
- cache_config=DBCacheConfig(
- residual_diff_threshold=0.12,
- ),
- ),
- ParamsModifier(
- cache_config=DBCacheConfig(
- Fn_compute_blocks=1,
- residual_diff_threshold=0.25,
- ),
- ),
- ],
- ),
- )
-
-
-# Set default prompt
-prompt = "A cat holding a sign that says hello world"
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt,
- num_inference_steps=28,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"flux.adapter.{cache_dit.strify(pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/adapter/run_qwen_image_adapter.py b/examples/adapter/run_qwen_image_adapter.py
deleted file mode 100644
index 76a2fcc44..000000000
--- a/examples/adapter/run_qwen_image_adapter.py
+++ /dev/null
@@ -1,124 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
-from utils import GiB, get_args, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- torch_dtype=torch.bfloat16,
- # https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
- device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
-)
-
-
-if args.cache:
- assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
- from cache_dit import (
- BlockAdapter,
- ForwardPattern,
- DBCacheConfig,
- TaylorSeerCalibratorConfig,
- )
-
- cache_dit.enable_cache(
- BlockAdapter(
- # Any DiffusionPipeline, Qwen-Image, etc.
- pipe=pipe,
- auto=True,
- # Check `📚Forward Pattern Matching` documentation and hack the code of
- # of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
- forward_pattern=ForwardPattern.Pattern_1,
- ),
- # Cache context kwargs
- cache_config=DBCacheConfig(
- residual_diff_threshold=0.12,
- enable_separate_cfg=True,
- ),
- calibrator_config=TaylorSeerCalibratorConfig(
- taylorseer_order=4,
- ),
- )
-
-
-if torch.cuda.device_count() <= 1:
- # Enable memory savings
- pipe.enable_model_cpu_offload()
-
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-if args.prompt is not None:
- prompt = args.prompt
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-
-# Generate with different aspect ratios
-aspect_ratios = {
- "1:1": (1328, 1328),
- "16:9": (1664, 928),
- "9:16": (928, 1664),
- "4:3": (1472, 1140),
- "3:4": (1140, 1472),
- "3:2": (1584, 1056),
- "2:3": (1056, 1584),
-}
-
-width, height = aspect_ratios["16:9"]
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-
-# do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
-image = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=width,
- height=height,
- num_inference_steps=50,
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(42),
-).images[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"qwen-image.adapter.{cache_dit.strify(stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/api/run_cache_refresh_flux.py b/examples/api/run_cache_refresh_flux.py
new file mode 100644
index 000000000..9d610311f
--- /dev/null
+++ b/examples/api/run_cache_refresh_flux.py
@@ -0,0 +1,171 @@
+import os
+import sys
+
+sys.path.append("..")
+
+import time
+import torch
+from diffusers import FluxPipeline, FluxTransformer2DModel
+from utils import get_args, strify, MemoryTracker
+from cache_dit import (
+ BlockAdapter,
+ ForwardPattern,
+ ParamsModifier,
+ DBCacheConfig,
+ TaylorSeerCalibratorConfig,
+)
+import cache_dit
+from cache_dit.platforms import current_platform
+
+parser = get_args(parse=False)
+parser.add_argument(
+ "--no-adapt",
+ action="store_true",
+ default=False,
+ help="Disable BlockAdapter or not",
+)
+parser.add_argument(
+ "--summary",
+ action="store_true",
+ default=False,
+ help="Print summary of the model after each inference",
+)
+parser.add_argument(
+ "--refresh-use-cache-config",
+ action="store_true",
+ default=False,
+ help="Use the cache config during cache refreshing",
+)
+args = parser.parse_args()
+print(args)
+
+
+pipe = FluxPipeline.from_pretrained(
+ (
+ args.model_path
+ if args.model_path is not None
+ else os.environ.get(
+ "FLUX_DIR",
+ "black-forest-labs/FLUX.1-dev",
+ )
+ ),
+ torch_dtype=torch.bfloat16,
+).to(current_platform.device_type)
+
+if args.cache:
+
+ assert isinstance(pipe.transformer, FluxTransformer2DModel)
+
+ cache_dit.enable_cache(
+ (
+ BlockAdapter(
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_1,
+ ],
+ )
+ if not args.no_adapt
+ else pipe.transformer
+ ),
+ cache_config=(
+ DBCacheConfig(
+ Fn_compute_blocks=args.Fn,
+ Bn_compute_blocks=args.Bn,
+ max_warmup_steps=args.max_warmup_steps,
+ max_cached_steps=args.max_cached_steps,
+ max_continuous_cached_steps=args.max_continuous_cached_steps,
+ residual_diff_threshold=args.rdt,
+ # NOTE: num_inference_steps can be None here, we will
+ # set it properly during cache refreshing.
+ num_inference_steps=None,
+ )
+ if args.cache
+ else None
+ ),
+ params_modifiers=[
+ ParamsModifier(
+ cache_config=DBCacheConfig().reset(
+ residual_diff_threshold=args.rdt,
+ ),
+ ),
+ ParamsModifier(
+ # NOTE: single_transformer_blocks should have higher
+ # residual_diff_threshold because of the precision error
+ # accumulation from previous transformer_blocks
+ cache_config=DBCacheConfig().reset(
+ residual_diff_threshold=args.rdt * 3,
+ ),
+ ),
+ ],
+ )
+
+# Set default prompt
+prompt = "A cat holding a sign that says hello world"
+if args.prompt is not None:
+ prompt = args.prompt
+
+
+def run_pipe(steps: int = 28):
+ if args.refresh_use_cache_config:
+ cache_dit.refresh_context(
+ pipe.transformer,
+ # The cache settings should all be located in the cache config
+ # if cache config is provided. Otherwise, we will skip it.
+ cache_config=DBCacheConfig().reset(
+ num_inference_steps=steps,
+ ),
+ calibrator_config=TaylorSeerCalibratorConfig().reset(
+ taylorseer_order=1,
+ ),
+ verbose=True,
+ )
+ else:
+ cache_dit.refresh_context(
+ pipe.transformer,
+ num_inference_steps=steps,
+ verbose=True,
+ )
+ image = pipe(
+ prompt,
+ height=1024 if args.height is None else args.height,
+ width=1024 if args.width is None else args.width,
+ num_inference_steps=steps,
+ generator=torch.Generator("cpu").manual_seed(0),
+ ).images[0]
+ return image
+
+
+if args.compile:
+ cache_dit.set_compile_configs()
+ pipe.transformer = torch.compile(pipe.transformer)
+
+
+memory_tracker = MemoryTracker() if args.track_memory else None
+if memory_tracker:
+ memory_tracker.__enter__()
+
+steps = [8, 16, 28, 40, 50]
+for i in range(len(steps)):
+ print("-" * 150)
+ start = time.time()
+ image = run_pipe(steps=steps[i])
+ end = time.time()
+ time_cost = end - start
+
+ save_path = f"flux.steps{steps[i]}.{strify(args, pipe.transformer)}.png"
+ image.save(save_path)
+
+ if args.summary:
+ cache_dit.summary(pipe.transformer)
+ print(f"Time cost: {time_cost:.2f}s")
+ print(f"Saving image to {save_path}")
+
+
+if memory_tracker:
+ memory_tracker.__exit__(None, None, None)
+ memory_tracker.report()
diff --git a/examples/pipeline/run_wan_2.2.py b/examples/api/run_cache_refresh_wan_2.2.py
similarity index 58%
rename from examples/pipeline/run_wan_2.2.py
rename to examples/api/run_cache_refresh_wan_2.2.py
index 1be426828..2b859de3d 100644
--- a/examples/pipeline/run_wan_2.2.py
+++ b/examples/api/run_cache_refresh_wan_2.2.py
@@ -11,11 +11,18 @@
from diffusers.schedulers.scheduling_unipc_multistep import (
UniPCMultistepScheduler,
)
-from utils import get_args, GiB, strify, cachify, MemoryTracker
+from utils import get_args, GiB, strify, MemoryTracker
import cache_dit
-
-
-args = get_args()
+from cache_dit.platforms import current_platform
+
+parser = get_args(parse=False)
+parser.add_argument(
+ "--summary",
+ action="store_true",
+ default=False,
+ help="Print summary of the model after each inference",
+)
+args = parser.parse_args()
print(args)
@@ -31,7 +38,7 @@
),
torch_dtype=torch.bfloat16,
# https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
- device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
+ device_map=("balanced" if (current_platform.device_count() > 1 and GiB() <= 48) else None),
)
# flow shift should be 3.0 for 480p images, 5.0 for 720p images
@@ -46,16 +53,18 @@
if args.cache:
from cache_dit import (
- ForwardPattern,
BlockAdapter,
+ ForwardPattern,
ParamsModifier,
DBCacheConfig,
)
- cachify(
- args,
+ assert isinstance(pipe.transformer, WanTransformer3DModel)
+ assert isinstance(pipe.transformer_2, WanTransformer3DModel)
+
+ # Dual transformer caching with transformer-only api in cache-dit.
+ cache_dit.enable_cache(
BlockAdapter(
- pipe=pipe,
transformer=[
pipe.transformer,
pipe.transformer_2,
@@ -85,16 +94,27 @@
],
has_separate_cfg=True,
),
+ cache_config=DBCacheConfig(
+ Fn_compute_blocks=args.Fn,
+ Bn_compute_blocks=args.Bn,
+ max_warmup_steps=args.max_warmup_steps,
+ max_cached_steps=args.max_cached_steps,
+ max_continuous_cached_steps=args.max_continuous_cached_steps,
+ residual_diff_threshold=args.rdt,
+ # NOTE: num_inference_steps can be None here, we will
+ # set it properly during cache refreshing.
+ num_inference_steps=None,
+ ),
)
# When device_map is None, we need to explicitly move the model to GPU
# or enable CPU offload to avoid running on CPU
-if torch.cuda.device_count() <= 1:
+if current_platform.device_count() <= 1:
# Single GPU: use CPU offload for memory efficiency
pipe.enable_model_cpu_offload()
-elif torch.cuda.device_count() > 1 and pipe.device.type == "cpu":
+elif current_platform.device_count() > 1 and pipe.device.type == "cpu":
# Multi-GPU but model is on CPU (device_map was None): move to default GPU
- pipe.to("cuda")
+ pipe.to(current_platform.device_type)
# Wan currently requires installing diffusers from source
assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
@@ -138,46 +158,83 @@
if args.negative_prompt is not None:
negative_prompt = args.negative_prompt
-if args.compile or args.quantize:
- cache_dit.set_compile_configs()
- pipe.transformer.compile_repeated_blocks(fullgraph=True)
- pipe.transformer_2.compile_repeated_blocks(fullgraph=True)
- # warmup
+def split_inference_steps(num_inference_steps: int = 30) -> tuple[int, int]:
+ if pipe.config.boundary_ratio is not None:
+ boundary_timestep = pipe.config.boundary_ratio * pipe.scheduler.config.num_train_timesteps
+ else:
+ boundary_timestep = None
+ pipe.scheduler.set_timesteps(num_inference_steps, device=current_platform.device_type)
+ timesteps = pipe.scheduler.timesteps
+ num_high_noise_steps = 0 # high-noise steps for transformer
+ for t in timesteps:
+ if boundary_timestep is not None and t >= boundary_timestep:
+ num_high_noise_steps += 1
+ # low-noise steps for transformer_2
+ num_low_noise_steps = num_inference_steps - num_high_noise_steps
+ return num_high_noise_steps, num_low_noise_steps
+
+
+def run_pipe(steps: int = 30):
+
+ if args.cache:
+ # Refresh cache context with proper num_inference_steps
+ num_high_noise_steps, num_low_noise_steps = split_inference_steps(
+ num_inference_steps=steps,
+ )
+
+ cache_dit.refresh_context(
+ pipe.transformer,
+ num_inference_steps=num_high_noise_steps,
+ verbose=True,
+ )
+ cache_dit.refresh_context(
+ pipe.transformer_2,
+ num_inference_steps=num_low_noise_steps,
+ verbose=True,
+ )
video = pipe(
prompt=prompt,
height=height,
width=width,
num_frames=81,
- num_inference_steps=50,
+ num_inference_steps=steps,
generator=torch.Generator("cpu").manual_seed(0),
).frames[0]
+ return video
+
+
+if args.compile or args.quantize:
+ cache_dit.set_compile_configs()
+ pipe.transformer.compile_repeated_blocks(fullgraph=True)
+ pipe.transformer_2.compile_repeated_blocks(fullgraph=True)
+
+ # warmup
+ run_pipe(steps=8)
memory_tracker = MemoryTracker() if args.track_memory else None
if memory_tracker:
memory_tracker.__enter__()
-start = time.time()
-video = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=height,
- width=width,
- num_frames=81,
- num_inference_steps=50,
- generator=torch.Generator("cpu").manual_seed(0),
-).frames[0]
-end = time.time()
+
+steps = [16, 28, 50]
+for i in range(len(steps)):
+ print("-" * 150)
+ start = time.time()
+ video = run_pipe(steps=steps[i])
+ end = time.time()
+ time_cost = end - start
+
+ save_path = f"wan2.2.steps{steps[i]}.{strify(args, pipe.transformer)}.mp4"
+ export_to_video(video, save_path, fps=16)
+
+ if args.summary:
+ cache_dit.summary(pipe, details=True)
+ print(f"Time cost: {time_cost:.2f}s")
+ print(f"Saving video to {save_path}")
+
if memory_tracker:
memory_tracker.__exit__(None, None, None)
memory_tracker.report()
-
-cache_dit.summary(pipe, details=True)
-
-time_cost = end - start
-save_path = f"wan2.2.{strify(args, pipe)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=16)
diff --git a/examples/api/run_cpu_offload.py b/examples/api/run_cpu_offload.py
deleted file mode 100644
index 4fb4e4f42..000000000
--- a/examples/api/run_cpu_offload.py
+++ /dev/null
@@ -1,168 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
-from utils import GiB, get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-from cache_dit.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-parser = get_args(parse=False)
-parser.add_argument(
- "--offload-type",
- type=str,
- choices=["model", "sequential", "group"],
- default="model",
-)
-parser.add_argument(
- "--cache-after-offload",
- action="store_true",
- default=False,
-)
-args = parser.parse_args()
-logger.info(args)
-
-
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- torch_dtype=torch.bfloat16,
- # https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
- device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
-)
-
-if args.cache and not args.cache_after_offload:
- logger.info("Enabled Cache before offload")
- cachify(args, pipe)
-
-if torch.cuda.device_count() <= 1:
- # Enable memory savings
- if args.offload_type == "model":
- logger.info("Enabled Model CPU Offload")
- pipe.enable_model_cpu_offload()
- elif args.offload_type == "sequential":
- logger.info("Enabled Sequential CPU Offload")
- pipe.enable_sequential_cpu_offload()
- elif args.offload_type == "group":
- logger.info("Enabled Group Offload")
- pipe.enable_group_offload(
- onload_device=torch.device("cuda"),
- offload_device=torch.device("cpu"),
- offload_type="block_level",
- num_blocks_per_group=1,
- use_stream=True,
- record_stream=True,
- exclude_modules=[
- "vae",
- ],
- )
-
-if args.cache and args.cache_after_offload:
- logger.info("Enabled Cache after offload")
- # WARN: cache after group offload still not work.
- cachify(args, pipe)
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-if args.prompt is not None:
- prompt = args.prompt
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-
-# Generate with different aspect ratios
-aspect_ratios = {
- "1:1": (1328, 1328),
- "16:9": (1664, 928),
- "9:16": (928, 1664),
- "4:3": (1472, 1140),
- "3:4": (1140, 1472),
- "3:2": (1584, 1056),
- "2:3": (1056, 1584),
-}
-
-width, height = aspect_ratios["16:9"]
-
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
-if args.quantize:
- # Apply Quantization (default: FP8 DQ) to Transformer
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type,
- per_row=False,
- exclude_layers=[
- "img_in",
- "txt_in",
- "embedder",
- "embed",
- "norm_out",
- "proj_out",
- ],
- )
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer.compile_repeated_blocks(fullgraph=True)
-
- # warmup
- image = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=width,
- height=height,
- num_inference_steps=50,
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(42),
- ).images[0]
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-# do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
-image = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=width,
- height=height,
- num_inference_steps=50,
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(42),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"qwen-image.{strify(args, stats)}.png"
-logger.info(f"Time cost: {time_cost:.2f}s")
-logger.info(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/api/run_disable_cache.py b/examples/api/run_disable_cache.py
deleted file mode 100644
index eda7743e9..000000000
--- a/examples/api/run_disable_cache.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import FluxPipeline
-from utils import get_args, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = FluxPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "FLUX_DIR",
- "black-forest-labs/FLUX.1-dev",
- )
- ),
- torch_dtype=torch.bfloat16,
-).to("cuda")
-
-
-if args.cache:
- adapter = cache_dit.enable_cache(pipe)
- print(cache_dit.strify(pipe))
- # Test disable_cache api
- # cache_dit.disable_cache(adapter)
- cache_dit.disable_cache(pipe)
- print(cache_dit.strify(pipe))
-
-
-# Set default prompt
-prompt = "A cat holding a sign that says hello world"
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt,
- num_inference_steps=28,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"flux.{cache_dit.strify(pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/api/run_steps_mask.py b/examples/api/run_steps_mask.py
index ce562b58d..b1ff4f910 100644
--- a/examples/api/run_steps_mask.py
+++ b/examples/api/run_steps_mask.py
@@ -8,6 +8,7 @@
from diffusers import FluxPipeline, FluxTransformer2DModel
from utils import get_args, strify, MemoryTracker
import cache_dit
+from cache_dit.platforms import current_platform
parser = get_args(parse=False)
@@ -36,25 +37,6 @@
if args.step_mask in step_mask_aliases:
args.step_mask = step_mask_aliases[args.step_mask]
-# Define different step computation masks for 28 steps
-step_computation_masks = {
- "slow": cache_dit.steps_mask(
- compute_bins=[8, 3, 3, 2, 2], # 18
- cache_bins=[1, 2, 2, 2, 3], # 10
- ),
- "medium": cache_dit.steps_mask(
- compute_bins=[6, 2, 2, 2, 2], # 14
- cache_bins=[1, 3, 3, 3, 4], # 14
- ),
- "fast": cache_dit.steps_mask(
- compute_bins=[6, 1, 1, 1, 1], # 10
- cache_bins=[1, 3, 4, 5, 5], # 18
- ),
- "ultra": cache_dit.steps_mask(
- compute_bins=[4, 1, 1, 1, 1], # 8
- cache_bins=[1, 4, 5, 6, 6], # 20
- ),
-}
step_computation_dynamic_policy_rdt = {
"slow": 0.20,
@@ -95,8 +77,12 @@
max_continuous_cached_steps=args.max_continuous_cached_steps,
residual_diff_threshold=args.rdt,
# LeMiCa or EasyCache style Mask for 28 steps, e.g,
- # 111111010010000010000100001, 1: compute, 0: cache.
- steps_computation_mask=step_computation_masks[args.step_mask],
+ # slow: 11111111 0 111 00 111 00 11 00 1 000 1,
+ # 1: full compute steps, 0: dynamic/static cache.
+ steps_computation_mask=cache_dit.steps_mask(
+ mask_policy=args.step_mask, # slow, medium, fast, ultra.
+ total_steps=28 if args.steps is None else args.steps,
+ ),
# The policy for cache steps can be 'dynamic' or 'static'
steps_computation_policy=args.step_policy,
),
@@ -125,7 +111,7 @@
)
print(f"Applied quantization: {args.quantize_type} to Transformer and Text Encoder 2.")
-pipe.to("cuda")
+pipe.to(current_platform.device_type)
if args.attn is not None:
if hasattr(pipe.transformer, "set_attention_backend"):
diff --git a/examples/api/run_transformer_only.py b/examples/api/run_transformer_only.py
index 4896d2b34..1c4aa115a 100644
--- a/examples/api/run_transformer_only.py
+++ b/examples/api/run_transformer_only.py
@@ -8,6 +8,7 @@
from diffusers import FluxPipeline, FluxTransformer2DModel
from utils import get_args, strify, MemoryTracker
import cache_dit
+from cache_dit.platforms import current_platform
parser = get_args(parse=False)
@@ -31,7 +32,7 @@
)
),
torch_dtype=torch.bfloat16,
-).to("cuda")
+).to(current_platform.device_type)
if args.cache:
from cache_dit import (
diff --git a/examples/api/run_unified_api.py b/examples/api/run_unified_api.py
deleted file mode 100644
index 4411f32fb..000000000
--- a/examples/api/run_unified_api.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import FluxPipeline
-from utils import get_args, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = FluxPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "FLUX_DIR",
- "black-forest-labs/FLUX.1-dev",
- )
- ),
- torch_dtype=torch.bfloat16,
-).to("cuda")
-
-
-if args.cache:
- cachify(args, pipe)
- print(cache_dit.strify(pipe))
-
-
-# Set default prompt
-prompt = "A cat holding a sign that says hello world"
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt,
- num_inference_steps=28,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"flux.{cache_dit.strify(pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/assets/flux.1024x1024.C0_Q0_DBCache_F1B0_W4I1M0MC3_R0.24_CFG0_T0O0_Ulysses4_S17.png b/examples/assets/flux.1024x1024.C0_Q0_DBCache_F1B0_W4I1M0MC3_R0.24_CFG0_T0O0_Ulysses4_S17.png
new file mode 100644
index 000000000..6563cddf3
Binary files /dev/null and b/examples/assets/flux.1024x1024.C0_Q0_DBCache_F1B0_W4I1M0MC3_R0.24_CFG0_T0O0_Ulysses4_S17.png differ
diff --git a/examples/assets/flux.1024x1024.C0_Q0_DBCache_F1B0_W8I1M0MC3_R0.24_CFG0_T0O0_Ulysses4_S15.png b/examples/assets/flux.1024x1024.C0_Q0_DBCache_F1B0_W8I1M0MC3_R0.24_CFG0_T0O0_Ulysses4_S15.png
new file mode 100644
index 000000000..afa457cb5
Binary files /dev/null and b/examples/assets/flux.1024x1024.C0_Q0_DBCache_F1B0_W8I1M0MC3_R0.24_CFG0_T0O0_Ulysses4_S15.png differ
diff --git a/examples/assets/flux.1024x1024.C0_Q0_NONE_Ulysses4.png b/examples/assets/flux.1024x1024.C0_Q0_NONE_Ulysses4.png
new file mode 100644
index 000000000..e7334e55a
Binary files /dev/null and b/examples/assets/flux.1024x1024.C0_Q0_NONE_Ulysses4.png differ
diff --git a/examples/assets/zimage_controlnet.1728x992.C0_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_CNP.png b/examples/assets/zimage_controlnet.1728x992.C0_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_CNP.png
new file mode 100644
index 000000000..3ee69fd40
Binary files /dev/null and b/examples/assets/zimage_controlnet.1728x992.C0_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_CNP.png differ
diff --git a/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE.png b/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE.png
new file mode 100644
index 000000000..263ea34ad
Binary files /dev/null and b/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE.png differ
diff --git a/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE_Ulysses2.png b/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE_Ulysses2.png
new file mode 100644
index 000000000..ac1041aa1
Binary files /dev/null and b/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE_Ulysses2.png differ
diff --git a/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE_Ulysses4.png b/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE_Ulysses4.png
new file mode 100644
index 000000000..3a303a3bb
Binary files /dev/null and b/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE_Ulysses4.png differ
diff --git a/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE_Ulysses4_CNP.png b/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE_Ulysses4_CNP.png
new file mode 100644
index 000000000..c2376fb9f
Binary files /dev/null and b/examples/assets/zimage_controlnet.1728x992.C0_Q0_NONE_Ulysses4_CNP.png differ
diff --git a/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_CNP.png b/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_CNP.png
new file mode 100644
index 000000000..7ec108736
Binary files /dev/null and b/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_CNP.png differ
diff --git a/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_async_CNP.png b/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_async_CNP.png
new file mode 100644
index 000000000..039ec2c09
Binary files /dev/null and b/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_async_CNP.png differ
diff --git a/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_CNP.png b/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_CNP.png
new file mode 100644
index 000000000..35e75e7c0
Binary files /dev/null and b/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_CNP.png differ
diff --git a/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_CNP_sdpa_cudnn.png b/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_CNP_sdpa_cudnn.png
new file mode 100644
index 000000000..beb87c6cb
Binary files /dev/null and b/examples/assets/zimage_controlnet.1728x992.C1_Q0_DBCache_F1B0_W4I1M0MC3_R0.6_SCM111101001_dynamic_CFG0_T0O0_Ulysses4_S2_ulysses_float8_CNP_sdpa_cudnn.png differ
diff --git a/examples/base.py b/examples/base.py
new file mode 100644
index 000000000..e35de9494
--- /dev/null
+++ b/examples/base.py
@@ -0,0 +1,682 @@
+import os
+import time
+import types
+import torch
+import argparse
+import dataclasses
+from PIL import Image
+from enum import Enum
+import numpy as np
+from typing import Dict, Any, Union, Optional, List, Callable
+from diffusers.utils import export_to_video
+from diffusers.schedulers import SchedulerMixin
+from diffusers import DiffusionPipeline, ModelMixin
+from transformers import GenerationMixin
+from diffusers.loaders.lora_base import LoraBaseMixin
+from diffusers.quantizers import PipelineQuantizationConfig
+from cache_dit.logger import init_logger
+import cache_dit
+
+from utils import (
+ strify,
+ maybe_destroy_distributed,
+ maybe_init_distributed,
+ maybe_apply_optimization,
+ pipe_quant_bnb_4bit_config,
+ create_profiler_from_args,
+ MemoryTracker,
+)
+
+logger = init_logger(__name__)
+
+
+class ExampleType(Enum):
+ T2V = "T2V - Text to Video"
+ I2V = "I2V - Image to Video"
+ T2I = "T2I - Text to Image"
+ IE2I = "IE2I - Image Editing to Image"
+ FLF2V = "FLF2V - First Last Frames to Video"
+ VACE = "VACE - Video All-in-one Creation and Editing"
+
+
+@dataclasses.dataclass
+class ExampleInputData:
+ # This class provides default input data for examples.
+ # The default values may be overridden by command line
+ # args or other means.
+ # General inputs for both image and video generation
+ prompt: Optional[str] = None
+ negative_prompt: Optional[str] = None
+ height: Optional[int] = None
+ width: Optional[int] = None
+ guidance_scale: Optional[float] = None
+ guidance_scale_2: Optional[float] = None # for dual guidance scale
+ true_cfg_scale: Optional[float] = None
+ num_inference_steps: Optional[int] = None
+ num_images_per_prompt: Optional[int] = None
+ num_frames: Optional[int] = None
+ # Specific inputs for image editing
+ image: Optional[Union[List[Image.Image], Image.Image]] = None
+ mask_image: Optional[Union[List[Image.Image], Image.Image]] = None
+ # Specific inputs for video generation, e.g, Wan VACE
+ video: Optional[List[Image.Image]] = None
+ mask: Optional[List[Image.Image]] = None
+ # Specific inputs for controlnet, e.g, Qwen-Image-ControlNet-Inpainting
+ control_image: Optional[Union[List[Image.Image], Image.Image]] = None
+ control_mask: Optional[Union[List[Image.Image], Image.Image]] = None
+ controlnet_conditioning_scale: Optional[float] = None
+ # Specific inputs for Qwen Image Layered
+ layers: Optional[int] = None
+ resolution: Optional[int] = None
+ cfg_normalize: Optional[bool] = None
+ use_en_prompt: Optional[bool] = None
+ # Other inputs
+ seed: int = 0
+ # Use 'cpu' by default for better reproducibility across different hardware
+ gen_device: str = "cpu"
+ generator: torch.Generator = dataclasses.field(
+ default_factory=lambda: torch.Generator("cpu").manual_seed(0)
+ )
+ # Some extra args, e.g, editing model specific inputs
+ extra_input_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
+
+ def data(self, args: argparse.Namespace) -> Dict[str, Any]:
+ self._preprocess()
+ data = dataclasses.asdict(self)
+ # Flatten extra_args and merge into main dict
+ extra_args = data.pop("extra_input_kwargs") # {key: value, ...}
+ extra_args = extra_args if extra_args is not None else {}
+ # Remove None values from extra_args
+ extra_data = {k: v for k, v in extra_args.items() if v is not None}
+ input_data = {k: v for k, v in data.items() if v is not None}
+ input_data.update(extra_data)
+ # Override with args if provided
+ if args.prompt is not None:
+ input_data["prompt"] = args.prompt
+ if args.negative_prompt is not None:
+ input_data["negative_prompt"] = args.negative_prompt
+ if args.height is not None:
+ input_data["height"] = args.height
+ if args.width is not None:
+ input_data["width"] = args.width
+ if args.num_inference_steps is not None:
+ input_data["num_inference_steps"] = args.num_inference_steps
+ if args.num_frames is not None:
+ input_data["num_frames"] = args.num_frames
+ if args.image_path is not None:
+ if "image" in input_data:
+ if isinstance(input_data["image"], list):
+ if len(input_data["image"]) > 1:
+ logger.warning(
+ "Overriding multiple input images with a single image "
+ "from args.image_path."
+ )
+ input_data["image"] = Image.open(args.image_path).convert("RGB")
+ if args.mask_image_path is not None:
+ if "mask_image" in input_data:
+ if isinstance(input_data["mask_image"], list):
+ if len(input_data["mask_image"]) > 1:
+ logger.warning(
+ "Overriding multiple input mask images with a single mask "
+ "image from args.mask_image_path."
+ )
+ input_data["mask_image"] = Image.open(args.mask_image_path).convert("RGB")
+ # Set generator with seed from input data or args
+ if args.generator_device is not None:
+ self.gen_device = args.generator_device
+ if args.seed is not None:
+ self.seed = args.seed
+ input_data["generator"] = torch.Generator(self.gen_device).manual_seed(self.seed)
+ # Remove redundant keys from input data
+ input_data.pop("seed", None)
+ input_data.pop("gen_device", None)
+ return input_data
+
+ def new_generator(self, args: argparse.Namespace = None) -> torch.Generator:
+ # NOTE: We should always create a new generator before each inference to
+ # ensure reproducibility when using the same seed. Alawys use cpu generator
+ # for better cross-device consistency.
+ if args is not None and args.generator_device is not None:
+ self.gen_device = args.generator_device
+ if args is not None and args.seed is not None:
+ return torch.Generator(self.gen_device).manual_seed(args.seed)
+ elif self.seed is not None:
+ return torch.Generator(self.gen_device).manual_seed(self.seed)
+ else:
+ return torch.Generator(self.gen_device).manual_seed(0)
+
+ def _preprocess(self):
+ if self.image is not None:
+ if isinstance(self.image, list) and len(self.image) == 1:
+ # unwrap single image from list for general use cases
+ self.image = self.image[0]
+ if self.mask_image is not None:
+ if isinstance(self.mask_image, list) and len(self.mask_image) == 1:
+ # unwrap single mask image from list for general use cases
+ self.mask_image = self.mask_image[0]
+
+ def summary(self, args: argparse.Namespace) -> str:
+ summary_str = "🤖 Example Input Summary:\n"
+ data = self.data(args)
+ for k, v in data.items():
+ if k in ["prompt", "negative_prompt"]:
+ summary_str += f"- {k}: {v}\n"
+ elif k in ["height", "width", "num_inference_steps", "num_frames"]:
+ summary_str += f"- {k}: {v}\n"
+ elif k in ["image", "mask_image", "control_image", "control_mask"]:
+ if isinstance(v, Image.Image):
+ W, H = v.size
+ summary_str += f"- {k}: Single Image ({H}x{W})\n"
+ elif isinstance(v, list):
+ if len(v) > 0:
+ summary_str += f"- {k}: List Images ({len(v)} images)\n"
+ for i in range(min(len(v), 3)): # show up to 3 images
+ if isinstance(v[i], Image.Image):
+ W, H = v[i].size
+ summary_str += f" - Image {i}: ({H}x{W})\n"
+ else:
+ summary_str += f" - Image {i}: Not a valid PIL Image\n"
+ elif len(v) == 1:
+ if isinstance(v[0], Image.Image):
+ W, H = v[0].size
+ summary_str += f"- {k}: Single Image ({H}x{W})\n"
+ else:
+ summary_str += f"- {k}: Not a valid PIL Image\n"
+ else:
+ summary_str += f"- {k}: Empty List\n"
+ elif k in ["video", "mask"]:
+ if isinstance(v, list):
+ if len(v) > 0:
+ summary_str += f"- {k}: List of Frames ({len(v)} frames)\n"
+ for i in range(min(len(v), 1)): # show up to 1 frames
+ if isinstance(v[i], Image.Image):
+ W, H = v[i].size
+ summary_str += f" - Frame {i}: ({H}x{W})\n"
+ else:
+ summary_str += f" - Frame {i}: Not a valid PIL Image\n"
+ else:
+ summary_str += f"- {k}: Empty List\n"
+ else:
+ summary_str += f"- {k}: Not a valid list of frames\n"
+ elif k == "generator":
+ # Show seed and device info
+ if isinstance(v, torch.Generator):
+ gen_device = v.device if hasattr(v, "device") else "cpu"
+ gen_seed = v.initial_seed() if hasattr(v, "initial_seed") else "N/A"
+ summary_str += f"- {k}: device {gen_device}, seed {gen_seed}\n"
+ else:
+ summary_str += f"- {k}: Not a valid torch.Generator\n"
+ else:
+ summary_str += f"- {k}: {v}\n"
+ summary_str = summary_str.rstrip("\n")
+ logger.info(summary_str)
+ return summary_str
+
+
+@dataclasses.dataclass
+class ExampleOutputData:
+ # Tag
+ model_tag: Optional[str] = None
+ strify_tag: Optional[str] = None
+ # Generated image or video
+ image: Optional[Image.Image | List[Image.Image]] = (
+ None # Single PIL Images or list of PIL Images
+ )
+ video: Optional[List[Image.Image]] = None # List of PIL Images or video frames
+ # Performance metrics
+ load_time: Optional[float] = None
+ warmup_time: Optional[float] = None
+ inference_time: Optional[float] = None
+ memory_usage: Optional[float] = None
+ # Other outputs
+ extra_output_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
+
+ def save(self, args: argparse.Namespace) -> None:
+ # TODO: Handle other extra outputs as needed
+ save_path = args.save_path
+ if save_path is None:
+ save_path = self._default_save_path()
+ if save_path is None:
+ logger.warning("No valid save path found for output data.")
+ return
+
+ if self.image is not None:
+ if isinstance(self.image, Image.Image):
+ self.image.save(save_path)
+ logger.info(f"Image saved to {save_path}")
+ elif isinstance(self.image, list):
+ save_pre = ".".join(save_path.split(".")[:-1])
+ save_ext = save_path.split(".")[-1]
+ for i, img in enumerate(self.image):
+ img_save_path = f"{save_pre}_{i}.{save_ext}"
+ img.save(img_save_path)
+ logger.info(f"Image {i} saved to {img_save_path}")
+
+ if self.video is not None:
+ export_to_video(self.video, save_path, fps=8)
+ logger.info(f"Video saved to {save_path}")
+
+ def _default_save_path(self) -> Optional[str]:
+ if self.image is not None:
+ try:
+ W, H = self.image.size
+ HxW_str = f"{H}x{W}"
+ except Exception:
+ HxW_str = None
+ if HxW_str is not None:
+ if HxW_str not in self.strify_tag:
+ return f"{self.model_tag}.{HxW_str}.{self.strify_tag}.png"
+ else:
+ return f"{self.model_tag}.{self.strify_tag}.png"
+ else:
+ return f"{self.model_tag}.{self.strify_tag}.png"
+ elif self.video is not None:
+ try:
+ if isinstance(self.video, (list, np.ndarray)) and len(self.video) > 0:
+ if isinstance(self.video[0], Image.Image):
+ W, H = self.video[0].size
+ elif isinstance(self.video[0], np.ndarray):
+ frame = self.video[0] # type: np.ndarray
+ H, W = frame.shape[:2]
+ else:
+ raise ValueError("Invalid video frame type.")
+ if isinstance(self.video, list):
+ num_frames = len(self.video)
+ elif isinstance(self.video, np.ndarray):
+ num_frames = self.video.shape[0]
+ else:
+ raise ValueError("Invalid video type.")
+ HxW_str = f"{H}x{W}x{num_frames}"
+ else:
+ HxW_str = None
+ except Exception:
+ HxW_str = None
+ if HxW_str is not None:
+ if HxW_str not in self.strify_tag:
+ return f"{self.model_tag}.{HxW_str}.{self.strify_tag}.mp4"
+ else:
+ return f"{self.model_tag}.{self.strify_tag}.mp4"
+ else:
+ return f"{self.model_tag}.{self.strify_tag}.mp4"
+ else:
+ return None
+
+ def summary(self, args: argparse.Namespace) -> str:
+ from cache_dit.platforms import current_platform
+
+ logger.info("🤖 Example Output Summary:")
+ summary_str = f"- Model: {args.example}\n- Optimization: {self.strify_tag}\n"
+ device_name = current_platform.get_device_name()
+ world_size = (
+ 1 if not torch.distributed.is_initialized() else torch.distributed.get_world_size()
+ )
+ summary_str += f"- Device: {device_name} x {world_size}\n"
+ if self.load_time is not None:
+ summary_str += f"- Load Time: {self.load_time:.2f}s\n"
+ if self.warmup_time is not None:
+ summary_str += f"- Warmup Time: {self.warmup_time:.2f}s\n"
+ if self.inference_time is not None:
+ summary_str += f"- Inference Time: {self.inference_time:.2f}s\n"
+ if self.memory_usage is not None:
+ summary_str += f"- Memory Usage: {self.memory_usage:.2f}GiB\n"
+ summary_str = summary_str.rstrip("\n")
+ logger.info(summary_str)
+ return summary_str
+
+
+@dataclasses.dataclass
+class ExampleInitConfig:
+ # This class provides default initialization config for examples.
+ # The default values may be overridden by command line args or other means.
+ task_type: ExampleType
+ model_name_or_path: str
+ pipeline_class: Optional[type[DiffusionPipeline]] = DiffusionPipeline
+ torch_dtype: Optional[torch.dtype] = torch.bfloat16
+ bnb_4bit_components: Optional[List[str]] = dataclasses.field(default_factory=list)
+ scheduler: Optional[Union[SchedulerMixin, Callable]] = None # lora case
+ transformer: Optional[Union[ModelMixin, Callable]] = None # lora or nunchaku case
+ vae: Optional[Union[ModelMixin, Callable]] = None
+ text_encoder: Optional[Union[GenerationMixin, Callable]] = None
+ controlnet: Optional[Union[ModelMixin, Callable]] = None
+ lora_weights_path: Optional[str] = None
+ lora_weights_name: Optional[str] = None
+ # For parallelism compatibility, tensor parallelism requires fused LoRA
+ force_fuse_lora: bool = True
+ pre_init_hook: Optional[Callable[[Any], None]] = None # For future use
+ post_init_hook: Optional[Callable[[DiffusionPipeline], None]] = None
+ # For DBCache, Parallelism optimization.
+ extra_optimize_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
+
+ def __post_init__(self):
+ if not self.bnb_4bit_components:
+ self.bnb_4bit_components = ["text_encoder"]
+
+ def get_pipe(self, args: argparse.Namespace, **kwargs) -> DiffusionPipeline:
+ if self.pipeline_class is None:
+ raise ValueError("pipeline_class must be provided to get the pipeline instance.")
+ pipeline_quantization_config = self._pipeline_quantization_config(args)
+ pipe = self.pipeline_class.from_pretrained(
+ self.model_name_or_path if args.model_path is None else args.model_path,
+ torch_dtype=self.torch_dtype,
+ quantization_config=pipeline_quantization_config,
+ device_map="balanced" if args.device_map_balance else None,
+ **self._custom_components_kwargs(),
+ ) # type: LoraBaseMixin
+ if self.post_init_hook is not None:
+ self.post_init_hook(pipe, **kwargs)
+
+ # Load lora and fuse if needed
+ if self.has_lora:
+ assert issubclass(
+ type(pipe), LoraBaseMixin
+ ), "Pipeline class must inherit from LoraBaseMixin to load LoRA weights."
+ assert hasattr(
+ pipe, "load_lora_weights"
+ ), "Pipeline instance must have load_lora_weights method to load LoRA weights."
+ if self.lora_weights_name is None:
+ # TODO: Support adapter name in the future
+ pipe.load_lora_weights(self.lora_weights_path)
+ else:
+ pipe.load_lora_weights(self.lora_weights_path, weight_name=self.lora_weights_name)
+ if not args.disable_fuse_lora and (
+ pipeline_quantization_config is None
+ or "transformer" not in pipeline_quantization_config.components_to_quantize
+ or self.force_fuse_lora
+ ):
+ pipe.fuse_lora()
+ pipe.unload_lora_weights()
+ logger.info("Fused and unloaded LoRA weights into the transformer.")
+ else:
+ logger.warning("Keep LoRA weights in memory since transformer is quantized.")
+
+ return pipe
+
+ def summary(self, args: argparse.Namespace, **kwargs) -> str:
+ logger.info("🤖 Example Init Config Summary:")
+ extra_model_path = kwargs.get("extra_model_path", "")
+ model_name_or_path = self.model_name_or_path if args.model_path is None else args.model_path
+ summary_str = "- Model: "
+ if (
+ os.path.basename(extra_model_path).lower()
+ != os.path.basename(model_name_or_path).lower()
+ ):
+ summary_str += f"\n - {model_name_or_path}\n"
+ summary_str += f" - {extra_model_path}\n"
+ else:
+ summary_str += f"{model_name_or_path}\n"
+ summary_str += f"- Task Type: {self.task_type.value}\n"
+ summary_str += f"- Torch Dtype: {self.torch_dtype}\n"
+ if self.lora_weights_path is not None and self.lora_weights_name is not None:
+ summary_str += (
+ f"- LoRA Weights: {os.path.join(self.lora_weights_path, self.lora_weights_name)}\n"
+ )
+ elif self.lora_weights_path is not None:
+ summary_str += f"- LoRA Path: {self.lora_weights_path}\n"
+ else:
+ summary_str += "- LoRA Weights: None\n"
+ summary_str = summary_str.rstrip("\n")
+ logger.info(summary_str)
+ return summary_str
+
+ def _custom_components_kwargs(self) -> Dict[str, Any]:
+ custom_components_kwargs = {}
+
+ custom_components_kwargs["scheduler"] = (
+ self.scheduler
+ if not _is_function_or_method(
+ self.scheduler,
+ )
+ else self.scheduler() # get scheduler instance
+ )
+ custom_components_kwargs["transformer"] = (
+ self.transformer
+ if not _is_function_or_method(
+ self.transformer,
+ )
+ else self.transformer() # get transformer instance
+ )
+ custom_components_kwargs["vae"] = (
+ self.vae
+ if not _is_function_or_method(
+ self.vae,
+ )
+ else self.vae() # get vae instance
+ )
+ custom_components_kwargs["text_encoder"] = (
+ self.text_encoder
+ if not _is_function_or_method(
+ self.text_encoder,
+ )
+ else self.text_encoder() # get text_encoder instance
+ )
+ custom_components_kwargs["controlnet"] = (
+ self.controlnet
+ if not _is_function_or_method(
+ self.controlnet,
+ )
+ else self.controlnet() # get controlnet instance
+ )
+ # Remove None components
+ custom_components_kwargs = {
+ k: v for k, v in custom_components_kwargs.items() if v is not None
+ }
+ return custom_components_kwargs
+
+ @property
+ def has_lora(self) -> bool:
+ return (
+ self.lora_weights_path is not None
+ and os.path.exists(self.lora_weights_path)
+ and self.lora_weights_name is not None
+ and os.path.exists(os.path.join(self.lora_weights_path, self.lora_weights_name))
+ )
+
+ def _pipeline_quantization_config(
+ self, args: argparse.Namespace
+ ) -> Optional[PipelineQuantizationConfig]:
+ if self.bnb_4bit_components is None or len(self.bnb_4bit_components) == 0:
+ return None
+ return pipe_quant_bnb_4bit_config(
+ args=args,
+ components_to_quantize=self.bnb_4bit_components,
+ )
+
+
+def _is_function_or_method(component: Any) -> bool:
+ func_types = (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ types.LambdaType,
+ )
+ excluded_module_classes = (
+ SchedulerMixin,
+ ModelMixin,
+ GenerationMixin,
+ torch.nn.Module,
+ )
+
+ is_basic_func = isinstance(component, func_types)
+ is_excluded_instance = isinstance(component, excluded_module_classes)
+ is_method = isinstance(
+ component,
+ (
+ types.MethodType,
+ types.ClassMethodDescriptorType,
+ ),
+ )
+ return is_basic_func and not is_excluded_instance and not is_method
+
+
+class Example:
+ def __init__(
+ self,
+ args: argparse.Namespace,
+ init_config: Optional[ExampleInitConfig] = None,
+ input_data: Optional[ExampleInputData] = None,
+ ):
+ self.args = args
+ self.init_config: Optional[ExampleInitConfig] = init_config
+ self.input_data: Optional[ExampleInputData] = input_data
+ self.output_data: Optional[ExampleOutputData] = None
+ self.rank, self.device = maybe_init_distributed(self.args)
+
+ def check_valid(self) -> bool:
+ if self.args is None:
+ raise ValueError("args must be provided.")
+ if self.input_data is None:
+ raise ValueError("input_data must be provided.")
+ if self.init_config is None:
+ raise ValueError("init_config must be provided.")
+ return True
+
+ def prepare_input_data(self):
+ input_kwargs = self.input_data.data(self.args)
+ default_num_inference_steps = input_kwargs.get("num_inference_steps", None)
+ extra_optimize_kwargs = self.init_config.extra_optimize_kwargs
+ extra_optimize_kwargs["default_num_inference_steps"] = default_num_inference_steps
+ return input_kwargs, extra_optimize_kwargs
+
+ def run(self) -> None:
+ self.check_valid()
+ start_time = time.time()
+ pipe = self.init_config.get_pipe(self.args)
+ load_time = time.time() - start_time
+
+ input_kwargs, extra_optimize_kwargs = self.prepare_input_data()
+ default_num_inference_steps = input_kwargs.get("num_inference_steps", None)
+
+ maybe_apply_optimization(self.args, pipe, **extra_optimize_kwargs)
+
+ pipe.set_progress_bar_config(disable=self.rank != 0)
+
+ # track memory if needed
+ memory_tracker = MemoryTracker() if self.args.track_memory else None
+ if memory_tracker:
+ memory_tracker.__enter__()
+
+ # warm up
+ start_time = time.time()
+ for _ in range(self.args.warmup):
+ input_kwargs = self.new_generator(input_kwargs, self.args)
+ if self.args.warmup_num_inference_steps is not None:
+ input_kwargs["num_inference_steps"] = self.args.warmup_num_inference_steps
+ _ = pipe(**input_kwargs)
+ if self.args.warmup > 0:
+ warmup_time = (time.time() - start_time) / self.args.warmup
+ else:
+ warmup_time = None
+ # restore num_inference_steps
+ input_kwargs["num_inference_steps"] = default_num_inference_steps
+
+ start_time = time.time()
+ # actual inference
+ model_tag = self.args.example if self.args.example is not None else "None"
+ if self.args.profile:
+ requested_profile_name = getattr(self.args, "profile_name", None)
+ profile_name = requested_profile_name or f"{model_tag}_profile"
+ profiler = create_profiler_from_args(self.args, profile_name=profile_name)
+ with profiler:
+ for _ in range(self.args.repeat):
+ input_kwargs = self.new_generator(input_kwargs, self.args)
+ output = pipe(**input_kwargs)
+ if self.rank == 0:
+ logger.info(
+ f"Profiler traces saved to: {profiler.output_dir}/{profiler.trace_path.name}"
+ )
+ else:
+ for _ in range(self.args.repeat):
+ input_kwargs = self.new_generator(input_kwargs, self.args)
+ output = pipe(**input_kwargs)
+ if self.args.repeat > 0:
+ inference_time = (time.time() - start_time) / self.args.repeat
+ else:
+ inference_time = None
+
+ if self.args.cache_summary:
+ if self.rank == 0:
+ cache_dit.summary(pipe)
+
+ if memory_tracker:
+ memory_tracker.__exit__(None, None, None)
+ peak_gb = memory_tracker.report()
+ else:
+ peak_gb = None
+
+ # Prepare output data
+ output_data = ExampleOutputData(
+ model_tag=model_tag,
+ strify_tag=f"{strify(self.args, pipe)}",
+ load_time=load_time,
+ warmup_time=warmup_time,
+ inference_time=inference_time,
+ memory_usage=peak_gb,
+ )
+
+ if self.init_config.task_type in [ExampleType.T2I, ExampleType.IE2I]:
+ output_data.image = (
+ output.images[0] if isinstance(output.images, list) else output.images
+ )
+ elif self.init_config.task_type in [
+ ExampleType.T2V,
+ ExampleType.I2V,
+ ExampleType.FLF2V,
+ ExampleType.VACE,
+ ]:
+ output_data.video = output.frames[0] if hasattr(output, "frames") else output
+
+ self.output_data = output_data
+
+ if self.rank == 0:
+ logger.info("-" * 100)
+ self.init_config.summary(
+ self.args,
+ # path for extra model, e.g., lora weights, svdq int4 weights, etc.
+ extra_model_path=ExampleRegister.get_default(
+ self.args.example,
+ ),
+ )
+ self.input_data.summary(self.args)
+ self.output_data.summary(self.args)
+ self.output_data.save(self.args)
+ logger.info("-" * 100)
+
+ maybe_destroy_distributed()
+
+ def new_generator(
+ self, input_kwargs: Dict[str, Any], args: argparse.Namespace
+ ) -> torch.Generator:
+ # NOTE: We should always create a new generator before each inference to
+ # ensure reproducibility when using the same seed.
+ input_kwargs["generator"] = self.input_data.new_generator(args=args)
+ return input_kwargs
+
+
+class ExampleRegister:
+ _example_registry: Dict[str, Callable[..., Example]] = {}
+ _example_registry_defaults: Dict[str, str] = {}
+
+ @classmethod
+ def register(cls, name: str, default: str = ""):
+ def decorator(example_func: Callable[..., Example]):
+ if name in cls._example_registry:
+ raise ValueError(f"Example '{name}' is already registered.")
+ cls._example_registry[name] = example_func
+ cls._example_registry_defaults[name] = default
+ return example_func
+
+ return decorator
+
+ @classmethod
+ def get_example(cls, args: argparse.Namespace, name: str, **kwargs) -> Example:
+ if name not in cls._example_registry:
+ raise ValueError(f"Example '{name}' is not registered.")
+ example_func = cls._example_registry[name]
+ return example_func(args, **kwargs)
+
+ @classmethod
+ def list_examples(cls) -> List[str]:
+ return list(cls._example_registry.keys())
+
+ @classmethod
+ def get_default(cls, name: str) -> str:
+ return cls._example_registry_defaults.get(name, "")
diff --git a/examples/config.yaml b/examples/config.yaml
new file mode 100644
index 000000000..82fe12234
--- /dev/null
+++ b/examples/config.yaml
@@ -0,0 +1,12 @@
+cache_config:
+ max_warmup_steps: 8
+ warmup_interval: 2
+ max_cached_steps: -1
+ max_continuous_cached_steps: 2
+ Fn_compute_blocks: 1
+ Bn_compute_blocks: 0
+ num_inference_steps: 28
+ steps_computation_mask: fast
+ residual_diff_threshold: 0.12
+ enable_taylorseer: true
+ taylorseer_order: 1
diff --git a/examples/data/edit2509_1.jpg b/examples/data/edit2509_1.jpg
new file mode 100644
index 000000000..ebb019ebc
Binary files /dev/null and b/examples/data/edit2509_1.jpg differ
diff --git a/examples/data/edit2509_2.jpg b/examples/data/edit2509_2.jpg
new file mode 100644
index 000000000..1df751434
Binary files /dev/null and b/examples/data/edit2509_2.jpg differ
diff --git a/examples/data/pose.jpg b/examples/data/pose.jpg
new file mode 100644
index 000000000..473c670c7
Binary files /dev/null and b/examples/data/pose.jpg differ
diff --git a/examples/data/yarn-art-pikachu.png b/examples/data/yarn-art-pikachu.png
new file mode 100644
index 000000000..32a69608f
Binary files /dev/null and b/examples/data/yarn-art-pikachu.png differ
diff --git a/examples/generate.py b/examples/generate.py
new file mode 100644
index 000000000..738c9db64
--- /dev/null
+++ b/examples/generate.py
@@ -0,0 +1,77 @@
+from cache_dit.logger import init_logger
+from utils import get_base_args, maybe_postprocess_args
+from registers import ExampleRegister # noqa: F403, F401
+from helpers import activate_all_examples
+
+# Make sure all example are registered
+activate_all_examples()
+
+logger = init_logger(__name__)
+
+
+def get_example_args():
+ parser = get_base_args(parse=False)
+ parser.add_argument(
+ "task",
+ type=str,
+ nargs="?",
+ default="generate",
+ choices=["generate", "list"] + ExampleRegister.list_examples(),
+ help=(
+ "The task to perform or example name to run. "
+ "Use 'list' to list all available examples, "
+ "or specify an example name directly (defaults to 'generate' task)."
+ ),
+ )
+ parser.add_argument(
+ "example",
+ type=str,
+ nargs="?",
+ default=None,
+ choices=[None] + ExampleRegister.list_examples(),
+ help="Names of the examples to run. If not specified, skip running example.",
+ )
+ args = parser.parse_args()
+
+ if args.task in ExampleRegister.list_examples():
+ args.example = args.task
+ args.task = "generate"
+
+ return maybe_postprocess_args(args)
+
+
+if __name__ == "__main__":
+ logger = init_logger(__name__)
+ args = get_example_args()
+ if args.task == "list":
+ logger.info("Available examples:")
+ max_name_len = max(len(name) for name in ExampleRegister.list_examples())
+ for name in ExampleRegister.list_examples():
+ default = ExampleRegister.get_default(name)
+ # format by max_name_len
+ info = f"- ✅ {name:<{max_name_len}} - Defalut: {default}"
+ logger.info(info)
+ exit(0)
+ else:
+ if args.example is None:
+ logger.error(
+ "Please specify an example name to run. Use --list-examples to "
+ "see all available examples."
+ )
+ exit(1)
+
+ # logging all args with better formatting
+ logger.info("Running example with the following arguments:")
+ for arg, value in vars(args).items():
+ logger.info(f"- {arg}: {value}")
+
+ example = ExampleRegister.get_example(args, args.example)
+ example.run()
+
+ # Usage:
+ # python3 generate.py list
+ # python3 generate.py zimage # simplified, task defaults to 'generate'
+ # python3 generate.py qwen_image_edit_lightning --cpu-offload
+ # python3 generate.py generate zimage # full syntax, still supported
+ # torchrun --nproc_per_node=4 generate.py zimage --parallel ulysses --ulysses-anything
+ # torchrun --nproc_per_node=4 generate.py zimage --parallel tp --parallel-text
diff --git a/examples/helpers.py b/examples/helpers.py
new file mode 100644
index 000000000..2205d83a3
--- /dev/null
+++ b/examples/helpers.py
@@ -0,0 +1,20 @@
+def activate_all_examples():
+ from registers import flux_example # noqa: F403, F401
+ from registers import flux_fill_example # noqa: F403, F401
+ from registers import flux2_klein_example # noqa: F403, F401
+ from registers import flux2_example # noqa: F403, F401
+ from registers import qwen_image_example # noqa: F403, F401
+ from registers import qwen_image_controlnet_example # noqa: F403, F401
+ from registers import qwen_image_edit_example # noqa: F403, F401
+ from registers import qwen_image_layered_example # noqa: F403, F401
+ from registers import skyreels_v2_example # noqa: F403, F401
+ from registers import ltx2_t2v_example # noqa: F403, F401
+ from registers import ltx2_i2v_example # noqa: F403, F401
+ from registers import wan_example # noqa: F403, F401
+ from registers import wan_i2v_example # noqa: F403, F401
+ from registers import wan_vace_example # noqa: F403, F401
+ from registers import ovis_image_example # noqa: F403, F401
+ from registers import zimage_example # noqa: F403, F401
+ from registers import longcat_image_example # noqa: F403, F401
+ from registers import longcat_image_edit_example # noqa: F403, F401
+ from registers import zimage_controlnet_example # noqa: F403, F401
diff --git a/examples/parallel_config.yaml b/examples/parallel_config.yaml
new file mode 100644
index 000000000..cbf482538
--- /dev/null
+++ b/examples/parallel_config.yaml
@@ -0,0 +1,15 @@
+cache_config:
+ max_warmup_steps: 8
+ warmup_interval: 2
+ max_cached_steps: -1
+ max_continuous_cached_steps: 2
+ Fn_compute_blocks: 1
+ Bn_compute_blocks: 0
+ residual_diff_threshold: 0.12
+ enable_taylorseer: true
+ taylorseer_order: 1
+parallelism_config:
+ ulysses_size: auto
+ parallel_kwargs:
+ attention_backend: native
+ extra_parallel_modules: ["text_encoder", "vae"]
diff --git a/examples/parallelism/.gitignore b/examples/parallelism/.gitignore
deleted file mode 100644
index 9469a63d1..000000000
--- a/examples/parallelism/.gitignore
+++ /dev/null
@@ -1,4 +0,0 @@
-tmp
-*.png
-*.mp4
-__pycache__
diff --git a/examples/parallelism/run_chroma_cp.py b/examples/parallelism/run_chroma_cp.py
deleted file mode 100644
index 6c7500e3f..000000000
--- a/examples/parallelism/run_chroma_cp.py
+++ /dev/null
@@ -1,89 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import ChromaPipeline
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-import cache_dit
-
-# NOTE: Please use `--parallel ulysses --attn naitve` for Chroma with context parallelism,
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe = ChromaPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get("CHROMA1_DIR", "lodestones/Chroma1-HD")
- ),
- torch_dtype=torch.bfloat16,
-)
-
-pipe.to("cuda")
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-prompt = [
- "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
-]
-if args.prompt is not None:
- prompt = [args.prompt]
-
-negative_prompt = [
- "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
-]
-if args.negative_prompt is not None:
- negative_prompt = [args.negative_prompt]
-
-
-def run_pipe(warmup: bool = False):
- image = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- generator=torch.Generator("cpu").manual_seed(433),
- num_inference_steps=40 if not warmup else 5,
- guidance_scale=3.0,
- num_images_per_prompt=1,
- ).images[0]
- return image
-
-
-# warmup
-run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe, details=True)
-
- time_cost = end - start
- save_path = f"chroma1-hd.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_chroma_tp.py b/examples/parallelism/run_chroma_tp.py
deleted file mode 100644
index 58666b301..000000000
--- a/examples/parallelism/run_chroma_tp.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import ChromaPipeline
-from utils import (
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- cachify,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe = ChromaPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "CHROMA1_DIR",
- "lodestones/Chroma1-HD",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-pipe.to(device)
-
-torch.cuda.empty_cache()
-pipe.set_progress_bar_config(disable=rank != 0)
-
-prompt = [
- "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
-]
-if args.prompt is not None:
- prompt = [args.prompt]
-
-negative_prompt = [
- "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
-]
-if args.negative_prompt is not None:
- negative_prompt = [args.negative_prompt]
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- generator=torch.Generator("cpu").manual_seed(433),
- num_inference_steps=40,
- guidance_scale=3.0,
- num_images_per_prompt=1,
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe, details=True)
-
- time_cost = end - start
- save_path = f"chroma1-hd.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_chrono_edit.py b/examples/parallelism/run_chrono_edit.py
deleted file mode 100644
index 07ff56088..000000000
--- a/examples/parallelism/run_chrono_edit.py
+++ /dev/null
@@ -1,150 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import numpy as np
-from PIL import Image
-from diffusers import (
- AutoencoderKLWan,
- ChronoEditTransformer3DModel,
- ChronoEditPipeline,
-)
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers.utils import load_image
-from transformers import CLIPVisionModel
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-model_id = args.model_path if args.model_path is not None else "nvidia/ChronoEdit-14B-Diffusers"
-model_id = (
- args.model_path if args.model_path is not None else os.environ.get("CHRONO_EDIT_DIR", model_id)
-)
-
-image_encoder = CLIPVisionModel.from_pretrained(
- model_id, subfolder="image_encoder", torch_dtype=torch.float32
-)
-vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
-transformer = ChronoEditTransformer3DModel.from_pretrained(
- model_id, subfolder="transformer", torch_dtype=torch.bfloat16
-)
-
-enable_quantization = args.quantize and args.quantize_type == "bitsandbytes_4bit"
-
-pipe = ChronoEditPipeline.from_pretrained(
- model_id,
- vae=vae,
- image_encoder=image_encoder,
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- # text_encoder: ~ 6GiB, transformer: ~ 8GiB, total: ~14GiB
- components_to_quantize=["text_encoder", "transformer"],
- )
- if enable_quantization
- else None
- ),
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-# Enable memory savings
-if not enable_quantization:
- pipe.enable_model_cpu_offload(device=device)
-else:
- pipe.to(device)
-
-assert isinstance(pipe.vae, AutoencoderKLWan)
-pipe.vae.enable_tiling()
-
-image = load_image("../data/chrono_edit_example.png")
-
-max_area = 720 * 1280
-aspect_ratio = image.height / image.width
-mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
-height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
-width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
-image = image.resize((width, height))
-
-prompt = (
- "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
- "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
-)
-
-
-if args.prompt is not None:
-
- prompt = args.prompt
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- output = pipe(
- image=image,
- prompt=prompt,
- height=height,
- width=width,
- num_frames=5,
- guidance_scale=5.0,
- enable_temporal_reasoning=False,
- num_temporal_reasoning_steps=0,
- num_inference_steps=((50 if not warmup else 5) if args.steps is None else args.steps),
- generator=torch.Generator("cuda").manual_seed(0),
- ).frames[0]
- output = Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8"))
- return output
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
- # warmup
- _ = run_pipe(warmup=True)
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-output = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- stats = cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"chrono-edit.{strify(args, stats)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- output.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_cogvideox_1.5.py b/examples/parallelism/run_cogvideox_1.5.py
deleted file mode 100644
index 47ea2e347..000000000
--- a/examples/parallelism/run_cogvideox_1.5.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers.utils import export_to_video
-from diffusers import CogVideoXPipeline, AutoencoderKLCogVideoX
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe = CogVideoXPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get("COGVIDEOX_1_5_DIR", "zai-org/CogVideoX1.5-5B")
- ),
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-assert isinstance(pipe.vae, AutoencoderKLCogVideoX) # enable type check for IDE
-torch.cuda.empty_cache()
-pipe.enable_model_cpu_offload(device=device)
-pipe.vae.enable_tiling()
-
-prompt = (
- "A panda, dressed in a small, red jacket and a tiny hat, "
- "sits on a wooden stool in a serene bamboo forest. The "
- "panda's fluffy paws strum a miniature acoustic guitar, "
- "producing soft, melodic tunes. Nearby, a few other pandas "
- "gather, watching curiously and some clapping in rhythm. "
- "Sunlight filters through the tall bamboo, casting a gentle "
- "glow on the scene. The panda's face is expressive, showing "
- "concentration and joy as it plays. The background includes "
- "a small, flowing stream and vibrant green foliage, enhancing "
- "the peaceful and magical atmosphere of this unique musical "
- "performance."
-)
-
-
-if args.prompt is not None:
-
- prompt = args.prompt
-
-
-def run_pipe(warmup: bool = False):
- video = pipe(
- prompt=prompt,
- num_videos_per_prompt=1,
- num_inference_steps=50 if not warmup else 5,
- num_frames=16,
- guidance_scale=6,
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- return video
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- stats = cache_dit.summary(pipe)
-
- time_cost = end - start
- parallel_type = args.parallel_type or "none"
- save_path = f"cogvideox_1.5_{parallel_type}.{strify(args, stats)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=8)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_cogview3_plus.py b/examples/parallelism/run_cogview3_plus.py
deleted file mode 100644
index 269196881..000000000
--- a/examples/parallelism/run_cogview3_plus.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import CogView3PlusPipeline
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe = CogView3PlusPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "COGVIEW3_DIR",
- "THUDM/CogView3-Plus-3B",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-torch.cuda.empty_cache()
-pipe.enable_model_cpu_offload(device=device)
-
-prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background."
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe(warmup: bool = False):
- image = pipe(
- prompt=prompt,
- guidance_scale=7.0,
- num_inference_steps=50 if not warmup else 5,
- width=1024,
- height=1024,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-# warmup
-run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- stats = cache_dit.summary(pipe)
-
- time_cost = end - start
- parallel_type = args.parallel_type or "none"
- save_path = f"cogview3_plus_{parallel_type}.{strify(args, stats)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_cogview4.py b/examples/parallelism/run_cogview4.py
deleted file mode 100644
index 79f546344..000000000
--- a/examples/parallelism/run_cogview4.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import CogView4Pipeline
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe = CogView4Pipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get("COGVIEW4_DIR", "THUDM/CogView4-6B")
- ),
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe, enable_separate_cfg=True)
-
-torch.cuda.empty_cache()
-pipe.enable_model_cpu_offload(device=device)
-
-prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background."
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe(warmup: bool = False):
- image = pipe(
- prompt=prompt,
- guidance_scale=3.5, # >1, do separate cfg
- num_inference_steps=50 if not warmup else 5,
- width=1024,
- height=1024,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-# warmup
-run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- stats = cache_dit.summary(pipe)
-
- time_cost = end - start
- parallel_type = args.parallel_type or "none"
- save_path = f"cogview4_{parallel_type}.{strify(args, stats)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_consisid.py b/examples/parallelism/run_consisid.py
deleted file mode 100644
index 535649455..000000000
--- a/examples/parallelism/run_consisid.py
+++ /dev/null
@@ -1,121 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import ConsisIDPipeline
-from diffusers.pipelines.consisid.consisid_utils import (
- prepare_face_models,
- process_face_embeddings_infer,
-)
-from diffusers.utils import export_to_video
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("CONSISID_DIR", "BestWishYsh/ConsisID-preview")
-)
-
-(
- face_helper_1,
- face_helper_2,
- face_clip_model,
- face_main_model,
- eva_transform_mean,
- eva_transform_std,
-) = prepare_face_models(model_id, device="cuda", dtype=torch.bfloat16)
-pipe = ConsisIDPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-pipe.enable_model_cpu_offload(device=device)
-pipe.vae.enable_tiling()
-
-# ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body).
-prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel."
-if args.prompt is not None:
- prompt = args.prompt
-# image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true"
-image = "../data/consisid_input.png"
-
-id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(
- face_helper_1,
- face_clip_model,
- face_helper_2,
- eva_transform_mean,
- eva_transform_std,
- face_main_model,
- "cuda",
- torch.bfloat16,
- image,
- is_align_face=True,
-)
-
-del face_helper_1
-del face_helper_2
-del face_clip_model
-del face_main_model
-del eva_transform_mean
-del eva_transform_std
-torch.cuda.empty_cache()
-torch.cuda.ipc_collect()
-
-
-def run_pipe(warmup: bool = False):
- video = pipe(
- image=image,
- prompt=prompt,
- num_inference_steps=50 if not warmup else 5,
- guidance_scale=6.0,
- use_dynamic_cfg=False,
- id_vit_hidden=id_vit_hidden,
- id_cond=id_cond,
- kps_cond=face_kps,
- generator=torch.Generator("cpu").manual_seed(42),
- ).frames[0]
- return video
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe, details=True)
-
- time_cost = end - start
- save_path = f"consisid.{strify(args, pipe)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=8)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_dit_xl_cp.py b/examples/parallelism/run_dit_xl_cp.py
deleted file mode 100644
index 1509aafc5..000000000
--- a/examples/parallelism/run_dit_xl_cp.py
+++ /dev/null
@@ -1,80 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-
-from diffusers import DiTPipeline, DPMSolverMultistepScheduler
-from utils import (
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe = DiTPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "DIT_XL_DIR",
- "facebook/DiT-XL-2-256",
- )
- ),
- torch_dtype=torch.float16,
-)
-pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
-pipe = pipe.to("cuda")
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-words = ["white shark"]
-
-class_ids = pipe.get_label_ids(words)
-
-
-def run_pipe():
- image = pipe(
- class_labels=class_ids,
- num_inference_steps=25,
- generator=torch.Generator("cpu").manual_seed(33),
- ).images[0]
- return image
-
-
-# warmup
-_ = run_pipe()
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"dit-xl.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_flux_cp.py b/examples/parallelism/run_flux_cp.py
deleted file mode 100644
index ffc82844d..000000000
--- a/examples/parallelism/run_flux_cp.py
+++ /dev/null
@@ -1,110 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- FluxPipeline,
- FluxTransformer2DModel,
- PipelineQuantizationConfig,
-)
-from utils import (
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe: FluxPipeline = FluxPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "FLUX_DIR",
- "black-forest-labs/FLUX.1-dev",
- )
- ),
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder_2"],
- )
- if args.quantize
- else None
- ),
-).to("cuda")
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-assert isinstance(pipe.transformer, FluxTransformer2DModel)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-# Set default prompt
-prompt = "A cat holding a sign that says hello world"
-if args.prompt is not None:
- prompt = args.prompt
-
-
-height = 1024 if args.height is None else args.height
-width = 1024 if args.width is None else args.width
-
-
-def run_pipe(pipe: FluxPipeline):
- image = pipe(
- prompt,
- height=height,
- width=width,
- num_inference_steps=28 if args.steps is None else args.steps,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe(pipe)
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"flux.{height}x{width}.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_flux_nunchaku_cp.py b/examples/parallelism/run_flux_nunchaku_cp.py
deleted file mode 100644
index 0f692c2f3..000000000
--- a/examples/parallelism/run_flux_nunchaku_cp.py
+++ /dev/null
@@ -1,164 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-import time
-
-import torch
-import torch.distributed as dist
-from diffusers import (
- FluxPipeline,
- FluxTransformer2DModel,
- PipelineQuantizationConfig,
-)
-from nunchaku.models.transformers.transformer_flux_v2 import (
- NunchakuFluxTransformer2DModelV2,
-)
-from utils import (
- get_args,
- strify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-nunchaku_flux_dir = os.environ.get(
- "NUNCHAKA_FLUX_DIR",
- "nunchaku-tech/nunchaku-flux.1-dev",
-)
-transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
- f"{nunchaku_flux_dir}/svdq-int4_r32-flux.1-dev.safetensors",
-)
-pipe: FluxPipeline = FluxPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev")
- ),
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder_2"],
- )
- if args.quantize
- else None
- ),
-).to("cuda")
-
-
-if args.cache or args.parallel_type is not None:
- from cache_dit import (
- ParamsModifier,
- DBCacheConfig,
- TaylorSeerCalibratorConfig,
- ParallelismConfig,
- )
-
- cache_dit.enable_cache(
- pipe,
- cache_config=(
- DBCacheConfig(
- Fn_compute_blocks=args.Fn,
- Bn_compute_blocks=args.Bn,
- max_warmup_steps=args.max_warmup_steps,
- max_cached_steps=args.max_cached_steps,
- max_continuous_cached_steps=args.max_continuous_cached_steps,
- residual_diff_threshold=args.rdt,
- )
- if args.cache
- else None
- ),
- calibrator_config=(
- TaylorSeerCalibratorConfig(
- taylorseer_order=args.taylorseer_order,
- )
- if args.taylorseer
- else None
- ),
- params_modifiers=[
- ParamsModifier(
- # transformer_blocks
- cache_config=DBCacheConfig().reset(residual_diff_threshold=args.rdt),
- ),
- ParamsModifier(
- # single_transformer_blocks
- cache_config=DBCacheConfig().reset(residual_diff_threshold=args.rdt * 3),
- ),
- ],
- parallelism_config=(
- ParallelismConfig(
- ulysses_size=(dist.get_world_size() if args.parallel_type == "ulysses" else None),
- ring_size=(dist.get_world_size() if args.parallel_type == "ring" else None),
- )
- if args.parallel_type in ["ulysses", "ring"]
- else None
- ),
- )
-
-assert isinstance(pipe.transformer, FluxTransformer2DModel)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-# Set default prompt
-prompt = "A cat holding a sign that says hello world"
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe(pipe: FluxPipeline):
- image = pipe(
- prompt,
- height=1024 if args.height is None else args.height,
- width=1024 if args.width is None else args.width,
- num_inference_steps=28 if args.steps is None else args.steps,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-if args.compile:
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe(pipe)
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"flux.nunchaku.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_flux_tp.py b/examples/parallelism/run_flux_tp.py
deleted file mode 100644
index 5a12ec3eb..000000000
--- a/examples/parallelism/run_flux_tp.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- FluxPipeline,
- FluxTransformer2DModel,
-)
-from utils import (
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe: FluxPipeline = FluxPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "FLUX_DIR",
- "black-forest-labs/FLUX.1-dev",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-pipe.to(device)
-
-assert isinstance(pipe.transformer, FluxTransformer2DModel)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-# Set default prompt
-prompt = "A cat holding a sign that says hello world"
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe(pipe: FluxPipeline):
- image = pipe(
- prompt,
- height=1024 if args.height is None else args.height,
- width=1024 if args.width is None else args.width,
- num_inference_steps=28 if args.steps is None else args.steps,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe(pipe)
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"flux.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_hunyuan_dit_tp.py b/examples/parallelism/run_hunyuan_dit_tp.py
deleted file mode 100644
index 3328e3c27..000000000
--- a/examples/parallelism/run_hunyuan_dit_tp.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import HunyuanDiTPipeline
-from utils import (
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "HUNYUAN_DIT_DIR",
- "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers",
- # "Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers",
- )
-)
-
-# Initialize distributed for tensor parallelism
-rank, device = maybe_init_distributed(args)
-
-pipe = HunyuanDiTPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.float16,
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-torch.cuda.empty_cache()
-pipe.enable_model_cpu_offload(device=device)
-
-# You may also use English prompt as HunyuanDiT supports both English and Chinese
-prompt = "An astronaut riding a horse on Mars"
-
-if args.prompt is not None:
- prompt = args.prompt
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt,
- num_inference_steps=args.steps if args.steps else 50,
- generator=torch.Generator(device).manual_seed(0),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-if rank == 0:
- time_cost = end - start
- version = "11" if "1.1" in model_id else "12"
- save_path = f"hunyuan_dit_{version}.{strify(args, stats)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_hunyuan_image_2.1_cp.py b/examples/parallelism/run_hunyuan_image_2.1_cp.py
deleted file mode 100644
index 924605f36..000000000
--- a/examples/parallelism/run_hunyuan_image_2.1_cp.py
+++ /dev/null
@@ -1,125 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- HunyuanImagePipeline,
- HunyuanImageTransformer2DModel,
-)
-from diffusers.quantizers import PipelineQuantizationConfig
-
-from utils import (
- GiB,
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-# NOTE: Please use `--parallel ulysses --attn naitve` for HunyuanImage with context parallelism,
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-enable_quatization = args.quantize and GiB() < 96
-# For now you need to install the latest diffusers as below:
-# pip install git+https://github.com/huggingface/diffusers@main
-pipe: HunyuanImagePipeline = HunyuanImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "HUNYUAN_IMAGE_DIR",
- "hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
- )
- ),
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder"],
- )
- if enable_quatization
- else None
- ),
-)
-
-if GiB() < 96:
- if enable_quatization:
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type, # float8_weight_only
- )
- pipe.to(device)
-else:
- pipe.to(device)
-
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-torch.cuda.empty_cache()
-assert isinstance(pipe.transformer, HunyuanImageTransformer2DModel)
-
-if GiB() < 96 and not enable_quatization:
- pipe.enable_model_cpu_offload(device=device)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- prompt = 'A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, wearing a red knitted scarf and a red beret with the word "Tencent" on it, holding a paintbrush with a focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style.'
- if args.prompt is not None:
- prompt = args.prompt
- image = pipe(
- prompt,
- num_inference_steps=50 if not warmup else 5,
- height=2048,
- width=2048,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"hunyuan_image_2.1.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_hunyuan_image_2.1_tp.py b/examples/parallelism/run_hunyuan_image_2.1_tp.py
deleted file mode 100644
index 8e0c3879e..000000000
--- a/examples/parallelism/run_hunyuan_image_2.1_tp.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- HunyuanImagePipeline,
- HunyuanImageTransformer2DModel,
-)
-from utils import (
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-# For now you need to install the latest diffusers as below:
-# pip install git+https://github.com/huggingface/diffusers@main
-pipe: HunyuanImagePipeline = HunyuanImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "HUNYUAN_IMAGE_DIR",
- "hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-torch.cuda.empty_cache()
-assert isinstance(pipe.transformer, HunyuanImageTransformer2DModel)
-pipe.enable_model_cpu_offload(device=device)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- prompt = 'A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, standing in a painting studio, wearing a red knitted scarf and a red beret with the word "Tencent" on it, holding a paintbrush with a focused expression as it paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style.'
- if args.prompt is not None:
- prompt = args.prompt
- image = pipe(
- prompt,
- num_inference_steps=50 if not warmup else 5,
- height=2048,
- width=2048,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"hunyuan_image_2.1.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_hunyuan_video_cp.py b/examples/parallelism/run_hunyuan_video_cp.py
deleted file mode 100644
index 64b3ff88d..000000000
--- a/examples/parallelism/run_hunyuan_video_cp.py
+++ /dev/null
@@ -1,130 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-
-import torch
-from diffusers import (
- HunyuanVideoPipeline,
- HunyuanVideoTransformer3DModel,
- AutoencoderKLHunyuanVideo,
-)
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers.utils import export_to_video
-from utils import (
- GiB,
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-enable_quatization = args.quantize and GiB() < 96
-
-pipe: HunyuanVideoPipeline = HunyuanVideoPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "HUNYUAN_VIDEO_DIR",
- "hunyuanvideo-community/HunyuanVideo",
- )
- ),
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder"], # 4GiB
- )
- if enable_quatization
- else None
- ),
-)
-
-
-if GiB() < 96:
- if enable_quatization:
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type, # float8_weight_only, 12GiB
- )
- pipe.to(device)
-else:
- pipe.to(device)
-
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
-
-if GiB() < 96 and not enable_quatization:
- pipe.enable_model_cpu_offload(device=device)
-
-assert isinstance(pipe.vae, AutoencoderKLHunyuanVideo)
-pipe.vae.enable_tiling()
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- prompt = "A cat walks on the grass, realistic"
- if args.prompt is not None:
- prompt = args.prompt
- output = pipe(
- prompt,
- height=320,
- width=512,
- num_frames=61,
- num_inference_steps=30 if not warmup else 5,
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- return output
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"hunyuan_video.{strify(args, pipe)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=15)
-
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_hunyuan_video_tp.py b/examples/parallelism/run_hunyuan_video_tp.py
deleted file mode 100644
index 96da321a9..000000000
--- a/examples/parallelism/run_hunyuan_video_tp.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-
-import torch
-from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
-from diffusers.utils import export_to_video
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe: HunyuanVideoPipeline = HunyuanVideoPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "HUNYUAN_VIDEO_DIR",
- "hunyuanvideo-community/HunyuanVideo",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
-pipe.enable_model_cpu_offload(device=device)
-pipe.vae.enable_tiling()
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(pipe: HunyuanVideoPipeline):
- prompt = "A cat walks on the grass, realistic"
- if args.prompt is not None:
- prompt = args.prompt
- output = pipe(
- prompt,
- height=320,
- width=512,
- num_frames=61,
- num_inference_steps=30,
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- return output
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe(pipe)
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"hunyuan_video.{strify(args, pipe)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=15)
-
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_kandinsky5_t2v_cp.py b/examples/parallelism/run_kandinsky5_t2v_cp.py
deleted file mode 100644
index 9b2ae9e7f..000000000
--- a/examples/parallelism/run_kandinsky5_t2v_cp.py
+++ /dev/null
@@ -1,122 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- Kandinsky5T2VPipeline,
- AutoencoderKLHunyuanVideo,
- Kandinsky5Transformer3DModel,
-)
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers.utils import export_to_video
-from utils import (
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-# Available models:
-# ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers
-# ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers
-# ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers
-# ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
-)
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("KANDINSKY5_T2V_DIR", model_id)
-)
-pipe = Kandinsky5T2VPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
- quantization_config=PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder", "text_encoder_2"],
- ),
-)
-pipe = pipe.to("cuda")
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe, enable_separate_cfg=not ("nocfg" in model_id))
-
-prompt = "A cat and a dog baking a cake together in a kitchen."
-
-if args.prompt is not None:
-
- prompt = args.prompt
-negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-assert isinstance(pipe.vae, AutoencoderKLHunyuanVideo)
-
-pipe.vae.enable_tiling()
-
-
-def run_pipe(warmup: bool = False):
- video = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=512,
- width=768,
- num_frames=121,
- num_inference_steps=50 if not warmup else 5,
- guidance_scale=5.0,
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- return video
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- assert isinstance(pipe.transformer, Kandinsky5Transformer3DModel)
- pipe.transformer.compile_repeated_blocks(mode="default")
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"kandinsky5.{strify(args, pipe)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=24, quality=9)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_kandinsky5_t2v_tp.py b/examples/parallelism/run_kandinsky5_t2v_tp.py
deleted file mode 100644
index fd1d74f83..000000000
--- a/examples/parallelism/run_kandinsky5_t2v_tp.py
+++ /dev/null
@@ -1,103 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-
-import torch
-from diffusers import AutoencoderKLHunyuanVideo, Kandinsky5T2VPipeline
-from diffusers.utils import export_to_video
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-# Available models:
-# ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers
-# ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers
-# ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers
-# ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
-)
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("KANDINSKY5_T2V_DIR", model_id)
-)
-# For now you need to install the latest diffusers as below:
-# pip install git+https://github.com/huggingface/diffusers@main
-pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe, enable_separate_cfg=not ("nocfg" in model_id))
-
-pipe = pipe.to(device)
-
-torch.cuda.empty_cache()
-
-prompt = "A cat and a dog baking a cake together in a kitchen."
-
-if args.prompt is not None:
-
- prompt = args.prompt
-negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-assert isinstance(pipe.vae, AutoencoderKLHunyuanVideo)
-
-pipe.vae.enable_tiling()
-
-
-def run_pipe():
- video = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=512,
- width=768,
- num_frames=121,
- num_inference_steps=50,
- guidance_scale=5.0,
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- return video
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"kandinsky5.{strify(args, pipe)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=24, quality=9)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_ltx_video_cp.py b/examples/parallelism/run_ltx_video_cp.py
deleted file mode 100644
index bce9f3da4..000000000
--- a/examples/parallelism/run_ltx_video_cp.py
+++ /dev/null
@@ -1,172 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- LTXConditionPipeline,
- LTXLatentUpsamplePipeline,
- AutoencoderKLLTXVideo,
-)
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers.utils import export_to_video
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-import cache_dit
-
-# NOTE: Please use `--parallel ulysses --attn naitve` for LTXVideo with context parallelism,
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe = LTXConditionPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get("LTX_VIDEO_DIR", "Lightricks/LTX-Video-0.9.7-dev")
- ),
- torch_dtype=torch.bfloat16,
- quantization_config=PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder", "transformer"],
- ),
-)
-
-pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
- os.environ.get("LTX_UPSCALER_DIR", "Lightricks/ltxv-spatial-upscaler-0.9.7"),
- vae=pipe.vae,
- torch_dtype=torch.bfloat16,
-)
-
-pipe.to(device)
-pipe_upsample.to(device)
-assert isinstance(pipe.vae, AutoencoderKLLTXVideo)
-assert isinstance(pipe_upsample.vae, AutoencoderKLLTXVideo)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-pipe_upsample.set_progress_bar_config(disable=rank != 0)
-
-if args.cache or args.parallel_type is not None:
- if args.parallel_type is not None:
- assert args.attn == "native", "Context parallelism for LTXVideo requires " "--attn native"
- cachify(args, pipe)
-
-
-def round_to_nearest_resolution_acceptable_by_vae(height, width):
- height = height - (height % pipe.vae_spatial_compression_ratio)
- width = width - (width % pipe.vae_spatial_compression_ratio)
- return height, width
-
-
-prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
-
-
-if args.prompt is not None:
-
- prompt = args.prompt
-negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-expected_height, expected_width = 512, 704
-downscale_factor = 2 / 3
-num_frames = 49
-
-# Part 1. Generate video at smaller resolution
-downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(
- expected_width * downscale_factor
-)
-downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
- downscaled_height, downscaled_width
-)
-
-
-stats = None
-
-
-def run_pipe(warmup: bool = False):
- global stats
-
- latents = pipe(
- conditions=None,
- prompt=prompt,
- negative_prompt=negative_prompt,
- width=downscaled_width,
- height=downscaled_height,
- num_frames=num_frames,
- num_inference_steps=30 if not warmup else 4,
- generator=torch.Generator("cpu").manual_seed(0),
- output_type="latent",
- ).frames
-
- stats = cache_dit.summary(pipe, details=True)
-
- # Part 2. Upscale generated video using latent upsampler with fewer inference steps
- # The available latent upsampler upscales the height/width by 2x
- upscaled_height, upscaled_width = (
- downscaled_height * 2,
- downscaled_width * 2,
- )
- upscaled_latents = pipe_upsample(latents=latents, output_type="latent").frames
-
- if warmup:
- return None
-
- # Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
- video = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- width=upscaled_width,
- height=upscaled_height,
- num_frames=num_frames,
- denoise_strength=0.4, # Effectively, 4 inference steps out of 10
- num_inference_steps=10,
- latents=upscaled_latents,
- decode_timestep=0.05,
- image_cond_noise_scale=0.025,
- generator=torch.Generator("cpu").manual_seed(0),
- output_type="pil",
- ).frames[0]
- return video
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- # Part 4. Downscale the video to the expected resolution
- video = [frame.resize((expected_width, expected_height)) for frame in video]
-
- time_cost = end - start
- save_path = f"ltx-video.{strify(args, stats)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=8)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_ltx_video_tp.py b/examples/parallelism/run_ltx_video_tp.py
deleted file mode 100644
index f21c0813a..000000000
--- a/examples/parallelism/run_ltx_video_tp.py
+++ /dev/null
@@ -1,161 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-
-import torch
-from diffusers import (
- AutoencoderKLLTXVideo,
- LTXConditionPipeline,
- LTXLatentUpsamplePipeline,
-)
-from diffusers.utils import export_to_video
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe = LTXConditionPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get("LTX_VIDEO_DIR", "Lightricks/LTX-Video-0.9.7-dev")
- ),
- torch_dtype=torch.bfloat16,
-)
-
-pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
- os.environ.get("LTX_UPSCALER_DIR", "Lightricks/ltxv-spatial-upscaler-0.9.7"),
- vae=pipe.vae,
- torch_dtype=torch.bfloat16,
-)
-
-pipe_upsample.to(device)
-assert isinstance(pipe.vae, AutoencoderKLLTXVideo)
-assert isinstance(pipe_upsample.vae, AutoencoderKLLTXVideo)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-pipe_upsample.set_progress_bar_config(disable=rank != 0)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-torch.cuda.empty_cache()
-pipe.enable_model_cpu_offload(device=device)
-
-
-def round_to_nearest_resolution_acceptable_by_vae(height, width):
- height = height - (height % pipe.vae_spatial_compression_ratio)
- width = width - (width % pipe.vae_spatial_compression_ratio)
- return height, width
-
-
-prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
-
-
-if args.prompt is not None:
-
- prompt = args.prompt
-negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-expected_height, expected_width = 512, 704
-downscale_factor = 2 / 3
-num_frames = 49
-
-# Part 1. Generate video at smaller resolution
-downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(
- expected_width * downscale_factor
-)
-downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
- downscaled_height, downscaled_width
-)
-
-
-stats = None
-
-
-def run_pipe(warmup: bool = False):
- global stats
-
- latents = pipe(
- conditions=None,
- prompt=prompt,
- negative_prompt=negative_prompt,
- width=downscaled_width,
- height=downscaled_height,
- num_frames=num_frames,
- num_inference_steps=30 if not warmup else 4,
- generator=torch.Generator("cpu").manual_seed(0),
- output_type="latent",
- ).frames
-
- stats = cache_dit.summary(pipe, details=True)
-
- # Part 2. Upscale generated video using latent upsampler with fewer inference steps
- # The available latent upsampler upscales the height/width by 2x
- upscaled_height, upscaled_width = (
- downscaled_height * 2,
- downscaled_width * 2,
- )
- upscaled_latents = pipe_upsample(latents=latents, output_type="latent").frames
-
- if warmup:
- return None
-
- # Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
- video = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- width=upscaled_width,
- height=upscaled_height,
- num_frames=num_frames,
- denoise_strength=0.4, # Effectively, 4 inference steps out of 10
- num_inference_steps=10,
- latents=upscaled_latents,
- decode_timestep=0.05,
- image_cond_noise_scale=0.025,
- generator=torch.Generator("cpu").manual_seed(0),
- output_type="pil",
- ).frames[0]
- return video
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- # Part 4. Downscale the video to the expected resolution
- video = [frame.resize((expected_width, expected_height)) for frame in video]
-
- time_cost = end - start
- save_path = f"ltx-video.{strify(args, stats)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=8)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_mochi_tp.py b/examples/parallelism/run_mochi_tp.py
deleted file mode 100644
index 10d9c0d77..000000000
--- a/examples/parallelism/run_mochi_tp.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-
-import torch
-from diffusers import MochiPipeline
-from diffusers.utils import export_to_video
-from utils import (
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-model_id = args.model_path if args.model_path is not None else "genmo/mochi-1-preview"
-model_id = args.model_path if args.model_path is not None else os.environ.get("MOCHI_DIR", model_id)
-
-pipe = MochiPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-pipe.enable_model_cpu_offload(device=device)
-pipe.vae.enable_tiling()
-
-torch.cuda.empty_cache()
-
-prompt = (
- "Close-up of a chameleon's eye, with its scaly skin "
- "changing color. Ultra high resolution 4k."
-)
-
-
-if args.prompt is not None:
-
- prompt = args.prompt
-
-
-def run_pipe():
- video = pipe(
- prompt,
- num_frames=49,
- num_inference_steps=64,
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- return video
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- stats = cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"mochi.{strify(args, stats)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=10)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_pixart_sigma_cp.py b/examples/parallelism/run_pixart_sigma_cp.py
deleted file mode 100644
index bf9233c1a..000000000
--- a/examples/parallelism/run_pixart_sigma_cp.py
+++ /dev/null
@@ -1,95 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import Transformer2DModel, PixArtSigmaPipeline
-from utils import (
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "PIXART_SIGMA_DIR",
- "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
- )
-)
-transformer = Transformer2DModel.from_pretrained(
- model_id,
- subfolder="transformer",
- torch_dtype=torch.bfloat16,
- use_safetensors=True,
-)
-pipe = PixArtSigmaPipeline.from_pretrained(
- model_id,
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- use_safetensors=True,
-)
-pipe.to("cuda")
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-# Set default prompt
-prompt = "A small cactus with a happy face in the Sahara desert."
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe(warmup: bool = False):
- image = pipe(
- prompt,
- height=1024 if args.height is None else args.height,
- width=1024 if args.width is None else args.width,
- num_inference_steps=50 if not warmup else 5,
- generator=torch.Generator(device="cpu").manual_seed(42),
- ).images[0]
- return image
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- stats = cache_dit.summary(pipe)
- time_cost = end - start
- save_path = f"pixart-sigma.{strify(args, stats)}.png"
-
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_pixart_tp.py b/examples/parallelism/run_pixart_tp.py
deleted file mode 100644
index e19e2e85d..000000000
--- a/examples/parallelism/run_pixart_tp.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- Transformer2DModel,
- PixArtSigmaPipeline,
- PixArtAlphaPipeline,
-)
-from utils import (
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-# Support both PixArt-Alpha and PixArt-Sigma models
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "PIXART_DIR",
- "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
- # Alternative models:
- # "PixArt-alpha/PixArt-XL-2-1024-MS",
- # "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
- )
-)
-
-# Determine pipeline type based on model
-if "Sigma" in model_id:
- pipeline_class = PixArtSigmaPipeline
-else:
- pipeline_class = PixArtAlphaPipeline
-
-transformer = Transformer2DModel.from_pretrained(
- model_id,
- subfolder="transformer",
- torch_dtype=torch.bfloat16,
- use_safetensors=True,
-)
-
-pipe = pipeline_class.from_pretrained(
- model_id,
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- use_safetensors=True,
-)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-torch.cuda.empty_cache()
-pipe.enable_model_cpu_offload(device=device)
-pipe.set_progress_bar_config(disable=rank != 0)
-
-# Set default prompt
-prompt = "A small cactus with a happy face in the Sahara desert."
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe(warmup: bool = False):
- image = pipe(
- prompt,
- height=1024 if args.height is None else args.height,
- width=1024 if args.width is None else args.width,
- num_inference_steps=50 if not warmup else 5,
- generator=torch.Generator(device="cpu").manual_seed(42),
- ).images[0]
- return image
-
-
-# Warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- stats = cache_dit.summary(pipe)
- time_cost = end - start
- model_name = "pixart-sigma" if "Sigma" in model_id else "pixart-alpha"
- save_path = f"{model_name}.{strify(args, stats)}.png"
-
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_qwen_image_cp.py b/examples/parallelism/run_qwen_image_cp.py
deleted file mode 100644
index f37a56c96..000000000
--- a/examples/parallelism/run_qwen_image_cp.py
+++ /dev/null
@@ -1,149 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- QwenImagePipeline,
- QwenImageTransformer2DModel,
- AutoencoderKLQwenImage,
-)
-
-from utils import (
- GiB,
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
-enable_quatization = args.quantize and GiB() < 96
-if GiB() < 96:
- if enable_quatization:
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type,
- exclude_layers=[
- "img_in",
- "txt_in",
- ],
- )
- pipe.text_encoder = cache_dit.quantize(
- pipe.text_encoder,
- quant_type=args.quantize_type,
- )
- pipe.to(device)
-else:
- pipe.to(device)
-
-if GiB() <= 48 or not enable_quatization:
- assert isinstance(pipe.vae, AutoencoderKLQwenImage)
- pipe.vae.enable_tiling()
-
-# Apply cache and context parallelism here
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-
-if GiB() < 96 and not enable_quatization:
- # NOTE: Enable cpu offload before enabling context parallelism will
- # raise shape error after first pipe call, so we enable it after.
- # It seems a bug of diffusers that cpu offload is not fully
- # compatible with context parallelism, visa versa.
- pipe.enable_model_cpu_offload(device=device)
-
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-if args.prompt is not None:
- prompt = args.prompt
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-height = 1024 if args.height is None else args.height
-width = 1024 if args.width is None else args.width
-
-
-def run_pipe(warmup: bool = False):
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- input_prompt = prompt + positive_magic["en"]
- output = pipe(
- prompt=input_prompt,
- negative_prompt=negative_prompt,
- width=height,
- height=width,
- num_inference_steps=((50 if args.steps is None else args.steps) if not warmup else 5),
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(0),
- output_type="latent" if args.perf else "pil",
- )
- image = output.images[0] if not args.perf else None
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"qwen-image.{height}x{width}.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- if not args.perf:
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_qwen_image_lightning_cp.py b/examples/parallelism/run_qwen_image_lightning_cp.py
deleted file mode 100644
index 77a0249b9..000000000
--- a/examples/parallelism/run_qwen_image_lightning_cp.py
+++ /dev/null
@@ -1,202 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import math
-from diffusers import (
- QwenImagePipeline,
- QwenImageTransformer2DModel,
- AutoencoderKLQwenImage,
- FlowMatchEulerDiscreteScheduler,
-)
-from utils import (
- GiB,
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-# From https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
-scheduler_config = {
- "base_image_seq_len": 256,
- "base_shift": math.log(3), # We use shift=3 in distillation
- "invert_sigmas": False,
- "max_image_seq_len": 8192,
- "max_shift": math.log(3), # We use shift=3 in distillation
- "num_train_timesteps": 1000,
- "shift": 1.0,
- "shift_terminal": None, # set shift_terminal to None
- "stochastic_sampling": False,
- "time_shift_type": "exponential",
- "use_beta_sigmas": False,
- "use_dynamic_shifting": True,
- "use_exponential_sigmas": False,
- "use_karras_sigmas": False,
-}
-scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
-
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- scheduler=scheduler,
- torch_dtype=torch.bfloat16,
-)
-
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
-steps = 8 if args.steps is None else args.steps
-assert steps in [8, 4]
-
-pipe.load_lora_weights(
- os.environ.get(
- "QWEN_IMAGE_LIGHT_DIR",
- "lightx2v/Qwen-Image-Lightning",
- ),
- weight_name=(
- "Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors"
- if steps > 4
- else "Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors"
- ),
-)
-
-pipe.fuse_lora()
-pipe.unload_lora_weights()
-
-enable_quatization = args.quantize and GiB() < 96
-if GiB() < 96:
- if enable_quatization:
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type,
- exclude_layers=[
- "img_in",
- "txt_in",
- ],
- )
- pipe.text_encoder = cache_dit.quantize(
- pipe.text_encoder,
- quant_type=args.quantize_type,
- )
- pipe.to(device)
-else:
- pipe.to(device)
-
-if GiB() <= 48 and not enable_quatization:
- assert isinstance(pipe.vae, AutoencoderKLQwenImage)
- pipe.vae.enable_tiling()
-
-# Apply cache and context parallelism here
-if args.cache or args.parallel_type is not None:
- from cache_dit import DBCacheConfig
-
- cachify(
- args,
- pipe,
- cache_config=(
- DBCacheConfig(
- Fn_compute_blocks=16,
- Bn_compute_blocks=16,
- max_warmup_steps=4 if steps > 4 else 2,
- max_cached_steps=2 if steps > 4 else 1,
- max_continuous_cached_steps=1,
- enable_separate_cfg=False, # true_cfg_scale=1.0
- residual_diff_threshold=0.50 if steps > 4 else 0.8,
- )
- if args.cache
- else None
- ),
- )
-
-
-if GiB() < 96 and not enable_quatization:
- # NOTE: Enable cpu offload before enabling context parallelism will
- # raise shape error after first pipe call, so we enable it after.
- # It seems a bug of diffusers that cpu offload is not fully
- # compatible with context parallelism, visa versa.
- pipe.enable_model_cpu_offload(device=device)
-
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-if args.prompt is not None:
- prompt = args.prompt
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- output = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=1024 if args.width is None else args.width,
- height=1024 if args.height is None else args.height,
- num_inference_steps=steps if not warmup else steps,
- true_cfg_scale=1.0, # means no separate cfg
- generator=torch.Generator(device="cpu").manual_seed(0),
- output_type="latent" if args.perf else "pil",
- )
- image = output.images[0] if not args.perf else None
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-
-if rank == 0:
- stats = cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"qwen-image-lightning.{steps}steps.{strify(args, stats)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_qwen_image_lightning_nunchaku_cp.py b/examples/parallelism/run_qwen_image_lightning_nunchaku_cp.py
deleted file mode 100644
index fea60e50d..000000000
--- a/examples/parallelism/run_qwen_image_lightning_nunchaku_cp.py
+++ /dev/null
@@ -1,199 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import math
-import torch.distributed as dist
-from diffusers import (
- QwenImagePipeline,
- FlowMatchEulerDiscreteScheduler,
- PipelineQuantizationConfig,
-)
-from nunchaku.models.transformers.transformer_qwenimage import (
- NunchakuQwenImageTransformer2DModel,
-)
-from utils import (
- get_args,
- strify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-# From https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
-scheduler_config = {
- "base_image_seq_len": 256,
- "base_shift": math.log(3), # We use shift=3 in distillation
- "invert_sigmas": False,
- "max_image_seq_len": 8192,
- "max_shift": math.log(3), # We use shift=3 in distillation
- "num_train_timesteps": 1000,
- "shift": 1.0,
- "shift_terminal": None, # set shift_terminal to None
- "stochastic_sampling": False,
- "time_shift_type": "exponential",
- "use_beta_sigmas": False,
- "use_dynamic_shifting": True,
- "use_exponential_sigmas": False,
- "use_karras_sigmas": False,
-}
-scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
-
-steps = 8 if args.steps is None else args.steps
-assert steps in [8, 4]
-
-nunchaku_qwen_image_dir = os.environ.get(
- "NUNCHAKA_QWEN_IMAGE_DIR",
- "nunchaku-tech/nunchaku-qwen-image",
-)
-lightning_version = "v1.1" if steps == 8 else "v1.0"
-transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
- f"{nunchaku_qwen_image_dir}/svdq-int4_r32-qwen-image-lightning"
- f"{lightning_version}-{steps}steps.safetensors"
-)
-
-# Minimize VRAM required: 25GiB if use w4a16_text_encoder else 35GiB
-w4a16_text_encoder = args.quantize
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- transformer=transformer,
- scheduler=scheduler,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder"],
- )
- if w4a16_text_encoder
- else None
- ),
-).to("cuda")
-
-
-if args.cache or args.parallel_type is not None:
- from cache_dit import (
- DBCacheConfig,
- ParallelismConfig,
- TaylorSeerCalibratorConfig,
- )
-
- cache_dit.enable_cache(
- pipe,
- cache_config=(
- DBCacheConfig(
- Fn_compute_blocks=16,
- Bn_compute_blocks=16,
- max_warmup_steps=4 if steps > 4 else 2,
- warmup_interval=2 if steps > 4 else 1,
- max_cached_steps=2 if steps > 4 else 1,
- max_continuous_cached_steps=1,
- enable_separate_cfg=False, # true_cfg_scale=1.0
- residual_diff_threshold=0.50 if steps > 4 else 0.8,
- )
- if args.cache
- else None
- ),
- calibrator_config=(
- TaylorSeerCalibratorConfig(
- taylorseer_order=args.taylorseer_order,
- )
- if args.taylorseer
- else None
- ),
- parallelism_config=(
- ParallelismConfig(
- ulysses_size=(dist.get_world_size() if args.parallel_type == "ulysses" else None),
- ring_size=(dist.get_world_size() if args.parallel_type == "ring" else None),
- )
- if args.parallel_type in ["ulysses", "ring"]
- else None
- ),
- )
-
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-if args.prompt is not None:
- prompt = args.prompt
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe():
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- output = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=1024 if args.width is None else args.width,
- height=1024 if args.height is None else args.height,
- num_inference_steps=steps,
- true_cfg_scale=1.0,
- generator=torch.Generator(device="cpu").manual_seed(0),
- output_type="latent" if args.perf else "pil",
- )
- image = output.images[0] if not args.perf else None
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe()
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"qwen-image-lightning.{steps}steps.nunchaku.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- if not args.perf:
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_qwen_image_lightning_tp.py b/examples/parallelism/run_qwen_image_lightning_tp.py
deleted file mode 100644
index 109779854..000000000
--- a/examples/parallelism/run_qwen_image_lightning_tp.py
+++ /dev/null
@@ -1,206 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import math
-from diffusers import (
- QwenImagePipeline,
- QwenImageTransformer2DModel,
- AutoencoderKLQwenImage,
- FlowMatchEulerDiscreteScheduler,
-)
-from diffusers.quantizers import PipelineQuantizationConfig
-
-from utils import (
- GiB,
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-# From https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
-scheduler_config = {
- "base_image_seq_len": 256,
- "base_shift": math.log(3), # We use shift=3 in distillation
- "invert_sigmas": False,
- "max_image_seq_len": 8192,
- "max_shift": math.log(3), # We use shift=3 in distillation
- "num_train_timesteps": 1000,
- "shift": 1.0,
- "shift_terminal": None, # set shift_terminal to None
- "stochastic_sampling": False,
- "time_shift_type": "exponential",
- "use_beta_sigmas": False,
- "use_dynamic_shifting": True,
- "use_exponential_sigmas": False,
- "use_karras_sigmas": False,
-}
-scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
-
-enable_quatization = args.quantize and GiB() < 96
-
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- scheduler=scheduler,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- (
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder"],
- )
- )
- if enable_quatization
- else None
- ),
-)
-
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
-steps = 8 if args.steps is None else args.steps
-assert steps in [8, 4]
-
-pipe.load_lora_weights(
- os.environ.get(
- "QWEN_IMAGE_LIGHT_DIR",
- "lightx2v/Qwen-Image-Lightning",
- ),
- weight_name=(
- "Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors"
- if steps > 4
- else "Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors"
- ),
-)
-
-pipe.fuse_lora()
-pipe.unload_lora_weights()
-
-if GiB() <= 48 and not enable_quatization:
- assert isinstance(pipe.vae, AutoencoderKLQwenImage)
- pipe.vae.enable_tiling()
-
-# Apply cache and context parallelism here
-if args.cache or args.parallel_type is not None:
- from cache_dit import DBCacheConfig
-
- cachify(
- args,
- pipe,
- cache_config=(
- DBCacheConfig(
- Fn_compute_blocks=16,
- Bn_compute_blocks=16,
- max_warmup_steps=4 if steps > 4 else 2,
- max_cached_steps=2 if steps > 4 else 1,
- max_continuous_cached_steps=1,
- enable_separate_cfg=False, # true_cfg_scale=1.0
- residual_diff_threshold=0.50 if steps > 4 else 0.8,
- )
- if args.cache
- else None
- ),
- )
-
-
-# Minimum 40GiB is required for tensor parallelism = 2
-if GiB() < 48 and not enable_quatization:
- if not args.parallel_type == "tp":
- # NOTE: Seems CPU offload is not compatible with tensor
- # parallelism (via DTensor).
- pipe.enable_model_cpu_offload(device=device)
- else:
- pipe.to(device)
-else:
- pipe.to(device)
-
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-if args.prompt is not None:
- prompt = args.prompt
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- output = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=1024 if args.width is None else args.width,
- height=1024 if args.height is None else args.height,
- num_inference_steps=steps if not warmup else steps,
- true_cfg_scale=1.0, # means no separate cfg
- generator=torch.Generator(device="cpu").manual_seed(0),
- output_type="latent" if args.perf else "pil",
- )
- image = output.images[0] if not args.perf else None
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-
-if rank == 0:
- stats = cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"qwen-image-lightning.{steps}steps.{strify(args, stats)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_qwen_image_nunchaku_cp.py b/examples/parallelism/run_qwen_image_nunchaku_cp.py
deleted file mode 100644
index e83fc7439..000000000
--- a/examples/parallelism/run_qwen_image_nunchaku_cp.py
+++ /dev/null
@@ -1,168 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import torch.distributed as dist
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers import QwenImagePipeline
-from nunchaku.models.transformers.transformer_qwenimage import (
- NunchakuQwenImageTransformer2DModel,
-)
-from utils import (
- get_args,
- strify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-nunchaku_qwen_image_dir = os.environ.get(
- "NUNCHAKA_QWEN_IMAGE_DIR",
- "nunchaku-tech/nunchaku-qwen-image",
-)
-transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
- f"{nunchaku_qwen_image_dir}/svdq-int4_r32-qwen-image.safetensors"
-)
-
-# Minimize VRAM required: 20GiB if use w4a16_text_encoder else 30GiB
-w4a16_text_encoder = args.quantize
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder"],
- )
- if w4a16_text_encoder
- else None
- ),
-).to("cuda")
-
-
-if args.cache or args.parallel_type is not None:
- from cache_dit import (
- DBCacheConfig,
- ParallelismConfig,
- TaylorSeerCalibratorConfig,
- )
-
- cache_dit.enable_cache(
- pipe,
- cache_config=(
- DBCacheConfig(
- Fn_compute_blocks=args.Fn,
- Bn_compute_blocks=args.Bn,
- max_warmup_steps=args.max_warmup_steps,
- max_cached_steps=args.max_cached_steps,
- max_continuous_cached_steps=args.max_continuous_cached_steps,
- residual_diff_threshold=args.rdt,
- )
- if args.cache
- else None
- ),
- calibrator_config=(
- TaylorSeerCalibratorConfig(
- taylorseer_order=args.taylorseer_order,
- )
- if args.taylorseer
- else None
- ),
- parallelism_config=(
- ParallelismConfig(
- ulysses_size=(dist.get_world_size() if args.parallel_type == "ulysses" else None),
- ring_size=(dist.get_world_size() if args.parallel_type == "ring" else None),
- )
- if args.parallel_type in ["ulysses", "ring"]
- else None
- ),
- )
-
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-if args.prompt is not None:
- prompt = args.prompt
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- output = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=1024 if args.width is None else args.width,
- height=1024 if args.height is None else args.height,
- num_inference_steps=((50 if args.steps is None else args.steps) if not warmup else 5),
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(0),
- output_type="latent" if args.perf else "pil",
- )
- image = output.images[0] if not args.perf else None
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"qwen-image.nunchaku.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- if not args.perf:
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_qwen_image_tp.py b/examples/parallelism/run_qwen_image_tp.py
deleted file mode 100644
index a9e7cd7e6..000000000
--- a/examples/parallelism/run_qwen_image_tp.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-
-import torch
-from diffusers import (
- QwenImagePipeline,
- QwenImageTransformer2DModel,
- AutoencoderKLQwenImage,
-)
-from utils import (
- GiB,
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-pipe: QwenImagePipeline = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
-# Apply cache and tensor parallelism here
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-
-# NOTE: Please handle to(device) after cachify(namely, after applying
-# tensor parallelism), in order to reduce the memory usage at the beginning.
-enable_quatization = args.quantize and GiB() < 96
-if GiB() < 96:
- if enable_quatization:
- # Only quantize text encoder module to fit in GPUs with
- # 48GiB memory for better performance. the required memory
- # for transformer per GPU is reduced significantly after
- # tensor parallelism.
- pipe.text_encoder = cache_dit.quantize(
- pipe.text_encoder,
- quant_type=args.quantize_type,
- )
-else:
- pipe.to(device)
-
-
-# Minimum 40GiB is required for tensor parallelism = 2
-if GiB() < 48 and not enable_quatization:
- if not args.parallel_type == "tp":
- # NOTE: Seems CPU offload is not compatible with tensor
- # parallelism (via DTensor).
- pipe.enable_model_cpu_offload(device=device)
- else:
- pipe.to(device)
-else:
- if not args.parallel_type == "tp":
- pipe.enable_model_cpu_offload(device=device)
- else:
- pipe.to(device)
-
-
-if GiB() <= 48 and not enable_quatization:
- assert isinstance(pipe.vae, AutoencoderKLQwenImage)
- pipe.vae.enable_tiling()
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-if args.prompt is not None:
- prompt = args.prompt
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- output = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=1024 if args.width is None else args.width,
- height=1024 if args.height is None else args.height,
- num_inference_steps=((50 if args.steps is None else args.steps) if not warmup else 5),
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(0),
- output_type="latent" if args.perf else "pil",
- )
- image = output.images[0] if not args.perf else None
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
- time_cost = end - start
- save_path = f"qwen-image.{strify(args, pipe)}.png"
- print(f"Time cost: {time_cost:.2f}s")
- if not args.perf:
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_skyreels_v2_cp.py b/examples/parallelism/run_skyreels_v2_cp.py
deleted file mode 100644
index 8ddae6df6..000000000
--- a/examples/parallelism/run_skyreels_v2_cp.py
+++ /dev/null
@@ -1,112 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import AutoModel, SkyReelsV2Pipeline, UniPCMultistepScheduler
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers.utils import export_to_video
-from utils import (
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("SKYREELS_V2_DIR", "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers")
-)
-
-vae = AutoModel.from_pretrained(
- model_id,
- subfolder="vae",
- torch_dtype=torch.float32,
-).to("cuda")
-
-pipe = SkyReelsV2Pipeline.from_pretrained(
- model_id,
- vae=vae,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["transformer", "text_encoder"],
- )
- if args.quantize
- else None
- ),
-).to("cuda")
-
-flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
-
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe(pipe: SkyReelsV2Pipeline):
- video = pipe(
- prompt=prompt,
- num_inference_steps=50 if args.steps is None else args.steps,
- height=720, # 720 for 720P
- width=1280, # 1280 for 720P
- num_frames=21,
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- return video
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe(pipe)
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe, details=True)
-
- time_cost = end - start
- save_path = f"skyreels_v2.{strify(args, pipe)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=8, quality=8)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_skyreels_v2_tp.py b/examples/parallelism/run_skyreels_v2_tp.py
deleted file mode 100644
index 3e3cd5ac4..000000000
--- a/examples/parallelism/run_skyreels_v2_tp.py
+++ /dev/null
@@ -1,101 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import AutoModel, SkyReelsV2Pipeline, UniPCMultistepScheduler
-from diffusers.utils import export_to_video
-from utils import (
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("SKYREELS_V2_DIR", "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers")
-)
-
-vae = AutoModel.from_pretrained(
- model_id,
- subfolder="vae",
- torch_dtype=torch.float32,
-).to("cuda")
-
-pipe = SkyReelsV2Pipeline.from_pretrained(
- model_id,
- vae=vae,
- torch_dtype=torch.bfloat16,
-)
-
-# Apply TP before moving to device
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-pipe.to(device)
-
-flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
-
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe(pipe: SkyReelsV2Pipeline):
- video = pipe(
- prompt=prompt,
- num_inference_steps=50 if args.steps is None else args.steps,
- height=720, # 720 for 720P
- width=1280, # 1280 for 720P
- num_frames=21,
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- return video
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe(pipe)
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe, details=True)
-
- time_cost = end - start
- save_path = f"skyreels_v2.{strify(args, pipe)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=8, quality=8)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_wan_2.2_i2v_cp.py b/examples/parallelism/run_wan_2.2_i2v_cp.py
deleted file mode 100644
index c3e129ba5..000000000
--- a/examples/parallelism/run_wan_2.2_i2v_cp.py
+++ /dev/null
@@ -1,185 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- AutoencoderKLWan,
- WanTransformer3DModel,
- WanImageToVideoPipeline,
-)
-from diffusers.utils import export_to_video, load_image
-
-from utils import (
- GiB,
- get_args,
- strify,
- cachify,
- maybe_init_distributed,
- maybe_destroy_distributed,
- MemoryTracker,
-)
-import cache_dit
-import numpy as np
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "WAN_2_2_I2V_DIR",
- "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
- )
-)
-
-pipe: WanImageToVideoPipeline = WanImageToVideoPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
-)
-
-
-if args.quantize:
- # default: float8_weight_only
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type,
- )
- pipe.transformer_2 = cache_dit.quantize(
- pipe.transformer_2,
- quant_type=args.quantize_type,
- )
-
-
-if args.cache or args.parallel_type is not None:
- from cache_dit import (
- ForwardPattern,
- BlockAdapter,
- ParamsModifier,
- DBCacheConfig,
- )
-
- cachify(
- args,
- BlockAdapter(
- pipe=pipe,
- transformer=[
- pipe.transformer,
- pipe.transformer_2,
- ],
- blocks=[
- pipe.transformer.blocks,
- pipe.transformer_2.blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_2,
- ForwardPattern.Pattern_2,
- ],
- params_modifiers=[
- # high-noise transformer only have 30% steps
- ParamsModifier(
- cache_config=DBCacheConfig().reset(
- max_warmup_steps=4,
- max_cached_steps=8,
- ),
- ),
- ParamsModifier(
- cache_config=DBCacheConfig().reset(
- max_warmup_steps=2,
- max_cached_steps=20,
- ),
- ),
- ],
- has_separate_cfg=True,
- ),
- )
-
-
-pipe.enable_model_cpu_offload(device=device)
-
-# Wan currently requires installing diffusers from source
-if GiB() <= 48:
- assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
- pipe.vae.enable_tiling()
-
-
-image = load_image(
- "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG"
-)
-
-max_area = 480 * 832
-aspect_ratio = image.height / image.width
-mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
-height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
-width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
-image = image.resize((width, height))
-
-prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
-if args.prompt is not None:
- prompt = args.prompt
-negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- num_inference_steps = 50 if not warmup else 5
- if args.steps is not None:
- num_inference_steps = args.steps
- video = pipe(
- image=image,
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=height,
- width=width,
- num_frames=49, # pipe.vae_scale_factor_temporal=4
- guidance_scale=3.5,
- num_inference_steps=num_inference_steps,
- generator=torch.Generator(device="cpu").manual_seed(0),
- ).frames[0]
-
- return video
-
-
-if args.compile or args.quantize:
- assert isinstance(pipe.transformer, WanTransformer3DModel)
- assert isinstance(pipe.transformer_2, WanTransformer3DModel)
- cache_dit.set_compile_configs()
- pipe.transformer.compile_repeated_blocks()
- pipe.transformer_2.compile_repeated_blocks()
-
-
-# warmup
-run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if rank == 0:
-
- if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
- cache_dit.summary(pipe, details=True)
-
- time_cost = end - start
- save_path = f"wan2.2-i2v.cp.frame{len(video)}.{height}x{width}.{strify(args, pipe)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(video, save_path, fps=16)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_wan_cp.py b/examples/parallelism/run_wan_cp.py
deleted file mode 100644
index 66df4cd8b..000000000
--- a/examples/parallelism/run_wan_cp.py
+++ /dev/null
@@ -1,151 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-
-import torch
-from diffusers import WanPipeline, WanTransformer3DModel
-from diffusers.utils import export_to_video
-from utils import (
- GiB,
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "WAN_2_2_DIR",
- "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
- )
-)
-pipe = WanPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache or args.parallel_type is not None:
- from cache_dit import (
- ForwardPattern,
- BlockAdapter,
- ParamsModifier,
- DBCacheConfig,
- )
-
- if "Wan2.1" in model_id:
- cachify(args, pipe)
- else:
- # Wan 2.2 only
- cachify(
- args,
- BlockAdapter(
- pipe=pipe,
- transformer=[
- pipe.transformer,
- pipe.transformer_2,
- ],
- blocks=[
- pipe.transformer.blocks,
- pipe.transformer_2.blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_2,
- ForwardPattern.Pattern_2,
- ],
- params_modifiers=[
- # high-noise transformer only have 30% steps
- ParamsModifier(
- cache_config=DBCacheConfig().reset(
- max_warmup_steps=4,
- max_cached_steps=8,
- ),
- ),
- ParamsModifier(
- cache_config=DBCacheConfig().reset(
- max_warmup_steps=2,
- max_cached_steps=20,
- ),
- ),
- ],
- has_separate_cfg=True,
- ),
- )
-
-assert isinstance(pipe.transformer, WanTransformer3DModel)
-# Enable memory savings
-if GiB() < 40:
- pipe.enable_model_cpu_offload(device=device)
-else:
- pipe.to(device)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- prompt = "A cat walks on the grass, realistic"
- if args.prompt is not None:
- prompt = args.prompt
- negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
- if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
- seed = 1234
- generator = torch.Generator(device="cpu").manual_seed(seed)
-
- num_inference_steps = 30 if not warmup else 4
- output = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=480,
- width=832,
- num_frames=49,
- guidance_scale=5.0,
- generator=generator,
- num_inference_steps=num_inference_steps,
- ).frames[0]
- return output
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"wan.{strify(args, pipe)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- export_to_video(video, save_path, fps=16)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_wan_tp.py b/examples/parallelism/run_wan_tp.py
deleted file mode 100644
index 5d75d7daa..000000000
--- a/examples/parallelism/run_wan_tp.py
+++ /dev/null
@@ -1,152 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-
-import torch
-from diffusers import WanPipeline, WanTransformer3DModel
-from diffusers.utils import export_to_video
-from utils import (
- GiB,
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "WAN_2_2_DIR",
- "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
- )
-)
-pipe = WanPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache or args.parallel_type is not None:
- from cache_dit import (
- ForwardPattern,
- BlockAdapter,
- ParamsModifier,
- DBCacheConfig,
- )
-
- if "Wan2.1" in model_id:
- cachify(args, pipe)
- else:
- # Wan 2.2 only
- cachify(
- args,
- BlockAdapter(
- pipe=pipe,
- transformer=[
- pipe.transformer,
- pipe.transformer_2,
- ],
- blocks=[
- pipe.transformer.blocks,
- pipe.transformer_2.blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_2,
- ForwardPattern.Pattern_2,
- ],
- params_modifiers=[
- # high-noise transformer only have 30% steps
- ParamsModifier(
- cache_config=DBCacheConfig().reset(
- max_warmup_steps=4,
- max_cached_steps=8,
- ),
- ),
- ParamsModifier(
- cache_config=DBCacheConfig().reset(
- max_warmup_steps=2,
- max_cached_steps=20,
- ),
- ),
- ],
- has_separate_cfg=True,
- ),
- )
-
-# Enable memory savings
-if GiB() < 40:
- pipe.enable_model_cpu_offload(device=device)
-else:
- pipe.to(device)
-
-assert isinstance(pipe.transformer, WanTransformer3DModel)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- prompt = "A cat walks on the grass, realistic"
- if args.prompt is not None:
- prompt = args.prompt
- negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
- if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
- seed = 1234
- generator = torch.Generator(device="cpu").manual_seed(seed)
-
- num_inference_steps = 30 if not warmup else 4
- output = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=480,
- width=832,
- num_frames=49,
- guidance_scale=5.0,
- generator=generator,
- num_inference_steps=num_inference_steps,
- ).frames[0]
- return output
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
- cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"wan.{strify(args, pipe)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- export_to_video(video, save_path, fps=16)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_wan_vace_cp.py b/examples/parallelism/run_wan_vace_cp.py
deleted file mode 100644
index 4b0d00f98..000000000
--- a/examples/parallelism/run_wan_vace_cp.py
+++ /dev/null
@@ -1,156 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import PIL.Image
-from diffusers import AutoencoderKLWan, WanVACEPipeline
-from diffusers.schedulers.scheduling_unipc_multistep import (
- UniPCMultistepScheduler,
-)
-from diffusers.utils import export_to_video, load_image
-
-from utils import (
- GiB,
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-
-def prepare_video_and_mask(
- first_img: PIL.Image.Image,
- last_img: PIL.Image.Image,
- height: int,
- width: int,
- num_frames: int,
-):
- first_img = first_img.resize((width, height))
- last_img = last_img.resize((width, height))
- frames = []
- frames.append(first_img)
- # Ideally, this should be 127.5 to match original code, but they perform computation on numpy arrays
- # whereas we are passing PIL images. If you choose to pass numpy arrays, you can set it to 127.5 to
- # match the original code.
- frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2))
- frames.append(last_img)
- mask_black = PIL.Image.new("L", (width, height), 0)
- mask_white = PIL.Image.new("L", (width, height), 255)
- mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black]
- return frames, mask
-
-
-model_id = args.model_path if args.model_path is not None else "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
-model_id = (
- args.model_path if args.model_path is not None else os.environ.get("WAN_VACE_DIR", model_id)
-)
-vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
-pipe = WanVACEPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
-flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-# Enable memory savings
-if GiB() < 40:
- pipe.enable_model_cpu_offload(device=device)
-else:
- pipe.to(device)
-
-assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
-pipe.vae.enable_tiling()
-pipe.vae.enable_slicing()
-
-prompt = (
- "CG animation style, a small blue bird takes off from the ground, "
- "flapping its wings. The bird's feathers are delicate, with a unique "
- "pattern on its chest. The background shows a blue sky with white "
- "clouds under bright sunshine. The camera follows the bird upward, "
- "capturing its flight and the vastness of the sky from a close-up, "
- "low-angle perspective."
-)
-if args.prompt is not None:
- prompt = args.prompt
-
-negative_prompt = (
- "Bright tones, overexposed, static, blurred details, subtitles, "
- "style, works, paintings, images, static, overall gray, worst "
- "quality, low quality, JPEG compression residue, ugly, incomplete, "
- "extra fingers, poorly drawn hands, poorly drawn faces, deformed, "
- "disfigured, misshapen limbs, fused fingers, still picture, messy "
- "background, three legs, many people in the background, walking "
- "backwards"
-)
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-first_frame = load_image("../data/flf2v_input_first_frame.png")
-last_frame = load_image("../data/flf2v_input_last_frame.png")
-
-height = 512
-width = 512
-num_frames = 81
-video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- output = pipe(
- video=video,
- mask=mask,
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=height,
- width=width,
- num_frames=num_frames,
- num_inference_steps=30 if not warmup else 5,
- guidance_scale=5.0,
- generator=torch.Generator("cpu").manual_seed(42),
- ).frames[0]
- return output
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-output = run_pipe(warmup=False)
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
-
- stats = cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"wan-vace.{strify(args, stats)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(output, save_path, fps=16)
-
-maybe_destroy_distributed()
diff --git a/examples/parallelism/run_wan_vace_tp.py b/examples/parallelism/run_wan_vace_tp.py
deleted file mode 100644
index a0f289877..000000000
--- a/examples/parallelism/run_wan_vace_tp.py
+++ /dev/null
@@ -1,158 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import PIL.Image
-from diffusers import AutoencoderKLWan, WanVACEPipeline
-from diffusers.schedulers.scheduling_unipc_multistep import (
- UniPCMultistepScheduler,
-)
-from diffusers.utils import export_to_video, load_image
-
-from utils import (
- GiB,
- cachify,
- get_args,
- maybe_destroy_distributed,
- maybe_init_distributed,
- strify,
- MemoryTracker,
-)
-
-import cache_dit
-
-args = get_args()
-print(args)
-
-rank, device = maybe_init_distributed(args)
-
-
-def prepare_video_and_mask(
- first_img: PIL.Image.Image,
- last_img: PIL.Image.Image,
- height: int,
- width: int,
- num_frames: int,
-):
- first_img = first_img.resize((width, height))
- last_img = last_img.resize((width, height))
- frames = []
- frames.append(first_img)
- # Ideally, this should be 127.5 to match original code, but they perform computation on numpy arrays
- # whereas we are passing PIL images. If you choose to pass numpy arrays, you can set it to 127.5 to
- # match the original code.
- frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2))
- frames.append(last_img)
- mask_black = PIL.Image.new("L", (width, height), 0)
- mask_white = PIL.Image.new("L", (width, height), 255)
- mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black]
- return frames, mask
-
-
-model_id = args.model_path if args.model_path is not None else "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
-model_id = (
- args.model_path if args.model_path is not None else os.environ.get("WAN_VACE_DIR", model_id)
-)
-vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
-pipe = WanVACEPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
-flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
-
-if args.cache or args.parallel_type is not None:
- cachify(args, pipe)
-
-torch.cuda.empty_cache()
-# Enable memory savings
-if GiB() < 40:
- pipe.enable_model_cpu_offload(device=device)
-else:
- pipe.to(device)
-
-assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
-pipe.vae.enable_tiling()
-pipe.vae.enable_slicing()
-
-# Set default prompt and negative prompt
-prompt = (
- "CG animation style, a small blue bird takes off from the ground, "
- "flapping its wings. The bird's feathers are delicate, with a unique "
- "pattern on its chest. The background shows a blue sky with white "
- "clouds under bright sunshine. The camera follows the bird upward, "
- "capturing its flight and the vastness of the sky from a close-up, "
- "low-angle perspective."
-)
-if args.prompt is not None:
- prompt = args.prompt
-
-negative_prompt = (
- "Bright tones, overexposed, static, blurred details, subtitles, "
- "style, works, paintings, images, static, overall gray, worst "
- "quality, low quality, JPEG compression residue, ugly, incomplete, "
- "extra fingers, poorly drawn hands, poorly drawn faces, deformed, "
- "disfigured, misshapen limbs, fused fingers, still picture, messy "
- "background, three legs, many people in the background, walking "
- "backwards"
-)
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-first_frame = load_image("../data/flf2v_input_first_frame.png")
-last_frame = load_image("../data/flf2v_input_last_frame.png")
-
-height = 512
-width = 512
-num_frames = 81
-video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames)
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- output = pipe(
- video=video,
- mask=mask,
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=height,
- width=width,
- num_frames=num_frames,
- num_inference_steps=30 if not warmup else 5,
- guidance_scale=5.0,
- generator=torch.Generator("cpu").manual_seed(42),
- ).frames[0]
- return output
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-output = run_pipe(warmup=False)
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-if rank == 0:
-
- stats = cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"wan-vace.{strify(args, stats)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(output, save_path, fps=16)
-
-maybe_destroy_distributed()
diff --git a/examples/pipeline/run_allegro.py b/examples/pipeline/run_allegro.py
deleted file mode 100644
index b9dcaed82..000000000
--- a/examples/pipeline/run_allegro.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import AllegroPipeline
-from diffusers.utils import export_to_video
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("ALLEGRO_DIR", "rhymes-ai/Allegro")
-)
-
-pipe = AllegroPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.float16,
-)
-
-pipe.to("cuda")
-
-pipe.vae.enable_tiling()
-
-if args.cache:
- cachify(args, pipe)
-
-prompt = (
- "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, "
- "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this "
- "location might be a popular spot for docking fishing boats."
-)
-
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = pipe(
- prompt,
- guidance_scale=7.5,
- max_sequence_length=512,
- num_inference_steps=100,
- generator=torch.Generator("cpu").manual_seed(0),
-).frames[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"allegro.{strify(args, stats)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=8)
diff --git a/examples/pipeline/run_amused.py b/examples/pipeline/run_amused.py
deleted file mode 100644
index c4d97d177..000000000
--- a/examples/pipeline/run_amused.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import AmusedPipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = AmusedPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "AMUSED_DIR",
- "amused/amused-512",
- )
- ),
- variant="fp16",
- torch_dtype=torch.float16,
-)
-pipe = pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-prompt = "a photo of an astronaut riding a horse on mars"
-
-
-if args.prompt is not None:
-
- prompt = args.prompt
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt,
- num_inference_steps=12,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"amused.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_auraflow.py b/examples/pipeline/run_auraflow.py
deleted file mode 100644
index 98bfd04f6..000000000
--- a/examples/pipeline/run_auraflow.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-
-from diffusers import AuraFlowPipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("AURAFLOW_DIR", "fal/AuraFlow-v0.3")
-)
-
-pipe = AuraFlowPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.float16,
- use_safetensors=True,
-).to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-# Set default prompt
-prompt = "rempage of the iguana character riding F1, fast and furious, cinematic movie poster"
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt=prompt,
- width=1536,
- height=768,
- num_inference_steps=50,
- generator=torch.Generator("cpu").manual_seed(1),
- guidance_scale=3.5,
-).images[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"auraflow.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_chroma.py b/examples/pipeline/run_chroma.py
deleted file mode 100644
index 3422a20a7..000000000
--- a/examples/pipeline/run_chroma.py
+++ /dev/null
@@ -1,71 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import ChromaPipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = ChromaPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "CHROMA1_DIR",
- "lodestones/Chroma1-HD",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-prompt = [
- "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
-]
-if args.prompt is not None:
- prompt = [args.prompt]
-
-negative_prompt = [
- "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
-]
-if args.negative_prompt is not None:
- negative_prompt = [args.negative_prompt]
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- generator=torch.Generator("cpu").manual_seed(433),
- num_inference_steps=40,
- guidance_scale=3.0,
- num_images_per_prompt=1,
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe, details=True)
-
-time_cost = end - start
-save_path = f"chroma1-hd.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_chrono_edit.py b/examples/pipeline/run_chrono_edit.py
deleted file mode 100644
index 320ad1030..000000000
--- a/examples/pipeline/run_chrono_edit.py
+++ /dev/null
@@ -1,126 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import numpy as np
-from PIL import Image
-from diffusers import (
- AutoencoderKLWan,
- ChronoEditTransformer3DModel,
- ChronoEditPipeline,
-)
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers.utils import load_image
-from transformers import CLIPVisionModel
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-model_id = args.model_path if args.model_path is not None else "nvidia/ChronoEdit-14B-Diffusers"
-model_id = (
- args.model_path if args.model_path is not None else os.environ.get("CHRONO_EDIT_DIR", model_id)
-)
-
-image_encoder = CLIPVisionModel.from_pretrained(
- model_id, subfolder="image_encoder", torch_dtype=torch.float32
-)
-vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
-transformer = ChronoEditTransformer3DModel.from_pretrained(
- model_id, subfolder="transformer", torch_dtype=torch.bfloat16
-)
-
-enable_quantization = args.quantize and args.quantize_type == "bitsandbytes_4bit"
-
-pipe = ChronoEditPipeline.from_pretrained(
- model_id,
- vae=vae,
- image_encoder=image_encoder,
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- # text_encoder: ~ 6GiB, transformer: ~ 8GiB, total: ~14GiB
- components_to_quantize=["text_encoder", "transformer"],
- )
- if enable_quantization
- else None
- ),
-)
-
-if args.cache:
- cachify(args, pipe)
-
-# Enable memory savings
-pipe.enable_model_cpu_offload()
-assert isinstance(pipe.vae, AutoencoderKLWan)
-pipe.vae.enable_tiling()
-pipe.vae.enable_slicing()
-
-image = load_image("../data/chrono_edit_example.png")
-
-max_area = 720 * 1280
-aspect_ratio = image.height / image.width
-mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
-height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
-width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
-image = image.resize((width, height))
-
-prompt = (
- "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
- "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
-)
-
-
-if args.prompt is not None:
-
- prompt = args.prompt
-
-
-def run_pipe(warmup: bool = False):
- output = pipe(
- image=image,
- prompt=prompt,
- height=height,
- width=width,
- num_frames=5,
- guidance_scale=5.0,
- enable_temporal_reasoning=False,
- num_temporal_reasoning_steps=0,
- num_inference_steps=((50 if not warmup else 1) if args.steps is None else args.steps),
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- output = Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8"))
- return output
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-output = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"chrono-edit.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-output.save(save_path)
diff --git a/examples/pipeline/run_cogvideox.py b/examples/pipeline/run_cogvideox.py
deleted file mode 100644
index 3173d8131..000000000
--- a/examples/pipeline/run_cogvideox.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers.utils import export_to_video
-from diffusers import CogVideoXPipeline, AutoencoderKLCogVideoX
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("COGVIDEOX_DIR", "THUDM/CogVideoX-2b")
-)
-
-pipe = CogVideoXPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
-)
-
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-assert isinstance(pipe.vae, AutoencoderKLCogVideoX) # enable type check for IDE
-pipe.vae.enable_slicing()
-pipe.vae.enable_tiling()
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-prompt = (
- "A panda, dressed in a small, red jacket and a tiny hat, "
- "sits on a wooden stool in a serene bamboo forest. The "
- "panda's fluffy paws strum a miniature acoustic guitar, "
- "producing soft, melodic tunes. Nearby, a few other pandas "
- "gather, watching curiously and some clapping in rhythm. "
- "Sunlight filters through the tall bamboo, casting a gentle "
- "glow on the scene. The panda's face is expressive, showing "
- "concentration and joy as it plays. The background includes "
- "a small, flowing stream and vibrant green foliage, enhancing "
- "the peaceful and magical atmosphere of this unique musical "
- "performance."
-)
-if args.prompt is not None:
- prompt = args.prompt
-video = pipe(
- prompt=prompt,
- num_videos_per_prompt=1,
- num_inference_steps=50,
- num_frames=49,
- guidance_scale=6,
- generator=torch.Generator("cpu").manual_seed(0),
-).frames[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"cogvideox.{strify(args, stats)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=8)
diff --git a/examples/pipeline/run_cogvideox_1.5.py b/examples/pipeline/run_cogvideox_1.5.py
deleted file mode 100644
index be36246b6..000000000
--- a/examples/pipeline/run_cogvideox_1.5.py
+++ /dev/null
@@ -1,77 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers.utils import export_to_video
-from diffusers import CogVideoXPipeline, AutoencoderKLCogVideoX
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("COGVIDEOX_1_5_DIR", "THUDM/CogVideoX1.5-5b")
-)
-
-pipe = CogVideoXPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
- device_map="balanced",
-)
-
-if args.cache:
- cachify(args, pipe)
-
-assert isinstance(pipe.vae, AutoencoderKLCogVideoX) # enable type check for IDE
-pipe.vae.enable_slicing()
-pipe.vae.enable_tiling()
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-prompt = (
- "A panda, dressed in a small, red jacket and a tiny hat, "
- "sits on a wooden stool in a serene bamboo forest. The "
- "panda's fluffy paws strum a miniature acoustic guitar, "
- "producing soft, melodic tunes. Nearby, a few other pandas "
- "gather, watching curiously and some clapping in rhythm. "
- "Sunlight filters through the tall bamboo, casting a gentle "
- "glow on the scene. The panda's face is expressive, showing "
- "concentration and joy as it plays. The background includes "
- "a small, flowing stream and vibrant green foliage, enhancing "
- "the peaceful and magical atmosphere of this unique musical "
- "performance."
-)
-if args.prompt is not None:
- prompt = args.prompt
-video = pipe(
- prompt=prompt,
- num_videos_per_prompt=1,
- num_inference_steps=50,
- num_frames=16,
- guidance_scale=6,
- generator=torch.Generator("cpu").manual_seed(0),
-).frames[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"cogvideox1.5.{strify(args, stats)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=8)
diff --git a/examples/pipeline/run_cogview3_plus.py b/examples/pipeline/run_cogview3_plus.py
deleted file mode 100644
index 8d07dd429..000000000
--- a/examples/pipeline/run_cogview3_plus.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import CogView3PlusPipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-
-pipe = CogView3PlusPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "COGVIEW3_DIR",
- "THUDM/CogView3-Plus-3B",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background."
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt=prompt,
- guidance_scale=7.0,
- num_inference_steps=50,
- width=1024,
- height=1024,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"cogview3_plus.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_cogview4.py b/examples/pipeline/run_cogview4.py
deleted file mode 100644
index 22e0d8f0e..000000000
--- a/examples/pipeline/run_cogview4.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import CogView4Pipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-
-pipe = CogView4Pipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "COGVIEW4_DIR",
- "THUDM/CogView4-6B",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe, enable_separate_cfg=True)
-
-prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background."
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt=prompt,
- guidance_scale=3.5, # >1, do separate cfg
- num_inference_steps=50,
- width=1024,
- height=1024,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"cogview4.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_consisid.py b/examples/pipeline/run_consisid.py
deleted file mode 100644
index e0a1f1d24..000000000
--- a/examples/pipeline/run_consisid.py
+++ /dev/null
@@ -1,94 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import ConsisIDPipeline
-from diffusers.pipelines.consisid.consisid_utils import (
- prepare_face_models,
- process_face_embeddings_infer,
-)
-from diffusers.utils import export_to_video
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("CONSISID_DIR", "BestWishYsh/ConsisID-preview")
-)
-
-(
- face_helper_1,
- face_helper_2,
- face_clip_model,
- face_main_model,
- eva_transform_mean,
- eva_transform_std,
-) = prepare_face_models(model_id, device="cuda", dtype=torch.bfloat16)
-pipe = ConsisIDPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
- device_map="balanced",
-)
-
-# ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body).
-prompt = "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel."
-if args.prompt is not None:
- prompt = args.prompt
-# image = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_input.png?download=true"
-image = "../data/consisid_input.png"
-
-id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(
- face_helper_1,
- face_clip_model,
- face_helper_2,
- eva_transform_mean,
- eva_transform_std,
- face_main_model,
- "cuda",
- torch.bfloat16,
- image,
- is_align_face=True,
-)
-
-if args.cache:
- cachify(args, pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = pipe(
- image=image,
- prompt=prompt,
- num_inference_steps=50,
- guidance_scale=6.0,
- use_dynamic_cfg=False,
- id_vit_hidden=id_vit_hidden,
- id_cond=id_cond,
- kps_cond=face_kps,
- generator=torch.Generator("cpu").manual_seed(42),
-).frames[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe, details=True)
-
-time_cost = end - start
-save_path = f"consisid.{strify(args, pipe)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=8)
diff --git a/examples/pipeline/run_cosmos.py b/examples/pipeline/run_cosmos.py
deleted file mode 100644
index da6ad1e24..000000000
--- a/examples/pipeline/run_cosmos.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import Cosmos2VideoToWorldPipeline
-from diffusers.utils import export_to_video, load_image
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-# Available checkpoints: nvidia/Cosmos-Predict2-2B-Video2World, nvidia/Cosmos-Predict2-14B-Video2World
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("COSMOS_DIR", "nvidia/Cosmos-Predict2-2B-Video2World")
-)
-
-pipe = Cosmos2VideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
-
-if args.prompt is not None:
-
- prompt = args.prompt
-negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-image = load_image(
- "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yellow-scrubber.png"
-)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = pipe(
- image=image,
- prompt=prompt,
- negative_prompt=negative_prompt,
- generator=torch.Generator().manual_seed(1),
-).frames[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"cosmos.{strify(args, stats)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=8)
diff --git a/examples/pipeline/run_dit_xl.py b/examples/pipeline/run_dit_xl.py
deleted file mode 100644
index 568b27c1b..000000000
--- a/examples/pipeline/run_dit_xl.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-
-from diffusers import DiTPipeline, DPMSolverMultistepScheduler
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-
-pipe = DiTPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "DIT_XL_DIR",
- "facebook/DiT-XL-2-256",
- )
- ),
- torch_dtype=torch.float16,
-)
-pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
-pipe = pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-words = ["white shark"]
-
-class_ids = pipe.get_label_ids(words)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- class_labels=class_ids,
- num_inference_steps=25,
- generator=torch.Generator("cpu").manual_seed(33),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"dit-xl.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_easyanimate.py b/examples/pipeline/run_easyanimate.py
deleted file mode 100644
index 820505081..000000000
--- a/examples/pipeline/run_easyanimate.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import EasyAnimatePipeline
-from diffusers.utils import export_to_video
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("EASY_ANIMATE_DIR", "alibaba-pai/EasyAnimateV5.1-7b-zh")
-)
-
-
-pipe = EasyAnimatePipeline.from_pretrained(
- model_id,
- torch_dtype=torch.float16,
-)
-
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-prompt = "A cat walks on the grass, realistic style."
-
-if args.prompt is not None:
-
- prompt = args.prompt
-negative_prompt = "bad detailed"
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- num_frames=49,
- num_inference_steps=30,
- generator=torch.Generator("cuda").manual_seed(0),
-).frames[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"easyanimate.{strify(args, stats)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=8)
diff --git a/examples/pipeline/run_flux.py b/examples/pipeline/run_flux.py
deleted file mode 100644
index 0135fa70d..000000000
--- a/examples/pipeline/run_flux.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import FluxPipeline, FluxTransformer2DModel
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = FluxPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "FLUX_DIR",
- "black-forest-labs/FLUX.1-dev",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-if args.cache:
- cachify(args, pipe)
-
-assert isinstance(pipe.transformer, FluxTransformer2DModel)
-if args.quantize:
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type,
- exclude_layers=[
- "embedder",
- "embed",
- ],
- )
- pipe.text_encoder_2 = cache_dit.quantize(
- pipe.text_encoder_2,
- quant_type=args.quantize_type,
- )
- print(f"Applied quantization: {args.quantize_type} to Transformer and Text Encoder 2.")
-
-pipe.to("cuda")
-
-if args.attn is not None:
- if hasattr(pipe.transformer, "set_attention_backend"):
- pipe.transformer.set_attention_backend(args.attn)
- print(f"Set attention backend to {args.attn}")
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
- pipe.text_encoder = torch.compile(pipe.text_encoder)
- pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2)
- pipe.vae = torch.compile(pipe.vae)
-
-# Set default prompt
-prompt = "A cat holding a sign that says hello world"
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe():
- image = pipe(
- prompt,
- height=1024 if args.height is None else args.height,
- width=1024 if args.width is None else args.width,
- num_inference_steps=28 if args.steps is None else args.steps,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-# warmup
-_ = run_pipe()
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"flux.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_flux_fill.py b/examples/pipeline/run_flux_fill.py
deleted file mode 100644
index 9fb8d9003..000000000
--- a/examples/pipeline/run_flux_fill.py
+++ /dev/null
@@ -1,82 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import FluxFillPipeline
-from diffusers.utils import load_image
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = FluxFillPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "FLUX_FILL_DIR",
- "black-forest-labs/FLUX.1-Fill-dev",
- )
- ),
- torch_dtype=torch.bfloat16,
-).to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-# Set default prompt
-prompt = "a white paper cup"
-if args.prompt is not None:
- prompt = args.prompt
-
-if args.compile:
- from diffusers import FluxTransformer2DModel
-
- cache_dit.set_compile_configs()
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
- pipe.transformer.compile_repeated_blocks(fullgraph=True)
-
- # warmup
- image = pipe(
- prompt=prompt,
- image=load_image("../data/cup.png"),
- mask_image=load_image("../data/cup_mask.png"),
- guidance_scale=30,
- num_inference_steps=28,
- max_sequence_length=512,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt=prompt,
- image=load_image("../data/cup.png"),
- mask_image=load_image("../data/cup_mask.png"),
- guidance_scale=30,
- num_inference_steps=28,
- max_sequence_length=512,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"flux-fill.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_flux_kontext.py b/examples/pipeline/run_flux_kontext.py
deleted file mode 100644
index c3bba8df0..000000000
--- a/examples/pipeline/run_flux_kontext.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import FluxKontextPipeline
-from diffusers.utils import load_image
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = FluxKontextPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "FLUX_KONTEXT_DIR",
- "black-forest-labs/FLUX.1-Kontext-dev",
- )
- ),
- torch_dtype=torch.bfloat16,
-).to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-# Set default prompt
-prompt = "Add a hat to the cat"
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-
-image = pipe(
- image=load_image("../data/cat.png"),
- prompt=prompt,
- guidance_scale=2.5,
- num_inference_steps=28,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"flux-kontext.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_hidream.py b/examples/pipeline/run_hidream.py
deleted file mode 100644
index bce38e4ae..000000000
--- a/examples/pipeline/run_hidream.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import HiDreamImagePipeline
-from transformers import AutoTokenizer, LlamaForCausalLM
-from diffusers.quantizers import PipelineQuantizationConfig
-from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
-
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-tokenizer_4 = AutoTokenizer.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "LLAMA_DIR",
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
- )
- ),
-)
-
-text_encoder_4 = LlamaForCausalLM.from_pretrained(
- os.environ.get(
- "LLAMA_DIR",
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
- ),
- output_hidden_states=True,
- output_attentions=True,
- torch_dtype=torch.bfloat16,
- quantization_config=TransformersBitsAndBytesConfig(
- load_in_4bit=True,
- ),
-)
-
-pipe = HiDreamImagePipeline.from_pretrained(
- os.environ.get(
- "HIDREAM_DIR",
- "HiDream-ai/HiDream-I1-Full",
- ),
- tokenizer_4=tokenizer_4,
- text_encoder_4=text_encoder_4,
- torch_dtype=torch.bfloat16,
- quantization_config=PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["transformer"],
- ),
-)
-
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-# Set default prompt
-prompt = 'A cute girl holding a sign that says "Hi-Dreams.ai".'
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt,
- height=1024 if args.height is None else args.height,
- width=1024 if args.width is None else args.width,
- guidance_scale=5.0,
- num_inference_steps=50,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"hidream.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_hunyuan_dit.py b/examples/pipeline/run_hunyuan_dit.py
deleted file mode 100644
index fe7e55a35..000000000
--- a/examples/pipeline/run_hunyuan_dit.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import HunyuanDiTPipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("HUNYUAN_DIT_DIR", "Tencent-Hunyuan/HunyuanDiT-Diffusers")
-)
-
-pipe = HunyuanDiTPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.float16,
-)
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-# You may also use English prompt as HunyuanDiT supports both English and Chinese
-# Set default prompt
-prompt = "一个宇航员在骑马"
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt,
- num_inference_steps=50,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"hunyuan_dit.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_hunyuan_image_2.1.py b/examples/pipeline/run_hunyuan_image_2.1.py
deleted file mode 100644
index 312aa7ecd..000000000
--- a/examples/pipeline/run_hunyuan_image_2.1.py
+++ /dev/null
@@ -1,113 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- HunyuanImagePipeline,
- HunyuanImageTransformer2DModel,
-)
-from diffusers.quantizers import PipelineQuantizationConfig
-
-from utils import GiB, get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-enable_quatization = args.quantize and GiB() < 96
-# For now you need to install the latest diffusers as below:
-# pip install git+https://github.com/huggingface/diffusers@main
-pipe: HunyuanImagePipeline = HunyuanImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "HUNYUAN_IMAGE_DIR",
- "hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
- )
- ),
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder"], # ~4GiB
- )
- if enable_quatization
- else None
- ),
-)
-
-if GiB() < 96:
- if enable_quatization:
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type, # float8_weight_only
- )
- pipe.to("cuda")
-else:
- pipe.to("cuda")
-
-
-if args.cache:
- cachify(args, pipe)
-
-torch.cuda.empty_cache()
-assert isinstance(pipe.transformer, HunyuanImageTransformer2DModel)
-
-if GiB() < 96 and not enable_quatization:
- pipe.enable_model_cpu_offload()
-
-
-def run_pipe(warmup: bool = False):
- prompt = "A cute, cartoon-style anthropomorphic penguin plush toy with fluffy fur, "
- if args.prompt is not None:
- prompt = args.prompt
- "standing in a painting studio, wearing a red knitted scarf and a red beret with "
- "the word “Tencent” on it, holding a paintbrush with a focused expression as it "
- "paints an oil painting of the Mona Lisa, rendered in a photorealistic photographic style."
- image = pipe(
- prompt,
- num_inference_steps=50 if not warmup else 5,
- height=2048,
- width=2048,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"hunyuan_image_2.1.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_hunyuan_video.py b/examples/pipeline/run_hunyuan_video.py
deleted file mode 100644
index c57bf7f45..000000000
--- a/examples/pipeline/run_hunyuan_video.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers.utils import export_to_video
-from diffusers import HunyuanVideoPipeline, AutoencoderKLHunyuanVideo
-from utils import GiB, get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("HUNYUAN_VIDEO_DIR", "hunyuanvideo-community/HunyuanVideo")
-)
-pipe = HunyuanVideoPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
- # https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
- device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
-)
-
-
-if args.cache:
- cachify(args, pipe)
-
-# When device_map is None, we need to explicitly move the model to GPU
-# or enable CPU offload to avoid running on CPU
-if torch.cuda.device_count() <= 1:
- # Single GPU: use CPU offload for memory efficiency
- pipe.enable_model_cpu_offload()
-elif torch.cuda.device_count() > 1 and pipe.device.type == "cpu":
- # Multi-GPU but model is on CPU (device_map was None): move to default GPU
- pipe.to("cuda")
-
-assert isinstance(pipe.vae, AutoencoderKLHunyuanVideo)
-
-# Enable memory savings
-if GiB() <= 48:
- pipe.vae.enable_tiling(
- # Make it runnable on GPUs with 48GB memory
- tile_sample_min_height=128,
- tile_sample_stride_height=96,
- tile_sample_min_width=128,
- tile_sample_stride_width=96,
- tile_sample_min_num_frames=32,
- tile_sample_stride_num_frames=24,
- )
-else:
- pipe.vae.enable_tiling()
-
-prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys."
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-output = pipe(
- prompt=prompt,
- num_frames=18,
- num_inference_steps=50,
- generator=torch.Generator("cpu").manual_seed(0),
-).frames[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"hunyuan_video.{strify(args, stats)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(output, save_path, fps=9)
diff --git a/examples/pipeline/run_kandinsky5_t2v.py b/examples/pipeline/run_kandinsky5_t2v.py
deleted file mode 100644
index aaae49d82..000000000
--- a/examples/pipeline/run_kandinsky5_t2v.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import Kandinsky5T2VPipeline, AutoencoderKLHunyuanVideo
-from diffusers.utils import export_to_video
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-# Available models:
-# ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers
-# ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers
-# ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers
-# ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers"
-)
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("KANDINSKY5_T2V_DIR", model_id)
-)
-pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
-pipe = pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe, enable_separate_cfg=not ("nocfg" in model_id))
-
-prompt = "A cat and a dog baking a cake together in a kitchen."
-
-if args.prompt is not None:
-
- prompt = args.prompt
-negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards"
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-assert isinstance(pipe.vae, AutoencoderKLHunyuanVideo)
-
-pipe.vae.enable_tiling()
-
-
-def run_pipe():
- video = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=512,
- width=768,
- num_frames=121,
- num_inference_steps=50,
- guidance_scale=5.0,
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- return video
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"kandinsky5.{strify(args, pipe)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=24, quality=9)
diff --git a/examples/pipeline/run_longcat_video.py b/examples/pipeline/run_longcat_video.py
deleted file mode 100644
index 18349c31b..000000000
--- a/examples/pipeline/run_longcat_video.py
+++ /dev/null
@@ -1,273 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import datetime
-import numpy as np
-
-import torch
-import torch.distributed as dist
-
-from transformers import AutoTokenizer, UMT5EncoderModel
-from transformers import BitsAndBytesConfig as TransformersBitsAndBytesConfig
-from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
-from torchvision.io import write_video
-
-sys.path.append(os.environ.get("LONGCAT_VIDEO_PKG_DIR", ""))
-from longcat_video.pipeline_longcat_video import LongCatVideoPipeline
-from longcat_video.modules.scheduling_flow_match_euler_discrete import (
- FlowMatchEulerDiscreteScheduler,
-)
-from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
-from longcat_video.modules.longcat_video_dit import (
- LongCatVideoTransformer3DModel,
-)
-from longcat_video.context_parallel import context_parallel_util
-from longcat_video.context_parallel.context_parallel_util import (
- init_context_parallel,
-)
-
-from utils import get_args, strify, GiB, MemoryTracker
-import cache_dit
-
-# Example usage:
-# export LONGCAT_VIDEO_PKG_DIR=/path/to/codes/of/LongCat-Video
-# export LONGCAT_VIDEO_DIR=/path/to/models/of/LongCat-Video
-# Add `--quantize` to enable loading models with bitsandbytes
-# for lower memory usage (e.g, GPU w/ < 48GB memory)
-# torchrun --nproc_per_node=4 run_longcat_video.py --quantize --compile
-# torchrun --nproc_per_node=4 run_longcat_video.py --quantize --compile --cache
-# torchrun --nproc_per_node=4 run_longcat_video.py --quantize --compile --cache --Fn 1
-
-
-def torch_gc():
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
-
-
-def generate(args):
- print(args)
- # case setup
- prompt = "In a realistic photography style, a white boy around seven or eight years old sits on a park bench, wearing a light blue T-shirt, denim shorts, and white sneakers. He holds an ice cream cone with vanilla and chocolate flavors, and beside him is a medium-sized golden Labrador. Smiling, the boy offers the ice cream to the dog, who eagerly licks it with its tongue. The sun is shining brightly, and the background features a green lawn and several tall trees, creating a warm and loving scene."
- if args.prompt is not None:
- prompt = args.prompt
- negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
- if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
- # load parsed args
- checkpoint_dir = args.checkpoint_dir
- context_parallel_size = args.context_parallel_size
-
- # prepare distributed environment
- rank = int(os.environ["RANK"])
- num_gpus = torch.cuda.device_count()
- local_rank = rank % num_gpus
- torch.cuda.set_device(local_rank)
- dist.init_process_group(
- backend="nccl",
- timeout=datetime.timedelta(seconds=3600 * 24),
- device_id=local_rank,
- )
- global_rank = dist.get_rank()
- num_processes = dist.get_world_size()
-
- if context_parallel_size is None:
- context_parallel_size = num_processes
-
- # initialize context parallel before loading models
- init_context_parallel(
- context_parallel_size=context_parallel_size,
- global_rank=global_rank,
- world_size=num_processes,
- )
- cp_size = context_parallel_util.get_cp_size()
- cp_split_hw = context_parallel_util.get_optimal_split(cp_size)
-
- tokenizer = AutoTokenizer.from_pretrained(
- checkpoint_dir, subfolder="tokenizer", torch_dtype=torch.bfloat16
- )
-
- # Load text encoder with bnb 4bits quantization if specified
- text_encoder = UMT5EncoderModel.from_pretrained(
- checkpoint_dir,
- subfolder="text_encoder",
- torch_dtype=torch.bfloat16,
- quantization_config=(
- TransformersBitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_quant_type="nf4",
- bnb_4bit_compute_dtype=torch.bfloat16,
- )
- if args.quantize
- else None
- ),
- )
-
- vae = AutoencoderKLWan.from_pretrained(
- checkpoint_dir, subfolder="vae", torch_dtype=torch.bfloat16
- )
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
- checkpoint_dir, subfolder="scheduler", torch_dtype=torch.bfloat16
- )
-
- # Load DiT with bnb 4bits/8bits quantization if specified
- use_4bits_transformer = False
- if args.quantize:
- if context_parallel_size >= 2 and GiB() >= 40:
- # Activation will be split across multiple GPUs in CP,
- # so we only apply FP8 Weight Only Quantization here
- # to keep higher accuracy.
- dit = LongCatVideoTransformer3DModel.from_pretrained(
- checkpoint_dir,
- subfolder="dit",
- cp_split_hw=cp_split_hw,
- torch_dtype=torch.bfloat16,
- )
- dit = cache_dit.quantize(dit, quant_type=args.quantize_type)
- else:
- dit = LongCatVideoTransformer3DModel.from_pretrained(
- checkpoint_dir,
- subfolder="dit",
- cp_split_hw=cp_split_hw,
- torch_dtype=torch.bfloat16,
- quantization_config=DiffusersBitsAndBytesConfig(
- load_in_4bit=True,
- bnb_4bit_quant_type="nf4",
- bnb_4bit_compute_dtype=torch.bfloat16,
- ),
- )
- use_4bits_transformer = True
-
- pipe = LongCatVideoPipeline(
- tokenizer=tokenizer,
- text_encoder=text_encoder,
- vae=vae,
- scheduler=scheduler,
- dit=dit,
- )
-
- if GiB() <= 48:
- pipe.vae.enable_tiling()
-
- pipe.to(f"cuda:{local_rank}")
-
- if args.cache:
- from cache_dit import (
- BlockAdapter,
- ForwardPattern,
- DBCacheConfig,
- TaylorSeerCalibratorConfig,
- )
-
- assert isinstance(pipe.dit, LongCatVideoTransformer3DModel)
-
- # Using Cache-DiT to cache the DiT transformer blocks of LongCat-Video
- cache_dit.enable_cache(
- BlockAdapter(
- transformer=pipe.dit,
- blocks=pipe.dit.blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=False,
- has_separate_cfg=False,
- ),
- cache_config=DBCacheConfig(
- Fn_compute_blocks=args.Fn,
- Bn_compute_blocks=args.Bn,
- max_warmup_steps=args.max_warmup_steps,
- max_cached_steps=args.max_cached_steps,
- max_continuous_cached_steps=args.max_continuous_cached_steps,
- residual_diff_threshold=args.rdt,
- # NOTE: num_inference_steps is required for Transformer-only interface
- num_inference_steps=50 if args.steps is None else args.steps,
- ),
- calibrator_config=(
- TaylorSeerCalibratorConfig(
- taylorseer_order=args.taylorseer_order,
- )
- if args.taylorseer
- else None
- ),
- )
-
- global_seed = 42
- seed = global_seed + global_rank
-
- def run_t2v(warmup: bool = False):
- # t2v (480p)
- print(f"Generating video, warmup={warmup} ...")
- output = pipe.generate_t2v(
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=480 if args.height is None else args.height,
- width=832 if args.width is None else args.width,
- num_frames=93 if args.frames is None else args.frames,
- num_inference_steps=((50 if args.steps is None else args.steps) if not warmup else 4),
- guidance_scale=4.0,
- generator=torch.Generator(device=local_rank).manual_seed(seed),
- )[0]
- return output
-
- if args.compile:
- cache_dit.set_compile_configs()
- pipe.dit = torch.compile(pipe.dit)
-
- # warmup
- _ = run_t2v(warmup=True)
- torch_gc()
-
- memory_tracker = MemoryTracker() if args.track_memory else None
- if memory_tracker:
- memory_tracker.__enter__()
-
- start = time.time()
- output = run_t2v()
- end = time.time()
-
- if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
- if local_rank == 0:
- cache_dit.summary(pipe.dit)
-
- time_cost = end - start
- save_path = f"longcat-video.{strify(args, pipe.dit)}"
- if args.quantize:
- save_path += ".bnb4bits" if use_4bits_transformer else ".fp8wo"
- save_path += ".mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
-
- output_tensor = torch.from_numpy(np.array(output))
- output_tensor = (output_tensor * 255).clamp(0, 255).to(torch.uint8)
- write_video(
- save_path,
- output_tensor,
- fps=15,
- video_codec="libx264",
- options={"crf": f"{18}"},
- )
- del output
- torch_gc()
-
- if dist.is_initialized():
- dist.destroy_process_group()
-
-
-def _parse_args():
- DEAULT_CHECKPOINT_DIR = os.environ.get("LONGCAT_VIDEO_DIR", None)
- parser = get_args(parse=False)
- parser.add_argument("--frames", type=int, default=None)
- parser.add_argument("--context_parallel_size", type=int, default=None)
- parser.add_argument("--checkpoint_dir", type=str, default=DEAULT_CHECKPOINT_DIR)
- args = parser.parse_args()
-
- return args
-
-
-if __name__ == "__main__":
- args = _parse_args()
- generate(args)
diff --git a/examples/pipeline/run_ltx_video.py b/examples/pipeline/run_ltx_video.py
deleted file mode 100644
index b5275283f..000000000
--- a/examples/pipeline/run_ltx_video.py
+++ /dev/null
@@ -1,148 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import (
- LTXConditionPipeline,
- LTXLatentUpsamplePipeline,
- AutoencoderKLLTXVideo,
-)
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers.utils import export_to_video
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = LTXConditionPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get("LTX_VIDEO_DIR", "Lightricks/LTX-Video-0.9.7-dev")
- ),
- torch_dtype=torch.bfloat16,
- quantization_config=PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["transformer", "text_encoder"],
- ),
-)
-
-pipe_upsample = LTXLatentUpsamplePipeline.from_pretrained(
- os.environ.get("LTX_UPSCALER_DIR", "Lightricks/ltxv-spatial-upscaler-0.9.7"),
- vae=pipe.vae,
- torch_dtype=torch.bfloat16,
-)
-pipe.to("cuda")
-pipe_upsample.to("cuda")
-assert isinstance(pipe.vae, AutoencoderKLLTXVideo)
-pipe.vae.enable_tiling()
-
-if args.cache:
- cachify(args, pipe)
-
-
-def round_to_nearest_resolution_acceptable_by_vae(height, width):
- height = height - (height % pipe.vae_spatial_compression_ratio)
- width = width - (width % pipe.vae_spatial_compression_ratio)
- return height, width
-
-
-prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
-if args.prompt is not None:
- prompt = args.prompt
-
-negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-expected_height, expected_width = 512, 704
-downscale_factor = 2 / 3
-num_frames = 121
-
-# Part 1. Generate video at smaller resolution
-downscaled_height, downscaled_width = int(expected_height * downscale_factor), int(
- expected_width * downscale_factor
-)
-downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
- downscaled_height, downscaled_width
-)
-
-
-def run_pipe(warmup: bool = False):
-
- latents = pipe(
- conditions=None,
- prompt=prompt,
- negative_prompt=negative_prompt,
- width=downscaled_width,
- height=downscaled_height,
- num_frames=num_frames,
- num_inference_steps=30 if not warmup else 4,
- generator=torch.Generator("cpu").manual_seed(0),
- output_type="latent",
- ).frames
-
- # Part 2. Upscale generated video using latent upsampler with fewer inference steps
- # The available latent upsampler upscales the height/width by 2x
- upscaled_height, upscaled_width = (
- downscaled_height * 2,
- downscaled_width * 2,
- )
- upscaled_latents = pipe_upsample(latents=latents, output_type="latent").frames
-
- if warmup:
- return None
-
- # Part 3. Denoise the upscaled video with few steps to improve texture (optional, but recommended)
- video = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- width=upscaled_width,
- height=upscaled_height,
- num_frames=num_frames,
- denoise_strength=0.4, # Effectively, 4 inference steps out of 10
- num_inference_steps=10,
- latents=upscaled_latents,
- decode_timestep=0.05,
- image_cond_noise_scale=0.025,
- generator=torch.Generator("cpu").manual_seed(0),
- output_type="pil",
- ).frames[0]
- return video
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-stats = cache_dit.summary(pipe)
-
-# Part 4. Downscale the video to the expected resolution
-video = [frame.resize((expected_width, expected_height)) for frame in video]
-
-time_cost = end - start
-save_path = f"ltx-video.{strify(args, stats)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=8)
diff --git a/examples/pipeline/run_lumina2.py b/examples/pipeline/run_lumina2.py
deleted file mode 100644
index 9b3fc28f2..000000000
--- a/examples/pipeline/run_lumina2.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import Lumina2Transformer2DModel, Lumina2Pipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("LUMINA_DIR", "Alpha-VLLM/Lumina-Image-2.0")
-)
-
-ckpt_path = os.path.join(model_id, "consolidated.00-of-01.pth")
-transformer = Lumina2Transformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
-
-pipe = Lumina2Pipeline.from_pretrained(
- model_id, transformer=transformer, torch_dtype=torch.bfloat16
-)
-
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-# Set default prompt
-prompt = "a cute cat holding a sign that says hello 'Lumina2'"
-if args.prompt is not None:
- prompt = args.prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt,
- height=1024,
- width=1024,
- num_inference_steps=30,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"lumina2.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_mochi.py b/examples/pipeline/run_mochi.py
deleted file mode 100644
index 45ffec0bd..000000000
--- a/examples/pipeline/run_mochi.py
+++ /dev/null
@@ -1,78 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import MochiPipeline
-from diffusers.utils import export_to_video
-from diffusers.quantizers import PipelineQuantizationConfig
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-pipe = MochiPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "MOCHI_DIR",
- "genmo/mochi-1-preview",
- )
- ),
- torch_dtype=torch.bfloat16,
- quantization_config=PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["transformer", "text_encoder"],
- ),
-)
-
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-pipe.enable_vae_tiling()
-
-prompt = (
- "Close-up of a chameleon's eye, with its scaly skin "
- "changing color. Ultra high resolution 4k."
-)
-
-
-if args.prompt is not None:
-
- prompt = args.prompt
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = pipe(
- prompt,
- num_frames=49,
- num_inference_steps=64,
- generator=torch.Generator("cpu").manual_seed(0),
-).frames[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"mochi.{strify(args, stats)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=10)
diff --git a/examples/pipeline/run_omnigen_v1.py b/examples/pipeline/run_omnigen_v1.py
deleted file mode 100644
index a2b4df9fd..000000000
--- a/examples/pipeline/run_omnigen_v1.py
+++ /dev/null
@@ -1,58 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import OmniGenPipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("OMNIGEN_DIR", "Shitao/OmniGen-v1-diffusers")
-)
-
-pipe = OmniGenPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-prompt = "Realistic photo. A young woman sits on a sofa, holding a book and facing the camera. She wears delicate silver hoop earrings adorned with tiny, sparkling diamonds that catch the light, with her long chestnut hair cascading over her shoulders. Her eyes are focused and gentle, framed by long, dark lashes. She is dressed in a cozy cream sweater, which complements her warm, inviting smile. Behind her, there is a table with a cup of water in a sleek, minimalist blue mug. The background is a serene indoor setting with soft natural light filtering through a window, adorned with tasteful art and flowers, creating a cozy and peaceful ambiance. 4K, HD."
-if args.prompt is not None:
- prompt = args.prompt
-image = pipe(
- prompt=prompt,
- height=1024,
- width=1024,
- guidance_scale=3,
- num_inference_steps=50,
- generator=torch.Generator(device="cpu").manual_seed(111),
-).images[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"omingen-v1.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_pixart_alpha.py b/examples/pipeline/run_pixart_alpha.py
deleted file mode 100644
index f5d0ae4f9..000000000
--- a/examples/pipeline/run_pixart_alpha.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import PixArtAlphaPipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "PIXART_ALPHA_DIR",
- "PixArt-alpha/PixArt-XL-2-1024-MS",
- )
-)
-
-pipe = PixArtAlphaPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
-)
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-prompt = "A small cactus with a happy face in the Sahara desert."
-if args.prompt is not None:
- prompt = args.prompt
-image = pipe(
- prompt,
- num_inference_steps=50,
- generator=torch.Generator(device="cpu").manual_seed(42),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-time_cost = end - start
-save_path = f"pixart-alpha.{strify(args, stats)}.png"
-
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_pixart_sigma.py b/examples/pipeline/run_pixart_sigma.py
deleted file mode 100644
index 0f081e543..000000000
--- a/examples/pipeline/run_pixart_sigma.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import Transformer2DModel, PixArtSigmaPipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "PIXART_SIGMA_DIR",
- "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
- )
-)
-transformer = Transformer2DModel.from_pretrained(
- model_id,
- subfolder="transformer",
- torch_dtype=torch.bfloat16,
- use_safetensors=True,
-)
-pipe = PixArtSigmaPipeline.from_pretrained(
- model_id,
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- use_safetensors=True,
-)
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-prompt = "A small cactus with a happy face in the Sahara desert."
-if args.prompt is not None:
- prompt = args.prompt
-image = pipe(
- prompt,
- num_inference_steps=50,
- generator=torch.Generator(device="cpu").manual_seed(42),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-time_cost = end - start
-save_path = f"pixart-sigma.{strify(args, stats)}.png"
-
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_prx_t2i.py b/examples/pipeline/run_prx_t2i.py
deleted file mode 100644
index b55c9fcc7..000000000
--- a/examples/pipeline/run_prx_t2i.py
+++ /dev/null
@@ -1,66 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import PRXPipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-# Load pipeline with from_pretrained
-pipe = PRXPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "PRX_T2I_DIR",
- "Photoroom/prx-512-t2i-sft",
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-# Set default prompt
-prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe():
- image = pipe(
- prompt,
- num_inference_steps=28,
- guidance_scale=5.0,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"prx.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_qwen_image.py b/examples/pipeline/run_qwen_image.py
deleted file mode 100644
index d7b2f9ec9..000000000
--- a/examples/pipeline/run_qwen_image.py
+++ /dev/null
@@ -1,138 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
-from utils import GiB, get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- torch_dtype=torch.bfloat16,
- # https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
- device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
-)
-
-if args.cache:
- cachify(args, pipe)
-
-# When device_map is None, we need to explicitly move the model to GPU
-# or enable CPU offload to avoid running on CPU
-if torch.cuda.device_count() <= 1:
- # Single GPU: use CPU offload for memory efficiency
- pipe.enable_model_cpu_offload()
-elif torch.cuda.device_count() > 1 and pipe.device.type == "cpu":
- # Multi-GPU but model is on CPU (device_map was None): move to default GPU
- pipe.to("cuda")
-
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-if args.prompt is not None:
- prompt = args.prompt
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-
-# Generate with different aspect ratios
-aspect_ratios = {
- "1:1": (1328, 1328),
- "16:9": (1664, 928),
- "9:16": (928, 1664),
- "4:3": (1472, 1140),
- "3:4": (1140, 1472),
- "3:2": (1584, 1056),
- "2:3": (1056, 1584),
-}
-
-# Use command line args if provided, otherwise default to 16:9
-if args.width is not None and args.height is not None:
- width, height = args.width, args.height
-else:
- width, height = aspect_ratios["16:9"]
-
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
-if args.quantize:
- # Apply Quantization (default: FP8 DQ) to Transformer
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type,
- per_row=False,
- exclude_layers=[
- "img_in",
- "txt_in",
- "embedder",
- "embed",
- "norm_out",
- "proj_out",
- ],
- )
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer.compile_repeated_blocks(fullgraph=True)
-
- # warmup
- image = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=width,
- height=height,
- num_inference_steps=50,
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(42),
- ).images[0]
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-# do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
-image = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=width,
- height=height,
- num_inference_steps=50,
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(42),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"qwen-image.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_qwen_image_controlnet_inpaint.py b/examples/pipeline/run_qwen_image_controlnet_inpaint.py
deleted file mode 100644
index 2268ffbd5..000000000
--- a/examples/pipeline/run_qwen_image_controlnet_inpaint.py
+++ /dev/null
@@ -1,145 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers.utils import load_image
-from diffusers import (
- QwenImageControlNetModel,
- QwenImageControlNetInpaintPipeline,
- QwenImageTransformer2DModel,
-)
-from utils import GiB, get_args, strify, cachify, MemoryTracker
-
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-base_model = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
-)
-controlnet_model = os.environ.get(
- "QWEN_IMAGE_CN_DIR",
- "InstantX/Qwen-Image-ControlNet-Inpainting",
-)
-
-controlnet = QwenImageControlNetModel.from_pretrained(
- controlnet_model,
- torch_dtype=torch.bfloat16,
-)
-
-pipe = QwenImageControlNetInpaintPipeline.from_pretrained(
- base_model,
- controlnet=controlnet,
- torch_dtype=torch.bfloat16,
-)
-
-assert isinstance(pipe.controlnet, QwenImageControlNetModel)
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
-control_image = load_image(
- "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/images/image1.png"
-)
-mask_image = load_image(
- "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets/masks/mask1.png"
-)
-prompt = "一辆绿色的出租车行驶在路上"
-if args.prompt is not None:
- prompt = args.prompt
-negative_prompt = "worst quality, low quality, blurry, text, watermark, logo" # or " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-
-if GiB() < 96:
- # FP8 weight only
- if args.quantize:
- # Minimum VRAM required: 42 GiB
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type,
- exclude_layers=[
- "img_in",
- "txt_in",
- "embedder",
- "embed",
- "norm_out",
- "proj_out",
- ],
- )
- pipe.text_encoder = cache_dit.quantize(
- pipe.text_encoder,
- quant_type=args.quantize_type,
- )
-
- pipe.to("cuda")
- else:
- print("Enable Model CPU Offload ...")
- pipe.enable_model_cpu_offload()
- pipe.enable_vae_tiling()
-else:
- pipe.to("cuda")
-
-if args.cache:
-
- cachify(
- args,
- pipe,
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- # (negative_prompt is not None, default None)
- enable_separate_cfg=False if negative_prompt is None else True,
- )
-
-
-def run_pipe():
- image = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- control_image=control_image.convert("RGB"),
- control_mask=mask_image,
- controlnet_conditioning_scale=1.0,
- width=mask_image.size[0],
- height=mask_image.size[1],
- num_inference_steps=50,
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(0),
- ).images[0]
-
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer.compile_repeated_blocks(mode="default")
-
- # warmup
- run_pipe()
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"qwen-image-controlnet-inpaint.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_qwen_image_edit.py b/examples/pipeline/run_qwen_image_edit.py
deleted file mode 100644
index 1689c861a..000000000
--- a/examples/pipeline/run_qwen_image_edit.py
+++ /dev/null
@@ -1,92 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-
-from PIL import Image
-from diffusers import QwenImageEditPipeline, QwenImageTransformer2DModel
-from utils import GiB, get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-pipe = QwenImageEditPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_EDIT_DIR",
- "Qwen/Qwen-Image-Edit",
- )
- ),
- torch_dtype=torch.bfloat16,
- # https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
- device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
-)
-
-if args.cache:
- cachify(args, pipe)
-
-# When device_map is None, we need to explicitly move the model to GPU
-# or enable CPU offload to avoid running on CPU
-if torch.cuda.device_count() <= 1:
- # Single GPU: use CPU offload for memory efficiency
- pipe.enable_model_cpu_offload()
-elif torch.cuda.device_count() > 1 and pipe.device.type == "cpu":
- # Multi-GPU but model is on CPU (device_map was None): move to default GPU
- pipe.to("cuda")
-
-image = Image.open("../data/bear.png").convert("RGB")
-prompt = "Only change the bear's color to purple"
-if args.prompt is not None:
- prompt = args.prompt
-
-if args.compile:
- assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
- torch._dynamo.config.recompile_limit = 1024
- torch._dynamo.config.accumulated_recompile_limit = 8192
- pipe.transformer.compile_repeated_blocks(mode="default")
-
- # Warmup
- image = pipe(
- image=image,
- prompt=prompt,
- negative_prompt=" ",
- generator=torch.Generator(device="cpu").manual_seed(0),
- true_cfg_scale=4.0,
- num_inference_steps=50,
- ).images[0]
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-
-image = pipe(
- image=image,
- prompt=prompt,
- negative_prompt=" ",
- generator=torch.Generator(device="cpu").manual_seed(0),
- true_cfg_scale=4.0,
- num_inference_steps=50,
-).images[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"qwen-image-edit.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_qwen_image_edit_plus.py b/examples/pipeline/run_qwen_image_edit_plus.py
deleted file mode 100644
index 5757cb1f7..000000000
--- a/examples/pipeline/run_qwen_image_edit_plus.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-
-import torch
-from PIL import Image
-from diffusers import QwenImageEditPlusPipeline, QwenImageTransformer2DModel
-from io import BytesIO
-import requests
-from utils import GiB, get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-pipe = QwenImageEditPlusPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_EDIT_2509_DIR",
- "Qwen/Qwen-Image-Edit-2509",
- )
- ),
- torch_dtype=torch.bfloat16,
- # https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
- device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
-)
-
-if args.cache:
- cachify(args, pipe)
-
-# When device_map is None, we need to explicitly move the model to GPU
-# or enable CPU offload to avoid running on CPU
-if torch.cuda.device_count() <= 1:
- # Single GPU: use CPU offload for memory efficiency
- pipe.enable_model_cpu_offload()
-elif torch.cuda.device_count() > 1 and pipe.device.type == "cpu":
- # Multi-GPU but model is on CPU (device_map was None): move to default GPU
- pipe.to("cuda")
-
-image1 = Image.open(
- BytesIO(
- requests.get(
- "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_1.jpg"
- ).content
- )
-)
-image2 = Image.open(
- BytesIO(
- requests.get(
- "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_2.jpg"
- ).content
- )
-)
-prompt = "The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square."
-if args.prompt is not None:
- prompt = args.prompt
-inputs = {
- "image": [image1, image2],
- "prompt": prompt,
- "generator": torch.Generator(device="cpu").manual_seed(0),
- "true_cfg_scale": 4.0,
- "negative_prompt": " ",
- "num_inference_steps": 40,
- "guidance_scale": 1.0,
- "num_images_per_prompt": 1,
-}
-
-if args.compile:
- assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
- torch._dynamo.config.recompile_limit = 1024
- torch._dynamo.config.accumulated_recompile_limit = 8192
- pipe.transformer.compile_repeated_blocks(mode="default")
-
- # Warmup
- image = pipe(**inputs).images[0]
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(**inputs).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"qwen-image-edit-plus.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_qwen_image_lightning.py b/examples/pipeline/run_qwen_image_lightning.py
deleted file mode 100644
index 28f3ec8cc..000000000
--- a/examples/pipeline/run_qwen_image_lightning.py
+++ /dev/null
@@ -1,192 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import math
-from diffusers import (
- QwenImagePipeline,
- QwenImageTransformer2DModel,
- FlowMatchEulerDiscreteScheduler,
-)
-from utils import GiB, get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-# From https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
-scheduler_config = {
- "base_image_seq_len": 256,
- "base_shift": math.log(3), # We use shift=3 in distillation
- "invert_sigmas": False,
- "max_image_seq_len": 8192,
- "max_shift": math.log(3), # We use shift=3 in distillation
- "num_train_timesteps": 1000,
- "shift": 1.0,
- "shift_terminal": None, # set shift_terminal to None
- "stochastic_sampling": False,
- "time_shift_type": "exponential",
- "use_beta_sigmas": False,
- "use_dynamic_shifting": True,
- "use_exponential_sigmas": False,
- "use_karras_sigmas": False,
-}
-scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
-
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- scheduler=scheduler,
- torch_dtype=torch.bfloat16,
- # https://huggingface.co/docs/diffusers/main/en/tutorials/inference_with_big_models#device-placement
- device_map=("balanced" if (torch.cuda.device_count() > 1 and GiB() <= 48) else None),
-)
-
-steps = 8 if args.steps is None else args.steps
-assert steps in [8, 4]
-
-pipe.load_lora_weights(
- os.environ.get(
- "QWEN_IMAGE_LIGHT_DIR",
- "lightx2v/Qwen-Image-Lightning",
- ),
- weight_name=(
- "Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors"
- if steps > 4
- else "Qwen-Image-Lightning-4steps-V1.0-bf16.safetensors"
- ),
-)
-
-if args.fuse_lora:
- pipe.fuse_lora()
- pipe.unload_lora_weights()
-
-
-if args.cache:
- from cache_dit import DBCacheConfig
-
- cachify(
- args,
- pipe,
- cache_config=DBCacheConfig(
- Fn_compute_blocks=16,
- Bn_compute_blocks=16,
- max_warmup_steps=4 if steps > 4 else 2,
- max_cached_steps=2 if steps > 4 else 1,
- max_continuous_cached_steps=1,
- enable_separate_cfg=False, # true_cfg_scale=1.0
- residual_diff_threshold=0.50 if steps > 4 else 0.8,
- ),
- )
-
-
-# When device_map is None, we need to explicitly move the model to GPU
-# or enable CPU offload to avoid running on CPU
-if torch.cuda.device_count() <= 1:
- # Single GPU: use CPU offload for memory efficiency
- pipe.enable_model_cpu_offload()
-elif torch.cuda.device_count() > 1 and pipe.device.type == "cpu":
- # Multi-GPU but model is on CPU (device_map was None): move to default GPU
- pipe.to("cuda")
-
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-if args.prompt is not None:
- prompt = args.prompt
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-
-# Generate with different aspect ratios
-aspect_ratios = {
- "1:1": (1328, 1328),
- "16:9": (1664, 928),
- "9:16": (928, 1664),
- "4:3": (1472, 1140),
- "3:4": (1140, 1472),
- "3:2": (1584, 1056),
- "2:3": (1056, 1584),
-}
-
-width, height = aspect_ratios["16:9"]
-
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
-if args.quantize:
- # Apply Quantization (default: FP8 DQ) to Transformer
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type,
- per_row=False,
- exclude_layers=[
- "img_in",
- "txt_in",
- "embedder",
- "embed",
- "norm_out",
- "proj_out",
- ],
- )
-
-
-def run_pipe():
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- image = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=width,
- height=height,
- num_inference_steps=steps,
- true_cfg_scale=1.0, # means no separate cfg
- generator=torch.Generator(device="cpu").manual_seed(42),
- ).images[0]
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer.compile_repeated_blocks(fullgraph=True)
-
- # warmup
- run_pipe()
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe, details=True)
-
-time_cost = end - start
-save_path = f"qwen-image-lightning.{steps}steps.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_sana.py b/examples/pipeline/run_sana.py
deleted file mode 100644
index adfa3dd84..000000000
--- a/examples/pipeline/run_sana.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import SanaPipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("SANA_DIR", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers")
-)
-
-pipe = SanaPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
-).to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-prompt = "a tiny astronaut hatching from an egg on the moon"
-
-
-if args.prompt is not None:
-
- prompt = args.prompt
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = pipe(
- prompt,
- num_inference_steps=20,
- generator=torch.Generator("cpu").manual_seed(1),
-).images[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"sana.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_sd_3.5.py b/examples/pipeline/run_sd_3.5.py
deleted file mode 100644
index 0589e052f..000000000
--- a/examples/pipeline/run_sd_3.5.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-
-from diffusers import StableDiffusion3Pipeline
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "SD_3_5_DIR",
- "stabilityai/stable-diffusion-3.5-large",
- )
-)
-
-pipe = StableDiffusion3Pipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
- device_map="balanced",
-)
-
-if args.cache:
- cachify(args, pipe)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
-if args.prompt is not None:
- prompt = args.prompt
-image = pipe(
- prompt,
- num_inference_steps=50,
- generator=torch.Generator(device="cpu").manual_seed(42),
-).images[0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-time_cost = end - start
-save_path = f"sd_3_5.{strify(args, stats)}.png"
-
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_skyreels_v2.py b/examples/pipeline/run_skyreels_v2.py
deleted file mode 100644
index fc913c00d..000000000
--- a/examples/pipeline/run_skyreels_v2.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import AutoModel, SkyReelsV2Pipeline, UniPCMultistepScheduler
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers.utils import export_to_video
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get("SKYREELS_V2_DIR", "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers")
-)
-
-vae = AutoModel.from_pretrained(
- model_id,
- subfolder="vae",
- torch_dtype=torch.float32,
-).to("cuda")
-
-pipe = SkyReelsV2Pipeline.from_pretrained(
- model_id,
- vae=vae,
- torch_dtype=torch.bfloat16,
- quantization_config=PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["transformer", "text_encoder"],
- ),
-)
-
-pipe.to("cuda")
-
-flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
-
-if args.cache:
- cachify(args, pipe)
-
-prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
-
-
-if args.prompt is not None:
-
- prompt = args.prompt
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = pipe(
- prompt=prompt,
- num_inference_steps=50,
- height=720, # 720 for 720P
- width=1280, # 1280 for 720P
- num_frames=21,
- generator=torch.Generator("cpu").manual_seed(0),
-).frames[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe, details=True)
-
-time_cost = end - start
-save_path = f"skyreels_v2.{strify(args, pipe)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=8, quality=8)
diff --git a/examples/pipeline/run_visual_cloze.py b/examples/pipeline/run_visual_cloze.py
deleted file mode 100644
index a37793e92..000000000
--- a/examples/pipeline/run_visual_cloze.py
+++ /dev/null
@@ -1,88 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import VisualClozePipeline
-from diffusers.utils import load_image
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-# Load the VisualClozePipeline
-pipe = VisualClozePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "VISUAL_CLOZE_DIR",
- "VisualCloze/VisualClozePipeline-512",
- )
- ),
- resolution=512,
- torch_dtype=torch.bfloat16,
-)
-pipe.to("cuda")
-
-if args.cache:
- cachify(args, pipe)
-
-# Load in-context images (make sure the paths are correct and accessible)
-# The images are from the VITON-HD dataset at https://github.com/shadow2496/VITON-HD
-image_paths = [
- # in-context examples
- [
- load_image("../data/visualcloze/00700_00.jpg"),
- load_image("../data/visualcloze/03673_00.jpg"),
- load_image("../data/visualcloze/00700_00_tryon_catvton_0.jpg"),
- ],
- # query with the target image
- [
- load_image("../data/visualcloze/00555_00.jpg"),
- load_image("../data/visualcloze/12265_00.jpg"),
- None,
- ],
-]
-
-# Task and content prompt
-task_prompt = "Each row shows a virtual try-on process that aims to put [IMAGE2] the clothing onto [IMAGE1] the person, producing [IMAGE3] the person wearing the new clothing."
-content_prompt = None
-
-# Run the pipeline
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-
-image = pipe(
- task_prompt=task_prompt,
- content_prompt=content_prompt,
- image=image_paths,
- upsampling_height=1632,
- upsampling_width=1232,
- upsampling_strength=0.3,
- guidance_scale=30,
- num_inference_steps=30,
- max_sequence_length=512,
- generator=torch.Generator("cpu").manual_seed(0),
-).images[0][0]
-
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"visualcloze-512.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/pipeline/run_wan.py b/examples/pipeline/run_wan.py
deleted file mode 100644
index a2882d5f7..000000000
--- a/examples/pipeline/run_wan.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import diffusers
-from diffusers import WanPipeline, AutoencoderKLWan
-from diffusers.utils import export_to_video
-from diffusers.schedulers.scheduling_unipc_multistep import (
- UniPCMultistepScheduler,
-)
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-height, width = 480, 832
-pipe = WanPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "WAN_DIR",
- "Wan-AI/Wan2.1-T2V-1.3B-Diffusers", # "num_layers": 30,
- )
- ),
- torch_dtype=torch.bfloat16,
-)
-
-# flow shift should be 3.0 for 480p images, 5.0 for 720p images
-if hasattr(pipe, "scheduler") and pipe.scheduler is not None:
- # Use the UniPCMultistepScheduler with the specified flow shift
- flow_shift = 3.0 if height == 480 else 5.0
- pipe.scheduler = UniPCMultistepScheduler.from_config(
- pipe.scheduler.config,
- flow_shift=flow_shift,
- )
-
-
-if args.cache:
- cachify(args, pipe)
-
-# Enable memory savings
-pipe.enable_model_cpu_offload()
-
-# Wan currently requires installing diffusers from source
-assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
-if diffusers.__version__ >= "0.34.0":
- pipe.vae.enable_tiling()
- pipe.vae.enable_slicing()
-else:
- print(
- "Wan pipeline requires diffusers version >= 0.34.0 "
- "for vae tiling and slicing, please install diffusers "
- "from source."
- )
-
-prompt = (
- "An astronaut dancing vigorously on the moon with earth "
- "flying past in the background, hyperrealistic"
-)
-if args.prompt is not None:
- prompt = args.prompt
-
-negative_prompt = ""
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = pipe(
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=height,
- width=width,
- num_frames=49,
- num_inference_steps=35,
- generator=torch.Generator("cpu").manual_seed(0),
-).frames[0]
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"wan.{strify(args, stats)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=16)
diff --git a/examples/pipeline/run_wan_2.2_i2v.py b/examples/pipeline/run_wan_2.2_i2v.py
deleted file mode 100644
index 252ec0ec2..000000000
--- a/examples/pipeline/run_wan_2.2_i2v.py
+++ /dev/null
@@ -1,185 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import diffusers
-from diffusers import (
- AutoencoderKLWan,
- WanTransformer3DModel,
- WanImageToVideoPipeline,
-)
-from diffusers.utils import export_to_video, load_image
-
-from utils import get_args, GiB, strify, cachify, MemoryTracker
-import cache_dit
-import numpy as np
-
-args = get_args()
-print(args)
-
-model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "WAN_2_2_I2V_DIR",
- "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
- )
-)
-
-pipe: WanImageToVideoPipeline = WanImageToVideoPipeline.from_pretrained(
- model_id,
- torch_dtype=torch.bfloat16,
- # Based on: https://github.com/huggingface/diffusers/pull/12523
- device_map=("balanced" if GiB() < 96 and torch.cuda.device_count() > 1 else None),
-)
-
-# When device_map is None, we need to explicitly move the model to GPU
-# or enable CPU offload to avoid running on CPU
-if GiB() < 96 and torch.cuda.device_count() <= 1:
- # issue: https://github.com/huggingface/diffusers/issues/12499
- print("Enable model cpu offload for low memory device.")
- pipe.enable_model_cpu_offload()
-elif torch.cuda.device_count() > 1 and pipe.device.type == "cpu":
- # Multi-GPU but model is on CPU (device_map was None): move to default GPU
- pipe.to("cuda")
-
-
-if args.cache:
- from cache_dit import (
- ForwardPattern,
- BlockAdapter,
- ParamsModifier,
- DBCacheConfig,
- )
-
- cachify(
- args,
- BlockAdapter(
- pipe=pipe,
- transformer=[
- pipe.transformer,
- pipe.transformer_2,
- ],
- blocks=[
- pipe.transformer.blocks,
- pipe.transformer_2.blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_2,
- ForwardPattern.Pattern_2,
- ],
- params_modifiers=[
- # high-noise transformer only have 30% steps
- ParamsModifier(
- cache_config=DBCacheConfig().reset(
- max_warmup_steps=4,
- max_cached_steps=8,
- ),
- ),
- ParamsModifier(
- cache_config=DBCacheConfig().reset(
- max_warmup_steps=2,
- max_cached_steps=20,
- ),
- ),
- ],
- has_separate_cfg=True,
- ),
- )
-
-# Wan currently requires installing diffusers from source
-assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
-if diffusers.__version__ >= "0.34.0":
- pipe.vae.enable_tiling()
- pipe.vae.enable_slicing()
-else:
- print(
- "Wan pipeline requires diffusers version >= 0.34.0 "
- "for vae tiling and slicing, please install diffusers "
- "from source."
- )
-
-assert isinstance(pipe.transformer, WanTransformer3DModel)
-assert isinstance(pipe.transformer_2, WanTransformer3DModel)
-
-if args.quantize:
- assert isinstance(args.quantize_type, str)
- if args.quantize_type.endswith("wo"): # weight only
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type,
- )
- # We only apply activation quantization (default: FP8 DQ)
- # for low-noise transformer to avoid non-trivial precision
- # downgrade.
- pipe.transformer_2 = cache_dit.quantize(
- pipe.transformer_2,
- quant_type=args.quantize_type,
- )
-
-
-image = load_image(
- "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG"
-)
-
-max_area = 480 * 832
-aspect_ratio = image.height / image.width
-mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
-height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
-width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
-image = image.resize((width, height))
-
-prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
-if args.prompt is not None:
- prompt = args.prompt
-negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-
-def run_pipe():
- video = pipe(
- image=image,
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=height,
- width=width,
- num_frames=81, # pipe.vae_scale_factor_temporal=4
- guidance_scale=3.5,
- num_inference_steps=50,
- generator=torch.Generator(device="cpu").manual_seed(0),
- ).frames[0]
-
- return video
-
-
-if args.compile or args.quantize:
- cache_dit.set_compile_configs()
- pipe.transformer.compile_repeated_blocks(fullgraph=True)
- pipe.transformer_2.compile_repeated_blocks(fullgraph=True)
-
- # warmup
- run_pipe()
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-video = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe, details=True)
-
-time_cost = end - start
-save_path = f"wan2.2-i2v.frame{len(video)}.{height}x{width}.{strify(args, pipe)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(video, save_path, fps=16)
diff --git a/examples/pipeline/run_wan_flf2v.py b/examples/pipeline/run_wan_flf2v.py
deleted file mode 100644
index 9d973ebce..000000000
--- a/examples/pipeline/run_wan_flf2v.py
+++ /dev/null
@@ -1,142 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import diffusers
-import argparse
-import numpy as np
-import torchvision.transforms.functional as TF
-from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
-from diffusers.utils import export_to_video, load_image
-from transformers import CLIPVisionModel
-
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-def aspect_ratio_resize(image, pipe, max_area=720 * 1280):
- aspect_ratio = image.height / image.width
- mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
- height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
- width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
- image = image.resize((width, height))
- return image, height, width
-
-
-def center_crop_resize(image, height, width):
- # Calculate resize ratio to match first frame dimensions
- resize_ratio = max(width / image.width, height / image.height)
-
- # Resize the image
- width = round(image.width * resize_ratio)
- height = round(image.height * resize_ratio)
- size = [width, height]
- image = TF.center_crop(image, size)
-
- return image, height, width
-
-
-def prepare_pipeline(
- pipe: WanImageToVideoPipeline,
- args: argparse.ArgumentParser,
-):
- if args.cache:
- cachify(args, pipe)
-
- # Enable memory savings
- pipe.enable_model_cpu_offload()
-
- # Wan currently requires installing diffusers from source
- assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
- if diffusers.__version__ >= "0.34.0":
- pipe.vae.enable_tiling()
- pipe.vae.enable_slicing()
- else:
- print(
- "Wan pipeline requires diffusers version >= 0.34.0 "
- "for vae tiling and slicing, please install diffusers "
- "from source."
- )
-
- return pipe
-
-
-def main():
- args = get_args()
- print(args)
-
- model_id = (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "WAN_FLF2V_DIR",
- "Wan-AI/Wan2.1-FLF2V-14B-720P-Diffusers",
- )
- )
- image_encoder = CLIPVisionModel.from_pretrained(
- model_id, subfolder="image_encoder", torch_dtype=torch.float32
- )
- vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
- pipe = WanImageToVideoPipeline.from_pretrained(
- model_id,
- vae=vae,
- image_encoder=image_encoder,
- torch_dtype=torch.bfloat16,
- )
- pipe.to("cuda")
-
- pipe = prepare_pipeline(pipe, args)
-
- first_frame = load_image("../data/flf2v_input_first_frame.png")
- last_frame = load_image("../data/flf2v_input_last_frame.png")
-
- first_frame, height, width = aspect_ratio_resize(first_frame, pipe)
- if last_frame.size != first_frame.size:
- last_frame, _, _ = center_crop_resize(last_frame, height, width)
-
- # Set default prompt
- prompt = (
- "CG animation style, a small blue bird takes off from the ground, flapping its wings. "
- + "The bird's feathers are delicate, with a unique pattern on its chest. The background shows "
- + "a blue sky with white clouds under bright sunshine. The camera follows the bird upward, "
- + "capturing its flight and the vastness of the sky from a close-up, low-angle perspective."
- )
- if args.prompt is not None:
- prompt = args.prompt
-
- memory_tracker = MemoryTracker() if args.track_memory else None
- if memory_tracker:
- memory_tracker.__enter__()
-
- start = time.time()
- output = pipe(
- image=first_frame,
- last_image=last_frame,
- prompt=prompt,
- height=height,
- width=width,
- guidance_scale=5.5,
- num_frames=49,
- num_inference_steps=35,
- generator=torch.Generator("cpu").manual_seed(0),
- ).frames[0]
- end = time.time()
-
- if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
- stats = cache_dit.summary(pipe)
-
- time_cost = end - start
- save_path = f"wan.flf2v.{strify(args, stats)}.mp4"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving video to {save_path}")
- export_to_video(output, save_path, fps=16)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/pipeline/run_wan_vace.py b/examples/pipeline/run_wan_vace.py
deleted file mode 100644
index 1787a207f..000000000
--- a/examples/pipeline/run_wan_vace.py
+++ /dev/null
@@ -1,133 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import PIL.Image
-from diffusers import AutoencoderKLWan, WanVACEPipeline
-from diffusers.schedulers.scheduling_unipc_multistep import (
- UniPCMultistepScheduler,
-)
-from diffusers.utils import export_to_video, load_image
-
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-def prepare_video_and_mask(
- first_img: PIL.Image.Image,
- last_img: PIL.Image.Image,
- height: int,
- width: int,
- num_frames: int,
-):
- first_img = first_img.resize((width, height))
- last_img = last_img.resize((width, height))
- frames = []
- frames.append(first_img)
- # Ideally, this should be 127.5 to match original code, but they perform computation on numpy arrays
- # whereas we are passing PIL images. If you choose to pass numpy arrays, you can set it to 127.5 to
- # match the original code.
- frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2))
- frames.append(last_img)
- mask_black = PIL.Image.new("L", (width, height), 0)
- mask_white = PIL.Image.new("L", (width, height), 255)
- mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black]
- return frames, mask
-
-
-model_id = "Wan-AI/Wan2.1-VACE-1.3B-diffusers"
-model_id = (
- args.model_path if args.model_path is not None else os.environ.get("WAN_VACE_DIR", model_id)
-)
-vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
-pipe = WanVACEPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
-flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
-pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
-
-if args.cache:
- cachify(args, pipe)
-
-# Enable memory savings
-pipe.enable_model_cpu_offload()
-
-assert isinstance(pipe.vae, AutoencoderKLWan) # enable type check for IDE
-pipe.vae.enable_tiling()
-pipe.vae.enable_slicing()
-
-prompt = (
- "CG animation style, a small blue bird takes off from the ground, "
- "flapping its wings. The bird's feathers are delicate, with a unique "
- "pattern on its chest. The background shows a blue sky with white "
- "clouds under bright sunshine. The camera follows the bird upward, "
- "capturing its flight and the vastness of the sky from a close-up, "
- "low-angle perspective."
-)
-if args.prompt is not None:
- prompt = args.prompt
-
-negative_prompt = (
- "Bright tones, overexposed, static, blurred details, subtitles, "
- "style, works, paintings, images, static, overall gray, worst "
- "quality, low quality, JPEG compression residue, ugly, incomplete, "
- "extra fingers, poorly drawn hands, poorly drawn faces, deformed, "
- "disfigured, misshapen limbs, fused fingers, still picture, messy "
- "background, three legs, many people in the background, walking "
- "backwards"
-)
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-first_frame = load_image("../data/flf2v_input_first_frame.png")
-last_frame = load_image("../data/flf2v_input_last_frame.png")
-
-height = 512
-width = 512
-num_frames = 81
-video, mask = prepare_video_and_mask(first_frame, last_frame, height, width, num_frames)
-
-
-def run_pipe(warmup: bool = False):
- output = pipe(
- video=video,
- mask=mask,
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=height,
- width=width,
- num_frames=num_frames,
- num_inference_steps=30 if not warmup else 5,
- guidance_scale=5.0,
- generator=torch.Generator("cpu").manual_seed(42),
- ).frames[0]
- return output
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-output = run_pipe(warmup=False)
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"wan-vace.{strify(args, stats)}.mp4"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving video to {save_path}")
-export_to_video(output, save_path, fps=16)
diff --git a/examples/quantize/run_flux_ao.py b/examples/quantize/run_flux_ao.py
deleted file mode 100644
index 3d238c7fb..000000000
--- a/examples/quantize/run_flux_ao.py
+++ /dev/null
@@ -1,84 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers import FluxPipeline, FluxTransformer2DModel
-from utils import get_args, strify, cachify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-pipe: FluxPipeline = FluxPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "FLUX_DIR",
- "black-forest-labs/FLUX.1-dev",
- )
- ),
- torch_dtype=torch.bfloat16,
-).to("cuda")
-
-
-if args.cache:
- cachify(args, pipe)
-
-
-if args.quantize:
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
- pipe.transformer = cache_dit.quantize(
- pipe.transformer,
- quant_type=args.quantize_type,
- )
-
-# Set default prompt
-prompt = "A cat holding a sign that says hello world"
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe(warmup: bool = False):
- image = pipe(
- prompt,
- width=1024 if args.width is None else args.width,
- height=1024 if args.height is None else args.height,
- num_inference_steps=((28 if args.steps is None else args.steps) if not warmup else 5),
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-if args.compile:
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
- pipe.transformer.compile_repeated_blocks()
-
-
-# warmup
-_ = run_pipe(warmup=True)
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"flux.ao.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/quantize/run_flux_nunchaku.py b/examples/quantize/run_flux_nunchaku.py
deleted file mode 100644
index 6e4806ba5..000000000
--- a/examples/quantize/run_flux_nunchaku.py
+++ /dev/null
@@ -1,115 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-import time
-
-import torch
-from diffusers import FluxPipeline, FluxTransformer2DModel
-
-from nunchaku.models.transformers.transformer_flux_v2 import (
- NunchakuFluxTransformer2DModelV2,
-)
-from utils import get_args, strify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-nunchaku_flux_dir = os.environ.get(
- "NUNCHAKA_FLUX_DIR",
- "nunchaku-tech/nunchaku-flux.1-dev",
-)
-transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
- f"{nunchaku_flux_dir}/svdq-int4_r32-flux.1-dev.safetensors",
-)
-pipe: FluxPipeline = FluxPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get("FLUX_DIR", "black-forest-labs/FLUX.1-dev")
- ),
- transformer=transformer,
- torch_dtype=torch.bfloat16,
-).to("cuda")
-
-
-if args.cache:
- from cache_dit import (
- ParamsModifier,
- DBCacheConfig,
- TaylorSeerCalibratorConfig,
- )
-
- cache_dit.enable_cache(
- pipe,
- cache_config=DBCacheConfig(
- Fn_compute_blocks=args.Fn,
- Bn_compute_blocks=args.Bn,
- max_warmup_steps=args.max_warmup_steps,
- max_cached_steps=args.max_cached_steps,
- max_continuous_cached_steps=args.max_continuous_cached_steps,
- residual_diff_threshold=args.rdt,
- ),
- calibrator_config=(
- TaylorSeerCalibratorConfig(
- taylorseer_order=args.taylorseer_order,
- )
- if args.taylorseer
- else None
- ),
- params_modifiers=[
- ParamsModifier(
- # transformer_blocks
- cache_config=DBCacheConfig().reset(residual_diff_threshold=args.rdt),
- ),
- ParamsModifier(
- # single_transformer_blocks
- cache_config=DBCacheConfig().reset(residual_diff_threshold=args.rdt * 3),
- ),
- ],
- )
-
-# Set default prompt
-prompt = "A cat holding a sign that says hello world"
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe(pipe: FluxPipeline):
- image = pipe(
- prompt,
- num_inference_steps=28,
- generator=torch.Generator("cpu").manual_seed(0),
- ).images[0]
- return image
-
-
-if args.compile:
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
- # warmup
- _ = run_pipe(pipe)
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe(pipe)
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"flux.nunchaku.int4.{strify(args, pipe)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/quantize/run_qwen_image_edit_plus_lightning_nunchaku.py b/examples/quantize/run_qwen_image_edit_plus_lightning_nunchaku.py
deleted file mode 100644
index a761e6261..000000000
--- a/examples/quantize/run_qwen_image_edit_plus_lightning_nunchaku.py
+++ /dev/null
@@ -1,168 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import math
-
-import torch
-from PIL import Image
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers import QwenImageEditPlusPipeline, QwenImageTransformer2DModel
-from diffusers import FlowMatchEulerDiscreteScheduler
-from nunchaku import NunchakuQwenImageTransformer2DModel
-
-from io import BytesIO
-import requests
-from utils import get_args, strify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-# From https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
-scheduler_config = {
- "base_image_seq_len": 256,
- "base_shift": math.log(3), # We use shift=3 in distillation
- "invert_sigmas": False,
- "max_image_seq_len": 8192,
- "max_shift": math.log(3), # We use shift=3 in distillation
- "num_train_timesteps": 1000,
- "shift": 1.0,
- "shift_terminal": None, # set shift_terminal to None
- "stochastic_sampling": False,
- "time_shift_type": "exponential",
- "use_beta_sigmas": False,
- "use_dynamic_shifting": True,
- "use_exponential_sigmas": False,
- "use_karras_sigmas": False,
-}
-scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
-
-steps = 8 if args.steps is None else args.steps
-assert steps in [8, 4]
-
-nunchaku_qwen_image_edit_plus_dir = os.environ.get(
- "NUNCHAKA_QWEN_IMAGE_EDIT_2509_DIR",
- "nunchaku-tech/nunchaku-qwen-image-edit-2509",
-)
-
-transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
- f"{nunchaku_qwen_image_edit_plus_dir}/svdq-int4_r128-qwen-image-edit-2509-lightningv2.0-{steps}steps.safetensors"
-)
-
-# Minimize VRAM required: 25GiB if use w4a16_text_encoder else 35GiB
-w4a16_text_encoder = False
-pipe = QwenImageEditPlusPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_EDIT_2509_DIR",
- "Qwen/Qwen-Image-Edit-2509",
- )
- ),
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder"],
- )
- if w4a16_text_encoder
- else None
- ),
-).to("cuda")
-
-if args.cache:
- from cache_dit import (
- DBCacheConfig,
- TaylorSeerCalibratorConfig,
- )
-
- cache_dit.enable_cache(
- pipe,
- cache_config=DBCacheConfig(
- Fn_compute_blocks=16,
- Bn_compute_blocks=16,
- max_warmup_steps=4 if steps > 4 else 2,
- warmup_interval=2 if steps > 4 else 1,
- max_cached_steps=2 if steps > 4 else 1,
- max_continuous_cached_steps=1,
- enable_separate_cfg=False, # true_cfg_scale=1.0
- residual_diff_threshold=0.50 if steps > 4 else 0.8,
- ),
- calibrator_config=(
- TaylorSeerCalibratorConfig(
- taylorseer_order=args.taylorseer_order,
- )
- if args.taylorseer
- else None
- ),
- )
-
-
-image1 = Image.open(
- BytesIO(
- requests.get(
- "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_1.jpg"
- ).content
- )
-)
-image2 = Image.open(
- BytesIO(
- requests.get(
- "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_2.jpg"
- ).content
- )
-)
-prompt = "The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square."
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe():
- inputs = {
- "image": [image1, image2],
- "prompt": prompt,
- "generator": torch.Generator(device="cpu").manual_seed(0),
- "true_cfg_scale": 1.0,
- "negative_prompt": " ",
- "num_inference_steps": steps,
- }
- return pipe(**inputs).images[0]
-
-
-if args.compile:
- assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
- cache_dit.set_compile_configs()
- pipe.transformer.compile_repeated_blocks(mode="default")
-
- # Warmup
- run_pipe()
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe, details=True)
-
-time_cost = end - start
-save_path = f"qwen-image-edit-plus-lightning.{steps}steps.nunchaku.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/quantize/run_qwen_image_edit_plus_nunchaku.py b/examples/quantize/run_qwen_image_edit_plus_nunchaku.py
deleted file mode 100644
index 5de7149bc..000000000
--- a/examples/quantize/run_qwen_image_edit_plus_nunchaku.py
+++ /dev/null
@@ -1,145 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-
-import torch
-from PIL import Image
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers import QwenImageEditPlusPipeline, QwenImageTransformer2DModel
-from nunchaku import NunchakuQwenImageTransformer2DModel
-
-from io import BytesIO
-import requests
-from utils import get_args, strify, MemoryTracker
-import cache_dit
-
-args = get_args()
-print(args)
-
-
-nunchaku_qwen_image_edit_plus_dir = os.environ.get(
- "NUNCHAKA_QWEN_IMAGE_EDIT_2509_DIR",
- "nunchaku-tech/nunchaku-qwen-image-edit-2509",
-)
-
-transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
- f"{nunchaku_qwen_image_edit_plus_dir}/svdq-int4_r128-qwen-image-edit-2509.safetensors"
-)
-
-# Minimize VRAM required: 20GiB if use w4a16_text_encoder else 30GiB
-w4a16_text_encoder = False
-pipe = QwenImageEditPlusPipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_EDIT_2509_DIR",
- "Qwen/Qwen-Image-Edit-2509",
- )
- ),
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder"],
- )
- if w4a16_text_encoder
- else None
- ),
-).to("cuda")
-
-if args.cache:
- from cache_dit import (
- DBCacheConfig,
- TaylorSeerCalibratorConfig,
- )
-
- cache_dit.enable_cache(
- pipe,
- cache_config=DBCacheConfig(
- Fn_compute_blocks=args.Fn,
- Bn_compute_blocks=args.Bn,
- max_warmup_steps=args.max_warmup_steps,
- max_cached_steps=args.max_cached_steps,
- max_continuous_cached_steps=args.max_continuous_cached_steps,
- residual_diff_threshold=args.rdt,
- ),
- calibrator_config=(
- TaylorSeerCalibratorConfig(
- taylorseer_order=args.taylorseer_order,
- )
- if args.taylorseer
- else None
- ),
- )
-
-
-image1 = Image.open(
- BytesIO(
- requests.get(
- "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_1.jpg"
- ).content
- )
-)
-image2 = Image.open(
- BytesIO(
- requests.get(
- "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-Image/edit2509/edit2509_2.jpg"
- ).content
- )
-)
-prompt = "The magician bear is on the left, the alchemist bear is on the right, facing each other in the central park square."
-if args.prompt is not None:
- prompt = args.prompt
-
-
-def run_pipe():
- inputs = {
- "image": [image1, image2],
- "prompt": prompt,
- "generator": torch.Generator(device="cpu").manual_seed(0),
- "true_cfg_scale": 4.0,
- "negative_prompt": " ",
- "num_inference_steps": 40,
- "guidance_scale": 1.0,
- "num_images_per_prompt": 1,
- }
- return pipe(**inputs).images[0]
-
-
-if args.compile:
- assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
- cache_dit.set_compile_configs()
- pipe.transformer.compile_repeated_blocks(mode="default")
-
- # Warmup
- run_pipe()
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"qwen-image-edit-plus.nunchaku.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/quantize/run_qwen_image_lightning_nunchaku.py b/examples/quantize/run_qwen_image_lightning_nunchaku.py
deleted file mode 100644
index 5f97a42b5..000000000
--- a/examples/quantize/run_qwen_image_lightning_nunchaku.py
+++ /dev/null
@@ -1,188 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-import math
-from diffusers import (
- QwenImagePipeline,
- QwenImageTransformer2DModel,
- FlowMatchEulerDiscreteScheduler,
- PipelineQuantizationConfig,
-)
-from nunchaku.models.transformers.transformer_qwenimage import (
- NunchakuQwenImageTransformer2DModel,
-)
-from utils import get_args, strify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-
-# From https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
-scheduler_config = {
- "base_image_seq_len": 256,
- "base_shift": math.log(3), # We use shift=3 in distillation
- "invert_sigmas": False,
- "max_image_seq_len": 8192,
- "max_shift": math.log(3), # We use shift=3 in distillation
- "num_train_timesteps": 1000,
- "shift": 1.0,
- "shift_terminal": None, # set shift_terminal to None
- "stochastic_sampling": False,
- "time_shift_type": "exponential",
- "use_beta_sigmas": False,
- "use_dynamic_shifting": True,
- "use_exponential_sigmas": False,
- "use_karras_sigmas": False,
-}
-scheduler = FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
-
-steps = 8 if args.steps is None else args.steps
-assert steps in [8, 4]
-
-nunchaku_qwen_image_dir = os.environ.get(
- "NUNCHAKA_QWEN_IMAGE_DIR",
- "nunchaku-tech/nunchaku-qwen-image",
-)
-lightning_version = "v1.1" if steps == 8 else "v1.0"
-transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
- f"{nunchaku_qwen_image_dir}/svdq-int4_r32-qwen-image-lightning"
- f"{lightning_version}-{steps}steps.safetensors"
-)
-
-# Minimize VRAM required: 25GiB if use w4a16_text_encoder else 35GiB
-w4a16_text_encoder = False
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- transformer=transformer,
- scheduler=scheduler,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder"],
- )
- if w4a16_text_encoder
- else None
- ),
-).to("cuda")
-
-
-if args.cache:
- from cache_dit import (
- DBCacheConfig,
- TaylorSeerCalibratorConfig,
- )
-
- cache_dit.enable_cache(
- pipe,
- cache_config=DBCacheConfig(
- Fn_compute_blocks=16,
- Bn_compute_blocks=16,
- max_warmup_steps=4 if steps > 4 else 2,
- warmup_interval=2 if steps > 4 else 1,
- max_cached_steps=2 if steps > 4 else 1,
- max_continuous_cached_steps=1,
- enable_separate_cfg=False, # true_cfg_scale=1.0
- residual_diff_threshold=0.50 if steps > 4 else 0.8,
- ),
- calibrator_config=(
- TaylorSeerCalibratorConfig(
- taylorseer_order=args.taylorseer_order,
- )
- if args.taylorseer
- else None
- ),
- )
-
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-if args.prompt is not None:
- prompt = args.prompt
-
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-
-# Generate with different aspect ratios
-aspect_ratios = {
- "1:1": (1328, 1328),
- "16:9": (1664, 928),
- "9:16": (928, 1664),
- "4:3": (1472, 1140),
- "3:4": (1140, 1472),
- "3:2": (1584, 1056),
- "2:3": (1056, 1584),
-}
-
-width, height = aspect_ratios["16:9"]
-
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
-
-def run_pipe():
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- image = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=width,
- height=height,
- num_inference_steps=steps,
- true_cfg_scale=1.0, # means no separate cfg
- generator=torch.Generator(device="cpu").manual_seed(42),
- ).images[0]
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer.compile_repeated_blocks(fullgraph=True)
-
- # warmup
- run_pipe()
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe, details=True)
-
-time_cost = end - start
-save_path = f"qwen-image-lightning.nunchaku.{steps}steps.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/quantize/run_qwen_image_nunchaku.py b/examples/quantize/run_qwen_image_nunchaku.py
deleted file mode 100644
index 83e600947..000000000
--- a/examples/quantize/run_qwen_image_nunchaku.py
+++ /dev/null
@@ -1,156 +0,0 @@
-import os
-import sys
-
-sys.path.append("..")
-
-import time
-import torch
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
-from nunchaku.models.transformers.transformer_qwenimage import (
- NunchakuQwenImageTransformer2DModel,
-)
-
-from utils import get_args, strify, MemoryTracker
-import cache_dit
-
-
-args = get_args()
-print(args)
-
-nunchaku_qwen_image_dir = os.environ.get(
- "NUNCHAKA_QWEN_IMAGE_DIR",
- "nunchaku-tech/nunchaku-qwen-image",
-)
-transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
- f"{nunchaku_qwen_image_dir}/svdq-int4_r32-qwen-image.safetensors"
-)
-
-# Minimize VRAM required: 20GiB if use w4a16_text_encoder else 30GiB
-w4a16_text_encoder = False
-pipe = QwenImagePipeline.from_pretrained(
- (
- args.model_path
- if args.model_path is not None
- else os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- )
- ),
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- components_to_quantize=["text_encoder"],
- )
- if w4a16_text_encoder
- else None
- ),
-).to("cuda")
-
-
-if args.cache:
- from cache_dit import (
- DBCacheConfig,
- TaylorSeerCalibratorConfig,
- )
-
- cache_dit.enable_cache(
- pipe,
- cache_config=DBCacheConfig(
- Fn_compute_blocks=args.Fn,
- Bn_compute_blocks=args.Bn,
- max_warmup_steps=args.max_warmup_steps,
- max_cached_steps=args.max_cached_steps,
- max_continuous_cached_steps=args.max_continuous_cached_steps,
- residual_diff_threshold=args.rdt,
- ),
- calibrator_config=(
- TaylorSeerCalibratorConfig(
- taylorseer_order=args.taylorseer_order,
- )
- if args.taylorseer
- else None
- ),
- )
-
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-if args.prompt is not None:
- prompt = args.prompt
-
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-if args.negative_prompt is not None:
- negative_prompt = args.negative_prompt
-
-
-# Generate with different aspect ratios
-aspect_ratios = {
- "1:1": (1328, 1328),
- "16:9": (1664, 928),
- "9:16": (928, 1664),
- "4:3": (1472, 1140),
- "3:4": (1140, 1472),
- "3:2": (1584, 1056),
- "2:3": (1056, 1584),
-}
-
-width, height = aspect_ratios["16:9"]
-
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
-
-def run_pipe():
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- image = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=width,
- height=height,
- num_inference_steps=50,
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(42),
- ).images[0]
- return image
-
-
-if args.compile:
- cache_dit.set_compile_configs()
- pipe.transformer = torch.compile(pipe.transformer)
-
- # warmup
- run_pipe()
-
-
-memory_tracker = MemoryTracker() if args.track_memory else None
-if memory_tracker:
- memory_tracker.__enter__()
-
-start = time.time()
-image = run_pipe()
-end = time.time()
-
-if memory_tracker:
- memory_tracker.__exit__(None, None, None)
- memory_tracker.report()
-
-stats = cache_dit.summary(pipe)
-
-time_cost = end - start
-save_path = f"qwen-image.nunchaku.{strify(args, stats)}.png"
-print(f"Time cost: {time_cost:.2f}s")
-print(f"Saving image to {save_path}")
-image.save(save_path)
diff --git a/examples/registers.py b/examples/registers.py
new file mode 100644
index 000000000..dddfd64d3
--- /dev/null
+++ b/examples/registers.py
@@ -0,0 +1,1169 @@
+import os
+import math
+import torch
+import argparse
+import PIL.Image
+import cache_dit
+import numpy as np
+from typing import Tuple, List, Optional
+from diffusers.utils import load_image
+from diffusers import FlowMatchEulerDiscreteScheduler
+from cache_dit import DBCacheConfig, ParamsModifier
+from cache_dit.logger import init_logger
+
+from base import (
+ Example,
+ ExampleType,
+ ExampleInputData,
+ ExampleInitConfig,
+ ExampleRegister,
+)
+
+logger = init_logger(__name__)
+
+
+__all__ = [
+ "flux_example",
+ "flux_fill_example",
+ "flux2_example",
+ "flux2_klein_example",
+ "qwen_image_example",
+ "qwen_image_controlnet_example",
+ "qwen_image_edit_example",
+ "qwen_image_layered_example",
+ "skyreels_v2_example",
+ "ltx2_t2v_example",
+ "ltx2_i2v_example",
+ "wan_example",
+ "wan_i2v_example",
+ "wan_vace_example",
+ "ovis_image_example",
+ "zimage_example",
+ "zimage_controlnet_example",
+ "longcat_image_example",
+ "longcat_image_edit_example",
+]
+
+
+# Please note that the following environment variables is only for debugging and
+# development purpose. In practice, users should directly provide the model names
+# or paths. The default values are the official model names on HuggingFace Hub.
+_env_path_mapping = {
+ "FLUX_DIR": "black-forest-labs/FLUX.1-dev",
+ "FLUX_FILL_DIR": "black-forest-labs/FLUX.1-Fill-dev",
+ "NUNCHAKU_FLUX_DIR": "nunchaku-tech/nunchaku-flux.1-dev",
+ "FLUX_2_DIR": "black-forest-labs/FLUX.2-dev",
+ "FLUX_2_KLEIN_4B_DIR": "black-forest-labs/FLUX.2-klein-4B",
+ "FLUX_2_KLEIN_BASE_4B_DIR": "black-forest-labs/FLUX.2-klein-base-4B",
+ "FLUX_2_KLEIN_9B_DIR": "black-forest-labs/FLUX.2-klein-9B",
+ "FLUX_2_KLEIN_BASE_9B_DIR": "black-forest-labs/FLUX.2-klein-base-9B",
+ "OVIS_IMAGE_DIR": "AIDC-AI/Ovis-Image-7B",
+ "LTX2_DIR": "Lightricks/LTX-2",
+ "QWEN_IMAGE_DIR": "Qwen/Qwen-Image",
+ "QWEN_IMAGE_2512_DIR": "Qwen/Qwen-Image-2512",
+ "QWEN_IMAGE_LIGHT_DIR": "lightx2v/Qwen-Image-Lightning",
+ "QWEN_IMAGE_EDIT_2509_DIR": "Qwen/Qwen-Image-Edit-2509",
+ "QWEN_IMAGE_EDIT_2511_DIR": "Qwen/Qwen-Image-Edit-2511",
+ "QWEN_IMAGE_EDIT_2511_LIGHT_DIR": "lightx2v/Qwen-Image-Edit-2511-Lightning",
+ "QWEN_IMAGE_CONTROLNET_DIR": "InstantX/Qwen-Image-ControlNet-Inpainting",
+ "QWEN_IMAGE_LAYERED_DIR": "Qwen/Qwen-Image-Layered",
+ "SKYREELS_V2_DIR": "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers",
+ "WAN_DIR": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
+ "WAN_2_2_DIR": "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
+ "WAN_I2V_DIR": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",
+ "WAN_2_2_I2V_DIR": "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
+ "WAN_VACE_DIR": "Wan-AI/Wan2.1-VACE-1.3B-diffusers",
+ "WAN_2_2_VACE_DIR": "linoyts/Wan2.2-VACE-Fun-14B-diffusers",
+ "ZIMAGE_DIR": "Tongyi-MAI/Z-Image-Turbo",
+ "NUNCHAKU_ZIMAGE_DIR": "nunchaku-tech/nunchaku-z-image-turbo",
+ "Z_IMAGE_CONTROLNET_2_1_DIR": "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1",
+ "Z_IMAGE_CONTROLNET_2_0_DIR": "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
+ "LONGCAT_IMAGE_DIR": "meituan-longcat/LongCat-Image",
+ "LONGCAT_IMAGE_EDIT_DIR": "meituan-longcat/LongCat-Image-Edit",
+}
+_path_env_mapping = {v: k for k, v in _env_path_mapping.items()}
+
+
+def _path(
+ default: str,
+ args: Optional[argparse.Namespace] = None,
+ ENV: Optional[str] = None,
+ lora: bool = False,
+ controlnet: bool = False,
+ transformer: bool = False,
+) -> str:
+ # Prefer command line argument if provided
+ if args is not None:
+ model_path_arg = args.model_path
+ if lora:
+ model_path_arg = args.lora_path
+ if controlnet:
+ model_path_arg = args.controlnet_path
+ if transformer:
+ model_path_arg = args.transformer_path
+ if model_path_arg is not None:
+ return model_path_arg
+ # Next, check environment variable
+ if ENV is None:
+ ENV = _path_env_mapping.get(default, None)
+ if ENV is None:
+ return default
+ return os.environ.get(ENV, default)
+
+
+@ExampleRegister.register("flux", default="black-forest-labs/FLUX.1-dev")
+@ExampleRegister.register("flux_nunchaku", default="nunchaku-tech/nunchaku-flux.1-dev")
+def flux_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import FluxPipeline
+
+ if "nunchaku" in args.example.lower():
+ from nunchaku.models.transformers.transformer_flux_v2 import (
+ NunchakuFluxTransformer2DModelV2,
+ )
+
+ nunchaku_flux_dir = _path(
+ "nunchaku-tech/nunchaku-flux.1-dev",
+ args=args,
+ transformer=True,
+ )
+ transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
+ f"{nunchaku_flux_dir}/svdq-int4_r32-flux.1-dev.safetensors",
+ )
+ else:
+ transformer = None
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=_path("black-forest-labs/FLUX.1-dev"),
+ pipeline_class=FluxPipeline,
+ transformer=transformer, # maybe use Nunchaku Flux transformer
+ # `text_encoder_2` will be quantized when `--quantize-type`
+ # is set to `bnb_4bit`. Only hints for quantization.
+ bnb_4bit_components=["text_encoder_2"],
+ ),
+ input_data=ExampleInputData(
+ prompt="A cat holding a sign that says hello world",
+ height=1024,
+ width=1024,
+ num_inference_steps=28,
+ ),
+ )
+
+
+@ExampleRegister.register("flux_fill", default="black-forest-labs/FLUX.1-Fill-dev")
+def flux_fill_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import FluxFillPipeline
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.IE2I, # Image Editing to Image
+ model_name_or_path=_path("black-forest-labs/FLUX.1-Fill-dev"),
+ pipeline_class=FluxFillPipeline,
+ # `text_encoder_2` will be quantized when `--quantize-type`
+ # is set to `bnb_4bit`. Only hints for quantization.
+ bnb_4bit_components=["text_encoder_2"],
+ ),
+ input_data=ExampleInputData(
+ prompt="a white paper cup",
+ image=load_image("./data/cup.png"),
+ mask_image=load_image("./data/cup_mask.png"),
+ guidance_scale=30,
+ height=1024,
+ width=1024,
+ num_inference_steps=28,
+ ),
+ )
+
+
+def _flux2_params_modifiers(args: argparse.Namespace) -> List[ParamsModifier]:
+ return [
+ ParamsModifier(
+ # Modified config only for transformer_blocks
+ # Must call the `reset` method of DBCacheConfig.
+ cache_config=DBCacheConfig().reset(
+ residual_diff_threshold=args.residual_diff_threshold,
+ ),
+ ),
+ ParamsModifier(
+ # Modified config only for single_transformer_blocks
+ # NOTE: FLUX.2, single_transformer_blocks should have `higher`
+ # residual_diff_threshold because of the precision error
+ # accumulation from previous transformer_blocks
+ cache_config=DBCacheConfig().reset(
+ residual_diff_threshold=args.residual_diff_threshold * 3,
+ ),
+ ),
+ ]
+
+
+@ExampleRegister.register("flux2", default="black-forest-labs/FLUX.2-dev")
+def flux2_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import Flux2Pipeline
+
+ params_modifiers = _flux2_params_modifiers(args)
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=_path("black-forest-labs/FLUX.2-dev"),
+ pipeline_class=Flux2Pipeline,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ # Extra init args for DBCacheConfig, ParamsModifier, etc.
+ extra_optimize_kwargs={
+ "params_modifiers": params_modifiers,
+ },
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ "Realistic macro photograph of a hermit crab using a soda can as its shell, "
+ "partially emerging from the can, captured with sharp detail and natural colors, "
+ "on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean "
+ "waves in the background. The can has the text `BFL Diffusers` on it and it has a color "
+ "gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
+ ),
+ height=1024,
+ width=1024,
+ num_inference_steps=28,
+ guidance_scale=4,
+ ),
+ )
+
+
+@ExampleRegister.register("flux2_klein_4b", default="black-forest-labs/FLUX.2-klein-4B")
+@ExampleRegister.register("flux2_klein_9b", default="black-forest-labs/FLUX.2-klein-9B")
+@ExampleRegister.register("flux2_klein_base_4b", default="black-forest-labs/FLUX.2-klein-base-4B")
+@ExampleRegister.register("flux2_klein_base_9b", default="black-forest-labs/FLUX.2-klein-base-9B")
+def flux2_klein_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import Flux2KleinPipeline
+
+ # cfg: guidance_scale > 1 and not is_distilled
+ if "base" in args.example.lower():
+ num_inference_steps = 50
+ guidance_scale = 4.0 # typical cfg for base model
+ enable_separate_cfg = True
+ if "4b" in args.example.lower():
+ model_path = _path("black-forest-labs/FLUX.2-klein-base-4B")
+ else:
+ model_path = _path("black-forest-labs/FLUX.2-klein-base-9B")
+ else:
+ num_inference_steps = 4
+ guidance_scale = 1.0 # no cfg for klein
+ enable_separate_cfg = False
+ if "4b" in args.example.lower():
+ model_path = _path("black-forest-labs/FLUX.2-klein-4B")
+ else:
+ model_path = _path("black-forest-labs/FLUX.2-klein-9B")
+
+ params_modifiers = _flux2_params_modifiers(args)
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=model_path,
+ pipeline_class=Flux2KleinPipeline,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ # Extra init args for DBCacheConfig, ParamsModifier, etc.
+ extra_optimize_kwargs={
+ "params_modifiers": params_modifiers,
+ "enable_separate_cfg": enable_separate_cfg,
+ },
+ ),
+ input_data=ExampleInputData(
+ prompt="A cat holding a sign that says hello world",
+ height=1024,
+ width=1024,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ ),
+ )
+
+
+def _qwen_light_scheduler() -> FlowMatchEulerDiscreteScheduler:
+ # From https://github.com/ModelTC/Qwen-Image-Lightning/blob/342260e8f5468d2f24d084ce04f55e101007118b/generate_with_diffusers.py#L82C9-L97C10
+ scheduler_config = {
+ "base_image_seq_len": 256,
+ "base_shift": math.log(3), # We use shift=3 in distillation
+ "invert_sigmas": False,
+ "max_image_seq_len": 8192,
+ "max_shift": math.log(3), # We use shift=3 in distillation
+ "num_train_timesteps": 1000,
+ "shift": 1.0,
+ "shift_terminal": None, # set shift_terminal to None
+ "stochastic_sampling": False,
+ "time_shift_type": "exponential",
+ "use_beta_sigmas": False,
+ "use_dynamic_shifting": True,
+ "use_exponential_sigmas": False,
+ "use_karras_sigmas": False,
+ }
+ return FlowMatchEulerDiscreteScheduler.from_config(scheduler_config)
+
+
+def _qwen_light_cache_config(args: argparse.Namespace) -> Optional[DBCacheConfig]:
+ if not args.cache:
+ return None
+ steps = 8 if args.num_inference_steps is None else args.num_inference_steps
+ return DBCacheConfig(
+ Fn_compute_blocks=16,
+ Bn_compute_blocks=16,
+ max_warmup_steps=4 if steps > 4 else 2,
+ max_cached_steps=2 if steps > 4 else 1,
+ max_continuous_cached_steps=1,
+ enable_separate_cfg=False, # true_cfg_scale=1.0
+ residual_diff_threshold=0.50 if steps > 4 else 0.8,
+ )
+
+
+@ExampleRegister.register("qwen_image", default="Qwen/Qwen-Image")
+@ExampleRegister.register("qwen_image_2512", default="Qwen/Qwen-Image-2512")
+@ExampleRegister.register("qwen_image_lightning", default="lightx2v/Qwen-Image-Lightning")
+def qwen_image_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import QwenImagePipeline
+
+ if "lightning" in args.example.lower():
+ scheduler = _qwen_light_scheduler()
+ else:
+ scheduler = None
+
+ if "lightning" in args.example.lower():
+ # For lightning model, only 8 or 4 inference steps are supported
+ steps = 8 if args.num_inference_steps is None else args.num_inference_steps
+ assert steps in [8, 4]
+ lora_weights_path = _path("lightx2v/Qwen-Image-Lightning", args=args, lora=True)
+ lora_weight_name = f"Qwen-Image-Lightning-{steps}steps-V1.0-bf16.safetensors"
+ cache_config = _qwen_light_cache_config(args)
+ true_cfg_scale = 1.0 # means no separate cfg for lightning models
+ else:
+ steps = 50 if args.num_inference_steps is None else args.num_inference_steps
+ lora_weights_path = None
+ lora_weight_name = None
+ cache_config = None
+ true_cfg_scale = 4.0
+
+ positive_magic = {
+ "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
+ "zh": ", 超清,4K,电影级构图.", # for chinese prompt
+ }
+ prompt = (
+ "A coffee shop entrance features a chalkboard sign reading "
+ '"Qwen Coffee 😊 $2 per cup," with a neon light beside it '
+ 'displaying "通义千问". Next to it hangs a poster showing a '
+ "beautiful Chinese woman, and beneath the poster is written "
+ '"π≈3.1415926-53589793-23846264-33832795-02384197". '
+ "Ultra HD, 4K, cinematic composition"
+ )
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=_path("Qwen/Qwen-Image"),
+ pipeline_class=QwenImagePipeline,
+ scheduler=scheduler,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ lora_weights_path=lora_weights_path,
+ lora_weights_name=lora_weight_name,
+ force_fuse_lora=True, # For parallelism compatibility
+ extra_optimize_kwargs={
+ "cache_config": cache_config,
+ },
+ ),
+ input_data=ExampleInputData(
+ prompt=prompt + positive_magic["en"],
+ negative_prompt=" ",
+ height=1024,
+ width=1024,
+ num_inference_steps=steps,
+ true_cfg_scale=true_cfg_scale,
+ ),
+ )
+
+
+@ExampleRegister.register("qwen_image_edit", default="Qwen/Qwen-Image-Edit-2509")
+@ExampleRegister.register("qwen_image_edit_lightning", default="lightx2v/Qwen-Image-Lightning")
+@ExampleRegister.register("qwen_image_edit_2511", default="Qwen/Qwen-Image-Edit-2511")
+@ExampleRegister.register(
+ "qwen_image_edit_2511_lightning", default="lightx2v/Qwen-Image-Edit-2511-Lightning"
+)
+def qwen_image_edit_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import QwenImageEditPlusPipeline
+
+ if "lightning" in args.example.lower():
+ scheduler = _qwen_light_scheduler()
+ else:
+ scheduler = None
+
+ if "lightning" in args.example.lower():
+ # For lightning model, only 8 or 4 inference steps are supported
+ steps = 8 if args.num_inference_steps is None else args.num_inference_steps
+ assert steps in [8, 4]
+ if "2511" in args.example.lower():
+ assert steps == 4, "Qwen-Image-Edit-2511-Lightning only supports 4 steps."
+ lora_weights_path = _path("lightx2v/Qwen-Image-Edit-2511-Lightning", args, lora=True)
+ lora_weight_name = f"Qwen-Image-Edit-2511-Lightning-{steps}steps-V1.0-bf16.safetensors"
+ else:
+ lora_weights_path = os.path.join(
+ _path("lightx2v/Qwen-Image-Lightning", args, lora=True),
+ "Qwen-Image-Edit-2509",
+ )
+ lora_weight_name = f"Qwen-Image-Edit-2509-Lightning-{steps}steps-V1.0-bf16.safetensors"
+ cache_config = _qwen_light_cache_config(args)
+ true_cfg_scale = 1.0 # means no separate cfg for lightning models
+ else:
+ steps = 50 if args.num_inference_steps is None else args.num_inference_steps
+ lora_weights_path = None
+ lora_weight_name = None
+ cache_config = None
+ true_cfg_scale = 4.0
+
+ if "2511" in args.example.lower():
+ model_path_or_name = _path("Qwen/Qwen-Image-Edit-2511", args)
+ else:
+ model_path_or_name = _path("Qwen/Qwen-Image-Edit-2509", args)
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.IE2I, # Image Editing to Image
+ model_name_or_path=model_path_or_name,
+ pipeline_class=QwenImageEditPlusPipeline,
+ scheduler=scheduler,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ lora_weights_path=lora_weights_path,
+ lora_weights_name=lora_weight_name,
+ force_fuse_lora=True, # For parallelism compatibility
+ extra_optimize_kwargs={
+ "cache_config": cache_config,
+ },
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ "The magician bear is on the left, the alchemist bear is on the right, "
+ "facing each other in the central park square."
+ ),
+ negative_prompt=" ",
+ height=1024,
+ width=1024,
+ num_inference_steps=steps,
+ true_cfg_scale=true_cfg_scale, # 1.0 means no separate cfg for lightning models
+ # image1, image2
+ image=[
+ load_image("./data/edit2509_1.jpg"),
+ load_image("./data/edit2509_2.jpg"),
+ ],
+ ),
+ )
+
+
+@ExampleRegister.register(
+ "qwen_image_controlnet", default="InstantX/Qwen-Image-ControlNet-Inpainting"
+)
+def qwen_image_controlnet_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import QwenImageControlNetModel, QwenImageControlNetInpaintPipeline
+
+ # make sure controlnet is on cuda to avoid device mismatch while using cpu offload
+ controlnet = QwenImageControlNetModel.from_pretrained(
+ _path(
+ "InstantX/Qwen-Image-ControlNet-Inpainting",
+ args=args,
+ controlnet=True,
+ ),
+ torch_dtype=torch.bfloat16,
+ )
+
+ base_image_url = (
+ "https://huggingface.co/InstantX/Qwen-Image-ControlNet-Inpainting/resolve/main/assets"
+ )
+ control_image = load_image(f"{base_image_url}/images/image1.png").convert("RGB")
+ control_mask = load_image(f"{base_image_url}/masks/mask1.png")
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=_path("Qwen/Qwen-Image"),
+ pipeline_class=QwenImageControlNetInpaintPipeline,
+ controlnet=controlnet,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ ),
+ input_data=ExampleInputData(
+ prompt="一辆绿色的出租车行驶在路上",
+ negative_prompt="worst quality, low quality, blurry, text, watermark, logo",
+ control_image=control_image,
+ control_mask=control_mask,
+ controlnet_conditioning_scale=1.0,
+ height=control_mask.size[1] if args.height is None else args.height,
+ width=control_mask.size[0] if args.width is None else args.width,
+ num_inference_steps=50,
+ true_cfg_scale=4.0,
+ ),
+ )
+
+
+@ExampleRegister.register("qwen_image_layered", default="Qwen/Qwen-Image-Layered")
+def qwen_image_layered_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import QwenImageLayeredPipeline
+
+ model_name_or_path = _path("Qwen/Qwen-Image-Layered", args=args)
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=model_name_or_path,
+ pipeline_class=QwenImageLayeredPipeline,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ extra_optimize_kwargs={
+ "enable_separate_cfg": False, # negative prompt is not used in example
+ },
+ ),
+ input_data=ExampleInputData(
+ image=load_image("./data/yarn-art-pikachu.png").convert("RGBA"),
+ prompt="",
+ num_inference_steps=50,
+ true_cfg_scale=4.0,
+ layers=4,
+ resolution=640,
+ cfg_normalize=False,
+ use_en_prompt=True,
+ ),
+ )
+
+
+@ExampleRegister.register("skyreels_v2", default="Skywork/SkyReels-V2-T2V-14B-720P-Diffusers")
+def skyreels_v2_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import AutoModel, SkyReelsV2Pipeline, UniPCMultistepScheduler
+
+ model_name_or_path = _path(
+ "Skywork/SkyReels-V2-T2V-14B-720P-Diffusers",
+ args=args,
+ )
+ vae = AutoModel.from_pretrained(
+ model_name_or_path if args.model_path is None else args.model_path,
+ subfolder="vae",
+ torch_dtype=torch.float32,
+ ) # Use float32 VAE to reduce video generation artifacts
+
+ def post_init_hook(pipe: SkyReelsV2Pipeline, **kwargs):
+ flow_shift = 8.0 # 8.0 for T2V, 5.0 for I2V
+ pipe.scheduler = UniPCMultistepScheduler.from_config(
+ pipe.scheduler.config, flow_shift=flow_shift
+ )
+ logger.info(
+ f"Set UniPCMultistepScheduler with flow_shift={flow_shift} "
+ f"for {pipe.__class__.__name__}."
+ )
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2V, # Text to Video
+ model_name_or_path=model_name_or_path,
+ pipeline_class=SkyReelsV2Pipeline,
+ vae=vae,
+ post_init_hook=post_init_hook,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ "A cat and a dog baking a cake together in a kitchen. The cat is "
+ "carefully measuring flour, while the dog is stirring the batter "
+ "with a wooden spoon. The kitchen is cozy, with sunlight streaming "
+ "through the window."
+ ),
+ height=720,
+ width=1280,
+ num_frames=21,
+ num_inference_steps=50,
+ ),
+ )
+
+
+@ExampleRegister.register("ltx2_t2v", default="Lightricks/LTX-2")
+def ltx2_t2v_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import LTX2Pipeline
+
+ model_name_or_path = _path(
+ "Lightricks/LTX-2",
+ args=args,
+ )
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2V, # Text to Video
+ model_name_or_path=model_name_or_path,
+ pipeline_class=LTX2Pipeline,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ "A cinematic tracking shot through a neon-lit rainy cyberpunk street at night. "
+ "Reflections shimmer on wet asphalt, holographic signs flicker, and steam rises from vents. "
+ "Smooth camera motion, natural parallax, ultra-realistic detail, cinematic lighting."
+ ),
+ negative_prompt=(
+ "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion artifacts, "
+ "bad anatomy, ugly, transition, static, text, watermark"
+ ),
+ height=512,
+ width=768,
+ num_frames=121,
+ num_inference_steps=40,
+ guidance_scale=4.0,
+ extra_input_kwargs={
+ "frame_rate": 24.0,
+ },
+ ),
+ )
+
+
+@ExampleRegister.register("ltx2_i2v", default="Lightricks/LTX-2")
+def ltx2_i2v_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import LTX2ImageToVideoPipeline
+
+ model_name_or_path = _path(
+ "Lightricks/LTX-2",
+ args=args,
+ )
+
+ height = 512 if args.height is None else args.height
+ width = 768 if args.width is None else args.width
+ if args.image_path is not None:
+ image = load_image(args.image_path)
+ else:
+ image = load_image(
+ "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
+ )
+ image = image.resize((width, height))
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.I2V, # Image to Video
+ model_name_or_path=model_name_or_path,
+ pipeline_class=LTX2ImageToVideoPipeline,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ "A young girl stands calmly in the foreground, looking directly at the camera, "
+ "as a house fire rages in the background."
+ ),
+ negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
+ image=image,
+ height=height,
+ width=width,
+ num_frames=121,
+ num_inference_steps=40,
+ guidance_scale=4.0,
+ extra_input_kwargs={
+ "frame_rate": 24.0,
+ },
+ ),
+ )
+
+
+def _wan_2_2_params_modifiers(args: argparse.Namespace) -> List[ParamsModifier]:
+ if not args.cache:
+ return None
+ return [
+ ParamsModifier(
+ # high-noise transformer only have 30% steps
+ cache_config=DBCacheConfig().reset(
+ max_warmup_steps=4,
+ max_cached_steps=8,
+ ),
+ ),
+ ParamsModifier(
+ cache_config=DBCacheConfig().reset(
+ max_warmup_steps=2,
+ max_cached_steps=20,
+ ),
+ ),
+ ]
+
+
+@ExampleRegister.register("wan2.1_t2v", default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
+@ExampleRegister.register("wan2.2_t2v", default="Wan-AI/Wan2.2-T2V-A14B-Diffusers")
+def wan_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import WanPipeline
+
+ if "wan2.2" in args.example.lower():
+ model_name_or_path = _path(
+ "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
+ args=args,
+ )
+ else:
+ model_name_or_path = _path(
+ "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
+ args=args,
+ )
+
+ if "wan2.2" in args.example.lower():
+ params_modifiers = _wan_2_2_params_modifiers(args)
+ else:
+ params_modifiers = None
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2V, # Text to Video
+ model_name_or_path=model_name_or_path,
+ pipeline_class=WanPipeline,
+ bnb_4bit_components=(
+ ["text_encoder", "transformer", "transformer_2"]
+ if "wan2.2" in args.example.lower()
+ else ["text_encoder", "transformer"]
+ ),
+ extra_optimize_kwargs={
+ "params_modifiers": params_modifiers,
+ },
+ ),
+ input_data=ExampleInputData(
+ prompt="A cat walks on the grass, realistic",
+ negative_prompt=(
+ "Bright tones, overexposed, static, blurred details, subtitles, "
+ "style, works, paintings, images, static, overall gray, worst quality, "
+ "low quality, JPEG compression residue, ugly, incomplete, extra fingers, "
+ "poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen "
+ "limbs, fused fingers, still picture, messy background, three legs, many "
+ "people in the background, walking backwards"
+ ),
+ height=480,
+ width=832,
+ num_frames=49,
+ guidance_scale=5.0,
+ num_inference_steps=30,
+ ),
+ )
+
+
+@ExampleRegister.register("wan2.1_i2v", default="Wan-AI/Wan2.1-I2V-14B-480P-Diffusers")
+@ExampleRegister.register("wan2.2_i2v", default="Wan-AI/Wan2.2-I2V-A14B-Diffusers")
+def wan_i2v_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import WanImageToVideoPipeline
+
+ if "wan2.2" in args.example.lower():
+ model_name_or_path = _path(
+ "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
+ args=args,
+ )
+ else:
+ model_name_or_path = _path(
+ "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers",
+ args=args,
+ )
+
+ if "wan2.2" in args.example.lower():
+ params_modifiers = _wan_2_2_params_modifiers(args)
+ else:
+ params_modifiers = None
+
+ if args.image_path is not None:
+ image = load_image(args.image_path).convert("RGB")
+ else:
+ image = load_image(
+ "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG"
+ )
+
+ max_area = 480 * 832
+ aspect_ratio = image.height / image.width
+ vae_scale_factor_spatial = 8 # for Wan VAE
+ patch_size = 2 # for Wan transformer, [1, 2, 2]
+ mod_value = vae_scale_factor_spatial * patch_size
+ height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
+ width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
+ image = image.resize((width, height))
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.I2V, # Image to Video
+ model_name_or_path=model_name_or_path,
+ pipeline_class=WanImageToVideoPipeline,
+ bnb_4bit_components=(
+ ["text_encoder", "transformer", "transformer_2"]
+ if "wan2.2" in args.example.lower()
+ else ["text_encoder", "transformer"]
+ ),
+ extra_optimize_kwargs={
+ "params_modifiers": params_modifiers,
+ },
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a "
+ "surfboard. The fluffy-furred feline gazes directly at the camera with "
+ "a relaxed expression. Blurred beach scenery forms the background featuring "
+ "crystal-clear waters, distant green hills, and a blue sky dotted with white "
+ "clouds. The cat assumes a naturally relaxed posture, as if savoring the sea "
+ "breeze and warm sunlight. A close-up shot highlights the feline's intricate "
+ "details and the refreshing atmosphere of the seaside."
+ ),
+ negative_prompt=(
+ "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,"
+ "低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,"
+ "毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
+ ),
+ image=image,
+ height=height,
+ width=width,
+ num_frames=49,
+ guidance_scale=3.5,
+ num_inference_steps=50,
+ ),
+ )
+
+
+@ExampleRegister.register("wan2.1_vace", default="Wan-AI/Wan2.1-VACE-1.3B-diffusers")
+@ExampleRegister.register("wan2.2_vace", default="linoyts/Wan2.2-VACE-Fun-14B-diffusers")
+def wan_vace_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import WanVACEPipeline, AutoencoderKLWan, UniPCMultistepScheduler
+
+ if "wan2.2" in args.example.lower():
+ model_name_or_path = _path(
+ "linoyts/Wan2.2-VACE-Fun-14B-diffusers",
+ args=args,
+ )
+ else:
+ model_name_or_path = _path(
+ "Wan-AI/Wan2.1-VACE-1.3B-diffusers",
+ args=args,
+ )
+
+ vae = AutoencoderKLWan.from_pretrained(
+ model_name_or_path,
+ subfolder="vae",
+ torch_dtype=torch.float32,
+ )
+
+ def post_init_hook(pipe: WanVACEPipeline, **kwargs):
+ flow_shift = 5.0 # 5.0 for 720P, 3.0 for 480P
+ pipe.scheduler = UniPCMultistepScheduler.from_config(
+ pipe.scheduler.config,
+ flow_shift=flow_shift,
+ )
+ logger.info(
+ f"Set UniPCMultistepScheduler with flow_shift={flow_shift} "
+ f"for {pipe.__class__.__name__}."
+ )
+
+ if "wan2.2" in args.example.lower():
+ params_modifiers = _wan_2_2_params_modifiers(args)
+ else:
+ params_modifiers = None
+
+ def _video_and_mask(
+ first_img: PIL.Image.Image,
+ last_img: PIL.Image.Image,
+ height: int,
+ width: int,
+ num_frames: int,
+ ) -> Tuple[List[PIL.Image.Image], List[PIL.Image.Image]]:
+ first_img = first_img.resize((width, height))
+ last_img = last_img.resize((width, height))
+ frames = []
+ frames.append(first_img)
+ # Ideally, this should be 127.5 to match original code, but they perform
+ # computation on numpy arrays whereas we are passing PIL images. If you
+ # choose to pass numpy arrays, you can set it to 127.5 to match the original code.
+ frames.extend([PIL.Image.new("RGB", (width, height), (128, 128, 128))] * (num_frames - 2))
+ frames.append(last_img)
+ mask_black = PIL.Image.new("L", (width, height), 0)
+ mask_white = PIL.Image.new("L", (width, height), 255)
+ mask = [mask_black, *[mask_white] * (num_frames - 2), mask_black]
+ return frames, mask
+
+ first_frame = load_image("./data/flf2v_input_first_frame.png")
+ last_frame = load_image("./data/flf2v_input_last_frame.png")
+
+ height = 512 if args.height is None else args.height
+ width = 512 if args.width is None else args.width
+ num_frames = 81 if args.num_frames is None else args.num_frames
+ video, mask = _video_and_mask(first_frame, last_frame, height, width, num_frames)
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.VACE, # Video All-in-one Creation and Editing
+ model_name_or_path=model_name_or_path,
+ pipeline_class=WanVACEPipeline,
+ vae=vae,
+ post_init_hook=post_init_hook,
+ bnb_4bit_components=(
+ ["text_encoder", "transformer", "transformer_2"]
+ if "wan2.2" in args.example.lower()
+ else ["text_encoder", "transformer"]
+ ),
+ extra_optimize_kwargs={
+ "params_modifiers": params_modifiers,
+ },
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ "CG animation style, a small blue bird takes off from the ground, "
+ "flapping its wings. The bird's feathers are delicate, with a unique "
+ "pattern on its chest. The background shows a blue sky with white "
+ "clouds under bright sunshine. The camera follows the bird upward, "
+ "capturing its flight and the vastness of the sky from a close-up, "
+ "low-angle perspective."
+ ),
+ negative_prompt=(
+ "Bright tones, overexposed, static, blurred details, subtitles, "
+ "style, works, paintings, images, static, overall gray, worst "
+ "quality, low quality, JPEG compression residue, ugly, incomplete, "
+ "extra fingers, poorly drawn hands, poorly drawn faces, deformed, "
+ "disfigured, misshapen limbs, fused fingers, still picture, messy "
+ "background, three legs, many people in the background, walking "
+ "backwards"
+ ),
+ video=video,
+ mask=mask,
+ height=height,
+ width=width,
+ num_frames=num_frames,
+ guidance_scale=5.0,
+ num_inference_steps=30,
+ ),
+ )
+
+
+@ExampleRegister.register("ovis_image", default="AIDC-AI/Ovis-Image-7B")
+def ovis_image_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import OvisImagePipeline
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=_path("AIDC-AI/Ovis-Image-7B"),
+ pipeline_class=OvisImagePipeline,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ 'A creative 3D artistic render where the text "OVIS-IMAGE" is written in a bold, '
+ "expressive handwritten brush style using thick, wet oil paint. The paint is a mix "
+ "of vibrant rainbow colors (red, blue, yellow) swirling together like toothpaste "
+ "or impasto art. You can see the ridges of the brush bristles and the glossy, wet "
+ "texture of the paint. The background is a clean artist's canvas. Dynamic lighting "
+ "creates soft shadows behind the floating paint strokes. Colorful, expressive, tactile "
+ "texture, 4k detail."
+ ),
+ height=1024,
+ width=1024,
+ num_inference_steps=25,
+ guidance_scale=5.0, # has separate cfg for ovis image
+ ),
+ )
+
+
+def _zimage_turbo_steps_mask(
+ args: argparse.Namespace,
+) -> Optional[List[int]]:
+ if not args.cache:
+ return None
+ return (
+ cache_dit.steps_mask(
+ # slow, medium, fast, ultra.
+ mask_policy=args.mask_policy,
+ total_steps=9 if args.num_inference_steps is None else args.num_inference_steps,
+ )
+ if args.mask_policy is not None
+ else (
+ cache_dit.steps_mask(
+ compute_bins=[5, 1, 1], # = 7 (compute steps)
+ cache_bins=[1, 1], # = 2 (dynamic cache steps)
+ )
+ if args.steps_mask
+ else None
+ )
+ )
+
+
+@ExampleRegister.register("zimage", default="Tongyi-MAI/Z-Image-Turbo")
+@ExampleRegister.register("zimage_nunchaku", default="nunchaku/nunchaku-z-image-turbo")
+def zimage_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import ZImagePipeline
+
+ if args.cache:
+ # Only warmup 4 steps (total 9 steps) for distilled models
+ args.max_warmup_steps = min(4, args.max_warmup_steps)
+
+ if "nunchaku" in args.example.lower():
+ from nunchaku import NunchakuZImageTransformer2DModel
+
+ nunchaku_zimage_dir = _path(
+ "nunchaku-tech/nunchaku-z-image-turbo",
+ args=args,
+ transformer=True,
+ )
+ transformer = NunchakuZImageTransformer2DModel.from_pretrained(
+ f"{nunchaku_zimage_dir}/svdq-int4_r128-z-image-turbo.safetensors"
+ )
+ else:
+ transformer = None
+
+ steps_computation_mask = _zimage_turbo_steps_mask(args)
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=_path("Tongyi-MAI/Z-Image-Turbo"),
+ pipeline_class=ZImagePipeline,
+ transformer=transformer, # maybe use Nunchaku zimage transformer
+ bnb_4bit_components=["text_encoder"],
+ extra_optimize_kwargs={
+ "steps_computation_mask": steps_computation_mask,
+ },
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, "
+ "red floral forehead pattern. Elaborate high bun, golden phoenix headdress, "
+ "red flowers, beads. Holds round folding fan with lady, trees, bird. Neon "
+ "lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. "
+ "Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), "
+ "blurred colorful distant lights."
+ ),
+ height=1024,
+ width=1024,
+ guidance_scale=0.0, # Guidance should be 0 for the Turbo models
+ num_inference_steps=9,
+ ),
+ )
+
+
+@ExampleRegister.register(
+ "zimage_controlnet_2.1", default="alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1"
+)
+@ExampleRegister.register(
+ "zimage_controlnet_2.0", default="alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0"
+)
+def zimage_controlnet_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import ZImageControlNetPipeline, ZImageControlNetModel
+
+ if args.cache:
+ # Only warmup 4 steps (total 9 steps) for distilled models
+ args.max_warmup_steps = min(4, args.max_warmup_steps)
+
+ if "2.0" in args.example.lower():
+ controlnet_dir = _path(
+ "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
+ args=args,
+ controlnet=True,
+ )
+ controlnet_path = os.path.join(
+ controlnet_dir, "Z-Image-Turbo-Fun-Controlnet-Union-2.0.safetensors"
+ )
+ controlnet = ZImageControlNetModel.from_single_file(
+ controlnet_path,
+ torch_dtype=torch.bfloat16,
+ config="hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0",
+ )
+ else:
+ controlnet_dir = _path(
+ "alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1",
+ args=args,
+ controlnet=True,
+ )
+ controlnet_path = os.path.join(
+ controlnet_dir, "Z-Image-Turbo-Fun-Controlnet-Union-2.1.safetensors"
+ )
+ controlnet = ZImageControlNetModel.from_single_file(
+ controlnet_path,
+ torch_dtype=torch.bfloat16,
+ )
+
+ control_image = load_image("./data/pose.jpg")
+ steps_computation_mask = _zimage_turbo_steps_mask(args)
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=_path("Tongyi-MAI/Z-Image-Turbo"),
+ pipeline_class=ZImageControlNetPipeline,
+ controlnet=controlnet,
+ bnb_4bit_components=["text_encoder"],
+ extra_optimize_kwargs={
+ "steps_computation_mask": steps_computation_mask,
+ },
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,"
+ "发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,透着一股甜美的青春气息;"
+ "神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,仿佛沉浸在思绪之中。在她身后,"
+ "是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕。"
+ ),
+ control_image=control_image,
+ controlnet_conditioning_scale=0.75,
+ height=1728,
+ width=992,
+ num_inference_steps=9,
+ guidance_scale=0.0,
+ ),
+ )
+
+
+@ExampleRegister.register("longcat_image", default="meituan-longcat/LongCat-Image")
+def longcat_image_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import LongCatImagePipeline
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.T2I, # Text to Image
+ model_name_or_path=_path("meituan-longcat/LongCat-Image"),
+ pipeline_class=LongCatImagePipeline,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ ),
+ input_data=ExampleInputData(
+ prompt=(
+ "A young Asian woman wearing a yellow knit sweater paired with a white necklace. "
+ "Her hands rest on her knees, with a serene expression. The background features a "
+ "rough brick wall, with warm afternoon sunlight casting upon her, creating a tranquil "
+ "and cozy atmosphere. The shot uses a medium-distance perspective, highlighting her "
+ "demeanor and the details of her attire. Soft lighting illuminates her face, emphasizing "
+ "her facial features and the texture of her accessories, adding depth and warmth to the image. "
+ "The overall composition is simple and elegant, with the brick wall's texture complementing "
+ "the interplay of sunlight and shadows, showcasing the character's grace and composure."
+ ),
+ height=1024,
+ width=1024,
+ num_inference_steps=50,
+ guidance_scale=4.5,
+ ),
+ )
+
+
+@ExampleRegister.register("longcat_image_edit", default="meituan-longcat/LongCat-Image-Edit")
+def longcat_image_edit_example(args: argparse.Namespace, **kwargs) -> Example:
+ from diffusers import LongCatImageEditPipeline
+
+ if args.image_path is not None:
+ image = load_image(args.image_path).convert("RGB")
+ else:
+ image_url = (
+ "https://huggingface.co/meituan-longcat/LongCat-Image-Edit/resolve/main/assets/test.png"
+ )
+ image = load_image(image_url).convert("RGB")
+
+ return Example(
+ args=args,
+ init_config=ExampleInitConfig(
+ task_type=ExampleType.IE2I, # Image Editing to Image
+ model_name_or_path=_path("meituan-longcat/LongCat-Image-Edit"),
+ pipeline_class=LongCatImageEditPipeline,
+ bnb_4bit_components=["text_encoder", "transformer"],
+ ),
+ input_data=ExampleInputData(
+ prompt=("Turn the cat into a dog"),
+ negative_prompt="",
+ num_inference_steps=50,
+ guidance_scale=4.5,
+ image=image,
+ ),
+ )
diff --git a/examples/requirements.txt b/examples/requirements.txt
index 489c49513..56034930e 100644
--- a/examples/requirements.txt
+++ b/examples/requirements.txt
@@ -1,5 +1,2 @@
imageio-ffmpeg
-# wan currently requires installing from source
-diffusers>=0.35.1
-torchao>=0.12.0
ftfy
diff --git a/examples/utils.py b/examples/utils.py
index c68636516..31a60edcf 100644
--- a/examples/utils.py
+++ b/examples/utils.py
@@ -1,11 +1,22 @@
-import argparse
-
import torch
+import argparse
import torch.distributed as dist
+from diffusers import DiffusionPipeline
+from typing import Optional, List, Tuple
+from diffusers.quantizers import PipelineQuantizationConfig
import cache_dit
from cache_dit import init_logger
-from cache_dit.parallelism.parallel_backend import ParallelismBackend
+from cache_dit.quantize.utils import normalize_quantize_type
+from cache_dit import (
+ BlockAdapter,
+ DBCacheConfig,
+ ParallelismBackend,
+ ParallelismConfig,
+ TaylorSeerCalibratorConfig,
+)
+
+from cache_dit.platforms import current_platform
logger = init_logger(__name__)
@@ -14,20 +25,20 @@ class MemoryTracker:
"""Track peak GPU memory usage during execution."""
def __init__(self, device=None):
- self.device = device if device is not None else torch.cuda.current_device()
- self.enabled = torch.cuda.is_available()
+ self.device = device if device is not None else current_platform.current_device()
+ self.enabled = current_platform.is_accelerator_available()
self.peak_memory = 0
def __enter__(self):
if self.enabled:
- torch.cuda.reset_peak_memory_stats(self.device)
- torch.cuda.synchronize(self.device)
+ current_platform.reset_peak_memory_stats(self.device)
+ current_platform.synchronize(self.device)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.enabled:
- torch.cuda.synchronize(self.device)
- self.peak_memory = torch.cuda.max_memory_allocated(self.device)
+ current_platform.synchronize(self.device)
+ self.peak_memory = current_platform.max_memory_allocated(self.device)
def get_peak_memory_gb(self):
"""Get peak memory in GB."""
@@ -44,10 +55,10 @@ def report(self):
def GiB():
try:
- if not torch.cuda.is_available():
+ if not current_platform.is_accelerator_available():
return 0
- total_memory_bytes = torch.cuda.get_device_properties(
- torch.cuda.current_device(),
+ total_memory_bytes = current_platform.get_device_properties(
+ current_platform.current_device(),
).total_memory
total_memory_gib = total_memory_bytes / (1024**3)
return int(total_memory_gib)
@@ -59,37 +70,306 @@ def get_args(
parse: bool = True,
) -> argparse.ArgumentParser | argparse.Namespace:
parser = argparse.ArgumentParser()
- parser.add_argument("--cache", action="store_true", default=False)
- parser.add_argument("--compile", action="store_true", default=False)
- parser.add_argument("--fuse-lora", action="store_true", default=False)
- parser.add_argument("--steps", type=int, default=None)
- parser.add_argument("--Fn", type=int, default=8)
- parser.add_argument("--Bn", type=int, default=0)
- parser.add_argument("--rdt", type=float, default=0.08)
- parser.add_argument("--max-warmup-steps", "--w", type=int, default=8)
- parser.add_argument("--warmup-interval", "--wi", type=int, default=1)
- parser.add_argument("--max-cached-steps", "--mc", type=int, default=-1)
- parser.add_argument("--max-continuous-cached-steps", "--mcc", type=int, default=-1)
- parser.add_argument("--taylorseer", action="store_true", default=False)
- parser.add_argument("--taylorseer-order", "-order", type=int, default=1)
- parser.add_argument("--height", type=int, default=None)
- parser.add_argument("--width", type=int, default=None)
- parser.add_argument("--quantize", "-q", action="store_true", default=False)
+ # Model and data paths
+ parser.add_argument(
+ "--model-path",
+ type=str,
+ default=None,
+ help="Override model path if provided",
+ )
+ parser.add_argument(
+ "--controlnet-path",
+ type=str,
+ default=None,
+ help="Override controlnet model path if provided",
+ )
+ parser.add_argument(
+ "--lora-path",
+ type=str,
+ default=None,
+ help="Override lora model path if provided",
+ )
+ parser.add_argument(
+ "--transformer-path",
+ type=str,
+ default=None,
+ help="Override transformer model path if provided",
+ )
+ parser.add_argument(
+ "--image-path",
+ type=str,
+ default=None,
+ help="Override image path if provided",
+ )
+ parser.add_argument(
+ "--mask-image-path",
+ type=str,
+ default=None,
+ help="Override mask image path if provided",
+ )
+ # Acceleration Config path
+ parser.add_argument(
+ "--config-path",
+ "--config",
+ type=str,
+ default=None,
+ help="Path to CacheDiT configuration YAML file",
+ )
+ # Sampling settings
+ parser.add_argument(
+ "--prompt",
+ type=str,
+ default=None,
+ help="Override default prompt if provided",
+ )
+ parser.add_argument(
+ "--negative-prompt",
+ type=str,
+ default=None,
+ help="Override default negative prompt if provided",
+ )
+ parser.add_argument(
+ "--num_inference_steps",
+ "--steps",
+ type=int,
+ default=None,
+ help="Number of inference steps",
+ )
+ parser.add_argument(
+ "--warmup",
+ type=int,
+ default=1,
+ help="Number of warmup steps before measuring performance",
+ )
+ parser.add_argument(
+ "--warmup-num-inference-steps",
+ "--warmup-steps",
+ type=int,
+ default=None,
+ help="Number of warmup inference steps per warmup before measuring performance",
+ )
+ parser.add_argument(
+ "--repeat",
+ type=int,
+ default=1,
+ help="Number of times to repeat the inference for performance measurement",
+ )
+ parser.add_argument(
+ "--height",
+ type=int,
+ default=None,
+ help="Height of the generated image",
+ )
+ parser.add_argument(
+ "--width",
+ type=int,
+ default=None,
+ help="Width of the generated image",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=None,
+ help="Random seed for reproducibility",
+ )
+ parser.add_argument(
+ "--num-frames",
+ "--frames",
+ type=int,
+ default=None,
+ help="Number of frames to generate for video",
+ )
+ # Output settings
+ parser.add_argument(
+ "--save-path",
+ type=str,
+ default=None,
+ help="Path to save the generated output, e.g., output.png or output.mp4",
+ )
+ # Cache specific settings
+ parser.add_argument(
+ "--cache",
+ action="store_true",
+ default=False,
+ help="Enable Cache Acceleration",
+ )
+ parser.add_argument(
+ "--cache-summary",
+ "--summary",
+ action="store_true",
+ default=False,
+ help="Enable Cache Summary logging",
+ )
+ parser.add_argument(
+ "--Fn-compute-blocks",
+ "--Fn",
+ type=int,
+ default=1,
+ help="CacheDiT Fn_compute_blocks parameter",
+ )
+ parser.add_argument(
+ "--Bn-compute-blocks",
+ "--Bn",
+ type=int,
+ default=0,
+ help="CacheDiT Bn_compute_blocks parameter",
+ )
+ parser.add_argument(
+ "--residual-diff-threshold",
+ "--rdt",
+ type=float,
+ default=0.24,
+ help="CacheDiT residual diff threshold",
+ )
+ parser.add_argument(
+ "--max-warmup-steps",
+ "--ws",
+ type=int,
+ default=8,
+ help="Maximum warmup steps for CacheDiT",
+ )
+ parser.add_argument(
+ "--warmup-interval",
+ "--wi",
+ type=int,
+ default=1,
+ help="Warmup interval for CacheDiT",
+ )
+ parser.add_argument(
+ "--max-cached-steps",
+ "--mc",
+ type=int,
+ default=-1,
+ help="Maximum cached steps for CacheDiT",
+ )
+ parser.add_argument(
+ "--max-continuous-cached-steps",
+ "--mcc",
+ type=int,
+ default=3,
+ help="Maximum continuous cached steps for CacheDiT",
+ )
+ parser.add_argument(
+ "--taylorseer",
+ action="store_true",
+ default=False,
+ help="Enable TaylorSeer for CacheDiT",
+ )
+ parser.add_argument(
+ "--taylorseer-order",
+ "-order",
+ type=int,
+ default=1,
+ help="TaylorSeer order",
+ )
+ parser.add_argument(
+ "--steps-mask",
+ action="store_true",
+ default=False,
+ help="Enable steps mask for CacheDiT",
+ )
+ parser.add_argument(
+ "--mask-policy",
+ "--scm",
+ type=str,
+ default=None,
+ choices=[
+ None,
+ "slow",
+ "s",
+ "medium",
+ "m",
+ "fast",
+ "f",
+ "ultra",
+ "u",
+ ],
+ help="Pre-defined steps computation mask policy",
+ )
+ # Quantization settings
+ parser.add_argument(
+ "--quantize",
+ "--q",
+ action="store_true",
+ default=False,
+ help="Enable quantization for transformer",
+ )
# float8, float8_weight_only, int8, int8_weight_only, int4, int4_weight_only
parser.add_argument(
"--quantize-type",
+ "--q-type",
type=str,
- default="float8_weight_only",
+ default=None,
choices=[
+ None,
"float8",
"float8_weight_only",
+ "float8_wo", # alias for float8_weight_only
"int8",
"int8_weight_only",
+ "int8_wo", # alias for int8_weight_only
"int4",
"int4_weight_only",
+ "int4_wo", # alias for int4_weight_only
"bitsandbytes_4bit",
+ "bnb_4bit", # alias for bitsandbytes_4bit
],
)
+ parser.add_argument(
+ "--quantize-text-encoder",
+ "--q-text",
+ action="store_true",
+ default=False,
+ help="Enable quantization for text encoder",
+ )
+ parser.add_argument(
+ "--quantize-text-type",
+ "--q-text-type",
+ type=str,
+ default=None,
+ choices=[
+ None,
+ "float8",
+ "float8_weight_only",
+ "float8_wo", # alias for float8_weight_only
+ "int8",
+ "int8_weight_only",
+ "int8_wo", # alias for int8_weight_only
+ "int4",
+ "int4_weight_only",
+ "int4_wo", # alias for int4_weight_only
+ "bitsandbytes_4bit",
+ "bnb_4bit", # alias for bitsandbytes_4bit
+ ],
+ )
+ parser.add_argument(
+ "--quantize-controlnet",
+ "--q-controlnet",
+ action="store_true",
+ default=False,
+ help="Enable quantization for text encoder",
+ )
+ parser.add_argument(
+ "--quantize-controlnet-type",
+ "--q-controlnet-type",
+ type=str,
+ default=None,
+ choices=[
+ None,
+ "float8",
+ "float8_weight_only",
+ "float8_wo", # alias for float8_weight_only
+ "int8",
+ "int8_weight_only",
+ "int8_wo", # alias for int8_weight_only
+ "int4",
+ "int4_weight_only",
+ "int4_wo", # alias for int4_weight_only
+ "bitsandbytes_4bit",
+ "bnb_4bit", # alias for bitsandbytes_4bit
+ ],
+ )
+ # Parallelism settings
parser.add_argument(
"--parallel-type",
"--parallel",
@@ -102,6 +382,25 @@ def get_args(
"ring",
],
)
+ parser.add_argument(
+ "--parallel-vae",
+ action="store_true",
+ default=False,
+ help="Enable VAE parallelism if applicable.",
+ )
+ parser.add_argument(
+ "--parallel-text-encoder",
+ "--parallel-text",
+ action="store_true",
+ default=False,
+ help="Enable text encoder parallelism if applicable.",
+ )
+ parser.add_argument(
+ "--parallel-controlnet",
+ action="store_true",
+ default=False,
+ help="Enable ControlNet parallelism if applicable.",
+ )
parser.add_argument(
"--attn", # attention backend for context parallelism
type=str,
@@ -109,19 +408,111 @@ def get_args(
choices=[
None,
"flash",
+ "_flash_3", # FlashAttention-3
# Based on this fix: https://github.com/huggingface/diffusers/pull/12563
"native", # native pytorch attention: sdpa
"_native_cudnn",
+ # '_sdpa_cudnn' is only in cache-dit to support context parallelism
+ # with attn masks, e.g., ZImage. It is not in diffusers yet.
+ "_sdpa_cudnn",
"sage", # Need install sageattention: https://github.com/thu-ml/SageAttention
+ "_native_npu", # native npu attention
],
)
- parser.add_argument("--perf", action="store_true", default=False)
- # New arguments for customization
- parser.add_argument("--prompt", type=str, default=None, help="Override default prompt")
parser.add_argument(
- "--negative-prompt", type=str, default=None, help="Override default negative prompt"
+ "--ulysses-anything",
+ "--uaa",
+ action="store_true",
+ default=False,
+ help="Enable Ulysses Anything Attention for context parallelism",
+ )
+ parser.add_argument(
+ "--ulysses-float8",
+ "--ufp8",
+ action="store_true",
+ default=False,
+ help="Enable Ulysses Attention/UAA Float8 for context parallelism",
+ )
+ parser.add_argument(
+ "--ulysses-async",
+ "--uaqkv",
+ action="store_true",
+ default=False,
+ help="Enabled experimental Async QKV Projection with Ulysses for context parallelism",
+ )
+ # Offload settings
+ parser.add_argument(
+ "--cpu-offload",
+ "--cpu-offload-model",
+ action="store_true",
+ default=False,
+ help="Enable CPU offload for model if applicable.",
+ )
+ parser.add_argument(
+ "--sequential-cpu-offload",
+ action="store_true",
+ default=False,
+ help="Enable sequential GPU offload for model if applicable.",
+ )
+ parser.add_argument(
+ "--device-map-balance",
+ "--device-map",
+ action="store_true",
+ default=False,
+ help="Enable automatic device map balancing model if multiple GPUs are available.",
+ )
+ # Vae tiling/slicing settings
+ parser.add_argument(
+ "--vae-tiling",
+ action="store_true",
+ default=False,
+ help="Enable VAE tiling for low memory device.",
)
- parser.add_argument("--model-path", type=str, default=None, help="Override model path")
+ parser.add_argument(
+ "--vae-slicing",
+ action="store_true",
+ default=False,
+ help="Enable VAE slicing for low memory device.",
+ )
+ # Compiling settings
+ parser.add_argument(
+ "--compile",
+ action="store_true",
+ default=False,
+ help="Enable compile for transformer",
+ )
+ parser.add_argument(
+ "--compile-repeated-blocks",
+ action="store_true",
+ default=False,
+ help="Enable compile for repeated blocks in transformer",
+ )
+ parser.add_argument(
+ "--compile-vae",
+ action="store_true",
+ default=False,
+ help="Enable compile for VAE",
+ )
+ parser.add_argument(
+ "--compile-text-encoder",
+ "--compile-text",
+ action="store_true",
+ default=False,
+ help="Enable compile for text encoder",
+ )
+ parser.add_argument(
+ "--compile-controlnet",
+ action="store_true",
+ default=False,
+ help="Enable compile for ControlNet",
+ )
+ parser.add_argument(
+ "--max-autotune",
+ action="store_true",
+ default=False,
+ help="Enable max-autotune mode for torch.compile",
+ )
+ # Profiling and memory tracking settings
parser.add_argument(
"--track-memory",
action="store_true",
@@ -129,131 +520,965 @@ def get_args(
help="Track and report peak GPU memory usage",
)
parser.add_argument(
- "--ulysses-anything",
- "--uaa",
+ "--profile",
action="store_true",
default=False,
- help="Enable Ulysses Anything Attention for context parallelism",
+ help="Enable profiling with torch.profiler",
)
parser.add_argument(
- "--disable-compute-comm-overlap",
- "--dcco",
+ "--profile-name",
+ type=str,
+ default=None,
+ help="Name for the profiling session",
+ )
+ parser.add_argument(
+ "--profile-dir",
+ type=str,
+ default=None,
+ help="Directory to save profiling results",
+ )
+ parser.add_argument(
+ "--profile-activities",
+ type=str,
+ nargs="+",
+ default=["CPU", "GPU"],
+ choices=["CPU", "GPU", "MEM"],
+ help="Activities to profile (CPU, GPU, MEM)",
+ )
+ parser.add_argument(
+ "--profile-with-stack",
action="store_true",
- default=False,
- help="Disable compute-communication overlap during compilation",
+ default=True,
+ help="profile with stack for better traceability",
)
- return parser.parse_args() if parse else parser
+ parser.add_argument(
+ "--profile-record-shapes",
+ action="store_true",
+ default=True,
+ help="profile record shapes for better analysis",
+ )
+ # Lora settings
+ parser.add_argument(
+ "--disable-fuse-lora",
+ type=str,
+ default=None,
+ help="Disable fuse_lora even if lora weights are provided.",
+ )
+ # Generator device
+ parser.add_argument(
+ "--generator-device",
+ "--gen-device",
+ type=str,
+ default=None,
+ help="Device for torch.Generator, e.g., 'cuda' or 'cpu'. "
+ "If not set, use 'cpu' for better reproducibility across "
+ "different hardware.",
+ )
+
+ args_or_parser = parser.parse_args() if parse else parser
+ if parse:
+ return maybe_postprocess_args(args_or_parser)
+ return args_or_parser
+
+
+def get_base_args(parse: bool = True) -> argparse.Namespace | argparse.ArgumentParser:
+ return get_args(parse=parse) # For future extension if needed
+
+
+def maybe_postprocess_args(args: argparse.Namespace) -> argparse.Namespace:
+ # Force enable quantization if quantize_type is specified
+ if args.quantize_type is not None:
+ args.quantize = True
+
+ # Handle alias for quantize_type
+ if args.quantize and args.quantize_type is None:
+ args.quantize_type = "float8_weight_only" # default type
+
+ args.quantize_type = normalize_quantize_type(args.quantize_type)
+
+ # Force enable quantization for text encoder if quantize_text_type is specified
+ if args.quantize_text_type is not None:
+ args.quantize_text_encoder = True
+ # Handle alias for quantize_text_type
+ if args.quantize_text_encoder and args.quantize_text_type is None:
+ # default to same as quantize_type
+ args.quantize_text_type = args.quantize_type
+
+ args.quantize_text_type = normalize_quantize_type(args.quantize_text_type)
+
+ # Force enable quantization for controlnet if quantize_controlnet_type is specified
+ if args.quantize_controlnet_type is not None:
+ args.quantize_controlnet = True
+ # Handle alias for quantize_controlnet_type
+ if args.quantize_controlnet and args.quantize_controlnet_type is None:
+ # default to same as quantize_type
+ args.quantize_controlnet_type = args.quantize_type
+
+ args.quantize_controlnet_type = normalize_quantize_type(args.quantize_controlnet_type)
+
+ if args.mask_policy is not None and not args.steps_mask:
+ # Enable steps mask if mask_policy is specified
+ args.steps_mask = True
+ # Handle alias for mask_policy
+ if args.mask_policy == "s": # alias
+ args.mask_policy = "slow"
+ if args.mask_policy == "m": # alias
+ args.mask_policy = "medium"
+ if args.mask_policy == "f": # alias
+ args.mask_policy = "fast"
+ if args.mask_policy == "u": # alias
+ args.mask_policy = "ultra"
+ return args
+
+
+def get_text_encoder_from_pipe(
+ pipe: DiffusionPipeline,
+) -> Tuple[Optional[torch.nn.Module], Optional[str]]:
+ pipe_cls_name = pipe.__class__.__name__
+ if (
+ hasattr(pipe, "text_encoder_2")
+ and not pipe_cls_name.startswith("Hunyuan")
+ and not pipe_cls_name.startswith("Kandinsky")
+ ):
+ # Specific for FluxPipeline, FLUX.1-dev
+ return getattr(pipe, "text_encoder_2"), "text_encoder_2"
+ elif hasattr(pipe, "text_encoder_3"): # HiDream pipeline
+ return getattr(pipe, "text_encoder_3"), "text_encoder_3"
+ elif hasattr(pipe, "text_encoder"): # General case
+ return getattr(pipe, "text_encoder"), "text_encoder"
+ else:
+ return None, None
+
+
+def prepare_extra_parallel_modules(
+ args,
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+ custom_extra_modules: Optional[List[torch.nn.Module]] = None,
+) -> list:
+ if custom_extra_modules is not None:
+ return custom_extra_modules
+
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ assert pipe is not None, "Please set extra_parallel_modules manually if pipe is None."
+ else:
+ pipe = pipe_or_adapter
+
+ extra_parallel_modules = []
+
+ if args.parallel_text_encoder:
+ text_encoder, _ = get_text_encoder_from_pipe(pipe)
+ if text_encoder is not None:
+ extra_parallel_modules.append(text_encoder)
+ else:
+ logger.warning(
+ "parallel-text-encoder is set but no text encoder found in the pipeline."
+ )
+
+ if args.parallel_vae:
+ assert not args.vae_tiling, "VAE tiling is not compatible with VAE parallelism."
+ assert not args.vae_slicing, "VAE slicing is not compatible with VAE parallelism."
+ if hasattr(pipe, "vae"):
+ extra_parallel_modules.append(getattr(pipe, "vae"))
+ else:
+ logger.warning("parallel-vae is set but no VAE found in the pipeline.")
+
+ if args.parallel_controlnet:
+ if hasattr(pipe, "controlnet"):
+ extra_parallel_modules.append(getattr(pipe, "controlnet"))
+ else:
+ logger.warning("parallel-controlnet is set but no ControlNet found in the pipeline.")
+
+ return extra_parallel_modules
+
+
+def maybe_compile_text_encoder(
+ args,
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+) -> DiffusionPipeline | BlockAdapter:
+ if args.compile_text_encoder:
+ torch.set_float32_matmul_precision("high")
+
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ assert pipe is not None, "Please compile text encoder manually if pipe is None."
+ else:
+ pipe = pipe_or_adapter
+
+ text_encoder, name = get_text_encoder_from_pipe(pipe)
+ if text_encoder is not None and not isinstance(
+ text_encoder,
+ torch._dynamo.OptimizedModule, # already compiled
+ ):
+ # Find module to be compiled, [encoder, model, model.language_model, ...]
+ _module_to_compile = text_encoder
+ if hasattr(_module_to_compile, "model"):
+ if hasattr(_module_to_compile.model, "language_model"):
+ _module_to_compile = _module_to_compile.model.language_model
+ else:
+ _module_to_compile = _module_to_compile.model
+
+ if hasattr(_module_to_compile, "encoder"):
+ _module_to_compile = _module_to_compile.encoder
+
+ _module_to_compile_cls_name = _module_to_compile.__class__.__name__
+ _text_encoder_cls_name = text_encoder.__class__.__name__
+ if isinstance(_module_to_compile, torch.nn.Module):
+ logger.info(
+ f"Compiling text encoder module {name}:{_text_encoder_cls_name}:"
+ f"{_module_to_compile_cls_name} ..."
+ )
+ _module_to_compile = torch.compile(
+ _module_to_compile,
+ mode="max-autotune-no-cudagraphs" if args.max_autotune else "default",
+ )
+ # Set back the compiled text encoder
+ if hasattr(text_encoder, "model"):
+ if hasattr(text_encoder.model, "language_model"):
+ text_encoder.model.language_model = _module_to_compile
+ else:
+ text_encoder.model = _module_to_compile
+ if hasattr(text_encoder, "encoder"):
+ text_encoder.encoder = _module_to_compile
+
+ setattr(pipe, name, text_encoder)
+ else:
+ logger.warning(
+ f"Cannot compile text encoder module {name}:{_text_encoder_cls_name}:"
+ f"{_module_to_compile_cls_name} Not a torch.nn.Module."
+ )
+ else:
+ logger.warning("compile-text-encoder is set but no text encoder found in the pipeline.")
+ return pipe_or_adapter
+
+
+def maybe_compile_controlnet(
+ args,
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+) -> DiffusionPipeline | BlockAdapter:
+ if args.compile_controlnet:
+ cache_dit.set_compile_configs()
+ torch.set_float32_matmul_precision("high")
+
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ assert pipe is not None, "Please compile transformer manually if pipe is None."
+ else:
+ pipe = pipe_or_adapter
+
+ if hasattr(pipe, "controlnet"):
+ controlnet = getattr(pipe, "controlnet", None)
+ if controlnet is not None and not isinstance(
+ controlnet,
+ torch._dynamo.OptimizedModule, # already compiled
+ ):
+ controlnet_cls_name = controlnet.__class__.__name__
+ if isinstance(controlnet, torch.nn.Module):
+ logger.info(f"Compiling controlnet module: {controlnet_cls_name} ...")
+ controlnet = torch.compile(
+ controlnet,
+ mode="max-autotune-no-cudagraphs" if args.max_autotune else "default",
+ )
+ setattr(pipe, "controlnet", controlnet)
+ else:
+ logger.warning(
+ f"Cannot compile controlnet module: {controlnet_cls_name} Not a"
+ " torch.nn.Module."
+ )
+ setattr(pipe, "controlnet", controlnet)
+ else:
+ logger.warning("compile is set but no controlnet found in the pipeline.")
+
+
+def maybe_compile_vae(
+ args,
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+) -> DiffusionPipeline | BlockAdapter:
+ if args.compile_vae:
+ torch.set_float32_matmul_precision("high")
+
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ assert pipe is not None, "Please compile VAE manually if pipe is None."
+ else:
+ pipe = pipe_or_adapter
+
+ if hasattr(pipe, "vae"):
+ vae = getattr(pipe, "vae", None)
+ if vae is not None and not isinstance(
+ vae,
+ torch._dynamo.OptimizedModule, # already compiled
+ ):
+ vae_cls_name = vae.__class__.__name__
+ if hasattr(vae, "encoder"):
+ _encoder_to_compile = vae.encoder
+ if isinstance(_encoder_to_compile, torch.nn.Module):
+ logger.info(f"Compiling VAE encoder module: {vae_cls_name}.encoder ...")
+ vae.encoder = torch.compile(
+ _encoder_to_compile,
+ mode="max-autotune-no-cudagraphs" if args.max_autotune else "default",
+ )
+ else:
+ logger.warning(
+ f"Cannot compile VAE encoder module: {vae_cls_name}.encoder Not a"
+ " torch.nn.Module."
+ )
+ if hasattr(vae, "decoder"):
+ _decoder_to_compile = vae.decoder
+ if isinstance(_decoder_to_compile, torch.nn.Module):
+ logger.info(f"Compiling VAE decoder module: {vae_cls_name}.decoder ...")
+ vae.decoder = torch.compile(
+ _decoder_to_compile,
+ mode="max-autotune-no-cudagraphs" if args.max_autotune else "default",
+ )
+ else:
+ logger.warning(
+ f"Cannot compile VAE decoder module: {vae_cls_name}.decoder Not a"
+ " torch.nn.Module."
+ )
+ setattr(pipe, "vae", vae)
+ else:
+ logger.warning(f"Cannot compile VAE module: {vae_cls_name} Not a torch.nn.Module.")
+ else:
+ logger.warning("compile-vae is set but no VAE found in the pipeline.")
+ return pipe_or_adapter
-def cachify(
+def maybe_compile_transformer(
+ args,
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+) -> DiffusionPipeline | BlockAdapter:
+ if args.compile:
+ cache_dit.set_compile_configs()
+ torch.set_float32_matmul_precision("high")
+
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ assert pipe is not None, "Please compile transformer manually if pipe is None."
+ else:
+ pipe = pipe_or_adapter
+
+ if hasattr(pipe, "transformer"):
+ transformer = getattr(pipe, "transformer", None)
+ if transformer is not None and not isinstance(
+ transformer,
+ torch._dynamo.OptimizedModule, # already compiled
+ ):
+ transformer_cls_name = transformer.__class__.__name__
+ if isinstance(transformer, torch.nn.Module):
+ logger.info(f"Compiling transformer module: {transformer_cls_name} ...")
+ transformer = torch.compile(
+ transformer,
+ mode="max-autotune-no-cudagraphs" if args.max_autotune else "default",
+ )
+ setattr(pipe, "transformer", transformer)
+ else:
+ logger.warning(
+ f"Cannot compile transformer module: {transformer_cls_name} Not a"
+ " torch.nn.Module."
+ )
+ setattr(pipe, "transformer", transformer)
+ else:
+ logger.warning("compile is set but no transformer found in the pipeline.")
+
+ if hasattr(pipe, "transformer_2"):
+ transformer_2 = getattr(pipe, "transformer_2", None)
+ if transformer_2 is not None and not isinstance(
+ transformer_2,
+ torch._dynamo.OptimizedModule, # already compiled
+ ):
+ transformer_2_cls_name = transformer_2.__class__.__name__
+ if isinstance(transformer_2, torch.nn.Module):
+ logger.info(f"Compiling transformer_2 module: {transformer_2_cls_name} ...")
+ transformer_2 = torch.compile(
+ transformer_2,
+ mode="max-autotune-no-cudagraphs" if args.max_autotune else "default",
+ )
+ setattr(pipe, "transformer_2", transformer_2)
+ else:
+ logger.warning(
+ f"Cannot compile transformer_2 module: {transformer_2_cls_name} Not a"
+ " torch.nn.Module."
+ )
+ setattr(pipe, "transformer_2", transformer_2)
+
+ return pipe_or_adapter
+
+
+def maybe_quantize_transformer(
+ args,
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+) -> DiffusionPipeline | BlockAdapter:
+ # Quantize transformer by default if quantization is enabled
+ if args.quantize:
+ if args.quantize_type in ("bitsandbytes_4bit", "bnb_4bit"):
+ logger.debug(
+ "bitsandbytes_4bit quantization should be handled by"
+ " PipelineQuantizationConfig in from_pretrained."
+ )
+ return pipe_or_adapter
+
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ assert pipe is not None, "Please quantize transformer manually if pipe is None."
+ else:
+ pipe = pipe_or_adapter
+
+ _class_not_supported_per_row = [
+ "QwenImageTransformer2DModel",
+ ]
+
+ def is_per_row_supported(transformer):
+ transformer_cls_name = transformer.__class__.__name__
+ return transformer_cls_name not in _class_not_supported_per_row
+
+ if hasattr(pipe, "transformer"):
+ transformer = getattr(pipe, "transformer", None)
+ if transformer is not None:
+ transformer_cls_name = transformer.__class__.__name__
+ if isinstance(transformer, torch.nn.Module):
+ logger.info(
+ f"Quantizing transformer module: {transformer_cls_name} to"
+ f" {args.quantize_type} ..."
+ )
+ transformer = cache_dit.quantize(
+ transformer,
+ quant_type=args.quantize_type,
+ per_row=is_per_row_supported(transformer),
+ )
+ setattr(pipe, "transformer", transformer)
+ else:
+ logger.warning(
+ f"Cannot quantize transformer module: {transformer_cls_name} Not a"
+ " torch.nn.Module."
+ )
+ setattr(pipe, "transformer", transformer)
+ else:
+ logger.warning("quantize is set but no transformer found in the pipeline.")
+
+ if hasattr(pipe, "transformer_2"):
+ transformer_2 = getattr(pipe, "transformer_2", None)
+ if transformer_2 is not None:
+ transformer_2_cls_name = transformer_2.__class__.__name__
+ if isinstance(transformer_2, torch.nn.Module):
+ logger.info(
+ f"Quantizing transformer_2 module: {transformer_2_cls_name} to"
+ f" {args.quantize_type} ..."
+ )
+ transformer_2 = cache_dit.quantize(
+ transformer_2,
+ quant_type=args.quantize_type,
+ per_row=is_per_row_supported(transformer_2),
+ )
+ setattr(pipe, "transformer_2", transformer_2)
+ else:
+ logger.warning(
+ f"Cannot quantize transformer_2 module: {transformer_2_cls_name} Not a"
+ " torch.nn.Module."
+ )
+ setattr(pipe, "transformer_2", transformer_2)
+
+ return pipe_or_adapter
+
+
+def maybe_quantize_text_encoder(
+ args,
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+) -> DiffusionPipeline | BlockAdapter:
+ # Quantize text encoder by default if quantize_text_encoder is enabled
+ if args.quantize_text_encoder:
+ if args.quantize_text_type in ("bitsandbytes_4bit", "bnb_4bit"):
+ logger.debug(
+ "bitsandbytes_4bit quantization should be handled by"
+ " PipelineQuantizationConfig in from_pretrained."
+ )
+ return pipe_or_adapter
+
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ assert pipe is not None, "Please quantize text encoder manually if pipe is None."
+ else:
+ pipe = pipe_or_adapter
+
+ text_encoder, name = get_text_encoder_from_pipe(pipe)
+ if text_encoder is not None:
+ text_encoder_cls_name = text_encoder.__class__.__name__
+ if isinstance(text_encoder, torch.nn.Module):
+ logger.info(
+ f"Quantizing text encoder module: {name}:{text_encoder_cls_name} to"
+ f" {args.quantize_text_type} ..."
+ )
+ text_encoder = cache_dit.quantize(
+ text_encoder,
+ quant_type=args.quantize_text_type,
+ )
+ setattr(pipe, name, text_encoder)
+ else:
+ logger.warning(
+ f"Cannot quantize text encoder module: {name}:{text_encoder_cls_name} Not a"
+ " torch.nn.Module."
+ )
+ else:
+ logger.warning("quantize is set but no text encoder found in the pipeline.")
+ return pipe_or_adapter
+
+
+def maybe_quantize_controlnet(
+ args,
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+) -> DiffusionPipeline | BlockAdapter:
+ # Quantize controlnet by default if quantize_controlnet is enabled
+ if args.quantize_controlnet:
+ if args.quantize_controlnet_type in ("bitsandbytes_4bit", "bnb_4bit"):
+ logger.debug(
+ "bitsandbytes_4bit quantization should be handled by"
+ " PipelineQuantizationConfig in from_pretrained."
+ )
+ return pipe_or_adapter
+
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ assert pipe is not None, "Please quantize controlnet manually if pipe is None."
+ else:
+ pipe = pipe_or_adapter
+
+ if hasattr(pipe, "controlnet"):
+ controlnet = getattr(pipe, "controlnet", None)
+ if controlnet is not None:
+ controlnet_cls_name = controlnet.__class__.__name__
+ if isinstance(controlnet, torch.nn.Module):
+ logger.info(
+ f"Quantizing controlnet module: {controlnet_cls_name} to"
+ f" {args.quantize_controlnet_type} ..."
+ )
+ controlnet = cache_dit.quantize(
+ controlnet,
+ quant_type=args.quantize_controlnet_type,
+ )
+ setattr(pipe, "controlnet", controlnet)
+ else:
+ logger.warning(
+ f"Cannot quantize controlnet module: {controlnet_cls_name} Not a"
+ " torch.nn.Module."
+ )
+ setattr(pipe, "controlnet", controlnet)
+ else:
+ logger.warning("quantize_controlnet is set but no controlnet found in the pipeline.")
+ return pipe_or_adapter
+
+
+def pipe_quant_bnb_4bit_config(
+ args,
+ components_to_quantize: Optional[List[str]] = ["text_encoder"],
+) -> Optional[PipelineQuantizationConfig]:
+ if not args.quantize_text_encoder and not args.quantize:
+ return None
+
+ if components_to_quantize:
+ # Remove all components if quantize type is not bitsandbytes_4bit
+ if args.quantize_type != "bitsandbytes_4bit":
+ if "transformer" in components_to_quantize:
+ components_to_quantize.remove("transformer")
+ if "transformer_2" in components_to_quantize:
+ components_to_quantize.remove("transformer_2")
+ if args.quantize_text_type != "bitsandbytes_4bit":
+ if "text_encoder" in components_to_quantize:
+ components_to_quantize.remove("text_encoder")
+ if "text_encoder_2" in components_to_quantize:
+ components_to_quantize.remove("text_encoder_2")
+
+ # Remove text encoder if parallel_text_encoder is enabled
+ if args.parallel_text_encoder:
+ if "text_encoder" in components_to_quantize:
+ components_to_quantize.remove("text_encoder")
+ if "text_encoder_2" in components_to_quantize:
+ components_to_quantize.remove("text_encoder_2")
+
+ if components_to_quantize:
+ quantization_config = (
+ (
+ PipelineQuantizationConfig(
+ quant_backend="bitsandbytes_4bit",
+ quant_kwargs={
+ "load_in_4bit": True,
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_compute_dtype": torch.bfloat16,
+ },
+ components_to_quantize=components_to_quantize,
+ )
+ )
+ if args.quantize or args.quantize_text_encoder
+ else None
+ )
+ else:
+ quantization_config = None
+
+ return quantization_config
+
+
+def maybe_vae_tiling_or_slicing(
+ args,
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+) -> DiffusionPipeline | BlockAdapter:
+ if args.vae_tiling or args.vae_slicing:
+ assert not args.parallel_vae, "VAE tiling/slicing is not compatible with VAE parallelism."
+
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ assert pipe is not None, "Please enable VAE tiling/slicing manually if pipe is None."
+ else:
+ pipe = pipe_or_adapter
+
+ if hasattr(pipe, "vae"):
+ vae = getattr(pipe, "vae", None)
+ if vae is not None:
+ vae_cls_name = vae.__class__.__name__
+ if args.vae_tiling:
+ if hasattr(vae, "enable_tiling"):
+ logger.info(f"Enabling VAE tiling for module: {vae_cls_name} ...")
+ vae.enable_tiling()
+ else:
+ logger.warning(
+ f"Cannot enable VAE tiling for module: {vae_cls_name} No enable_tiling"
+ " method."
+ )
+ if args.vae_slicing:
+ if hasattr(vae, "enable_slicing"):
+ logger.info(f"Enabling VAE slicing for module: {vae_cls_name} ...")
+ vae.enable_slicing()
+ else:
+ logger.warning(
+ f"Cannot enable VAE slicing for module: {vae_cls_name} No enable_slicing"
+ " method."
+ )
+ setattr(pipe, "vae", vae)
+ else:
+ logger.warning("vae-tiling is set but no VAE found in the pipeline.")
+ return pipe_or_adapter
+
+
+def maybe_cpu_offload(
+ args,
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+) -> bool:
+ _, device = get_rank_device()
+ if args.cpu_offload or args.sequential_cpu_offload:
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ assert pipe is not None, "Please enable cpu offload manually if pipe is None."
+ else:
+ pipe = pipe_or_adapter
+
+ pipe_cls_name = pipe.__class__.__name__
+ if args.sequential_cpu_offload:
+ logger.info(f"Enabling Sequential CPU offload for the model {pipe_cls_name} ...")
+ pipe.enable_sequential_cpu_offload(device=device)
+ else:
+ logger.info(f"Enabling CPU offload for the model {pipe_cls_name} ...")
+ pipe.enable_model_cpu_offload(device=device)
+
+ return True
+
+ return False
+
+
+def maybe_apply_optimization(
args,
pipe_or_adapter,
**kwargs,
):
- if args.disable_compute_comm_overlap:
- # Enable compute comm overlap default for torch.compile if used
- # cache_dit.set_compile_flags(), users need to disable it explicitly.
- cache_dit.disable_compute_comm_overlap()
+ if args.attn is not None and args.parallel_type is None:
+ # NON-parallelism case: set attention backend directly
+ try:
+ from cache_dit.parallelism.attention import _maybe_register_custom_attn_backends
- if args.cache or args.parallel_type is not None:
- import torch.distributed as dist
+ _maybe_register_custom_attn_backends()
+ except Exception as e:
+ logger.warning(
+ "Failed to register custom attention backends. "
+ f"Proceeding to set attention backend anyway. Error: {e}"
+ )
+
+ def _set_backend(module):
+ if module is None:
+ return
+ if hasattr(module, "set_attention_backend"):
+ module.set_attention_backend(args.attn)
+ logger.info(
+ f"Set attention backend to {args.attn} for module: {module.__class__.__name__}."
+ )
+ else:
+ logger.warning(
+ "--attn was provided but module does not support set_attention_backend: "
+ f"{module.__class__.__name__}."
+ )
- from cache_dit import DBCacheConfig, ParallelismConfig, TaylorSeerCalibratorConfig
+ try:
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ transformer = pipe_or_adapter.transformer
+ if isinstance(transformer, list):
+ for t in transformer:
+ _set_backend(t)
+ else:
+ _set_backend(transformer)
+ else:
+ pipe = pipe_or_adapter
+ if hasattr(pipe, "transformer"):
+ _set_backend(getattr(pipe, "transformer"))
+ else:
+ _set_backend(pipe)
+ except Exception as e:
+ raise RuntimeError(
+ f"Failed to set attention backend to {args.attn}. "
+ "This usually means the backend is unavailable (e.g., FlashAttention-3 not installed) "
+ "or the model/shape/dtype is unsupported. "
+ f"Original error: {e}"
+ ) from e
- cache_config = kwargs.pop("cache_config", None)
- parallelism_config = kwargs.pop("parallelism_config", None)
+ default_num_inference_steps = kwargs.pop("default_num_inference_steps", None)
+ if args.cache or args.parallel_type is not None or args.config_path is not None:
- backend = (
- ParallelismBackend.NATIVE_PYTORCH
- if args.parallel_type in ["tp"]
- else ParallelismBackend.NATIVE_DIFFUSER
- )
+ if args.config_path is None:
+ # Construct acceleration configs from command line args if config path is not provided
+ cache_config = kwargs.pop("cache_config", None)
+ parallelism_config = kwargs.pop("parallelism_config", None)
+
+ backend = (
+ ParallelismBackend.NATIVE_PYTORCH
+ if args.parallel_type in ["tp"]
+ else ParallelismBackend.NATIVE_DIFFUSER
+ )
+
+ extra_parallel_modules = prepare_extra_parallel_modules(
+ args,
+ pipe_or_adapter,
+ custom_extra_modules=kwargs.get("extra_parallel_modules", None),
+ )
- parallel_kwargs = (
- {
- "attention_backend": ("_native_cudnn" if not args.attn else args.attn),
- "experimental_ulysses_anything": args.ulysses_anything,
+ parallel_kwargs = {
+ "attention_backend": ("native" if not args.attn else args.attn),
+ # e.g., text_encoder_2 in FluxPipeline, text_encoder in Flux2Pipeline
+ "extra_parallel_modules": extra_parallel_modules,
}
- if backend == ParallelismBackend.NATIVE_DIFFUSER
- else None
- )
- cache_dit.enable_cache(
- pipe_or_adapter,
- cache_config=(
- DBCacheConfig(
- Fn_compute_blocks=args.Fn,
- Bn_compute_blocks=args.Bn,
- max_warmup_steps=args.max_warmup_steps,
- warmup_interval=args.warmup_interval,
- max_cached_steps=args.max_cached_steps,
- max_continuous_cached_steps=args.max_continuous_cached_steps,
- residual_diff_threshold=args.rdt,
- enable_separate_cfg=kwargs.get("enable_separate_cfg", None),
+ if backend == ParallelismBackend.NATIVE_PYTORCH:
+ if args.attn is None:
+ parallel_kwargs["attention_backend"] = None
+
+ if backend == ParallelismBackend.NATIVE_DIFFUSER:
+ parallel_kwargs.update(
+ {
+ "experimental_ulysses_anything": args.ulysses_anything,
+ "experimental_ulysses_float8": args.ulysses_float8,
+ "experimental_ulysses_async": args.ulysses_async,
+ }
)
- if cache_config is None and args.cache
- else cache_config
- ),
- calibrator_config=(
- TaylorSeerCalibratorConfig(
- taylorseer_order=args.taylorseer_order,
+
+ # Caching and Parallelism
+ if args.steps_mask and args.mask_policy is not None:
+ logger.info(
+ f"Using steps computation mask with policy: {args.mask_policy} for caching."
)
- if args.taylorseer
- else None
- ),
- parallelism_config=(
- ParallelismConfig(
- ulysses_size=(
- dist.get_world_size() if args.parallel_type == "ulysses" else None
- ),
- ring_size=(dist.get_world_size() if args.parallel_type == "ring" else None),
- tp_size=(dist.get_world_size() if args.parallel_type == "tp" else None),
- backend=backend,
- parallel_kwargs=parallel_kwargs,
+ if default_num_inference_steps is None:
+ assert (
+ args.num_inference_steps is not None
+ ), "num_inference_steps (--steps) must be provided for steps mask."
+ num_inference_steps = args.num_inference_steps
+ else:
+ num_inference_steps = default_num_inference_steps
+ steps_computation_mask = cache_dit.steps_mask(
+ total_steps=num_inference_steps,
+ mask_policy=args.mask_policy,
)
- if parallelism_config is None and args.parallel_type in ["ulysses", "ring", "tp"]
- else parallelism_config
- ),
- )
+ else:
+ steps_computation_mask = None
+
+ cache_dit.enable_cache(
+ pipe_or_adapter,
+ cache_config=(
+ DBCacheConfig(
+ Fn_compute_blocks=args.Fn_compute_blocks,
+ Bn_compute_blocks=args.Bn_compute_blocks,
+ max_warmup_steps=args.max_warmup_steps,
+ warmup_interval=args.warmup_interval,
+ max_cached_steps=args.max_cached_steps,
+ max_continuous_cached_steps=args.max_continuous_cached_steps,
+ residual_diff_threshold=args.residual_diff_threshold,
+ enable_separate_cfg=kwargs.get("enable_separate_cfg", None),
+ steps_computation_mask=steps_computation_mask,
+ )
+ if cache_config is None and args.cache
+ else cache_config
+ ),
+ calibrator_config=(
+ TaylorSeerCalibratorConfig(
+ taylorseer_order=args.taylorseer_order,
+ )
+ if args.taylorseer
+ else None
+ ),
+ params_modifiers=kwargs.get("params_modifiers", None),
+ parallelism_config=(
+ ParallelismConfig(
+ ulysses_size=(
+ dist.get_world_size() if args.parallel_type == "ulysses" else None
+ ),
+ ring_size=(dist.get_world_size() if args.parallel_type == "ring" else None),
+ tp_size=(dist.get_world_size() if args.parallel_type == "tp" else None),
+ backend=backend,
+ parallel_kwargs=parallel_kwargs,
+ )
+ if parallelism_config is None
+ and args.parallel_type in ["ulysses", "ring", "tp"]
+ else parallelism_config
+ ),
+ )
+ else:
+ # Apply acceleration configs from config path
+ cache_dit.enable_cache(
+ pipe_or_adapter,
+ **cache_dit.load_configs(args.config_path),
+ )
+ logger.info(f"Applied acceleration from {args.config_path}.")
+
+ # Quantization
+ # WARN: Must apply quantization after tensor parallelism is applied.
+ # torchao is compatible with tensor parallelism but requires to be
+ # applied after TP.
+ maybe_quantize_transformer(args, pipe_or_adapter)
+ maybe_quantize_text_encoder(args, pipe_or_adapter)
+ maybe_quantize_controlnet(args, pipe_or_adapter)
+
+ # VAE Tiling or Slicing
+ maybe_vae_tiling_or_slicing(args, pipe_or_adapter)
+
+ # Compilation
+ maybe_compile_transformer(args, pipe_or_adapter)
+ maybe_compile_text_encoder(args, pipe_or_adapter)
+ maybe_compile_controlnet(args, pipe_or_adapter)
+ maybe_compile_vae(args, pipe_or_adapter)
+
+ # CPU Offload
+ _, device = get_rank_device()
+ if not maybe_cpu_offload(args, pipe_or_adapter):
+ # Set device if no cpu offload
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ else:
+ pipe = pipe_or_adapter
+ if pipe is not None and not args.device_map_balance:
+ pipe.to(device)
return pipe_or_adapter
def strify(args, pipe_or_stats):
+ base_str = ""
+ if args.height is not None and args.width is not None:
+ base_str += f"{args.height}x{args.width}_"
quantize_type = args.quantize_type if args.quantize else ""
if quantize_type != "":
quantize_type = f"_{quantize_type}"
- base_str = (
+ base_str += (
f"C{int(args.compile)}_Q{int(args.quantize)}{quantize_type}_"
f"{cache_dit.strify(pipe_or_stats)}"
)
if args.ulysses_anything:
base_str += "_ulysses_anything"
+ if args.ulysses_float8:
+ base_str += "_float8"
+ else:
+ if args.ulysses_float8:
+ base_str += "_ulysses_float8"
+ if args.ulysses_async:
+ base_str += "_ulysses_async"
+ if args.parallel_text_encoder:
+ if "_TEP" not in base_str:
+ base_str += "_TEP" # Text Encoder Parallelism
+ if args.parallel_vae:
+ if "_VAEP" not in base_str:
+ base_str += "_VAEP" # VAE Parallelism
+ if args.parallel_controlnet:
+ if "_CNP" not in base_str:
+ base_str += "_CNP" # ControlNet Parallelism
+ if args.attn is not None:
+ base_str += f"_{args.attn.strip('_')}"
return base_str
+def get_rank_device():
+ available = current_platform.is_accelerator_available()
+ device_type = current_platform.device_type
+ if dist.is_initialized():
+ rank = dist.get_rank()
+ device = torch.device(device_type, rank % current_platform.device_count())
+ return rank, device
+ return 0, torch.device(device_type if available else "cpu")
+
+
def maybe_init_distributed(args=None):
+ from cache_dit.platforms.platform import CpuPlatform
+
+ platform_full_backend = current_platform.full_dist_backend
+ cpu_full_backend = CpuPlatform.full_dist_backend
+ backend = (
+ f"{cpu_full_backend},{platform_full_backend}"
+ if args.ulysses_anything
+ else current_platform.dist_backend
+ )
if args is not None:
if args.parallel_type is not None:
dist.init_process_group(
- backend="cpu:gloo,cuda:nccl" if args.ulysses_anything else "nccl",
+ backend=backend,
+ )
+ rank, device = get_rank_device()
+ current_platform.set_device(device)
+ return rank, device
+ elif args.config_path is not None:
+ # check if distributed is needed from config file
+ has_parallelism_config = cache_dit.load_parallelism_config(
+ args.config_path,
+ check_only=True,
)
- rank = dist.get_rank()
- device = torch.device("cuda", rank % torch.cuda.device_count())
- torch.cuda.set_device(device)
+ if has_parallelism_config:
+ if not dist.is_initialized():
+ dist.init_process_group(
+ backend=backend,
+ )
+ rank, device = get_rank_device()
+ current_platform.set_device(device)
+ return rank, device
+ else:
+ # no distributed needed
+ rank, device = get_rank_device()
+ return rank, device
+ else:
+ # no distributed needed
+ rank, device = get_rank_device()
return rank, device
else:
# always init distributed for other examples
if not dist.is_initialized():
dist.init_process_group(
- backend="nccl",
+ backend=platform_full_backend,
)
- rank = dist.get_rank()
- device = torch.device("cuda", rank % torch.cuda.device_count())
- torch.cuda.set_device(device)
+ rank, device = get_rank_device()
+ current_platform.set_device(device)
return rank, device
- return 0, torch.device("cuda" if torch.cuda.is_available() else "cpu")
def maybe_destroy_distributed():
if dist.is_initialized():
dist.destroy_process_group()
+
+
+def create_profiler_from_args(args, profile_name=None):
+ from cache_dit.profiler import ProfilerContext
+
+ return ProfilerContext(
+ enabled=args.profile,
+ activities=getattr(args, "profile_activities", ["CPU", "GPU"]),
+ output_dir=getattr(args, "profile_dir", None),
+ profile_name=profile_name or getattr(args, "profile_name", None),
+ with_stack=getattr(args, "profile_with_stack", True),
+ record_shapes=getattr(args, "profile_record_shapes", True),
+ )
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 000000000..68439d9fc
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,135 @@
+site_name: Cache-DiT
+site_description: A PyTorch-native and Flexible Inference Engine with Hybrid Cache Acceleration and Parallelism for DiTs
+site_url: https://github.com/vipshop/cache-dit
+site_author: Cache-DiT Team
+
+repo_url: https://github.com/vipshop/cache-dit
+repo_name: vipshop/cache-dit
+edit_uri: edit/main/docs/
+
+theme:
+ name: material
+ language: en
+ logo: https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit-logo-v2.png
+ favicon: https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit-logo-v2.png
+ palette:
+ # Palette toggle for automatic mode
+ - media: "(prefers-color-scheme)"
+ toggle:
+ icon: material/brightness-auto
+ name: Switch to light mode
+ # Palette toggle for light mode
+ - media: "(prefers-color-scheme: light)"
+ scheme: default
+ primary: white
+ toggle:
+ icon: material/brightness-7
+ name: Switch to dark mode
+ # Palette toggle for dark mode
+ - media: "(prefers-color-scheme: dark)"
+ scheme: slate
+ primary: black
+ toggle:
+ icon: material/brightness-2
+ name: Switch to system preference
+ features:
+ - content.action.edit
+ - content.code.copy
+ - navigation.instant
+ - navigation.instant.progress
+ - navigation.tracking
+ - navigation.tabs
+ - navigation.tabs.sticky
+ - navigation.sections
+ - navigation.indexes
+ - navigation.top
+ - navigation.footer
+ - navigation.sidebar
+ - search.suggest
+ - search.highlight
+ - search.share
+ - content.code.annotate
+ - content.tabs
+ - content.tooltips
+ - toc.follow
+
+nav:
+ - Home: README.md
+ - Docmentation:
+ - Overviews: user_guide/OVERVIEWS.md
+ - Installation: user_guide/INSTALL.md
+ - Unified Cache APIs: user_guide/CACHE_API.md
+ - DBCache Design: user_guide/DBCACHE_DESIGN.md
+ - Context Parallelism: user_guide/CONTEXT_PARALLEL.md
+ - Tensor Parallelism: user_guide/TENSOR_PARALLEL.md
+ - Extra Modules Parallelism: user_guide/EXTRA_PARALLEL.md
+ - Low-Bits Quantization: user_guide/QUANTIZATION.md
+ - Attention Backends: user_guide/ATTENTION.md
+ - Use Torch Compile: user_guide/COMPILE.md
+ - Ascend NPU Support: user_guide/ASCEND_NPU.md
+ - Config with YAML: user_guide/LOAD_CONFIGS.md
+ - Serving Deployment: user_guide/SERVING.md
+ - Metrics Tools: user_guide/METRICS.md
+ - Profiler Usage: user_guide/PROFILER.md
+ - API Docmentation: user_guide/API_DOCS.md
+ - Supported Matrix:
+ - NVIDIA GPU: supported_matrix/NVIDIA_GPU.md
+ - Ascend NPU: supported_matrix/ASCEND_NPU.md
+ - Benchmark:
+ - Hybrid Cache: benchmark/HYBRID_CACHE.md
+ - NVIDIA GPU: benchmark/NVIDIA_GPU.md
+ - Ascend NPU: benchmark/ASCEND_NPU.md
+ - Developer Guide:
+ - pre-commit: developer_guide/PRE_COMMIT.md
+ - Support New Model: developer_guide/SUPPORT_NEW_MODEL.md
+ - Quick Examples: EXAMPLES.md
+ - Community Integration: COMMUNITY.md
+ - FAQ: FAQ.md
+
+plugins:
+ - search
+
+markdown_extensions:
+ - admonition
+ - attr_list
+ - md_in_html
+ - tables
+ - fenced_code
+ - codehilite:
+ guess_lang: false
+ - toc:
+ permalink: true
+ toc_depth: 3
+ - pymdownx.highlight:
+ anchor_linenums: true
+ line_spans: __span
+ pygments_lang_class: true
+ linenums_style: pymdownx
+ - pymdownx.inlinehilite
+ - pymdownx.snippets
+ - pymdownx.superfences:
+ custom_fences:
+ - name: mermaid
+ class: mermaid
+ format: !!python/name:pymdownx.superfences.fence_code_format
+ - pymdownx.tabbed:
+ alternate_style: true
+ - pymdownx.emoji:
+ emoji_index: !!python/name:material.extensions.emoji.twemoji
+ emoji_generator: !!python/name:material.extensions.emoji.to_svg
+ - pymdownx.tilde
+ - pymdownx.arithmatex:
+ generic: true
+
+extra:
+ social:
+ - icon: fontawesome/brands/github
+ link: https://github.com/vipshop/cache-dit
+ - icon: fontawesome/brands/python
+ link: https://pypi.org/project/cache-dit/
+ version:
+ provider: mike
+
+extra_javascript:
+ - https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js
+ - https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
diff --git a/pyproject.toml b/pyproject.toml
index 78d90eeba..ed58e9e1c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,7 @@ readme = "README.md"
dependencies = [
"pyyaml",
"torch>=2.7.1",
- "diffusers>=0.35.1",
+ "diffusers>=0.36.0",
"transformers>=4.55.2",
]
@@ -33,6 +33,13 @@ metrics = [
"lpips==0.1.4",
]
+serving = [
+ "fastapi>=0.104.0",
+ "uvicorn>=0.24.0",
+ "pydantic>=2.0.0",
+ "peft",
+]
+
dev = [
"packaging",
"pre-commit",
@@ -49,10 +56,26 @@ dev = [
"scikit-image",
]
+docs = [
+ "mkdocs>=1.5.0",
+ "mkdocs-api-autonav",
+ "mkdocs-material",
+ "mkdocstrings-python",
+ "mkdocs-gen-files",
+ "mkdocs-awesome-nav",
+ "mkdocs-glightbox",
+ "mkdocs-git-revision-date-localized-plugin",
+ "mkdocs-minify-plugin",
+ "regex",
+ "ruff",
+ "pydantic",
+]
+
all = [
"cache-dit[parallelism]",
"cache-dit[quantization]",
"cache-dit[metrics]",
+ "cache-dit[serving]",
]
[project.urls]
diff --git a/setup.cfg b/setup.cfg
index b7019148b..c42243fe4 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -10,6 +10,8 @@ max-line-length = 100
ignore = E731, E203, E402, W503, W504, F821, E501, B, C4, EXE
per-file-ignores =
__init__.py: F401, F403, F405
+ # ignore all errors in the mkdocs config file
+ mkdocs.yml: ALL
exclude = venv
[pydocstyle]
diff --git a/src/cache_dit/__init__.py b/src/cache_dit/__init__.py
index 8371de4cd..bd917b817 100644
--- a/src/cache_dit/__init__.py
+++ b/src/cache_dit/__init__.py
@@ -5,86 +5,43 @@
version_tuple = (0, 0, "unknown version")
-from cache_dit.utils import disable_print
-from cache_dit.logger import init_logger
-from cache_dit.caching import load_options
-from cache_dit.caching import enable_cache
-from cache_dit.caching import steps_mask
-from cache_dit.caching import disable_cache
-from cache_dit.caching import cache_type
-from cache_dit.caching import block_range
-from cache_dit.caching import CacheType
-from cache_dit.caching import BlockAdapter
-from cache_dit.caching import ParamsModifier
-from cache_dit.caching import ForwardPattern
-from cache_dit.caching import PatchFunctor
-from cache_dit.caching import BasicCacheConfig
-from cache_dit.caching import DBCacheConfig
-from cache_dit.caching import DBPruneConfig
-from cache_dit.caching import CalibratorConfig
-from cache_dit.caching import TaylorSeerCalibratorConfig
-from cache_dit.caching import FoCaCalibratorConfig
-from cache_dit.caching import supported_pipelines
-from cache_dit.caching import get_adapter
-from cache_dit.parallelism import ParallelismBackend
-from cache_dit.parallelism import ParallelismConfig
-from cache_dit.compile import set_compile_configs
-from cache_dit.summary import supported_matrix
-from cache_dit.summary import summary
-from cache_dit.summary import strify
-
-try:
- from cache_dit.quantize import quantize
-except ImportError as e: # noqa: F841
- err_msg = str(e)
-
- def quantize(*args, **kwargs):
- raise ImportError(
- "Quantization requires additional dependencies. "
- "Please install cache-dit[quantization] or cache-dit[all] "
- f"to use this feature. Error message: {err_msg}"
- )
-
-
-def enable_compute_comm_overlap():
- try:
- from cache_dit.compile import enable_compile_compute_comm_overlap
-
- enable_compile_compute_comm_overlap()
- except: # noqa: E722
- pass
-
-
-def disable_compute_comm_overlap():
- try:
- from cache_dit.compile import disable_compile_compute_comm_overlap
-
- disable_compile_compute_comm_overlap()
- except: # noqa: E722
- pass
-
-
-try:
- from cache_dit.parallelism import disable_ulysses_anything
- from cache_dit.parallelism import enable_ulysses_anything
-
-except ImportError as e: # noqa: F841
- err_msg = str(e)
-
- def enable_ulysses_anything(*args, **kwargs):
- raise ImportError(
- "Ulysses Anything Attention requires additional dependencies. "
- "Please install cache-dit[parallelism] or cache-dit[all] "
- f"to use this feature. Error message: {err_msg}"
- )
-
- def disable_ulysses_anything(*args, **kwargs):
- raise ImportError(
- "Ulysses Anything Attention requires additional dependencies. "
- "Please install cache-dit[parallelism] or cache-dit[all] "
- f"to use this feature. Error message: {err_msg}"
- )
-
+from .utils import disable_print
+from .logger import init_logger
+from .caching import load_options # deprecated
+from .caching import load_cache_config
+from .caching import load_parallelism_config
+from .caching import load_configs
+from .caching import enable_cache
+from .caching import refresh_context
+from .caching import steps_mask
+from .caching import disable_cache
+from .caching import cache_type
+from .caching import block_range
+from .caching import CacheType
+from .caching import BlockAdapter
+from .caching import ParamsModifier
+from .caching import ForwardPattern
+from .caching import PatchFunctor
+from .caching import BasicCacheConfig
+from .caching import DBCacheConfig
+from .caching import DBPruneConfig
+from .caching import CalibratorConfig
+from .caching import TaylorSeerCalibratorConfig
+from .caching import FoCaCalibratorConfig
+from .caching import supported_pipelines
+from .caching import get_adapter
+from .parallelism import ParallelismBackend
+from .parallelism import ParallelismConfig
+from .compile import set_compile_configs
+from .summary import supported_matrix
+from .summary import summary
+from .summary import strify
+from .profiler import ProfilerContext
+from .profiler import profile_function
+from .profiler import create_profiler_context
+from .profiler import get_profiler_output_dir
+from .profiler import set_profiler_output_dir
+from .quantize import quantize
NONE = CacheType.NONE
DBCache = CacheType.DBCache
diff --git a/src/cache_dit/caching/__init__.py b/src/cache_dit/caching/__init__.py
index 35ebabec3..ac193746b 100644
--- a/src/cache_dit/caching/__init__.py
+++ b/src/cache_dit/caching/__init__.py
@@ -1,37 +1,41 @@
-from cache_dit.caching.cache_types import CacheType
-from cache_dit.caching.cache_types import cache_type
-from cache_dit.caching.cache_types import block_range
+from .cache_types import CacheType
+from .cache_types import cache_type
+from .cache_types import block_range
-from cache_dit.caching.forward_pattern import ForwardPattern
-from cache_dit.caching.params_modifier import ParamsModifier
-from cache_dit.caching.patch_functors import PatchFunctor
+from .forward_pattern import ForwardPattern
+from .params_modifier import ParamsModifier
+from .patch_functors import PatchFunctor
-from cache_dit.caching.block_adapters import BlockAdapter
-from cache_dit.caching.block_adapters import BlockAdapterRegister
-from cache_dit.caching.block_adapters import FakeDiffusionPipeline
+from .block_adapters import BlockAdapter
+from .block_adapters import BlockAdapterRegister
+from .block_adapters import FakeDiffusionPipeline
-from cache_dit.caching.cache_contexts import BasicCacheConfig
-from cache_dit.caching.cache_contexts import DBCacheConfig
-from cache_dit.caching.cache_contexts import CachedContext
-from cache_dit.caching.cache_contexts import CachedContextManager
-from cache_dit.caching.cache_contexts import DBPruneConfig
-from cache_dit.caching.cache_contexts import PrunedContext
-from cache_dit.caching.cache_contexts import PrunedContextManager
-from cache_dit.caching.cache_contexts import ContextManager
-from cache_dit.caching.cache_contexts import CalibratorConfig
-from cache_dit.caching.cache_contexts import TaylorSeerCalibratorConfig
-from cache_dit.caching.cache_contexts import FoCaCalibratorConfig
+from .cache_contexts import BasicCacheConfig
+from .cache_contexts import DBCacheConfig
+from .cache_contexts import CachedContext
+from .cache_contexts import CachedContextManager
+from .cache_contexts import DBPruneConfig
+from .cache_contexts import PrunedContext
+from .cache_contexts import PrunedContextManager
+from .cache_contexts import ContextManager
+from .cache_contexts import CalibratorConfig
+from .cache_contexts import TaylorSeerCalibratorConfig
+from .cache_contexts import FoCaCalibratorConfig
-from cache_dit.caching.cache_blocks import CachedBlocks
-from cache_dit.caching.cache_blocks import PrunedBlocks
-from cache_dit.caching.cache_blocks import UnifiedBlocks
+from .cache_blocks import CachedBlocks
+from .cache_blocks import PrunedBlocks
+from .cache_blocks import UnifiedBlocks
-from cache_dit.caching.cache_adapters import CachedAdapter
+from .cache_adapters import CachedAdapter
-from cache_dit.caching.cache_interface import enable_cache
-from cache_dit.caching.cache_interface import disable_cache
-from cache_dit.caching.cache_interface import supported_pipelines
-from cache_dit.caching.cache_interface import get_adapter
-from cache_dit.caching.cache_interface import steps_mask
+from .cache_interface import enable_cache
+from .cache_interface import refresh_context
+from .cache_interface import disable_cache
+from .cache_interface import supported_pipelines
+from .cache_interface import get_adapter
+from .cache_interface import steps_mask
-from cache_dit.caching.utils import load_options
+from .utils import load_options # deprecated
+from .utils import load_cache_config
+from .utils import load_parallelism_config
+from .utils import load_configs
diff --git a/src/cache_dit/caching/block_adapters/__init__.py b/src/cache_dit/caching/block_adapters/__init__.py
index 8df83cf73..a3cb08d26 100644
--- a/src/cache_dit/caching/block_adapters/__init__.py
+++ b/src/cache_dit/caching/block_adapters/__init__.py
@@ -1,739 +1,70 @@
-import os
-from cache_dit.caching.forward_pattern import ForwardPattern
-from cache_dit.caching.block_adapters.block_adapters import BlockAdapter
-from cache_dit.caching.block_adapters.block_adapters import (
- FakeDiffusionPipeline,
-)
-from cache_dit.caching.block_adapters.block_adapters import ParamsModifier
-from cache_dit.caching.block_adapters.block_registers import (
- BlockAdapterRegister,
-)
+import importlib
+from typing import Callable
+from .block_adapters import BlockAdapter
+from .block_adapters import FakeDiffusionPipeline
+from .block_adapters import ParamsModifier
+from .block_registers import BlockAdapterRegister
+from cache_dit.logger import init_logger
+logger = init_logger(__name__)
-@BlockAdapterRegister.register("Flux")
-def flux_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import FluxTransformer2DModel
- from cache_dit.utils import is_diffusers_at_least_0_3_5
- from cache_dit.caching.patch_functors import FluxPatchFunctor
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
- transformer_cls_name: str = pipe.transformer.__class__.__name__
- if is_diffusers_at_least_0_3_5() and not transformer_cls_name.startswith("Nunchaku"):
- # NOTE(DefTruth): Users should never use this variable directly,
- # it is only for developers to control whether to enable dummy
- # blocks, default to enabled.
- _CACHE_DIT_FLUX_ENABLE_DUMMY_BLOCKS = (
- os.environ.get("CACHE_DIT_FLUX_ENABLE_DUMMY_BLOCKS", "1") == "1"
- )
-
- if not _CACHE_DIT_FLUX_ENABLE_DUMMY_BLOCKS:
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_1,
- ForwardPattern.Pattern_1,
- ],
- check_forward_pattern=True,
- **kwargs,
- )
- else:
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=(
- pipe.transformer.transformer_blocks + pipe.transformer.single_transformer_blocks
- ),
- blocks_name="transformer_blocks",
- dummy_blocks_names=["single_transformer_blocks"],
- patch_functor=FluxPatchFunctor(),
- forward_pattern=ForwardPattern.Pattern_1,
- **kwargs,
- )
- else:
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_1,
- ForwardPattern.Pattern_3,
- ],
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("Mochi")
-def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import MochiTransformer3DModel
-
- assert isinstance(pipe.transformer, MochiTransformer3DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_0,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("CogVideoX")
-def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import CogVideoXTransformer3DModel
-
- assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_0,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("Wan")
-def wan_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import (
- WanTransformer3DModel,
- WanVACETransformer3DModel,
- )
- from cache_dit.caching.patch_functors import WanVACEPatchFunctor
-
- assert isinstance(
- pipe.transformer,
- (WanTransformer3DModel, WanVACETransformer3DModel),
- )
- cls_name = pipe.transformer.__class__.__name__
- patch_functor = WanVACEPatchFunctor() if cls_name.startswith("WanVACE") else None
-
- if getattr(pipe, "transformer_2", None):
- assert isinstance(
- pipe.transformer_2,
- (WanTransformer3DModel, WanVACETransformer3DModel),
- )
- # Wan 2.2 MoE
- return BlockAdapter(
- pipe=pipe,
- transformer=[
- pipe.transformer,
- pipe.transformer_2,
- ],
- blocks=[
- pipe.transformer.blocks,
- pipe.transformer_2.blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_2,
- ForwardPattern.Pattern_2,
- ],
- patch_functor=patch_functor,
- check_forward_pattern=True,
- has_separate_cfg=True,
- **kwargs,
- )
- else:
- # Wan 2.1
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.blocks,
- forward_pattern=ForwardPattern.Pattern_2,
- patch_functor=patch_functor,
- check_forward_pattern=True,
- has_separate_cfg=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("HunyuanVideo")
-def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import HunyuanVideoTransformer3DModel
-
- assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_0,
- ForwardPattern.Pattern_0,
- ],
- check_forward_pattern=True,
- # The type hint in diffusers is wrong
- check_num_outputs=False,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("QwenImage")
-def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import QwenImageTransformer2DModel
-
- assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-
- pipe_cls_name: str = pipe.__class__.__name__
- if pipe_cls_name.startswith("QwenImageControlNet"):
- from cache_dit.caching.patch_functors import (
- QwenImageControlNetPatchFunctor,
- )
-
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_1,
- patch_functor=QwenImageControlNetPatchFunctor(),
- check_forward_pattern=True,
- has_separate_cfg=True,
- )
- else:
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_1,
- check_forward_pattern=True,
- has_separate_cfg=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("LTX")
-def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import LTXVideoTransformer3DModel
-
- assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_2,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("Allegro")
-def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import AllegroTransformer3DModel
-
- assert isinstance(pipe.transformer, AllegroTransformer3DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_2,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("CogView3Plus")
-def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import CogView3PlusTransformer2DModel
-
- assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_0,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("CogView4")
-def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import CogView4Transformer2DModel
-
- assert isinstance(pipe.transformer, CogView4Transformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_0,
- check_forward_pattern=True,
- has_separate_cfg=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("Cosmos")
-def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import CosmosTransformer3DModel
-
- assert isinstance(pipe.transformer, CosmosTransformer3DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_2,
- check_forward_pattern=True,
- has_separate_cfg=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("EasyAnimate")
-def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import EasyAnimateTransformer3DModel
-
- assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_0,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("SkyReelsV2")
-def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import SkyReelsV2Transformer3DModel
-
- assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.blocks,
- # NOTE: Use Pattern_3 instead of Pattern_2 because the
- # encoder_hidden_states will never change in the blocks
- # forward loop.
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=True,
- has_separate_cfg=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("StableDiffusion3")
-def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import SD3Transformer2DModel
-
- assert isinstance(pipe.transformer, SD3Transformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_1,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("ConsisID")
-def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import ConsisIDTransformer3DModel
-
- assert isinstance(pipe.transformer, ConsisIDTransformer3DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_0,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("DiT")
-def dit_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import DiTTransformer2DModel
- from cache_dit.caching.patch_functors import DiTPatchFunctor
-
- assert isinstance(pipe.transformer, DiTTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- patch_functor=DiTPatchFunctor(),
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("Amused")
-def amused_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import UVit2DModel
-
- assert isinstance(pipe.transformer, UVit2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_layers,
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("Bria")
-def bria_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import BriaTransformer2DModel
-
- assert isinstance(pipe.transformer, BriaTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_0,
- ForwardPattern.Pattern_0,
- ],
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("Lumina")
-def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import Lumina2Transformer2DModel
- from diffusers import LuminaNextDiT2DModel
-
- assert isinstance(pipe.transformer, (Lumina2Transformer2DModel, LuminaNextDiT2DModel))
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.layers,
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("OmniGen")
-def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import OmniGenTransformer2DModel
-
- assert isinstance(pipe.transformer, OmniGenTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.layers,
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("PixArt")
-def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import PixArtTransformer2DModel
-
- assert isinstance(pipe.transformer, PixArtTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("Sana")
-def sana_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import SanaTransformer2DModel
-
- assert isinstance(pipe.transformer, SanaTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("StableAudio")
-def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import StableAudioDiTModel
-
- assert isinstance(pipe.transformer, StableAudioDiTModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("VisualCloze")
-def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import FluxTransformer2DModel
- from cache_dit.utils import is_diffusers_at_least_0_3_5
-
- assert isinstance(pipe.transformer, FluxTransformer2DModel)
- if is_diffusers_at_least_0_3_5():
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_1,
- ForwardPattern.Pattern_1,
- ],
- check_forward_pattern=True,
- **kwargs,
- )
- else:
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_1,
- ForwardPattern.Pattern_3,
- ],
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("AuraFlow")
-def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import AuraFlowTransformer2DModel
-
- assert isinstance(pipe.transformer, AuraFlowTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.single_transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("Chroma")
-def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import ChromaTransformer2DModel
- from cache_dit.caching.patch_functors import ChromaPatchFunctor
-
- assert isinstance(pipe.transformer, ChromaTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_1,
- ForwardPattern.Pattern_3,
- ],
- patch_functor=ChromaPatchFunctor(),
- check_forward_pattern=True,
- has_separate_cfg=True,
- **kwargs,
+def import_error_adapter(
+ *args,
+ **kwargs,
+) -> BlockAdapter:
+ raise ImportError(
+ "This BlockAdapter requires latest diffusers to be installed. "
+ "Please install diffusers from source."
)
-@BlockAdapterRegister.register("ShapE")
-def shape_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import PriorTransformer
-
- assert isinstance(pipe.prior, PriorTransformer)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.prior,
- blocks=pipe.prior.transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("HiDream")
-def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
- # NOTE: Need to patch Transformer forward to fully support
- # double_stream_blocks and single_stream_blocks, namely, need
- # to remove the logics inside the blocks forward loop:
- # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
- # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
- from diffusers import HiDreamImageTransformer2DModel
- from cache_dit.caching.patch_functors import HiDreamPatchFunctor
-
- assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.double_stream_blocks,
- pipe.transformer.single_stream_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_0,
- ForwardPattern.Pattern_3,
- ],
- patch_functor=HiDreamPatchFunctor(),
- # NOTE: The type hint in diffusers is wrong
- check_forward_pattern=True,
- check_num_outputs=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("HunyuanDiT")
-def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
- from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
-
- assert isinstance(
- pipe.transformer,
- (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
- )
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- patch_functor=HunyuanDiTPatchFunctor(),
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("HunyuanDiTPAG")
-def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
- from diffusers import HunyuanDiT2DModel
- from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
-
- assert isinstance(pipe.transformer, HunyuanDiT2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- patch_functor=HunyuanDiTPatchFunctor(),
- check_forward_pattern=True,
- **kwargs,
- )
-
-
-@BlockAdapterRegister.register("Kandinsky5")
-def kandinsky5_adapter(pipe, **kwargs) -> BlockAdapter:
- try:
- from diffusers import Kandinsky5Transformer3DModel
-
- assert isinstance(pipe.transformer, Kandinsky5Transformer3DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.visual_transformer_blocks,
- forward_pattern=ForwardPattern.Pattern_3, # or Pattern_2
- has_separate_cfg=True,
- check_forward_pattern=False,
- check_num_outputs=False,
- **kwargs,
- )
- except ImportError:
- raise ImportError(
- "Kandinsky5Transformer3DModel is not available in the current diffusers version. "
- "Please upgrade diffusers>=0.36.dev0 to use this adapter."
- )
-
-
-@BlockAdapterRegister.register("PRX")
-def prx_adapter(pipe, **kwargs) -> BlockAdapter:
- try:
- from diffusers import PRXTransformer2DModel
-
- assert isinstance(pipe.transformer, PRXTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.blocks,
- forward_pattern=ForwardPattern.Pattern_3,
- check_forward_pattern=True,
- check_num_outputs=False,
- **kwargs,
- )
- except ImportError:
- raise ImportError(
- "PRXTransformer2DModel is not available in the current diffusers version. "
- "Please upgrade diffusers>=0.36.dev0 to use this adapter."
- )
-
-
-@BlockAdapterRegister.register("HunyuanImage")
-def hunyuan_image_adapter(pipe, **kwargs) -> BlockAdapter:
- try:
- from diffusers import HunyuanImageTransformer2DModel
-
- assert isinstance(pipe.transformer, HunyuanImageTransformer2DModel)
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=[
- pipe.transformer.transformer_blocks,
- pipe.transformer.single_transformer_blocks,
- ],
- forward_pattern=[
- ForwardPattern.Pattern_0,
- ForwardPattern.Pattern_0,
- ],
- # set `has_separate_cfg` as True to enable separate cfg caching
- # since in hyimage-2.1 the `guider_state` contains 2 input batches.
- # The cfg is `enabled` by default in AdaptiveProjectedMixGuidance.
- has_separate_cfg=True,
- check_forward_pattern=True,
- **kwargs,
- )
- except ImportError:
- raise ImportError(
- "HunyuanImageTransformer2DModel is not available in the current diffusers version. "
- "Please upgrade diffusers>=0.36.dev0 to use this adapter."
- )
-
-
-@BlockAdapterRegister.register("ChronoEdit")
-def chronoedit_adapter(pipe, **kwargs) -> BlockAdapter:
+def _safe_import(module_name: str, func_name: str) -> Callable[..., BlockAdapter]:
try:
- from diffusers import ChronoEditTransformer3DModel
-
- assert isinstance(pipe.transformer, ChronoEditTransformer3DModel)
- # Same as Wan 2.1 adapter
- return BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.blocks,
- forward_pattern=ForwardPattern.Pattern_2,
- check_forward_pattern=True,
- has_separate_cfg=True,
- **kwargs,
- )
- except ImportError:
- raise ImportError(
- "ChronoEditTransformer3DModel is not available in the current diffusers version. "
- "Please upgrade diffusers>=0.36.dev0 to use this adapter."
- )
+ # e.g., module_name = ".adapters", func_name = "flux_adapter"
+ package = __package__ if __package__ is not None else ""
+ module = importlib.import_module(module_name, package=package)
+ target_func = getattr(module, func_name)
+ return target_func
+ except (ImportError, AttributeError) as e:
+ logger.debug(f"Failed to import {func_name} from {module_name}: {e}")
+ return import_error_adapter
+
+
+flux_adapter = _safe_import(".adapters", "flux_adapter")
+mochi_adapter = _safe_import(".adapters", "mochi_adapter")
+cogvideox_adapter = _safe_import(".adapters", "cogvideox_adapter")
+wan_adapter = _safe_import(".adapters", "wan_adapter")
+hunyuanvideo_adapter = _safe_import(".adapters", "hunyuanvideo_adapter")
+qwenimage_adapter = _safe_import(".adapters", "qwenimage_adapter")
+ltxvideo_adapter = _safe_import(".adapters", "ltxvideo_adapter")
+allegro_adapter = _safe_import(".adapters", "allegro_adapter")
+cogview3plus_adapter = _safe_import(".adapters", "cogview3plus_adapter")
+cogview4_adapter = _safe_import(".adapters", "cogview4_adapter")
+cosmos_adapter = _safe_import(".adapters", "cosmos_adapter")
+easyanimate_adapter = _safe_import(".adapters", "easyanimate_adapter")
+skyreelsv2_adapter = _safe_import(".adapters", "skyreelsv2_adapter")
+sd3_adapter = _safe_import(".adapters", "sd3_adapter")
+consisid_adapter = _safe_import(".adapters", "consisid_adapter")
+dit_adapter = _safe_import(".adapters", "dit_adapter")
+amused_adapter = _safe_import(".adapters", "amused_adapter")
+bria_adapter = _safe_import(".adapters", "bria_adapter")
+lumina2_adapter = _safe_import(".adapters", "lumina2_adapter")
+omnigen_adapter = _safe_import(".adapters", "omnigen_adapter")
+pixart_adapter = _safe_import(".adapters", "pixart_adapter")
+sana_adapter = _safe_import(".adapters", "sana_adapter")
+stabledudio_adapter = _safe_import(".adapters", "stabledudio_adapter")
+visualcloze_adapter = _safe_import(".adapters", "visualcloze_adapter")
+auraflow_adapter = _safe_import(".adapters", "auraflow_adapter")
+chroma_adapter = _safe_import(".adapters", "chroma_adapter")
+shape_adapter = _safe_import(".adapters", "shape_adapter")
+hidream_adapter = _safe_import(".adapters", "hidream_adapter")
+hunyuandit_adapter = _safe_import(".adapters", "hunyuandit_adapter")
+hunyuanditpag_adapter = _safe_import(".adapters", "hunyuanditpag_adapter")
+kandinsky5_adapter = _safe_import(".adapters", "kandinsky5_adapter")
+prx_adapter = _safe_import(".adapters", "prx_adapter")
+hunyuan_image_adapter = _safe_import(".adapters", "hunyuan_image_adapter")
+chronoedit_adapter = _safe_import(".adapters", "chronoedit_adapter")
+zimage_adapter = _safe_import(".adapters", "zimage_adapter")
+ovis_image_adapter = _safe_import(".adapters", "ovis_image_adapter")
+longcat_image_adapter = _safe_import(".adapters", "longcat_image_adapter")
diff --git a/src/cache_dit/caching/block_adapters/adapters.py b/src/cache_dit/caching/block_adapters/adapters.py
new file mode 100644
index 000000000..50b411ce7
--- /dev/null
+++ b/src/cache_dit/caching/block_adapters/adapters.py
@@ -0,0 +1,904 @@
+import torch
+from typing import List, Tuple, Union, Optional
+from ..forward_pattern import ForwardPattern
+from .block_adapters import BlockAdapter
+from .block_registers import BlockAdapterRegister
+from cache_dit.envs import ENV
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def _relaxed_assert(
+ transformer: torch.nn.Module,
+ allow_classes: Optional[
+ Union[
+ torch.nn.Module,
+ List[torch.nn.Module],
+ Tuple[torch.nn.Module],
+ ]
+ ] = None,
+) -> None:
+ if allow_classes is not None and not isinstance(allow_classes, (list, tuple)):
+ allow_classes = (allow_classes,)
+ _imported_module_ = transformer.__module__
+ if _imported_module_.startswith("diffusers"):
+ # Only apply strict check for Diffusers transformers
+ if allow_classes is not None:
+ assert isinstance(transformer, allow_classes), (
+ f"Transformer class {transformer.__class__.__name__} not in "
+ f"allowed classes: {[cls.__name__ for cls in allow_classes]}"
+ )
+ else:
+ logger.warning(
+ "No allowed classes provided for transformer strict type check "
+ "in BlockAdapter. Skipping strict type check."
+ )
+ else:
+ # Otherwise, just log a warning and skip strict type check, e.g:
+ # sglang/multimodal_gen/runtime/models/dits/flux.py#L411
+ logger.warning(
+ f"Transformer class {transformer.__class__.__name__} is from "
+ f"{_imported_module_.split('.')[0]} not diffusers, skipping strict type check "
+ "in BlockAdapter."
+ )
+
+
+@BlockAdapterRegister.register("Flux")
+def flux_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import FluxTransformer2DModel
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
+ from cache_dit.caching.patch_functors import FluxPatchFunctor
+
+ supported_transformers = (FluxTransformer2DModel,)
+ try:
+ from diffusers import Flux2Transformer2DModel
+
+ supported_transformers += (Flux2Transformer2DModel,)
+ except ImportError:
+ Flux2Transformer2DModel = None # requires diffusers>=0.36.dev0
+
+ _relaxed_assert(pipe.transformer, supported_transformers)
+
+ transformer_cls_name: str = pipe.transformer.__class__.__name__
+ if (
+ is_diffusers_at_least_0_3_5()
+ and not transformer_cls_name.startswith("Nunchaku")
+ and not transformer_cls_name.startswith("Flux2")
+ ):
+ # NOTE(DefTruth): Users should never use this variable directly,
+ # it is only for developers to control whether to enable dummy
+ # blocks, default to enabled.
+
+ if not ENV.CACHE_DIT_FLUX_ENABLE_DUMMY_BLOCKS:
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_1,
+ ],
+ check_forward_pattern=True,
+ **kwargs,
+ )
+ else:
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=(
+ pipe.transformer.transformer_blocks + pipe.transformer.single_transformer_blocks
+ ),
+ blocks_name="transformer_blocks",
+ dummy_blocks_names=["single_transformer_blocks"],
+ patch_functor=FluxPatchFunctor(),
+ forward_pattern=ForwardPattern.Pattern_1,
+ **kwargs,
+ )
+ else:
+ # Case for Flux2Transformer2DModel and NunchakuFluxTransformer2DModel
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_3,
+ ],
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("Mochi")
+def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import MochiTransformer3DModel
+
+ _relaxed_assert(pipe.transformer, MochiTransformer3DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_0,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("CogVideoX")
+def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import CogVideoXTransformer3DModel
+
+ _relaxed_assert(pipe.transformer, CogVideoXTransformer3DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_0,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("Wan")
+def wan_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import (
+ WanTransformer3DModel,
+ WanVACETransformer3DModel,
+ )
+ from cache_dit.caching.patch_functors import WanVACEPatchFunctor
+
+ _relaxed_assert(
+ pipe.transformer,
+ (WanTransformer3DModel, WanVACETransformer3DModel),
+ )
+ cls_name = pipe.transformer.__class__.__name__ # type: str
+ patch_functor = WanVACEPatchFunctor() if cls_name.startswith("WanVACE") else None
+
+ if getattr(pipe, "transformer_2", None):
+ _relaxed_assert(
+ pipe.transformer_2,
+ (WanTransformer3DModel, WanVACETransformer3DModel),
+ )
+ # Wan 2.2 MoE
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=[
+ pipe.transformer,
+ pipe.transformer_2,
+ ],
+ blocks=[
+ pipe.transformer.blocks,
+ pipe.transformer_2.blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_2,
+ ForwardPattern.Pattern_2,
+ ],
+ patch_functor=patch_functor,
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+ else:
+ # Wan 2.1 or Transformer only case
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.blocks,
+ forward_pattern=ForwardPattern.Pattern_2,
+ patch_functor=patch_functor,
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("HunyuanVideo")
+def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import HunyuanVideoTransformer3DModel
+
+ transformer_cls_name: str = pipe.transformer.__class__.__name__
+ supported_transformers = (HunyuanVideoTransformer3DModel,)
+ try:
+ from diffusers import HunyuanVideo15Transformer3DModel
+
+ supported_transformers += (HunyuanVideo15Transformer3DModel,)
+ except ImportError:
+ HunyuanVideo15Transformer3DModel = None # requires diffusers>=0.36.dev0
+
+ _relaxed_assert(pipe.transformer, supported_transformers)
+
+ if transformer_cls_name.startswith("HunyuanVideo15"):
+ # HunyuanVideo 1.5, has speparate cfg for conditional and unconditional forward
+ # Reference:
+ # - https://huggingface.co/hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-480p_t2v/blob/main/guider/guider_config.json#L4
+ # - https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/hunyuan_video1_5/pipeline_hunyuan_video1_5.py#L753
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_0,
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+ else:
+ # HunyuanVideo 1.0
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_0,
+ ],
+ check_forward_pattern=True,
+ # The type hint in diffusers is wrong
+ check_num_outputs=False,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("QwenImage")
+def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
+ try:
+ from diffusers import QwenImageTransformer2DModel
+ except ImportError:
+ QwenImageTransformer2DModel = None # requires diffusers>=0.35.2
+
+ _relaxed_assert(pipe.transformer, QwenImageTransformer2DModel)
+
+ pipe_cls_name: str = pipe.__class__.__name__
+ if pipe_cls_name.startswith("QwenImageControlNet"):
+ from cache_dit.caching.patch_functors import (
+ QwenImageControlNetPatchFunctor,
+ )
+
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_1,
+ patch_functor=QwenImageControlNetPatchFunctor(),
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ )
+ else:
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_1,
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("LTX")
+def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
+ # LTX-1 (LTXVideoTransformer3DModel) and LTX-2 (LTX2VideoTransformer3DModel) share
+ # the `transformer_blocks` structure, but differ in block forward IO:
+ # - LTX-1 blocks return only `hidden_states` -> Pattern_2
+ # - LTX-2 blocks return `(hidden_states, audio_hidden_states)` -> Pattern_0
+ from diffusers import LTXVideoTransformer3DModel
+
+ cls_name: str = pipe.transformer.__class__.__name__
+ is_ltx2: bool = cls_name.startswith("LTX2")
+ forward_pattern = ForwardPattern.Pattern_0 if is_ltx2 else ForwardPattern.Pattern_2
+
+ try:
+ from diffusers import LTX2VideoTransformer3DModel
+ from cache_dit.caching.patch_functors import LTX2PatchFunctor
+
+ patch_functor = LTX2PatchFunctor() if is_ltx2 else None
+ except Exception:
+ LTX2VideoTransformer3DModel = None # requires newer diffusers
+ patch_functor = None
+
+ supported_transformers = (LTXVideoTransformer3DModel,)
+ if LTX2VideoTransformer3DModel is not None:
+ supported_transformers = supported_transformers + (LTX2VideoTransformer3DModel,)
+
+ _relaxed_assert(pipe.transformer, supported_transformers)
+
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=forward_pattern,
+ patch_functor=patch_functor,
+ # Tips: Treat the audio_hidden_states in LTX-2 as encoder_hidden_states in Pattern_0
+ # while using cache. This values will not affect the correctness since audio_hidden_states
+ # will be cache and update normally.
+ check_forward_pattern=False,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("Allegro")
+def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import AllegroTransformer3DModel
+
+ _relaxed_assert(pipe.transformer, AllegroTransformer3DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_2,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("CogView3Plus")
+def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import CogView3PlusTransformer2DModel
+
+ _relaxed_assert(pipe.transformer, CogView3PlusTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_0,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("CogView4")
+def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import CogView4Transformer2DModel
+
+ _relaxed_assert(pipe.transformer, CogView4Transformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_0,
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("Cosmos")
+def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import CosmosTransformer3DModel
+
+ _relaxed_assert(pipe.transformer, CosmosTransformer3DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_2,
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("EasyAnimate")
+def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import EasyAnimateTransformer3DModel
+
+ _relaxed_assert(pipe.transformer, EasyAnimateTransformer3DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_0,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("SkyReelsV2")
+def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import SkyReelsV2Transformer3DModel
+
+ _relaxed_assert(pipe.transformer, SkyReelsV2Transformer3DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.blocks,
+ # NOTE: Use Pattern_3 instead of Pattern_2 because the
+ # encoder_hidden_states will never change in the blocks
+ # forward loop.
+ forward_pattern=ForwardPattern.Pattern_3,
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("StableDiffusion3")
+def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import SD3Transformer2DModel
+
+ _relaxed_assert(pipe.transformer, SD3Transformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_1,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("ConsisID")
+def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import ConsisIDTransformer3DModel
+
+ _relaxed_assert(pipe.transformer, ConsisIDTransformer3DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_0,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("DiT")
+def dit_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import DiTTransformer2DModel
+ from cache_dit.caching.patch_functors import DiTPatchFunctor
+
+ _relaxed_assert(pipe.transformer, DiTTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_3,
+ patch_functor=DiTPatchFunctor(),
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("Amused")
+def amused_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import UVit2DModel
+
+ _relaxed_assert(pipe.transformer, UVit2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_layers,
+ forward_pattern=ForwardPattern.Pattern_3,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("Bria")
+def bria_adapter(pipe, **kwargs) -> BlockAdapter:
+ try:
+ from diffusers import BriaTransformer2DModel
+ except ImportError:
+ BriaTransformer2DModel = None # requires diffusers>=0.36.dev0
+
+ _relaxed_assert(pipe.transformer, BriaTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_0,
+ ],
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("Lumina")
+def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import Lumina2Transformer2DModel
+ from diffusers import LuminaNextDiT2DModel
+
+ _relaxed_assert(pipe.transformer, (Lumina2Transformer2DModel, LuminaNextDiT2DModel))
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.layers,
+ forward_pattern=ForwardPattern.Pattern_3,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("OmniGen")
+def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import OmniGenTransformer2DModel
+
+ _relaxed_assert(pipe.transformer, OmniGenTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.layers,
+ forward_pattern=ForwardPattern.Pattern_3,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("PixArt")
+def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import PixArtTransformer2DModel
+
+ _relaxed_assert(pipe.transformer, PixArtTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_3,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("Sana")
+def sana_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import SanaTransformer2DModel
+
+ _relaxed_assert(pipe.transformer, SanaTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_3,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("StableAudio")
+def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import StableAudioDiTModel
+
+ _relaxed_assert(pipe.transformer, StableAudioDiTModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_3,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("VisualCloze")
+def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import FluxTransformer2DModel
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
+
+ _relaxed_assert(pipe.transformer, FluxTransformer2DModel)
+ if is_diffusers_at_least_0_3_5():
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_1,
+ ],
+ check_forward_pattern=True,
+ **kwargs,
+ )
+ else:
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_3,
+ ],
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("AuraFlow")
+def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import AuraFlowTransformer2DModel
+
+ _relaxed_assert(pipe.transformer, AuraFlowTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.single_transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_3,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("Chroma")
+def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import ChromaTransformer2DModel
+ from cache_dit.caching.patch_functors import ChromaPatchFunctor
+
+ _relaxed_assert(pipe.transformer, ChromaTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_3,
+ ],
+ patch_functor=ChromaPatchFunctor(),
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("ShapE")
+def shape_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import PriorTransformer
+
+ _relaxed_assert(pipe.prior, PriorTransformer)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.prior,
+ blocks=pipe.prior.transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_3,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("HiDream")
+def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
+ # NOTE: Need to patch Transformer forward to fully support
+ # double_stream_blocks and single_stream_blocks, namely, need
+ # to remove the logics inside the blocks forward loop:
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
+ from diffusers import HiDreamImageTransformer2DModel
+ from cache_dit.caching.patch_functors import HiDreamPatchFunctor
+
+ _relaxed_assert(pipe.transformer, HiDreamImageTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.double_stream_blocks,
+ pipe.transformer.single_stream_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_3,
+ ],
+ patch_functor=HiDreamPatchFunctor(),
+ # NOTE: The type hint in diffusers is wrong
+ check_forward_pattern=True,
+ check_num_outputs=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("HunyuanDiT")
+def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
+ from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
+
+ _relaxed_assert(
+ pipe.transformer,
+ (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
+ )
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.blocks,
+ forward_pattern=ForwardPattern.Pattern_3,
+ patch_functor=HunyuanDiTPatchFunctor(),
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("HunyuanDiTPAG")
+def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
+ from diffusers import HunyuanDiT2DModel
+ from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
+
+ _relaxed_assert(pipe.transformer, HunyuanDiT2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.blocks,
+ forward_pattern=ForwardPattern.Pattern_3,
+ patch_functor=HunyuanDiTPatchFunctor(),
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("Kandinsky5")
+def kandinsky5_adapter(pipe, **kwargs) -> BlockAdapter:
+ try:
+ from diffusers import Kandinsky5Transformer3DModel
+ except ImportError:
+ Kandinsky5Transformer3DModel = None # requires diffusers>=0.36.dev
+
+ _relaxed_assert(pipe.transformer, Kandinsky5Transformer3DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.visual_transformer_blocks,
+ forward_pattern=ForwardPattern.Pattern_3, # or Pattern_2
+ has_separate_cfg=True,
+ check_forward_pattern=False,
+ check_num_outputs=False,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("PRX")
+def prx_adapter(pipe, **kwargs) -> BlockAdapter:
+ try:
+ from diffusers import PRXTransformer2DModel
+ except ImportError:
+ PRXTransformer2DModel = None # requires diffusers>=0.36.dev0
+
+ _relaxed_assert(pipe.transformer, PRXTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.blocks,
+ forward_pattern=ForwardPattern.Pattern_3,
+ check_forward_pattern=True,
+ check_num_outputs=False,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("HunyuanImage")
+def hunyuan_image_adapter(pipe, **kwargs) -> BlockAdapter:
+ try:
+ from diffusers import HunyuanImageTransformer2DModel
+ except ImportError:
+ HunyuanImageTransformer2DModel = None # requires diffusers>=0.36
+
+ _relaxed_assert(pipe.transformer, HunyuanImageTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_0,
+ ],
+ # set `has_separate_cfg` as True to enable separate cfg caching
+ # since in hyimage-2.1 the `guider_state` contains 2 input batches.
+ # The cfg is `enabled` by default in AdaptiveProjectedMixGuidance.
+ has_separate_cfg=True,
+ check_forward_pattern=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("ChronoEdit")
+def chronoedit_adapter(pipe, **kwargs) -> BlockAdapter:
+ try:
+ from diffusers import ChronoEditTransformer3DModel
+ except ImportError:
+ ChronoEditTransformer3DModel = None # requires diffusers>=0.36.dev0
+
+ _relaxed_assert(pipe.transformer, ChronoEditTransformer3DModel)
+ # Same as Wan 2.1 adapter
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.blocks,
+ forward_pattern=ForwardPattern.Pattern_2,
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("ZImage")
+def zimage_adapter(pipe, **kwargs) -> BlockAdapter:
+ from cache_dit.caching.patch_functors import ZImageControlNetPatchFunctor
+
+ try:
+ from diffusers import ZImageTransformer2DModel
+ except ImportError:
+ ZImageTransformer2DModel = None # requires diffusers>=0.36.dev0
+
+ has_controlnet = hasattr(pipe, "controlnet") and pipe.controlnet is not None
+ _relaxed_assert(pipe.transformer, ZImageTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.layers,
+ forward_pattern=ForwardPattern.Pattern_3,
+ patch_functor=ZImageControlNetPatchFunctor() if has_controlnet else None,
+ # ZImage DON'T have 'hidden_states' (use 'x') in its block
+ # forward signature. So we disable the forward pattern check here.
+ check_forward_pattern=False,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("OvisImage")
+def ovis_image_adapter(pipe, **kwargs) -> BlockAdapter:
+ try:
+ from diffusers import OvisImageTransformer2DModel
+ except ImportError:
+ OvisImageTransformer2DModel = None # requires diffusers>=0.36.dev
+
+ _relaxed_assert(pipe.transformer, OvisImageTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_1,
+ ],
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
+
+
+@BlockAdapterRegister.register("LongCatImage")
+def longcat_image_adapter(pipe, **kwargs) -> BlockAdapter:
+ try:
+ from diffusers import LongCatImageTransformer2DModel
+ except ImportError:
+ LongCatImageTransformer2DModel = None # requires diffusers>=0.36.dev
+
+ _relaxed_assert(pipe.transformer, LongCatImageTransformer2DModel)
+ return BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=[
+ pipe.transformer.transformer_blocks,
+ pipe.transformer.single_transformer_blocks,
+ ],
+ forward_pattern=[
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_1,
+ ],
+ check_forward_pattern=True,
+ has_separate_cfg=True,
+ **kwargs,
+ )
diff --git a/src/cache_dit/caching/block_adapters/block_adapters.py b/src/cache_dit/caching/block_adapters/block_adapters.py
index 97aa7cdac..fef187a13 100644
--- a/src/cache_dit/caching/block_adapters/block_adapters.py
+++ b/src/cache_dit/caching/block_adapters/block_adapters.py
@@ -7,9 +7,9 @@
from typing import Any, Tuple, List, Optional, Union
from diffusers import DiffusionPipeline, ModelMixin
-from cache_dit.caching.patch_functors import PatchFunctor
-from cache_dit.caching.forward_pattern import ForwardPattern
-from cache_dit.caching.params_modifier import ParamsModifier
+from ..patch_functors import PatchFunctor
+from ..forward_pattern import ForwardPattern
+from ..params_modifier import ParamsModifier
from cache_dit.logger import init_logger
@@ -223,9 +223,19 @@ def maybe_patchify(self, *args, **kwargs):
# Process some specificial cases, specific for transformers
# that has different forward patterns between single_transformer_blocks
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
+
if self.patch_functor is not None:
if self.transformer is not None:
- self.patch_functor.apply(self.transformer, *args, **kwargs)
+ if self.nested_depth(self.transformer) == 0:
+ self.patch_functor.apply(self.transformer, *args, **kwargs)
+ elif self.nested_depth(self.transformer) == 1:
+ for transformer in self.transformer:
+ self.patch_functor.apply(transformer, *args, **kwargs)
+ else:
+ raise ValueError(
+ "transformer nested depth can't more than 1, "
+ f"current is: {self.nested_depth(self.transformer)}"
+ )
else:
assert hasattr(self.pipe, "transformer"), (
"pipe.transformer can not be None when patch_functor "
diff --git a/src/cache_dit/caching/block_adapters/block_registers.py b/src/cache_dit/caching/block_adapters/block_registers.py
index 15ce1a8df..f5ee44a71 100644
--- a/src/cache_dit/caching/block_adapters/block_registers.py
+++ b/src/cache_dit/caching/block_adapters/block_registers.py
@@ -2,7 +2,7 @@
from typing import Any, Tuple, List, Dict, Callable, Union
from diffusers import DiffusionPipeline
-from cache_dit.caching.block_adapters.block_adapters import (
+from .block_adapters import (
BlockAdapter,
FakeDiffusionPipeline,
)
@@ -24,6 +24,8 @@ class BlockAdapterRegister:
"Lumina2",
"Kandinsky5",
"ChronoEdit",
+ "HunyuanVideo15",
+ "OvisImage",
]
@classmethod
@@ -66,25 +68,25 @@ def get_adapter(
@classmethod
def has_separate_cfg(
cls,
- pipe_or_adapter: Union[
+ pipe_or_adapter_or_module: Union[
DiffusionPipeline,
FakeDiffusionPipeline,
BlockAdapter,
- Any,
+ torch.nn.Module, # e.g., transformer-only case
],
) -> bool:
- # Prefer custom setting from block adapter.
- if isinstance(pipe_or_adapter, BlockAdapter):
- return pipe_or_adapter.has_separate_cfg
+ # 0. Prefer custom setting from block adapter.
+ if isinstance(pipe_or_adapter_or_module, BlockAdapter):
+ return pipe_or_adapter_or_module.has_separate_cfg
has_separate_cfg = False
- if isinstance(pipe_or_adapter, FakeDiffusionPipeline):
+ if isinstance(pipe_or_adapter_or_module, FakeDiffusionPipeline):
return False
- if isinstance(pipe_or_adapter, DiffusionPipeline):
+ if isinstance(pipe_or_adapter_or_module, (DiffusionPipeline, torch.nn.Module)):
adapter = cls.get_adapter(
- pipe_or_adapter,
+ pipe_or_adapter_or_module,
skip_post_init=True, # check cfg setting only
)
if adapter is not None:
@@ -93,7 +95,7 @@ def has_separate_cfg(
if has_separate_cfg:
return True
- pipe_cls_name = pipe_or_adapter.__class__.__name__
+ pipe_cls_name = pipe_or_adapter_or_module.__class__.__name__
for name in cls._predefined_adapters_has_separate_cfg:
if pipe_cls_name.startswith(name):
return True
diff --git a/src/cache_dit/caching/cache_adapters/cache_adapter.py b/src/cache_dit/caching/cache_adapters/cache_adapter.py
index 784565f38..d3ea7e43b 100644
--- a/src/cache_dit/caching/cache_adapters/cache_adapter.py
+++ b/src/cache_dit/caching/cache_adapters/cache_adapter.py
@@ -7,17 +7,25 @@
from diffusers import DiffusionPipeline, ModelMixin
-from cache_dit.caching.cache_types import CacheType
-from cache_dit.caching.block_adapters import BlockAdapter
-from cache_dit.caching.block_adapters import FakeDiffusionPipeline
-from cache_dit.caching.block_adapters import ParamsModifier
-from cache_dit.caching.block_adapters import BlockAdapterRegister
-from cache_dit.caching.cache_contexts import ContextManager
-from cache_dit.caching.cache_contexts import BasicCacheConfig
-from cache_dit.caching.cache_contexts import CalibratorConfig
-from cache_dit.caching.cache_blocks import UnifiedBlocks
+from ..cache_types import CacheType
+from ..block_adapters import BlockAdapter
+from ..block_adapters import FakeDiffusionPipeline
+from ..block_adapters import ParamsModifier
+from ..block_adapters import BlockAdapterRegister
+from ..cache_contexts import ContextManager
+from ..cache_contexts import BasicCacheConfig
+from ..cache_contexts import CalibratorConfig
+from ..cache_blocks import UnifiedBlocks
from cache_dit.logger import init_logger
+try:
+ from accelerate import hooks
+
+ _accelerate_is_availble = True
+except ImportError:
+ _accelerate_is_availble = False
+
+
logger = init_logger(__name__)
@@ -173,7 +181,8 @@ def create_context(
BlockAdapter.assert_normalized(block_adapter)
if BlockAdapter.is_cached(block_adapter.pipe):
- return block_adapter.pipe
+ logger.warning("Pipeline has been already cached, skip creating cache context again.")
+ return None, block_adapter.pipe
# Check context_kwargs
context_kwargs = cls.check_context_kwargs(block_adapter, **context_kwargs)
@@ -266,45 +275,71 @@ def modify_context_params(
for i in range(
min(len(contexts_kwargs), len(flatten_modifiers)),
):
- if "cache_config" in flatten_modifiers[i]._context_kwargs:
- modifier_cache_config = flatten_modifiers[i]._context_kwargs.get(
- "cache_config", None
- )
- modifier_calibrator_config = flatten_modifiers[i]._context_kwargs.get(
- "calibrator_config", None
+ contexts_kwargs[i] = cls._modify_context_params(
+ flatten_modifiers[i]._context_kwargs,
+ contexts_kwargs[i],
+ )
+ cls._config_messages(**contexts_kwargs[i])
+
+ return flatten_contexts, contexts_kwargs
+
+ @classmethod
+ def _modify_context_params(
+ cls,
+ new_context_kwargs: Dict[str, Any],
+ old_context_kwargs: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ modified_context_kwargs = copy.deepcopy(old_context_kwargs)
+ if "cache_config" in new_context_kwargs:
+ new_cache_config = new_context_kwargs.get("cache_config", None)
+ new_calibrator_config = new_context_kwargs.get("calibrator_config", None)
+ # Modify cache_config
+ if new_cache_config is not None:
+ assert isinstance(new_cache_config, BasicCacheConfig), (
+ f"cache_config must be BasicCacheConfig, but got " f"{type(new_cache_config)}."
)
- if modifier_cache_config is not None:
- assert isinstance(modifier_cache_config, BasicCacheConfig), (
+ if modified_context_kwargs.get("cache_config", None) is None:
+ modified_context_kwargs["cache_config"] = new_cache_config
+ else:
+ assert isinstance(modified_context_kwargs["cache_config"], BasicCacheConfig), (
f"cache_config must be BasicCacheConfig, but got "
- f"{type(modifier_cache_config)}."
+ f"{type(modified_context_kwargs['cache_config'])}."
)
- contexts_kwargs[i]["cache_config"].update(**modifier_cache_config.as_dict())
- if modifier_calibrator_config is not None:
- assert isinstance(modifier_calibrator_config, CalibratorConfig), (
+ modified_context_kwargs["cache_config"].update(**new_cache_config.as_dict())
+ # Modify calibrator_config
+ if new_calibrator_config is not None:
+ assert isinstance(new_calibrator_config, CalibratorConfig), (
+ f"calibrator_config must be CalibratorConfig, but got "
+ f"{type(new_calibrator_config)}."
+ )
+ if modified_context_kwargs.get("calibrator_config", None) is None:
+ modified_context_kwargs["calibrator_config"] = new_calibrator_config
+ else:
+ assert isinstance(
+ modified_context_kwargs["calibrator_config"], CalibratorConfig
+ ), (
f"calibrator_config must be CalibratorConfig, but got "
- f"{type(modifier_calibrator_config)}."
+ f"{type(modified_context_kwargs['calibrator_config'])}."
)
- if contexts_kwargs[i].get("calibrator_config", None) is None:
- contexts_kwargs[i]["calibrator_config"] = modifier_calibrator_config
- else:
- contexts_kwargs[i]["calibrator_config"].update(
- **modifier_calibrator_config.as_dict()
- )
- cls._config_messages(**contexts_kwargs[i])
-
- return flatten_contexts, contexts_kwargs
+ modified_context_kwargs["calibrator_config"].update(
+ **new_calibrator_config.as_dict()
+ )
+ return modified_context_kwargs
@classmethod
- def _config_messages(cls, **contexts_kwargs):
+ def _config_messages(cls, logging: bool = True, **contexts_kwargs):
cache_config: BasicCacheConfig = contexts_kwargs.get("cache_config", None)
calibrator_config: CalibratorConfig = contexts_kwargs.get("calibrator_config", None)
+ message = ""
if cache_config is not None:
message = f"Collected Context Config: {cache_config.strify()}"
if calibrator_config is not None:
message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
else:
message += ", Calibrator Config: None"
+ if logging:
logger.info(message)
+ return message
@classmethod
def mock_blocks(
@@ -362,19 +397,19 @@ def mock_transformer(
assert isinstance(dummy_blocks_names, list)
- from accelerate import hooks
-
- _hf_hook: Optional[hooks.ModelHook] = None
-
- if getattr(transformer, "_hf_hook", None) is not None:
- _hf_hook = transformer._hf_hook # hooks from accelerate.hooks
- if hasattr(transformer, "_old_forward"):
- logger.warning(
- "_hf_hook is not None, so, we have to re-direct transformer's "
- f"original_forward({id(original_forward)}) to transformer's "
- f"_old_forward({id(transformer._old_forward)})"
- )
- original_forward = transformer._old_forward
+ if _accelerate_is_availble:
+ _hf_hook: Optional[hooks.ModelHook] = None
+ if getattr(transformer, "_hf_hook", None) is not None:
+ _hf_hook = transformer._hf_hook # hooks from accelerate.hooks
+ if hasattr(transformer, "_old_forward"):
+ logger.warning(
+ "_hf_hook is not None, so, we have to re-direct transformer's "
+ f"original_forward({id(original_forward)}) to transformer's "
+ f"_old_forward({id(transformer._old_forward)})"
+ )
+ original_forward = transformer._old_forward
+ else:
+ _hf_hook = None
# TODO: remove group offload hooks the re-apply after cache applied.
# hooks = _diffusers_hook.hooks.copy(); _diffusers_hook.hooks.clear()
@@ -383,6 +418,9 @@ def mock_transformer(
# from diffusers.hooks.group_offloading import apply_group_offloading
context_manager: ContextManager = block_adapter.pipe._context_manager
assert isinstance(context_manager, ContextManager._supported_managers)
+ # NOTE: Also assign context manager to transformer for transformer-only case
+ transformer._context_manager = context_manager # instance level
+ transformer._context_names = unique_blocks_name # instance level
def new_forward(self, *args, **kwargs):
with ExitStack() as stack:
@@ -500,7 +538,7 @@ def apply_stats_hooks(
cls,
block_adapter: BlockAdapter,
):
- from cache_dit.caching.cache_blocks import (
+ from ..cache_blocks import (
apply_stats,
)
@@ -528,6 +566,7 @@ def maybe_release_hooks(
pipe_or_adapter: Union[
DiffusionPipeline,
BlockAdapter,
+ torch.nn.Module, # Transformer-only
],
):
# release model hooks
@@ -541,6 +580,16 @@ def _release_transformer_hooks(transformer):
del transformer._original_forward
if hasattr(transformer, "_is_cached"):
del transformer._is_cached
+ if hasattr(transformer, "_context_manager"):
+ context_manager = transformer._context_manager
+ if isinstance(context_manager, ContextManager._supported_managers):
+ context_manager.clear_contexts()
+ try:
+ del transformer._context_manager
+ except Exception:
+ pass
+ if hasattr(transformer, "_context_names"):
+ del transformer._context_names
def _release_pipeline_hooks(pipe):
if hasattr(pipe, "_original_call"):
@@ -591,14 +640,14 @@ def _release_pipeline_params(pipe):
)
# release stats hooks
- from cache_dit.caching.cache_blocks import (
+ from ..cache_blocks import (
remove_stats,
)
cls.release_hooks(pipe_or_adapter, remove_stats, remove_stats, remove_stats)
# maybe release parallelism stats
- from cache_dit.parallelism.parallel_interface import (
+ from cache_dit.parallelism import (
remove_parallelism_stats,
)
@@ -616,22 +665,90 @@ def release_hooks(
DiffusionPipeline,
BlockAdapter,
],
- _release_blocks: Callable,
- _release_transformer: Callable,
- _release_pipeline: Callable,
+ _release_blocks: Optional[Callable] = None,
+ _release_transformer: Optional[Callable] = None,
+ _release_pipeline: Optional[Callable] = None,
):
if isinstance(pipe_or_adapter, DiffusionPipeline):
pipe = pipe_or_adapter
- _release_pipeline(pipe)
+ if _release_pipeline is not None:
+ _release_pipeline(pipe)
if hasattr(pipe, "transformer"):
- _release_transformer(pipe.transformer)
+ if _release_transformer is not None:
+ _release_transformer(pipe.transformer)
if hasattr(pipe, "transformer_2"): # Wan 2.2
- _release_transformer(pipe.transformer_2)
+ if _release_transformer is not None:
+ _release_transformer(pipe.transformer_2)
elif isinstance(pipe_or_adapter, BlockAdapter):
adapter = pipe_or_adapter
BlockAdapter.assert_normalized(adapter)
- _release_pipeline(adapter.pipe)
+ if _release_pipeline is not None:
+ _release_pipeline(adapter.pipe)
for transformer in BlockAdapter.flatten(adapter.transformer):
- _release_transformer(transformer)
+ if _release_transformer is not None:
+ _release_transformer(transformer)
for blocks in BlockAdapter.flatten(adapter.blocks):
- _release_blocks(blocks)
+ if _release_blocks is not None:
+ _release_blocks(blocks)
+ elif isinstance(pipe_or_adapter, torch.nn.Module):
+ transformer = pipe_or_adapter
+ if _release_transformer is not None:
+ _release_transformer(transformer)
+ for blocks in BlockAdapter.find_blocks(transformer):
+ if _release_blocks is not None:
+ _release_blocks(blocks)
+
+ @classmethod
+ def maybe_refresh_context(
+ cls,
+ transformer: torch.nn.Module,
+ **force_refresh_kwargs,
+ ):
+ verbose = force_refresh_kwargs.pop("verbose", False)
+ # Get context manager from transformer
+ if not hasattr(transformer, "_context_manager"):
+ logger.warning(
+ "Transformer has no attribute '_context_manager', skip refreshing cache context."
+ )
+ return
+ context_manager: ContextManager = transformer._context_manager
+ assert isinstance(context_manager, ContextManager._supported_managers)
+ if not context_manager.persistent_context:
+ logger.warning(
+ "Transformer's context manager is not persistent, skip refreshing cache context."
+ )
+ return
+ context_names: List[str] = getattr(transformer, "_context_names", [])
+ if not context_names:
+ logger.warning(
+ "Transformer has no attribute '_context_names' or it's empty, "
+ "skip refreshing cache context."
+ )
+ return
+
+ for context_name in context_names:
+ current_context = context_manager.get_context(context_name)
+ old_init_kwargs = getattr(current_context, "_init_kwargs", {}) # type: dict
+ new_init_kwargs = copy.deepcopy(old_init_kwargs)
+ # Remove old context
+ context_manager.remove_context(context_name)
+ new_init_kwargs = cls._modify_context_params(
+ force_refresh_kwargs,
+ new_init_kwargs,
+ )
+ # Re-create new context with old init kwargs updated by
+ # force_refresh_kwargs.
+ context_manager.reset_context(
+ context_name,
+ **new_init_kwargs,
+ )
+ if verbose:
+ logger.info(
+ f"✅ Refreshed cache context: {context_name}, "
+ f"{cls._config_messages(logging=False, **new_init_kwargs)}"
+ )
+ # reset _context_kwargs for transformer
+ if hasattr(transformer, "_context_kwargs"):
+ # Will overwrite the _context_kwargs by last context kwargs.
+ # Only used for strify utilization.
+ transformer._context_kwargs = new_init_kwargs
diff --git a/src/cache_dit/caching/cache_blocks/__init__.py b/src/cache_dit/caching/cache_blocks/__init__.py
index f1ff43f31..bb79234fd 100644
--- a/src/cache_dit/caching/cache_blocks/__init__.py
+++ b/src/cache_dit/caching/cache_blocks/__init__.py
@@ -1,26 +1,26 @@
import torch
-from cache_dit.caching import ForwardPattern
-from cache_dit.caching.cache_types import CacheType
-from cache_dit.caching.cache_contexts.cache_context import CachedContext
-from cache_dit.caching.cache_contexts.prune_context import PrunedContext
-from cache_dit.caching.cache_contexts.cache_manager import (
+from ..forward_pattern import ForwardPattern
+from ..cache_types import CacheType
+from ..cache_contexts.cache_context import CachedContext
+from ..cache_contexts.prune_context import PrunedContext
+from ..cache_contexts.cache_manager import (
CachedContextManager,
)
-from cache_dit.caching.cache_contexts.prune_manager import (
+from ..cache_contexts.prune_manager import (
PrunedContextManager,
)
-from cache_dit.caching.cache_blocks.pattern_0_1_2 import (
+from .pattern_0_1_2 import (
CachedBlocks_Pattern_0_1_2,
PrunedBlocks_Pattern_0_1_2,
)
-from cache_dit.caching.cache_blocks.pattern_3_4_5 import (
+from .pattern_3_4_5 import (
CachedBlocks_Pattern_3_4_5,
PrunedBlocks_Pattern_3_4_5,
)
-from cache_dit.caching.cache_blocks.pattern_utils import apply_stats
-from cache_dit.caching.cache_blocks.pattern_utils import remove_stats
+from .pattern_utils import apply_stats
+from .pattern_utils import remove_stats
from cache_dit.logger import init_logger
diff --git a/src/cache_dit/caching/cache_blocks/pattern_0_1_2.py b/src/cache_dit/caching/cache_blocks/pattern_0_1_2.py
index d47295e0c..06bfc4777 100644
--- a/src/cache_dit/caching/cache_blocks/pattern_0_1_2.py
+++ b/src/cache_dit/caching/cache_blocks/pattern_0_1_2.py
@@ -1,5 +1,5 @@
-from cache_dit.caching import ForwardPattern
-from cache_dit.caching.cache_blocks.pattern_base import (
+from ..forward_pattern import ForwardPattern
+from .pattern_base import (
CachedBlocks_Pattern_Base,
PrunedBlocks_Pattern_Base,
)
diff --git a/src/cache_dit/caching/cache_blocks/pattern_3_4_5.py b/src/cache_dit/caching/cache_blocks/pattern_3_4_5.py
index 4b1c90dcf..657990fb8 100644
--- a/src/cache_dit/caching/cache_blocks/pattern_3_4_5.py
+++ b/src/cache_dit/caching/cache_blocks/pattern_3_4_5.py
@@ -1,17 +1,17 @@
import torch
-from cache_dit.caching import ForwardPattern
-from cache_dit.caching.cache_contexts.cache_manager import (
+from ..forward_pattern import ForwardPattern
+from ..cache_contexts.cache_manager import (
ContextNotExistError,
)
-from cache_dit.caching.cache_blocks.pattern_base import (
+from .pattern_base import (
CachedBlocks_Pattern_Base,
)
-from cache_dit.caching.cache_contexts.prune_context import PrunedContext
-from cache_dit.caching.cache_contexts.prune_manager import (
+from ..cache_contexts.prune_context import PrunedContext
+from ..cache_contexts.prune_manager import (
PrunedContextManager,
)
-from cache_dit.caching.cache_types import CacheType
+from ..cache_types import CacheType
from cache_dit.logger import init_logger
diff --git a/src/cache_dit/caching/cache_blocks/pattern_base.py b/src/cache_dit/caching/cache_blocks/pattern_base.py
index 1b964021a..3f5c5f014 100644
--- a/src/cache_dit/caching/cache_blocks/pattern_base.py
+++ b/src/cache_dit/caching/cache_blocks/pattern_base.py
@@ -3,17 +3,17 @@
import torch
import torch.distributed as dist
from diffusers.hooks import HookRegistry
-from cache_dit.caching.cache_contexts.cache_context import CachedContext
-from cache_dit.caching.cache_contexts.prune_context import PrunedContext
-from cache_dit.caching.cache_contexts.cache_manager import (
+from ..cache_contexts.cache_context import CachedContext
+from ..cache_contexts.prune_context import PrunedContext
+from ..cache_contexts.cache_manager import (
CachedContextManager,
ContextNotExistError,
)
-from cache_dit.caching.cache_contexts.prune_manager import (
+from ..cache_contexts.prune_manager import (
PrunedContextManager,
)
-from cache_dit.caching import ForwardPattern
-from cache_dit.caching.cache_types import CacheType
+from ..forward_pattern import ForwardPattern
+from ..cache_types import CacheType
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -22,9 +22,9 @@
from diffusers.hooks.context_parallel import ContextParallelSplitHook
except ImportError:
ContextParallelSplitHook = None
- logger.warning(
- "Context parallelism requires the 'diffusers>=0.36.dev0'."
- "Please install latest version of diffusers from source: \n"
+ logger.debug(
+ "Context parallelism in cache-dit requires 'diffusers>=0.36.dev0.\n"
+ "Please install latest version of diffusers from source via: \n"
"pip3 install git+https://github.com/huggingface/diffusers.git"
)
@@ -444,6 +444,7 @@ def call_Mn_blocks(
):
original_hidden_states = hidden_states
original_encoder_hidden_states = encoder_hidden_states
+
for block in self._Mn_blocks():
hidden_states = block(
hidden_states,
diff --git a/src/cache_dit/caching/cache_blocks/pattern_utils.py b/src/cache_dit/caching/cache_blocks/pattern_utils.py
index a81a763ef..4d12aef2a 100644
--- a/src/cache_dit/caching/cache_blocks/pattern_utils.py
+++ b/src/cache_dit/caching/cache_blocks/pattern_utils.py
@@ -1,9 +1,9 @@
import torch
from typing import Any
-from cache_dit.caching import CachedContext
-from cache_dit.caching import CachedContextManager
-from cache_dit.caching import PrunedContextManager
+from ..cache_contexts import CachedContext
+from ..cache_contexts import CachedContextManager
+from ..cache_contexts import PrunedContextManager
def apply_stats(
diff --git a/src/cache_dit/caching/cache_contexts/__init__.py b/src/cache_dit/caching/cache_contexts/__init__.py
index 46e8e3d4a..8a1e32cb9 100644
--- a/src/cache_dit/caching/cache_contexts/__init__.py
+++ b/src/cache_dit/caching/cache_contexts/__init__.py
@@ -1,28 +1,28 @@
-from cache_dit.caching.cache_contexts.calibrators import (
+from .calibrators import (
Calibrator,
CalibratorBase,
CalibratorConfig,
TaylorSeerCalibratorConfig,
FoCaCalibratorConfig,
)
-from cache_dit.caching.cache_contexts.cache_config import (
+from .cache_config import (
BasicCacheConfig,
DBCacheConfig,
)
-from cache_dit.caching.cache_contexts.cache_context import (
+from .cache_context import (
CachedContext,
)
-from cache_dit.caching.cache_contexts.cache_manager import (
+from .cache_manager import (
CachedContextManager,
ContextNotExistError,
)
-from cache_dit.caching.cache_contexts.prune_config import DBPruneConfig
-from cache_dit.caching.cache_contexts.prune_context import (
+from .prune_config import DBPruneConfig
+from .prune_context import (
PrunedContext,
)
-from cache_dit.caching.cache_contexts.prune_manager import (
+from .prune_manager import (
PrunedContextManager,
)
-from cache_dit.caching.cache_contexts.context_manager import (
+from .context_manager import (
ContextManager,
)
diff --git a/src/cache_dit/caching/cache_contexts/cache_config.py b/src/cache_dit/caching/cache_contexts/cache_config.py
index c3ad76ae5..d3fc3fd0c 100644
--- a/src/cache_dit/caching/cache_contexts/cache_config.py
+++ b/src/cache_dit/caching/cache_contexts/cache_config.py
@@ -1,7 +1,7 @@
import torch
import dataclasses
from typing import Optional, Union, List
-from cache_dit.caching.cache_types import CacheType
+from ..cache_types import CacheType
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -117,6 +117,12 @@ def strify(self) -> str:
base_str += f"_SCM{''.join(map(str, self.steps_computation_mask))}"
base_str += f"_{self.steps_computation_policy}"
+ if self.num_inference_steps is not None:
+ base_str += f"_N{self.num_inference_steps}"
+
+ if self.enable_separate_cfg is not None:
+ base_str += f"_CFG{int(self.enable_separate_cfg)}"
+
return base_str
diff --git a/src/cache_dit/caching/cache_contexts/cache_context.py b/src/cache_dit/caching/cache_contexts/cache_context.py
index f2dadd533..136a308d3 100644
--- a/src/cache_dit/caching/cache_contexts/cache_context.py
+++ b/src/cache_dit/caching/cache_contexts/cache_context.py
@@ -7,12 +7,12 @@
import torch
-from cache_dit.caching.cache_contexts.cache_config import (
+from .cache_config import (
BasicCacheConfig,
ExtraCacheConfig,
DBCacheConfig,
)
-from cache_dit.caching.cache_contexts.calibrators import (
+from .calibrators import (
Calibrator,
CalibratorBase,
CalibratorConfig,
diff --git a/src/cache_dit/caching/cache_contexts/cache_manager.py b/src/cache_dit/caching/cache_contexts/cache_manager.py
index 999ac5983..7d652a639 100644
--- a/src/cache_dit/caching/cache_contexts/cache_manager.py
+++ b/src/cache_dit/caching/cache_contexts/cache_manager.py
@@ -5,11 +5,8 @@
import torch
import torch.distributed as dist
-from cache_dit.caching.cache_contexts.calibrators import CalibratorBase
-from cache_dit.caching.cache_contexts.cache_context import (
- BasicCacheConfig,
- CachedContext,
-)
+from .calibrators import CalibratorBase
+from .cache_context import CachedContext
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -20,7 +17,6 @@ class ContextNotExistError(Exception):
class CachedContextManager:
- # Each Pipeline should have it's own context manager instance.
def __init__(self, name: str = None, persistent_context: bool = False):
self.name = name
@@ -56,17 +52,13 @@ def is_pre_refreshed(self) -> bool:
if num_inference_steps is not None:
current_step = _context.get_current_step() # e.g, 0~49,50~99,...
return current_step == num_inference_steps - 1
- return False
+ # If num_inference_steps is None, always return True, thus will make
+ # `apply_stats_hooks` called after each forward when persistent_context is True.
+ # Otherwise, we will lost the accurate cached stats after each request.
+ return True
@torch.compiler.disable
def new_context(self, *args, **kwargs) -> CachedContext:
- if self._persistent_context:
- cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
- assert cache_config is not None and cache_config.num_inference_steps is not None, (
- "When persistent_context is True, num_inference_steps "
- "must be set in cache_config for proper cache refreshing."
- f"\nkwargs: {kwargs}"
- )
_context = CachedContext(*args, **kwargs)
# NOTE: Patch args and kwargs for implicit refresh.
_context._init_args = args # maybe empty tuple: ()
@@ -90,12 +82,6 @@ def maybe_refresh(
raise ContextNotExistError("Context not exist!")
_context = self._cached_context_manager[cached_context]
- if self._persistent_context:
- assert _context.cache_config.num_inference_steps is not None, (
- "When persistent_context is True, num_inference_steps must be set "
- "in cache_config for proper cache refreshing."
- )
-
num_inference_steps = _context.cache_config.num_inference_steps
if num_inference_steps is not None:
current_step = _context.get_current_step() # e.g, 0~49,50~99,...
diff --git a/src/cache_dit/caching/cache_contexts/calibrators/__init__.py b/src/cache_dit/caching/cache_contexts/calibrators/__init__.py
index d64ee67e9..6ba65c533 100644
--- a/src/cache_dit/caching/cache_contexts/calibrators/__init__.py
+++ b/src/cache_dit/caching/cache_contexts/calibrators/__init__.py
@@ -57,9 +57,11 @@ def update(self, **kwargs) -> "CalibratorConfig":
def empty(self, **kwargs) -> "CalibratorConfig":
# Set all fields to None
+ skip_constants = {"calibrator_type"}
for field in dataclasses.fields(self):
if hasattr(self, field.name):
- setattr(self, field.name, None)
+ if field.name not in skip_constants:
+ setattr(self, field.name, None)
if kwargs:
self.update(**kwargs)
return self
diff --git a/src/cache_dit/caching/cache_contexts/context_manager.py b/src/cache_dit/caching/cache_contexts/context_manager.py
index 9f3ec9b21..1276f1b7d 100644
--- a/src/cache_dit/caching/cache_contexts/context_manager.py
+++ b/src/cache_dit/caching/cache_contexts/context_manager.py
@@ -1,10 +1,6 @@
-from cache_dit.caching.cache_types import CacheType
-from cache_dit.caching.cache_contexts.cache_manager import (
- CachedContextManager,
-)
-from cache_dit.caching.cache_contexts.prune_manager import (
- PrunedContextManager,
-)
+from ..cache_types import CacheType
+from .cache_manager import CachedContextManager
+from .prune_manager import PrunedContextManager
from cache_dit.logger import init_logger
logger = init_logger(__name__)
diff --git a/src/cache_dit/caching/cache_contexts/prune_config.py b/src/cache_dit/caching/cache_contexts/prune_config.py
index 154f4dc36..ca4956a86 100644
--- a/src/cache_dit/caching/cache_contexts/prune_config.py
+++ b/src/cache_dit/caching/cache_contexts/prune_config.py
@@ -1,9 +1,7 @@
import dataclasses
from typing import List
-from cache_dit.caching.cache_types import CacheType
-from cache_dit.caching.cache_contexts.cache_config import (
- BasicCacheConfig,
-)
+from ..cache_types import CacheType
+from .cache_config import BasicCacheConfig
from cache_dit.logger import init_logger
diff --git a/src/cache_dit/caching/cache_contexts/prune_context.py b/src/cache_dit/caching/cache_contexts/prune_context.py
index 6c285a42d..019b788d1 100644
--- a/src/cache_dit/caching/cache_contexts/prune_context.py
+++ b/src/cache_dit/caching/cache_contexts/prune_context.py
@@ -2,13 +2,9 @@
import logging
import dataclasses
from typing import List
-from cache_dit.caching.cache_types import CacheType
-from cache_dit.caching.cache_contexts.prune_config import (
- DBPruneConfig,
-)
-from cache_dit.caching.cache_contexts.cache_context import (
- CachedContext,
-)
+from ..cache_types import CacheType
+from .prune_config import DBPruneConfig
+from .cache_context import CachedContext
from cache_dit.logger import init_logger
diff --git a/src/cache_dit/caching/cache_contexts/prune_manager.py b/src/cache_dit/caching/cache_contexts/prune_manager.py
index 6b141ec33..bd81d90e1 100644
--- a/src/cache_dit/caching/cache_contexts/prune_manager.py
+++ b/src/cache_dit/caching/cache_contexts/prune_manager.py
@@ -2,13 +2,8 @@
import functools
from typing import Dict, List, Tuple, Union
-from cache_dit.caching.cache_contexts.cache_manager import (
- BasicCacheConfig,
- CachedContextManager,
-)
-from cache_dit.caching.cache_contexts.prune_context import (
- PrunedContext,
-)
+from .cache_manager import CachedContextManager
+from .prune_context import PrunedContext
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -25,12 +20,6 @@ def __init__(self, name: str = None, **kwargs):
# Overwrite for Dynamic Block Prune
def new_context(self, *args, **kwargs) -> PrunedContext:
- if self._persistent_context:
- cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
- assert cache_config is not None and cache_config.num_inference_steps is not None, (
- "When persistent_context is True, num_inference_steps "
- "must be set in cache_config for proper cache refreshing."
- )
_context = PrunedContext(*args, **kwargs)
# NOTE: Patch args and kwargs for implicit refresh.
_context._init_args = args # maybe empty tuple: ()
diff --git a/src/cache_dit/caching/cache_interface.py b/src/cache_dit/caching/cache_interface.py
index 864f0b8e7..ac84fcfee 100644
--- a/src/cache_dit/caching/cache_interface.py
+++ b/src/cache_dit/caching/cache_interface.py
@@ -1,17 +1,18 @@
+import copy
import torch
from typing import Any, Tuple, List, Union, Optional
from diffusers import DiffusionPipeline, ModelMixin
-from cache_dit.caching.cache_types import CacheType
-from cache_dit.caching.block_adapters import BlockAdapter
-from cache_dit.caching.block_adapters import BlockAdapterRegister
-from cache_dit.caching.cache_adapters import CachedAdapter
-from cache_dit.caching.cache_contexts import BasicCacheConfig
-from cache_dit.caching.cache_contexts import DBCacheConfig
-from cache_dit.caching.cache_contexts import DBPruneConfig
-from cache_dit.caching.cache_contexts import CalibratorConfig
-from cache_dit.caching.params_modifier import ParamsModifier
-from cache_dit.parallelism import ParallelismConfig
-from cache_dit.parallelism import enable_parallelism
+from .cache_types import CacheType
+from .block_adapters import BlockAdapter
+from .block_adapters import BlockAdapterRegister
+from .cache_adapters import CachedAdapter
+from .cache_contexts import BasicCacheConfig
+from .cache_contexts import DBCacheConfig
+from .cache_contexts import DBPruneConfig
+from .cache_contexts import CalibratorConfig
+from .params_modifier import ParamsModifier
+from ..parallelism import ParallelismConfig
+from ..parallelism import enable_parallelism
from cache_dit.logger import init_logger
@@ -253,7 +254,7 @@ def enable_cache(
"deprecated and will be removed in the future, please use "
"`calibrator_config` parameter instead!"
)
- from cache_dit.caching.cache_contexts.calibrators import (
+ from .cache_contexts.calibrators import (
TaylorSeerCalibratorConfig,
)
@@ -298,6 +299,31 @@ def enable_cache(
parallelism_config, ParallelismConfig
), "parallelism_config should be of type ParallelismConfig."
+ # Prefer custom has_controlnet flag from users if provided, otherwise,
+ # we will automatically check whether the pipeline has controlnet.
+ if "has_controlnet" not in parallelism_config.parallel_kwargs:
+ # This flag is used to decide whether to use the special parallelism
+ # plan due to the addition of ControlNet, e.g., Z-Image-ControlNet.
+ parallelism_config.parallel_kwargs["has_controlnet"] = _has_controlnet(
+ pipe_or_adapter,
+ )
+ parallelism_config._has_controlnet = parallelism_config.parallel_kwargs[
+ "has_controlnet"
+ ]
+
+ # Parse extra parallel modules from names to actual modules
+ if (
+ extra_parallel_module := parallelism_config.parallel_kwargs.get(
+ "extra_parallel_modules", None
+ )
+ ) is not None:
+ parallelism_config.parallel_kwargs["extra_parallel_modules"] = (
+ _parse_extra_parallel_modules(
+ pipe_or_adapter,
+ extra_parallel_module,
+ )
+ )
+
transformers = []
if isinstance(pipe_or_adapter, DiffusionPipeline):
adapter = BlockAdapterRegister.get_adapter(
@@ -336,14 +362,151 @@ def enable_cache(
return pipe_or_adapter
+def _has_controlnet(pipe_or_adapter: DiffusionPipeline | BlockAdapter) -> bool:
+ """Check if the given pipeline has ControlNet."""
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ else:
+ pipe = pipe_or_adapter
+ if hasattr(pipe, "controlnet") and getattr(pipe, "controlnet") is not None:
+ return True
+ return False
+
+
+def _parse_text_encoder(
+ pipe: DiffusionPipeline,
+) -> Tuple[Optional[torch.nn.Module], Optional[str]]:
+ pipe_cls_name = pipe.__class__.__name__
+ if (
+ hasattr(pipe, "text_encoder_2")
+ and not pipe_cls_name.startswith("Hunyuan")
+ and not pipe_cls_name.startswith("Kandinsky")
+ ):
+ # Specific for FluxPipeline, FLUX.1-dev
+ return getattr(pipe, "text_encoder_2"), "text_encoder_2"
+ elif hasattr(pipe, "text_encoder_3"): # HiDream pipeline
+ return getattr(pipe, "text_encoder_3"), "text_encoder_3"
+ elif hasattr(pipe, "text_encoder"): # General case
+ return getattr(pipe, "text_encoder"), "text_encoder"
+ else:
+ return None, None
+
+
+def _parse_extra_parallel_modules(
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
+ extra_parallel_module: List[str | torch.nn.Module],
+) -> Union[List[torch.nn.Module], List]:
+ if isinstance(pipe_or_adapter, BlockAdapter):
+ pipe = pipe_or_adapter.pipe
+ else:
+ pipe = pipe_or_adapter
+
+ if not extra_parallel_module: # empty list
+ return []
+
+ parsed_extra_parallel_modules: List[torch.nn.Module] = []
+ for module_or_name in extra_parallel_module:
+ if isinstance(module_or_name, torch.nn.Module):
+ parsed_extra_parallel_modules.append(module_or_name)
+ continue
+
+ if hasattr(pipe, module_or_name):
+ if module_or_name == "text_encoder":
+ # Special handling for text encoder
+ text_encoder, _ = _parse_text_encoder(pipe)
+ if text_encoder is not None:
+ parsed_extra_parallel_modules.append(text_encoder)
+ else:
+ logger.warning(
+ "Text encoder not found in the pipeline for extra parallel module."
+ )
+ else:
+ parsed_extra_parallel_modules.append(getattr(pipe, module_or_name))
+ else:
+ logger.warning(
+ f"Extra parallel module name {module_or_name} not found in the pipeline."
+ )
+ return parsed_extra_parallel_modules
+
+
+def refresh_context(
+ transformer: torch.nn.Module,
+ **force_refresh_kwargs,
+):
+ r"""Refresh cache context for the given transformer. This is useful when
+ the users run into transformer-only case with dynamic num_inference_steps.
+ For example, when num_inference_steps changes significantly between different
+ requests, the cache context should be refreshed to avoid potential
+ precision degradation. Usage:
+ ```py
+ >>> import cache_dit
+ >>> from cache_dit import DBCacheConfig
+ >>> from diffusers import DiffusionPipeline
+ >>> # Init cache context with num_inference_steps=None (default)
+ >>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image")
+ >>> pipe = cache_dit.enable_cache(pipe.transformer, cache_config=DBCacheConfig(...))
+ >>> # Assume num_inference_steps is 28, and we want to refresh the context
+ >>> cache_dit.refresh_context(transformer, num_inference_steps=28, verbose=True)
+ >>> output = pipe(...) # Just call the pipe as normal.
+ >>> stats = cache_dit.summary(pipe.transformer) # Then, get the summary
+ >>> # Update the cache context with new num_inference_steps=50.
+ >>> cache_dit.refresh_context(pipe.transformer, num_inference_steps=50, verbose=True)
+ >>> output = pipe(...) # Just call the pipe as normal.
+ >>> stats = cache_dit.summary(pipe.transformer) # Then, get the summary
+ >>> # Update the cache context with new cache_config.
+ >>> cache_dit.refresh_context(
+ pipe.transformer,
+ cache_config=DBCacheConfig(
+ residual_diff_threshold=0.1,
+ max_warmup_steps=10,
+ max_cached_steps=20,
+ max_continuous_cached_steps=4,
+ num_inference_steps=50,
+ ),
+ verbose=True,
+ )
+ >>> output = pipe(...) # Just call the pipe as normal.
+ >>> stats = cache_dit.summary(pipe.transformer) # Then, get the summary
+ ```
+ """
+ if force_refresh_kwargs:
+ if "cache_config" not in force_refresh_kwargs:
+ # Assume force_refresh_kwargs is passed as dict, e.g.,
+ # {"num_inference_steps": 50}
+ from .utils import load_cache_config
+
+ cache_config, calibrator_config = load_cache_config(
+ force_refresh_kwargs,
+ reset=True,
+ )
+ force_refresh_kwargs["cache_config"] = copy.deepcopy(cache_config)
+ if calibrator_config is not None:
+ force_refresh_kwargs["calibrator_config"] = copy.deepcopy(calibrator_config)
+ else:
+ allowed_keys = {"cache_config", "calibrator_config", "verbose"}
+ not_allowed_keys = set(force_refresh_kwargs.keys()) - allowed_keys
+ if not_allowed_keys:
+ logger.warning(
+ f"force_refresh_kwargs contains cache_config, please put the extra "
+ f"kwargs: {not_allowed_keys} into cache_config directly. Ohtherwise, "
+ f"these kwargs will be ignored."
+ )
+ CachedAdapter.maybe_refresh_context(
+ transformer,
+ **force_refresh_kwargs,
+ )
+
+
def disable_cache(
pipe_or_adapter: Union[
DiffusionPipeline,
BlockAdapter,
+ torch.nn.Module, # Transformer-only
],
):
+ cls_name = pipe_or_adapter.__class__.__name__
CachedAdapter.maybe_release_hooks(pipe_or_adapter)
- logger.warning(f"Cache Acceleration is disabled for: " f"{pipe_or_adapter.__class__.__name__}.")
+ logger.warning(f"Acceleration hooks is disabled for: {cls_name}.")
def supported_pipelines(
@@ -358,7 +521,7 @@ def get_adapter(
return BlockAdapterRegister.get_adapter(pipe)
-def steps_mask(
+def _steps_mask(
compute_bins: List[int],
cache_bins: List[int],
total_steps: Optional[int] = None,
@@ -393,3 +556,245 @@ def steps_mask(
break
return mask[:total_steps]
+
+
+def steps_mask(
+ compute_bins: Optional[List[int]] = None,
+ cache_bins: Optional[List[int]] = None,
+ total_steps: Optional[int] = None,
+ mask_policy: Optional[str] = "medium",
+) -> list[int]:
+ r"""
+ Define a step computation mask based on compute and cache bins.
+
+ Args:
+ compute_bins (`List[int]`, *optional*, defaults to None):
+ A list specifying the number of consecutive steps to compute.
+ For example, [4, 2] means compute 4 steps, then 2 steps.
+ cache_bins (`List[int]`, *optional*, defaults to None):
+ A list specifying the number of consecutive steps to cache.
+ For example, [2, 4] means cache 2 steps, then 4 steps.
+ total_steps (`int`, *optional*, defaults to None):
+ Total number of steps for which the mask is generated.
+ If provided, the sum of compute_bins and cache_bins must be at
+ least total_steps.
+ mask_policy (`str`, *optional*, defaults to "medium"):
+ Predefined mask policy. Options are "slow", "medium", "fast", "ultra".
+ For examples, if total_steps=28, each policy corresponds to specific
+ compute and cache bin configurations:
+ - "slow": compute_bins=[8, 3, 3, 2, 1, 1], cache_bins=1, 2, 2, 2, 3]
+ - "medium": compute_bins=[6, 2, 2, 2, 2, 1], cache_bins=[1, 3, 3, 3, 3]
+ - "fast": compute_bins=[6, 1, 1, 1, 1], cache_bins=[1, 3, 4, 5, 4]
+ - "ultra": compute_bins=[4, 1, 1, 1, 1], cache_bins=[2, 5, 6, 7]
+ Returns:
+ `List[int]`: A list representing the step computation mask, where 1
+ indicates a compute step and 0 indicates a cache step.
+ """
+ # Prefer compute/cache bins if both are provided
+ if compute_bins is not None and cache_bins is not None:
+ return _steps_mask(
+ compute_bins=compute_bins,
+ cache_bins=cache_bins,
+ total_steps=total_steps,
+ )
+
+ assert (
+ total_steps is not None
+ ), "total_steps must be provided when using predefined mask_policy."
+ # 28 steps predefined policies
+ predefined_policies = {
+ # NOTE: last step will never cache by default
+ # mask: 11111111 0 111 00 111 00 11 00 1 000 1
+ "slow": [
+ [8, 3, 3, 2, 1, 1], # = 18 compute steps
+ [1, 2, 2, 2, 3], # = 10 cache steps
+ ],
+ "medium": [
+ [6, 2, 2, 2, 2, 1], # = 15 compute steps
+ [1, 3, 3, 3, 3], # = 13 cache steps
+ ],
+ "fast": [
+ [6, 1, 1, 1, 1, 1], # = 11 compute steps
+ [1, 3, 4, 5, 4], # = 17 cache steps
+ ],
+ "ultra": [
+ [4, 1, 1, 1, 1], # = 8 compute steps
+ [2, 5, 6, 7], # = 20 cache steps
+ ],
+ }
+
+ def _sum_policy(policy: List[List[int]]) -> int:
+ return sum(policy[0]) + sum(policy[1])
+
+ def _truncate_policy(policy: List[List[int]], target_steps: int) -> List[List[int]]:
+ compute_bins, cache_bins = policy # reference only
+ while _sum_policy(policy) > target_steps:
+ if cache_bins:
+ cache_bins[-1] -= 1
+ if cache_bins[-1] == 0:
+ cache_bins.pop()
+ if _sum_policy(policy) <= target_steps:
+ break
+ if compute_bins:
+ compute_bins[-1] -= 1
+ if compute_bins[-1] == 0:
+ compute_bins.pop()
+ if _sum_policy(policy) <= target_steps:
+ break
+ return [compute_bins, cache_bins]
+
+ def _truncate_predefined_policies(
+ policies: dict[str, List[List[int]]],
+ target_steps: int,
+ ) -> dict[str, List[List[int]]]:
+ truncated_policies = {}
+ for name, policy in policies.items():
+ truncated_policies[name] = _truncate_policy(policy, target_steps)
+ return truncated_policies
+
+ if total_steps > 28:
+ # Expand bins if total_steps exceed predefined sum
+ # For example, for total_steps=50, we will expand the bins
+ # of each policy until they can cover total_steps.
+ # This ensures the relative ratio of compute/cache steps
+ # remains consistent with the predefined policies.
+ for policy in predefined_policies.values():
+ min_bins_len = min(len(policy[0]), len(policy[1]))
+ compute_bins = copy.deepcopy(policy[0])
+ cache_bins = copy.deepcopy(policy[1])
+ while _sum_policy(policy) < total_steps:
+ for i in range(min_bins_len):
+ # Add 1 to each compute bin, e.g., total_steps=50,
+ # slow: 8 -> 8 + int(8 * (50 / 28) * 0.5) = 14
+ # 3 -> 3 + int(3 * (50 / 28) * 0.5) = 5
+ # fast: 6 -> 6 + int(6 * (50 / 28) * 0.5) = 11
+ # 1 -> 1 + int(1 * (50 / 28) * 0.5) = 2
+ policy[0][i] += max(int(compute_bins[i] * ((total_steps / 28) * 0.5)), 1)
+ if _sum_policy(policy) >= total_steps:
+ break
+ # Add 1 to each cache bin, e.g., total_steps=50,
+ # slow: 1 -> 1 + int(1 * (50 / 28) * 0.5) = 2
+ # 2 -> 2 + int(2 * (50 / 28) * 0.5) = 4
+ # fast: 1 -> 1 + int(1 * (50 / 28) * 0.5) = 2
+ # 3 -> 3 + int(3 * (50 / 28) * 0.5) = 5
+ policy[1][i] += max(int(cache_bins[i] * ((total_steps / 28) * 0.5)), 1)
+ if _sum_policy(policy) >= total_steps:
+ break
+ if _sum_policy(policy) >= total_steps:
+ break
+ # compute bin due to compute_bins always longer than cache_bins
+ policy[0][-1] += 1
+ if _sum_policy(policy) >= total_steps:
+ break
+
+ # truncate to exact total_steps
+ predefined_policies = _truncate_predefined_policies(
+ predefined_policies,
+ total_steps,
+ )
+
+ elif total_steps < 28 and total_steps >= 16:
+ # Truncate bins to fit total_steps
+ predefined_policies = _truncate_predefined_policies(
+ predefined_policies,
+ total_steps,
+ )
+ elif total_steps < 16 and total_steps >= 8:
+ # Mainly for distilled models with less steps, use smaller compute/cache bins
+ if total_steps > 8:
+ predefined_policies = {
+ "slow": [
+ [4, 2, 2, 2, 1], # = 11
+ [1, 1, 1, 1], # = 4
+ ],
+ "medium": [
+ [4, 2, 1, 1, 1], # = 9
+ [1, 1, 2, 2], # = 6
+ ],
+ "fast": [
+ [3, 1, 1, 1, 1], # = 7
+ [1, 2, 2, 3], # = 8
+ ],
+ "ultra": [
+ [2, 1, 1, 1, 1], # = 6
+ [1, 2, 3, 3], # = 9
+ ],
+ }
+ # Specifical case for Z-Image-Turbo with 9 steps
+ if total_steps == 9:
+ predefined_policies = {
+ "slow": [
+ [5, 2, 1], # = 8
+ [1], # = 1
+ ],
+ "medium": [
+ [5, 1, 1], # = 7
+ [1, 1], # = 2
+ ],
+ "fast": [
+ [4, 1, 1], # = 6
+ [1, 2], # = 3
+ ],
+ "ultra": [
+ [3, 1, 1], # = 5
+ [2, 2], # = 4
+ ],
+ }
+ else: # total_steps == 8
+ # cases: 8 steps distilled models
+ predefined_policies = {
+ "slow": [
+ [5, 1, 1], # = 7
+ [1], # = 1
+ ],
+ "medium": [
+ [4, 1, 1], # = 6
+ [1, 1], # = 2
+ ],
+ "fast": [
+ [3, 1, 1], # = 5
+ [1, 2], # = 3
+ ],
+ "ultra": [
+ [2, 1, 1], # = 4
+ [2, 2], # = 4
+ ],
+ }
+ for policy in predefined_policies.values():
+ predefined_policies = _truncate_predefined_policies(
+ predefined_policies,
+ total_steps,
+ )
+ elif total_steps < 8:
+ # case: 4 or 6 steps distilled models
+ assert total_steps in (4, 6), (
+ "Only total_steps=4 or 6 is supported for predefined masks "
+ f"while total_steps < 8. Got total_steps={total_steps}."
+ )
+ constant_plicy_4_steps = [[2, 1], [1]]
+ constant_plicy_6_steps = [[3, 1], [2]]
+ if total_steps == 4:
+ constant_plicy = constant_plicy_4_steps
+ else:
+ constant_plicy = constant_plicy_6_steps
+
+ predefined_policies = {
+ "slow": constant_plicy,
+ "medium": constant_plicy,
+ "fast": constant_plicy,
+ "ultra": constant_plicy,
+ }
+
+ if mask_policy not in predefined_policies:
+ raise ValueError(
+ f"mask_policy {mask_policy} is not valid. "
+ f"Choose from {list(predefined_policies.keys())}."
+ )
+ compute_bins, cache_bins = predefined_policies[mask_policy]
+ # Will truncate if exceeded total_steps
+ compute_mask = _steps_mask(
+ compute_bins=compute_bins, cache_bins=cache_bins, total_steps=total_steps
+ )
+ # Force last step to compute
+ compute_mask[-1] = 1
+ return compute_mask
diff --git a/src/cache_dit/caching/params_modifier.py b/src/cache_dit/caching/params_modifier.py
index 43d01c209..50024b0fa 100644
--- a/src/cache_dit/caching/params_modifier.py
+++ b/src/cache_dit/caching/params_modifier.py
@@ -1,7 +1,7 @@
from typing import Optional
-from cache_dit.caching.cache_contexts import BasicCacheConfig
-from cache_dit.caching.cache_contexts import CalibratorConfig
+from .cache_contexts import BasicCacheConfig
+from .cache_contexts import CalibratorConfig
from cache_dit.logger import init_logger
diff --git a/src/cache_dit/caching/patch_functors/__init__.py b/src/cache_dit/caching/patch_functors/__init__.py
index feb7e5fac..58e6ccf71 100644
--- a/src/cache_dit/caching/patch_functors/__init__.py
+++ b/src/cache_dit/caching/patch_functors/__init__.py
@@ -1,18 +1,44 @@
-from cache_dit.caching.patch_functors.functor_base import PatchFunctor
-from cache_dit.caching.patch_functors.functor_dit import DiTPatchFunctor
-from cache_dit.caching.patch_functors.functor_flux import FluxPatchFunctor
-from cache_dit.caching.patch_functors.functor_chroma import (
- ChromaPatchFunctor,
+import importlib
+from cache_dit.logger import init_logger
+from .functor_base import PatchFunctor
+
+logger = init_logger(__name__)
+
+
+class ImportErrorPatchFunctor(PatchFunctor):
+ def _apply(
+ self,
+ transformer,
+ **kwargs,
+ ):
+ raise ImportError(
+ "This PatchFunctor requires latest diffusers to be installed. "
+ "Please install diffusers from source."
+ )
+
+
+def _safe_import(module_name: str, class_name: str) -> type[PatchFunctor]:
+ try:
+ # e.g., module_name = ".functor_dit", class_name = "DiTPatchFunctor"
+ package = __package__ if __package__ is not None else ""
+ module = importlib.import_module(module_name, package=package)
+ target_class = getattr(module, class_name)
+ return target_class
+ except (ImportError, AttributeError) as e:
+ logger.debug(f"Warning: Failed to import {class_name} from {module_name}: {e}")
+ return ImportErrorPatchFunctor
+
+
+DiTPatchFunctor = _safe_import(".functor_dit", "DiTPatchFunctor")
+FluxPatchFunctor = _safe_import(".functor_flux", "FluxPatchFunctor")
+ChromaPatchFunctor = _safe_import(".functor_chroma", "ChromaPatchFunctor")
+HiDreamPatchFunctor = _safe_import(".functor_hidream", "HiDreamPatchFunctor")
+HunyuanDiTPatchFunctor = _safe_import(".functor_hunyuan_dit", "HunyuanDiTPatchFunctor")
+QwenImageControlNetPatchFunctor = _safe_import(
+ ".functor_qwen_image_controlnet", "QwenImageControlNetPatchFunctor"
)
-from cache_dit.caching.patch_functors.functor_hidream import (
- HiDreamPatchFunctor,
-)
-from cache_dit.caching.patch_functors.functor_hunyuan_dit import (
- HunyuanDiTPatchFunctor,
-)
-from cache_dit.caching.patch_functors.functor_qwen_image_controlnet import (
- QwenImageControlNetPatchFunctor,
-)
-from cache_dit.caching.patch_functors.functor_wan_vace import (
- WanVACEPatchFunctor,
+WanVACEPatchFunctor = _safe_import(".functor_wan_vace", "WanVACEPatchFunctor")
+LTX2PatchFunctor = _safe_import(".functor_ltx2", "LTX2PatchFunctor")
+ZImageControlNetPatchFunctor = _safe_import(
+ ".functor_zimage_controlnet", "ZImageControlNetPatchFunctor"
)
diff --git a/src/cache_dit/caching/patch_functors/functor_base.py b/src/cache_dit/caching/patch_functors/functor_base.py
index 17b6f4b6a..844b35c2d 100644
--- a/src/cache_dit/caching/patch_functors/functor_base.py
+++ b/src/cache_dit/caching/patch_functors/functor_base.py
@@ -2,17 +2,38 @@
from abc import abstractmethod
from cache_dit.logger import init_logger
+from cache_dit.envs import ENV
logger = init_logger(__name__)
class PatchFunctor:
- @abstractmethod
def apply(
self,
transformer: torch.nn.Module,
*args,
**kwargs,
) -> torch.nn.Module:
- raise NotImplementedError("apply method is not implemented.")
+ if not ENV.CACHE_DIT_PATCH_FUNCTOR_DISABLE_DIFFUSERS_CHECK:
+ if not self.is_from_diffusers(transformer):
+ return transformer
+ return self._apply(transformer, *args, **kwargs)
+
+ @abstractmethod
+ def _apply(
+ self,
+ transformer: torch.nn.Module,
+ *args,
+ **kwargs,
+ ) -> torch.nn.Module:
+ raise NotImplementedError("_apply method is not implemented.")
+
+ @classmethod
+ def is_from_diffusers(cls, transformer: torch.nn.Module) -> bool:
+ if ENV.CACHE_DIT_PATCH_FUNCTOR_DISABLE_DIFFUSERS_CHECK:
+ return True
+ if transformer.__module__.startswith("diffusers"):
+ return True
+ logger.warning("Found transformer not from diffusers. Skipping patch functor.")
+ return False
diff --git a/src/cache_dit/caching/patch_functors/functor_chroma.py b/src/cache_dit/caching/patch_functors/functor_chroma.py
index 1c1814d82..046090bfe 100644
--- a/src/cache_dit/caching/patch_functors/functor_chroma.py
+++ b/src/cache_dit/caching/patch_functors/functor_chroma.py
@@ -13,9 +13,7 @@
unscale_lora_layers,
)
-from cache_dit.caching.patch_functors.functor_base import (
- PatchFunctor,
-)
+from .functor_base import PatchFunctor
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -23,7 +21,7 @@
class ChromaPatchFunctor(PatchFunctor):
- def apply(
+ def _apply(
self,
transformer: ChromaTransformer2DModel,
**kwargs,
diff --git a/src/cache_dit/caching/patch_functors/functor_dit.py b/src/cache_dit/caching/patch_functors/functor_dit.py
index 65c7d04cc..7503f83d2 100644
--- a/src/cache_dit/caching/patch_functors/functor_dit.py
+++ b/src/cache_dit/caching/patch_functors/functor_dit.py
@@ -6,9 +6,7 @@
DiTTransformer2DModel,
Transformer2DModelOutput,
)
-from cache_dit.caching.patch_functors.functor_base import (
- PatchFunctor,
-)
+from .functor_base import PatchFunctor
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -16,7 +14,7 @@
class DiTPatchFunctor(PatchFunctor):
- def apply(
+ def _apply(
self,
transformer: DiTTransformer2DModel,
**kwargs,
diff --git a/src/cache_dit/caching/patch_functors/functor_flux.py b/src/cache_dit/caching/patch_functors/functor_flux.py
index f84030d17..0c1b837f7 100644
--- a/src/cache_dit/caching/patch_functors/functor_flux.py
+++ b/src/cache_dit/caching/patch_functors/functor_flux.py
@@ -14,9 +14,7 @@
unscale_lora_layers,
)
-from cache_dit.caching.patch_functors.functor_base import (
- PatchFunctor,
-)
+from .functor_base import PatchFunctor
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -24,7 +22,7 @@
class FluxPatchFunctor(PatchFunctor):
- def apply(
+ def _apply(
self,
transformer: FluxTransformer2DModel,
blocks: torch.nn.ModuleList = None,
diff --git a/src/cache_dit/caching/patch_functors/functor_hidream.py b/src/cache_dit/caching/patch_functors/functor_hidream.py
index 3e0f84a53..981b58cb1 100644
--- a/src/cache_dit/caching/patch_functors/functor_hidream.py
+++ b/src/cache_dit/caching/patch_functors/functor_hidream.py
@@ -13,9 +13,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from cache_dit.caching.patch_functors.functor_base import (
- PatchFunctor,
-)
+from .functor_base import PatchFunctor
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -23,7 +21,7 @@
class HiDreamPatchFunctor(PatchFunctor):
- def apply(
+ def _apply(
self,
transformer: HiDreamImageTransformer2DModel,
**kwargs,
diff --git a/src/cache_dit/caching/patch_functors/functor_hunyuan_dit.py b/src/cache_dit/caching/patch_functors/functor_hunyuan_dit.py
index 22630c128..db40ecc9d 100644
--- a/src/cache_dit/caching/patch_functors/functor_hunyuan_dit.py
+++ b/src/cache_dit/caching/patch_functors/functor_hunyuan_dit.py
@@ -5,9 +5,7 @@
HunyuanDiTBlock,
Transformer2DModelOutput,
)
-from cache_dit.caching.patch_functors.functor_base import (
- PatchFunctor,
-)
+from .functor_base import PatchFunctor
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -15,7 +13,7 @@
class HunyuanDiTPatchFunctor(PatchFunctor):
- def apply(
+ def _apply(
self,
transformer: HunyuanDiT2DModel,
**kwargs,
diff --git a/src/cache_dit/caching/patch_functors/functor_ltx2.py b/src/cache_dit/caching/patch_functors/functor_ltx2.py
new file mode 100644
index 000000000..1c56ad6e9
--- /dev/null
+++ b/src/cache_dit/caching/patch_functors/functor_ltx2.py
@@ -0,0 +1,329 @@
+import torch
+from typing import Optional, Dict, Any
+
+try:
+ from diffusers.models.transformers.transformer_ltx2 import (
+ LTX2VideoTransformer3DModel,
+ AudioVisualModelOutput,
+ )
+except ImportError:
+ raise ImportError(
+ "LTX2VideoTransformer3DModel is not available. "
+ "Please install the latest version of diffusers."
+ )
+
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ scale_lora_layers,
+ unscale_lora_layers,
+)
+from .functor_base import PatchFunctor
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+class LTX2PatchFunctor(PatchFunctor):
+
+ def _apply(
+ self,
+ transformer: LTX2VideoTransformer3DModel,
+ **kwargs,
+ ) -> torch.nn.Module:
+
+ if hasattr(transformer, "_is_patched"):
+ return transformer
+
+ assert isinstance(transformer, LTX2VideoTransformer3DModel)
+
+ is_patched = False
+
+ cls_name = transformer.__class__.__name__
+
+ transformer.forward = __patch_transformer_forward__.__get__(transformer)
+
+ is_patched = True
+
+ if is_patched:
+ logger.warning(f"Patched {cls_name} for cache-dit.")
+ assert not getattr(transformer, "_is_parallelized", False), (
+ "Please call `cache_dit.enable_cache` before Parallelize, "
+ "the __patch_transformer_forward__ will overwrite the "
+ "parallized forward and cause a downgrade of performance."
+ )
+
+ transformer._is_patched = is_patched # True or False
+
+ logger.info(f"Applied {self.__class__.__name__} for {cls_name}, " f"Patch: {is_patched}.")
+ return transformer
+
+
+def __patch_transformer_forward__(
+ self: LTX2VideoTransformer3DModel,
+ hidden_states: torch.Tensor,
+ audio_hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ audio_encoder_hidden_states: torch.Tensor,
+ timestep: torch.LongTensor,
+ audio_timestep: Optional[torch.LongTensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ audio_encoder_attention_mask: Optional[torch.Tensor] = None,
+ num_frames: Optional[int] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ fps: float = 24.0,
+ audio_num_frames: Optional[int] = None,
+ video_coords: Optional[torch.Tensor] = None,
+ audio_coords: Optional[torch.Tensor] = None,
+ attention_kwargs: Optional[Dict[str, Any]] = None,
+ return_dict: bool = True,
+) -> torch.Tensor:
+ """
+ Forward pass for LTX-2.0 audiovisual video transformer.
+
+ Args:
+ hidden_states (`torch.Tensor`):
+ Input patchified video latents of shape `(batch_size, num_video_tokens, in_channels)`.
+ audio_hidden_states (`torch.Tensor`):
+ Input patchified audio latents of shape `(batch_size, num_audio_tokens, audio_in_channels)`.
+ encoder_hidden_states (`torch.Tensor`):
+ Input video text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`.
+ audio_encoder_hidden_states (`torch.Tensor`):
+ Input audio text embeddings of shape `(batch_size, text_seq_len, self.config.caption_channels)`.
+ timestep (`torch.Tensor`):
+ Input timestep of shape `(batch_size, num_video_tokens)`. These should already be scaled by
+ `self.config.timestep_scale_multiplier`.
+ audio_timestep (`torch.Tensor`, *optional*):
+ Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation
+ params. This is only used by certain pipelines such as the I2V pipeline.
+ encoder_attention_mask (`torch.Tensor`, *optional*):
+ Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
+ audio_encoder_attention_mask (`torch.Tensor`, *optional*):
+ Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling.
+ num_frames (`int`, *optional*):
+ The number of latent video frames. Used if calculating the video coordinates for RoPE.
+ height (`int`, *optional*):
+ The latent video height. Used if calculating the video coordinates for RoPE.
+ width (`int`, *optional*):
+ The latent video width. Used if calculating the video coordinates for RoPE.
+ fps: (`float`, *optional*, defaults to `24.0`):
+ The desired frames per second of the generated video. Used if calculating the video coordinates for
+ RoPE.
+ audio_num_frames: (`int`, *optional*):
+ The number of latent audio frames. Used if calculating the audio coordinates for RoPE.
+ video_coords (`torch.Tensor`, *optional*):
+ The video coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape
+ `(batch_size, 3, num_video_tokens, 2)`. If not supplied, this will be calculated inside `forward`.
+ audio_coords (`torch.Tensor`, *optional*):
+ The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape
+ `(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`.
+ attention_kwargs (`Dict[str, Any]`, *optional*):
+ Optional dict of keyword args to be passed to the attention processor.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple.
+
+ Returns:
+ `AudioVisualModelOutput` or `tuple`:
+ If `return_dict` is `True`, returns a structured output of type `AudioVisualModelOutput`, otherwise a
+ `tuple` is returned where the first element is the denoised video latent patch sequence and the second
+ element is the denoised audio latent patch sequence.
+ """
+ if attention_kwargs is not None:
+ attention_kwargs = attention_kwargs.copy()
+ lora_scale = attention_kwargs.pop("scale", 1.0)
+ else:
+ lora_scale = 1.0
+
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+ else:
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
+ logger.warning(
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
+ )
+
+ # Determine timestep for audio.
+ audio_timestep = audio_timestep if audio_timestep is not None else timestep
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ if audio_encoder_attention_mask is not None and audio_encoder_attention_mask.ndim == 2:
+ audio_encoder_attention_mask = (
+ 1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)
+ ) * -10000.0
+ audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
+
+ batch_size = hidden_states.size(0)
+
+ # 1. Prepare RoPE positional embeddings
+ if video_coords is None:
+ video_coords = self.rope.prepare_video_coords(
+ batch_size, num_frames, height, width, hidden_states.device, fps=fps
+ )
+ if audio_coords is None:
+ audio_coords = self.audio_rope.prepare_audio_coords(
+ batch_size, audio_num_frames, audio_hidden_states.device
+ )
+
+ video_rotary_emb = self.rope(video_coords, device=hidden_states.device)
+ audio_rotary_emb = self.audio_rope(audio_coords, device=audio_hidden_states.device)
+
+ video_cross_attn_rotary_emb = self.cross_attn_rope(
+ video_coords[:, 0:1, :], device=hidden_states.device
+ )
+ audio_cross_attn_rotary_emb = self.cross_attn_audio_rope(
+ audio_coords[:, 0:1, :], device=audio_hidden_states.device
+ )
+
+ # 2. Patchify input projections
+ hidden_states = self.proj_in(hidden_states)
+ audio_hidden_states = self.audio_proj_in(audio_hidden_states)
+
+ # 3. Prepare timestep embeddings and modulation parameters
+ timestep_cross_attn_gate_scale_factor = (
+ self.config.cross_attn_timestep_scale_multiplier / self.config.timestep_scale_multiplier
+ )
+
+ # 3.1. Prepare global modality (video and audio) timestep embedding and modulation parameters
+ # temb is used in the transformer blocks (as expected), while embedded_timestep is used for the output layer
+ # modulation with scale_shift_table (and similarly for audio)
+ temb, embedded_timestep = self.time_embed(
+ timestep.flatten(),
+ batch_size=batch_size,
+ hidden_dtype=hidden_states.dtype,
+ )
+ temb = temb.view(batch_size, -1, temb.size(-1))
+ embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.size(-1))
+
+ temb_audio, audio_embedded_timestep = self.audio_time_embed(
+ audio_timestep.flatten(),
+ batch_size=batch_size,
+ hidden_dtype=audio_hidden_states.dtype,
+ )
+ temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))
+ audio_embedded_timestep = audio_embedded_timestep.view(
+ batch_size, -1, audio_embedded_timestep.size(-1)
+ )
+
+ # 3.2. Prepare global modality cross attention modulation parameters
+ video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
+ timestep.flatten(),
+ batch_size=batch_size,
+ hidden_dtype=hidden_states.dtype,
+ )
+ video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
+ timestep.flatten() * timestep_cross_attn_gate_scale_factor,
+ batch_size=batch_size,
+ hidden_dtype=hidden_states.dtype,
+ )
+ video_cross_attn_scale_shift = video_cross_attn_scale_shift.view(
+ batch_size, -1, video_cross_attn_scale_shift.shape[-1]
+ )
+ video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(
+ batch_size, -1, video_cross_attn_a2v_gate.shape[-1]
+ )
+
+ audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
+ audio_timestep.flatten(),
+ batch_size=batch_size,
+ hidden_dtype=audio_hidden_states.dtype,
+ )
+ audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
+ audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
+ batch_size=batch_size,
+ hidden_dtype=audio_hidden_states.dtype,
+ )
+ audio_cross_attn_scale_shift = audio_cross_attn_scale_shift.view(
+ batch_size, -1, audio_cross_attn_scale_shift.shape[-1]
+ )
+ audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(
+ batch_size, -1, audio_cross_attn_v2a_gate.shape[-1]
+ )
+
+ # 4. Prepare prompt embeddings
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
+
+ audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
+ audio_encoder_hidden_states = audio_encoder_hidden_states.view(
+ batch_size, -1, audio_hidden_states.size(-1)
+ )
+
+ # 5. Run transformer blocks
+ for block in self.transformer_blocks:
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states, audio_hidden_states = self._gradient_checkpointing_func(
+ block,
+ hidden_states,
+ audio_hidden_states,
+ encoder_hidden_states,
+ audio_encoder_hidden_states,
+ temb,
+ temb_audio,
+ video_cross_attn_scale_shift,
+ audio_cross_attn_scale_shift,
+ video_cross_attn_a2v_gate,
+ audio_cross_attn_v2a_gate,
+ video_rotary_emb,
+ audio_rotary_emb,
+ video_cross_attn_rotary_emb,
+ audio_cross_attn_rotary_emb,
+ encoder_attention_mask,
+ audio_encoder_attention_mask,
+ )
+ else:
+ hidden_states, audio_hidden_states = block(
+ # Make block forward args consistent with original signature,
+ # thus, also make it compatible with caching in cache-dit.
+ # - Begin patching:
+ # hidden_states=hidden_states,
+ # audio_hidden_states=audio_hidden_states,
+ # encoder_hidden_states=encoder_hidden_states,
+ # audio_encoder_hidden_states=audio_encoder_hidden_states,
+ # - After patching:
+ hidden_states,
+ audio_hidden_states,
+ encoder_hidden_states,
+ audio_encoder_hidden_states,
+ temb=temb,
+ temb_audio=temb_audio,
+ temb_ca_scale_shift=video_cross_attn_scale_shift,
+ temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
+ temb_ca_gate=video_cross_attn_a2v_gate,
+ temb_ca_audio_gate=audio_cross_attn_v2a_gate,
+ video_rotary_emb=video_rotary_emb,
+ audio_rotary_emb=audio_rotary_emb,
+ ca_video_rotary_emb=video_cross_attn_rotary_emb,
+ ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
+ encoder_attention_mask=encoder_attention_mask,
+ audio_encoder_attention_mask=audio_encoder_attention_mask,
+ )
+
+ # 6. Output layers (including unpatchification)
+ scale_shift_values = self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
+
+ hidden_states = self.norm_out(hidden_states)
+ hidden_states = hidden_states * (1 + scale) + shift
+ output = self.proj_out(hidden_states)
+
+ audio_scale_shift_values = (
+ self.audio_scale_shift_table[None, None] + audio_embedded_timestep[:, :, None]
+ )
+ audio_shift, audio_scale = audio_scale_shift_values[:, :, 0], audio_scale_shift_values[:, :, 1]
+
+ audio_hidden_states = self.audio_norm_out(audio_hidden_states)
+ audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
+ audio_output = self.audio_proj_out(audio_hidden_states)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (output, audio_output)
+ return AudioVisualModelOutput(sample=output, audio_sample=audio_output)
diff --git a/src/cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py b/src/cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py
index 6725f2838..569cd263f 100644
--- a/src/cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py
+++ b/src/cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py
@@ -11,9 +11,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from cache_dit.caching.patch_functors.functor_base import (
- PatchFunctor,
-)
+from .functor_base import PatchFunctor
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -21,7 +19,7 @@
class QwenImageControlNetPatchFunctor(PatchFunctor):
- def apply(
+ def _apply(
self,
transformer: QwenImageTransformer2DModel,
**kwargs,
diff --git a/src/cache_dit/caching/patch_functors/functor_wan_vace.py b/src/cache_dit/caching/patch_functors/functor_wan_vace.py
index 45e02814a..ad73dc056 100644
--- a/src/cache_dit/caching/patch_functors/functor_wan_vace.py
+++ b/src/cache_dit/caching/patch_functors/functor_wan_vace.py
@@ -10,9 +10,7 @@
scale_lora_layers,
unscale_lora_layers,
)
-from cache_dit.caching.patch_functors.functor_base import (
- PatchFunctor,
-)
+from .functor_base import PatchFunctor
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -20,7 +18,7 @@
class WanVACEPatchFunctor(PatchFunctor):
- def apply(
+ def _apply(
self,
transformer: WanVACETransformer3DModel,
**kwargs,
diff --git a/src/cache_dit/caching/patch_functors/functor_zimage_controlnet.py b/src/cache_dit/caching/patch_functors/functor_zimage_controlnet.py
new file mode 100644
index 000000000..385d16734
--- /dev/null
+++ b/src/cache_dit/caching/patch_functors/functor_zimage_controlnet.py
@@ -0,0 +1,245 @@
+import torch
+from typing import Optional, Dict, List
+from diffusers.models.transformers.transformer_z_image import (
+ ZImageTransformer2DModel,
+ ZImageTransformerBlock,
+ Transformer2DModelOutput,
+ SEQ_MULTI_OF,
+ pad_sequence,
+)
+
+from .functor_base import PatchFunctor
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+class ZImageControlNetPatchFunctor(PatchFunctor):
+ def _apply(
+ self,
+ transformer: ZImageTransformer2DModel,
+ **kwargs,
+ ) -> ZImageTransformer2DModel:
+ if hasattr(transformer, "_is_patched"):
+ return transformer
+
+ is_patched = False
+
+ for layer_idx, layer in enumerate(transformer.layers):
+ if not hasattr(layer, "_is_patched"):
+ layer._layer_idx = layer_idx # type: ignore
+ layer.forward = __patch_block_forward__.__get__(layer)
+
+ is_patched = True
+ cls_name = transformer.__class__.__name__
+
+ if is_patched:
+ logger.warning(f"Patched {cls_name} for cache-dit.")
+ assert not getattr(transformer, "_is_parallelized", False), (
+ "Please call `cache_dit.enable_cache` before Parallelize, "
+ "the __patch_transformer_forward__ will overwrite the "
+ "parallized forward and cause a downgrade of performance."
+ )
+ transformer.forward = __patch_transformer_forward__.__get__(transformer)
+
+ transformer._is_patched = is_patched # True or False
+
+ logger.info(f"Applied {self.__class__.__name__} for {cls_name}, " f"Patch: {is_patched}.")
+ return transformer
+
+
+def __patch_block_forward__(
+ self: ZImageTransformerBlock,
+ x: torch.Tensor,
+ attn_mask: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ adaln_input: Optional[torch.Tensor] = None,
+ controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
+):
+ if self.modulation:
+ assert adaln_input is not None
+ scale_msa, gate_msa, scale_mlp, gate_mlp = (
+ self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
+ )
+ gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
+ scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
+
+ # Attention block
+ attn_out = self.attention(
+ self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
+ )
+ x = x + gate_msa * self.attention_norm2(attn_out)
+
+ # FFN block
+ x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
+ else:
+ # Attention block
+ attn_out = self.attention(
+ self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis
+ )
+ x = x + self.attention_norm2(attn_out)
+
+ # FFN block
+ x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
+
+ # ControlNet addition
+ if controlnet_block_samples is not None:
+ layer_idx = self._layer_idx # type: ignore
+ if layer_idx in controlnet_block_samples:
+ controlnet_sample = controlnet_block_samples[layer_idx]
+
+ # NOTE: Make it compatible for context parallelism
+ _parallel_config = getattr(self.attention.processor, "_parallel_config", None)
+ if _parallel_config is not None:
+ cp_config = _parallel_config.context_parallel_config
+ if cp_config is not None and cp_config._world_size > 1:
+ # Split controlnet_sample for each device using tensor split
+ # at sequence dim, which is dim=1.
+ controlnet_sample = torch.tensor_split(
+ controlnet_sample, cp_config._world_size, dim=1
+ )[cp_config._rank]
+
+ x = x + controlnet_sample
+
+ return x
+
+
+def __patch_transformer_forward__(
+ self: ZImageTransformer2DModel,
+ x: List[torch.Tensor],
+ t,
+ cap_feats: List[torch.Tensor],
+ controlnet_block_samples: Optional[Dict[int, torch.Tensor]] = None,
+ patch_size=2,
+ f_patch_size=1,
+ return_dict: bool = True,
+):
+ assert patch_size in self.all_patch_size
+ assert f_patch_size in self.all_f_patch_size
+
+ bsz = len(x)
+ device = x[0].device
+ t = t * self.t_scale
+ t = self.t_embedder(t)
+
+ (
+ x,
+ cap_feats,
+ x_size,
+ x_pos_ids,
+ cap_pos_ids,
+ x_inner_pad_mask,
+ cap_inner_pad_mask,
+ ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
+
+ # x embed & refine
+ x_item_seqlens = [len(_) for _ in x]
+ assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
+ x_max_item_seqlen = max(x_item_seqlens)
+
+ x = torch.cat(x, dim=0)
+ x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
+
+ # Match t_embedder output dtype to x for layerwise casting compatibility
+ adaln_input = t.type_as(x)
+ x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
+ x = list(x.split(x_item_seqlens, dim=0))
+ x_freqs_cis = list(
+ self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)
+ )
+
+ x = pad_sequence(x, batch_first=True, padding_value=0.0)
+ x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
+ # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
+ x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
+
+ x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
+ for i, seq_len in enumerate(x_item_seqlens):
+ x_attn_mask[i, :seq_len] = 1
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for layer in self.noise_refiner:
+ x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
+ else:
+ for layer in self.noise_refiner:
+ x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
+
+ # cap embed & refine
+ cap_item_seqlens = [len(_) for _ in cap_feats]
+ cap_max_item_seqlen = max(cap_item_seqlens)
+
+ cap_feats = torch.cat(cap_feats, dim=0)
+ cap_feats = self.cap_embedder(cap_feats)
+ cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
+ cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
+ cap_freqs_cis = list(
+ self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(
+ [len(_) for _ in cap_pos_ids], dim=0
+ )
+ )
+
+ cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
+ cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
+ # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
+ cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
+
+ cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
+ for i, seq_len in enumerate(cap_item_seqlens):
+ cap_attn_mask[i, :seq_len] = 1
+
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for layer in self.context_refiner:
+ cap_feats = self._gradient_checkpointing_func(
+ layer, cap_feats, cap_attn_mask, cap_freqs_cis
+ )
+ else:
+ for layer in self.context_refiner:
+ cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
+
+ # unified
+ unified = []
+ unified_freqs_cis = []
+ for i in range(bsz):
+ x_len = x_item_seqlens[i]
+ cap_len = cap_item_seqlens[i]
+ unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
+ unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
+ unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
+ assert unified_item_seqlens == [len(_) for _ in unified]
+ unified_max_item_seqlen = max(unified_item_seqlens)
+
+ unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
+ unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
+ unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
+ for i, seq_len in enumerate(unified_item_seqlens):
+ unified_attn_mask[i, :seq_len] = 1
+
+ # NOTE: Already fused controlnet_block_samples into each block forward function.
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ for layer_idx, layer in enumerate(self.layers):
+ unified = self._gradient_checkpointing_func(
+ layer,
+ unified,
+ unified_attn_mask,
+ unified_freqs_cis,
+ adaln_input,
+ controlnet_block_samples,
+ )
+ else:
+ for layer_idx, layer in enumerate(self.layers):
+ unified = layer(
+ unified,
+ unified_attn_mask,
+ unified_freqs_cis,
+ adaln_input,
+ controlnet_block_samples,
+ )
+
+ unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
+ unified = list(unified.unbind(dim=0))
+ x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
+
+ if not return_dict:
+ return (x,)
+
+ return Transformer2DModelOutput(sample=x)
diff --git a/src/cache_dit/caching/utils.py b/src/cache_dit/caching/utils.py
index ecd7157ac..303f2f220 100644
--- a/src/cache_dit/caching/utils.py
+++ b/src/cache_dit/caching/utils.py
@@ -1,73 +1,313 @@
import yaml
import copy
+from typing import Tuple, Optional, Union
+from .cache_contexts import (
+ DBCacheConfig,
+ TaylorSeerCalibratorConfig,
+ DBPruneConfig,
+ CalibratorConfig,
+)
+from ..parallelism import ParallelismConfig, ParallelismBackend
+from cache_dit.logger import init_logger
+logger = init_logger(__name__)
-def load_cache_options_from_dict(cache_kwargs: dict) -> dict:
+
+def load_cache_options_from_dict(cache_kwargs: dict, reset: bool = False) -> dict:
+ r"""
+ Load cache options from a dictionary. We keep this function for backward compatibility.
+ Args:
+ cache_kwargs (`dict`):
+ A dictionary containing the cache configuration.
+ reset (`bool`, *optional*, defaults to `False`):
+ Whether to reset the configuration to default values to None before applying the loaded settings.
+ This is useful when you want to ensure that only the settings specified in the dictionary
+ are applied, without retaining any previous configurations (e.g., when using ParaModifier to modify
+ existing configurations).
+ Returns:
+ `dict`: A dictionary containing the loaded cache options.
+ """
try:
# deep copy to avoid modifying original kwargs
kwargs: dict = copy.deepcopy(cache_kwargs)
cache_context_kwargs = {}
if kwargs.get("enable_taylorseer", False):
- from cache_dit.caching.cache_contexts.calibrators import (
- TaylorSeerCalibratorConfig,
- )
-
- cache_context_kwargs["calibrator_config"] = TaylorSeerCalibratorConfig(
- enable_calibrator=kwargs.get("enable_taylorseer"),
- enable_encoder_calibrator=kwargs.get("enable_encoder_taylorseer", False),
- calibrator_cache_type=kwargs.get("taylorseer_cache_type", "residual"),
- taylorseer_order=kwargs.get("taylorseer_order", 1),
+ cache_context_kwargs["calibrator_config"] = (
+ TaylorSeerCalibratorConfig(
+ enable_calibrator=kwargs.get("enable_taylorseer"),
+ enable_encoder_calibrator=kwargs.get("enable_encoder_taylorseer", False),
+ calibrator_cache_type=kwargs.get("taylorseer_cache_type", "residual"),
+ taylorseer_order=kwargs.get("taylorseer_order", 1),
+ )
+ if not reset
+ else TaylorSeerCalibratorConfig().reset(
+ enable_calibrator=kwargs.get("enable_taylorseer"),
+ enable_encoder_calibrator=kwargs.get("enable_encoder_taylorseer", False),
+ calibrator_cache_type=kwargs.get("taylorseer_cache_type", "residual"),
+ taylorseer_order=kwargs.get("taylorseer_order", 1),
+ )
)
if "cache_type" not in kwargs:
- from cache_dit.caching.cache_contexts import BasicCacheConfig
-
- cache_context_kwargs["cache_config"] = BasicCacheConfig()
+ # Assume DBCache if cache_type is not specified
+ cache_context_kwargs["cache_config"] = (
+ DBCacheConfig() if not reset else DBCacheConfig().reset()
+ )
cache_context_kwargs["cache_config"].update(**kwargs)
else:
cache_type = str(kwargs.get("cache_type", None))
if cache_type == "DBCache":
- from cache_dit.caching.cache_contexts import DBCacheConfig
- cache_context_kwargs["cache_config"] = DBCacheConfig()
+ cache_context_kwargs["cache_config"] = (
+ DBCacheConfig() if not reset else DBCacheConfig().reset()
+ )
cache_context_kwargs["cache_config"].update(**kwargs)
elif cache_type == "DBPrune":
- from cache_dit.caching.cache_contexts import DBPruneConfig
- cache_context_kwargs["cache_config"] = DBPruneConfig()
+ cache_context_kwargs["cache_config"] = (
+ DBPruneConfig() if not reset else DBPruneConfig().reset()
+ )
cache_context_kwargs["cache_config"].update(**kwargs)
else:
raise ValueError(f"Unsupported cache_type: {cache_type}.")
- if "parallelism_config" in kwargs:
- from cache_dit.parallelism.parallel_config import (
- ParallelismConfig,
- )
-
- parallelism_kwargs = kwargs.get("parallelism_config", {})
- cache_context_kwargs["parallelism_config"] = ParallelismConfig(**parallelism_kwargs)
-
return cache_context_kwargs
except Exception as e:
raise ValueError(f"Error parsing cache configuration. {str(e)}")
-def load_cache_options_from_yaml(yaml_file_path: str) -> dict:
+def load_cache_options_from_yaml(yaml_file_path: str, reset: bool = False) -> dict:
try:
with open(yaml_file_path, "r") as f:
kwargs: dict = yaml.safe_load(f)
- return load_cache_options_from_dict(kwargs)
+ return load_cache_options_from_dict(kwargs, reset)
except FileNotFoundError:
raise FileNotFoundError(f"Configuration file not found: {yaml_file_path}")
except yaml.YAMLError as e:
raise yaml.YAMLError(f"YAML file parsing error: {str(e)}")
-def load_options(path_or_dict: str | dict) -> dict:
+def load_options(path_or_dict: str | dict, reset: bool = False) -> dict:
+ r"""
+ Load cache options from a YAML file or a dictionary.
+ Args:
+ path_or_dict (`str` or `dict`):
+ The file path to the YAML configuration file or a dictionary containing the configuration.
+ reset (`bool`, *optional*, defaults to `False`):
+ Whether to reset the configuration to default values to None before applying the loaded settings.
+ This is useful when you want to ensure that only the settings specified in the file or dictionary
+ are applied, without retaining any previous configurations (e.g., when using ParaModifier to modify
+ existing configurations).
+ Returns:
+ `dict`: A dictionary containing the loaded cache options.
+ """
+ # Deprecated function warning
+ logger.warning(
+ "`load_options` is deprecated and will be removed in future versions. "
+ "Please use `load_configs` instead."
+ )
if isinstance(path_or_dict, str):
- return load_cache_options_from_yaml(path_or_dict)
+ return load_cache_options_from_yaml(path_or_dict, reset)
elif isinstance(path_or_dict, dict):
- return load_cache_options_from_dict(path_or_dict)
+ return load_cache_options_from_dict(path_or_dict, reset)
else:
raise ValueError("Input must be a file path (str) or a configuration dictionary (dict).")
+
+
+def load_cache_config(
+ path_or_dict: str | dict, **kwargs
+) -> Tuple[DBCacheConfig, Optional[CalibratorConfig]]:
+ r"""
+ New APU that only load cache configuration from a YAML file or a dictionary. Assumes
+ that the yaml contains a 'cache_config' section, and returns only that section.
+ Raise ValueError if not found.
+ Args:
+ path_or_dict (`str` or `dict`):
+ The file path to the YAML configuration file or a dictionary containing the configuration.
+ reset (`bool`, *optional*, defaults to `False`):
+ Whether to reset the configuration to default values to None before applying the loaded settings.
+ This is useful when you want to ensure that only the settings specified in the file or dictionary
+ are applied, without retaining any previous configurations (e.g., when using ParaModifier to modify
+ existing configurations).
+ Returns:
+ `dict`: A dictionary containing the loaded cache configuration.
+ """
+ if isinstance(path_or_dict, str):
+ try:
+ with open(path_or_dict, "r") as f:
+ cache_kwargs: dict = yaml.safe_load(f)
+ except FileNotFoundError:
+ raise FileNotFoundError(f"Configuration file not found: {path_or_dict}")
+ except yaml.YAMLError as e:
+ raise yaml.YAMLError(f"YAML file parsing error: {str(e)}")
+ elif isinstance(path_or_dict, dict):
+ cache_kwargs: dict = copy.deepcopy(path_or_dict)
+ else:
+ raise ValueError("Input must be a file path (str) or a configuration dictionary (dict).")
+
+ if "cache_config" not in cache_kwargs:
+ if "parallelism_config" in cache_kwargs:
+ # Allow missing cache_config for only parallelism_config checking
+ return None, None
+ # Try to load full cache options for backward compatibility if cache_config not found
+ # and the parallelism_config is also not provided. This is to support old config files
+ # and refresh_context api that only contains cache options (already used in vllm-omni).
+ cache_context_kwargs = load_cache_options_from_dict(
+ cache_kwargs, kwargs.get("reset", False)
+ )
+ cache_config: DBCacheConfig = cache_context_kwargs.get("cache_config", None)
+ calibrator_config = cache_context_kwargs.get("calibrator_config", None)
+ if cache_config is None:
+ raise ValueError("Failed to load 'cache_config'. Got None.")
+ return cache_config, calibrator_config
+
+ cache_config_kwargs = cache_kwargs["cache_config"]
+ # Parse steps_mask if exists
+ if "steps_computation_mask" in cache_config_kwargs:
+ steps_computation_mask = cache_config_kwargs["steps_computation_mask"]
+ if isinstance(steps_computation_mask, str):
+ assert (
+ "num_inference_steps" in cache_config_kwargs
+ ), "To parse steps_mask from str, 'num_inference_steps' must be provided in cache_config."
+ from .cache_interface import steps_mask
+
+ num_inference_steps = cache_config_kwargs["num_inference_steps"]
+ cache_config_kwargs["steps_computation_mask"] = steps_mask(
+ total_steps=num_inference_steps, mask_policy=steps_computation_mask
+ )
+ # Reuse load_cache_options_from_dict to parse cache_config
+ cache_context_kwargs = load_cache_options_from_dict(
+ cache_config_kwargs, kwargs.get("reset", False)
+ )
+ cache_config: DBCacheConfig = cache_context_kwargs.get("cache_config", None)
+ calibrator_config = cache_context_kwargs.get("calibrator_config", None)
+ if cache_config is None:
+ raise ValueError("Failed to load 'cache_config'. Got None.")
+ return cache_config, calibrator_config
+
+
+def load_parallelism_config(
+ path_or_dict: str | dict, **kwargs
+) -> Optional[ParallelismConfig] | bool:
+ r"""
+ Load parallelism configuration from a YAML file or a dictionary. Assumes that the yaml
+ contains a 'parallelism_config' section, and returns only that section. Raise ValueError
+ if not found.
+ Args:
+ path_or_dict (`str` or `dict`):
+ The file path to the YAML configuration file or a dictionary containing the configuration.
+ Returns:
+ `ParallelismConfig`: An instance of ParallelismConfig containing the loaded parallelism configuration.
+ """
+ if isinstance(path_or_dict, str):
+ try:
+ with open(path_or_dict, "r") as f:
+ parallel_kwargs: dict = yaml.safe_load(f)
+ except FileNotFoundError:
+ raise FileNotFoundError(f"Configuration file not found: {path_or_dict}")
+ except yaml.YAMLError as e:
+ raise yaml.YAMLError(f"YAML file parsing error: {str(e)}")
+ elif isinstance(path_or_dict, dict):
+ parallel_kwargs: dict = copy.deepcopy(path_or_dict)
+ else:
+ raise ValueError("Input must be a file path (str) or a configuration dictionary (dict).")
+
+ if kwargs.get("check_only", False):
+ return "parallelism_config" in parallel_kwargs
+
+ # Allow missing parallelism_config
+ if "parallelism_config" not in parallel_kwargs:
+ return None
+
+ parallelism_config_kwargs = parallel_kwargs["parallelism_config"]
+ if "backend" in parallelism_config_kwargs:
+ backend_str = parallelism_config_kwargs["backend"]
+ parallelism_config_kwargs["backend"] = ParallelismBackend.from_str(backend_str)
+
+ def _maybe_auto_parallel_size(size: str | int | None) -> Optional[int]:
+ if size is None:
+ return None
+ if isinstance(size, int):
+ return size
+ if isinstance(size, str) and size.lower() == "auto":
+ import torch.distributed as dist
+
+ size = 1
+ if dist.is_initialized():
+ # Assume world size is the parallel size
+ size = dist.get_world_size()
+ if size == 1:
+ logger.warning(
+ "Auto parallel size selected as 1. Make sure to run with torch.distributed "
+ "to utilize multiple devices for parallelism."
+ )
+ else:
+ logger.info(f"Auto selected parallel size to {size}.")
+ return size
+ raise ValueError(f"Invalid parallel size value: {size}. Must be int or 'auto'.")
+
+ if kwargs.get("auto_parallel_size", True):
+ if "ulysses_size" in parallelism_config_kwargs:
+ parallelism_config_kwargs["ulysses_size"] = _maybe_auto_parallel_size(
+ parallelism_config_kwargs["ulysses_size"]
+ )
+ if "ring_size" in parallelism_config_kwargs:
+ parallelism_config_kwargs["ring_size"] = _maybe_auto_parallel_size(
+ parallelism_config_kwargs["ring_size"]
+ )
+ if "tp_size" in parallelism_config_kwargs:
+ parallelism_config_kwargs["tp_size"] = _maybe_auto_parallel_size(
+ parallelism_config_kwargs["tp_size"]
+ )
+
+ parallelism_config = ParallelismConfig(**parallelism_config_kwargs)
+ return parallelism_config
+
+
+def load_configs(
+ path_or_dict: str | dict,
+ return_dict: bool = True,
+ **kwargs,
+) -> Union[Tuple[DBCacheConfig, Optional[CalibratorConfig], ParallelismConfig], dict]:
+ r"""
+ Load both cache and parallelism configurations from a YAML file or a dictionary. For example,
+ the YAML file can be structured as follows:
+ ```yaml
+ cache_config:
+ max_warmup_steps: 8
+ warmup_interval: 2
+ max_cached_steps: -1
+ max_continuous_cached_steps: 2
+ Fn_compute_blocks: 1
+ Bn_compute_blocks: 0
+ residual_diff_threshold: 0.12
+ enable_taylorseer: true
+ taylorseer_order: 1
+ parallelism_config:
+ ulysses_size: 4
+ parallel_kwargs:
+ attention_backend: native
+ experimental_ulysses_anything: true
+ experimental_ulysses_float8: true
+ extra_parallel_modules: ["text_encoder", "vae"]
+ ```
+ Args:
+ path_or_dict (`str` or `dict`):
+ The file path to the YAML configuration file or a dictionary containing the configuration.
+ Returns:
+ `Tuple[DBCacheConfig, Optional[CalibratorConfig], ParallelismConfig]`: A tuple containing the loaded
+ cache configuration, optional calibrator configuration, and parallelism configuration. If `return_dict`
+ is set to `True`, returns a dictionary with keys "cache_config", "calibrator_config", and "parallelism_config".
+ """
+ cache_config, calibrator_config = load_cache_config(path_or_dict, **kwargs)
+ parallelism_config = load_parallelism_config(path_or_dict, **kwargs)
+ if isinstance(parallelism_config, bool):
+ parallelism_config = None
+ if return_dict:
+ return {
+ "cache_config": cache_config,
+ "calibrator_config": calibrator_config,
+ "parallelism_config": parallelism_config,
+ }
+ return cache_config, calibrator_config, parallelism_config
diff --git a/src/cache_dit/compile/__init__.py b/src/cache_dit/compile/__init__.py
index 730222984..5d8f9058d 100644
--- a/src/cache_dit/compile/__init__.py
+++ b/src/cache_dit/compile/__init__.py
@@ -1,4 +1 @@
-from cache_dit.compile.utils import set_compile_configs
-from cache_dit.compile.utils import enable_compile_compute_comm_overlap
-from cache_dit.compile.utils import disable_compile_compute_comm_overlap
-from cache_dit.compile.utils import is_compile_compute_comm_overlap_enabled
+from .utils import set_compile_configs
diff --git a/src/cache_dit/compile/utils.py b/src/cache_dit/compile/utils.py
index f26875217..a25174b7c 100644
--- a/src/cache_dit/compile/utils.py
+++ b/src/cache_dit/compile/utils.py
@@ -1,60 +1,23 @@
-import os
-
import torch
import torch.distributed as dist
-from cache_dit.logger import init_logger, logging_rank_0
+from ..envs import ENV
+from ..platforms import current_platform
+from cache_dit.logger import init_logger
logger = init_logger(__name__)
-def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
- mode = kwargs.get("epilogue_prologue_fusion", False)
- CACHE_DIT_EPILOGUE_PROLOGUE_FUSION = bool(
- int(os.environ.get("CACHE_DIT_EPILOGUE_PROLOGUE_FUSION", "0"))
- )
-
- if CACHE_DIT_EPILOGUE_PROLOGUE_FUSION:
- logging_rank_0(
- logger,
- "CACHE_DIT_EPILOGUE_PROLOGUE_FUSION is set to 1. \n"
- "Force enable epilogue and prologue fusion.",
- )
-
- return CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or mode
-
-
-_CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP = (
- os.environ.get(
- "CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP",
- "1",
- )
- == "1"
-)
-
-
-def enable_compile_compute_comm_overlap():
- global _CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP
- _CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP = True
- logger.info("Enabled compile compute-communication overlap manually.")
-
-
-def disable_compile_compute_comm_overlap():
- global _CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP
- _CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP = False
- logger.info("Disabled compile compute-communication overlap manually.")
-
-
-def is_compile_compute_comm_overlap_enabled() -> bool:
- global _CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP
- return _CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP
-
-
def set_compile_configs(
descent_tuning: bool = False,
cuda_graphs: bool = False,
force_disable_compile_caches: bool = False,
+ fx_graph_cache: bool = True,
+ fx_graph_remote_cache: bool = False,
+ autotune_local_cache: bool = False,
use_fast_math: bool = False,
compute_comm_overlap: bool = True,
+ capture_scalar_outputs: bool = False,
+ capture_dynamic_output_shape_ops: bool = False,
**kwargs, # other kwargs
):
# Alway increase recompile_limit for dynamic shape compilation
@@ -62,31 +25,32 @@ def set_compile_configs(
torch._dynamo.config.accumulated_recompile_limit = 8192 # default is 256
# Handle compiler caches
# https://github.com/vllm-project/vllm/blob/23baa2180b0ebba5ae94073ba9b8e93f88b75486/vllm/compilation/compiler_interface.py#L270
- torch._inductor.config.fx_graph_cache = True
- torch._inductor.config.fx_graph_remote_cache = False
+ torch._inductor.config.fx_graph_cache = fx_graph_cache
+ torch._inductor.config.fx_graph_remote_cache = fx_graph_remote_cache
# https://github.com/pytorch/pytorch/issues/153791
- torch._inductor.config.autotune_local_cache = False
+ torch._inductor.config.autotune_local_cache = autotune_local_cache
if dist.is_initialized():
# Enable compute comm overlap
torch._inductor.config.reorder_for_compute_comm_overlap = (
- compute_comm_overlap and is_compile_compute_comm_overlap_enabled()
+ compute_comm_overlap and ENV.CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP
)
# L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
if torch._inductor.config.reorder_for_compute_comm_overlap:
torch._inductor.config.intra_node_bw = (
- 64 if "L20" in torch.cuda.get_device_name() else 300
+ 64 if "L20" in current_platform.get_device_name() else 300
)
+ # https://docs.pytorch.org/docs/stable/nested.html#data-dependent-operation-within-torch-compile
+ if hasattr(torch._dynamo.config, "capture_scalar_outputs"):
+ torch._dynamo.config.capture_scalar_outputs = capture_scalar_outputs
+ torch._dynamo.config.capture_dynamic_output_shape_ops = capture_dynamic_output_shape_ops
+
if not descent_tuning:
return
- FORCE_DISABLE_CUSTOM_COMPILE_CONFIG = (
- os.environ.get("CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG", "0") == "1"
- )
- if FORCE_DISABLE_CUSTOM_COMPILE_CONFIG:
- logging_rank_0(
- logger,
+ if ENV.CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG:
+ logger.info(
"CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG is set to 1. \n"
"Force disable custom compile config.",
)
@@ -107,7 +71,7 @@ def set_compile_configs(
torch._inductor.config.epilogue_fusion = False
# Enable epilogue and prologue fusion
- if epilogue_prologue_fusion_enabled(**kwargs):
+ if ENV.CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or kwargs.get("epilogue_prologue_fusion", False):
torch._inductor.config.epilogue_fusion = True
torch._inductor.config.prologue_fusion = True
torch._inductor.config.epilogue_fusion_first = True
diff --git a/src/cache_dit/envs.py b/src/cache_dit/envs.py
new file mode 100644
index 000000000..44eb30160
--- /dev/null
+++ b/src/cache_dit/envs.py
@@ -0,0 +1,90 @@
+import os
+
+
+class ENV(object):
+ # ENVs for cache-dit
+
+ # Logging ENVs
+ CACHE_DIT_LOG_LEVEL: str = os.environ.get("CACHE_DIT_LOG_LEVEL", "info")
+ CACHE_DIT_LOG_DIR: str = os.environ.get("CACHE_DIT_LOG_DIR", None)
+
+ # Parallelism ENVs
+
+ # Enable custom attention backend dispatch for context parallelism
+ # in cache-dit by default. Users can set the environment variable
+ # to 0 to disable this behavior. Default to enabled for better
+ # compatibility and performance.
+ CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH: bool = bool(
+ int(os.getenv("CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH", "1"))
+ )
+
+ # Avoid re-registering custom attention backend dispatch in cache-dit.
+ # Inner use only. Users should not set this variable directly.
+ CACHE_DIT_ENABLE_CUSTOM_ATTN_ALREADY_DISPATCH: bool = bool(
+ int(os.getenv("CACHE_DIT_ENABLE_CUSTOM_ATTN_ALREADY_DISPATCH", "0"))
+ )
+
+ # Environment variable flags for Ulysses Attention variants in cache-dit.
+ # Enable Ulysses Anything Attention by setting the environment variable to 1.
+ # Otherwise, users can set it by 'exprimental_ulysses_anything' argument in
+ # ContextParallelism.
+ CACHE_DIT_ENABELD_ULYSSES_ANYTHING: bool = bool(
+ int(os.environ.get("CACHE_DIT_ENABELD_ULYSSES_ANYTHING", "0"))
+ )
+
+ # Enable Ulysses Anything Attention Float8 by setting the environment variable to 1.
+ # Otherwise, users can set it by 'experimental_ulysses_anything=True' and
+ # 'experimental_ulysses_float=True' arguments in ContextParallelism.
+ CACHE_DIT_ENABELD_ULYSSES_ANYTHING_FLOAT8: bool = bool(
+ int(os.environ.get("CACHE_DIT_ENABELD_ULYSSES_ANYTHING_FLOAT8", "0"))
+ )
+
+ # Enable Ulysses Attention by setting the environment variable to 1.
+ # Otherwise, users can set it by 'experimental_ulysses_float8' argument in
+ # ContextParallelism.
+ CACHE_DIT_ENABELD_ULYSSES_FLOAT8: bool = bool(
+ int(os.environ.get("CACHE_DIT_ENABELD_ULYSSES_FLOAT8", "0"))
+ )
+
+ # Enable unpadded communication for uneven attention heads without padding
+ # by setting the environment variable to 1.
+ CACHE_DIT_UNEVEN_HEADS_COMM_NO_PAD: bool = bool(
+ int(os.environ.get("CACHE_DIT_UNEVEN_HEADS_COMM_NO_PAD", "0"))
+ )
+
+ # Models ENVs
+
+ # Users should never use this variable directly, it is only for developers
+ # to control whether to enable dummy blocks for FLUX, default to enabled.
+ CACHE_DIT_FLUX_ENABLE_DUMMY_BLOCKS: bool = bool(
+ int(os.environ.get("CACHE_DIT_FLUX_ENABLE_DUMMY_BLOCKS", "1"))
+ )
+
+ # Torch compile ENVs
+
+ # Enable epilogue and prologue fusion in cache-dit compile optimizations
+ CACHE_DIT_EPILOGUE_PROLOGUE_FUSION: bool = bool(
+ int(os.environ.get("CACHE_DIT_EPILOGUE_PROLOGUE_FUSION", "0"))
+ )
+
+ # Enable compile compute-communication (all reduce) overlap in cache-dit by
+ # default. Users can set the environment variable to 0 to disable this behavior.
+ # Default to enabled for better performance.
+ CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP: bool = bool(
+ int(os.environ.get("CACHE_DIT_ENABLE_COMPILE_COMPUTE_COMM_OVERLAP", "1"))
+ )
+
+ # Force disable custom compile config in cache-dit by setting the environment
+ # variable to 1. Otherwise, cache-dit will set custom compile configs for
+ # better performance during torch.compile.
+ CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG: bool = bool(
+ int(os.environ.get("CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG", "0"))
+ )
+
+ # Patch Functors ENVs
+
+ # Force disable the checking of whether the model is from diffusers in patch functors.
+ # Users can set the environment variable to 1 to disable this behavior.
+ CACHE_DIT_PATCH_FUNCTOR_DISABLE_DIFFUSERS_CHECK: bool = bool(
+ int(os.environ.get("CACHE_DIT_PATCH_FUNCTOR_DISABLE_DIFFUSERS_CHECK", "0"))
+ )
diff --git a/src/cache_dit/kernels/__init__.py b/src/cache_dit/kernels/__init__.py
index e69de29bb..a74c825e9 100644
--- a/src/cache_dit/kernels/__init__.py
+++ b/src/cache_dit/kernels/__init__.py
@@ -0,0 +1,6 @@
+from .triton import (
+ per_token_quant_fp8,
+ per_token_dequant_fp8,
+ qkv_permute_quant_fp8,
+ qkv_dequant_permute_fp8,
+)
diff --git a/src/cache_dit/quantize/backends/bitsandbytes/__init__.py b/src/cache_dit/kernels/cuda/__init__.py
similarity index 100%
rename from src/cache_dit/quantize/backends/bitsandbytes/__init__.py
rename to src/cache_dit/kernels/cuda/__init__.py
diff --git a/src/cache_dit/kernels/triton/__init__.py b/src/cache_dit/kernels/triton/__init__.py
new file mode 100644
index 000000000..71079be92
--- /dev/null
+++ b/src/cache_dit/kernels/triton/__init__.py
@@ -0,0 +1,4 @@
+from .per_token_quant_8bit import per_token_quant_fp8
+from .per_token_quant_8bit import per_token_dequant_fp8
+from .per_token_quant_8bit import qkv_permute_quant_fp8
+from .per_token_quant_8bit import qkv_dequant_permute_fp8
diff --git a/src/cache_dit/kernels/triton/per_token_quant_8bit.py b/src/cache_dit/kernels/triton/per_token_quant_8bit.py
new file mode 100644
index 000000000..dc7e115c5
--- /dev/null
+++ b/src/cache_dit/kernels/triton/per_token_quant_8bit.py
@@ -0,0 +1,260 @@
+import torch
+import triton
+import triton.language as tl
+
+__all__ = [
+ "per_token_quant_fp8",
+ "per_token_dequant_fp8",
+ "qkv_permute_quant_fp8",
+ "qkv_dequant_permute_fp8",
+]
+
+
+@triton.jit
+def _per_token_quant_8bit(
+ y_ptr: tl.tensor,
+ x_ptr: tl.tensor,
+ H: int,
+ eps: float,
+ bit8_min: float,
+ bit8_max: float,
+ BLOCK: tl.constexpr,
+):
+ s_id = tl.program_id(0).to(tl.int64)
+ y_ptr += s_id * (H + 2)
+ y_s_ptr = y_ptr + H
+ x_ptr += s_id * H
+
+ _absmax = tl.full([BLOCK], value=eps, dtype=tl.float32)
+ for h in range(0, H, BLOCK):
+ cols = h + tl.arange(0, BLOCK).to(tl.int64)
+ mask = cols < H
+ x = tl.load(x_ptr + cols, mask=mask, other=0.0, eviction_policy="evict_last").to(tl.float32)
+ _absmax = tl.maximum(tl.abs(x), _absmax)
+
+ _absmax = tl.max(_absmax)
+ x_s = _absmax / bit8_max
+ x_s_inv = 1.0 / x_s
+ x_s = x_s.to(x_ptr.dtype.element_ty)
+
+ y_s_ptr = y_s_ptr.to(tl.pointer_type(x_ptr.dtype.element_ty, 1))
+ tl.store(y_s_ptr, x_s)
+
+ for h in range(0, H, BLOCK):
+ cols = h + tl.arange(0, BLOCK).to(tl.int64)
+ mask = cols < H
+ x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32)
+ x_q = tl.clamp(x * x_s_inv, bit8_min, bit8_max).to(y_ptr.dtype.element_ty)
+ tl.store(y_ptr + cols, x_q, mask=mask)
+
+
+@triton.jit
+def _per_token_dequant_8bit(
+ y_ptr: tl.tensor,
+ x_ptr: tl.tensor,
+ H: int,
+ BLOCK: tl.constexpr,
+):
+ s_id = tl.program_id(0).to(tl.int64)
+ y_ptr += s_id * H
+ x_ptr += s_id * (H + 2)
+
+ x_s_ptr = x_ptr + H
+ x_s_ptr = x_s_ptr.to(tl.pointer_type(y_ptr.dtype.element_ty, 1))
+ x_s = tl.load(x_s_ptr).to(tl.float32)
+
+ for h in range(0, H, BLOCK):
+ cols = h + tl.arange(0, BLOCK).to(tl.int64)
+ mask = cols < H
+ x = tl.load(x_ptr + cols, mask=mask, other=0.0).to(tl.float32)
+ x = x * x_s
+ tl.store(y_ptr + cols, x, mask=mask)
+
+
+def per_token_quant_fp8(x: torch.Tensor) -> torch.Tensor:
+ assert x.dtype == torch.bfloat16, f"expected bfloat16 but got {x.dtype}"
+ dtype = torch.float8_e4m3fn
+ finfo = torch.finfo(dtype)
+ *shape, H = x.shape
+ x = x.reshape(-1, H).contiguous()
+ M, N = x.shape
+ y = torch.empty((M, N + 2), dtype=dtype, device=x.device)
+
+ BLOCK = max(min(8192, 65536 // x.element_size(), triton.next_power_of_2(N)), 128)
+ num_warps = min(max(BLOCK // 256, 1), 8)
+
+ with torch.cuda.device(x.device):
+ _per_token_quant_8bit[(M,)](
+ y,
+ x,
+ N,
+ eps=1e-4,
+ bit8_min=finfo.min,
+ bit8_max=finfo.max,
+ BLOCK=BLOCK,
+ num_warps=num_warps,
+ )
+ return y.reshape(*shape, H + 2)
+
+
+def per_token_dequant_fp8(x: torch.Tensor) -> torch.Tensor:
+ assert x.dtype == torch.float8_e4m3fn, f"expected float8_e4m3fn but got {x.dtype}"
+ *shape, H = x.shape
+ x = x.reshape(-1, H).contiguous()
+ M, N = x.shape
+ N -= 2
+ y = torch.empty((M, N), dtype=torch.bfloat16, device=x.device)
+
+ BLOCK = max(min(8192, 65536 // x.element_size(), triton.next_power_of_2(N)), 128)
+ num_warps = min(max(BLOCK // 256, 1), 8)
+
+ with torch.cuda.device(x.device):
+ _per_token_dequant_8bit[(M,)](
+ y,
+ x,
+ N,
+ BLOCK=BLOCK,
+ num_warps=num_warps,
+ )
+ return y.reshape(*shape, H - 2)
+
+
+@triton.jit
+def _qkv_permute_quant(
+ quant_x_ptr: tl.tensor,
+ x_ptr: tl.tensor,
+ qx_stride_b: int,
+ qx_stride_n: int,
+ x_stride_b: int,
+ x_stride_s: int,
+ x_stride_p: int,
+ B: int,
+ S: int,
+ N: int,
+ D: int,
+ EPS: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_D: tl.constexpr,
+):
+ psb_id = tl.program_id(0).to(tl.int64)
+ b_id = psb_id % B
+ s_id = (psb_id // B) % S
+ p_id = psb_id // (S * B)
+
+ x_ptr += b_id * x_stride_b + s_id * x_stride_s + p_id * x_stride_p
+ quant_x_ptr += psb_id * qx_stride_b
+ scale_ptr = quant_x_ptr.to(tl.pointer_type(tl.float32, 1))
+
+ n_offset = tl.arange(0, BLOCK_SIZE_N)[None, :]
+ n_mask = n_offset < N
+ d_offset = tl.arange(0, BLOCK_SIZE_D)[:, None]
+ d_mask = d_offset < D
+ mask = n_mask & d_mask
+
+ quant_x_blk = quant_x_ptr + n_offset * qx_stride_n + d_offset
+ scale_blk = scale_ptr + n_offset * (D // 4 + 1) + D // 4
+ x_blk = x_ptr + n_offset * D + d_offset
+
+ x = tl.load(x_blk, mask=mask, other=0.0).to(tl.float32)
+ scale = tl.max(tl.abs(x), axis=0, keep_dims=True) / 448.0
+ scale = tl.maximum(scale, EPS)
+ quant_x = x / scale
+ quant_x = tl.clamp(quant_x, -448.0, 448.0).to(tl.float8e4nv)
+
+ tl.store(quant_x_blk, quant_x, mask=mask)
+ tl.store(scale_blk, scale, mask=n_mask)
+
+
+@triton.jit
+def _qkv_dequant_permute(
+ x_ptr: tl.tensor,
+ quant_x_ptr: tl.tensor,
+ x_stride_s: int,
+ qx_stride_s: int,
+ qx_stride_b: int,
+ qx_stride_n: int,
+ B: int,
+ S: int,
+ N: int,
+ D: int,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_D: tl.constexpr,
+):
+ bs_id = tl.program_id(0).to(tl.int64)
+ b_id = bs_id % B
+ s_id = bs_id // B
+
+ quant_x_ptr += s_id * qx_stride_s + b_id * qx_stride_b
+ scale_ptr = quant_x_ptr.to(tl.pointer_type(tl.float32, 1))
+ x_ptr += bs_id * x_stride_s
+
+ n_offset = tl.arange(0, BLOCK_SIZE_N)[None, :]
+ n_mask = n_offset < N
+ d_offset = tl.arange(0, BLOCK_SIZE_D)[:, None]
+ d_mask = d_offset < D
+ mask = n_mask & d_mask
+
+ x_blk = x_ptr + n_offset * D + d_offset
+ quant_x_blk = quant_x_ptr + n_offset * qx_stride_n + d_offset
+ scale_blk = scale_ptr + n_offset * (D // 4 + 1) + D // 4
+
+ qx = tl.load(quant_x_blk, mask=mask, other=0.0).to(tl.float32)
+ scale = tl.load(scale_blk, mask=n_mask, other=0.0).to(tl.float32)
+
+ tl.store(x_blk, qx * scale, mask=mask)
+
+
+def qkv_permute_quant_fp8(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
+ B, S, P, N, D = x.shape
+
+ quant_x = torch.empty((P, S, B, N, D + 4), dtype=torch.float8_e4m3fn, device=x.device)
+
+ grid = (P * S * B,)
+
+ with torch.cuda.device(x.device):
+ _qkv_permute_quant[grid](
+ quant_x,
+ x,
+ quant_x.stride(2),
+ quant_x.stride(3),
+ x.stride(0),
+ x.stride(1),
+ x.stride(2),
+ B,
+ S,
+ N,
+ D,
+ eps,
+ triton.next_power_of_2(N),
+ triton.next_power_of_2(D),
+ )
+
+ return quant_x
+
+
+def qkv_dequant_permute_fp8(
+ quant_x: torch.Tensor, dtype: torch.dtype = torch.bfloat16
+) -> torch.Tensor:
+ S, B, N, _D = quant_x.shape
+ D = _D - 4
+ x = torch.empty((B, S, N, D), dtype=dtype, device=quant_x.device)
+
+ grid = (B * S,)
+
+ with torch.cuda.device(x.device):
+ _qkv_dequant_permute[grid](
+ x,
+ quant_x,
+ x.stride(1),
+ quant_x.stride(0),
+ quant_x.stride(1),
+ quant_x.stride(2),
+ B,
+ S,
+ N,
+ D,
+ triton.next_power_of_2(N),
+ triton.next_power_of_2(D),
+ )
+
+ return x
diff --git a/src/cache_dit/kernels/triton_taylorseer.py b/src/cache_dit/kernels/triton/taylorseer.py
similarity index 100%
rename from src/cache_dit/kernels/triton_taylorseer.py
rename to src/cache_dit/kernels/triton/taylorseer.py
diff --git a/src/cache_dit/logger.py b/src/cache_dit/logger.py
index 4ab791841..2f88fb8fc 100644
--- a/src/cache_dit/logger.py
+++ b/src/cache_dit/logger.py
@@ -2,13 +2,14 @@
import os
import sys
import torch.distributed as dist
+from .envs import ENV
_FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"
-_LOG_LEVEL = os.environ.get("CACHE_DIT_LOG_LEVEL", "info")
+_LOG_LEVEL = ENV.CACHE_DIT_LOG_LEVEL
_LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0)
-_LOG_DIR = os.environ.get("CACHE_DIT_LOG_DIR", None)
+_LOG_DIR = ENV.CACHE_DIT_LOG_DIR
class NewLineFormatter(logging.Formatter):
diff --git a/src/cache_dit/metrics/__init__.py b/src/cache_dit/metrics/__init__.py
index a6fd4bd38..609c2d160 100644
--- a/src/cache_dit/metrics/__init__.py
+++ b/src/cache_dit/metrics/__init__.py
@@ -9,18 +9,18 @@
"Install with:\npip install cache-dit[metrics]"
)
-from cache_dit.metrics.metrics import compute_psnr
-from cache_dit.metrics.metrics import compute_ssim
-from cache_dit.metrics.metrics import compute_mse
-from cache_dit.metrics.metrics import compute_video_psnr
-from cache_dit.metrics.metrics import compute_video_ssim
-from cache_dit.metrics.metrics import compute_video_mse
-from cache_dit.metrics.fid import FrechetInceptionDistance
-from cache_dit.metrics.fid import compute_fid
-from cache_dit.metrics.fid import compute_video_fid
-from cache_dit.metrics.config import set_metrics_verbose
-from cache_dit.metrics.config import get_metrics_verbose
-from cache_dit.metrics.metrics import entrypoint
+from .metrics import compute_psnr
+from .metrics import compute_ssim
+from .metrics import compute_mse
+from .metrics import compute_video_psnr
+from .metrics import compute_video_ssim
+from .metrics import compute_video_mse
+from .fid import FrechetInceptionDistance
+from .fid import compute_fid
+from .fid import compute_video_fid
+from .config import set_metrics_verbose
+from .config import get_metrics_verbose
+from .metrics import entrypoint
def main():
diff --git a/src/cache_dit/metrics/clip_score.py b/src/cache_dit/metrics/clip_score.py
index 7d43efc21..1d0e83896 100644
--- a/src/cache_dit/metrics/clip_score.py
+++ b/src/cache_dit/metrics/clip_score.py
@@ -8,8 +8,9 @@
from transformers import CLIPProcessor, CLIPModel
from typing import Tuple, Union
-from cache_dit.metrics.config import _IMAGE_EXTENSIONS
-from cache_dit.metrics.config import get_metrics_verbose
+from .config import _IMAGE_EXTENSIONS
+from .config import get_metrics_verbose
+from ..platforms import current_platform
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -21,7 +22,9 @@
class CLIPScore:
def __init__(
self,
- device="cuda" if torch.cuda.is_available() else "cpu",
+ device=(
+ current_platform.device_type if current_platform.is_accelerator_available() else "cpu"
+ ),
clip_model_path: str = None,
):
self.device = device
diff --git a/src/cache_dit/metrics/fid.py b/src/cache_dit/metrics/fid.py
index c7aaffbdc..28863a335 100644
--- a/src/cache_dit/metrics/fid.py
+++ b/src/cache_dit/metrics/fid.py
@@ -12,11 +12,12 @@
from torch.nn.functional import adaptive_avg_pool2d
from typing import Tuple, Union
-from cache_dit.metrics.inception import InceptionV3
-from cache_dit.metrics.config import _IMAGE_EXTENSIONS
-from cache_dit.metrics.config import _VIDEO_EXTENSIONS
-from cache_dit.metrics.config import get_metrics_verbose
-from cache_dit.utils import disable_print
+from .inception import InceptionV3
+from .config import _IMAGE_EXTENSIONS
+from .config import _VIDEO_EXTENSIONS
+from .config import get_metrics_verbose
+from ..platforms import current_platform
+from ..utils import disable_print
from cache_dit.logger import init_logger
warnings.filterwarnings("ignore")
@@ -225,7 +226,9 @@ def calculate_activation_statistics(
class FrechetInceptionDistance:
def __init__(
self,
- device="cuda" if torch.cuda.is_available() else "cpu",
+ device=(
+ current_platform.device_type if current_platform.is_accelerator_available() else "cpu"
+ ),
dims: int = 2048,
num_workers: int = 1,
batch_size: int = 1,
diff --git a/src/cache_dit/metrics/image_reward.py b/src/cache_dit/metrics/image_reward.py
index e52b5731b..e2c71f170 100644
--- a/src/cache_dit/metrics/image_reward.py
+++ b/src/cache_dit/metrics/image_reward.py
@@ -12,9 +12,10 @@
import torchvision.transforms.v2 as T
from typing import Tuple, Union
-from cache_dit.metrics.config import _IMAGE_EXTENSIONS
-from cache_dit.metrics.config import get_metrics_verbose
-from cache_dit.utils import disable_print
+from .config import _IMAGE_EXTENSIONS
+from .config import get_metrics_verbose
+from ..platforms import current_platform
+from ..utils import disable_print
from cache_dit.logger import init_logger
warnings.filterwarnings("ignore")
@@ -29,7 +30,9 @@
class ImageRewardScore:
def __init__(
self,
- device="cuda" if torch.cuda.is_available() else "cpu",
+ device=(
+ current_platform.device_type if current_platform.is_accelerator_available() else "cpu"
+ ),
imagereward_model_path: str = None,
):
self.device = device
diff --git a/src/cache_dit/metrics/lpips.py b/src/cache_dit/metrics/lpips.py
index 614e5497a..972814d1f 100644
--- a/src/cache_dit/metrics/lpips.py
+++ b/src/cache_dit/metrics/lpips.py
@@ -2,7 +2,7 @@
import lpips
import torch
-from cache_dit.utils import disable_print
+from ..utils import disable_print
warnings.filterwarnings("ignore")
diff --git a/src/cache_dit/metrics/metrics.py b/src/cache_dit/metrics/metrics.py
index 7bf83f9eb..3ebbd0b97 100644
--- a/src/cache_dit/metrics/metrics.py
+++ b/src/cache_dit/metrics/metrics.py
@@ -10,16 +10,16 @@
from skimage.metrics import mean_squared_error
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import structural_similarity
-from cache_dit.metrics.config import set_metrics_verbose
-from cache_dit.metrics.config import get_metrics_verbose
-from cache_dit.metrics.config import _IMAGE_EXTENSIONS
-from cache_dit.metrics.config import _VIDEO_EXTENSIONS
+from .config import set_metrics_verbose
+from .config import get_metrics_verbose
+from .config import _IMAGE_EXTENSIONS
+from .config import _VIDEO_EXTENSIONS
+from .fid import compute_fid
+from .fid import compute_video_fid
+from .lpips import compute_lpips_img
+from .clip_score import compute_clip_score
+from .image_reward import compute_reward_score
from cache_dit.logger import init_logger
-from cache_dit.metrics.fid import compute_fid
-from cache_dit.metrics.fid import compute_video_fid
-from cache_dit.metrics.lpips import compute_lpips_img
-from cache_dit.metrics.clip_score import compute_clip_score
-from cache_dit.metrics.image_reward import compute_reward_score
logger = init_logger(__name__)
@@ -1074,6 +1074,10 @@ def _format_table(format_strs: List[str], metric: str):
if metric.upper() in key or metric.lower() in key:
selected_items[key] = METRICS_META[key]
+ # skip unselected metric
+ if len(selected_items) == 0:
+ continue
+
reverse = (
True
if metric.lower()
diff --git a/src/cache_dit/parallelism/__init__.py b/src/cache_dit/parallelism/__init__.py
index 16516d10b..a575ae3f2 100644
--- a/src/cache_dit/parallelism/__init__.py
+++ b/src/cache_dit/parallelism/__init__.py
@@ -1,6 +1,4 @@
-from cache_dit.parallelism.parallel_backend import ParallelismBackend
-from cache_dit.parallelism.parallel_config import ParallelismConfig
-from cache_dit.parallelism.backends.native_diffusers import enable_ulysses_anything
-from cache_dit.parallelism.backends.native_diffusers import disable_ulysses_anything
-from cache_dit.parallelism.parallel_interface import enable_parallelism
-from cache_dit.parallelism.parallel_interface import maybe_pad_prompt
+from .backend import ParallelismBackend
+from .config import ParallelismConfig
+from .dispatch import enable_parallelism
+from .dispatch import remove_parallelism_stats
diff --git a/src/cache_dit/parallelism/attention/__init__.py b/src/cache_dit/parallelism/attention/__init__.py
new file mode 100644
index 000000000..b3321a430
--- /dev/null
+++ b/src/cache_dit/parallelism/attention/__init__.py
@@ -0,0 +1,33 @@
+from cache_dit.envs import ENV
+from ._distributed_primitives import (
+ _unified_all_to_all_o_async_fn,
+ _unified_all_to_all_qkv_async_fn,
+ _prepare_ulysses_comm_metadata,
+)
+from ._experimental_utils import (
+ _is_diffusers_parallelism_available,
+ _maybe_patch_find_submodule,
+)
+from ._templated_ulysses import (
+ enable_ulysses_anything,
+ enable_ulysses_float8,
+)
+
+
+def _maybe_register_custom_attn_backends():
+ """Maybe re-register native attention backend to enable context parallelism."""
+ if not ENV.CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH:
+ return
+
+ if ENV.CACHE_DIT_ENABLE_CUSTOM_ATTN_ALREADY_DISPATCH:
+ return
+
+ ENV.CACHE_DIT_ENABLE_CUSTOM_ATTN_ALREADY_DISPATCH = True
+
+ from ._attention_dispatch import (
+ _native_attention,
+ _sdpa_cudnn_attention,
+ _sage_attention,
+ _flash_attention_3,
+ _native_npu_attention,
+ )
diff --git a/src/cache_dit/parallelism/attention/_attention_dispatch.py b/src/cache_dit/parallelism/attention/_attention_dispatch.py
new file mode 100644
index 000000000..a51733881
--- /dev/null
+++ b/src/cache_dit/parallelism/attention/_attention_dispatch.py
@@ -0,0 +1,747 @@
+import torch
+import math
+from typing import Optional
+from cache_dit.platforms import current_platform
+
+try:
+ from diffusers.models.attention_dispatch import (
+ _AttentionBackendRegistry,
+ AttentionBackendName,
+ _check_device,
+ _check_shape,
+ _check_qkv_dtype_bf16_or_fp16,
+ _check_device_cuda,
+ )
+
+ # For sage attention backend re-registration
+ from diffusers.models.attention_dispatch import (
+ sageattn,
+ _sage_attention_forward_op,
+ _sage_attention_backward_op,
+ )
+
+ # For flash attention 3 backend re-registration
+ try:
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
+
+ _flash_attn_3_available = True
+ except ImportError:
+ flash_attn_3_func = None
+ _flash_attn_3_available = False
+
+ # For native npu attention backend re-registration
+ if current_platform.device_type == "npu":
+ from torch_npu import npu_fusion_attention
+ else:
+ npu_fusion_attention = None
+
+ from diffusers.models._modeling_parallel import ParallelConfig
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+from cache_dit.logger import init_logger
+from cache_dit.envs import ENV
+
+from ._templated_ring import UnifiedTemplatedRingAttention
+from ._templated_ulysses import UnifiedTemplatedUlyssesAttention
+
+logger = init_logger(__name__)
+MAX_TOKEN = 2147483647
+
+__all__ = [
+ "_native_attention",
+ "_sdpa_cudnn_attention",
+ "_sage_attention",
+ "_flash_attention_3",
+ "_native_npu_attention",
+]
+
+
+def _registry_pop_attn_backend(attn_backend: AttentionBackendName):
+ _AttentionBackendRegistry._backends.pop(attn_backend, None)
+ _AttentionBackendRegistry._constraints.pop(attn_backend, None)
+ _AttentionBackendRegistry._supported_arg_names.pop(attn_backend, None)
+ if isinstance(_AttentionBackendRegistry._supports_context_parallel, dict):
+ _AttentionBackendRegistry._supports_context_parallel.pop(attn_backend, None)
+ elif attn_backend in _AttentionBackendRegistry._supports_context_parallel:
+ _AttentionBackendRegistry._supports_context_parallel.remove(attn_backend.value)
+
+
+def _set_new_attn_backend(member: str, value: str):
+ # e.g., _set_new_attn_backend("_SDPA_CUDNN", "_sdpa_cudnn")
+ new_member = str.__new__(AttentionBackendName, value)
+ new_member._name_ = member
+ new_member._value_ = value
+ setattr(AttentionBackendName, member, new_member)
+ AttentionBackendName._member_map_[member] = new_member
+ AttentionBackendName._member_names_.append(member)
+ AttentionBackendName._value2member_map_[value] = new_member
+
+
+# Enable custom native attention backend with context parallelism
+# by default. Users can set the environment variable to 0 to disable
+# this behavior. Default to enabled for better compatibility.
+if ENV.CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH:
+ _ATTENTION_OPS_ALLOW_ATTN_MASK = [
+ "_native_attention_forward_op",
+ "_sdpa_cudnn_attention_forward_op",
+ "_npu_attention_forward_op",
+ ]
+
+ # Re-define templated context parallel attention to support attn mask
+ def _unified_templated_context_parallel_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ *,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ if attn_mask is not None:
+ # NOTE(DefTruth): Check if forward_op is native attention forward op
+ forward_op_name = forward_op.__name__
+ if forward_op_name not in _ATTENTION_OPS_ALLOW_ATTN_MASK:
+ raise ValueError(
+ "Templated context parallel attention with attn_mask "
+ "is only supported for native attention backend, "
+ f"but got forward_op: {forward_op_name}."
+ )
+ if is_causal:
+ raise ValueError("Causal attention is not yet supported for templated attention.")
+ if enable_gqa:
+ raise ValueError("GQA is not yet supported for templated attention.")
+
+ # TODO: add support for unified attention with ring/ulysses degree both being > 1
+ if _parallel_config.context_parallel_config.ring_degree > 1:
+ return UnifiedTemplatedRingAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ elif _parallel_config.context_parallel_config.ulysses_degree > 1:
+ return UnifiedTemplatedUlyssesAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ else:
+ raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")
+
+ # NOTE: Remove NATIVE attention backend constraints and re-register it.
+ # Here is a temporary workaround to enable context parallelism with
+ # native attention backend for attn mask support until diffusers
+ # officially support it.
+
+ def _native_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ # used for backward pass
+ if _save_ctx:
+ ctx.save_for_backward(query, key, value)
+ ctx.attn_mask = attn_mask
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.enable_gqa = enable_gqa
+
+ if return_lse:
+ # Use native flash attention to get lse if return_lse is True
+ if attn_mask is not None:
+ raise ValueError(
+ "`attn_mask` is not yet supported for native flash attention with lse."
+ )
+ out, lse = torch.ops.aten._scaled_dot_product_flash_attention(
+ query.transpose(1, 2),
+ key.transpose(1, 2),
+ value.transpose(1, 2),
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ )[:2]
+ out = out.transpose(1, 2)
+ lse = lse.transpose(1, 2)
+ return out, lse
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+
+ return out
+
+ def _native_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ query, key, value = ctx.saved_tensors
+
+ query.requires_grad_(True)
+ key.requires_grad_(True)
+ value.requires_grad_(True)
+
+ query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query_t,
+ key=key_t,
+ value=value_t,
+ attn_mask=ctx.attn_mask,
+ dropout_p=ctx.dropout_p,
+ is_causal=ctx.is_causal,
+ scale=ctx.scale,
+ enable_gqa=ctx.enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+
+ grad_out_t = grad_out.permute(0, 2, 1, 3)
+ grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
+ outputs=out,
+ inputs=[query_t, key_t, value_t],
+ grad_outputs=grad_out_t,
+ retain_graph=False,
+ )
+
+ grad_query = grad_query_t.permute(0, 2, 1, 3)
+ grad_key = grad_key_t.permute(0, 2, 1, 3)
+ grad_value = grad_value_t.permute(0, 2, 1, 3)
+
+ return grad_query, grad_key, grad_value
+
+ # Re-register NATIVE attention backend to allow attn mask while using context parallelism
+ _registry_pop_attn_backend(AttentionBackendName.NATIVE)
+
+ @_AttentionBackendRegistry.register(
+ AttentionBackendName.NATIVE,
+ constraints=[_check_device, _check_shape],
+ supports_context_parallel=True,
+ )
+ def _native_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("Native attention backend does not support setting `return_lse=True`.")
+ if _parallel_config is None:
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ else:
+ out = _unified_templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op=_native_attention_forward_op,
+ backward_op=_native_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ return out
+
+ logger.info(
+ "Re-registered NATIVE attention backend to enable context parallelism "
+ "with attn mask in cache-dit. You can disable this behavior by: "
+ "export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0."
+ )
+
+ def _sdpa_cudnn_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ # Native attention does not return_lse
+ if return_lse:
+ raise ValueError("cudnn attention with sdpa does not support return_lse=True")
+
+ # used for backward pass
+ if _save_ctx:
+ ctx.save_for_backward(query, key, value)
+ ctx.attn_mask = attn_mask
+ ctx.dropout_p = dropout_p
+ ctx.is_causal = is_causal
+ ctx.scale = scale
+ ctx.enable_gqa = enable_gqa
+
+ query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+
+ return out
+
+ def _sdpa_cudnn_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ raise NotImplementedError("Backward for cudnn attention with sdpa is not implemented yet.")
+
+ # Register _sdpa_cudnn_attention backend to allow attn mask while using context parallelism
+ _set_new_attn_backend("_SDPA_CUDNN", "_sdpa_cudnn")
+ assert hasattr(AttentionBackendName, "_SDPA_CUDNN")
+
+ @_AttentionBackendRegistry.register(
+ AttentionBackendName._SDPA_CUDNN, # type: AttentionBackendName
+ constraints=[_check_device, _check_shape],
+ supports_context_parallel=True,
+ )
+ def _sdpa_cudnn_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ) -> torch.Tensor:
+ lse = None
+ if _parallel_config is None and not return_lse:
+ query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
+ with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
+ out = torch.nn.functional.scaled_dot_product_attention(
+ query=query,
+ key=key,
+ value=value,
+ attn_mask=attn_mask,
+ dropout_p=dropout_p,
+ is_causal=is_causal,
+ scale=scale,
+ enable_gqa=enable_gqa,
+ )
+ out = out.permute(0, 2, 1, 3)
+ else:
+ out = _unified_templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op=_sdpa_cudnn_attention_forward_op,
+ backward_op=_sdpa_cudnn_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
+
+ logger.info(
+ "Registered new attention backend: _SDPA_CUDNN to enable context "
+ "parallelism with attn mask in cache-dit. You can disable it by: "
+ "export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0."
+ )
+
+ _registry_pop_attn_backend(AttentionBackendName.SAGE)
+
+ @_AttentionBackendRegistry.register(
+ AttentionBackendName.SAGE,
+ constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
+ )
+ def _sage_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ) -> torch.Tensor:
+ lse = None
+ if _parallel_config is None:
+ out = sageattn(
+ q=query,
+ k=key,
+ v=value,
+ tensor_layout="NHD",
+ is_causal=is_causal,
+ sm_scale=scale,
+ return_lse=return_lse,
+ )
+ if return_lse:
+ out, lse, *_ = out
+ else:
+ out = _unified_templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ None,
+ 0.0,
+ is_causal,
+ scale,
+ False,
+ return_lse,
+ forward_op=_sage_attention_forward_op,
+ backward_op=_sage_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
+
+ logger.info(
+ "Re-registered SAGE attention backend to enable context parallelism "
+ "with FP8 Attention in cache-dit. You can disable this behavior by: "
+ "export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0."
+ )
+
+ # Flash Attention 3 forward op implementation (inference only)
+ def _flash_attention_3_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ """Flash Attention 3 forward operation for cache-dit (inference only)."""
+ if attn_mask is not None:
+ raise ValueError("`attn_mask` is not yet supported for flash-attn 3.")
+ if enable_gqa:
+ raise ValueError("`enable_gqa` is not yet supported for flash-attn 3.")
+ if dropout_p > 0.0:
+ raise ValueError("`dropout_p` > 0 is not yet supported for flash-attn 3.")
+
+ if scale is None:
+ scale = query.shape[-1] ** (-0.5)
+
+ if _save_ctx:
+ logger.warning(
+ "Flash Attention 3 is configured for inference only, but _save_ctx=True was passed. "
+ "Context will not be saved."
+ )
+
+ # Hardcoded parameters for FA3
+ window_size = (-1, -1)
+ softcap = 0.0
+ deterministic = False
+
+ out = flash_attn_3_func(
+ q=query,
+ k=key,
+ v=value,
+ softmax_scale=scale,
+ causal=is_causal,
+ qv=None,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ window_size=window_size,
+ attention_chunk=0,
+ softcap=softcap,
+ num_splits=1,
+ pack_gqa=None,
+ deterministic=deterministic,
+ sm_margin=0,
+ return_attn_probs=return_lse,
+ )
+ if return_lse:
+ out, lse = out
+ lse = lse.permute(0, 2, 1)
+ return out, lse
+ else:
+ return out
+
+ def _flash_attention_3_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ """Flash Attention 3 backward operation for cache-dit."""
+ raise NotImplementedError(
+ "Backward pass for Flash Attention 3 with context parallelism is not implemented yet in cache-dit."
+ )
+
+ # Re-register Flash Attention 3 backend
+ if _flash_attn_3_available:
+ if hasattr(AttentionBackendName, "_FLASH_3"):
+ _registry_pop_attn_backend(AttentionBackendName._FLASH_3)
+ else:
+ logger.info("AttentionBackendName._FLASH_3 not found, creating new backend.")
+ _set_new_attn_backend("_FLASH_3", "_flash_3")
+ assert hasattr(AttentionBackendName, "_FLASH_3")
+
+ @_AttentionBackendRegistry.register(
+ AttentionBackendName._FLASH_3,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
+ )
+ def _flash_attention_3(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ scale: Optional[float] = None,
+ is_causal: bool = False,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ) -> torch.Tensor:
+ lse = None
+ if _parallel_config is None:
+ # Non-parallel: use native flash-attn-3
+ window_size = (-1, -1)
+ softcap = 0.0
+ deterministic = False
+ out = flash_attn_3_func(
+ q=query,
+ k=key,
+ v=value,
+ softmax_scale=scale,
+ causal=is_causal,
+ qv=None,
+ q_descale=None,
+ k_descale=None,
+ v_descale=None,
+ window_size=window_size,
+ attention_chunk=0,
+ softcap=softcap,
+ num_splits=1,
+ pack_gqa=None,
+ deterministic=deterministic,
+ sm_margin=0,
+ return_attn_probs=return_lse,
+ )
+ if return_lse:
+ out, lse = out
+ lse = lse.permute(0, 2, 1)
+ else:
+ # Parallel: use cache-dit's optimized implementation
+ out = _unified_templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ None, # attn_mask not supported by FA3
+ 0.0, # dropout_p
+ is_causal,
+ scale,
+ False, # enable_gqa
+ return_lse,
+ forward_op=_flash_attention_3_forward_op,
+ backward_op=_flash_attention_3_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse = out
+
+ return (out, lse) if return_lse else out
+
+ logger.info(
+ "Re-registered FLASH_3 attention backend to enable context parallelism "
+ "with Ulysses Anything/Float8 in cache-dit. You can disable this behavior by: "
+ "export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0."
+ )
+ else:
+ _flash_attention_3 = None # type: ignore[assignment]
+ logger.info("Flash Attention 3 not available, skipping _FLASH_3 backend registration.")
+
+ _registry_pop_attn_backend(AttentionBackendName._NATIVE_NPU)
+
+ @_AttentionBackendRegistry.register(
+ AttentionBackendName._NATIVE_NPU,
+ constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
+ supports_context_parallel=True,
+ )
+ def _native_npu_attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ dropout_p: float = 0.0,
+ scale: Optional[float] = None,
+ return_lse: bool = False,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ) -> torch.Tensor:
+ if return_lse:
+ raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
+ if _parallel_config is None:
+ out = npu_fusion_attention(
+ query,
+ key,
+ value,
+ atten_mask=None,
+ input_layout="BSND",
+ scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
+ pre_tockens=MAX_TOKEN,
+ next_tockens=MAX_TOKEN,
+ head_num=query.size(2),
+ )[0]
+ else:
+ out = _unified_templated_context_parallel_attention(
+ query,
+ key,
+ value,
+ None,
+ dropout_p,
+ None,
+ 1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
+ None,
+ return_lse,
+ forward_op=_npu_attention_forward_op,
+ backward_op=_npu_attention_backward_op,
+ _parallel_config=_parallel_config,
+ )
+ return out
+
+ logger.info(
+ "Re-registered _NATIVE_NPU attention backend to enable context parallelism "
+ "You can disable this behavior by: "
+ "export CACHE_DIT_ENABLE_CUSTOM_ATTN_DISPATCH=0."
+ )
+
+ def _npu_attention_forward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor] = None,
+ dropout_p: float = 0.0,
+ is_causal: bool = False,
+ scale: Optional[float] = None,
+ enable_gqa: bool = False,
+ return_lse: bool = False,
+ _save_ctx: bool = True,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ if return_lse:
+ raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
+
+ if attn_mask is not None:
+ attn_mask = ~attn_mask.to(torch.bool)
+ out = npu_fusion_attention(
+ query,
+ key,
+ value,
+ atten_mask=attn_mask,
+ input_layout="BSND",
+ scale=scale,
+ pre_tockens=MAX_TOKEN,
+ next_tockens=MAX_TOKEN,
+ head_num=query.size(2),
+ )[0]
+
+ return out
+
+ def _npu_attention_backward_op(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")
+
+else:
+ from diffusers.models.attention_dispatch import (
+ _native_attention,
+ _sage_attention,
+ ) # noqa: F401
+
+ try:
+ from diffusers.models.attention_dispatch import _flash_attention_3 # noqa: F401
+ except ImportError:
+ _flash_attention_3 = None # type: ignore[assignment]
+
+ _sdpa_cudnn_attention = None # type: ignore[assignment]
+ _native_npu_attention = None # type: ignore[assignment]
+
+ logger.info("Skipped custom attention backend registration in cache-dit.")
diff --git a/src/cache_dit/parallelism/attention/_distributed_primitives.py b/src/cache_dit/parallelism/attention/_distributed_primitives.py
new file mode 100644
index 000000000..8a226be1d
--- /dev/null
+++ b/src/cache_dit/parallelism/attention/_distributed_primitives.py
@@ -0,0 +1,671 @@
+import functools
+from typing import Tuple, List, Callable, Optional
+
+import torch
+import torch.distributed as dist
+import torch.distributed._functional_collectives as fc
+import torch.nn.functional as F
+
+from cache_dit.platforms import current_platform
+
+try:
+ from cache_dit.kernels import (
+ per_token_quant_fp8,
+ per_token_dequant_fp8,
+ qkv_permute_quant_fp8,
+ qkv_dequant_permute_fp8,
+ )
+except ImportError:
+
+ def _fp8_kernel_unavailable(*args, **kwargs):
+ raise RuntimeError(
+ "FP8 kernels could not be imported (e.g., Triton may not be available on this "
+ "platform). FP8 async operations are not supported. Please install the required "
+ "dependencies or disable FP8 mode."
+ )
+
+ per_token_quant_fp8 = _fp8_kernel_unavailable
+ per_token_dequant_fp8 = _fp8_kernel_unavailable
+ qkv_permute_quant_fp8 = _fp8_kernel_unavailable
+ qkv_dequant_permute_fp8 = _fp8_kernel_unavailable
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+# Some helper distributed primitive functions for context parallel attention.
+__all__ = [
+ # All to all for Ulysses Attention
+ "_all_to_all_single_qkv_async",
+ "_all_to_all_single_o_async",
+ "_all_to_all_single_qkv_uneven_heads_async",
+ "_all_to_all_single_o_uneven_heads_async",
+ "_all_to_all_single_qkv_fp8_async",
+ "_all_to_all_single_o_fp8_async",
+ # All to all for Ulysses Anything Attention
+ "_all_to_all_single_any_qkv_async",
+ "_all_to_all_single_any_o_async",
+ "_all_to_all_single_any_qkv_fp8_async",
+ "_all_to_all_single_any_o_fp8_async",
+ # Helper functions for preparing communication metadata
+ "_prepare_ulysses_comm_metadata",
+ # Unified functions for Async Ulysses QKV/O Projection
+ "_unified_all_to_all_qkv_async_fn",
+ "_unified_all_to_all_o_async_fn",
+]
+
+# NOTE: We should always use the asynchronous all to all variants to keep the uified input/output shape
+# for any_qkvo and non-any_qkvo cases, otherwise, the input/output shape will be different, which makes
+# the unified function implementation complex and ugly.
+
+
+# Reference:
+# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
+# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
+# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_dispatch.py#L1012
+# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
+def _wait_tensor(tensor) -> torch.Tensor:
+ if isinstance(tensor, fc.AsyncCollectiveTensor):
+ tensor = tensor.wait()
+
+ return tensor
+
+
+def _get_rank_world_size(
+ group: dist.ProcessGroup,
+) -> Tuple[int, int]:
+ world_size = dist.get_world_size(group=group)
+ rank = dist.get_rank(group=group)
+ return rank, world_size
+
+
+@functools.lru_cache(maxsize=128)
+def _gather_size_by_comm(size: int, group: dist.ProcessGroup) -> List[int]:
+ r"""Gather the local size from all ranks.
+ size: int, local size
+ return: List[int], list of size from all ranks
+ """
+ world_size = dist.get_world_size(group=group)
+ # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
+ comm_backends = str(dist.get_backend(group=group))
+ # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
+ gather_device = "cpu" if "cpu" in comm_backends else current_platform.default_device()
+ gathered_sizes = [
+ torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)
+ ]
+ dist.all_gather(
+ gathered_sizes,
+ torch.tensor([size], device=gather_device, dtype=torch.int64),
+ group=group,
+ )
+
+ gathered_sizes = [s[0].item() for s in gathered_sizes]
+ # NOTE: DON'T use tolist here due to graph break - Explanation:
+ # Backend compiler `inductor` failed with aten._local_scalar_dense.default
+ return gathered_sizes
+
+
+def _split_head_sizes(
+ H: int,
+ group: dist.ProcessGroup,
+) -> List[int]:
+ r"""Split the head dimension size by world_size.
+ H: int, global head num
+ return: List[int], list of local head num for each rank
+ """
+ assert H is not None, "Global head num H must be provided."
+ rank, world_size = _get_rank_world_size(group)
+ # e.g, H = 30, world_size = 4, output_split_sizes = [8, 8, 8, 6]
+ output_split_sizes = []
+ base_head_num = H // world_size
+ remainder = H % world_size
+ for i in range(world_size):
+ if i < remainder:
+ output_split_sizes.append(base_head_num + 1)
+ else:
+ output_split_sizes.append(base_head_num)
+ return output_split_sizes
+
+
+# Helper functions to pad/unpad head dimension for QKV and O projections
+def _maybe_pad_qkv_head(
+ x: torch.Tensor,
+ H: int,
+ group: dist.ProcessGroup,
+) -> Tuple[torch.Tensor, int]:
+ r"""Maybe pad the head dimension to be divisible by world_size.
+ x: torch.Tensor, shape (B, S_LOCAL, H, D)
+ H: int, original global head num
+ return: Tuple[torch.Tensor, int], padded tensor (B, S_LOCAL, H + H_PAD, D) and H_PAD
+ """
+ _, world_size = _get_rank_world_size(group)
+ H_PAD = 0
+ if H % world_size != 0:
+ H_PAD = world_size - (H % world_size)
+ NEW_H_LOCAL = (H + H_PAD) // world_size
+ # e.g., Allow: H=30, world_size=8 -> NEW_H_LOCAL=4, H_PAD=2.
+ # NOT ALLOW: H=30, world_size=16 -> NEW_H_LOCAL=2, H_PAD=14.
+ assert (
+ H_PAD < NEW_H_LOCAL
+ ), f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
+ x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
+ return x, H_PAD
+
+
+def _maybe_unpad_qkv_head(
+ x: torch.Tensor,
+ H_PAD: int,
+ group: dist.ProcessGroup,
+) -> torch.Tensor:
+ r"""Maybe unpad the head dimension.
+ x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL + H_PAD, D)
+ H_PAD: int, head padding num
+ return: torch.Tensor, unpadded tensor (B, S_GLOBAL, H_LOCAL, D)
+ """
+ rank, world_size = _get_rank_world_size(group)
+ # Only the last rank may have padding
+ if H_PAD > 0 and rank == world_size - 1:
+ x = x[:, :, :-H_PAD, :]
+ return x.contiguous()
+
+
+def _maybe_pad_o_head(
+ x: torch.Tensor,
+ H: int,
+ group: dist.ProcessGroup,
+) -> Tuple[torch.Tensor, int]:
+ r"""Maybe pad the head dimension to be divisible by world_size.
+ x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D)
+ H: int, original global head num
+ return: Tuple[torch.Tensor, int], padded tensor (B, S_GLOBAL, H_LOCAL + H_PAD, D) and H_PAD
+ """
+ if H is None:
+ return x, 0
+
+ rank, world_size = _get_rank_world_size(group)
+ H_PAD = 0
+ # Only the last rank may need padding
+ if H % world_size != 0:
+ # We need to broadcast H_PAD to all ranks to keep consistency
+ # in unpadding step later for all ranks.
+ H_PAD = world_size - (H % world_size)
+ NEW_H_LOCAL = (H + H_PAD) // world_size
+ assert (
+ H_PAD < NEW_H_LOCAL
+ ), f"Padding head num {H_PAD} should be less than new local head num {NEW_H_LOCAL}"
+ if rank == world_size - 1:
+ x = F.pad(x, (0, 0, 0, H_PAD)).contiguous()
+ return x, H_PAD
+
+
+def _maybe_unpad_o_head(
+ x: torch.Tensor,
+ H_PAD: int,
+ group: dist.ProcessGroup,
+) -> torch.Tensor:
+ r"""Maybe unpad the head dimension.
+ x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL + H_PAD, D)
+ H_PAD: int, head padding num
+ return: torch.Tensor, unpadded tensor (B, S_LOCAL, H_GLOBAL, D)
+ """
+ if H_PAD > 0:
+ x = x[:, :, :-H_PAD, :]
+ return x.contiguous()
+
+
+# Helper functions to for all-to-all communication with Ulysses Attention
+def _prepare_ulysses_comm_metadata(
+ query: torch.Tensor,
+ **kwargs,
+) -> dict:
+ # query: (B, S_LOCAL, H_GLOBAL, D)
+ assert (
+ len(query.shape) == 4
+ ), "Query tensor must be 4-dimensional of shape (B, S_LOCAL, H_GLOBAL, D)"
+ extra_kwargs = {}
+ extra_kwargs["NUM_QO_HEAD"] = query.shape[2]
+ extra_kwargs["Q_S_LOCAL"] = query.shape[1]
+ # Add other kwargs if needed in future
+ return extra_kwargs
+
+
+def _all_to_all_single_qkv_async(
+ x: torch.Tensor,
+ group: dist.ProcessGroup,
+ **kwargs,
+) -> torch.Tensor:
+ r"""
+ x: torch.Tensor, shape (B, S_LOCAL, H, D)
+ return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
+ """
+ _, world_size = _get_rank_world_size(group)
+ B, S_LOCAL, H, D = x.shape
+ x, H_PAD = _maybe_pad_qkv_head(x, H, group)
+ H_LOCAL = (H + H_PAD) // world_size
+ x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+ _shape = x.shape # (world_size, S_LOCAL, B, H_LOCAL, D)
+
+ x = x.flatten()
+ x = fc.all_to_all_single(x, None, None, group)
+
+ def wait() -> torch.Tensor:
+ nonlocal x, H_PAD
+ x = _wait_tensor(x)
+ # (world_size, S_LOCAL, B, H_LOCAL, D)
+ # -> (S_GLOBAL, B, H_LOCAL, D)
+ # -> (B, S_GLOBAL, H_LOCAL, D)
+ x = x.reshape(_shape).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
+ x = _maybe_unpad_qkv_head(x, H_PAD, group)
+ return x
+
+ return wait
+
+
+def _all_to_all_single_o_async(
+ x: torch.Tensor,
+ group: dist.ProcessGroup,
+ **kwargs,
+) -> torch.Tensor:
+ r"""
+ x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D)
+ return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
+ """
+ # Assume H is provided in kwargs, since we can't infer H from x's shape.
+ # The padding logic needs H to determine if padding is necessary.
+ H = kwargs.get("NUM_QO_HEAD", None)
+ _, world_size = _get_rank_world_size(group)
+ x, H_PAD = _maybe_pad_o_head(x, H, group)
+ B, S_GLOBAL, H_LOCAL, D = x.shape
+ S_LOCAL = S_GLOBAL // world_size
+ # (B, S_GLOBAL, H_LOCAL, D) -> (world_size, H_LOCAL, B, S_LOCAL, D)
+ x = x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
+ _shape = x.shape # (world_size, H_LOCAL, B, S_LOCAL, D)
+
+ x = x.flatten()
+ x = fc.all_to_all_single(x, None, None, group)
+
+ def wait() -> torch.Tensor:
+ nonlocal x, H_PAD
+ x = _wait_tensor(x)
+ # (world_size, H_LOCAL, B, S_LOCAL, D)
+ # -> (H_GLOBAL, B, S_LOCAL, D)
+ # -> (B, S_LOCAL, H_GLOBAL, D)
+ x = x.reshape(_shape).flatten(0, 1).permute(1, 2, 0, 3).contiguous()
+ x = _maybe_unpad_o_head(x, H_PAD, group)
+ return x
+
+ return wait
+
+
+def _all_to_all_single_qkv_uneven_heads_async(
+ x: torch.Tensor,
+ group: dist.ProcessGroup,
+ **kwargs,
+) -> torch.Tensor:
+ r"""Another variant for uneven head splits without padding.
+ x: torch.Tensor, shape (B, S_LOCAL, H_GLOBAL, D)
+ return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
+ """
+ rank, world_size = _get_rank_world_size(group)
+ B, S_LOCAL, H_GLOBAL, D = x.shape
+ # NOTE: May use tensor_split here to ensure the same split policy
+ # that we have used in the EquipartitionSharder sharding strategy. Please
+ # note that the 'tensor_split' Splits a tensor into multiple sub-tensors,
+ # all of which are views of input, thus may not introduce extra IO access.
+ input_split_sizes = [i.size(2) for i in torch.tensor_split(x, world_size, dim=2)]
+ H_LOCAL = input_split_sizes[rank]
+ # [H_GLOBAL, B, S_LOCAL, D]
+ x = x.permute(2, 0, 1, 3).contiguous()
+ output_split_sizes = [H_LOCAL] * world_size
+ # [H_GLOBAL, B, S_LOCAL, D]
+ x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
+
+ def wait() -> torch.Tensor:
+ nonlocal x
+ x = _wait_tensor(x)
+ # [world_size, H_LOCAL, B, S_LOCAL, D]
+ x = x.reshape(world_size, H_LOCAL, B, S_LOCAL, D)
+ # [B, world_size, S_LOCAL, H_LOCAL, D]
+ x = x.permute(2, 0, 3, 1, 4).contiguous()
+ # [B, S_GLOBAL, H_LOCAL, D]
+ x = x.reshape(B, world_size * S_LOCAL, H_LOCAL, D)
+ return x
+
+ return wait
+
+
+def _all_to_all_single_o_uneven_heads_async(
+ x: torch.Tensor,
+ group: dist.ProcessGroup,
+ **kwargs,
+) -> torch.Tensor:
+ r"""Another variant for uneven head splits without padding.
+ x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D)
+ return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
+ """
+ # Assume H is provided in kwargs, since we can't infer H from x's shape.
+ # The padding logic needs H to determine if padding is necessary.
+ H = kwargs.get("NUM_QO_HEAD", None)
+ B, S_GLOBAL, H_LOCAL, D = x.shape
+ rank, world_size = _get_rank_world_size(group)
+ # e.g, H = 30, world_size = 4, output_split_sizes = [8, 8, 8, 6]
+ output_split_sizes = _split_head_sizes(H, group)
+
+ H_GLOBAL = sum(output_split_sizes)
+ S_LOCAL = S_GLOBAL // world_size
+ # [B, world_size, S_LOCAL, H_LOCAL, D]
+ x = x.reshape(B, world_size, S_LOCAL, H_LOCAL, D)
+ # [world_size, H_LOCAL, B, S_LOCAL, D]
+ x = x.permute(1, 3, 0, 2, 4).contiguous()
+ # [world_size * H_LOCAL, B, S_LOCAL, D]
+ x = x.flatten(0, 1)
+ input_split_sizes = [H_LOCAL] * world_size
+ # [world_size * H_LOCAL, B, S_LOCAL, D]
+ x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
+
+ def wait() -> torch.Tensor:
+ nonlocal x
+ x = _wait_tensor(x)
+ # [H_GLOBAL, B, S_LOCAL, D]
+ x = x.reshape(H_GLOBAL, B, S_LOCAL, D)
+ # [B, S_LOCAL, H_GLOBAL, D]
+ x = x.permute(1, 2, 0, 3).contiguous()
+ return x
+
+ return wait
+
+
+def _all_to_all_single_qkv_fp8_async(
+ x: torch.Tensor,
+ group: dist.ProcessGroup,
+ **kwargs,
+) -> Callable[..., torch.Tensor]:
+ r"""
+ x: torch.Tensor, shape (B, S_LOCAL, H, D)
+ return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
+ """
+ _, world_size = _get_rank_world_size(group)
+ B, S_LOCAL, H, D = x.shape
+ x, H_PAD = _maybe_pad_qkv_head(x, H, group)
+ H_LOCAL = (H + H_PAD) // world_size
+ x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D)
+ x = qkv_permute_quant_fp8(x)
+ shape_with_scale = x.shape # (world_size, S_LOCAL, B, H_LOCAL, D + itemsize)
+ x = x.flatten()
+ x = fc.all_to_all_single(x, None, None, group)
+
+ def wait() -> torch.Tensor:
+ nonlocal x, H_PAD
+ x = _wait_tensor(x)
+ x = x.reshape(shape_with_scale).flatten(0, 1)
+ x = qkv_dequant_permute_fp8(x)
+ x = _maybe_unpad_qkv_head(x, H_PAD, group)
+ return x
+
+ return wait
+
+
+def _all_to_all_single_o_fp8_async(
+ x: torch.Tensor,
+ group: dist.ProcessGroup,
+ **kwargs,
+) -> Callable[..., torch.Tensor]:
+ r"""
+ x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D)
+ return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
+ """
+ # Assume H is provided in kwargs, since we can't infer H from x's shape.
+ # The padding logic needs H to determine if padding is necessary.
+ H = kwargs.get("NUM_QO_HEAD", None)
+ _, world_size = _get_rank_world_size(group)
+ x, H_PAD = _maybe_pad_o_head(x, H, group)
+ B, S_GLOBAL, H_LOCAL, D = x.shape
+ S_LOCAL = S_GLOBAL // world_size
+ # (B, S_GLOBAL, H_LOCAL, D) -> (world_size, H_LOCAL, B, S_LOCAL, D)
+ x = x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
+ _shape = x.shape # (world_size, H_LOCAL, B, S_LOCAL, D)
+
+ x = per_token_quant_fp8(x)
+ shape_with_scale = x.shape # (world_size, H_LOCAL, B, S_LOCAL, D + itemsize)
+ x = x.flatten()
+ x = fc.all_to_all_single(x, None, None, group)
+
+ def wait() -> torch.Tensor:
+ nonlocal x, H_PAD
+ x = _wait_tensor(x)
+ x = x.reshape(shape_with_scale)
+ x = per_token_dequant_fp8(x)
+ # (world_size, H_LOCAL, B, S_LOCAL, D)
+ # -> (H_GLOBAL, B, S_LOCAL, D)
+ # -> (B, H_GLOBAL, S_LOCAL, D)
+ x = x.reshape(_shape).flatten(0, 1).permute(1, 2, 0, 3).contiguous()
+ x = _maybe_unpad_o_head(x, H_PAD, group)
+ return x
+
+ return wait
+
+
+@torch.compiler.allow_in_graph
+def _all_to_all_single_any_qkv_async(
+ x: torch.Tensor,
+ group: dist.ProcessGroup,
+ **kwargs,
+) -> Callable[..., torch.Tensor]:
+ r"""
+ x: torch.Tensor, shape (B, S_LOCAL, H, D)
+ return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
+ """
+ _, world_size = _get_rank_world_size(group)
+ B, S_LOCAL, H, D = x.shape
+ x, H_PAD = _maybe_pad_qkv_head(x, H, group)
+ H_LOCAL = (H + H_PAD) // world_size
+ # (world_size, S_LOCAL, B, H_LOCAL, D)
+ x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
+
+ input_split_sizes = [S_LOCAL] * world_size
+ # S_LOCAL maybe not equal for all ranks in dynamic shape case,
+ # since we don't know the actual shape before this timing, thus,
+ # we have to use all gather to collect the S_LOCAL first.
+ output_split_sizes = _gather_size_by_comm(S_LOCAL, group)
+ # NOTE: The `if` branch will introduce graph break for torch.compile,
+ # so, we choose to disable the even split optimization implementation
+ # _all_to_all_single for now.
+ x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D)
+ x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
+
+ def wait() -> torch.Tensor:
+ nonlocal x, H_PAD
+ x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
+ # (S_GLOBAL, B, H_LOCAL, D)
+ # -> (B, S_GLOBAL, H_LOCAL, D)
+ x = x.permute(1, 0, 2, 3).contiguous()
+ x = _maybe_unpad_qkv_head(x, H_PAD, group)
+ return x
+
+ return wait
+
+
+@torch.compiler.allow_in_graph
+def _all_to_all_single_any_o_async(
+ x: torch.Tensor,
+ group: dist.ProcessGroup,
+ **kwargs,
+) -> Callable[..., torch.Tensor]:
+ r"""
+ x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D)
+ return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
+ """
+ # Assume H is provided in kwargs, since we can't infer H from x's shape.
+ # The padding logic needs H to determine if padding is necessary.
+ H = kwargs.get("NUM_QO_HEAD", None)
+ rank, world_size = _get_rank_world_size(group)
+ x, H_PAD = _maybe_pad_o_head(x, H, group)
+ shape = x.shape # (B, S_GLOBAL, H_LOCAL, D)
+ (B, S_GLOBAL, H_LOCAL, D) = shape
+ # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..]
+ # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..]
+
+ # WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer
+ # from tensor split due to: if c = torch.cat((a, b)), world_size=4, then,
+ # c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] +
+ # b.tensor_split(4)[0].shape[1])
+
+ # input_split_sizes = [o.size(1) for o in torch.tensor_split(x, world_size, dim=1)]
+ # S_LOCAL = input_split_sizes[rank]
+
+ S_LOCAL = kwargs.get("Q_S_LOCAL")
+ input_split_sizes = _gather_size_by_comm(S_LOCAL, group)
+
+ x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D)
+ output_split_sizes = [S_LOCAL] * world_size
+ x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
+
+ def wait() -> torch.Tensor:
+ nonlocal x, H_PAD
+ x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
+ x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D)
+ x = x.permute(2, 1, 0, 3, 4).contiguous()
+ x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D)
+ x = _maybe_unpad_o_head(x, H_PAD, group)
+ return x
+
+ return wait
+
+
+@torch.compiler.allow_in_graph
+def _all_to_all_single_any_qkv_fp8_async(
+ x: torch.Tensor,
+ group: dist.ProcessGroup,
+ **kwargs,
+) -> Callable[..., torch.Tensor]:
+ r"""
+ x: torch.Tensor, shape (B, S_LOCAL, H, D)
+ return: Callable that returns (B, S_GLOBAL, H_LOCAL, D)
+ """
+ _, world_size = _get_rank_world_size(group)
+ B, S_LOCAL, H, D = x.shape
+ x, H_PAD = _maybe_pad_qkv_head(x, H, group)
+ H_LOCAL = (H + H_PAD) // world_size
+ # (world_size, S_LOCAL, B, H_LOCAL, D)
+ x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D)
+
+ input_split_sizes = [S_LOCAL] * world_size
+ # S_LOCAL maybe not equal for all ranks in dynamic shape case,
+ # since we don't know the actual shape before this timing, thus,
+ # we have to use all gather to collect the S_LOCAL first.
+ output_split_sizes = _gather_size_by_comm(S_LOCAL, group)
+ # NOTE: The `if` branch will introduce graph break for torch.compile,
+ # so, we choose to disable the even split optimization implementation
+ # _all_to_all_single for now.
+ x = qkv_permute_quant_fp8(x)
+ x = x.flatten(0, 1)
+ x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
+
+ def wait() -> torch.Tensor:
+ nonlocal x, H_PAD
+ x = _wait_tensor(x)
+ x = qkv_dequant_permute_fp8(x)
+ x = _maybe_unpad_qkv_head(x, H_PAD, group)
+ return x
+
+ return wait
+
+
+@torch.compiler.allow_in_graph
+def _all_to_all_single_any_o_fp8_async(
+ x: torch.Tensor,
+ group: dist.ProcessGroup,
+ **kwargs,
+) -> Callable[..., torch.Tensor]:
+ r"""
+ x: torch.Tensor, shape (B, S_GLOBAL, H_LOCAL, D)
+ return: Callable that returns (B, S_LOCAL, H_GLOBAL, D)
+ """
+ # Assume H is provided in kwargs, since we can't infer H from x's shape.
+ # The padding logic needs H to determine if padding is necessary.
+ H = kwargs.get("NUM_QO_HEAD", None)
+ rank, world_size = _get_rank_world_size(group)
+ x, H_PAD = _maybe_pad_o_head(x, H, group)
+ shape = x.shape # (B, S_GLOBAL, H_LOCAL, D)
+ x = per_token_quant_fp8(x)
+ (B, S_GLOBAL, H_LOCAL, D) = shape
+ # input_split: e.g, S_GLOBAL=9 input splits across ranks [[5,4], [5,4],..]
+ # output_split: e.g, S_GLOBAL=9 output splits across ranks [[5,5], [4,4],..]
+
+ # WARN: In some cases, e.g, joint attn in Qwen-Image, the S_LOCAL can not infer
+ # from tensor split due to: if c = torch.cat((a, b)), world_size=4, then,
+ # c.tensor_split(4)[0].shape[1] may != to (a.tensor_split(4)[0].shape[1] +
+ # b.tensor_split(4)[0].shape[1])
+
+ # input_split_sizes = [o.size(1) for o in torch.tensor_split(x, world_size, dim=1)]
+ # S_LOCAL = input_split_sizes[rank]
+
+ S_LOCAL = kwargs.get("Q_S_LOCAL")
+ input_split_sizes = _gather_size_by_comm(S_LOCAL, group)
+
+ x = x.permute(1, 0, 2, 3).contiguous() # (S_GLOBAL, B, H_LOCAL, D)
+ output_split_sizes = [S_LOCAL] * world_size
+ x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
+
+ def wait() -> torch.Tensor:
+ nonlocal x, H_PAD
+ x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
+ x = per_token_dequant_fp8(x)
+ x = x.reshape(world_size, S_LOCAL, B, H_LOCAL, D)
+ x = x.permute(2, 1, 0, 3, 4).contiguous()
+ x = x.reshape(B, S_LOCAL, world_size * H_LOCAL, D)
+ x = _maybe_unpad_o_head(x, H_PAD, group)
+ return x
+
+ return wait
+
+
+# Unified functions to select proper all to all implementations according to
+# Ulysses Float8 or other settings. Mainly used in Async Ulysses Attention.
+# TODO: Refactor basic any_qkvo and non-any_qkvo all2all functions to have
+# the same output shape, thus make the unified functions more general and clean.
+
+
+def _unified_all_to_all_qkv_async_fn(
+ fp8: Optional[bool] = None,
+) -> Callable[..., torch.Tensor]:
+ from ._templated_ulysses import is_ulysses_float8_enabled
+ from ._templated_ulysses import is_ulysses_anything_enabled
+ from ._templated_ulysses import is_ulysses_heads_no_padding
+
+ _force_disable_float8 = (fp8 is not None) and (not fp8)
+ if is_ulysses_anything_enabled():
+ if is_ulysses_float8_enabled() and not _force_disable_float8:
+ return _all_to_all_single_any_qkv_fp8_async
+ return _all_to_all_single_any_qkv_async
+ else:
+ if is_ulysses_float8_enabled() and not _force_disable_float8:
+ assert (
+ not is_ulysses_heads_no_padding()
+ ), "FP8 and ulysses heads no padding both enabled is not supported."
+ return _all_to_all_single_qkv_fp8_async
+ if is_ulysses_heads_no_padding():
+ return _all_to_all_single_qkv_uneven_heads_async
+ return _all_to_all_single_qkv_async
+
+
+def _unified_all_to_all_o_async_fn(
+ fp8: Optional[bool] = None,
+) -> Callable[..., torch.Tensor]:
+ from ._templated_ulysses import is_ulysses_float8_enabled
+ from ._templated_ulysses import is_ulysses_anything_enabled
+ from ._templated_ulysses import is_ulysses_heads_no_padding
+
+ _force_disable_float8 = (fp8 is not None) and (not fp8)
+ if is_ulysses_anything_enabled():
+ if is_ulysses_float8_enabled() and not _force_disable_float8:
+ return _all_to_all_single_any_o_fp8_async
+ return _all_to_all_single_any_o_async
+ else:
+ if is_ulysses_float8_enabled() and not _force_disable_float8:
+ assert (
+ not is_ulysses_heads_no_padding()
+ ), "FP8 and ulysses heads no padding both enabled is not supported."
+ return _all_to_all_single_o_fp8_async
+ if is_ulysses_heads_no_padding():
+ return _all_to_all_single_o_uneven_heads_async
+ return _all_to_all_single_o_async
diff --git a/src/cache_dit/parallelism/attention/_experimental_utils.py b/src/cache_dit/parallelism/attention/_experimental_utils.py
new file mode 100644
index 000000000..c335ac319
--- /dev/null
+++ b/src/cache_dit/parallelism/attention/_experimental_utils.py
@@ -0,0 +1,96 @@
+import torch
+import functools
+import diffusers
+from typing import List, Union
+
+try:
+ from diffusers import ContextParallelConfig # noqa: F401
+ from diffusers.hooks.context_parallel import (
+ _find_submodule_by_name as _find_submodule_by_name_for_context_parallel,
+ )
+
+ def _is_diffusers_parallelism_available() -> bool:
+ return True
+
+except ImportError:
+ ContextParallelConfig = None
+ _find_submodule_by_name_for_context_parallel = None
+
+ def _is_diffusers_parallelism_available() -> bool:
+ return False
+
+
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+__all__ = [
+ "_is_diffusers_parallelism_available",
+ "_maybe_patch_find_submodule",
+]
+
+# NOTE: Add this utility function to diffusers to support ModuleDict, such as 'all_final_layer', like ZImage
+# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/hooks/context_parallel.py#L283
+# This function is only used when diffusers native context parallelism is enabled and can compatible with the
+# original one.
+if (
+ _is_diffusers_parallelism_available()
+ and _find_submodule_by_name_for_context_parallel is not None
+):
+
+ @functools.wraps(_find_submodule_by_name_for_context_parallel)
+ def _patch_find_submodule_by_name(
+ model: torch.nn.Module, name: str
+ ) -> Union[torch.nn.Module, List[torch.nn.Module]]:
+ if name == "":
+ return model
+ first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
+ if first_atom == "*":
+ if not isinstance(model, torch.nn.ModuleList):
+ raise ValueError("Wildcard '*' can only be used with ModuleList")
+ submodules = []
+ for submodule in model:
+ subsubmodules = _patch_find_submodule_by_name(submodule, remaining_name)
+ if not isinstance(subsubmodules, list):
+ if isinstance(subsubmodules, torch.nn.ModuleDict):
+ subsubmodules = list(subsubmodules.values())
+ else:
+ subsubmodules = [subsubmodules]
+ submodules.extend(subsubmodules)
+ return submodules
+ else:
+ if hasattr(model, first_atom):
+ submodule = getattr(model, first_atom)
+ if isinstance(submodule, torch.nn.ModuleDict): # e.g, 'all_final_layer' in ZImage
+ if remaining_name == "":
+ submodule = list(submodule.values())
+ # Make sure all values are Modules, not support other complex cases.
+ for v in submodule:
+ if not isinstance(v, torch.nn.Module):
+ raise ValueError(
+ f"Value '{v}' in ModuleDict '{first_atom}' is not a Module"
+ )
+ return submodule
+ else:
+ raise ValueError(
+ f"Cannot access submodule '{remaining_name}' of ModuleDict '{first_atom}' directly. "
+ f"Please specify the key of the ModuleDict first."
+ )
+ return _patch_find_submodule_by_name(submodule, remaining_name)
+ else:
+ raise ValueError(
+ f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'"
+ )
+
+ def _maybe_patch_find_submodule():
+ if (
+ diffusers.hooks.context_parallel._find_submodule_by_name
+ != _patch_find_submodule_by_name
+ ):
+ diffusers.hooks.context_parallel._find_submodule_by_name = _patch_find_submodule_by_name
+ logger.debug("Patched _find_submodule_by_name to support ModuleDict.")
+
+else:
+
+ def _maybe_patch_find_submodule():
+ pass
diff --git a/src/cache_dit/parallelism/attention/_templated_ring.py b/src/cache_dit/parallelism/attention/_templated_ring.py
new file mode 100644
index 000000000..22a1f8e59
--- /dev/null
+++ b/src/cache_dit/parallelism/attention/_templated_ring.py
@@ -0,0 +1,56 @@
+# TODO: Support TemplatedRingAttention in cache-dit with PyTorch context-parallel api.
+# Reference: https://docs.pytorch.org/tutorials/unstable/context_parallel.html
+import torch
+from typing import Optional
+
+try:
+ from diffusers.models.attention_dispatch import TemplatedRingAttention
+ from diffusers.models._modeling_parallel import ParallelConfig
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+
+
+__all__ = ["UnifiedTemplatedRingAttention"]
+
+
+class UnifiedTemplatedRingAttention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ return _TemplatedRingAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+
+
+class _TemplatedRingAttention(TemplatedRingAttention):
+ """A wrapper of diffusers' TemplatedRingAttention to avoid name conflict."""
+
+ pass
diff --git a/src/cache_dit/parallelism/attention/_templated_ulysses.py b/src/cache_dit/parallelism/attention/_templated_ulysses.py
new file mode 100644
index 000000000..ee51876f9
--- /dev/null
+++ b/src/cache_dit/parallelism/attention/_templated_ulysses.py
@@ -0,0 +1,809 @@
+import copy
+import functools
+from typing import Optional, Tuple, List
+
+import torch
+import torch.distributed as dist
+
+try:
+ from diffusers.models._modeling_parallel import ParallelConfig
+ from diffusers.hooks.context_parallel import EquipartitionSharder
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+from ._distributed_primitives import (
+ _get_rank_world_size,
+ _gather_size_by_comm,
+ # All to all for Ulysses Attention
+ _all_to_all_single_o_async,
+ _all_to_all_single_qkv_fp8_async,
+ _all_to_all_single_o_fp8_async,
+ _all_to_all_single_qkv_uneven_heads_async,
+ _all_to_all_single_o_uneven_heads_async,
+ # All to all for Ulysses Anything Attention
+ _all_to_all_single_any_o_async,
+ _all_to_all_single_any_qkv_async,
+ _all_to_all_single_any_o_fp8_async,
+ _all_to_all_single_any_qkv_fp8_async,
+ _all_to_all_single_qkv_async,
+ # Helper functions for preparing communication metadata
+ _prepare_ulysses_comm_metadata,
+)
+
+from cache_dit.envs import ENV
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+__all__ = [
+ "UnifiedTemplatedUlyssesAttention",
+ "EquipartitionSharder",
+ "enable_ulysses_anything",
+ "is_ulysses_anything_enabled",
+ "disable_ulysses_anything",
+ "enable_ulysses_float8",
+ "is_ulysses_float8_enabled",
+ "disable_ulysses_float8",
+ "is_ulysses_heads_no_padding",
+]
+
+
+class UnifiedTemplatedUlyssesAttention(torch.autograd.Function):
+ """A unified wrapper for all Ulysses Attention variants in cache-dit."""
+
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ if is_ulysses_anything_enabled():
+ # Ulysses Anything Attention: Any sequence length and any head num supported.
+ if is_ulysses_float8_enabled():
+ return _TemplatedUlyssesAnythingAttentionFloat8.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ else:
+ return _TemplatedUlyssesAnythingAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ else:
+ # Ulysses Attention: Support even sequence length and any head num.
+ if is_ulysses_float8_enabled():
+ return _TemplatedUlyssesAttentionFloat8.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ else:
+ if is_ulysses_heads_no_padding():
+ return _TemplatedUlyssesAttentionUnEvenHeads.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+ else:
+ return _TemplatedUlyssesAttention.apply(
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ forward_op,
+ backward_op,
+ _parallel_config,
+ )
+
+
+# Re-implement Ulysses Attention with custom async all-to-all communication in cache-dit
+# Use '_' prefix to avoid name conflict with diffusers' TemplatedUlyssesAttention.
+class _TemplatedUlyssesAttention(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx._parallel_config = _parallel_config
+
+ metadata = _prepare_ulysses_comm_metadata(query)
+ query_wait = _all_to_all_single_qkv_async(query, group, **metadata)
+ key_wait = _all_to_all_single_qkv_async(key, group, **metadata)
+ value_wait = _all_to_all_single_qkv_async(value, group, **metadata)
+
+ query = query_wait() # type: torch.Tensor
+ key = key_wait() # type: torch.Tensor
+ value = value_wait() # type: torch.Tensor
+ out = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ _save_ctx=False,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ out_wait = _all_to_all_single_o_async(out, group, **metadata)
+
+ if return_lse:
+ # NOTE: DON'T use float8 all_to_all for out and lse, as it may
+ # cause more numerical instability.
+ lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1)
+ lse_wait = _all_to_all_single_o_async(lse, group, **metadata)
+ out = out_wait() # type: torch.Tensor
+ lse = lse_wait() # type: torch.Tensor
+ lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL)
+ else:
+ out = out_wait() # type: torch.Tensor
+ lse = None
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ raise NotImplementedError(
+ "Backward pass for Ulysses Attention in cache-dit is not implemented yet."
+ )
+
+
+class _TemplatedUlyssesAttentionUnEvenHeads(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx._parallel_config = _parallel_config
+
+ metadata = _prepare_ulysses_comm_metadata(query)
+ # Async all to all for query, key, value with uneven heads communication
+ query_wait = _all_to_all_single_qkv_uneven_heads_async(query, group, **metadata)
+ key_wait = _all_to_all_single_qkv_uneven_heads_async(key, group, **metadata)
+ value_wait = _all_to_all_single_qkv_uneven_heads_async(value, group, **metadata)
+
+ query = query_wait() # type: torch.Tensor
+ key = key_wait() # type: torch.Tensor
+ value = value_wait() # type: torch.Tensor
+
+ out = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ _save_ctx=False,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D)
+ out_wait = _all_to_all_single_o_uneven_heads_async(out, group, **metadata)
+
+ if return_lse:
+ # NOTE: DON'T use float8 all_to_all for out and lse, as it may
+ # cause more numerical instability.
+ lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1)
+ lse_wait = _all_to_all_single_o_uneven_heads_async(lse, group, **metadata)
+ out = out_wait() # type: torch.Tensor
+ lse = lse_wait() # type: torch.Tensor
+ lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL)
+ else:
+ out = out_wait() # type: torch.Tensor
+ lse = None
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ raise NotImplementedError(
+ "Backward pass for Ulysses Attention in cache-dit is not implemented yet."
+ )
+
+
+class _TemplatedUlyssesAttentionFloat8(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ ):
+ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx._parallel_config = _parallel_config
+
+ metadata = _prepare_ulysses_comm_metadata(query)
+ # Use async all_to_all to overlap comm and quant/dequant computation
+ # NOTE: Currently, we choose to keep K in FP16/BF16 format to keep higher
+ # precision during softmax computation: Softmax(Q@K^T) which is sensitive to
+ # numerical instability. So we only use float8 all_to_all for Q, V and O.
+ # TODO: We should relax this design and support all QKV in float8 format while
+ # the K-per-channel-smooth (e.g., in SageAttention) is used to improve numerical
+ # stability. Using this smooth technique before All-to-All on K may introduce
+ # extra AllReduce communication overhead.
+ key_wait = _all_to_all_single_qkv_async(key, group, **metadata)
+ query_wait = _all_to_all_single_qkv_fp8_async(query, group, **metadata)
+ value_wait = _all_to_all_single_qkv_fp8_async(value, group, **metadata)
+
+ query = query_wait() # type: torch.Tensor
+ value = value_wait() # type: torch.Tensor
+ key = key_wait() # type: torch.Tensor
+
+ out = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ _save_ctx=False,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ out_wait = _all_to_all_single_o_fp8_async(out, group, **metadata)
+
+ if return_lse:
+ # NOTE: DON'T use float8 all_to_all for out and lse, as it may
+ # cause more numerical instability.
+ lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1)
+ lse_wait = _all_to_all_single_o_async(lse, group, **metadata)
+ out = out_wait() # type: torch.Tensor
+ lse = lse_wait() # type: torch.Tensor
+ lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL)
+ else:
+ out = out_wait() # type: torch.Tensor
+ lse = None
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ raise NotImplementedError(
+ "Backward pass for Ulysses Attention Float8 in cache-dit is not implemented yet."
+ )
+
+
+class _TemplatedUlyssesAnythingAttention(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ **kwargs,
+ ):
+ # TODO: Should we only use float8 all_to_all for VO not QK? The softmax in
+ # QK may cause more numerical instability than P@V matrix multiplication.
+ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx._parallel_config = _parallel_config
+
+ metadata = _prepare_ulysses_comm_metadata(query)
+ query_wait = _all_to_all_single_any_qkv_async(query, group, **metadata)
+ key_wait = _all_to_all_single_any_qkv_async(key, group, **metadata)
+ value_wait = _all_to_all_single_any_qkv_async(value, group, **metadata)
+
+ query = query_wait() # type: torch.Tensor
+ key = key_wait() # type: torch.Tensor
+ value = value_wait() # type: torch.Tensor
+
+ out = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ _save_ctx=False,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D)
+ out_wait = _all_to_all_single_any_o_async(out, group, **metadata)
+
+ if return_lse:
+ # lse: (B, S_Q_GLOBAL, H_LOCAL)
+ lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1)
+ lse_wait = _all_to_all_single_any_o_async(lse, group, **metadata)
+ out = out_wait() # type: torch.Tensor
+ lse = lse_wait() # type: torch.Tensor
+ lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL)
+ else:
+ out = out_wait() # type: torch.Tensor
+ lse = None
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ raise NotImplementedError(
+ "Backward pass for Ulysses Anything Attention in cache-dit is not implemented yet."
+ )
+
+
+class _TemplatedUlyssesAnythingAttentionFloat8(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx: torch.autograd.function.FunctionCtx,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: Optional[torch.Tensor],
+ dropout_p: float,
+ is_causal: bool,
+ scale: Optional[float],
+ enable_gqa: bool,
+ return_lse: bool,
+ forward_op,
+ backward_op,
+ _parallel_config: Optional["ParallelConfig"] = None,
+ **kwargs,
+ ):
+ # TODO: Should we only use float8 all_to_all for VO not QK? The softmax in
+ # QK may cause more numerical instability than P@V matrix multiplication.
+ ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ ctx.forward_op = forward_op
+ ctx.backward_op = backward_op
+ ctx._parallel_config = _parallel_config
+
+ metadata = _prepare_ulysses_comm_metadata(query)
+ # Use async all_to_all to overlap comm and quant/dequant computation
+ # NOTE: Currently, we choose to keep K in FP16/BF16 format to keep higher
+ # precision during softmax computation: Softmax(Q@K^T) which is sensitive to
+ # numerical instability. So we only use float8 all_to_all for Q, V and O.
+ # TODO: We should relax this design and support all QKV in float8 format while
+ # the K-per-channel-smooth (e.g., in SageAttention) is used to improve numerical
+ # stability. Using this smooth technique before All-to-All on K may introduce
+ # extra AllReduce communication overhead.
+ key_wait = _all_to_all_single_any_qkv_async(key, group, **metadata)
+ query_wait = _all_to_all_single_any_qkv_fp8_async(query, group, **metadata)
+ value_wait = _all_to_all_single_any_qkv_fp8_async(value, group, **metadata)
+
+ query = query_wait() # type: torch.Tensor
+ value = value_wait() # type: torch.Tensor
+ key = key_wait() # type: torch.Tensor
+
+ out = forward_op(
+ ctx,
+ query,
+ key,
+ value,
+ attn_mask,
+ dropout_p,
+ is_causal,
+ scale,
+ enable_gqa,
+ return_lse,
+ _save_ctx=False,
+ _parallel_config=_parallel_config,
+ )
+ if return_lse:
+ out, lse, *_ = out
+
+ out_wait = _all_to_all_single_any_o_fp8_async(out, group, **metadata)
+
+ if return_lse:
+ # NOTE: DON'T use float8 all_to_all for out and lse, as it may
+ # cause more numerical instability.
+ lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1)
+ lse_wait = _all_to_all_single_any_o_async(lse, group, **metadata)
+ out = out_wait() # type: torch.Tensor
+ lse = lse_wait() # type: torch.Tensor
+ lse = lse.squeeze(-1).contiguous() # (B, S_Q_LOCAL, H_GLOBAL)
+ else:
+ out = out_wait() # type: torch.Tensor
+ lse = None
+
+ return (out, lse) if return_lse else out
+
+ @staticmethod
+ def backward(
+ ctx: torch.autograd.function.FunctionCtx,
+ grad_out: torch.Tensor,
+ *args,
+ ):
+ raise NotImplementedError(
+ "Backward pass for Ulysses Anything Attention Float8 in cache-dit is not implemented yet."
+ )
+
+
+@functools.lru_cache(maxsize=64)
+def _fill_gather_shapes(
+ shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int
+) -> List[List[int]]:
+ gather_shapes = []
+ for i in range(world_size):
+ # WARN: deepcopy to avoid modifying the original shape
+ rank_shape = list(copy.deepcopy(shape))
+ rank_shape[dim] = gather_dims[i]
+ gather_shapes.append(rank_shape)
+ return gather_shapes
+
+
+@torch.compiler.allow_in_graph
+def _all_gather_anything( # noqa: F811
+ tensor: torch.Tensor,
+ dim: int,
+ group: dist.device_mesh.DeviceMesh,
+) -> torch.Tensor:
+ _, world_size = _get_rank_world_size(group)
+ tensor = tensor.contiguous()
+ shape = tensor.shape
+ rank_dim = shape[dim]
+ gather_dims = _gather_size_by_comm(rank_dim, group)
+
+ # NOTE: The `if` branch will introduce graph break for torch.compile,
+ # so, we choose to disable the even split optimization for now.
+
+ gather_shapes = _fill_gather_shapes(
+ tuple(shape),
+ tuple(gather_dims),
+ dim,
+ world_size,
+ )
+
+ gathered_tensors = [
+ torch.empty(
+ shape,
+ device=tensor.device,
+ dtype=tensor.dtype,
+ )
+ for shape in gather_shapes
+ ]
+
+ dist.all_gather(gathered_tensors, tensor, group=group)
+ gathered_tensor = torch.cat(gathered_tensors, dim=dim)
+ return gathered_tensor
+
+
+# NOTE: dist.all_gather, Gathers tensors from the whole group in a list.
+# Complex and uneven sized tensors are supported.
+class AllGatherAnythingFunction(torch.autograd.Function):
+
+ @staticmethod
+ def forward(
+ ctx,
+ tensor: torch.Tensor,
+ dim: int,
+ group: dist.device_mesh.DeviceMesh,
+ ):
+ ctx.dim = dim
+ ctx.group = group
+ ctx.world_size = dist.get_world_size(group)
+ ctx.rank = dist.get_rank(group)
+ gathered_tensor = _all_gather_anything(tensor, dim, group)
+ return gathered_tensor
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # NOTE: We use `tensor_split` instead of chunk, because the `chunk`
+ # function may return fewer than the specified number of chunks!
+ grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
+ return grad_splits[ctx.rank], None, None
+
+
+# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
+# function may return fewer than the specified number of chunks! For example,
+# x = torch.tensor([1,2,3,4,5]), torch.chunk(x, 4) will return only 3 chunks:
+# (tensor([1, 2]), tensor([3, 4]), tensor([5])). This behavior can lead to
+# inconsistencies when sharding tensors across multiple devices. In contrast,
+# tensor_split will always return the specified number of chunks, the last chunk
+# may be smaller if the tensor size is not divisible by the number of chunks.
+# For example, torch.tensor_split(x, 4) will return 4 chunks:
+# (tensor([1, 2]), tensor([3]), tensor([4]), tensor([5])).
+@classmethod
+@functools.wraps(EquipartitionSharder.shard)
+def shard_anything(
+ cls: EquipartitionSharder,
+ tensor: torch.Tensor,
+ dim: int,
+ mesh: dist.device_mesh.DeviceMesh,
+ **kwargs,
+) -> torch.Tensor:
+ assert tensor.size()[dim] >= mesh.size(), (
+ f"Cannot shard tensor of size {tensor.size()} along dim {dim} "
+ f"across mesh of size {mesh.size()}."
+ )
+ return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]
+
+
+# NOTE: We use AllGatherAnythingFunction to support gathering
+# tensors with complex and uneven sizes across all ranks. It handles the
+# case where the tensor size (the seq_len of hidden_states) along the
+# specified dimension is not divisible by the number of ranks in the mesh.
+@classmethod
+@functools.wraps(EquipartitionSharder.unshard)
+def unshard_anything(
+ cls,
+ tensor: torch.Tensor,
+ dim: int,
+ mesh: torch.distributed.device_mesh.DeviceMesh,
+ **kwargs,
+) -> torch.Tensor:
+ tensor = tensor.contiguous()
+ tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
+ return tensor
+
+
+# Environment variable flags for Ulysses Attention variants in cache-dit.
+def is_ulysses_heads_no_padding() -> bool:
+ return ENV.CACHE_DIT_UNEVEN_HEADS_COMM_NO_PAD
+
+
+def enable_ulysses_anything(**kwargs):
+ try:
+ if ENV.CACHE_DIT_ENABELD_ULYSSES_ANYTHING:
+ # function for TemplatedUlyssesAnythingAttention.
+ if EquipartitionSharder.shard != shard_anything:
+ EquipartitionSharder.shard = shard_anything
+ EquipartitionSharder.unshard = unshard_anything
+ logger.warning(
+ "Ulysses Anything Attention is already enabled in cache-dit. "
+ "but EquipartitionSharder.shard/unshard is not set correctly, "
+ "resetting it to the correct shard/unshard_anything function."
+ )
+ return
+
+ ENV.CACHE_DIT_ENABELD_ULYSSES_ANYTHING = True
+
+ logger.warning(
+ "Ulysses Anything Attention is enabled in cache-dit. "
+ "Please note that this is an experimental feature and "
+ "may not be fully tested."
+ )
+
+ # Ensure the EquipartitionSharder uses our modified shard_anything
+ # function for TemplatedUlyssesAnythingAttention.
+ if EquipartitionSharder.shard != shard_anything:
+ EquipartitionSharder.shard = shard_anything
+ EquipartitionSharder.unshard = unshard_anything
+ logger.info(
+ "EquipartitionSharder.shard/unshard is set to shard/unshard_anything function "
+ "for Ulysses Anything Attention."
+ )
+ except Exception as e:
+ ENV.CACHE_DIT_ENABELD_ULYSSES_ANYTHING = False
+ logger.error(f"Failed to enable Ulysses Anything Attention in cache-dit due to error: {e}")
+ pass
+
+
+def is_ulysses_anything_enabled(**kwargs) -> bool:
+ return ENV.CACHE_DIT_ENABELD_ULYSSES_ANYTHING
+
+
+def disable_ulysses_anything(**kwargs):
+ ENV.CACHE_DIT_ENABELD_ULYSSES_ANYTHING = False
+ logger.info("Ulysses Anything Attention is manually disabled in cache-dit.")
+
+
+# Float8 flags for Ulysses/Ulysses Anything Attention
+def _enable_ulysses_anything_float8(**kwargs):
+ try:
+ if ENV.CACHE_DIT_ENABELD_ULYSSES_ANYTHING_FLOAT8:
+ # function for TemplatedUlyssesAnythingAttention.
+ if EquipartitionSharder.shard != shard_anything:
+ EquipartitionSharder.shard = shard_anything
+ EquipartitionSharder.unshard = unshard_anything
+ logger.warning(
+ "Ulysses Anything Attention Float8 is already enabled in cache-dit. "
+ "but EquipartitionSharder.shard/unshard is not set correctly, "
+ "resetting it to the correct shard/unshard_anything function."
+ )
+ return
+
+ ENV.CACHE_DIT_ENABELD_ULYSSES_ANYTHING_FLOAT8 = True
+
+ logger.warning(
+ "Ulysses Anything Attention Float8 is enabled in cache-dit. "
+ "Please note that this is an experimental feature and "
+ "may not be fully tested."
+ )
+
+ # Ensure the EquipartitionSharder uses our modified shard_anything
+ # function for TemplatedUlyssesAnythingAttention.
+ if EquipartitionSharder.shard != shard_anything:
+ EquipartitionSharder.shard = shard_anything
+ EquipartitionSharder.unshard = unshard_anything
+ logger.info(
+ "EquipartitionSharder.shard/unshard is set to shard/unshard_anything function "
+ "for Ulysses Anything Attention Float8."
+ )
+ except Exception as e:
+ ENV.CACHE_DIT_ENABELD_ULYSSES_ANYTHING_FLOAT8 = False
+ logger.error(
+ f"Failed to enable Ulysses Anything Attention Float8 in cache-dit due to error: {e}"
+ )
+ pass
+
+
+def _is_ulysses_anything_float8_enabled(**kwargs) -> bool:
+ return ENV.CACHE_DIT_ENABELD_ULYSSES_ANYTHING_FLOAT8
+
+
+def _disable_ulysses_anything_float8(**kwargs) -> bool:
+ ENV.CACHE_DIT_ENABELD_ULYSSES_ANYTHING_FLOAT8 = False
+ logger.info("Ulysses Anything Attention Float8 is manually disabled in cache-dit.")
+
+
+def enable_ulysses_float8(**kwargs):
+
+ # Check if Ulysses Anything Attention is already enabled
+ if is_ulysses_anything_enabled():
+ _enable_ulysses_anything_float8()
+ return
+
+ ENV.CACHE_DIT_ENABELD_ULYSSES_FLOAT8 = True
+ logger.warning(
+ "Ulysses Attention Float8 is enabled in cache-dit. "
+ "Please note that this is an experimental feature and "
+ "may not be fully tested."
+ )
+
+
+def is_ulysses_float8_enabled(**kwargs) -> bool:
+ return ENV.CACHE_DIT_ENABELD_ULYSSES_FLOAT8 or _is_ulysses_anything_float8_enabled()
+
+
+def disable_ulysses_float8(**kwargs) -> bool:
+ ENV.CACHE_DIT_ENABELD_ULYSSES_FLOAT8 = False
+ logger.info("Ulysses Attention Float8 is manually disabled in cache-dit.")
+ if is_ulysses_anything_enabled():
+ _disable_ulysses_anything_float8()
diff --git a/src/cache_dit/parallelism/autoencoders/__init__.py b/src/cache_dit/parallelism/autoencoders/__init__.py
new file mode 100644
index 000000000..134f2258f
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/__init__.py
@@ -0,0 +1 @@
+from .dispatch import maybe_enable_parallelism_for_auto_encoder
diff --git a/src/cache_dit/parallelism/autoencoders/data_parallelism/__init__.py b/src/cache_dit/parallelism/autoencoders/data_parallelism/__init__.py
new file mode 100644
index 000000000..94dbd76c4
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/data_parallelism/__init__.py
@@ -0,0 +1,39 @@
+import torch
+from typing import Optional
+from cache_dit.parallelism.config import ParallelismConfig
+from cache_dit.logger import init_logger
+
+try:
+ from .dp_plan_registers import AutoEncoderDataParallelismPlannerRegister
+ from .dp_planners import _activate_auto_encoder_dp_planners
+
+ _activate_auto_encoder_dp_planners()
+except ImportError as e:
+ raise ImportError(e)
+
+logger = init_logger(__name__)
+
+
+def maybe_enable_data_parallelism(
+ auto_encoder: torch.nn.Module,
+ parallelism_config: Optional[ParallelismConfig],
+) -> torch.nn.Module:
+ assert isinstance(
+ auto_encoder, torch.nn.Module
+ ), f"auto_encoder must be an instance of torch.nn.Module, but got {type(auto_encoder)}"
+
+ if parallelism_config is None:
+ return auto_encoder
+
+ # We don't check backend here because auto encoder may use different
+ # parallelism backend with transformer.
+
+ extra_parallel_kwargs = {}
+ if parallelism_config.parallel_kwargs is not None:
+ extra_parallel_kwargs = parallelism_config.parallel_kwargs
+
+ return AutoEncoderDataParallelismPlannerRegister.get_planner(auto_encoder)().apply(
+ auto_encoder=auto_encoder,
+ parallelism_config=parallelism_config,
+ **extra_parallel_kwargs,
+ )
diff --git a/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl.py b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl.py
new file mode 100644
index 000000000..c17543b98
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl.py
@@ -0,0 +1,205 @@
+# Adapted from: https://github.com/chengzeyi/ParaAttention.git
+import functools
+
+import torch
+import torch.distributed as dist
+from diffusers import AutoencoderKL
+from diffusers.models.autoencoders.vae import DecoderOutput
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+from .dp_plan_registers import (
+ AutoEncoderDataParallelismPlanner,
+ AutoEncoderDataParallelismPlannerRegister,
+)
+from .utils import send_tensor, recv_tensor
+
+logger = init_logger(__name__)
+
+
+@AutoEncoderDataParallelismPlannerRegister.register("AutoencoderKL")
+class AutoencoderKLDataParallelismPlanner(AutoEncoderDataParallelismPlanner):
+ def apply(
+ self,
+ auto_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ auto_encoder, AutoencoderKL
+ ), "AutoencoderKLDataParallelismPlanner can only be applied to AutoencoderKL"
+ auto_encoder_world_size = parallelism_config.auto_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ dp_mesh = dist.init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[auto_encoder_world_size],
+ )
+
+ auto_encoder = self.parallelize_tiling(
+ auto_encoder=auto_encoder,
+ dp_mesh=dp_mesh,
+ )
+
+ return auto_encoder
+
+ def parallelize_tiling(
+ self,
+ auto_encoder: AutoencoderKL,
+ dp_mesh: dist.DeviceMesh,
+ ):
+ group = dp_mesh.get_group()
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+
+ auto_encoder.enable_tiling()
+
+ @functools.wraps(auto_encoder.__class__._tiled_encode)
+ def new_tiled_encode(
+ self: AutoencoderKL,
+ x: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ count = 0
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ if count % world_size == rank:
+ tile = x[
+ :,
+ :,
+ i : i + self.tile_sample_min_size,
+ j : j + self.tile_sample_min_size,
+ ]
+ tile = self.encoder(tile)
+ if self.config.use_quant_conv:
+ tile = self.quant_conv(tile)
+ else:
+ tile = None
+ row.append(tile)
+ count += 1
+ rows.append(row)
+
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=x.device, dtype=x.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ tile = rows[i][j]
+ if tile is not None:
+ send_tensor(tile, 0, group)
+
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ enc = torch.cat(result_rows, dim=2)
+ else:
+ enc = recv_tensor(rank - 1, group, device=x.device, dtype=x.dtype)
+ if rank < world_size - 1:
+ send_tensor(enc, rank + 1, group)
+ return enc
+
+ auto_encoder._tiled_encode = new_tiled_encode.__get__(auto_encoder)
+
+ @functools.wraps(auto_encoder.__class__.tiled_decode)
+ def new_tiled_decode(
+ self: AutoencoderKL,
+ z: torch.Tensor,
+ *args,
+ return_dict: bool = False,
+ **kwargs,
+ ):
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ count = 0
+ rows = []
+ for i in range(0, z.shape[2], overlap_size):
+ row = []
+ for j in range(0, z.shape[3], overlap_size):
+ if count % world_size == rank:
+ tile = z[
+ :,
+ :,
+ i : i + self.tile_latent_min_size,
+ j : j + self.tile_latent_min_size,
+ ]
+ if self.config.use_post_quant_conv:
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ else:
+ decoded = None
+ row.append(decoded)
+ count += 1
+ rows.append(row)
+
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=z.device, dtype=z.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ decoded = rows[i][j]
+ if decoded is not None:
+ send_tensor(decoded, 0, group)
+
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ dec = torch.cat(result_rows, dim=2)
+ else:
+ dec = recv_tensor(rank - 1, group, device=z.device, dtype=z.dtype)
+ if rank < world_size - 1:
+ send_tensor(dec, rank + 1, group)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ auto_encoder.tiled_decode = new_tiled_decode.__get__(auto_encoder)
+
+ return auto_encoder
diff --git a/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_flux2.py b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_flux2.py
new file mode 100644
index 000000000..f86a4d32e
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_flux2.py
@@ -0,0 +1,204 @@
+import functools
+
+import torch
+import torch.distributed as dist
+from diffusers import AutoencoderKLFlux2
+from diffusers.models.autoencoders.vae import DecoderOutput
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+from .dp_plan_registers import (
+ AutoEncoderDataParallelismPlanner,
+ AutoEncoderDataParallelismPlannerRegister,
+)
+from .utils import send_tensor, recv_tensor
+
+logger = init_logger(__name__)
+
+
+@AutoEncoderDataParallelismPlannerRegister.register("AutoencoderKLFlux2")
+class AutoencoderKLFlux2DataParallelismPlanner(AutoEncoderDataParallelismPlanner):
+ def apply(
+ self,
+ auto_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ auto_encoder, AutoencoderKLFlux2
+ ), "AutoencoderKLFlux2DataParallelismPlanner can only be applied to AutoencoderKLFlux2"
+ auto_encoder_world_size = parallelism_config.auto_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ dp_mesh = dist.init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[auto_encoder_world_size],
+ )
+
+ auto_encoder = self.parallelize_tiling(
+ auto_encoder=auto_encoder,
+ dp_mesh=dp_mesh,
+ )
+
+ return auto_encoder
+
+ def parallelize_tiling(
+ self,
+ auto_encoder: AutoencoderKLFlux2,
+ dp_mesh: dist.DeviceMesh,
+ ):
+ group = dp_mesh.get_group()
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+
+ auto_encoder.enable_tiling()
+
+ @functools.wraps(auto_encoder.__class__._tiled_encode)
+ def new_tiled_encode(
+ self: AutoencoderKLFlux2,
+ x: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ count = 0
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ if count % world_size == rank:
+ tile = x[
+ :,
+ :,
+ i : i + self.tile_sample_min_size,
+ j : j + self.tile_sample_min_size,
+ ]
+ tile = self.encoder(tile)
+ if self.config.use_quant_conv:
+ tile = self.quant_conv(tile)
+ else:
+ tile = None
+ row.append(tile)
+ count += 1
+ rows.append(row)
+
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=x.device, dtype=x.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ tile = rows[i][j]
+ if tile is not None:
+ send_tensor(tile, 0, group)
+
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ enc = torch.cat(result_rows, dim=2)
+ else:
+ enc = recv_tensor(rank - 1, group, device=x.device, dtype=x.dtype)
+ if rank < world_size - 1:
+ send_tensor(enc, rank + 1, group)
+ return enc
+
+ auto_encoder._tiled_encode = new_tiled_encode.__get__(auto_encoder)
+
+ @functools.wraps(auto_encoder.__class__.tiled_decode)
+ def new_tiled_decode(
+ self: AutoencoderKLFlux2,
+ z: torch.Tensor,
+ *args,
+ return_dict: bool = False,
+ **kwargs,
+ ):
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ count = 0
+ rows = []
+ for i in range(0, z.shape[2], overlap_size):
+ row = []
+ for j in range(0, z.shape[3], overlap_size):
+ if count % world_size == rank:
+ tile = z[
+ :,
+ :,
+ i : i + self.tile_latent_min_size,
+ j : j + self.tile_latent_min_size,
+ ]
+ if self.config.use_post_quant_conv:
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ else:
+ decoded = None
+ row.append(decoded)
+ count += 1
+ rows.append(row)
+
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=z.device, dtype=z.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ decoded = rows[i][j]
+ if decoded is not None:
+ send_tensor(decoded, 0, group)
+
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ dec = torch.cat(result_rows, dim=2)
+ else:
+ dec = recv_tensor(rank - 1, group, device=z.device, dtype=z.dtype)
+ if rank < world_size - 1:
+ send_tensor(dec, rank + 1, group)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ auto_encoder.tiled_decode = new_tiled_decode.__get__(auto_encoder)
+
+ return auto_encoder
diff --git a/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_hunyuanvideo.py b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_hunyuanvideo.py
new file mode 100644
index 000000000..d13c6d930
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_hunyuanvideo.py
@@ -0,0 +1,242 @@
+# Adapted from: https://github.com/chengzeyi/ParaAttention.git
+import functools
+
+import torch
+import torch.distributed as dist
+from diffusers import AutoencoderKLHunyuanVideo
+from diffusers.models.autoencoders.vae import DecoderOutput
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+from .dp_plan_registers import (
+ AutoEncoderDataParallelismPlanner,
+ AutoEncoderDataParallelismPlannerRegister,
+)
+from .utils import send_tensor, recv_tensor
+
+logger = init_logger(__name__)
+
+
+@AutoEncoderDataParallelismPlannerRegister.register("AutoencoderKLHunyuanVideo")
+class AutoencoderKLHunyuanVideoDataParallelismPlanner(AutoEncoderDataParallelismPlanner):
+ def apply(
+ self,
+ auto_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ auto_encoder, AutoencoderKLHunyuanVideo
+ ), "AutoencoderKLHunyuanVideoDataParallelismPlanner can only be applied to AutoencoderKLHunyuanVideo"
+ auto_encoder_world_size = parallelism_config.auto_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ dp_mesh = dist.init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[auto_encoder_world_size],
+ )
+
+ auto_encoder = self.parallelize_tiling(
+ auto_encoder=auto_encoder,
+ dp_mesh=dp_mesh,
+ )
+
+ return auto_encoder
+
+ def parallelize_tiling(
+ self,
+ auto_encoder: AutoencoderKLHunyuanVideo,
+ dp_mesh: dist.DeviceMesh,
+ ):
+ group = dp_mesh.get_group()
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+
+ auto_encoder.enable_tiling()
+
+ @functools.wraps(auto_encoder.__class__.tiled_encode)
+ def new_tiled_encode(
+ self: AutoencoderKLHunyuanVideo,
+ x: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ batch_size, num_channels, num_frames, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = (
+ self.tile_sample_stride_height // self.spatial_compression_ratio
+ )
+ tile_latent_stride_width = (
+ self.tile_sample_stride_width // self.spatial_compression_ratio
+ )
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ if hasattr(self, "tile_sample_min_height"):
+ tile_sample_min_height = self.tile_sample_min_height
+ else:
+ tile_sample_min_height = self.tile_sample_min_size
+
+ if hasattr(self, "tile_sample_min_width"):
+ tile_sample_min_width = self.tile_sample_min_width
+ else:
+ tile_sample_min_width = self.tile_sample_min_size
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ count = 0
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ if count % world_size == rank:
+ tile = x[
+ :, :, :, i : i + tile_sample_min_height, j : j + tile_sample_min_width
+ ]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ else:
+ tile = None
+ row.append(tile)
+ count += 1
+ rows.append(row)
+
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=x.device, dtype=x.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ tile = rows[i][j]
+ if tile is not None:
+ send_tensor(tile, 0, group)
+
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(
+ tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]
+ )
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ else:
+ enc = recv_tensor(rank - 1, group, device=x.device, dtype=x.dtype)
+ if rank < world_size - 1:
+ send_tensor(enc, rank + 1, group)
+ return enc
+
+ auto_encoder.tiled_encode = new_tiled_encode.__get__(auto_encoder)
+
+ @functools.wraps(auto_encoder.__class__.tiled_decode)
+ def new_tiled_decode(
+ self: AutoencoderKLHunyuanVideo,
+ z: torch.Tensor,
+ *args,
+ return_dict: bool = False,
+ **kwargs,
+ ):
+ batch_size, num_channels, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = (
+ self.tile_sample_stride_height // self.spatial_compression_ratio
+ )
+ tile_latent_stride_width = (
+ self.tile_sample_stride_width // self.spatial_compression_ratio
+ )
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ count = 0
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ if count % world_size == rank:
+ tile = z[
+ :, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width
+ ]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ else:
+ decoded = None
+ row.append(decoded)
+ count += 1
+ rows.append(row)
+
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=z.device, dtype=z.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ decoded = rows[i][j]
+ if decoded is not None:
+ send_tensor(decoded, 0, group)
+
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(
+ tile[
+ :,
+ :,
+ :,
+ : self.tile_sample_stride_height,
+ : self.tile_sample_stride_width,
+ ]
+ )
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+ else:
+ dec = recv_tensor(rank - 1, group, device=z.device, dtype=z.dtype)
+ if rank < world_size - 1:
+ send_tensor(dec, rank + 1, group)
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(dec, dec)
+
+ auto_encoder.tiled_decode = new_tiled_decode.__get__(auto_encoder)
+
+ return auto_encoder
diff --git a/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_ltx2.py b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_ltx2.py
new file mode 100644
index 000000000..bcbe412bd
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_ltx2.py
@@ -0,0 +1,231 @@
+import functools
+
+import torch
+import torch.distributed as dist
+from typing import Optional
+from diffusers import AutoencoderKLLTX2Video
+from diffusers.models.autoencoders.vae import DecoderOutput
+
+from cache_dit.parallelism.config import ParallelismConfig
+from .dp_plan_registers import (
+ AutoEncoderDataParallelismPlanner,
+ AutoEncoderDataParallelismPlannerRegister,
+)
+from .utils import send_tensor, recv_tensor
+
+
+@AutoEncoderDataParallelismPlannerRegister.register("AutoencoderKLLTX2Video")
+class AutoencoderKLLTX2VideoDataParallelismPlanner(AutoEncoderDataParallelismPlanner):
+ def apply(
+ self,
+ auto_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ auto_encoder, AutoencoderKLLTX2Video
+ ), "AutoencoderKLLTX2VideoDataParallelismPlanner can only be applied to AutoencoderKLLTX2Video"
+
+ auto_encoder_world_size = parallelism_config.auto_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ dp_mesh = dist.init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[auto_encoder_world_size],
+ )
+
+ auto_encoder = self.parallelize_tiling(
+ auto_encoder=auto_encoder,
+ dp_mesh=dp_mesh,
+ )
+
+ return auto_encoder
+
+ def parallelize_tiling(
+ self,
+ auto_encoder: AutoencoderKLLTX2Video,
+ dp_mesh: dist.DeviceMesh,
+ ):
+ group = dp_mesh.get_group()
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+
+ auto_encoder.enable_tiling()
+
+ @functools.wraps(auto_encoder.__class__.tiled_encode)
+ def new_tiled_encode(
+ self: AutoencoderKLLTX2Video,
+ x: torch.Tensor,
+ causal=None,
+ *args,
+ **kwargs,
+ ):
+ _, _, num_frames, height, width = x.shape
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = (
+ self.tile_sample_stride_height // self.spatial_compression_ratio
+ )
+ tile_latent_stride_width = (
+ self.tile_sample_stride_width // self.spatial_compression_ratio
+ )
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ count = 0
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ if count % world_size == rank:
+ tile = x[
+ :,
+ :,
+ :,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ tile = self.encoder(tile, causal=causal)
+ else:
+ tile = None
+ row.append(tile)
+ count += 1
+ rows.append(row)
+
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=x.device, dtype=x.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ tile = rows[i][j]
+ if tile is not None:
+ send_tensor(tile, 0, group)
+
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(
+ tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]
+ )
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ else:
+ enc = recv_tensor(rank - 1, group, device=x.device, dtype=x.dtype)
+
+ if rank < world_size - 1:
+ send_tensor(enc, rank + 1, group)
+
+ return enc
+
+ auto_encoder.tiled_encode = new_tiled_encode.__get__(auto_encoder)
+
+ @functools.wraps(auto_encoder.__class__.tiled_decode)
+ def new_tiled_decode(
+ self: AutoencoderKLLTX2Video,
+ z: torch.Tensor,
+ temb: Optional[torch.Tensor],
+ causal=None,
+ return_dict: bool = True,
+ *args,
+ **kwargs,
+ ):
+ _, _, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = (
+ self.tile_sample_stride_height // self.spatial_compression_ratio
+ )
+ tile_latent_stride_width = (
+ self.tile_sample_stride_width // self.spatial_compression_ratio
+ )
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ count = 0
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ if count % world_size == rank:
+ tile = z[
+ :, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width
+ ]
+ decoded = self.decoder(tile, temb, causal=causal)
+ else:
+ decoded = None
+ row.append(decoded)
+ count += 1
+ rows.append(row)
+
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=z.device, dtype=z.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ decoded = rows[i][j]
+ if decoded is not None:
+ send_tensor(decoded, 0, group)
+
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(
+ tile[
+ :,
+ :,
+ :,
+ : self.tile_sample_stride_height,
+ : self.tile_sample_stride_width,
+ ]
+ )
+ result_rows.append(torch.cat(result_row, dim=4))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+ else:
+ dec = recv_tensor(rank - 1, group, device=z.device, dtype=z.dtype)
+
+ if rank < world_size - 1:
+ send_tensor(dec, rank + 1, group)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ auto_encoder.tiled_decode = new_tiled_decode.__get__(auto_encoder)
+
+ return auto_encoder
diff --git a/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_qwen_image.py b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_qwen_image.py
new file mode 100644
index 000000000..b1ba6fa04
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_qwen_image.py
@@ -0,0 +1,295 @@
+import functools
+
+import torch
+import torch.distributed as dist
+from diffusers import AutoencoderKLQwenImage
+from diffusers.models.autoencoders.vae import DecoderOutput
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+from .dp_plan_registers import (
+ AutoEncoderDataParallelismPlanner,
+ AutoEncoderDataParallelismPlannerRegister,
+)
+from .utils import send_tensor, recv_tensor
+
+logger = init_logger(__name__)
+
+
+@AutoEncoderDataParallelismPlannerRegister.register("AutoencoderKLQwenImage")
+class AutoencoderKLQwenImageDataParallelismPlanner(AutoEncoderDataParallelismPlanner):
+ def apply(
+ self,
+ auto_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ auto_encoder, AutoencoderKLQwenImage
+ ), "AutoencoderKLQwenImageDataParallelismPlanner can only be applied to AutoencoderKLQwenImage"
+ auto_encoder_world_size = parallelism_config.auto_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ dp_mesh = dist.init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[auto_encoder_world_size],
+ )
+
+ auto_encoder = self.parallelize_tiling(
+ auto_encoder=auto_encoder,
+ dp_mesh=dp_mesh,
+ )
+
+ return auto_encoder
+
+ def parallelize_tiling(
+ self,
+ auto_encoder: AutoencoderKLQwenImage,
+ dp_mesh: dist.DeviceMesh,
+ ):
+ group = dp_mesh.get_group()
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+
+ auto_encoder.enable_tiling()
+
+ @functools.wraps(auto_encoder.__class__.tiled_encode)
+ def new_tiled_encode(
+ self: AutoencoderKLQwenImage,
+ x: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ _, _, num_frames, height, width = x.shape
+
+ # Overwrite tile size and stride for better performance while
+ # still reducing memory usage.
+ if min(height, width) >= 1024:
+ self.tile_sample_min_height = 512
+ self.tile_sample_min_width = 512
+ self.tile_sample_stride_height = 384
+ self.tile_sample_stride_width = 384
+
+ latent_height = height // self.spatial_compression_ratio
+ latent_width = width // self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = (
+ self.tile_sample_stride_height // self.spatial_compression_ratio
+ )
+ tile_latent_stride_width = (
+ self.tile_sample_stride_width // self.spatial_compression_ratio
+ )
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ count = 0
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ if count % world_size == rank:
+ # num_frames = 1 for image model
+ self.clear_cache()
+ time = []
+ frame_range = 1 + (num_frames - 1) // 4
+ for k in range(frame_range):
+ self._enc_conv_idx = [0]
+ if k == 0:
+ tile = x[
+ :,
+ :,
+ :1,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ else:
+ tile = x[
+ :,
+ :,
+ 1 + 4 * (k - 1) : 1 + 4 * k,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ tile = self.encoder(
+ tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
+ )
+ tile = self.quant_conv(tile)
+ time.append(tile)
+ tile = torch.cat(time, dim=2)
+ else:
+ tile = None
+ row.append(tile)
+ count += 1
+ rows.append(row)
+ self.clear_cache()
+
+ # Gather all tiles to rank 0
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=x.device, dtype=x.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ tile = rows[i][j]
+ if tile is not None:
+ send_tensor(tile, 0, group)
+
+ # Blend tiles on rank 0
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(
+ tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]
+ )
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ else:
+ enc = recv_tensor(rank - 1, group, device=x.device, dtype=x.dtype)
+
+ # Propagate result through all ranks
+ if rank < world_size - 1:
+ send_tensor(enc, rank + 1, group)
+
+ return enc
+
+ auto_encoder.tiled_encode = new_tiled_encode.__get__(auto_encoder)
+
+ @functools.wraps(auto_encoder.__class__.tiled_decode)
+ def new_tiled_decode(
+ self: AutoencoderKLQwenImage,
+ z: torch.Tensor,
+ *args,
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ _, _, num_frames, height, width = z.shape
+
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ # Overwrite tile size and stride for better performance while
+ # still reducing memory usage.
+ if min(sample_height, sample_width) >= 1024:
+ self.tile_sample_min_height = 512
+ self.tile_sample_min_width = 512
+ self.tile_sample_stride_height = 384
+ self.tile_sample_stride_width = 384
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = (
+ self.tile_sample_stride_height // self.spatial_compression_ratio
+ )
+ tile_latent_stride_width = (
+ self.tile_sample_stride_width // self.spatial_compression_ratio
+ )
+
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ count = 0
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ if count % world_size == rank:
+ # num_frames = 1 for image model
+ self.clear_cache()
+ time = []
+ for k in range(num_frames):
+ self._conv_idx = [0]
+ tile = z[
+ :,
+ :,
+ k : k + 1,
+ i : i + tile_latent_min_height,
+ j : j + tile_latent_min_width,
+ ]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(
+ tile, feat_cache=self._feat_map, feat_idx=self._conv_idx
+ )
+ time.append(decoded)
+ decoded = torch.cat(time, dim=2)
+ else:
+ decoded = None
+ row.append(decoded)
+ count += 1
+ rows.append(row)
+ self.clear_cache()
+
+ # Gather all tiles to rank 0
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=z.device, dtype=z.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ decoded = rows[i][j]
+ if decoded is not None:
+ send_tensor(decoded, 0, group)
+
+ # Blend tiles on rank 0
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(
+ tile[
+ :,
+ :,
+ :,
+ : self.tile_sample_stride_height,
+ : self.tile_sample_stride_width,
+ ]
+ )
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+ else:
+ dec = recv_tensor(rank - 1, group, device=z.device, dtype=z.dtype)
+
+ # Propagate result through all ranks
+ if rank < world_size - 1:
+ send_tensor(dec, rank + 1, group)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ auto_encoder.tiled_decode = new_tiled_decode.__get__(auto_encoder)
+
+ return auto_encoder
diff --git a/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_wan.py b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_wan.py
new file mode 100644
index 000000000..9a28843cb
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_autoencoder_kl_wan.py
@@ -0,0 +1,294 @@
+# Adapted from: https://github.com/chengzeyi/ParaAttention.git
+import functools
+
+import torch
+import torch.distributed as dist
+from diffusers import AutoencoderKLWan
+from diffusers.models.autoencoders.vae import DecoderOutput
+from diffusers.models.autoencoders.autoencoder_kl_wan import unpatchify
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+from .dp_plan_registers import (
+ AutoEncoderDataParallelismPlanner,
+ AutoEncoderDataParallelismPlannerRegister,
+)
+from .utils import send_tensor, recv_tensor
+
+logger = init_logger(__name__)
+
+
+@AutoEncoderDataParallelismPlannerRegister.register("AutoencoderKLWan")
+class AutoencoderKLWanDataParallelismPlanner(AutoEncoderDataParallelismPlanner):
+ def apply(
+ self,
+ auto_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ auto_encoder, AutoencoderKLWan
+ ), "AutoencoderKLWanDataParallelismPlanner can only be applied to AutoencoderKLWan"
+ auto_encoder_world_size = parallelism_config.auto_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ dp_mesh = dist.init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[auto_encoder_world_size],
+ )
+
+ auto_encoder = self.parallelize_tiling(
+ auto_encoder=auto_encoder,
+ dp_mesh=dp_mesh,
+ )
+ return auto_encoder
+
+ def parallelize_tiling(
+ self,
+ auto_encoder: AutoencoderKLWan,
+ dp_mesh: dist.DeviceMesh,
+ ):
+ group = dp_mesh.get_group()
+ world_size = dist.get_world_size(group)
+ rank = dist.get_rank(group)
+
+ auto_encoder.enable_tiling()
+
+ @functools.wraps(auto_encoder.__class__.tiled_encode)
+ def new_tiled_encode(
+ self: AutoencoderKLWan,
+ x: torch.Tensor,
+ *args,
+ **kwargs,
+ ):
+ _, _, num_frames, height, width = x.shape
+ encode_spatial_compression_ratio = self.spatial_compression_ratio
+ if self.config.patch_size is not None:
+ assert encode_spatial_compression_ratio % self.config.patch_size == 0
+ encode_spatial_compression_ratio = (
+ self.spatial_compression_ratio // self.config.patch_size
+ )
+
+ latent_height = height // encode_spatial_compression_ratio
+ latent_width = width // encode_spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // encode_spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // encode_spatial_compression_ratio
+ tile_latent_stride_height = (
+ self.tile_sample_stride_height // encode_spatial_compression_ratio
+ )
+ tile_latent_stride_width = (
+ self.tile_sample_stride_width // encode_spatial_compression_ratio
+ )
+
+ blend_height = tile_latent_min_height - tile_latent_stride_height
+ blend_width = tile_latent_min_width - tile_latent_stride_width
+
+ # Split x into overlapping tiles and encode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ count = 0
+ rows = []
+ for i in range(0, height, self.tile_sample_stride_height):
+ row = []
+ for j in range(0, width, self.tile_sample_stride_width):
+ # TODO(DefTrut): Reduce computation overhead caused by iteration over all
+ # frames inside each tile. We can try to encode all frames in a tile at once.
+ if count % world_size == rank:
+ self.clear_cache()
+ time = []
+ frame_range = 1 + (num_frames - 1) // 4
+ for k in range(frame_range):
+ self._enc_conv_idx = [0]
+ if k == 0:
+ tile = x[
+ :,
+ :,
+ :1,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ else:
+ tile = x[
+ :,
+ :,
+ 1 + 4 * (k - 1) : 1 + 4 * k,
+ i : i + self.tile_sample_min_height,
+ j : j + self.tile_sample_min_width,
+ ]
+ tile = self.encoder(
+ tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
+ )
+ tile = self.quant_conv(tile)
+ time.append(tile)
+ enc = torch.cat(time, dim=2)
+ self.clear_cache()
+ else:
+ enc = None
+ row.append(enc)
+ count += 1
+ rows.append(row)
+
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=x.device, dtype=x.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ tile = rows[i][j]
+ if tile is not None:
+ send_tensor(tile, 0, group)
+
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(
+ tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]
+ )
+ result_rows.append(torch.cat(result_row, dim=-1))
+
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
+ else:
+ enc = recv_tensor(rank - 1, group, device=x.device, dtype=x.dtype)
+ if rank < world_size - 1:
+ send_tensor(enc, rank + 1, group)
+ return enc
+
+ auto_encoder.tiled_encode = new_tiled_encode.__get__(auto_encoder)
+
+ @functools.wraps(auto_encoder.__class__.tiled_decode)
+ def new_tiled_decode(
+ self: AutoencoderKLWan,
+ z: torch.Tensor,
+ *args,
+ return_dict: bool = True,
+ **kwargs,
+ ):
+ _, _, num_frames, height, width = z.shape
+ sample_height = height * self.spatial_compression_ratio
+ sample_width = width * self.spatial_compression_ratio
+
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
+ tile_latent_stride_height = (
+ self.tile_sample_stride_height // self.spatial_compression_ratio
+ )
+ tile_latent_stride_width = (
+ self.tile_sample_stride_width // self.spatial_compression_ratio
+ )
+ tile_sample_stride_height = self.tile_sample_stride_height
+ tile_sample_stride_width = self.tile_sample_stride_width
+ if self.config.patch_size is not None:
+ sample_height = sample_height // self.config.patch_size
+ sample_width = sample_width // self.config.patch_size
+ tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size
+ tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size
+ blend_height = (
+ self.tile_sample_min_height // self.config.patch_size
+ - tile_sample_stride_height
+ )
+ blend_width = (
+ self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width
+ )
+ else:
+ blend_height = self.tile_sample_min_height - tile_sample_stride_height
+ blend_width = self.tile_sample_min_width - tile_sample_stride_width
+
+ # Split z into overlapping tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ count = 0
+ rows = []
+ for i in range(0, height, tile_latent_stride_height):
+ row = []
+ for j in range(0, width, tile_latent_stride_width):
+ # TODO(DefTrut): Reduce computation overhead caused by iteration over all
+ # frames inside each tile. We can try to decode all frames in a tile at once.
+ if count % world_size == rank:
+ self.clear_cache()
+ time = []
+ for k in range(num_frames):
+ self._conv_idx = [0]
+ tile = z[
+ :,
+ :,
+ k : k + 1,
+ i : i + tile_latent_min_height,
+ j : j + tile_latent_min_width,
+ ]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(
+ tile,
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx,
+ first_chunk=(k == 0),
+ )
+ time.append(decoded)
+ decoded = torch.cat(time, dim=2)
+ self.clear_cache()
+ else:
+ decoded = None
+ row.append(decoded)
+ count += 1
+ rows.append(row)
+
+ if rank == 0:
+ count = 0
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ if count % world_size != rank:
+ rows[i][j] = recv_tensor(
+ count % world_size, group, device=z.device, dtype=z.dtype
+ )
+ count += 1
+ else:
+ for i in range(len(rows)):
+ for j in range(len(rows[i])):
+ decoded = rows[i][j]
+ if decoded is not None:
+ send_tensor(decoded, 0, group)
+
+ if rank == 0:
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_width)
+ result_row.append(
+ tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]
+ )
+ result_rows.append(torch.cat(result_row, dim=-1))
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
+ else:
+ dec = recv_tensor(rank - 1, group, device=z.device, dtype=z.dtype)
+ if rank < world_size - 1:
+ send_tensor(dec, rank + 1, group)
+
+ if self.config.patch_size is not None:
+ dec = unpatchify(dec, patch_size=self.config.patch_size)
+
+ dec = torch.clamp(dec, min=-1.0, max=1.0)
+
+ if not return_dict:
+ return (dec,)
+ return DecoderOutput(sample=dec)
+
+ auto_encoder.tiled_decode = new_tiled_decode.__get__(auto_encoder)
+
+ return auto_encoder
diff --git a/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_registers.py b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_registers.py
new file mode 100644
index 000000000..44332238c
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_plan_registers.py
@@ -0,0 +1,59 @@
+import torch
+import logging
+from abc import abstractmethod
+from typing import Dict
+from cache_dit.parallelism.config import ParallelismConfig
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+class AutoEncoderDataParallelismPlanner:
+
+ @abstractmethod
+ def apply(
+ self,
+ auto_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ raise NotImplementedError("apply method must be implemented by subclasses")
+
+
+class AutoEncoderDataParallelismPlannerRegister:
+ _auto_encoder_dp_planner_registry: Dict[str, AutoEncoderDataParallelismPlanner] = {}
+
+ @classmethod
+ def register(cls, name: str):
+ def decorator(planner_cls: type[AutoEncoderDataParallelismPlanner]):
+ assert (
+ name not in cls._auto_encoder_dp_planner_registry
+ ), f"AutoEncoderDataParallelismPlanner with name {name} is already registered."
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(f"Registering AutoEncoderDataParallelismPlanner: {name}")
+ cls._auto_encoder_dp_planner_registry[name] = planner_cls
+ return planner_cls
+
+ return decorator
+
+ @classmethod
+ def get_planner(
+ cls, auto_encoder: str | torch.nn.Module
+ ) -> type[AutoEncoderDataParallelismPlanner]:
+ if isinstance(auto_encoder, torch.nn.Module):
+ name = auto_encoder.__class__.__name__
+ else:
+ name = auto_encoder
+ planner_cls = None
+ if name in cls._auto_encoder_dp_planner_registry:
+ planner_cls = cls._auto_encoder_dp_planner_registry[name]
+ if planner_cls is None:
+ raise ValueError(f"No planner registered under name: {name}")
+ return planner_cls
+
+ @classmethod
+ def supported_planners(
+ cls,
+ ) -> tuple[int, list[str]]:
+ val_planners = cls._auto_encoder_dp_planner_registry.keys()
+ return len(val_planners), [p for p in val_planners]
diff --git a/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_planners.py b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_planners.py
new file mode 100644
index 000000000..2288fd356
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/data_parallelism/dp_planners.py
@@ -0,0 +1,54 @@
+import importlib
+from cache_dit.logger import init_logger
+from .dp_plan_registers import AutoEncoderDataParallelismPlanner
+
+logger = init_logger(__name__)
+
+
+class ImportErrorAutoEncoderDataParallelismPlanner(AutoEncoderDataParallelismPlanner):
+ def plan(
+ self,
+ auto_encoder,
+ **kwargs,
+ ):
+ raise ImportError(
+ "This AutoEncoderDataParallelismPlanner requires latest diffusers to be installed. "
+ "Please install diffusers from source."
+ )
+
+
+def _safe_import(module_name: str, class_name: str) -> type[AutoEncoderDataParallelismPlanner]:
+ try:
+ # e.g., module_name = ".dp_plan_autoencoder_kl", class_name = "AutoencoderKLDataParallelismPlanner"
+ package = __package__ if __package__ is not None else ""
+ module = importlib.import_module(module_name, package=package)
+ target_class = getattr(module, class_name)
+ return target_class
+ except (ImportError, AttributeError) as e:
+ logger.debug(f"Failed to import {class_name} from {module_name}: {e}")
+ return ImportErrorAutoEncoderDataParallelismPlanner
+
+
+def _activate_auto_encoder_dp_planners():
+ """Function to register all built-in auto encoder data parallelism planners."""
+ AutoencoderKLDataParallelismPlanner = _safe_import( # noqa: F841
+ ".dp_plan_autoencoder_kl", "AutoencoderKLDataParallelismPlanner"
+ )
+ AutoencoderKLLTX2VideoDataParallelismPlanner = _safe_import( # noqa: F841
+ ".dp_plan_autoencoder_kl_ltx2", "AutoencoderKLLTX2VideoDataParallelismPlanner"
+ )
+ AutoencoderKLQwenImageDataParallelismPlanner = _safe_import( # noqa: F841
+ ".dp_plan_autoencoder_kl_qwen_image", "AutoencoderKLQwenImageDataParallelismPlanner"
+ )
+ AutoencoderKLWanDataParallelismPlanner = _safe_import( # noqa: F841
+ ".dp_plan_autoencoder_kl_wan", "AutoencoderKLWanDataParallelismPlanner"
+ )
+ AutoencoderKLHunyuanVideoDataParallelismPlanner = _safe_import( # noqa: F841
+ ".dp_plan_autoencoder_kl_hunyuanvideo", "AutoencoderKLHunyuanVideoDataParallelismPlanner"
+ )
+ AutoencoderKLFlux2DataParallelismPlanner = _safe_import( # noqa: F841
+ ".dp_plan_autoencoder_kl_flux2", "AutoencoderKLFlux2DataParallelismPlanner"
+ )
+
+
+__all__ = ["_activate_auto_encoder_dp_planners"]
diff --git a/src/cache_dit/parallelism/autoencoders/data_parallelism/utils.py b/src/cache_dit/parallelism/autoencoders/data_parallelism/utils.py
new file mode 100644
index 000000000..e5f816741
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/data_parallelism/utils.py
@@ -0,0 +1,25 @@
+import torch
+import torch.distributed as dist
+
+
+def send_tensor(
+ tensor: torch.Tensor,
+ dst: int,
+ group: dist.ProcessGroup,
+) -> None:
+ tensor = tensor.contiguous()
+ dist.send_object_list([tensor.shape], dst=dst, group=group)
+ dist.send(tensor, dst=dst, group=group)
+
+
+def recv_tensor(
+ src: int,
+ group: dist.ProcessGroup,
+ device=None,
+ dtype=None,
+) -> torch.Tensor:
+ objects = [None]
+ dist.recv_object_list(objects, src=src, group=group)
+ t = torch.empty(objects[0], device=device, dtype=dtype)
+ dist.recv(t, src=src, group=group)
+ return t
diff --git a/src/cache_dit/parallelism/autoencoders/dispatch.py b/src/cache_dit/parallelism/autoencoders/dispatch.py
new file mode 100644
index 000000000..28e4b57e4
--- /dev/null
+++ b/src/cache_dit/parallelism/autoencoders/dispatch.py
@@ -0,0 +1,41 @@
+from typing import Optional
+
+import torch
+
+from cache_dit.parallelism.config import ParallelismConfig
+
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def maybe_enable_parallelism_for_auto_encoder(
+ auto_encoder: torch.nn.Module,
+ parallelism_config: Optional[ParallelismConfig],
+) -> torch.nn.Module:
+ assert isinstance(
+ auto_encoder, torch.nn.Module
+ ), f"auto_encoder must be an instance of torch.nn.Module, but got {type(auto_encoder)}"
+ if getattr(auto_encoder, "_is_parallelized", False):
+ logger.warning("The auto encoder is already parallelized. Skipping parallelism enabling.")
+ return auto_encoder
+
+ if parallelism_config is None:
+ return auto_encoder
+
+ from .data_parallelism import maybe_enable_data_parallelism
+
+ auto_encoder = maybe_enable_data_parallelism(
+ auto_encoder=auto_encoder,
+ parallelism_config=parallelism_config,
+ )
+
+ auto_encoder._is_parallelized = True # type: ignore[attr-defined]
+ auto_encoder._parallelism_config = parallelism_config # type: ignore[attr-defined]
+
+ logger.info(
+ f"Parallelize Auto Encoder: {auto_encoder.__class__.__name__}, "
+ f"id:{id(auto_encoder)}, {parallelism_config.strify(True, vae=True)}"
+ )
+
+ return auto_encoder
diff --git a/src/cache_dit/parallelism/parallel_backend.py b/src/cache_dit/parallelism/backend.py
similarity index 66%
rename from src/cache_dit/parallelism/parallel_backend.py
rename to src/cache_dit/parallelism/backend.py
index 9f05a35bd..b27179f9f 100644
--- a/src/cache_dit/parallelism/parallel_backend.py
+++ b/src/cache_dit/parallelism/backend.py
@@ -2,13 +2,16 @@
class ParallelismBackend(Enum):
+ AUTO = "Auto"
NATIVE_DIFFUSER = "Native_Diffuser"
NATIVE_PYTORCH = "Native_PyTorch"
NONE = "None"
@classmethod
def is_supported(cls, backend: "ParallelismBackend") -> bool:
- if backend == cls.NATIVE_PYTORCH:
+ if backend == cls.AUTO:
+ return True
+ elif backend == cls.NATIVE_PYTORCH:
return True
elif backend == cls.NATIVE_DIFFUSER:
try:
@@ -24,3 +27,13 @@ def is_supported(cls, backend: "ParallelismBackend") -> bool:
)
return True
return False
+
+ @classmethod
+ def from_str(cls, backend_str: str) -> "ParallelismBackend":
+ for backend in cls:
+ if backend.value.lower() == backend_str.lower():
+ return backend
+ raise ValueError(f"Unsupported parallelism backend: {backend_str}.")
+
+ def __str__(self) -> str:
+ return self.value
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/__init__.py b/src/cache_dit/parallelism/backends/native_diffusers/__init__.py
deleted file mode 100644
index bffa816ce..000000000
--- a/src/cache_dit/parallelism/backends/native_diffusers/__init__.py
+++ /dev/null
@@ -1,12 +0,0 @@
-from cache_dit.parallelism.backends.native_diffusers.context_parallelism import (
- ContextParallelismPlannerRegister,
-)
-from cache_dit.parallelism.backends.native_diffusers.context_parallelism.attention import (
- enable_ulysses_anything,
-)
-from cache_dit.parallelism.backends.native_diffusers.context_parallelism.attention import (
- disable_ulysses_anything,
-)
-from cache_dit.parallelism.backends.native_diffusers.parallel_difffusers import (
- maybe_enable_parallelism,
-)
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py b/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py
deleted file mode 100644
index 9e5a9247d..000000000
--- a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py
+++ /dev/null
@@ -1,11 +0,0 @@
-def maybe_resigter_native_attention_backend():
- """Maybe re-register native attention backend to enable context parallelism."""
- # Import custom attention backend ensuring registration
- from ._attention_dispatch import _native_attention
-
-
-from ._templated_ulysses_anything import (
- enable_ulysses_anything,
- is_ulysses_anything_enabled,
- disable_ulysses_anything,
-)
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py b/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py
deleted file mode 100644
index 1ac9abc7b..000000000
--- a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py
+++ /dev/null
@@ -1,296 +0,0 @@
-import os
-import torch
-from typing import Optional
-
-try:
- from diffusers.models.attention_dispatch import (
- _AttentionBackendRegistry,
- AttentionBackendName,
- _check_device,
- _check_shape,
- TemplatedRingAttention,
- TemplatedUlyssesAttention,
- )
- from diffusers.models._modeling_parallel import ParallelConfig
-except ImportError:
- raise ImportError(
- "Context parallelism requires the 'diffusers>=0.36.dev0'."
- "Please install latest version of diffusers from source: \n"
- "pip3 install git+https://github.com/huggingface/diffusers.git"
- )
-from cache_dit.logger import init_logger
-from ._templated_ulysses_anything import TemplatedUlyssesAnythingAttention
-from ._templated_ulysses_anything import is_ulysses_anything_enabled
-
-
-logger = init_logger(__name__)
-
-
-__all__ = [
- "_native_attention",
-]
-
-# Enable custom native attention backend with context parallelism
-# by default. Users can set the environment variable to 0 to disable
-# this behavior. Default to enabled for better compatibility.
-_CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH = bool(
- int(os.getenv("CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH", "1"))
-)
-
-
-def _is_native_attn_supported_context_parallel() -> bool:
- try:
- return (
- AttentionBackendName.NATIVE in _AttentionBackendRegistry._supports_context_parallel
- and _AttentionBackendRegistry._supports_context_parallel[AttentionBackendName.NATIVE]
- )
- except Exception:
- assert isinstance(_AttentionBackendRegistry._supports_context_parallel, set)
- return (
- AttentionBackendName.NATIVE.value
- in _AttentionBackendRegistry._supports_context_parallel
- )
-
-
-if _CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH:
- logger.warning(
- "Re-registering NATIVE attention backend to enable context parallelism. "
- "This is a temporary workaround and should be removed after the native "
- "attention backend supports context parallelism natively. Please check: "
- "https://github.com/huggingface/diffusers/pull/12563 for more details. "
- "Or, you can disable this behavior by setting the environment variable "
- "`CACHE_DIT_ENABLE_CUSTOM_CP_NATIVE_ATTN_DISPATCH=0`."
- )
- _AttentionBackendRegistry._backends.pop(AttentionBackendName.NATIVE)
- _AttentionBackendRegistry._constraints.pop(AttentionBackendName.NATIVE)
- _AttentionBackendRegistry._supported_arg_names.pop(AttentionBackendName.NATIVE)
- if _is_native_attn_supported_context_parallel():
- if isinstance(_AttentionBackendRegistry._supports_context_parallel, dict):
- _AttentionBackendRegistry._supports_context_parallel.pop(AttentionBackendName.NATIVE)
- else:
- _AttentionBackendRegistry._supports_context_parallel.remove(
- AttentionBackendName.NATIVE.value
- )
-
- # Re-define templated context parallel attention to support attn mask
- def _templated_context_parallel_attention_v2(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- scale: Optional[float] = None,
- enable_gqa: bool = False,
- return_lse: bool = False,
- *,
- forward_op,
- backward_op,
- _parallel_config: Optional["ParallelConfig"] = None,
- ):
- if attn_mask is not None:
- # NOTE(DefTruth): Check if forward_op is native attention forward op
- forward_op_name = forward_op.__name__
- if not forward_op_name == "_native_attention_forward_op":
- raise ValueError(
- "Templated context parallel attention with attn_mask "
- "is only supported for native attention backend, "
- f"but got forward_op: {forward_op_name}."
- )
- if is_causal:
- raise ValueError("Causal attention is not yet supported for templated attention.")
- if enable_gqa:
- raise ValueError("GQA is not yet supported for templated attention.")
-
- # TODO: add support for unified attention with ring/ulysses degree both being > 1
- if _parallel_config.context_parallel_config.ring_degree > 1:
- return TemplatedRingAttention.apply(
- query,
- key,
- value,
- attn_mask,
- dropout_p,
- is_causal,
- scale,
- enable_gqa,
- return_lse,
- forward_op,
- backward_op,
- _parallel_config,
- )
- elif _parallel_config.context_parallel_config.ulysses_degree > 1:
- if is_ulysses_anything_enabled():
- return TemplatedUlyssesAnythingAttention.apply(
- query,
- key,
- value,
- attn_mask,
- dropout_p,
- is_causal,
- scale,
- enable_gqa,
- return_lse,
- forward_op,
- backward_op,
- _parallel_config,
- )
- else:
- return TemplatedUlyssesAttention.apply(
- query,
- key,
- value,
- attn_mask,
- dropout_p,
- is_causal,
- scale,
- enable_gqa,
- return_lse,
- forward_op,
- backward_op,
- _parallel_config,
- )
- else:
- raise ValueError("Reaching this branch of code is unexpected. Please report a bug.")
-
- # NOTE:Remove NATIVE attention backend constraints and re-register it.
- # Here is a temporary workaround to enable context parallelism with
- # native attention backend. We should remove this workaround after
- # the native attention backend supports context parallelism natively.
- # Adapted from: https://github.com/huggingface/diffusers/pull/12563
-
- def _native_attention_forward_op(
- ctx: torch.autograd.function.FunctionCtx,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- scale: Optional[float] = None,
- enable_gqa: bool = False,
- return_lse: bool = False,
- _save_ctx: bool = True,
- _parallel_config: Optional["ParallelConfig"] = None,
- ):
- # Native attention does not return_lse
- if return_lse:
- raise ValueError("Native attention does not support return_lse=True")
-
- # used for backward pass
- if _save_ctx:
- ctx.save_for_backward(query, key, value)
- ctx.attn_mask = attn_mask
- ctx.dropout_p = dropout_p
- ctx.is_causal = is_causal
- ctx.scale = scale
- ctx.enable_gqa = enable_gqa
-
- query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
- out = torch.nn.functional.scaled_dot_product_attention(
- query=query,
- key=key,
- value=value,
- attn_mask=attn_mask,
- dropout_p=dropout_p,
- is_causal=is_causal,
- scale=scale,
- enable_gqa=enable_gqa,
- )
- out = out.permute(0, 2, 1, 3)
-
- return out
-
- def _native_attention_backward_op(
- ctx: torch.autograd.function.FunctionCtx,
- grad_out: torch.Tensor,
- *args,
- **kwargs,
- ):
- query, key, value = ctx.saved_tensors
-
- query.requires_grad_(True)
- key.requires_grad_(True)
- value.requires_grad_(True)
-
- query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
- out = torch.nn.functional.scaled_dot_product_attention(
- query=query_t,
- key=key_t,
- value=value_t,
- attn_mask=ctx.attn_mask,
- dropout_p=ctx.dropout_p,
- is_causal=ctx.is_causal,
- scale=ctx.scale,
- enable_gqa=ctx.enable_gqa,
- )
- out = out.permute(0, 2, 1, 3)
-
- grad_out_t = grad_out.permute(0, 2, 1, 3)
- grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
- outputs=out,
- inputs=[query_t, key_t, value_t],
- grad_outputs=grad_out_t,
- retain_graph=False,
- )
-
- grad_query = grad_query_t.permute(0, 2, 1, 3)
- grad_key = grad_key_t.permute(0, 2, 1, 3)
- grad_value = grad_value_t.permute(0, 2, 1, 3)
-
- return grad_query, grad_key, grad_value
-
- @_AttentionBackendRegistry.register(
- AttentionBackendName.NATIVE,
- constraints=[_check_device, _check_shape],
- supports_context_parallel=True,
- )
- def _native_attention(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor] = None,
- dropout_p: float = 0.0,
- is_causal: bool = False,
- scale: Optional[float] = None,
- enable_gqa: bool = False,
- return_lse: bool = False,
- _parallel_config: Optional["ParallelConfig"] = None,
- ) -> torch.Tensor:
- if return_lse:
- raise ValueError("Native attention backend does not support setting `return_lse=True`.")
- if _parallel_config is None:
- query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
- out = torch.nn.functional.scaled_dot_product_attention(
- query=query,
- key=key,
- value=value,
- attn_mask=attn_mask,
- dropout_p=dropout_p,
- is_causal=is_causal,
- scale=scale,
- enable_gqa=enable_gqa,
- )
- out = out.permute(0, 2, 1, 3)
- else:
- out = _templated_context_parallel_attention_v2(
- query,
- key,
- value,
- attn_mask,
- dropout_p,
- is_causal,
- scale,
- enable_gqa,
- return_lse,
- forward_op=_native_attention_forward_op,
- backward_op=_native_attention_backward_op,
- _parallel_config=_parallel_config,
- )
- return out
-
-else:
- from diffusers.models.attention_dispatch import (
- _native_attention,
- ) # noqa: F401
-
- logger.info("Native attention backend already supports context parallelism.")
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_templated_ulysses_anything.py b/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_templated_ulysses_anything.py
deleted file mode 100644
index 6ce27fb43..000000000
--- a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_templated_ulysses_anything.py
+++ /dev/null
@@ -1,401 +0,0 @@
-import os
-import copy
-import functools
-from typing import Optional, Tuple, List
-
-import torch
-import torch.distributed as dist
-import torch.distributed._functional_collectives as fc
-
-try:
- from diffusers.models._modeling_parallel import ParallelConfig
- from diffusers.hooks.context_parallel import EquipartitionSharder
-except ImportError:
- raise ImportError(
- "Context parallelism requires the 'diffusers>=0.36.dev0'."
- "Please install latest version of diffusers from source: \n"
- "pip3 install git+https://github.com/huggingface/diffusers.git"
- )
-from cache_dit.logger import init_logger
-
-logger = init_logger(__name__)
-
-__all__ = [
- "TemplatedUlyssesAnythingAttention",
- "EquipartitionSharder",
- "enable_ulysses_anything",
- "is_ulysses_anything_enabled",
- "disable_ulysses_anything",
-]
-
-
-# Reference:
-# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L827
-# - https://github.com/pytorch/pytorch/blob/f58a680d09e13658a52c6ba05c63c15759846bcc/torch/distributed/_functional_collectives.py#L246
-# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_dispatch.py#L1012
-# For fullgraph=True tracing compatibility (since FakeTensor does not have a `wait` method):
-def _wait_tensor(tensor):
- if isinstance(tensor, fc.AsyncCollectiveTensor):
- tensor = tensor.wait()
-
- return tensor
-
-
-def _get_rank_world_size(
- group: dist.ProcessGroup,
-) -> Tuple[int, int]:
- world_size = dist.get_world_size(group=group)
- rank = dist.get_rank(group=group)
- return rank, world_size
-
-
-@functools.lru_cache(maxsize=128)
-def _gather_size_by_comm(S_LOCAL: int, group: dist.ProcessGroup) -> List[int]:
- world_size = dist.get_world_size(group=group)
- # HACK: Use Gloo backend for all_gather to avoid H2D and D2H overhead
- comm_backends = str(dist.get_backend(group=group))
- # NOTE: e.g., dist.init_process_group(backend="cpu:gloo,cuda:nccl")
- gather_device = "cpu" if "cpu" in comm_backends else torch.device("cuda")
- gathered_sizes = [
- torch.empty((1,), device=gather_device, dtype=torch.int64) for _ in range(world_size)
- ]
- dist.all_gather(
- gathered_sizes,
- torch.tensor([S_LOCAL], device=gather_device, dtype=torch.int64),
- group=group,
- )
-
- gathered_sizes = [s[0].item() for s in gathered_sizes]
- # NOTE: DON'T use tolist here due to graph break - Explanation:
- # Backend compiler `inductor` failed with aten._local_scalar_dense.default
- return gathered_sizes
-
-
-@torch.compiler.allow_in_graph
-def _all_to_all_single_any_qkv(
- x: torch.Tensor,
- group: dist.ProcessGroup,
-) -> torch.Tensor:
- shape = x.shape # (world_size, S_LOCAL, B, H_LOCAL, D)
- (world_size, S_LOCAL, B, H_LOCAL, D) = shape
- input_split_sizes = [S_LOCAL] * world_size
- # S_LOCAL maybe not equal for all ranks in dynamic shape case,
- # since we don't know the actual shape before this timing, thus,
- # we have to use all gather to collect the S_LOCAL first.
- output_split_sizes = _gather_size_by_comm(S_LOCAL, group)
- # NOTE: The `if` branch will introduce graph break for torch.compile,
- # so, we choose to disable the even split optimization implementation
- # _all_to_all_single for now.
- x = x.flatten(0, 1) # (world_size * S_LOCAL, B, H_LOCAL, D)
- x = fc.all_to_all_single(x, output_split_sizes, input_split_sizes, group)
- x = _wait_tensor(x) # (S_GLOBAL, B, H_LOCAL, D)
- return x
-
-
-@torch.compiler.allow_in_graph
-def _all_to_all_single_any_o(
- out: torch.Tensor,
- group: dist.ProcessGroup,
-) -> torch.Tensor:
- rank, world_size = _get_rank_world_size(group)
- shape = out.shape # (B, S_GLOBAL, H_LOCAL, D)
- (B, S_GLOBAL, H_LOCAL, D) = shape
-
- # NOTE: The `if` branch will introduce graph break for torch.compile,
- # so, we choose to disable the even split optimization implementation
- # _all_to_all_single for now.
- out = out.flatten(0, 1).contiguous() # (B*S_GLOBAL, H_LOCAL, D)
- # NOTE: May use tensor_split here to ensure the same split policy
- # that we have used in the EquipartitionSharder sharding strategy. Please
- # note that the 'tensor_split' Splits a tensor into multiple sub-tensors,
- # all of which are views of input, thus may not introduce extra IO access.
- input_split_sizes = [o.shape[0] for o in torch.tensor_split(out, world_size, dim=0)]
- # input_split: e.g, B*S_GLOBAL=1*9 input splits across ranks [[5,4], [5,4],..]
- # output_split: e.g, B*S_GLOBAL=1*9 output splits across ranks [[5,5], [4,4],..]
- output_split_sizes = [input_split_sizes[rank]] * world_size
- out = fc.all_to_all_single(out, output_split_sizes, input_split_sizes, group)
- out = _wait_tensor(out) # (S_LOCAL*world_size, H_LOCAL, D)
- # NOTE: We can not simply reshape here, because the collective tensors
- # are stacked at dim=0(SeqLen), we need to first split them and then concat at
- # dim=1(Head), otherwise the result will be incorrect due to the linear layout
- # of the tensor in memory.
- H_GLOBAL = H_LOCAL * world_size
- S_LOCAL = out.shape[0] // world_size
- # TODO: How to avoid extra memory IO access here?
- out = torch.cat(out.tensor_split(world_size, dim=0), dim=1) # (B*S_LOCAL, H_GLOBAL, D)
- out = out.reshape(B, S_LOCAL, H_GLOBAL, D) # (B, S_LOCAL, H_GLOBAL, D)
- return out
-
-
-@torch.compiler.allow_in_graph
-def _gather_split_any_o( # noqa: F811
- out: torch.Tensor,
- group: dist.ProcessGroup,
-) -> torch.Tensor:
- # NOTE: This is an alternative implementation of _all_to_all_single
- # for any o. It use all_gather and split, which may be less efficient.
- rank, world_size = _get_rank_world_size(group)
- # (B, S_GLOBAL, H_LOCAL, D)
- # all gather to get (B, S_GLOBAL, H_GLOBAL, D) at H_GLOBAL dim
- out_gathered = [torch.empty_like(out) for _ in range(world_size)]
- dist.all_gather(out_gathered, out, group=group)
- out_gathered = torch.cat(out_gathered, dim=2)
- # (B, S_GLOBAL, H_GLOBAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D)
- out = out_gathered.tensor_split(world_size, dim=1)[rank]
- return out
-
-
-class TemplatedUlyssesAnythingAttention(torch.autograd.Function):
-
- @staticmethod
- def forward(
- ctx: torch.autograd.function.FunctionCtx,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attn_mask: Optional[torch.Tensor],
- dropout_p: float,
- is_causal: bool,
- scale: Optional[float],
- enable_gqa: bool,
- return_lse: bool,
- forward_op,
- backward_op,
- _parallel_config: Optional["ParallelConfig"] = None,
- **kwargs,
- ):
- ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
- world_size = _parallel_config.context_parallel_config.ulysses_degree
- group = ulysses_mesh.get_group()
-
- ctx.forward_op = forward_op
- ctx.backward_op = backward_op
- ctx._parallel_config = _parallel_config
-
- B, S_Q_LOCAL, H, D = query.shape
- _, S_KV_LOCAL, _, _ = key.shape
- H_LOCAL = H // world_size
- # (world_size, S_LOCAL, B, H_LOCAL, D)
- query = (
- query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
- )
- key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
- value = (
- value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
- )
- query, key, value = (_all_to_all_single_any_qkv(x, group) for x in (query, key, value))
- # (S_GLOBAL, B, H_LOCAL, D) -> (B, S_GLOBAL, H_LOCAL, D)
- query, key, value = (x.permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
-
- out = forward_op(
- ctx,
- query,
- key,
- value,
- attn_mask,
- dropout_p,
- is_causal,
- scale,
- enable_gqa,
- return_lse,
- _save_ctx=True,
- _parallel_config=_parallel_config,
- )
- if return_lse:
- out, lse, *_ = out
-
- # out: (B, S_Q_GLOBAL, H_LOCAL, D) -> (B, S_Q_LOCAL, H_GLOBAL, D)
- out = _all_to_all_single_any_o(out, group).contiguous()
-
- if return_lse:
- # lse: (B, S_Q_GLOBAL, H_LOCAL)
- lse = lse.unsqueeze(-1) # (B, S_Q_GLOBAL, H_LOCAL, D=1)
- lse = (
- _all_to_all_single_any_o(lse, group).squeeze(-1).contiguous()
- ) # (B, S_Q_LOCAL, H_GLOBAL)
- else:
- lse = None
-
- return (out, lse) if return_lse else out
-
- @staticmethod
- def backward(
- ctx: torch.autograd.function.FunctionCtx,
- grad_out: torch.Tensor,
- *args,
- ):
- raise NotImplementedError(
- "Backward pass for Ulysses Anything Attention is not implemented yet."
- )
-
-
-@functools.lru_cache(maxsize=64)
-def _fill_gather_shapes(
- shape: Tuple[int], gather_dims: Tuple[int], dim: int, world_size: int
-) -> List[List[int]]:
- gather_shapes = []
- for i in range(world_size):
- # WARN: deepcopy to avoid modifying the original shape
- rank_shape = list(copy.deepcopy(shape))
- rank_shape[dim] = gather_dims[i]
- gather_shapes.append(rank_shape)
- return gather_shapes
-
-
-@torch.compiler.allow_in_graph
-def _all_gather_anything( # noqa: F811
- tensor: torch.Tensor,
- dim: int,
- group: dist.device_mesh.DeviceMesh,
-) -> torch.Tensor:
- _, world_size = _get_rank_world_size(group)
- tensor = tensor.contiguous()
- shape = tensor.shape
- rank_dim = shape[dim]
- gather_dims = _gather_size_by_comm(rank_dim, group)
-
- # NOTE: The `if` branch will introduce graph break for torch.compile,
- # so, we choose to disable the even split optimization for now.
-
- gather_shapes = _fill_gather_shapes(
- tuple(shape),
- tuple(gather_dims),
- dim,
- world_size,
- )
-
- gathered_tensors = [
- torch.empty(
- shape,
- device=tensor.device,
- dtype=tensor.dtype,
- )
- for shape in gather_shapes
- ]
-
- dist.all_gather(gathered_tensors, tensor, group=group)
- gathered_tensor = torch.cat(gathered_tensors, dim=dim)
- return gathered_tensor
-
-
-# NOTE: dist.all_gather, Gathers tensors from the whole group in a list.
-# Complex and uneven sized tensors are supported.
-class AllGatherAnythingFunction(torch.autograd.Function):
-
- @staticmethod
- def forward(
- ctx,
- tensor: torch.Tensor,
- dim: int,
- group: dist.device_mesh.DeviceMesh,
- ):
- ctx.dim = dim
- ctx.group = group
- ctx.world_size = dist.get_world_size(group)
- ctx.rank = dist.get_rank(group)
- gathered_tensor = _all_gather_anything(tensor, dim, group)
- return gathered_tensor
-
- @staticmethod
- def backward(ctx, grad_output):
- # NOTE: We use `tensor_split` instead of chunk, because the `chunk`
- # function may return fewer than the specified number of chunks!
- grad_splits = torch.tensor_split(grad_output, ctx.world_size, dim=ctx.dim)
- return grad_splits[ctx.rank], None, None
-
-
-# NOTE: We use `tensor_split` instead of chunk, because the `chunk`
-# function may return fewer than the specified number of chunks! For example,
-# x = torch.tensor([1,2,3,4,5]), torch.chunk(x, 4) will return only 3 chunks:
-# (tensor([1, 2]), tensor([3, 4]), tensor([5])). This behavior can lead to
-# inconsistencies when sharding tensors across multiple devices. In contrast,
-# tensor_split will always return the specified number of chunks, the last chunk
-# may be smaller if the tensor size is not divisible by the number of chunks.
-# For example, torch.tensor_split(x, 4) will return 4 chunks:
-# (tensor([1, 2]), tensor([3]), tensor([4]), tensor([5])).
-@classmethod
-@functools.wraps(EquipartitionSharder.shard)
-def shard_anything(
- cls: EquipartitionSharder,
- tensor: torch.Tensor,
- dim: int,
- mesh: dist.device_mesh.DeviceMesh,
- **kwargs,
-) -> torch.Tensor:
- assert tensor.size()[dim] >= mesh.size(), (
- f"Cannot shard tensor of size {tensor.size()} along dim {dim} "
- f"across mesh of size {mesh.size()}."
- )
- return tensor.tensor_split(mesh.size(), dim=dim)[dist.get_rank(mesh.get_group())]
-
-
-# NOTE: We use AllGatherAnythingFunction to support gathering
-# tensors with complex and uneven sizes across all ranks. It handles the
-# case where the tensor size (the seq_len of hidden_states) along the
-# specified dimension is not divisible by the number of ranks in the mesh.
-@classmethod
-@functools.wraps(EquipartitionSharder.unshard)
-def unshard_anything(
- cls,
- tensor: torch.Tensor,
- dim: int,
- mesh: torch.distributed.device_mesh.DeviceMesh,
- **kwargs,
-) -> torch.Tensor:
- tensor = tensor.contiguous()
- tensor = AllGatherAnythingFunction.apply(tensor, dim, mesh.get_group())
- return tensor
-
-
-_CACHE_DIT_ENABELD_ULYSSES_ANYTHING = (
- os.environ.get("CACHE_DIT_ENABELD_ULYSSES_ANYTHING", "0") == "1"
-)
-
-
-def enable_ulysses_anything(**kwargs):
- global _CACHE_DIT_ENABELD_ULYSSES_ANYTHING
- try:
- if _CACHE_DIT_ENABELD_ULYSSES_ANYTHING:
- # function for TemplatedUlyssesAnythingAttention.
- if EquipartitionSharder.shard != shard_anything:
- EquipartitionSharder.shard = shard_anything
- EquipartitionSharder.unshard = unshard_anything
- logger.warning(
- "Ulysses Anything Attention is already enabled in cache-dit. "
- "but EquipartitionSharder.shard/unshard is not set correctly, "
- "resetting it to the correct shard/unshard_anything function."
- )
- return
-
- _CACHE_DIT_ENABELD_ULYSSES_ANYTHING = True
-
- logger.warning(
- "Ulysses Anything Attention is enabled in cache-dit. "
- "Please note that this is an experimental feature and "
- "may not be fully tested."
- )
-
- # Ensure the EquipartitionSharder uses our modified shard_anything
- # function for TemplatedUlyssesAnythingAttention.
- if EquipartitionSharder.shard != shard_anything:
- EquipartitionSharder.shard = shard_anything
- EquipartitionSharder.unshard = unshard_anything
- logger.info(
- "EquipartitionSharder.shard/unshard is set to shard/unshard_anything function "
- "for Ulysses Anything Attention."
- )
- except Exception as e:
- _CACHE_DIT_ENABELD_ULYSSES_ANYTHING = False
- logger.error(f"Failed to enable Ulysses Anything Attention in cache-dit due to error: {e}")
- pass
-
-
-def is_ulysses_anything_enabled(**kwargs) -> bool:
- global _CACHE_DIT_ENABELD_ULYSSES_ANYTHING
- return _CACHE_DIT_ENABELD_ULYSSES_ANYTHING
-
-
-def disable_ulysses_anything(**kwargs):
- global _CACHE_DIT_ENABELD_ULYSSES_ANYTHING
- _CACHE_DIT_ENABELD_ULYSSES_ANYTHING = False
- logger.info("Ulysses Anything Attention is manually disabled in cache-dit.")
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py b/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py
deleted file mode 100644
index 711836f18..000000000
--- a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import torch
-from typing import Optional
-from diffusers.models.modeling_utils import ModelMixin
-
-try:
- from diffusers.models._modeling_parallel import (
- ContextParallelInput,
- ContextParallelOutput,
- ContextParallelModelPlan,
- )
-except ImportError:
- raise ImportError(
- "Context parallelism requires the 'diffusers>=0.36.dev0'."
- "Please install latest version of diffusers from source: \n"
- "pip3 install git+https://github.com/huggingface/diffusers.git"
- )
-from .cp_plan_registers import (
- ContextParallelismPlanner,
- ContextParallelismPlannerRegister,
-)
-
-from cache_dit.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-@ContextParallelismPlannerRegister.register("QwenImage")
-class QwenImageContextParallelismPlanner(ContextParallelismPlanner):
- def apply(
- self,
- transformer: Optional[torch.nn.Module | ModelMixin] = None,
- **kwargs,
- ) -> ContextParallelModelPlan:
-
- # NOTE: Set it as False to use custom CP plan defined here.
- self._cp_planner_preferred_native_diffusers = False
-
- if transformer is not None and self._cp_planner_preferred_native_diffusers:
- from diffusers import QwenImageTransformer2DModel
-
- assert isinstance(
- transformer, QwenImageTransformer2DModel
- ), "Transformer must be an instance of QwenImageTransformer2DModel"
- if hasattr(transformer, "_cp_plan"):
- if transformer._cp_plan is not None:
- return transformer._cp_plan
-
- # Otherwise, use the custom CP plan defined here, this maybe
- # a little different from the native diffusers implementation
- # for some models.
- _cp_plan = {
- # Here is a Transformer level CP plan for Qwen-Image, which will
- # only apply the only 1 split hook (pre_forward) on the forward
- # of Transformer, and gather the output after Transformer forward.
- # Pattern of transformer forward, split_output=False:
- # un-split input -> splited input (inside transformer)
- # Pattern of the transformer_blocks, single_transformer_blocks:
- # splited input (previous splited output) -> to_qkv/...
- # -> all2all
- # -> attn (local head, full seqlen)
- # -> all2all
- # -> splited output
- # The `hidden_states` and `encoder_hidden_states` will still keep
- # itself splited after block forward (namely, automatic split by
- # the all2all comm op after attn) for the all blocks.
- "": {
- "hidden_states": ContextParallelInput(
- split_dim=1, expected_dims=3, split_output=False
- ),
- # NOTE: Due to the joint attention implementation of
- # QwenImageTransformerBlock, we must split the
- # encoder_hidden_states as well.
- "encoder_hidden_states": ContextParallelInput(
- split_dim=1, expected_dims=3, split_output=False
- ),
- # NOTE: But encoder_hidden_states_mask seems never used in
- # QwenImageTransformerBlock, so we do not split it here.
- # "encoder_hidden_states_mask": ContextParallelInput(
- # split_dim=1, expected_dims=2, split_output=False
- # ),
- },
- # Pattern of pos_embed, split_output=True (split output rather than input):
- # un-split input
- # -> keep input un-split
- # -> rope
- # -> splited output
- "pos_embed": {
- 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
- 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
- },
- # Then, the final proj_out will gather the splited output.
- # splited input (previous splited output)
- # -> all gather
- # -> un-split output
- "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
- }
- return _cp_plan
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py b/src/cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py
deleted file mode 100644
index 4d7b5ac5f..000000000
--- a/src/cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import torch
-
-from typing import Optional
-from cache_dit.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-from diffusers.models.modeling_utils import ModelMixin
-from cache_dit.parallelism.parallel_backend import ParallelismBackend
-from cache_dit.parallelism.parallel_config import ParallelismConfig
-from .context_parallelism import maybe_enable_context_parallelism
-
-
-def maybe_enable_parallelism(
- transformer: torch.nn.Module,
- parallelism_config: Optional[ParallelismConfig],
-) -> torch.nn.Module:
- assert isinstance(transformer, ModelMixin), (
- "transformer must be an instance of diffusers' ModelMixin, " f"but got {type(transformer)}"
- )
- if parallelism_config is None:
- return transformer
-
- assert isinstance(parallelism_config, ParallelismConfig), (
- "parallelism_config must be an instance of ParallelismConfig"
- f" but got {type(parallelism_config)}"
- )
-
- assert parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER, (
- f"parallelism backend must be {ParallelismBackend.NATIVE_DIFFUSER}, "
- f"but got {parallelism_config.backend}"
- )
-
- if parallelism_config.ulysses_size is not None or parallelism_config.ring_size is not None:
- transformer = maybe_enable_context_parallelism(
- transformer,
- parallelism_config,
- )
- else:
- raise ValueError(
- "NATIVE_DIFFUSER backend only support context parallelism now. "
- "Please set ulysses_size or ring_size in parallelism_config."
- )
- return transformer
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/utils.py b/src/cache_dit/parallelism/backends/native_diffusers/utils.py
deleted file mode 100644
index eda94ab47..000000000
--- a/src/cache_dit/parallelism/backends/native_diffusers/utils.py
+++ /dev/null
@@ -1,11 +0,0 @@
-try:
- from diffusers import ContextParallelConfig
-
- def native_diffusers_parallelism_available() -> bool:
- return True
-
-except ImportError:
- ContextParallelConfig = None
-
- def native_diffusers_parallelism_available() -> bool:
- return False
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/__init__.py b/src/cache_dit/parallelism/backends/native_pytorch/__init__.py
deleted file mode 100644
index 3015a9672..000000000
--- a/src/cache_dit/parallelism/backends/native_pytorch/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from cache_dit.parallelism.backends.native_pytorch.tensor_parallelism import (
- TensorParallelismPlannerRegister,
-)
-from cache_dit.parallelism.backends.native_pytorch.parallel_torch import (
- maybe_enable_parallelism,
-)
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/parallel_torch.py b/src/cache_dit/parallelism/backends/native_pytorch/parallel_torch.py
deleted file mode 100644
index 7fc1d3d6c..000000000
--- a/src/cache_dit/parallelism/backends/native_pytorch/parallel_torch.py
+++ /dev/null
@@ -1,54 +0,0 @@
-from typing import Optional
-
-import torch
-
-from diffusers.models.modeling_utils import ModelMixin
-
-from cache_dit.parallelism.parallel_backend import ParallelismBackend
-from cache_dit.parallelism.parallel_config import ParallelismConfig
-
-from cache_dit.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-def maybe_enable_parallelism(
- transformer: torch.nn.Module | ModelMixin,
- parallelism_config: Optional[ParallelismConfig],
-) -> torch.nn.Module:
- assert isinstance(transformer, torch.nn.Module), (
- "transformer must be an instance of torch.nn.Module, " f"but got {type(transformer)}"
- )
- assert isinstance(transformer, ModelMixin), (
- "transformer must be an instance of diffusers' ModelMixin, " f"but got {type(transformer)}"
- )
- if parallelism_config is None:
- return transformer
-
- assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
- "parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
- f"but got {parallelism_config.backend}"
- )
-
- assert isinstance(parallelism_config, ParallelismConfig), (
- "parallelism_config must be an instance of ParallelismConfig"
- f" but got {type(parallelism_config)}"
- )
- assert parallelism_config.ulysses_size is None and parallelism_config.ring_size is None, (
- "Ulysses/Ring parallelism is not supported in Native_PyTorch backend. "
- "Please set it to None in parallelism_config."
- )
-
- if parallelism_config.tp_size is not None and parallelism_config.tp_size > 1:
- from .tensor_parallelism import maybe_enable_tensor_parallelism
-
- transformer = maybe_enable_tensor_parallelism(
- transformer=transformer,
- parallelism_config=parallelism_config,
- )
- else:
- raise ValueError(
- "NATIVE_PYTORCH only supported tensor parallelism now. "
- "Please set tp_size > 1 for tensor parallelism."
- )
- return transformer
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py b/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py
deleted file mode 100644
index ab3a0e41b..000000000
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py
+++ /dev/null
@@ -1,26 +0,0 @@
-# NOTE: must import all planner classes to register them
-from .tp_plan_cogview import CogViewTensorParallelismPlanner
-from .tp_plan_flux import FluxTensorParallelismPlanner
-from .tp_plan_hunyuan_dit import HunyuanDiTTensorParallelismPlanner
-from .tp_plan_kandinsky5 import Kandinsky5TensorParallelismPlanner
-from .tp_plan_mochi import MochiTensorParallelismPlanner
-from .tp_plan_ltx_video import LTXVideoTensorParallelismPlanner
-from .tp_plan_pixart import PixArtTensorParallelismPlanner
-from .tp_plan_qwen_image import QwenImageTensorParallelismPlanner
-from .tp_plan_registers import TensorParallelismPlannerRegister
-from .tp_plan_wan import WanTensorParallelismPlanner
-from .tp_plan_skyreels import SkyReelsV2TensorParallelismPlanner
-
-__all__ = [
- "CogViewTensorParallelismPlanner",
- "FluxTensorParallelismPlanner",
- "HunyuanDiTTensorParallelismPlanner",
- "Kandinsky5TensorParallelismPlanner",
- "MochiTensorParallelismPlanner",
- "LTXVideoTensorParallelismPlanner",
- "PixArtTensorParallelismPlanner",
- "QwenImageTensorParallelismPlanner",
- "TensorParallelismPlannerRegister",
- "WanTensorParallelismPlanner",
- "SkyReelsV2TensorParallelismPlanner",
-]
diff --git a/src/cache_dit/parallelism/config.py b/src/cache_dit/parallelism/config.py
new file mode 100644
index 000000000..e1533ee41
--- /dev/null
+++ b/src/cache_dit/parallelism/config.py
@@ -0,0 +1,188 @@
+import dataclasses
+from typing import Optional, Dict, Any
+from cache_dit.parallelism.backend import ParallelismBackend
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+@dataclasses.dataclass
+class ParallelismConfig:
+ # Parallelism backend, defaults to AUTO. We will auto select the backend
+ # based on the parallelism configuration.
+ backend: ParallelismBackend = ParallelismBackend.AUTO
+ # Context parallelism config
+ # ulysses_size (`int`, *optional*):
+ # The degree of ulysses parallelism.
+ ulysses_size: int = None
+ # ring_size (`int`, *optional*):
+ # The degree of ring parallelism.
+ ring_size: int = None
+ # Tensor parallelism config
+ # tp_size (`int`, *optional*):
+ # The degree of tensor parallelism.
+ tp_size: int = None
+ # parallel_kwargs (`dict`, *optional*):
+ # Additional kwargs for parallelism backends. For example, for
+ # NATIVE_DIFFUSER backend, it can include `cp_plan` and
+ # `attention_backend` arguments for `Context Parallelism`.
+ parallel_kwargs: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict)
+ # Some internal fields for utils usage
+ _has_text_encoder: bool = False
+ _has_auto_encoder: bool = False
+ _has_controlnet: bool = False
+
+ def __post_init__(self):
+ assert ParallelismBackend.is_supported(self.backend), (
+ f"Parallel backend {self.backend} is not supported. "
+ f"Please make sure the required packages are installed."
+ )
+ if self.backend == ParallelismBackend.AUTO:
+ # Auto select the backend based on the parallelism configuration
+ if (self.ulysses_size is not None and self.ulysses_size > 1) or (
+ self.ring_size is not None and self.ring_size > 1
+ ):
+ self.backend = ParallelismBackend.NATIVE_DIFFUSER
+ elif self.tp_size is not None and self.tp_size > 1:
+ self.backend = ParallelismBackend.NATIVE_PYTORCH
+ else:
+ self.backend = ParallelismBackend.NONE
+ logger.info(f"Auto selected parallelism backend for transformer: {self.backend}")
+
+ # Validate the parallelism configuration and auto adjust the backend if needed
+ if self.tp_size is not None and self.tp_size > 1:
+ assert (
+ self.ulysses_size is None or self.ulysses_size == 1
+ ), "Tensor parallelism plus Ulysses parallelism is not supported right now."
+ assert (
+ self.ring_size is None or self.ring_size == 1
+ ), "Tensor parallelism plus Ring parallelism is not supported right now."
+ if self.backend != ParallelismBackend.NATIVE_PYTORCH:
+ logger.warning(
+ "Tensor parallelism is only supported for NATIVE_PYTORCH backend "
+ "right now. Force set backend to NATIVE_PYTORCH."
+ )
+ self.backend = ParallelismBackend.NATIVE_PYTORCH
+ elif (
+ self.ulysses_size is not None
+ and self.ulysses_size > 1
+ and self.ring_size is not None
+ and self.ring_size > 1
+ ):
+ raise ValueError(
+ "Ulysses parallelism plus Ring parallelism is not fully supported right now."
+ )
+ else:
+ if (self.ulysses_size is not None and self.ulysses_size > 1) or (
+ self.ring_size is not None and self.ring_size > 1
+ ):
+ if self.backend != ParallelismBackend.NATIVE_DIFFUSER:
+ logger.warning(
+ "Ulysses/Ring parallelism is only supported for NATIVE_DIFFUSER "
+ "backend right now. Force set backend to NATIVE_DIFFUSER."
+ )
+ self.backend = ParallelismBackend.NATIVE_DIFFUSER
+
+ def enabled(self) -> bool:
+ return (
+ (self.ulysses_size is not None and self.ulysses_size > 1)
+ or (self.ring_size is not None and self.ring_size > 1)
+ or (self.tp_size is not None and self.tp_size > 1)
+ )
+
+ def strify(
+ self,
+ details: bool = False,
+ text_encoder: bool = False,
+ vae: bool = False,
+ controlnet: bool = False,
+ ) -> str:
+ if details:
+ if text_encoder or vae:
+ extra_module_world_size = self._get_extra_module_world_size()
+ # Currently, only support tensor parallelism or data parallelism
+ # for extra modules using pytorch native backend or pure pytorch
+ # implementation. So we just hardcode the backend here.
+ parallel_str = f"ParallelismConfig(backend={ParallelismBackend.NATIVE_PYTORCH}, "
+
+ if text_encoder:
+ parallel_str += f"tp_size={extra_module_world_size}, "
+ elif controlnet:
+ parallel_str += f"ulysses_size={extra_module_world_size}, "
+ else:
+ parallel_str += f"dp_size={extra_module_world_size}, "
+ parallel_str = parallel_str.rstrip(", ") + ")"
+ return parallel_str
+
+ parallel_str = f"ParallelismConfig(backend={self.backend}, "
+ if self.ulysses_size is not None:
+ parallel_str += f"ulysses_size={self.ulysses_size}, "
+ if self.ring_size is not None:
+ parallel_str += f"ring_size={self.ring_size}, "
+ if self.tp_size is not None:
+ parallel_str += f"tp_size={self.tp_size}, "
+ parallel_str = parallel_str.rstrip(", ") + ")"
+ return parallel_str
+ else:
+ parallel_str = ""
+ if self.ulysses_size is not None:
+ parallel_str += f"Ulysses{self.ulysses_size}"
+ if self.ring_size is not None:
+ parallel_str += f"Ring{self.ring_size}"
+ if self.tp_size is not None:
+ parallel_str += f"TP{self.tp_size}"
+ if text_encoder or self._has_text_encoder:
+ parallel_str += "_TEP" # Text Encoder Parallelism
+ if vae or self._has_auto_encoder:
+ parallel_str += "_VAEP" # VAE Parallelism
+ if controlnet or self._has_controlnet:
+ parallel_str += "_CNP" # ControlNet Parallelism
+ return parallel_str
+
+ def _get_extra_module_world_size(self) -> Optional[int]:
+ """Get the world size for extra parallel modules, e.g., text encoder and VAE."""
+ # Maximize the parallel size for extra modules: max(tp_size, ulysses_size, ring_size)
+ sizes = []
+ if self.tp_size is not None and self.tp_size > 1:
+ sizes.append(self.tp_size)
+ if self.ulysses_size is not None and self.ulysses_size > 1:
+ sizes.append(self.ulysses_size)
+ if self.ring_size is not None and self.ring_size > 1:
+ sizes.append(self.ring_size)
+ if sizes:
+ return max(sizes)
+ return None
+
+ @property
+ def text_encoder_world_size(self) -> int:
+ """Get the world size for text encoder parallelism."""
+ world_size = self._get_extra_module_world_size()
+ assert (
+ world_size is None or world_size > 1
+ ), "Text encoder world size must be None or greater than 1 for parallelism."
+ self._has_text_encoder = True
+ return world_size
+
+ @property
+ def auto_encoder_world_size(self) -> int:
+ """Get the world size for VAE parallelism."""
+ world_size = self._get_extra_module_world_size()
+ assert (
+ world_size is None or world_size > 1
+ ), "VAE world size must be None or greater than 1 for parallelism."
+ self._has_auto_encoder = True
+ return world_size
+
+ @property
+ def vae_world_size(self) -> int: # alias of auto_encoder_world_size
+ return self.vae_world_size
+
+ @property
+ def controlnet_world_size(self) -> int:
+ """Get the world size for ControlNet parallelism."""
+ world_size = self._get_extra_module_world_size()
+ assert (
+ world_size is None or world_size > 1
+ ), "ControlNet world size must be None or greater than 1 for parallelism."
+ self._has_controlnet = True
+ return world_size
diff --git a/src/cache_dit/parallelism/controlnets/__init__.py b/src/cache_dit/parallelism/controlnets/__init__.py
new file mode 100644
index 000000000..ba6e9d596
--- /dev/null
+++ b/src/cache_dit/parallelism/controlnets/__init__.py
@@ -0,0 +1 @@
+from .dispatch import maybe_enable_parallelism_for_controlnet
diff --git a/src/cache_dit/parallelism/controlnets/context_parallelism/__init__.py b/src/cache_dit/parallelism/controlnets/context_parallelism/__init__.py
new file mode 100644
index 000000000..db0e92950
--- /dev/null
+++ b/src/cache_dit/parallelism/controlnets/context_parallelism/__init__.py
@@ -0,0 +1,91 @@
+import torch
+from typing import Optional
+
+from diffusers.models.modeling_utils import ModelMixin
+from cache_dit.parallelism.backend import ParallelismBackend
+from cache_dit.parallelism.config import ParallelismConfig
+from cache_dit.logger import init_logger
+
+try:
+ from diffusers import ContextParallelConfig # noqa: F401
+ from cache_dit.parallelism.attention import (
+ _maybe_register_custom_attn_backends,
+ _is_diffusers_parallelism_available,
+ enable_ulysses_anything,
+ enable_ulysses_float8,
+ )
+ from .cp_plan_registers import ControlNetContextParallelismPlannerRegister
+ from .cp_planners import _activate_controlnet_cp_planners
+
+ _maybe_register_custom_attn_backends()
+ _activate_controlnet_cp_planners()
+except ImportError as e:
+ raise ImportError(e)
+
+
+logger = init_logger(__name__)
+
+
+def maybe_enable_context_parallelism(
+ controlnet: torch.nn.Module,
+ parallelism_config: Optional[ParallelismConfig],
+) -> torch.nn.Module:
+ assert isinstance(controlnet, ModelMixin), (
+ "controlnet must be an instance of diffusers' ModelMixin, " f"but got {type(controlnet)}"
+ )
+ if parallelism_config is None:
+ return controlnet
+
+ assert isinstance(parallelism_config, ParallelismConfig), (
+ "parallelism_config must be an instance of ParallelismConfig"
+ f" but got {type(parallelism_config)}"
+ )
+
+ if (
+ parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
+ and _is_diffusers_parallelism_available()
+ ):
+ cp_config = None
+ if parallelism_config.ulysses_size is not None or parallelism_config.ring_size is not None:
+ cp_config = ContextParallelConfig(
+ ulysses_degree=parallelism_config.ulysses_size,
+ ring_degree=parallelism_config.ring_size,
+ )
+ if cp_config is not None:
+ experimental_ulysses_anything = parallelism_config.parallel_kwargs.get(
+ "experimental_ulysses_anything", False
+ )
+ # Float8 all_to_all for Ulysses Attention/Ulysses Anything Attention
+ experimental_ulysses_float8 = parallelism_config.parallel_kwargs.get(
+ "experimental_ulysses_float8", False
+ )
+
+ # Must call enable_ulysses_anything before enable_ulysses_float8.
+ if experimental_ulysses_anything:
+ enable_ulysses_anything()
+
+ if experimental_ulysses_float8:
+ enable_ulysses_float8()
+
+ if hasattr(controlnet, "enable_parallelism"):
+ # Prefer custom cp_plan if provided
+ cp_plan = parallelism_config.parallel_kwargs.get("cp_plan", None)
+ if cp_plan is not None:
+ logger.info(f"Using custom context parallelism plan: {cp_plan}")
+ else:
+ # Try get context parallelism plan from register if not provided
+ extra_parallel_kwargs = {}
+ if parallelism_config.parallel_kwargs is not None:
+ extra_parallel_kwargs = parallelism_config.parallel_kwargs
+ cp_plan = ControlNetContextParallelismPlannerRegister.get_planner(
+ controlnet
+ )().apply(controlnet=controlnet, **extra_parallel_kwargs)
+
+ controlnet.enable_parallelism(config=cp_config, cp_plan=cp_plan)
+
+ else:
+ raise ValueError(
+ f"{controlnet.__class__.__name__} does not support context parallelism."
+ )
+
+ return controlnet
diff --git a/src/cache_dit/parallelism/controlnets/context_parallelism/cp_plan_registers.py b/src/cache_dit/parallelism/controlnets/context_parallelism/cp_plan_registers.py
new file mode 100644
index 000000000..9aa5e8e61
--- /dev/null
+++ b/src/cache_dit/parallelism/controlnets/context_parallelism/cp_plan_registers.py
@@ -0,0 +1,82 @@
+import torch
+import logging
+from abc import abstractmethod
+from typing import Optional
+from diffusers.models.modeling_utils import ModelMixin
+
+try:
+ from diffusers.models._modeling_parallel import (
+ ContextParallelModelPlan,
+ )
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+__all__ = [
+ "ControlNetContextParallelismPlanner",
+ "ControlNetContextParallelismPlannerRegister",
+]
+
+
+class ControlNetContextParallelismPlanner:
+ # Prefer native diffusers implementation if available
+ _cp_planner_preferred_native_diffusers: bool = True
+
+ @abstractmethod
+ def apply(
+ self,
+ # NOTE: Keep this kwarg for future extensions
+ controlnet: Optional[torch.nn.Module | ModelMixin] = None,
+ **kwargs,
+ ) -> ContextParallelModelPlan:
+ # NOTE: This method should only return the CP plan dictionary.
+ raise NotImplementedError("apply method must be implemented by subclasses")
+
+
+class ControlNetContextParallelismPlannerRegister:
+ _cp_planner_registry: dict[str, ControlNetContextParallelismPlanner] = {}
+
+ @classmethod
+ def register(cls, name: str):
+ def decorator(planner_cls: type[ControlNetContextParallelismPlanner]):
+ assert (
+ name not in cls._cp_planner_registry
+ ), f"ControlNetContextParallelismPlanner with name {name} is already registered."
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(f"Registering ControlNetContextParallelismPlanner: {name}")
+ cls._cp_planner_registry[name] = planner_cls
+ return planner_cls
+
+ return decorator
+
+ @classmethod
+ def get_planner(
+ cls, controlnet: str | torch.nn.Module | ModelMixin
+ ) -> type[ControlNetContextParallelismPlanner]:
+ if isinstance(controlnet, (torch.nn.Module, ModelMixin)):
+ name = controlnet.__class__.__name__
+ else:
+ name = controlnet
+ planner_cls = None
+ for planner_name in cls._cp_planner_registry:
+ if name.startswith(planner_name):
+ planner_cls = cls._cp_planner_registry.get(planner_name)
+ break
+ if planner_cls is None:
+ raise ValueError(f"No planner registered under name: {name}")
+ return planner_cls
+
+ @classmethod
+ def supported_planners(
+ cls,
+ ) -> tuple[int, list[str]]:
+ val_planners = cls._cp_planner_registry.keys()
+ return len(val_planners), [p for p in val_planners]
diff --git a/src/cache_dit/parallelism/controlnets/context_parallelism/cp_plan_zimage_controlnet.py b/src/cache_dit/parallelism/controlnets/context_parallelism/cp_plan_zimage_controlnet.py
new file mode 100644
index 000000000..ef71e2e75
--- /dev/null
+++ b/src/cache_dit/parallelism/controlnets/context_parallelism/cp_plan_zimage_controlnet.py
@@ -0,0 +1,223 @@
+import torch
+import functools
+from typing import Optional
+from torch.distributed import DeviceMesh
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers import ZImageControlNetModel
+from diffusers.models.controlnets.controlnet_z_image import (
+ ZSingleStreamAttnProcessor,
+ dispatch_attention_fn,
+ Attention,
+)
+
+try:
+ from diffusers.models._modeling_parallel import (
+ ContextParallelInput,
+ ContextParallelOutput,
+ ContextParallelModelPlan,
+ )
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+from .cp_plan_registers import (
+ ControlNetContextParallelismPlanner,
+ ControlNetContextParallelismPlannerRegister,
+)
+from cache_dit.parallelism.attention import _unified_all_to_all_o_async_fn
+from cache_dit.parallelism.attention import _unified_all_to_all_qkv_async_fn
+from cache_dit.parallelism.attention import _prepare_ulysses_comm_metadata
+from cache_dit.platforms import current_platform
+
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+@ControlNetContextParallelismPlannerRegister.register("ZImageControlNetModel")
+class ZImageControlNetContextParallelismPlanner(ControlNetContextParallelismPlanner):
+ def apply(
+ self,
+ controlnet: Optional[torch.nn.Module | ModelMixin] = None,
+ **kwargs,
+ ) -> ContextParallelModelPlan:
+
+ # NOTE: Diffusers native CP plan still not supported for ZImageControlNetModel
+ self._cp_planner_preferred_native_diffusers = False
+
+ if controlnet is not None and self._cp_planner_preferred_native_diffusers:
+ assert isinstance(
+ controlnet, ZImageControlNetModel
+ ), "controlnet must be an instance of ZImageControlNetModel"
+ if hasattr(controlnet, "_cp_plan"):
+ if controlnet._cp_plan is not None:
+ return controlnet._cp_plan
+
+ experimental_ulysses_async = kwargs.get("experimental_ulysses_async", False)
+ if experimental_ulysses_async:
+ ZSingleStreamAttnProcessor.__call__ = (
+ __patch_ZSingleStreamAttnProcessor_ulysses_async__call__
+ )
+
+ logger.info(
+ "Enabled experimental Async QKV Projection with Ulysses style "
+ "Context Parallelism for ZImageControlNetModel."
+ )
+
+ # The cp plan for ZImage ControlNet is very complicated, I [HATE] it.
+ n_control_layers = len(controlnet.control_layers) # 15
+ n_control_noise_refiner_layers = len(controlnet.control_noise_refiner) # 2
+ _cp_plan = {
+ "control_noise_refiner.0": {
+ "c": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "control_noise_refiner.*": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "freqs_cis": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ f"control_noise_refiner.{n_control_noise_refiner_layers - 1}": ContextParallelOutput(
+ gather_dim=2, expected_dims=4
+ ),
+ "control_layers.0": {
+ "c": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "control_layers.*": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "freqs_cis": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ f"control_layers.{n_control_layers - 1}": ContextParallelOutput(
+ gather_dim=2, expected_dims=4
+ ),
+ }
+ return _cp_plan
+
+
+# NOTE: Support Async Ulysses QKV projection for Z-Image ControlNet
+def _ulysses_attn_with_async_qkv_proj_zimage_controlnet(
+ self: ZSingleStreamAttnProcessor,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+
+ ulysses_mesh: DeviceMesh = self._parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ _all_to_all_o_async_func = _unified_all_to_all_o_async_fn()
+ _all_to_all_qv_async_func = _unified_all_to_all_qkv_async_fn()
+ _all_to_all_k_async_func = _unified_all_to_all_qkv_async_fn(fp8=False)
+
+ # Apply RoPE
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ with torch.amp.autocast(current_platform.device_type, enabled=False):
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis.unsqueeze(2)
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
+ return x_out.type_as(x_in) # todo
+
+ dtype = hidden_states.dtype
+ query = attn.to_q(hidden_states) # type: torch.Tensor
+ query = query.unflatten(-1, (attn.heads, -1))
+ if attn.norm_q is not None: # Apply Norms
+ query = attn.norm_q(query)
+ if freqs_cis is not None: # Apply RoPE
+ query = apply_rotary_emb(query, freqs_cis)
+
+ metadata = _prepare_ulysses_comm_metadata(query)
+
+ # Async all to all for query
+ query_wait = _all_to_all_qv_async_func(query, group, **metadata)
+
+ key = attn.to_k(hidden_states) # type: torch.Tensor
+ key = key.unflatten(-1, (attn.heads, -1))
+ if attn.norm_k is not None: # Apply Norms
+ key = attn.norm_k(key)
+ if freqs_cis is not None: # Apply RoPE
+ key = apply_rotary_emb(key, freqs_cis)
+
+ # Async all to all for key
+ key_wait = _all_to_all_k_async_func(key, group, **metadata)
+
+ value = attn.to_v(hidden_states) # type: torch.Tensor
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ # Async all to all for value
+ value_wait = _all_to_all_qv_async_func(value, group, **metadata)
+
+ # Ensure the query, key, value are ready
+ query = query_wait()
+ key = key_wait()
+ value = value_wait()
+
+ # Cast to correct dtype
+ query, key = query.to(dtype), key.to(dtype)
+
+ # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
+ if attention_mask is not None and attention_mask.ndim == 2:
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Compute joint attention
+ out = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=None, # set to None to avoid double parallelism
+ ) # (B, S_GLOBAL, H_LOCAL, D)
+
+ out_wait = _all_to_all_o_async_func(out, group, **metadata) # (B, S_LOCAL, H_GLOBAL, D)
+ hidden_states = out_wait() # type: torch.Tensor
+
+ # Reshape back
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(dtype)
+
+ output = attn.to_out[0](hidden_states)
+ if len(attn.to_out) > 1: # dropout
+ output = attn.to_out[1](output)
+
+ return output
+
+
+ZSingleStreamAttnProcessor_original__call__ = ZSingleStreamAttnProcessor.__call__
+
+
+@functools.wraps(ZSingleStreamAttnProcessor_original__call__)
+def __patch_ZSingleStreamAttnProcessor_ulysses_async__call__(
+ self: ZSingleStreamAttnProcessor,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ if (
+ self._parallel_config is not None
+ and hasattr(self._parallel_config, "context_parallel_config")
+ and self._parallel_config.context_parallel_config is not None
+ and self._parallel_config.context_parallel_config.ulysses_degree > 1
+ ):
+ return _ulysses_attn_with_async_qkv_proj_zimage_controlnet(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states,
+ attention_mask,
+ freqs_cis,
+ )
+ else:
+ return ZSingleStreamAttnProcessor_original__call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states,
+ attention_mask,
+ freqs_cis,
+ )
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py b/src/cache_dit/parallelism/controlnets/context_parallelism/cp_planners.py
similarity index 57%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py
rename to src/cache_dit/parallelism/controlnets/context_parallelism/cp_planners.py
index 78ad49804..49f539a48 100644
--- a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py
+++ b/src/cache_dit/parallelism/controlnets/context_parallelism/cp_planners.py
@@ -56,68 +56,42 @@
# ContextParallelOutput:
# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
-from .cp_plan_registers import (
- ContextParallelismPlanner,
- ContextParallelismPlannerRegister,
-)
-from .cp_plan_flux import FluxContextParallelismPlanner
-from .cp_plan_qwen_image import QwenImageContextParallelismPlanner
-from .cp_plan_wan import WanContextParallelismPlanner
-from .cp_plan_wan import WanVACEContextParallelismPlanner
-from .cp_plan_ltxvideo import LTXVideoContextParallelismPlanner
-from .cp_plan_hunyuan import HunyuanImageContextParallelismPlanner
-from .cp_plan_hunyuan import HunyuanVideoContextParallelismPlanner
-from .cp_plan_cogvideox import CogVideoXContextParallelismPlanner
-from .cp_plan_cogview import CogView3PlusContextParallelismPlanner
-from .cp_plan_cogview import CogView4ContextParallelismPlanner
-from .cp_plan_cosisid import CosisIDContextParallelismPlanner
-from .cp_plan_chroma import ChromaContextParallelismPlanner
-from .cp_plan_pixart import PixArtContextParallelismPlanner
-from .cp_plan_dit import DiTContextParallelismPlanner
-from .cp_plan_kandinsky import Kandinsky5ContextParallelismPlanner
-from .cp_plan_skyreels import SkyReelsV2ContextParallelismPlanner
+import importlib
+from cache_dit.logger import init_logger
+from .cp_plan_registers import ControlNetContextParallelismPlanner
-try:
- import nunchaku # noqa: F401
+logger = init_logger(__name__)
- _nunchaku_available = True
-except ImportError:
- _nunchaku_available = False
-if _nunchaku_available:
- from .cp_plan_nunchaku import ( # noqa: F401
- NunchakuFluxContextParallelismPlanner,
- )
- from .cp_plan_nunchaku import ( # noqa: F401
- NunchakuQwenImageContextParallelismPlanner,
- )
+class ImportErrorContextParallelismPlanner(ControlNetContextParallelismPlanner):
+ def plan(
+ self,
+ controlnet,
+ **kwargs,
+ ):
+ raise ImportError(
+ "This ControlNetContextParallelismPlanner requires latest diffusers to be installed. "
+ "Please install diffusers from source."
+ )
+
+def _safe_import(module_name: str, class_name: str) -> type[ControlNetContextParallelismPlanner]:
+ try:
+ # e.g., module_name = ".cp_plan_zimage_controlnet", class_name = "ZImageControlNetContextParallelismPlanner"
+ package = __package__ if __package__ is not None else ""
+ module = importlib.import_module(module_name, package=package)
+ target_class = getattr(module, class_name)
+ return target_class
+ except (ImportError, AttributeError) as e:
+ logger.debug(f"Failed to import {class_name} from {module_name}: {e}")
+ return ImportErrorContextParallelismPlanner
-__all__ = [
- "ContextParallelismPlanner",
- "ContextParallelismPlannerRegister",
- "FluxContextParallelismPlanner",
- "QwenImageContextParallelismPlanner",
- "WanContextParallelismPlanner",
- "WanVACEContextParallelismPlanner",
- "LTXVideoContextParallelismPlanner",
- "HunyuanImageContextParallelismPlanner",
- "HunyuanVideoContextParallelismPlanner",
- "CogVideoXContextParallelismPlanner",
- "CogView3PlusContextParallelismPlanner",
- "CogView4ContextParallelismPlanner",
- "CosisIDContextParallelismPlanner",
- "ChromaContextParallelismPlanner",
- "PixArtContextParallelismPlanner",
- "DiTContextParallelismPlanner",
- "Kandinsky5ContextParallelismPlanner",
- "SkyReelsV2ContextParallelismPlanner",
-]
-if _nunchaku_available:
- __all__.extend(
- [
- "NunchakuFluxContextParallelismPlanner",
- "NunchakuQwenImageContextParallelismPlanner",
- ]
+def _activate_controlnet_cp_planners():
+ """Function to register all built-in context parallelism planners."""
+ ZImageControlNetContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_zimage_controlnet", "ZImageControlNetContextParallelismPlanner"
)
+
+
+__all__ = ["_activate_controlnet_cp_planners"]
diff --git a/src/cache_dit/parallelism/controlnets/dispatch.py b/src/cache_dit/parallelism/controlnets/dispatch.py
new file mode 100644
index 000000000..1ef24dce2
--- /dev/null
+++ b/src/cache_dit/parallelism/controlnets/dispatch.py
@@ -0,0 +1,50 @@
+import torch
+
+from typing import Optional
+from cache_dit.logger import init_logger
+
+from diffusers.models.modeling_utils import ModelMixin
+from cache_dit.parallelism.backend import ParallelismBackend
+from cache_dit.parallelism.config import ParallelismConfig
+from .context_parallelism import maybe_enable_context_parallelism
+
+logger = init_logger(__name__)
+
+
+def maybe_enable_parallelism_for_controlnet(
+ controlnet: torch.nn.Module | ModelMixin,
+ parallelism_config: Optional[ParallelismConfig],
+) -> torch.nn.Module:
+ assert isinstance(controlnet, (torch.nn.Module, ModelMixin)), (
+ "controlnet must be an instance of torch.nn.Module or ModelMixin, "
+ f"but got {type(controlnet)}"
+ )
+
+ if parallelism_config is None:
+ return controlnet
+
+ if parallelism_config.backend != ParallelismBackend.NATIVE_DIFFUSER:
+ logger.warning(
+ f"Parallelism backend {parallelism_config.backend} is not supported "
+ "for ControlNet now, skip context parallelism for ControlNet."
+ )
+ return controlnet
+
+ if parallelism_config.ulysses_size is not None or parallelism_config.ring_size is not None:
+ controlnet = maybe_enable_context_parallelism(
+ controlnet=controlnet,
+ parallelism_config=parallelism_config,
+ )
+ controlnet._is_parallelized = True # type: ignore[attr-defined]
+ # Use `parallelism` not `parallel` to avoid name conflict with diffusers.
+ controlnet._parallelism_config = parallelism_config # type: ignore[attr-defined]
+ logger.info(
+ f"Parallelize ControlNet: {controlnet.__class__.__name__}, "
+ f"id:{id(controlnet)}, {parallelism_config.strify(True)}"
+ )
+ else:
+ logger.warning(
+ "Please set ulysses_size or ring_size in parallelism_config to enable "
+ "context parallelism for ControlNet. Skipping parallelism for ControlNet."
+ )
+ return controlnet
diff --git a/src/cache_dit/parallelism/dispatch.py b/src/cache_dit/parallelism/dispatch.py
new file mode 100644
index 000000000..8dc98a3ad
--- /dev/null
+++ b/src/cache_dit/parallelism/dispatch.py
@@ -0,0 +1,175 @@
+import torch
+from diffusers.models.modeling_utils import ModelMixin
+from .backend import ParallelismBackend
+from .config import ParallelismConfig
+from cache_dit.utils import maybe_empty_cache
+from cache_dit.logger import init_logger
+from cache_dit.envs import ENV
+
+
+logger = init_logger(__name__)
+
+
+def enable_parallelism(
+ transformer: torch.nn.Module | ModelMixin,
+ parallelism_config: ParallelismConfig,
+) -> torch.nn.Module:
+ assert isinstance(transformer, (torch.nn.Module, ModelMixin)), (
+ "transformer must be an instance of torch.nn.Module or ModelMixin, "
+ f"but got {type(transformer)}"
+ )
+ if getattr(transformer, "_is_parallelized", False):
+ logger.warning("The transformer is already parallelized. Skipping parallelism enabling.")
+ return transformer
+
+ # Parallelize Transformer: The check of parallelism backend is only for transformer
+ # here. Text Encoder and VAE does not have different parallelism backends now.
+ from .transformers import maybe_enable_parallelism_for_transformer
+
+ transformer = maybe_enable_parallelism_for_transformer(
+ transformer=transformer,
+ parallelism_config=parallelism_config,
+ )
+ # Set attention backend for both context parallelism and tensor parallelism if the
+ # transformer is from diffusers and supports setting attention backend.
+ _maybe_set_module_attention_backend(
+ module=transformer,
+ parallelism_config=parallelism_config,
+ )
+
+ # Check text encoder and VAE for extra parallel modules
+ extra_parallel_modules: list[torch.nn.Module] = []
+ if parallelism_config.parallel_kwargs is not None:
+ extra_parallel_modules = parallelism_config.parallel_kwargs.get(
+ "extra_parallel_modules", []
+ )
+
+ if extra_parallel_modules:
+ for module in extra_parallel_modules:
+ # Enable parallelism for text encoder
+ if _is_text_encoder(module) and not _is_parallelized(module):
+ from .text_encoders import (
+ maybe_enable_parallelism_for_text_encoder,
+ )
+
+ maybe_enable_parallelism_for_text_encoder(
+ text_encoder=module,
+ parallelism_config=parallelism_config,
+ )
+ # Enable parallelism for ControlNet
+ elif _is_controlnet(module) and not _is_parallelized(module):
+ from .controlnets import (
+ maybe_enable_parallelism_for_controlnet,
+ )
+
+ maybe_enable_parallelism_for_controlnet(
+ controlnet=module,
+ parallelism_config=parallelism_config,
+ )
+ _maybe_set_module_attention_backend(
+ module=module,
+ parallelism_config=parallelism_config,
+ )
+ # Enable parallelism for VAE
+ elif _is_auto_encoder(module) and not _is_parallelized(module):
+ from .autoencoders import (
+ maybe_enable_parallelism_for_auto_encoder,
+ )
+
+ maybe_enable_parallelism_for_auto_encoder(
+ auto_encoder=module,
+ parallelism_config=parallelism_config,
+ )
+
+ transformer._extra_parallel_modules = extra_parallel_modules # type: ignore[attr-defined]
+ # NOTE: Workaround for potential memory peak issue after parallelism
+ # enabling, specially for tensor parallelism in native pytorch backend.
+ maybe_empty_cache()
+
+ return transformer
+
+
+def remove_parallelism_stats(
+ module: torch.nn.Module,
+) -> torch.nn.Module:
+
+ if not getattr(module, "_is_parallelized", False):
+ return module
+
+ def _remove_parallel_stats(module: torch.nn.Module) -> None:
+ if hasattr(module, "_is_parallelized"):
+ del module._is_parallelized
+ if hasattr(module, "_parallelism_config"):
+ del module._parallelism_config
+
+ _remove_parallel_stats(module)
+
+ # remove parallelism stats for extra parallel modules
+ if not hasattr(module, "_extra_parallel_modules"):
+ return module
+
+ extra_parallel_modules = getattr(module, "_extra_parallel_modules", [])
+ for extra_module in extra_parallel_modules:
+ _remove_parallel_stats(extra_module)
+
+ del module._extra_parallel_modules # type: ignore[attr-defined]
+ return module
+
+
+# Some helper functions for parallelism enabling
+def _maybe_set_module_attention_backend(
+ module: torch.nn.Module | ModelMixin,
+ parallelism_config: ParallelismConfig,
+) -> None:
+ # Set attention backend for both context parallelism and tensor parallelism if the
+ # transformer is from diffusers and supports setting attention backend.
+ module_cls_name = module.__class__.__name__
+ if hasattr(module, "set_attention_backend") and isinstance(module, ModelMixin):
+ attention_backend = parallelism_config.parallel_kwargs.get("attention_backend", None)
+ # native, _native_cudnn, flash, etc.
+ if attention_backend is None:
+ # Default to native for context parallelism due to:
+ # - attn mask support (re-registered in cache-dit)
+ # - general compatibility with various models
+ # NOTE: We only set default attention backend for NATIVE_DIFFUSER backend here
+ # while using context parallelism. For other backends, we do not change the
+ # attention backend if it is None.
+ if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
+ module.set_attention_backend("native")
+ logger.warning(
+ "attention_backend is None, set default attention backend of "
+ f"{module_cls_name} to native for context parallelism."
+ )
+ else:
+ # Ensure custom attention backends are registered in cache-dit.
+ if not ENV.CACHE_DIT_ENABLE_CUSTOM_ATTN_ALREADY_DISPATCH:
+ from .attention import (
+ _maybe_register_custom_attn_backends,
+ )
+
+ _maybe_register_custom_attn_backends()
+
+ module.set_attention_backend(attention_backend)
+ logger.info(
+ "Found attention_backend from config, set attention backend of "
+ f"{module_cls_name} to: {attention_backend}."
+ )
+
+
+def _is_text_encoder(module: torch.nn.Module) -> bool:
+ _import_module = module.__class__.__module__
+ return _import_module.startswith("transformers")
+
+
+def _is_controlnet(module: torch.nn.Module) -> bool:
+ _import_module = module.__class__.__module__
+ return _import_module.startswith("diffusers.models.controlnet")
+
+
+def _is_auto_encoder(module: torch.nn.Module) -> bool:
+ _import_module = module.__class__.__module__
+ return _import_module.startswith("diffusers.models.autoencoder")
+
+
+def _is_parallelized(module: torch.nn.Module) -> bool:
+ return getattr(module, "_is_parallelized", False)
diff --git a/src/cache_dit/parallelism/parallel_config.py b/src/cache_dit/parallelism/parallel_config.py
deleted file mode 100644
index d1ab09ce4..000000000
--- a/src/cache_dit/parallelism/parallel_config.py
+++ /dev/null
@@ -1,86 +0,0 @@
-import dataclasses
-from typing import Optional, Dict, Any
-from cache_dit.parallelism.parallel_backend import ParallelismBackend
-from cache_dit.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-@dataclasses.dataclass
-class ParallelismConfig:
- # Parallelism backend, defaults to NATIVE_DIFFUSER
- backend: ParallelismBackend = ParallelismBackend.NATIVE_DIFFUSER
- # Context parallelism config
- # ulysses_size (`int`, *optional*):
- # The degree of ulysses parallelism.
- ulysses_size: int = None
- # ring_size (`int`, *optional*):
- # The degree of ring parallelism.
- ring_size: int = None
- # Tensor parallelism config
- # tp_size (`int`, *optional*):
- # The degree of tensor parallelism.
- tp_size: int = None
- # parallel_kwargs (`dict`, *optional*):
- # Additional kwargs for parallelism backends. For example, for
- # NATIVE_DIFFUSER backend, it can include `cp_plan` and
- # `attention_backend` arguments for `Context Parallelism`.
- parallel_kwargs: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict)
-
- def __post_init__(self):
- assert ParallelismBackend.is_supported(self.backend), (
- f"Parallel backend {self.backend} is not supported. "
- f"Please make sure the required packages are installed."
- )
-
- # Validate the parallelism configuration and auto adjust the backend if needed
- if self.tp_size is not None and self.tp_size > 1:
- assert (
- self.ulysses_size is None or self.ulysses_size == 1
- ), "Tensor parallelism plus Ulysses parallelism is not supported right now."
- assert (
- self.ring_size is None or self.ring_size == 1
- ), "Tensor parallelism plus Ring parallelism is not supported right now."
- if self.backend != ParallelismBackend.NATIVE_PYTORCH:
- logger.warning(
- "Tensor parallelism is only supported for NATIVE_PYTORCH backend "
- "right now. Force set backend to NATIVE_PYTORCH."
- )
- self.backend = ParallelismBackend.NATIVE_PYTORCH
- elif (
- self.ulysses_size is not None
- and self.ulysses_size > 1
- and self.ring_size is not None
- and self.ring_size > 1
- ):
- raise ValueError(
- "Ulysses parallelism plus Ring parallelism is not fully supported right now."
- )
- else:
- if (self.ulysses_size is not None and self.ulysses_size > 1) or (
- self.ring_size is not None and self.ring_size > 1
- ):
- if self.backend != ParallelismBackend.NATIVE_DIFFUSER:
- logger.warning(
- "Ulysses/Ring parallelism is only supported for NATIVE_DIFFUSER "
- "backend right now. Force set backend to NATIVE_DIFFUSER."
- )
- self.backend = ParallelismBackend.NATIVE_DIFFUSER
-
- def strify(self, details: bool = False) -> str:
- if details:
- return (
- f"ParallelismConfig(backend={self.backend}, "
- f"ulysses_size={self.ulysses_size}, "
- f"ring_size={self.ring_size}, "
- f"tp_size={self.tp_size})"
- )
- else:
- parallel_str = ""
- if self.ulysses_size is not None:
- parallel_str += f"Ulysses{self.ulysses_size}"
- if self.ring_size is not None:
- parallel_str += f"Ring{self.ring_size}"
- if self.tp_size is not None:
- parallel_str += f"TP{self.tp_size}"
- return parallel_str
diff --git a/src/cache_dit/parallelism/parallel_interface.py b/src/cache_dit/parallelism/parallel_interface.py
deleted file mode 100644
index e371c53d4..000000000
--- a/src/cache_dit/parallelism/parallel_interface.py
+++ /dev/null
@@ -1,147 +0,0 @@
-import torch
-import torch.distributed as dist
-from typing import Union, Optional
-from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer
-from cache_dit.parallelism.parallel_backend import ParallelismBackend
-from cache_dit.parallelism.parallel_config import ParallelismConfig
-from cache_dit.utils import maybe_empty_cache
-from cache_dit.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-def enable_parallelism(
- transformer: torch.nn.Module,
- parallelism_config: ParallelismConfig,
-) -> torch.nn.Module:
- assert isinstance(transformer, torch.nn.Module), (
- "transformer must be an instance of torch.nn.Module, " f"but got {type(transformer)}"
- )
- if getattr(transformer, "_is_parallelized", False):
- logger.warning("The transformer is already parallelized. " "Skipping parallelism enabling.")
- return transformer
-
- if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
- from cache_dit.parallelism.backends.native_diffusers import (
- maybe_enable_parallelism,
- )
-
- transformer = maybe_enable_parallelism(
- transformer,
- parallelism_config,
- )
- elif parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH:
- from cache_dit.parallelism.backends.native_pytorch import (
- maybe_enable_parallelism,
- )
-
- transformer = maybe_enable_parallelism(
- transformer,
- parallelism_config,
- )
- else:
- raise ValueError(f"Parallel backend {parallelism_config.backend} is not supported yet.")
-
- transformer._is_parallelized = True # type: ignore[attr-defined]
- # Use `parallelism` not `parallel` to avoid name conflict with diffusers.
- transformer._parallelism_config = parallelism_config # type: ignore[attr-defined]
- logger.info(
- f"Enabled parallelism: {parallelism_config.strify(True)}, "
- f"transformer id:{id(transformer)}"
- )
-
- # NOTE: Workaround for potential memory peak issue after parallelism
- # enabling, specially for tensor parallelism in native pytorch backend.
- maybe_empty_cache()
-
- return transformer
-
-
-def remove_parallelism_stats(
- transformer: torch.nn.Module,
-) -> torch.nn.Module:
- if not getattr(transformer, "_is_parallelized", False):
- logger.warning("The transformer is not parallelized. " "Skipping removing parallelism.")
- return transformer
-
- if hasattr(transformer, "_is_parallelized"):
- del transformer._is_parallelized # type: ignore[attr-defined]
- if hasattr(transformer, "_parallelism_config"):
- del transformer._parallelism_config # type: ignore[attr-defined]
- return transformer
-
-
-def maybe_pad_prompt(
- tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
- prompt: str,
- extra_prompt: Optional[str] = None, # e.g., negative prompt
- num_parition: Optional[int] = None, # e.g., dist.get_world_size()
- pad_token: Optional[str] = None, # e.g., default tokenizer.pad_token
- num_extra_tokens: Optional[int] = 0, # e.g., negative prompt tokens length
- verbose: bool = True,
-) -> str:
- """Pad the prompt to make sure the number of tokens is divisible by num_partition."""
- assert isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)), (
- f"tokenizer must be an instance of PreTrainedTokenizer or PreTrainedTokenizerFast, "
- f"but got {type(tokenizer)}"
- )
- inputs_ids = tokenizer(prompt, return_tensors="pt")
-
- if num_parition is None:
- if dist.is_initialized():
- num_parition = dist.get_world_size()
- else:
- num_parition = 1
-
- if num_parition <= 1:
- return prompt
-
- if pad_token is None:
- pad_token = tokenizer.pad_token
- if pad_token is None:
- pad_token = tokenizer.eos_token
- if pad_token is None:
- pad_token = " "
- logger.warning(
- "pad_token and eos_token are not set in the tokenizer. "
- "Using space ' ' as the pad_token."
- )
-
- seq_len = inputs_ids.input_ids.shape[1] # [batch_size, seq_len]
-
- # Add extra tokens length, e.g., negative prompt tokens length
- partition_seq_len = seq_len
- partition_seq_len += num_extra_tokens
- if extra_prompt is not None:
- extra_inputs_ids = tokenizer(extra_prompt, return_tensors="pt")
- partition_seq_len += extra_inputs_ids.input_ids.shape[1]
- num_extra_tokens += extra_inputs_ids.input_ids.shape[1]
-
- if partition_seq_len % num_parition != 0:
- pad_len = num_parition - (partition_seq_len % num_parition)
- if verbose:
- logger.info(
- f"Padding the prompt from seq_len {seq_len} to "
- f"{seq_len + pad_len} to make {seq_len + pad_len} + "
- f"{num_extra_tokens} = {seq_len + pad_len + num_extra_tokens} "
- f"divisible by num_partition {num_parition}."
- )
- pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
- assert isinstance(pad_token_id, int), f"pad_token {pad_token} has more than one token."
-
- pad_ids = torch.full(
- (1, pad_len),
- pad_token_id,
- dtype=inputs_ids.input_ids.dtype,
- )
- inputs_ids.input_ids = torch.cat([inputs_ids.input_ids, pad_ids], dim=1)
-
- prompt = tokenizer.decode(inputs_ids.input_ids[0])
-
- new_seq_len = tokenizer(prompt, return_tensors="pt").input_ids.shape[1]
- new_partition_seq_len = new_seq_len + num_extra_tokens
- assert new_partition_seq_len % num_parition == 0, (
- f"Failed to pad the prompt to make it divisible by num_partition {num_parition}. "
- f"Got new_seq_len {new_seq_len}, new_partition_seq_len {new_partition_seq_len}."
- )
- return prompt
diff --git a/src/cache_dit/parallelism/text_encoders/__init__.py b/src/cache_dit/parallelism/text_encoders/__init__.py
new file mode 100644
index 000000000..24763b884
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/__init__.py
@@ -0,0 +1 @@
+from .dispatch import maybe_enable_parallelism_for_text_encoder # noqa: F401
diff --git a/src/cache_dit/parallelism/text_encoders/dispatch.py b/src/cache_dit/parallelism/text_encoders/dispatch.py
new file mode 100644
index 000000000..1bb83896c
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/dispatch.py
@@ -0,0 +1,41 @@
+from typing import Optional
+
+import torch
+
+from cache_dit.parallelism.config import ParallelismConfig
+
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def maybe_enable_parallelism_for_text_encoder(
+ text_encoder: torch.nn.Module,
+ parallelism_config: Optional[ParallelismConfig],
+) -> torch.nn.Module:
+ assert isinstance(
+ text_encoder, torch.nn.Module
+ ), f"text_encoder must be an instance of torch.nn.Module, but got {type(text_encoder)}"
+ if getattr(text_encoder, "_is_parallelized", False):
+ logger.warning("The text encoder is already parallelized. Skipping parallelism enabling.")
+ return text_encoder
+
+ if parallelism_config is None:
+ return text_encoder
+
+ from .tensor_parallelism import maybe_enable_tensor_parallelism
+
+ text_encoder = maybe_enable_tensor_parallelism(
+ text_encoder=text_encoder,
+ parallelism_config=parallelism_config,
+ )
+
+ text_encoder._is_parallelized = True # type: ignore[attr-defined]
+ text_encoder._parallelism_config = parallelism_config # type: ignore[attr-defined]
+
+ logger.info(
+ f"Parallelize Text Encoder: {text_encoder.__class__.__name__}, "
+ f"id:{id(text_encoder)}, {parallelism_config.strify(True, text_encoder=True)}"
+ )
+
+ return text_encoder
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/__init__.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/__init__.py
new file mode 100644
index 000000000..8a313daf1
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/__init__.py
@@ -0,0 +1,47 @@
+try:
+ import einops # noqa: F401
+except ImportError:
+ raise ImportError(
+ "parallelism functionality requires the 'parallelism' extra dependencies. "
+ "Install with:\npip install cache-dit[parallelism]"
+ )
+
+import torch
+from typing import Optional
+from cache_dit.parallelism.config import ParallelismConfig
+from cache_dit.logger import init_logger
+
+try:
+ from .tp_plan_registers import TextEncoderTensorParallelismPlannerRegister
+ from .tp_planners import _activate_text_encoder_tp_planners
+
+ _activate_text_encoder_tp_planners()
+except ImportError as e:
+ raise ImportError(e)
+
+logger = init_logger(__name__)
+
+
+def maybe_enable_tensor_parallelism(
+ text_encoder: torch.nn.Module,
+ parallelism_config: Optional[ParallelismConfig],
+) -> torch.nn.Module:
+ assert isinstance(
+ text_encoder, torch.nn.Module
+ ), f"text_encoder must be an instance of torch.nn.Module, but got {type(text_encoder)}"
+
+ if parallelism_config is None:
+ return text_encoder
+
+ # We don't check backend here because text encoder may use different
+ # parallelism backend with transformer.
+
+ extra_parallel_kwargs = {}
+ if parallelism_config.parallel_kwargs is not None:
+ extra_parallel_kwargs = parallelism_config.parallel_kwargs
+
+ return TextEncoderTensorParallelismPlannerRegister.get_planner(text_encoder)().apply(
+ text_encoder=text_encoder,
+ parallelism_config=parallelism_config,
+ **extra_parallel_kwargs,
+ )
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_gemma.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_gemma.py
new file mode 100644
index 000000000..5709f849a
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_gemma.py
@@ -0,0 +1,181 @@
+import torch
+from typing import Union
+from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
+from transformers import (
+ GemmaModel,
+ Gemma2Model,
+ Gemma3Model,
+ GemmaForCausalLM,
+ Gemma2ForCausalLM,
+ Gemma3ForCausalLM,
+ Gemma3ForConditionalGeneration,
+)
+from transformers.models.gemma.modeling_gemma import GemmaDecoderLayer
+from transformers.models.gemma2.modeling_gemma2 import Gemma2DecoderLayer
+from transformers.models.gemma3.modeling_gemma3 import Gemma3DecoderLayer
+from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoderLayer
+
+from torch.distributed import DeviceMesh, init_device_mesh
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.utils import maybe_empty_cache
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import (
+ TextEncoderTensorParallelismPlanner,
+ TextEncoderTensorParallelismPlannerRegister,
+)
+
+logger = init_logger(__name__)
+
+_supported_gemma_classes = (
+ T5GemmaEncoder,
+ GemmaModel,
+ Gemma2Model,
+ Gemma3Model,
+ GemmaForCausalLM,
+ Gemma2ForCausalLM,
+ Gemma3ForCausalLM,
+ Gemma3ForConditionalGeneration,
+)
+
+
+# Text Encoder Lumina-Image, prx series models.
+@TextEncoderTensorParallelismPlannerRegister.register("T5GemmaEncoder")
+@TextEncoderTensorParallelismPlannerRegister.register("GemmaModel")
+@TextEncoderTensorParallelismPlannerRegister.register("Gemma2Model")
+@TextEncoderTensorParallelismPlannerRegister.register("Gemma3Model")
+@TextEncoderTensorParallelismPlannerRegister.register("GemmaForCausalLM")
+@TextEncoderTensorParallelismPlannerRegister.register("Gemma2ForCausalLM")
+@TextEncoderTensorParallelismPlannerRegister.register("Gemma3ForCausalLM")
+@TextEncoderTensorParallelismPlannerRegister.register("Gemma3ForConditionalGeneration")
+class GemmaTensorParallelismPlanner(TextEncoderTensorParallelismPlanner):
+ def apply(
+ self,
+ text_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ text_encoder, _supported_gemma_classes
+ ), "GemmaTensorParallelismPlanner can only be applied to Gemma Language Models."
+ text_encoder_world_size = parallelism_config.text_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[text_encoder_world_size],
+ )
+
+ text_encoder = self.parallelize_text_encoder(
+ text_encoder=text_encoder,
+ tp_mesh=tp_mesh,
+ )
+
+ return text_encoder
+
+ def parallelize_text_encoder(
+ self,
+ text_encoder: Union[
+ T5GemmaEncoder,
+ GemmaModel,
+ Gemma2Model,
+ Gemma3Model,
+ GemmaForCausalLM,
+ Gemma2ForCausalLM,
+ Gemma3ForCausalLM,
+ Gemma3ForConditionalGeneration,
+ ],
+ tp_mesh: DeviceMesh,
+ ):
+
+ # NOTE: Gemma3 can be used as a multi-modal backbone. In those cases the actual
+ # language model is nested under `language_model` (and sometimes `language_model.model`).
+ # We need to unwrap to the module that has `layers`.
+ if isinstance(text_encoder, Gemma3ForConditionalGeneration) and hasattr(
+ text_encoder, "language_model"
+ ):
+ model_container = getattr(text_encoder, "language_model")
+ model = getattr(model_container, "model", model_container)
+ elif isinstance(
+ text_encoder,
+ (
+ GemmaForCausalLM,
+ Gemma2ForCausalLM,
+ Gemma3ForCausalLM,
+ ),
+ ):
+ model = text_encoder.model
+ else:
+ model = text_encoder
+ if not hasattr(model, "layers") and hasattr(model, "language_model"):
+ model_container = getattr(model, "language_model")
+ model = getattr(model_container, "model", model_container)
+
+ if not hasattr(model, "layers"):
+ raise AttributeError(
+ f"{model.__class__.__name__} object has no attribute 'layers'. "
+ "If this is a multi-modal Gemma3 model, expected the language model to be "
+ "under `language_model` (and optionally `language_model.model`)."
+ )
+
+ assert isinstance(model, torch.nn.Module), "model must be a torch.nn.Module"
+ for _, block in model.layers.named_children():
+ assert isinstance(
+ block,
+ (
+ GemmaDecoderLayer,
+ Gemma2DecoderLayer,
+ Gemma3DecoderLayer,
+ T5GemmaEncoderLayer,
+ ),
+ ), (
+ f"Unsupported layer type {block.__class__.__name__} for Gemma TP. "
+ "Expected a GemmaDecoderLayer/Gemma2DecoderLayer/Gemma3DecoderLayer/T5GemmaEncoderLayer."
+ )
+ layer_plan = {
+ "self_attn.q_proj": ColwiseParallel(),
+ "self_attn.k_proj": ColwiseParallel(),
+ "self_attn.v_proj": ColwiseParallel(),
+ "self_attn.o_proj": RowwiseParallel(),
+ "mlp.gate_proj": ColwiseParallel(),
+ "mlp.up_proj": ColwiseParallel(),
+ "mlp.down_proj": RowwiseParallel(),
+ }
+
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ if isinstance(
+ text_encoder,
+ (
+ GemmaForCausalLM,
+ Gemma2ForCausalLM,
+ Gemma3ForCausalLM,
+ Gemma3ForConditionalGeneration,
+ ),
+ ):
+ # NOTE: Gemma3ForConditionalGeneration may store the LM under `language_model`.
+ if isinstance(text_encoder, Gemma3ForConditionalGeneration) and hasattr(
+ text_encoder, "language_model"
+ ):
+ language_model = getattr(text_encoder, "language_model")
+ if hasattr(language_model, "model"):
+ language_model.model = model
+ else:
+ text_encoder.language_model = model
+ else:
+ text_encoder.model = model
+ else:
+ text_encoder = model
+
+ maybe_empty_cache()
+
+ return text_encoder
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_glm.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_glm.py
new file mode 100644
index 000000000..4e4bab3d9
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_glm.py
@@ -0,0 +1,101 @@
+import torch
+from typing import Union
+from transformers import GlmModel, GlmForCausalLM, Glm4Model, Glm4ForCausalLM
+from transformers.models.glm.modeling_glm import GlmDecoderLayer
+from transformers.models.glm4.modeling_glm4 import Glm4DecoderLayer
+
+from torch.distributed import DeviceMesh, init_device_mesh
+
+from torch.distributed.tensor import Replicate
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.utils import maybe_empty_cache
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import (
+ TextEncoderTensorParallelismPlanner,
+ TextEncoderTensorParallelismPlannerRegister,
+)
+
+logger = init_logger(__name__)
+
+
+_supported_glm_classes = (
+ GlmModel,
+ GlmForCausalLM,
+ Glm4Model,
+ Glm4ForCausalLM,
+)
+
+
+# Text Encoder for CogView4 series models.
+@TextEncoderTensorParallelismPlannerRegister.register("GlmModel")
+@TextEncoderTensorParallelismPlannerRegister.register("Glm4Model")
+@TextEncoderTensorParallelismPlannerRegister.register("GlmForCausalLM")
+@TextEncoderTensorParallelismPlannerRegister.register("Glm4ForCausalLM")
+class GlmTensorParallelismPlanner(TextEncoderTensorParallelismPlanner):
+ def apply(
+ self,
+ text_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ text_encoder, _supported_glm_classes
+ ), "GlmTensorParallelismPlanner can only be applied to Glm Language Models."
+ text_encoder_world_size = parallelism_config.text_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[text_encoder_world_size],
+ )
+
+ text_encoder = self.parallelize_text_encoder(
+ text_encoder=text_encoder,
+ tp_mesh=tp_mesh,
+ )
+
+ return text_encoder
+
+ def parallelize_text_encoder(
+ self,
+ text_encoder: Union[GlmModel, GlmForCausalLM, Glm4Model, Glm4ForCausalLM],
+ tp_mesh: DeviceMesh,
+ ):
+
+ if isinstance(text_encoder, (GlmForCausalLM, Glm4ForCausalLM)):
+ model = text_encoder.model
+ else:
+ model = text_encoder
+
+ assert isinstance(model, (GlmModel, Glm4Model)), "model must be an instance of GlmModel."
+ for _, block in model.layers.named_children():
+ assert isinstance(block, (GlmDecoderLayer, Glm4DecoderLayer))
+ layer_plan = {
+ "self_attn.q_proj": ColwiseParallel(),
+ "self_attn.k_proj": ColwiseParallel(),
+ "self_attn.v_proj": ColwiseParallel(),
+ "self_attn.o_proj": RowwiseParallel(),
+ "mlp.gate_up_proj": ColwiseParallel(output_layouts=Replicate()),
+ "mlp.down_proj": RowwiseParallel(output_layouts=Replicate()),
+ }
+
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ if isinstance(text_encoder, (GlmForCausalLM, Glm4ForCausalLM)):
+ text_encoder.model = model
+ else:
+ text_encoder = model
+
+ maybe_empty_cache()
+
+ return text_encoder
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_llama.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_llama.py
new file mode 100644
index 000000000..4afdebb28
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_llama.py
@@ -0,0 +1,89 @@
+import torch
+from typing import Union
+from transformers import LlamaModel, LlamaForCausalLM
+from torch.distributed import DeviceMesh, init_device_mesh
+
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.utils import maybe_empty_cache
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import (
+ TextEncoderTensorParallelismPlanner,
+ TextEncoderTensorParallelismPlannerRegister,
+)
+
+logger = init_logger(__name__)
+
+
+# Text Encoder HunyunVideo series models.
+@TextEncoderTensorParallelismPlannerRegister.register("LlamaModel")
+@TextEncoderTensorParallelismPlannerRegister.register("LlamaForCausalLM")
+class LlamaTensorParallelismPlanner(TextEncoderTensorParallelismPlanner):
+ def apply(
+ self,
+ text_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ text_encoder, (LlamaModel, LlamaForCausalLM)
+ ), "Qwen3TensorParallelismPlanner can only be applied to Llama Language Models."
+ text_encoder_world_size = parallelism_config.text_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[text_encoder_world_size],
+ )
+
+ text_encoder = self.parallelize_text_encoder(
+ text_encoder=text_encoder,
+ tp_mesh=tp_mesh,
+ )
+
+ return text_encoder
+
+ def parallelize_text_encoder(
+ self,
+ text_encoder: Union[LlamaModel, LlamaForCausalLM],
+ tp_mesh: DeviceMesh,
+ ):
+ from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+
+ if isinstance(text_encoder, LlamaForCausalLM):
+ model = text_encoder.model
+ else:
+ model = text_encoder
+
+ assert isinstance(model, LlamaModel), "model must be an instance of LlamaModel."
+ for _, block in model.layers.named_children():
+ assert isinstance(block, LlamaDecoderLayer)
+ layer_plan = {
+ "self_attn.q_proj": ColwiseParallel(),
+ "self_attn.k_proj": ColwiseParallel(),
+ "self_attn.v_proj": ColwiseParallel(),
+ "self_attn.o_proj": RowwiseParallel(),
+ "mlp.gate_proj": ColwiseParallel(),
+ "mlp.up_proj": ColwiseParallel(),
+ "mlp.down_proj": RowwiseParallel(),
+ }
+
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ if isinstance(text_encoder, LlamaForCausalLM):
+ text_encoder.model = model
+ else:
+ text_encoder = model
+
+ maybe_empty_cache()
+
+ return text_encoder
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_mistral.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_mistral.py
new file mode 100644
index 000000000..1a27735a9
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_mistral.py
@@ -0,0 +1,130 @@
+import torch
+from typing import Union
+from transformers import (
+ MistralModel,
+ Mistral3Model,
+ MistralForCausalLM,
+ Mistral3ForConditionalGeneration,
+)
+from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
+from torch.distributed import DeviceMesh, init_device_mesh
+
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.utils import maybe_empty_cache
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import (
+ TextEncoderTensorParallelismPlanner,
+ TextEncoderTensorParallelismPlannerRegister,
+)
+
+logger = init_logger(__name__)
+
+
+_supported_mistral_classes = (
+ MistralModel,
+ Mistral3Model,
+ MistralForCausalLM,
+ Mistral3ForConditionalGeneration,
+)
+
+
+# Text Encoder for FLUX.2 series models.
+@TextEncoderTensorParallelismPlannerRegister.register("MistralModel")
+@TextEncoderTensorParallelismPlannerRegister.register("Mistral3Model")
+@TextEncoderTensorParallelismPlannerRegister.register("MistralForCausalLM")
+@TextEncoderTensorParallelismPlannerRegister.register("Mistral3ForConditionalGeneration")
+class MistralTensorParallelismPlanner(TextEncoderTensorParallelismPlanner):
+ def apply(
+ self,
+ text_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ text_encoder, _supported_mistral_classes
+ ), "MistralTensorParallelismPlanner can only be applied to Mistral Language Models."
+ text_encoder_world_size = parallelism_config.text_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[text_encoder_world_size],
+ )
+
+ text_encoder = self.parallelize_text_encoder(
+ text_encoder=text_encoder,
+ tp_mesh=tp_mesh,
+ )
+
+ return text_encoder
+
+ def parallelize_text_encoder(
+ self,
+ text_encoder: Union[
+ MistralModel,
+ Mistral3Model,
+ MistralForCausalLM,
+ Mistral3ForConditionalGeneration,
+ ],
+ tp_mesh: DeviceMesh,
+ ):
+
+ if isinstance(
+ text_encoder,
+ (
+ Mistral3Model,
+ MistralForCausalLM,
+ Mistral3ForConditionalGeneration,
+ ),
+ ):
+ if isinstance(text_encoder, MistralForCausalLM):
+ model = text_encoder.model
+ else:
+ # Mistral3ForConditionalGeneration, Mistral3Model
+ model = text_encoder.language_model
+ else:
+ model = text_encoder
+
+ for _, block in model.layers.named_children():
+ assert isinstance(block, MistralDecoderLayer)
+ layer_plan = {
+ "self_attn.q_proj": ColwiseParallel(),
+ "self_attn.k_proj": ColwiseParallel(),
+ "self_attn.v_proj": ColwiseParallel(),
+ "self_attn.o_proj": RowwiseParallel(),
+ "mlp.gate_proj": ColwiseParallel(),
+ "mlp.up_proj": ColwiseParallel(),
+ "mlp.down_proj": RowwiseParallel(),
+ }
+
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ if isinstance(
+ text_encoder,
+ (
+ Mistral3Model,
+ MistralForCausalLM,
+ Mistral3ForConditionalGeneration,
+ ),
+ ):
+ if isinstance(text_encoder, MistralForCausalLM):
+ text_encoder.model = model
+ else:
+ # Mistral3ForConditionalGeneration, Mistral3Model
+ text_encoder.language_model = model
+ else:
+ text_encoder = model
+
+ maybe_empty_cache()
+
+ return text_encoder
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_qwen2_5.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_qwen2_5.py
new file mode 100644
index 000000000..1e50eaa7e
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_qwen2_5.py
@@ -0,0 +1,92 @@
+import torch
+from typing import Union
+from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLTextModel
+from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLDecoderLayer
+
+from torch.distributed import DeviceMesh, init_device_mesh
+
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.utils import maybe_empty_cache
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import (
+ TextEncoderTensorParallelismPlanner,
+ TextEncoderTensorParallelismPlannerRegister,
+)
+
+logger = init_logger(__name__)
+
+
+# Text Encoder for Qwen-Image, HunyuanImage-2.1, HunyuanVideo-1.5, Kandinsky-5 series models.
+@TextEncoderTensorParallelismPlannerRegister.register("Qwen2_5_VLTextModel")
+@TextEncoderTensorParallelismPlannerRegister.register("Qwen2_5_VLForConditionalGeneration")
+class Qwen2_5_VLTensorParallelismPlanner(TextEncoderTensorParallelismPlanner):
+ def apply(
+ self,
+ text_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ text_encoder, (Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLTextModel)
+ ), (
+ "Qwen2_5_VLTensorParallelismPlanner can only be applied to "
+ "Qwen2_5_VLForConditionalGeneration or Qwen2_5_VLTextModel"
+ )
+ text_encoder_world_size = parallelism_config.text_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[text_encoder_world_size],
+ )
+
+ text_encoder = self.parallelize_text_encoder(
+ text_encoder=text_encoder,
+ tp_mesh=tp_mesh,
+ )
+
+ return text_encoder
+
+ def parallelize_text_encoder(
+ self,
+ text_encoder: Union[Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLTextModel],
+ tp_mesh: DeviceMesh,
+ ):
+
+ if isinstance(text_encoder, Qwen2_5_VLForConditionalGeneration):
+ model = text_encoder.model.language_model
+ else:
+ model = text_encoder
+
+ for _, block in model.layers.named_children():
+ assert isinstance(block, Qwen2_5_VLDecoderLayer)
+ layer_plan = {
+ "self_attn.q_proj": ColwiseParallel(),
+ "self_attn.k_proj": ColwiseParallel(),
+ "self_attn.v_proj": ColwiseParallel(),
+ "self_attn.o_proj": RowwiseParallel(),
+ "mlp.gate_proj": ColwiseParallel(),
+ "mlp.up_proj": ColwiseParallel(),
+ "mlp.down_proj": RowwiseParallel(),
+ }
+
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ if isinstance(text_encoder, Qwen2_5_VLForConditionalGeneration):
+ text_encoder.model.language_model = model
+ else:
+ text_encoder = model
+
+ maybe_empty_cache()
+
+ return text_encoder
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_qwen3.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_qwen3.py
new file mode 100644
index 000000000..fbae0d396
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_qwen3.py
@@ -0,0 +1,89 @@
+import torch
+from typing import Union
+from transformers import Qwen3Model, Qwen3ForCausalLM
+from torch.distributed import DeviceMesh, init_device_mesh
+
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.utils import maybe_empty_cache
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import (
+ TextEncoderTensorParallelismPlanner,
+ TextEncoderTensorParallelismPlannerRegister,
+)
+
+logger = init_logger(__name__)
+
+
+# Text Encoder for Z-Image, Ovis-Image
+@TextEncoderTensorParallelismPlannerRegister.register("Qwen3Model")
+@TextEncoderTensorParallelismPlannerRegister.register("Qwen3ForCausalLM")
+class Qwen3TensorParallelismPlanner(TextEncoderTensorParallelismPlanner):
+ def apply(
+ self,
+ text_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ text_encoder, (Qwen3Model, Qwen3ForCausalLM)
+ ), "Qwen3TensorParallelismPlanner can only be applied to Qwen3 Language Models."
+ text_encoder_world_size = parallelism_config.text_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[text_encoder_world_size],
+ )
+
+ text_encoder = self.parallelize_text_encoder(
+ text_encoder=text_encoder,
+ tp_mesh=tp_mesh,
+ )
+
+ return text_encoder
+
+ def parallelize_text_encoder(
+ self,
+ text_encoder: Union[Qwen3Model, Qwen3ForCausalLM],
+ tp_mesh: DeviceMesh,
+ ):
+ from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer
+
+ if isinstance(text_encoder, Qwen3ForCausalLM):
+ model = text_encoder.model
+ else:
+ model = text_encoder
+
+ assert isinstance(model, Qwen3Model), "model must be an instance of Qwen3Model."
+ for _, block in model.layers.named_children():
+ assert isinstance(block, Qwen3DecoderLayer)
+ layer_plan = {
+ "self_attn.q_proj": ColwiseParallel(),
+ "self_attn.k_proj": ColwiseParallel(),
+ "self_attn.v_proj": ColwiseParallel(),
+ "self_attn.o_proj": RowwiseParallel(),
+ "mlp.gate_proj": ColwiseParallel(),
+ "mlp.up_proj": ColwiseParallel(),
+ "mlp.down_proj": RowwiseParallel(),
+ }
+
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ if isinstance(text_encoder, Qwen3ForCausalLM):
+ text_encoder.model = model
+ else:
+ text_encoder = model
+
+ maybe_empty_cache()
+
+ return text_encoder
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_registers.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_registers.py
new file mode 100644
index 000000000..bc4eed0e4
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_registers.py
@@ -0,0 +1,61 @@
+import torch
+import logging
+from abc import abstractmethod
+from typing import Dict
+from cache_dit.parallelism.config import ParallelismConfig
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+class TextEncoderTensorParallelismPlanner:
+
+ @abstractmethod
+ def apply(
+ self,
+ text_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ raise NotImplementedError("apply method must be implemented by subclasses")
+
+
+class TextEncoderTensorParallelismPlannerRegister:
+ _text_encoder_tp_planner_registry: Dict[str, TextEncoderTensorParallelismPlanner] = {}
+
+ @classmethod
+ def register(cls, name: str):
+ def decorator(planner_cls: type[TextEncoderTensorParallelismPlanner]):
+ assert (
+ name not in cls._text_encoder_tp_planner_registry
+ ), f"TextEncoderTensorParallelismPlanner with name {name} is already registered."
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug(f"Registering TextEncoderTensorParallelismPlanner: {name}")
+ cls._text_encoder_tp_planner_registry[name] = planner_cls
+ return planner_cls
+
+ return decorator
+
+ @classmethod
+ def get_planner(
+ cls, text_encoder: str | torch.nn.Module
+ ) -> type[TextEncoderTensorParallelismPlanner]:
+ if isinstance(text_encoder, torch.nn.Module):
+ name = text_encoder.__class__.__name__
+ else:
+ name = text_encoder
+ planner_cls = None
+ for planner_name in cls._text_encoder_tp_planner_registry:
+ if name.startswith(planner_name):
+ planner_cls = cls._text_encoder_tp_planner_registry.get(planner_name)
+ break
+ if planner_cls is None:
+ raise ValueError(f"No planner registered under name: {name}")
+ return planner_cls
+
+ @classmethod
+ def supported_planners(
+ cls,
+ ) -> tuple[int, list[str]]:
+ val_planners = cls._text_encoder_tp_planner_registry.keys()
+ return len(val_planners), [p for p in val_planners]
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_t5_encoder.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_t5_encoder.py
new file mode 100644
index 000000000..b2a3e2d90
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_t5_encoder.py
@@ -0,0 +1,96 @@
+import torch
+from transformers import T5EncoderModel
+from torch.distributed import DeviceMesh, init_device_mesh
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import (
+ TextEncoderTensorParallelismPlanner,
+ TextEncoderTensorParallelismPlannerRegister,
+)
+
+logger = init_logger(__name__)
+
+
+# Text Encoder for FLUX.1, Chroma1-HD, CogVideoX1.5, CogView3-Plus, VisualCloze,
+# HiDream, HunyuanImage 2.1, LTXVideo, mochi-preview, PixArt series models.
+@TextEncoderTensorParallelismPlannerRegister.register("T5EncoderModel")
+class T5EncoderTensorParallelismPlanner(TextEncoderTensorParallelismPlanner):
+ def apply(
+ self,
+ text_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ text_encoder, T5EncoderModel
+ ), "T5EncoderTensorParallelismPlanner can only be applied to T5EncoderModel"
+ text_encoder_world_size = parallelism_config.text_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[text_encoder_world_size],
+ )
+
+ text_encoder = self.parallelize_text_encoder(
+ text_encoder=text_encoder,
+ tp_mesh=tp_mesh,
+ )
+
+ return text_encoder
+
+ def parallelize_text_encoder(
+ self,
+ text_encoder: T5EncoderModel,
+ tp_mesh: DeviceMesh,
+ ):
+ from transformers.models.t5.modeling_t5 import (
+ T5Block,
+ T5Attention,
+ T5DenseActDense,
+ T5DenseGatedActDense,
+ )
+
+ for i, block in enumerate(text_encoder.encoder.block):
+ assert isinstance(block, T5Block)
+ assert isinstance(block.layer[0].SelfAttention, T5Attention)
+ block.layer[0].SelfAttention.n_heads //= tp_mesh.size()
+ block.layer[0].SelfAttention.inner_dim //= tp_mesh.size()
+ if isinstance(block.layer[1].DenseReluDense, T5DenseActDense):
+ layer_plan = {
+ "layer.0.SelfAttention.q": ColwiseParallel(),
+ "layer.0.SelfAttention.k": ColwiseParallel(),
+ "layer.0.SelfAttention.v": ColwiseParallel(),
+ "layer.0.SelfAttention.o": RowwiseParallel(),
+ "layer.1.DenseReluDense.wi": ColwiseParallel(),
+ "layer.1.DenseReluDense.wo": RowwiseParallel(),
+ }
+ elif isinstance(block.layer[1].DenseReluDense, T5DenseGatedActDense):
+ layer_plan = {
+ "layer.0.SelfAttention.q": ColwiseParallel(),
+ "layer.0.SelfAttention.k": ColwiseParallel(),
+ "layer.0.SelfAttention.v": ColwiseParallel(),
+ "layer.0.SelfAttention.o": RowwiseParallel(),
+ "layer.1.DenseReluDense.wi_0": ColwiseParallel(),
+ "layer.1.DenseReluDense.wi_1": ColwiseParallel(),
+ "layer.1.DenseReluDense.wo": RowwiseParallel(),
+ }
+ else:
+ raise NotImplementedError(
+ f"Unsupported feed-forward layer type: {type(block.layer[1].DenseReluDense)}"
+ )
+ if block.layer[0].SelfAttention.has_relative_attention_bias:
+ layer_plan["layer.0.SelfAttention.relative_attention_bias"] = ColwiseParallel()
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ return text_encoder
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_umt5_encoder.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_umt5_encoder.py
new file mode 100644
index 000000000..fcba7b03a
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_plan_umt5_encoder.py
@@ -0,0 +1,97 @@
+import torch
+from transformers import UMT5EncoderModel
+from torch.distributed import DeviceMesh, init_device_mesh
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import (
+ TextEncoderTensorParallelismPlanner,
+ TextEncoderTensorParallelismPlannerRegister,
+)
+
+logger = init_logger(__name__)
+
+
+# Text Encoder for Wan2.1, Wan2.2, ChronoEdit, LongCat-Video, SkyReelsV2 series models.
+@TextEncoderTensorParallelismPlannerRegister.register("UMT5EncoderModel")
+class UMT5EncoderTensorParallelismPlanner(TextEncoderTensorParallelismPlanner):
+ def apply(
+ self,
+ text_encoder: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert isinstance(
+ text_encoder, UMT5EncoderModel
+ ), "UMT5EncoderTensorParallelismPlanner can only be applied to UMT5EncoderModel"
+ text_encoder_world_size = parallelism_config.text_encoder_world_size
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[text_encoder_world_size],
+ )
+
+ text_encoder = self.parallelize_text_encoder(
+ text_encoder=text_encoder,
+ tp_mesh=tp_mesh,
+ )
+
+ return text_encoder
+
+ def parallelize_text_encoder(
+ self,
+ text_encoder: UMT5EncoderModel,
+ tp_mesh: DeviceMesh,
+ ):
+ from transformers.models.umt5.modeling_umt5 import (
+ UMT5Block,
+ UMT5Attention,
+ UMT5DenseActDense,
+ UMT5DenseGatedActDense,
+ )
+
+ for i, block in enumerate(text_encoder.encoder.block):
+ assert isinstance(block, UMT5Block)
+ assert isinstance(block.layer[0].SelfAttention, UMT5Attention)
+ block.layer[0].SelfAttention.n_heads //= tp_mesh.size()
+ block.layer[0].SelfAttention.inner_dim //= tp_mesh.size()
+ if isinstance(block.layer[1].DenseReluDense, UMT5DenseActDense):
+ layer_plan = {
+ "layer.0.SelfAttention.q": ColwiseParallel(),
+ "layer.0.SelfAttention.k": ColwiseParallel(),
+ "layer.0.SelfAttention.v": ColwiseParallel(),
+ "layer.0.SelfAttention.o": RowwiseParallel(),
+ "layer.1.DenseReluDense.wi": ColwiseParallel(),
+ "layer.1.DenseReluDense.wo": RowwiseParallel(),
+ }
+ elif isinstance(block.layer[1].DenseReluDense, UMT5DenseGatedActDense):
+ layer_plan = {
+ "layer.0.SelfAttention.q": ColwiseParallel(),
+ "layer.0.SelfAttention.k": ColwiseParallel(),
+ "layer.0.SelfAttention.v": ColwiseParallel(),
+ "layer.0.SelfAttention.o": RowwiseParallel(),
+ "layer.1.DenseReluDense.wi_0": ColwiseParallel(),
+ "layer.1.DenseReluDense.wi_1": ColwiseParallel(),
+ "layer.1.DenseReluDense.wo": RowwiseParallel(),
+ }
+ else:
+ raise NotImplementedError(
+ f"Unsupported feed-forward layer type: {type(block.layer[1].DenseReluDense)}"
+ )
+ # SelfAttention in UMT5Attention always has relative_attention_bias, nn.Embedding layer.
+ if block.layer[0].SelfAttention.has_relative_attention_bias:
+ layer_plan["layer.0.SelfAttention.relative_attention_bias"] = ColwiseParallel()
+
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ return text_encoder
diff --git a/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_planners.py b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_planners.py
new file mode 100644
index 000000000..1834ca4e9
--- /dev/null
+++ b/src/cache_dit/parallelism/text_encoders/tensor_parallelism/tp_planners.py
@@ -0,0 +1,60 @@
+import importlib
+from cache_dit.logger import init_logger
+from .tp_plan_registers import TextEncoderTensorParallelismPlanner
+
+logger = init_logger(__name__)
+
+
+class ImportErrorTextEncoderTensorParallelismPlanner(TextEncoderTensorParallelismPlanner):
+ def plan(
+ self,
+ text_encoder,
+ **kwargs,
+ ):
+ raise ImportError(
+ "This TextEncoderTensorParallelismPlanner requires latest diffusers to be installed. "
+ "Please install diffusers from source."
+ )
+
+
+def _safe_import(module_name: str, class_name: str) -> type[TextEncoderTensorParallelismPlanner]:
+ try:
+ # e.g., module_name = ".tp_plan_t5_encoder", class_name = "T5EncoderTensorParallelismPlanner"
+ package = __package__ if __package__ is not None else ""
+ module = importlib.import_module(module_name, package=package)
+ target_class = getattr(module, class_name)
+ return target_class
+ except (ImportError, AttributeError) as e:
+ logger.debug(f"Failed to import {class_name} from {module_name}: {e}")
+ return ImportErrorTextEncoderTensorParallelismPlanner
+
+
+def _activate_text_encoder_tp_planners():
+ """Function to register all built-in tensor parallelism planners."""
+ T5EncoderTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_t5_encoder", "T5EncoderTensorParallelismPlanner"
+ )
+ UMT5EncoderTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_umt5_encoder", "UMT5EncoderTensorParallelismPlanner"
+ )
+ MistralTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_mistral", "MistralTensorParallelismPlanner"
+ )
+ Qwen2_5_VLTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_qwen2_5", "Qwen2_5_VLTensorParallelismPlanner"
+ )
+ Qwen3TensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_qwen3", "Qwen3TensorParallelismPlanner"
+ )
+ LlamaTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_llama", "LlamaTensorParallelismPlanner"
+ )
+ GemmaTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_gemma", "GemmaTensorParallelismPlanner"
+ )
+ GlmTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_glm", "GlmTensorParallelismPlanner"
+ )
+
+
+__all__ = ["_activate_text_encoder_tp_planners"]
diff --git a/src/cache_dit/parallelism/transformers/__init__.py b/src/cache_dit/parallelism/transformers/__init__.py
new file mode 100644
index 000000000..729b48b5a
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/__init__.py
@@ -0,0 +1 @@
+from .dispatch import maybe_enable_parallelism_for_transformer
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py b/src/cache_dit/parallelism/transformers/context_parallelism/__init__.py
similarity index 59%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/__init__.py
index e5176e4ac..c05f1f45d 100644
--- a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/__init__.py
@@ -2,31 +2,37 @@
from typing import Optional
from diffusers.models.modeling_utils import ModelMixin
-from cache_dit.parallelism.parallel_backend import ParallelismBackend
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.backend import ParallelismBackend
+from cache_dit.parallelism.config import ParallelismConfig
from cache_dit.logger import init_logger
-from ..utils import (
- native_diffusers_parallelism_available,
- ContextParallelConfig,
-)
-from .attention import maybe_resigter_native_attention_backend
-from .attention import enable_ulysses_anything
-from .cp_planners import *
try:
- maybe_resigter_native_attention_backend()
+ from diffusers import ContextParallelConfig # noqa: F401
+ from cache_dit.parallelism.attention import (
+ _maybe_register_custom_attn_backends,
+ _is_diffusers_parallelism_available,
+ enable_ulysses_anything,
+ enable_ulysses_float8,
+ )
+ from .cp_plan_registers import ContextParallelismPlannerRegister
+ from .cp_planners import _activate_cp_planners
+
+ _maybe_register_custom_attn_backends()
+ _activate_cp_planners()
except ImportError as e:
raise ImportError(e)
+
logger = init_logger(__name__)
def maybe_enable_context_parallelism(
- transformer: torch.nn.Module,
+ transformer: torch.nn.Module | ModelMixin,
parallelism_config: Optional[ParallelismConfig],
) -> torch.nn.Module:
- assert isinstance(transformer, ModelMixin), (
- "transformer must be an instance of diffusers' ModelMixin, " f"but got {type(transformer)}"
+ assert isinstance(transformer, (torch.nn.Module, ModelMixin)), (
+ "transformer must be an instance of torch.nn.Module or ModelMixin, "
+ f"but got {type(transformer)}"
)
if parallelism_config is None:
return transformer
@@ -38,7 +44,7 @@ def maybe_enable_context_parallelism(
if (
parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
- and native_diffusers_parallelism_available()
+ and _is_diffusers_parallelism_available()
):
cp_config = None
if parallelism_config.ulysses_size is not None or parallelism_config.ring_size is not None:
@@ -50,28 +56,19 @@ def maybe_enable_context_parallelism(
experimental_ulysses_anything = parallelism_config.parallel_kwargs.get(
"experimental_ulysses_anything", False
)
+ # Float8 all_to_all for Ulysses Attention/Ulysses Anything Attention
+ experimental_ulysses_float8 = parallelism_config.parallel_kwargs.get(
+ "experimental_ulysses_float8", False
+ )
+
+ # Must call enable_ulysses_anything before enable_ulysses_float8.
if experimental_ulysses_anything:
enable_ulysses_anything()
- attention_backend = parallelism_config.parallel_kwargs.get("attention_backend", None)
+ if experimental_ulysses_float8:
+ enable_ulysses_float8()
+
if hasattr(transformer, "enable_parallelism"):
- if hasattr(transformer, "set_attention_backend"):
- # native, _native_cudnn, flash, etc.
- if attention_backend is None:
- # Now only _native_cudnn is supported for parallelism
- # issue: https://github.com/huggingface/diffusers/pull/12443
- transformer.set_attention_backend("_native_cudnn")
- logger.warning(
- "attention_backend is None, set default attention backend "
- "to _native_cudnn for parallelism because of the issue: "
- "https://github.com/huggingface/diffusers/pull/12443"
- )
- else:
- transformer.set_attention_backend(attention_backend)
- logger.info(
- "Found attention_backend from config, set attention "
- f"backend to: {attention_backend}"
- )
# Prefer custom cp_plan if provided
cp_plan = parallelism_config.parallel_kwargs.get("cp_plan", None)
if cp_plan is not None:
@@ -86,7 +83,7 @@ def maybe_enable_context_parallelism(
)
transformer.enable_parallelism(config=cp_config, cp_plan=cp_plan)
- _maybe_patch_native_parallel_config(transformer)
+ _maybe_patch_native_parallel_config(transformer, **extra_parallel_kwargs)
else:
raise ValueError(
f"{transformer.__class__.__name__} does not support context parallelism."
@@ -97,14 +94,13 @@ def maybe_enable_context_parallelism(
def _maybe_patch_native_parallel_config(
transformer: torch.nn.Module,
+ **kwargs,
) -> torch.nn.Module:
cls_name = transformer.__class__.__name__
if not cls_name.startswith("Nunchaku"):
return transformer
- from diffusers import FluxTransformer2DModel, QwenImageTransformer2DModel
-
try:
from nunchaku.models.transformers.transformer_flux_v2 import (
NunchakuFluxTransformer2DModelV2,
@@ -116,35 +112,41 @@ def _maybe_patch_native_parallel_config(
NunchakuQwenImageNaiveFA2Processor,
NunchakuQwenImageTransformer2DModel,
)
+ from nunchaku.models.transformers.transformer_zimage import (
+ NunchakuZImageTransformer2DModel,
+ NunchakuZSingleStreamAttnProcessor,
+ NunchakuZImageAttention,
+ )
except ImportError:
raise ImportError(
- "NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel "
- "requires the 'nunchaku' package. Please install nunchaku before using "
- "the context parallelism for nunchaku 4-bits models."
+ "NunchakuZImageTransformer2DModel, NunchakuFluxTransformer2DModelV2 and "
+ "NunchakuQwenImageTransformer2DModel requires the 'nunchaku' package. "
+ "Please install nunchaku>=1.10 before using the context parallelism for "
+ "nunchaku 4-bits models."
)
+
assert isinstance(
transformer,
(
NunchakuFluxTransformer2DModelV2,
- FluxTransformer2DModel,
- ),
- ) or isinstance(
- transformer,
- (
NunchakuQwenImageTransformer2DModel,
- QwenImageTransformer2DModel,
+ NunchakuZImageTransformer2DModel,
),
- ), (
- "transformer must be an instance of NunchakuFluxTransformer2DModelV2 "
- f"or NunchakuQwenImageTransformer2DModel, but got {type(transformer)}"
)
- config = transformer._parallel_config
+ config = getattr(transformer, "_parallel_config", None)
+ if config is None:
+ raise logger.warning(
+ f"The transformer {cls_name} does not have _parallel_config attribute. "
+ "Skipping patching native parallel config."
+ )
attention_classes = (
NunchakuFluxAttention,
NunchakuFluxFA2Processor,
NunchakuQwenAttention,
NunchakuQwenImageNaiveFA2Processor,
+ NunchakuZImageAttention,
+ NunchakuZSingleStreamAttnProcessor,
)
for module in transformer.modules():
if not isinstance(module, attention_classes):
@@ -152,6 +154,12 @@ def _maybe_patch_native_parallel_config(
processor = getattr(module, "processor", None)
if processor is None or not hasattr(processor, "_parallel_config"):
continue
+ if getattr(processor, "_parallel_config", None) is not None:
+ logger.warning(
+ f"The attention processor {processor.__class__.__name__} already has "
+ "_parallel_config attribute set. Skipping patching native parallel config."
+ )
+ continue
processor._parallel_config = config
return transformer
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_chroma.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_chroma.py
diff --git a/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_chrono_edit.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_chrono_edit.py
new file mode 100644
index 000000000..94dfe5be7
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_chrono_edit.py
@@ -0,0 +1,200 @@
+import torch
+import functools
+from typing import Optional, Tuple
+from diffusers.models.modeling_utils import ModelMixin
+
+try:
+ from diffusers.models.transformers.transformer_chronoedit import (
+ _get_added_kv_projections,
+ _get_qkv_projections,
+ dispatch_attention_fn,
+ )
+ from diffusers.models.transformers.transformer_chronoedit import (
+ WanAttention as ChronoEditWanAttention,
+ )
+ from diffusers.models.transformers.transformer_chronoedit import (
+ WanAttnProcessor as ChronoEditWanAttnProcessor,
+ )
+ from diffusers.models._modeling_parallel import (
+ ContextParallelInput,
+ ContextParallelOutput,
+ ContextParallelModelPlan,
+ )
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+from .cp_plan_registers import (
+ ContextParallelismPlanner,
+ ContextParallelismPlannerRegister,
+)
+
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+@ContextParallelismPlannerRegister.register("ChronoEditTransformer3D")
+class ChronoEditContextParallelismPlanner(ContextParallelismPlanner):
+ def apply(
+ self,
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
+ **kwargs,
+ ) -> ContextParallelModelPlan:
+
+ self._cp_planner_preferred_native_diffusers = False
+
+ if transformer is not None and self._cp_planner_preferred_native_diffusers:
+ if hasattr(transformer, "_cp_plan"):
+ if transformer._cp_plan is not None:
+ return transformer._cp_plan
+
+ # Otherwise, use the custom CP plan defined here, this maybe
+ # a little different from the native diffusers implementation
+ # for some models.
+ ChronoEditWanAttnProcessor.__call__ = __patch_ChronoEditWanAttnProcessor__call__
+ _cp_plan = {
+ # Pattern of rope, split_output=True (split output rather than input):
+ # un-split input
+ # -> keep input un-split
+ # -> rope
+ # -> splited output
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ },
+ # Pattern of blocks.0, split_output=False:
+ # un-split input -> split -> to_qkv/...
+ # -> all2all
+ # -> attn (local head, full seqlen)
+ # -> all2all
+ # -> splited output
+ # (only split hidden_states, not encoder_hidden_states)
+ "blocks.0": {
+ "hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ # Pattern of the all blocks, split_output=False:
+ # un-split input -> split -> to_qkv/...
+ # -> all2all
+ # -> attn (local head, full seqlen)
+ # -> all2all
+ # -> splited output
+ # (only split encoder_hidden_states, not hidden_states.
+ # hidden_states has been automatically split in previous
+ # block by all2all comm op after attn)
+ # The `encoder_hidden_states` will [NOT] be changed after each block forward,
+ # so we need to split it at [ALL] block by the inserted split hook.
+ # NOTE(DefTruth): We need to disable the splitting of encoder_hidden_states because
+ # the image_encoder consistently generates 257 tokens for image_embed. This causes
+ # the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
+ # after concatenation—to be indivisible by the number of devices in the CP.
+ # "blocks.*": {
+ # "encoder_hidden_states": ContextParallelInput(
+ # split_dim=1, expected_dims=3, split_output=False
+ # ),
+ # },
+ # Then, the final proj_out will gather the splited output.
+ # splited input (previous splited output)
+ # -> all gather
+ # -> un-split output
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+ return _cp_plan
+
+
+@functools.wraps(ChronoEditWanAttnProcessor.__call__)
+def __patch_ChronoEditWanAttnProcessor__call__(
+ self: ChronoEditWanAttnProcessor,
+ attn: ChronoEditWanAttention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+) -> torch.Tensor:
+ encoder_hidden_states_img = None
+ if attn.add_k_proj is not None:
+ # 512 is the context length of the text encoder, hardcoded for now
+ image_context_length = encoder_hidden_states.shape[1] - 512
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
+
+ query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ if rotary_emb is not None:
+
+ def apply_rotary_emb(
+ hidden_states: torch.Tensor,
+ freqs_cos: torch.Tensor,
+ freqs_sin: torch.Tensor,
+ ):
+ x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
+ cos = freqs_cos[..., 0::2]
+ sin = freqs_sin[..., 1::2]
+ out = torch.empty_like(hidden_states)
+ out[..., 0::2] = x1 * cos - x2 * sin
+ out[..., 1::2] = x1 * sin + x2 * cos
+ return out.type_as(hidden_states)
+
+ query = apply_rotary_emb(query, *rotary_emb)
+ key = apply_rotary_emb(key, *rotary_emb)
+
+ # I2V task
+ hidden_states_img = None
+ if encoder_hidden_states_img is not None:
+ key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
+ key_img = attn.norm_added_k(key_img)
+
+ key_img = key_img.unflatten(2, (attn.heads, -1))
+ value_img = value_img.unflatten(2, (attn.heads, -1))
+
+ hidden_states_img = dispatch_attention_fn(
+ query,
+ key_img,
+ value_img,
+ attn_mask=None,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ # FIXME(DefTruth): Since the key/value in cross-attention depends
+ # solely on encoder_hidden_states_img (img), the (q_chunk * k) * v
+ # computation can be parallelized independently. Thus, there is
+ # no need to pass the config here.
+ parallel_config=None,
+ )
+ hidden_states_img = hidden_states_img.flatten(2, 3)
+ hidden_states_img = hidden_states_img.type_as(query)
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ # FIXME(DefTruth): Since the key/value in cross-attention depends
+ # solely on encoder_hidden_states (text), the (q_chunk * k) * v
+ # computation can be parallelized independently. Thus, there is
+ # no need to pass the config here.
+ parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.type_as(query)
+
+ if hidden_states_img is not None:
+ hidden_states = hidden_states + hidden_states_img
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_cogvideox.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_cogvideox.py
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_cogview.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_cogview.py
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_cosisid.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_cosisid.py
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_dit.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_dit.py
diff --git a/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_flux.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_flux.py
new file mode 100644
index 000000000..1e1030497
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_flux.py
@@ -0,0 +1,294 @@
+import torch
+import functools
+from typing import Optional, Tuple, Dict, Any
+from torch.distributed import DeviceMesh
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers import FluxTransformer2DModel
+from diffusers.models.transformers.transformer_flux import (
+ FluxSingleTransformerBlock,
+ FluxAttnProcessor,
+ FluxAttention,
+ apply_rotary_emb,
+ dispatch_attention_fn,
+)
+
+try:
+ from diffusers.models._modeling_parallel import (
+ ContextParallelInput,
+ ContextParallelOutput,
+ ContextParallelModelPlan,
+ )
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+from .cp_plan_registers import (
+ ContextParallelismPlanner,
+ ContextParallelismPlannerRegister,
+)
+
+from cache_dit.logger import init_logger
+
+from cache_dit.parallelism.attention import _unified_all_to_all_o_async_fn
+from cache_dit.parallelism.attention import _unified_all_to_all_qkv_async_fn
+from cache_dit.parallelism.attention import _prepare_ulysses_comm_metadata
+
+logger = init_logger(__name__)
+
+
+@ContextParallelismPlannerRegister.register("FluxTransformer2DModel")
+class FluxContextParallelismPlanner(ContextParallelismPlanner):
+ def apply(
+ self,
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
+ **kwargs,
+ ) -> ContextParallelModelPlan:
+
+ experimental_ulysses_async = kwargs.get("experimental_ulysses_async", False)
+ if experimental_ulysses_async:
+ FluxAttnProcessor.__call__ = __patch_FluxAttnProcessor_ulysses_async__call__
+ FluxSingleTransformerBlock.forward = (
+ __patch_FluxSingleTransformerBlock_ulysses_async_forward__
+ )
+ logger.info(
+ "Enabled experimental Async QKV Projection with Ulysses style "
+ "Context Parallelism for FluxTransformer2DModel."
+ )
+
+ if transformer is not None and self._cp_planner_preferred_native_diffusers:
+ assert isinstance(
+ transformer, FluxTransformer2DModel
+ ), "Transformer must be an instance of FluxTransformer2DModel"
+ if hasattr(transformer, "_cp_plan"):
+ if transformer._cp_plan is not None:
+ return transformer._cp_plan
+
+ # Otherwise, use the custom CP plan defined here, this maybe
+ # a little different from the native diffusers implementation
+ # for some models.
+ _cp_plan = {
+ # Here is a Transformer level CP plan for Flux, which will
+ # only apply the only 1 split hook (pre_forward) on the forward
+ # of Transformer, and gather the output after Transformer forward.
+ # Pattern of transformer forward, split_output=False:
+ # un-split input -> splited input (inside transformer)
+ # Pattern of the transformer_blocks, single_transformer_blocks:
+ # splited input (previous splited output) -> to_qkv/...
+ # -> all2all
+ # -> attn (local head, full seqlen)
+ # -> all2all
+ # -> splited output
+ # The `hidden_states` and `encoder_hidden_states` will still keep
+ # itself splited after block forward (namely, automatic split by
+ # the all2all comm op after attn) for the all blocks.
+ # img_ids and txt_ids will only be splited once at the very beginning,
+ # and keep splited through the whole transformer forward. The all2all
+ # comm op only happens on the `out` tensor after local attn not on
+ # img_ids and txt_ids.
+ "": {
+ "hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "encoder_hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ },
+ # Then, the final proj_out will gather the splited output.
+ # splited input (previous splited output)
+ # -> all gather
+ # -> un-split output
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+ return _cp_plan
+
+
+# Async Ulysses QKV Proj for FLUX model
+# Reference:
+# - https://github.com/ByteDance-Seed/VeOmni/blob/main/veomni/distributed/sequence_parallel/async_ulysses.py#L43
+# - https://github.com/huggingface/diffusers/pull/12727 by @zhangtao0408
+def _ulysses_attn_with_async_qkv_proj_flux(
+ self: FluxAttnProcessor,
+ attn: FluxAttention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+
+ ulysses_mesh: DeviceMesh = self._parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ _all_to_all_o_async_func = _unified_all_to_all_o_async_fn()
+ _all_to_all_qv_async_func = _unified_all_to_all_qkv_async_fn()
+ _all_to_all_k_async_func = _unified_all_to_all_qkv_async_fn(fp8=False)
+
+ value = attn.to_v(hidden_states) # type: torch.Tensor
+ value = value.unflatten(-1, (attn.heads, -1))
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_value = attn.add_v_proj(encoder_hidden_states) # type: torch.Tensor
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+ value = torch.cat([encoder_value, value], dim=1)
+
+ metadata = _prepare_ulysses_comm_metadata(value)
+
+ # Async all to all for value
+ value_wait = _all_to_all_qv_async_func(value, group, **metadata)
+
+ query = attn.to_q(hidden_states)
+ query = query.unflatten(-1, (attn.heads, -1)) # type: torch.Tensor
+ query = attn.norm_q(query)
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) # type: torch.Tensor
+ encoder_query = attn.norm_added_q(encoder_query)
+ query = torch.cat([encoder_query, query], dim=1)
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+
+ # Async all to all for query
+ query_wait = _all_to_all_qv_async_func(query, group, **metadata)
+
+ key = attn.to_k(hidden_states) # type: torch.Tensor
+ key = key.unflatten(-1, (attn.heads, -1))
+ key = attn.norm_k(key)
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) # type: torch.Tensor
+ encoder_key = attn.norm_added_k(encoder_key)
+ key = torch.cat([encoder_key, key], dim=1)
+ if image_rotary_emb is not None:
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ # Async all to all for key
+ key_wait = _all_to_all_k_async_func(key, group, **metadata)
+
+ # Ensure the query, key, value are ready
+ value = value_wait()
+ query = query_wait()
+ key = key_wait()
+
+ out = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=None, # set to None to avoid double parallelism
+ ) # (B, S_GLOBAL, H_LOCAL, D)
+
+ if encoder_hidden_states is not None:
+ # Must be sync all to all for out when encoder_hidden_states is used
+ out_wait = _all_to_all_o_async_func(out, group, **metadata) # (B, S_LOCAL, H_GLOBAL, D)
+ out = out_wait() # type: torch.Tensor
+
+ hidden_states = out.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [
+ encoder_hidden_states.shape[1],
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
+ ],
+ dim=1,
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ # Can be async all to all for out when no encoder_hidden_states
+ out_wait = _all_to_all_o_async_func(out, group, **metadata) # (B, S_LOCAL, H_GLOBAL, D)
+ return out_wait
+
+
+FluxAttnProcessor_original__call__ = FluxAttnProcessor.__call__
+
+
+@functools.wraps(FluxAttnProcessor_original__call__)
+def __patch_FluxAttnProcessor_ulysses_async__call__(
+ self: FluxAttnProcessor,
+ attn: "FluxAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ if (
+ self._parallel_config is not None
+ and hasattr(self._parallel_config, "context_parallel_config")
+ and self._parallel_config.context_parallel_config is not None
+ and self._parallel_config.context_parallel_config.ulysses_degree > 1
+ ):
+ return _ulysses_attn_with_async_qkv_proj_flux(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # Otherwise, use the original call for non-ulysses case
+ return FluxAttnProcessor_original__call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+
+@functools.wraps(FluxSingleTransformerBlock.forward)
+def __patch_FluxSingleTransformerBlock_ulysses_async_forward__(
+ self: FluxSingleTransformerBlock,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ # Perform attention with Ulysses async QKV proj, the attn_output
+ # may be is an instance of AsyncCollectiveTensor.
+ attn_output_wait = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+ # NOTE: Enable the out all2all overlap with mlp computation
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ # NOTE: Then ensure the attn_output is ready
+ if not isinstance(attn_output_wait, torch.Tensor):
+ attn_output = attn_output_wait() # type: torch.Tensor
+ else:
+ attn_output = attn_output_wait
+ attn_output = attn_output.contiguous()
+ if attn_output.ndim == 4:
+ attn_output = attn_output.flatten(2, 3)
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, :text_seq_len],
+ hidden_states[:, text_seq_len:],
+ )
+ return encoder_hidden_states, hidden_states
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_flux2.py
similarity index 81%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_flux2.py
index dda3ca595..21d86a0c3 100644
--- a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_flux2.py
@@ -1,7 +1,7 @@
import torch
from typing import Optional
from diffusers.models.modeling_utils import ModelMixin
-from diffusers import FluxTransformer2DModel
+from diffusers import Flux2Transformer2DModel
try:
from diffusers.models._modeling_parallel import (
@@ -25,17 +25,21 @@
logger = init_logger(__name__)
-@ContextParallelismPlannerRegister.register("Flux")
-class FluxContextParallelismPlanner(ContextParallelismPlanner):
+@ContextParallelismPlannerRegister.register("Flux2Transformer2DModel")
+class Flux2ContextParallelismPlanner(ContextParallelismPlanner):
def apply(
self,
transformer: Optional[torch.nn.Module | ModelMixin] = None,
**kwargs,
) -> ContextParallelModelPlan:
+
+ # NOTE: Diffusers native CP plan still have bugs for Flux2 now.
+ self._cp_planner_preferred_native_diffusers = False
+
if transformer is not None and self._cp_planner_preferred_native_diffusers:
assert isinstance(
- transformer, FluxTransformer2DModel
- ), "Transformer must be an instance of FluxTransformer2DModel"
+ transformer, Flux2Transformer2DModel
+ ), "Transformer must be an instance of Flux2Transformer2DModel"
if hasattr(transformer, "_cp_plan"):
if transformer._cp_plan is not None:
return transformer._cp_plan
@@ -69,8 +73,8 @@ def apply(
"encoder_hidden_states": ContextParallelInput(
split_dim=1, expected_dims=3, split_output=False
),
- "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
- "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ "img_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "txt_ids": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
# Then, the final proj_out will gather the splited output.
# splited input (previous splited output)
@@ -79,3 +83,6 @@ def apply(
"proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
return _cp_plan
+
+
+# TODO: Add async Ulysses QKV proj for FLUX2 model
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_hunyuan.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_hunyuan.py
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_kandinsky.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_kandinsky.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_kandinsky.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_kandinsky.py
diff --git a/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_longcat_image.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_longcat_image.py
new file mode 100644
index 000000000..570fb9f23
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_longcat_image.py
@@ -0,0 +1,280 @@
+import torch
+import functools
+from typing import Optional, Tuple, Dict, Any
+from torch.distributed import DeviceMesh
+from diffusers.models.modeling_utils import ModelMixin
+
+try:
+ from diffusers import LongCatImageTransformer2DModel
+ from diffusers.models.transformers.transformer_longcat_image import (
+ LongCatImageSingleTransformerBlock,
+ LongCatImageAttnProcessor,
+ LongCatImageAttention,
+ apply_rotary_emb,
+ dispatch_attention_fn,
+ ) # requires diffusers>=0.37.0.dev0
+
+ _longcat_image_is_available = True
+except ImportError:
+ _longcat_image_is_available = False
+
+try:
+ from diffusers.models._modeling_parallel import (
+ ContextParallelInput,
+ ContextParallelOutput,
+ ContextParallelModelPlan,
+ )
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+from .cp_plan_registers import (
+ ContextParallelismPlanner,
+ ContextParallelismPlannerRegister,
+)
+
+from cache_dit.logger import init_logger
+
+from cache_dit.parallelism.attention import _unified_all_to_all_o_async_fn
+from cache_dit.parallelism.attention import _unified_all_to_all_qkv_async_fn
+from cache_dit.parallelism.attention import _prepare_ulysses_comm_metadata
+
+logger = init_logger(__name__)
+
+
+@ContextParallelismPlannerRegister.register("LongCatImageTransformer2DModel")
+class LongCatImageContextParallelismPlanner(ContextParallelismPlanner):
+ def apply(
+ self,
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
+ **kwargs,
+ ) -> ContextParallelModelPlan:
+
+ if not _longcat_image_is_available:
+ logger.warning(
+ "Diffusers LongCatImageTransformer2DModel or related classes are not found. "
+ "Please install diffusers>=0.37.0.dev0 from source. Skipping CP plan for LongCatImage."
+ )
+ return transformer
+
+ experimental_ulysses_async = kwargs.get("experimental_ulysses_async", False)
+ if experimental_ulysses_async:
+ LongCatImageAttnProcessor.__call__ = (
+ __patch_LongCatImageAttnProcessor_ulysses_async__call__
+ )
+ LongCatImageSingleTransformerBlock.forward = (
+ __patch_LongCatImageSingleTransformerBlock_ulysses_async_forward__
+ )
+ logger.info(
+ "Enabled experimental Async QKV Projection with Ulysses style "
+ "Context Parallelism for LongCatImageTransformer2DModel."
+ )
+
+ if transformer is not None and self._cp_planner_preferred_native_diffusers:
+ assert isinstance(
+ transformer, LongCatImageTransformer2DModel
+ ), "Transformer must be an instance of LongCatImageTransformer2DModel"
+ if hasattr(transformer, "_cp_plan"):
+ if transformer._cp_plan is not None:
+ return transformer._cp_plan
+
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "encoder_hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+ return _cp_plan
+
+
+# Async Ulysses QKV Proj for LongCatImage
+if _longcat_image_is_available:
+
+ def _ulysses_attn_with_async_qkv_proj_longcat_image(
+ self: LongCatImageAttnProcessor,
+ attn: LongCatImageAttention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+
+ ulysses_mesh: DeviceMesh = self._parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ _all_to_all_o_async_func = _unified_all_to_all_o_async_fn()
+ _all_to_all_qv_async_func = _unified_all_to_all_qkv_async_fn()
+ _all_to_all_k_async_func = _unified_all_to_all_qkv_async_fn(fp8=False)
+
+ value = attn.to_v(hidden_states) # type: torch.Tensor
+ value = value.unflatten(-1, (attn.heads, -1))
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_value = attn.add_v_proj(encoder_hidden_states) # type: torch.Tensor
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+ value = torch.cat([encoder_value, value], dim=1)
+
+ metadata = _prepare_ulysses_comm_metadata(value)
+
+ # Async all to all for value
+ value_wait = _all_to_all_qv_async_func(value, group, **metadata)
+
+ query = attn.to_q(hidden_states)
+ query = query.unflatten(-1, (attn.heads, -1)) # type: torch.Tensor
+ query = attn.norm_q(query)
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) # type: torch.Tensor
+ encoder_query = attn.norm_added_q(encoder_query)
+ query = torch.cat([encoder_query, query], dim=1)
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+
+ # Async all to all for query
+ query_wait = _all_to_all_qv_async_func(query, group, **metadata)
+
+ key = attn.to_k(hidden_states) # type: torch.Tensor
+ key = key.unflatten(-1, (attn.heads, -1))
+ key = attn.norm_k(key)
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) # type: torch.Tensor
+ encoder_key = attn.norm_added_k(encoder_key)
+ key = torch.cat([encoder_key, key], dim=1)
+ if image_rotary_emb is not None:
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ # Async all to all for key
+ key_wait = _all_to_all_k_async_func(key, group, **metadata)
+
+ # Ensure the query, key, value are ready
+ value = value_wait()
+ query = query_wait()
+ key = key_wait()
+
+ out = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=None, # set to None to avoid double parallelism
+ ) # (B, S_GLOBAL, H_LOCAL, D)
+
+ if encoder_hidden_states is not None:
+ # Must be sync all to all for out when encoder_hidden_states is used
+ out_wait = _all_to_all_o_async_func(out, group, **metadata) # (B, S_LOCAL, H_GLOBAL, D)
+ out = out_wait() # type: torch.Tensor
+
+ hidden_states = out.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [
+ encoder_hidden_states.shape[1],
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
+ ],
+ dim=1,
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ # Can be async all to all for out when no encoder_hidden_states
+ out_wait = _all_to_all_o_async_func(out, group, **metadata) # (B, S_LOCAL, H_GLOBAL, D)
+ return out_wait
+
+ LongCatImageAttnProcessor_original__call__ = LongCatImageAttnProcessor.__call__
+
+ @functools.wraps(LongCatImageAttnProcessor_original__call__)
+ def __patch_LongCatImageAttnProcessor_ulysses_async__call__(
+ self: LongCatImageAttnProcessor,
+ attn: "LongCatImageAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if (
+ self._parallel_config is not None
+ and hasattr(self._parallel_config, "context_parallel_config")
+ and self._parallel_config.context_parallel_config is not None
+ and self._parallel_config.context_parallel_config.ulysses_degree > 1
+ ):
+ return _ulysses_attn_with_async_qkv_proj_longcat_image(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # Otherwise, use the original call for non-ulysses case
+ return LongCatImageAttnProcessor_original__call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ @functools.wraps(LongCatImageSingleTransformerBlock.forward)
+ def __patch_LongCatImageSingleTransformerBlock_ulysses_async_forward__(
+ self: LongCatImageSingleTransformerBlock,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ # Perform attention with Ulysses async QKV proj, the attn_output
+ # may be is an instance of AsyncCollectiveTensor.
+ attn_output_wait = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+ # NOTE: Enable the out all2all overlap with mlp computation
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
+
+ # NOTE: Then ensure the attn_output is ready
+ if not isinstance(attn_output_wait, torch.Tensor):
+ attn_output = attn_output_wait() # type: torch.Tensor
+ else:
+ attn_output = attn_output_wait
+ attn_output = attn_output.contiguous()
+ if attn_output.ndim == 4:
+ attn_output = attn_output.flatten(2, 3)
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, :text_seq_len],
+ hidden_states[:, text_seq_len:],
+ )
+ return encoder_hidden_states, hidden_states
diff --git a/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_ltx2.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_ltx2.py
new file mode 100644
index 000000000..fab42af1b
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_ltx2.py
@@ -0,0 +1,279 @@
+# Mostly copy from https://github.com/vipshop/cache-dit/blob/main/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_ltxvideo.py
+import functools
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.transformers.transformer_ltx2 import (
+ LTX2Attention,
+ LTX2AudioVideoAttnProcessor,
+ LTX2VideoTransformer3DModel,
+)
+from diffusers.models.attention_dispatch import dispatch_attention_fn
+
+try:
+ from diffusers.models._modeling_parallel import (
+ ContextParallelInput,
+ ContextParallelModelPlan,
+ ContextParallelOutput,
+ )
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+
+from .cp_plan_registers import (
+ ContextParallelismPlanner,
+ ContextParallelismPlannerRegister,
+)
+
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+@ContextParallelismPlannerRegister.register("LTX2")
+class LTX2ContextParallelismPlanner(ContextParallelismPlanner):
+ def apply(
+ self,
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
+ **kwargs,
+ ) -> ContextParallelModelPlan:
+ assert transformer is not None, "Transformer must be provided."
+ assert isinstance(
+ transformer, LTX2VideoTransformer3DModel
+ ), "Transformer must be an instance of LTX2VideoTransformer3DModel"
+
+ # NOTE:
+ # - LTX2ImageToVideoPipeline passes `timestep` as a 2D tensor (B, seq_len) named `video_timestep`.
+ # - diffusers native LTX2 `_cp_plan` does NOT shard `timestep`, causing shape mismatch under CP:
+ # hidden_states: (B, seq_len/world, C) but temb built from timestep.flatten(): (B, seq_len, ...)
+ # leading to: RuntimeError size mismatch (1536 vs 6144).
+ # So we must use a custom plan for correctness under Ulysses/Ring CP.
+ self._cp_planner_preferred_native_diffusers = False
+
+ # Patch attention_mask preparation for CP head sharding + global seq padding
+ LTX2Attention.prepare_attention_mask = __patch__LTX2Attention_prepare_attention_mask__ # type: ignore[assignment]
+ LTX2AudioVideoAttnProcessor.__call__ = __patch__LTX2AudioVideoAttnProcessor__call__ # type: ignore[assignment]
+
+ rope_type = getattr(getattr(transformer, "config", None), "rope_type", "interleaved")
+ if rope_type == "split":
+ # split RoPE returns (B, H, T, D/2), shard along T dim
+ rope_expected_dims = 4
+ rope_split_dim = 2
+ else:
+ # interleaved RoPE returns (B, T, D), shard along T dim
+ rope_expected_dims = 3
+ rope_split_dim = 1
+
+ _cp_plan: ContextParallelModelPlan = {
+ "": {
+ # Shard video/audio latents across sequence
+ "hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "audio_hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ # Shard prompt embeds across sequence
+ "encoder_hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "audio_encoder_hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ # IMPORTANT: shard video timestep (B, seq_len) to match sharded hidden_states
+ "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ # NOTE: do NOT shard attention masks; handled in patched attention processor
+ },
+ # Split RoPE outputs to match CP-sharded sequence length
+ "rope": {
+ 0: ContextParallelInput(
+ split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
+ ),
+ 1: ContextParallelInput(
+ split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
+ ),
+ },
+ "audio_rope": {
+ 0: ContextParallelInput(
+ split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
+ ),
+ 1: ContextParallelInput(
+ split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
+ ),
+ },
+ "cross_attn_rope": {
+ 0: ContextParallelInput(
+ split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
+ ),
+ 1: ContextParallelInput(
+ split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
+ ),
+ },
+ "cross_attn_audio_rope": {
+ 0: ContextParallelInput(
+ split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
+ ),
+ 1: ContextParallelInput(
+ split_dim=rope_split_dim, expected_dims=rope_expected_dims, split_output=True
+ ),
+ },
+ # Gather outputs before returning
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ "audio_proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ return _cp_plan
+
+
+# Upstream links (for cross-checking when updating diffusers):
+# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_ltx2.py
+# - https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
+
+
+@functools.wraps(LTX2Attention.prepare_attention_mask)
+def __patch__LTX2Attention_prepare_attention_mask__(
+ self: LTX2Attention,
+ attention_mask: torch.Tensor,
+ target_length: int,
+ batch_size: int,
+ out_dim: int = 3,
+ # NOTE: Allow specifying head_size for CP
+ head_size: Optional[int] = None,
+) -> torch.Tensor:
+ # Differences vs diffusers:
+ # - diffusers signature does not accept `head_size` and always uses `self.heads`.
+ # - under Context Parallelism, each rank only owns `attn.heads // world_size` heads.
+ # If we keep repeating the mask with the full `self.heads`, the mask shape will not
+ # match the sharded attention computation.
+ if head_size is None:
+ head_size = self.heads
+ if attention_mask is None:
+ return attention_mask
+
+ current_length: int = attention_mask.shape[-1]
+ if current_length != target_length:
+ if attention_mask.device.type == "mps":
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
+ padding = torch.zeros(
+ padding_shape, dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
+ else:
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+
+ if out_dim == 3:
+ if attention_mask.shape[0] < batch_size * head_size:
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
+ elif out_dim == 4:
+ attention_mask = attention_mask.unsqueeze(1)
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
+
+ return attention_mask
+
+
+@functools.wraps(LTX2AudioVideoAttnProcessor.__call__)
+def __patch__LTX2AudioVideoAttnProcessor__call__(
+ self: LTX2AudioVideoAttnProcessor,
+ attn: "LTX2Attention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ query_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+ key_rotary_emb: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
+) -> torch.Tensor:
+ # Differences vs diffusers (transformer_ltx2.py):
+ # - diffusers always prepares attention_mask using the *local* `sequence_length` and
+ # reshapes it with `attn.heads`.
+ # - when Context Parallelism is enabled, `hidden_states` is sharded on seq dim, so
+ # `sequence_length` here is per-rank. However attention_mask typically corresponds
+ # to the *global* sequence length (before sharding), and each rank only uses a shard
+ # of heads (`attn.heads // world_size`).
+ # - this patch therefore:
+ # 1) uses `target_length = sequence_length * world_size` when CP is active
+ # 2) repeats/reshapes the mask using `head_size = attn.heads // world_size`
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ if attention_mask is not None:
+ if self._parallel_config is None:
+ attention_mask = attn.prepare_attention_mask(
+ attention_mask, sequence_length, batch_size
+ )
+ attention_mask = attention_mask.view(
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
+ )
+ else:
+ cp_config = getattr(self._parallel_config, "context_parallel_config", None)
+ if cp_config is not None and cp_config._world_size > 1:
+ head_size = attn.heads // cp_config._world_size
+ attention_mask = attn.prepare_attention_mask(
+ attention_mask,
+ sequence_length * cp_config._world_size,
+ batch_size,
+ 3,
+ head_size,
+ )
+ attention_mask = attention_mask.view(
+ batch_size, head_size, -1, attention_mask.shape[-1]
+ )
+ else:
+ attention_mask = attn.prepare_attention_mask(
+ attention_mask, sequence_length, batch_size
+ )
+ attention_mask = attention_mask.view(
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
+ )
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query = attn.norm_q(query)
+ key = attn.norm_k(key)
+
+ if query_rotary_emb is not None:
+ # Keep RoPE logic identical to upstream: for v2a/a2v cross-attn, K can use separate RoPE.
+ if attn.rope_type == "interleaved":
+ from diffusers.models.transformers.transformer_ltx2 import apply_interleaved_rotary_emb
+
+ query = apply_interleaved_rotary_emb(query, query_rotary_emb)
+ key = apply_interleaved_rotary_emb(
+ key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
+ )
+ elif attn.rope_type == "split":
+ from diffusers.models.transformers.transformer_ltx2 import apply_split_rotary_emb
+
+ query = apply_split_rotary_emb(query, query_rotary_emb)
+ key = apply_split_rotary_emb(
+ key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
+ )
+
+ query = query.unflatten(2, (attn.heads, -1))
+ key = key.unflatten(2, (attn.heads, -1))
+ value = value.unflatten(2, (attn.heads, -1))
+
+ hidden_states = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=self._parallel_config,
+ )
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_ltxvideo.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_ltxvideo.py
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_nunchaku.py
similarity index 67%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_nunchaku.py
index 961de08ba..4bc8f7581 100644
--- a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_nunchaku.py
@@ -19,11 +19,17 @@
NunchakuQwenImageNaiveFA2Processor,
NunchakuQwenImageTransformer2DModel,
)
+ from nunchaku.models.transformers.transformer_zimage import (
+ NunchakuZImageTransformer2DModel,
+ NunchakuZSingleStreamAttnProcessor,
+ NunchakuZImageAttention,
+ )
except ImportError:
raise ImportError(
- "NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel "
- "requires the 'nunchaku' package. Please install nunchaku before using "
- "the context parallelism for nunchaku 4-bits models."
+ "NunchakuZImageTransformer2DModel, NunchakuFluxTransformer2DModelV2 and "
+ "NunchakuQwenImageTransformer2DModel requires the 'nunchaku' package. "
+ "Please install nunchaku>=1.10 before using the context parallelism for "
+ "nunchaku 4-bits models."
)
try:
@@ -43,6 +49,7 @@
ContextParallelismPlannerRegister,
)
+from cache_dit.parallelism.attention import _maybe_patch_find_submodule
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -298,8 +305,8 @@ def apply(
@functools.wraps(NunchakuQwenImageNaiveFA2Processor.__call__)
def __patch_NunchakuQwenImageNaiveFA2Processor__call__(
- self,
- attn,
+ self: NunchakuQwenImageNaiveFA2Processor,
+ attn: NunchakuQwenAttention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
encoder_hidden_states_mask: torch.FloatTensor = None,
@@ -383,3 +390,139 @@ def __patch_NunchakuQwenImageNaiveFA2Processor__call__(
txt_attn_output = attn.to_add_out(txt_attn_output)
return img_attn_output, txt_attn_output
+
+
+@ContextParallelismPlannerRegister.register("NunchakuZImageTransformer2DModel")
+class NunchakuZImageContextParallelismPlanner(ContextParallelismPlanner):
+ def apply(
+ self,
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
+ **kwargs,
+ ) -> ContextParallelModelPlan:
+
+ # NOTE: Diffusers native CP plan still not supported for ZImageTransformer2DModel
+ self._cp_planner_preferred_native_diffusers = False
+
+ if transformer is not None and self._cp_planner_preferred_native_diffusers:
+ assert isinstance(
+ transformer, NunchakuZImageTransformer2DModel
+ ), "Transformer must be an instance of NunchakuZImageTransformer2DModel"
+ if hasattr(transformer, "_cp_plan"):
+ if transformer._cp_plan is not None:
+ return transformer._cp_plan
+
+ # NOTE: This only a temporary workaround for ZImage to make context parallelism
+ # work compatible with DBCache FnB0. The better way is to make DBCache fully
+ # compatible with diffusers native context parallelism, e.g., check the split/gather
+ # hooks in each block/layer in the initialization of DBCache.
+ # Issue: https://github.com/vipshop/cache-dit/issues/498
+ _maybe_patch_find_submodule()
+ if not hasattr(NunchakuZSingleStreamAttnProcessor, "_parallel_config"):
+ NunchakuZSingleStreamAttnProcessor._parallel_config = None
+ if not hasattr(NunchakuZSingleStreamAttnProcessor, "_attention_backend"):
+ NunchakuZSingleStreamAttnProcessor._attention_backend = None
+ if not hasattr(NunchakuZImageAttention, "_parallel_config"):
+ NunchakuZImageAttention._parallel_config = None
+ if not hasattr(NunchakuZImageAttention, "_attention_backend"):
+ NunchakuZImageAttention._attention_backend = None
+
+ n_noise_refiner_layers = len(transformer.noise_refiner) # 2
+ n_context_refiner_layers = len(transformer.context_refiner) # 2
+ n_layers = len(transformer.layers) # 30
+ # controlnet layer idx: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28]
+ # num_controlnet_samples = len(transformer.layers) // 2 # 15
+ has_controlnet = kwargs.get("has_controlnet", None)
+ if not has_controlnet:
+ # cp plan for ZImageTransformer2DModel if no controlnet
+ _cp_plan = {
+ # 0. Hooks for noise_refiner layers, 2
+ "noise_refiner.0": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "noise_refiner.*": {
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ f"noise_refiner.{n_noise_refiner_layers - 1}": ContextParallelOutput(
+ gather_dim=1, expected_dims=3
+ ),
+ # 1. Hooks for context_refiner layers, 2
+ "context_refiner.0": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "context_refiner.*": {
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ f"context_refiner.{n_context_refiner_layers - 1}": ContextParallelOutput(
+ gather_dim=1, expected_dims=3
+ ),
+ # 2. Hooks for main transformer layers, num_layers=30
+ "layers.0": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "layers.*": {
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ # NEED: call _maybe_patch_find_submodule to support ModuleDict like 'all_final_layer'
+ "all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ # NOTE: The 'all_final_layer' is a ModuleDict of several final layers,
+ # each for a specific patch size combination, so we do not add hooks for it here.
+ # So, we have to gather the output of the last transformer layer.
+ # f"layers.{num_layers - 1}": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+ else:
+ # Special cp plan for NunchakuZImageTransformer2DModel with ZImageControlNetModel
+ logger.warning(
+ "Using special context parallelism plan for NunchakuZImageTransformer2DModel "
+ "due to the 'has_controlnet' flag is set to True."
+ )
+ _cp_plan = {
+ # zimage controlnet shared the same refiner as zimage, so, we need to
+ # add gather hooks for all layers in noise_refiner and context_refiner.
+ # 0. Hooks for noise_refiner layers, 2
+ # Insert gather hook after each layers due to the ops: (controlnet)
+ # - x = x + noise_refiner_block_samples[layer_idx]
+ "noise_refiner.*": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ **{
+ f"noise_refiner.{i}": ContextParallelOutput(gather_dim=1, expected_dims=3)
+ for i in range(n_noise_refiner_layers)
+ },
+ # 1. Hooks for context_refiner layers, 2
+ "context_refiner.0": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "context_refiner.*": {
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ f"context_refiner.{n_context_refiner_layers - 1}": ContextParallelOutput(
+ gather_dim=1, expected_dims=3
+ ),
+ # 2. Hooks for main transformer layers, num_layers=30
+ # Insert gather hook after each layers due to the ops: (main transformer)
+ # - unified + controlnet_block_samples[layer_idx]
+ "layers.*": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ **{
+ f"layers.{i}": ContextParallelOutput(gather_dim=1, expected_dims=3)
+ for i in range(n_layers)
+ },
+ # NEED: call _maybe_patch_find_submodule to support ModuleDict like 'all_final_layer'
+ "all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+ return _cp_plan
diff --git a/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_ovis_image.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_ovis_image.py
new file mode 100644
index 000000000..16f14ca63
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_ovis_image.py
@@ -0,0 +1,302 @@
+import torch
+import functools
+from typing import Optional, Tuple, Dict, Any
+from torch.distributed import DeviceMesh
+from diffusers.models.modeling_utils import ModelMixin
+
+try:
+ from diffusers import OvisImageTransformer2DModel
+ from diffusers.models.transformers.transformer_ovis_image import (
+ OvisImageSingleTransformerBlock,
+ OvisImageAttnProcessor,
+ OvisImageAttention,
+ apply_rotary_emb,
+ dispatch_attention_fn,
+ )
+except ImportError:
+ raise ImportError(
+ "OvisImageTransformer2DModel requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+
+try:
+ from diffusers.models._modeling_parallel import (
+ ContextParallelInput,
+ ContextParallelOutput,
+ ContextParallelModelPlan,
+ )
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+from .cp_plan_registers import (
+ ContextParallelismPlanner,
+ ContextParallelismPlannerRegister,
+)
+
+from cache_dit.logger import init_logger
+
+from cache_dit.parallelism.attention import _unified_all_to_all_o_async_fn
+from cache_dit.parallelism.attention import _unified_all_to_all_qkv_async_fn
+from cache_dit.parallelism.attention import _prepare_ulysses_comm_metadata
+
+logger = init_logger(__name__)
+
+
+@ContextParallelismPlannerRegister.register("OvisImageTransformer2DModel")
+class OvisImageContextParallelismPlanner(ContextParallelismPlanner):
+ def apply(
+ self,
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
+ **kwargs,
+ ) -> ContextParallelModelPlan:
+
+ experimental_ulysses_async = kwargs.get("experimental_ulysses_async", False)
+ if experimental_ulysses_async:
+ OvisImageAttnProcessor.__call__ = __patch_OvisImageAttnProcessor_ulysses_async__call__
+ OvisImageSingleTransformerBlock.forward = (
+ __patch_OvisImageSingleTransformerBlock_ulysses_async_forward__
+ )
+ logger.info(
+ "Enabled experimental Async QKV Projection with Ulysses style "
+ "Context Parallelism for OvisImageTransformer2DModel."
+ )
+
+ if transformer is not None and self._cp_planner_preferred_native_diffusers:
+ assert isinstance(
+ transformer, OvisImageTransformer2DModel
+ ), "Transformer must be an instance of OvisImageTransformer2DModel"
+ if hasattr(transformer, "_cp_plan"):
+ if transformer._cp_plan is not None:
+ return transformer._cp_plan
+
+ # Otherwise, use the custom CP plan defined here, this maybe
+ # a little different from the native diffusers implementation
+ # for some models.
+ _cp_plan = {
+ # Here is a Transformer level CP plan for OvisImage, which will
+ # only apply the only 1 split hook (pre_forward) on the forward
+ # of Transformer, and gather the output after Transformer forward.
+ # Pattern of transformer forward, split_output=False:
+ # un-split input -> splited input (inside transformer)
+ # Pattern of the transformer_blocks, single_transformer_blocks:
+ # splited input (previous splited output) -> to_qkv/...
+ # -> all2all
+ # -> attn (local head, full seqlen)
+ # -> all2all
+ # -> splited output
+ # The `hidden_states` and `encoder_hidden_states` will still keep
+ # itself splited after block forward (namely, automatic split by
+ # the all2all comm op after attn) for the all blocks.
+ # img_ids and txt_ids will only be splited once at the very beginning,
+ # and keep splited through the whole transformer forward. The all2all
+ # comm op only happens on the `out` tensor after local attn not on
+ # img_ids and txt_ids.
+ "": {
+ "hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "encoder_hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "img_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ "txt_ids": ContextParallelInput(split_dim=0, expected_dims=2, split_output=False),
+ },
+ # Then, the final proj_out will gather the splited output.
+ # splited input (previous splited output)
+ # -> all gather
+ # -> un-split output
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+ return _cp_plan
+
+
+# Async Ulysses QKV Proj for OvisImage model
+def _ulysses_attn_with_async_qkv_proj_ovis_image(
+ self: OvisImageAttnProcessor,
+ attn: OvisImageAttention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+
+ ulysses_mesh: DeviceMesh = self._parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ _all_to_all_o_async_func = _unified_all_to_all_o_async_fn()
+ _all_to_all_qv_async_func = _unified_all_to_all_qkv_async_fn()
+ _all_to_all_k_async_func = _unified_all_to_all_qkv_async_fn(fp8=False)
+
+ value = attn.to_v(hidden_states) # type: torch.Tensor
+ value = value.unflatten(-1, (attn.heads, -1))
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_value = attn.add_v_proj(encoder_hidden_states) # type: torch.Tensor
+ encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
+ value = torch.cat([encoder_value, value], dim=1)
+
+ metadata = _prepare_ulysses_comm_metadata(value)
+
+ # Async all to all for value
+ value_wait = _all_to_all_qv_async_func(value, group, **metadata)
+
+ query = attn.to_q(hidden_states)
+ query = query.unflatten(-1, (attn.heads, -1)) # type: torch.Tensor
+ query = attn.norm_q(query)
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
+ encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) # type: torch.Tensor
+ encoder_query = attn.norm_added_q(encoder_query)
+ query = torch.cat([encoder_query, query], dim=1)
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
+
+ # Async all to all for query
+ query_wait = _all_to_all_qv_async_func(query, group, **metadata)
+
+ key = attn.to_k(hidden_states) # type: torch.Tensor
+ key = key.unflatten(-1, (attn.heads, -1))
+ key = attn.norm_k(key)
+ if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
+ encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) # type: torch.Tensor
+ encoder_key = attn.norm_added_k(encoder_key)
+ key = torch.cat([encoder_key, key], dim=1)
+ if image_rotary_emb is not None:
+ key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
+
+ # Async all to all for key
+ key_wait = _all_to_all_k_async_func(key, group, **metadata)
+
+ # Ensure the query, key, value are ready
+ value = value_wait()
+ query = query_wait()
+ key = key_wait()
+
+ out = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ backend=self._attention_backend,
+ parallel_config=None, # set to None to avoid double parallelism
+ ) # (B, S_GLOBAL, H_LOCAL, D)
+
+ if encoder_hidden_states is not None:
+ # Must be sync all to all for out when encoder_hidden_states is used
+ out_wait = _all_to_all_o_async_func(out, group, **metadata) # (B, S_LOCAL, H_GLOBAL, D)
+ out = out_wait() # type: torch.Tensor
+
+ hidden_states = out.flatten(2, 3)
+ hidden_states = hidden_states.to(query.dtype)
+
+ encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
+ [
+ encoder_hidden_states.shape[1],
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
+ ],
+ dim=1,
+ )
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ return hidden_states, encoder_hidden_states
+ else:
+ # Can be async all to all for out when no encoder_hidden_states
+ out_wait = _all_to_all_o_async_func(out, group, **metadata) # (B, S_LOCAL, H_GLOBAL, D)
+ return out_wait
+
+
+OvisImageAttnProcessor_original__call__ = OvisImageAttnProcessor.__call__
+
+
+@functools.wraps(OvisImageAttnProcessor_original__call__)
+def __patch_OvisImageAttnProcessor_ulysses_async__call__(
+ self: OvisImageAttnProcessor,
+ attn: "OvisImageAttention",
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ if (
+ self._parallel_config is not None
+ and hasattr(self._parallel_config, "context_parallel_config")
+ and self._parallel_config.context_parallel_config is not None
+ and self._parallel_config.context_parallel_config.ulysses_degree > 1
+ ):
+ return _ulysses_attn_with_async_qkv_proj_ovis_image(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+ # Otherwise, use the original call for non-ulysses case
+ return OvisImageAttnProcessor_original__call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+
+
+@functools.wraps(OvisImageSingleTransformerBlock.forward)
+def __patch_OvisImageSingleTransformerBlock_ulysses_async_forward__(
+ self: OvisImageSingleTransformerBlock,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ temb: torch.Tensor,
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ text_seq_len = encoder_hidden_states.shape[1]
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
+
+ residual = hidden_states
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
+
+ joint_attention_kwargs = joint_attention_kwargs or {}
+ # Perform attention with Ulysses async QKV proj, the attn_output
+ # may be is an instance of AsyncCollectiveTensor.
+ attn_output_wait = self.attn(
+ hidden_states=norm_hidden_states,
+ image_rotary_emb=image_rotary_emb,
+ **joint_attention_kwargs,
+ )
+ # NOTE: Enable the out all2all overlap with mlp computation
+ mlp_hidden_states, mlp_hidden_gate = torch.split(
+ self.proj_mlp(norm_hidden_states), [self.mlp_hidden_dim, self.mlp_hidden_dim], dim=-1
+ )
+ mlp_hidden_states = self.act_mlp(mlp_hidden_gate) * mlp_hidden_states
+
+ # NOTE: Then ensure the attn_output is ready
+ if not isinstance(attn_output_wait, torch.Tensor):
+ attn_output = attn_output_wait() # type: torch.Tensor
+ else:
+ attn_output = attn_output_wait
+ attn_output = attn_output.contiguous()
+ if attn_output.ndim == 4:
+ attn_output = attn_output.flatten(2, 3)
+
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
+ gate = gate.unsqueeze(1)
+ hidden_states = gate * self.proj_out(hidden_states)
+ hidden_states = residual + hidden_states
+ if hidden_states.dtype == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+
+ encoder_hidden_states, hidden_states = (
+ hidden_states[:, :text_seq_len],
+ hidden_states[:, text_seq_len:],
+ )
+ return encoder_hidden_states, hidden_states
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_pixart.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_pixart.py
diff --git a/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_qwen_image.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_qwen_image.py
new file mode 100644
index 000000000..cda176836
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_qwen_image.py
@@ -0,0 +1,279 @@
+import torch
+import functools
+import diffusers
+from typing import Optional
+from torch.distributed import DeviceMesh
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers import QwenImageTransformer2DModel
+from diffusers.models.transformers.transformer_qwenimage import (
+ QwenDoubleStreamAttnProcessor2_0,
+ dispatch_attention_fn,
+ apply_rotary_emb_qwen,
+ Attention,
+)
+
+try:
+ from diffusers.models._modeling_parallel import (
+ ContextParallelInput,
+ ContextParallelOutput,
+ ContextParallelModelPlan,
+ )
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+from .cp_plan_registers import (
+ ContextParallelismPlanner,
+ ContextParallelismPlannerRegister,
+)
+
+from cache_dit.parallelism.attention import _unified_all_to_all_o_async_fn
+from cache_dit.parallelism.attention import _unified_all_to_all_qkv_async_fn
+from cache_dit.parallelism.attention import _prepare_ulysses_comm_metadata
+
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+@ContextParallelismPlannerRegister.register("QwenImage")
+class QwenImageContextParallelismPlanner(ContextParallelismPlanner):
+ def apply(
+ self,
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
+ **kwargs,
+ ) -> ContextParallelModelPlan:
+
+ # NOTE: Set it as False to use custom CP plan defined here.
+ self._cp_planner_preferred_native_diffusers = False
+
+ experimental_ulysses_async = kwargs.get("experimental_ulysses_async", False)
+ if experimental_ulysses_async:
+ QwenDoubleStreamAttnProcessor2_0.__call__ = (
+ __patch_QwenDoubleStreamAttnProcessor2_0_ulysses_async__call__
+ )
+
+ logger.info(
+ "Enabled experimental Async QKV Projection with Ulysses style "
+ "Context Parallelism for QwenImageTransformer2DModel."
+ )
+
+ if transformer is not None and self._cp_planner_preferred_native_diffusers:
+
+ assert isinstance(
+ transformer, QwenImageTransformer2DModel
+ ), "Transformer must be an instance of QwenImageTransformer2DModel"
+ if hasattr(transformer, "_cp_plan"):
+ if transformer._cp_plan is not None:
+ return transformer._cp_plan
+
+ if diffusers.__version__ <= "0.36.0":
+ _cp_plan = {
+ "": {
+ "hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "encoder_hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "encoder_hidden_states_mask": ContextParallelInput(
+ split_dim=1, expected_dims=2, split_output=False
+ ),
+ },
+ "pos_embed": {
+ 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+ else:
+ # Make CP plan compatible with https://github.com/huggingface/diffusers/pull/12702
+ _cp_plan = {
+ "transformer_blocks.0": {
+ "hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ "encoder_hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ "pos_embed": {
+ 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+ },
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+
+ zero_cond_t = getattr(transformer, "zero_cond_t", False)
+ if zero_cond_t:
+ # modulate_index: [b, l=seq_len], Qwen-Image-Edit-2511
+ _cp_plan.update(
+ {
+ "transformer_blocks.*": {
+ "modulate_index": ContextParallelInput(
+ split_dim=1, expected_dims=2, split_output=False
+ ),
+ }
+ }
+ )
+
+ return _cp_plan
+
+
+# NOTE: Support Async Ulysses QKV projection for Qwen-Image
+def _ulysses_attn_with_async_qkv_proj_qwen_image(
+ self: QwenDoubleStreamAttnProcessor2_0,
+ attn: Attention,
+ hidden_states: torch.FloatTensor, # Image stream
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
+ encoder_hidden_states_mask: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+) -> torch.FloatTensor:
+ if encoder_hidden_states is None:
+ raise ValueError(
+ "QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)"
+ )
+
+ ulysses_mesh: DeviceMesh = self._parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ _all_to_all_o_async_func = _unified_all_to_all_o_async_fn()
+ _all_to_all_qv_async_func = _unified_all_to_all_qkv_async_fn()
+ _all_to_all_k_async_func = _unified_all_to_all_qkv_async_fn(fp8=False)
+
+ seq_txt = encoder_hidden_states.shape[1]
+
+ img_value = attn.to_v(hidden_states)
+ txt_value = attn.add_v_proj(encoder_hidden_states)
+ img_value = img_value.unflatten(-1, (attn.heads, -1))
+ txt_value = txt_value.unflatten(-1, (attn.heads, -1))
+ joint_value = torch.cat([txt_value, img_value], dim=1)
+
+ metadata = _prepare_ulysses_comm_metadata(joint_value)
+
+ # Async all to all for value
+ joint_value_wait = _all_to_all_qv_async_func(joint_value, group, **metadata)
+
+ # Compute QKV for image stream (sample projections)
+ img_query = attn.to_q(hidden_states)
+ # Compute QKV for text stream (context projections)
+ txt_query = attn.add_q_proj(encoder_hidden_states)
+ # Reshape for multi-head attention
+ img_query = img_query.unflatten(-1, (attn.heads, -1))
+ txt_query = txt_query.unflatten(-1, (attn.heads, -1))
+ # Apply QK normalization
+ if attn.norm_q is not None:
+ img_query = attn.norm_q(img_query)
+ if attn.norm_added_q is not None:
+ txt_query = attn.norm_added_q(txt_query)
+ # Apply RoPE
+ if image_rotary_emb is not None:
+ img_freqs, txt_freqs = image_rotary_emb
+ img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
+ txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
+ # Concatenate for joint attention
+ # Order: [text, image]
+ joint_query = torch.cat([txt_query, img_query], dim=1)
+
+ # Async all to all for query
+ joint_query_wait = _all_to_all_qv_async_func(joint_query, group, **metadata)
+
+ img_key = attn.to_k(hidden_states)
+ txt_key = attn.add_k_proj(encoder_hidden_states)
+ img_key = img_key.unflatten(-1, (attn.heads, -1))
+ txt_key = txt_key.unflatten(-1, (attn.heads, -1))
+ if attn.norm_k is not None:
+ img_key = attn.norm_k(img_key)
+ if attn.norm_added_k is not None:
+ txt_key = attn.norm_added_k(txt_key)
+ # Apply RoPE
+ if image_rotary_emb is not None:
+ img_freqs, txt_freqs = image_rotary_emb
+ img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
+ txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
+ joint_key = torch.cat([txt_key, img_key], dim=1)
+
+ # Async all to all for key
+ joint_key_wait = _all_to_all_k_async_func(joint_key, group, **metadata)
+
+ # (S_GLOBAL, B, H_LOCAL, D) -> (B, S_GLOBAL, H_LOCAL, D)
+ joint_value = joint_value_wait() # type: torch.Tensor
+ joint_query = joint_query_wait() # type: torch.Tensor
+ joint_key = joint_key_wait() # type: torch.Tensor
+
+ # Compute joint attention
+ out = dispatch_attention_fn(
+ joint_query,
+ joint_key,
+ joint_value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=None, # set to None to avoid double parallelism
+ ) # (B, S_GLOBAL, H_LOCAL, D)
+
+ # TODO: Split output before all to all to apply Async all to all,
+ # overlap comm and compute of _to_out and to_add_out projections.
+ out_wait = _all_to_all_o_async_func(out, group, **metadata) # (B, S_LOCAL, H_GLOBAL, D)
+ joint_hidden_states = out_wait() # type: torch.Tensor
+
+ # Reshape back
+ joint_hidden_states = joint_hidden_states.flatten(2, 3)
+ joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
+
+ # Split attention outputs back
+ txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
+ img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
+
+ # Apply output projections
+ img_attn_output = attn.to_out[0](img_attn_output)
+ if len(attn.to_out) > 1:
+ img_attn_output = attn.to_out[1](img_attn_output) # dropout
+
+ txt_attn_output = attn.to_add_out(txt_attn_output)
+
+ return img_attn_output, txt_attn_output
+
+
+QwenDoubleStreamAttnProcessor2_0_original__call__ = QwenDoubleStreamAttnProcessor2_0.__call__
+
+
+@functools.wraps(QwenDoubleStreamAttnProcessor2_0_original__call__)
+def __patch_QwenDoubleStreamAttnProcessor2_0_ulysses_async__call__(
+ self: QwenDoubleStreamAttnProcessor2_0,
+ attn: Attention,
+ hidden_states: torch.FloatTensor, # Image stream
+ encoder_hidden_states: torch.FloatTensor = None, # Text stream
+ encoder_hidden_states_mask: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+) -> torch.FloatTensor:
+ if (
+ self._parallel_config is not None
+ and hasattr(self._parallel_config, "context_parallel_config")
+ and self._parallel_config.context_parallel_config is not None
+ and self._parallel_config.context_parallel_config.ulysses_degree > 1
+ ):
+ return _ulysses_attn_with_async_qkv_proj_qwen_image(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states,
+ encoder_hidden_states_mask,
+ attention_mask,
+ image_rotary_emb,
+ )
+ else:
+ return QwenDoubleStreamAttnProcessor2_0_original__call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states,
+ encoder_hidden_states_mask,
+ attention_mask,
+ image_rotary_emb,
+ )
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_registers.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_registers.py
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_skyreels.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_skyreels.py
similarity index 100%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_skyreels.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_skyreels.py
diff --git a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_wan.py
similarity index 60%
rename from src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py
rename to src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_wan.py
index 8a3853856..3cd0db0b3 100644
--- a/src/cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_wan.py
@@ -1,17 +1,18 @@
import torch
import functools
-from typing import Optional, Tuple
+from typing import Tuple, Optional
from diffusers.models.modeling_utils import ModelMixin
from diffusers import WanVACETransformer3DModel
+from diffusers.models.transformers.transformer_wan import (
+ WanAttention,
+ WanAttnProcessor,
+ _get_added_kv_projections,
+ _get_qkv_projections,
+ dispatch_attention_fn,
+)
+
try:
- from diffusers.models.transformers.transformer_chronoedit import (
- WanAttention,
- WanAttnProcessor,
- _get_added_kv_projections,
- _get_qkv_projections,
- dispatch_attention_fn,
- )
from diffusers.models._modeling_parallel import (
ContextParallelInput,
ContextParallelOutput,
@@ -33,7 +34,6 @@
logger = init_logger(__name__)
-@ContextParallelismPlannerRegister.register("ChronoEditTransformer3D")
@ContextParallelismPlannerRegister.register("WanTransformer3D")
class WanContextParallelismPlanner(ContextParallelismPlanner):
def apply(
@@ -42,10 +42,7 @@ def apply(
**kwargs,
) -> ContextParallelModelPlan:
- cls_name = transformer.__class__.__name__ if transformer else ""
-
- if cls_name.startswith("ChronoEditTransformer3D"):
- self._cp_planner_preferred_native_diffusers = False
+ self._cp_planner_preferred_native_diffusers = False
if transformer is not None and self._cp_planner_preferred_native_diffusers:
if hasattr(transformer, "_cp_plan"):
@@ -55,108 +52,67 @@ def apply(
# Otherwise, use the custom CP plan defined here, this maybe
# a little different from the native diffusers implementation
# for some models.
- if cls_name.startswith("ChronoEditTransformer3D"):
- WanAttnProcessor.__call__ = __patch_WanAttnProcessor__call__
- _cp_plan = {
- # Pattern of rope, split_output=True (split output rather than input):
- # un-split input
- # -> keep input un-split
- # -> rope
- # -> splited output
- "rope": {
- 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
- 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
- },
- # Pattern of blocks.0, split_output=False:
- # un-split input -> split -> to_qkv/...
- # -> all2all
- # -> attn (local head, full seqlen)
- # -> all2all
- # -> splited output
- # (only split hidden_states, not encoder_hidden_states)
- "blocks.0": {
- "hidden_states": ContextParallelInput(
- split_dim=1, expected_dims=3, split_output=False
- ),
- },
- # Pattern of the all blocks, split_output=False:
- # un-split input -> split -> to_qkv/...
- # -> all2all
- # -> attn (local head, full seqlen)
- # -> all2all
- # -> splited output
- # (only split encoder_hidden_states, not hidden_states.
- # hidden_states has been automatically split in previous
- # block by all2all comm op after attn)
- # The `encoder_hidden_states` will [NOT] be changed after each block forward,
- # so we need to split it at [ALL] block by the inserted split hook.
- # NOTE(DefTruth): We need to disable the splitting of encoder_hidden_states because
- # the image_encoder consistently generates 257 tokens for image_embed. This causes
- # the shape of encoder_hidden_states—whose token count is always 769 (512 + 257)
- # after concatenation—to be indivisible by the number of devices in the CP.
- # "blocks.*": {
- # "encoder_hidden_states": ContextParallelInput(
- # split_dim=1, expected_dims=3, split_output=False
- # ),
- # },
- # Then, the final proj_out will gather the splited output.
- # splited input (previous splited output)
- # -> all gather
- # -> un-split output
- "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
- }
- else:
- _cp_plan = {
- # Pattern of rope, split_output=True (split output rather than input):
- # un-split input
- # -> keep input un-split
- # -> rope
- # -> splited output
- "rope": {
- 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
- 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
- },
- # Pattern of blocks.0, split_output=False:
- # un-split input -> split -> to_qkv/...
- # -> all2all
- # -> attn (local head, full seqlen)
- # -> all2all
- # -> splited output
- # (only split hidden_states, not encoder_hidden_states)
- "blocks.0": {
- "hidden_states": ContextParallelInput(
- split_dim=1, expected_dims=3, split_output=False
- ),
- },
- # Pattern of the all blocks, split_output=False:
- # un-split input -> split -> to_qkv/...
- # -> all2all
- # -> attn (local head, full seqlen)
- # -> all2all
- # -> splited output
- # (only split encoder_hidden_states, not hidden_states.
- # hidden_states has been automatically split in previous
- # block by all2all comm op after attn)
- # The `encoder_hidden_states` will [NOT] be changed after each block forward,
- # so we need to split it at [ALL] block by the inserted split hook.
- "blocks.*": {
- "encoder_hidden_states": ContextParallelInput(
- split_dim=1, expected_dims=3, split_output=False
- ),
- },
- # Then, the final proj_out will gather the splited output.
- # splited input (previous splited output)
- # -> all gather
- # -> un-split output
- "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
- }
+ WanAttnProcessor.__call__ = __patch_WanAttnProcessor__call__
+ _cp_plan = {
+ # Pattern of rope, split_output=True (split output rather than input):
+ # un-split input
+ # -> keep input un-split
+ # -> rope
+ # -> splited output
+ "rope": {
+ 0: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ 1: ContextParallelInput(split_dim=1, expected_dims=4, split_output=True),
+ },
+ # Pattern of blocks.0, split_output=False:
+ # un-split input -> split -> to_qkv/...
+ # -> all2all
+ # -> attn (local head, full seqlen)
+ # -> all2all
+ # -> splited output
+ # (only split hidden_states, not encoder_hidden_states)
+ "blocks.0": {
+ "hidden_states": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ # Pattern of the all blocks, split_output=False:
+ # un-split input -> split -> to_qkv/...
+ # -> all2all
+ # -> attn (local head, full seqlen)
+ # -> all2all
+ # -> splited output
+ # (only split encoder_hidden_states, not hidden_states.
+ # hidden_states has been automatically split in previous
+ # block by all2all comm op after attn)
+ # The `encoder_hidden_states` will [NOT] be changed after each block forward,
+ # so we need to split it at [ALL] block by the inserted split hook.
+ # NOTE(DefTruth): We need to disable the splitting of encoder_hidden_states because
+ # the image_encoder (Wan 2.1 I2V) consistently generates 257 tokens for image_embed.
+ # This causes the shape of encoder_hidden_states—whose token count is always
+ # 769 (512 + 257) after concatenation—to be indivisible by the number of devices
+ # in the CP.
+ # "blocks.*": {
+ # "encoder_hidden_states": ContextParallelInput(
+ # split_dim=1, expected_dims=3, split_output=False
+ # ),
+ # },
+ # Then, the final proj_out will gather the splited output.
+ # splited input (previous splited output)
+ # -> all gather
+ # -> un-split output
+ "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ # Wan 2.2 TI2V: https://github.com/huggingface/diffusers/pull/12562
+ "": {
+ "timestep": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+ },
+ }
return _cp_plan
@functools.wraps(WanAttnProcessor.__call__)
def __patch_WanAttnProcessor__call__(
self: WanAttnProcessor,
- attn: "WanAttention",
+ attn: WanAttention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
@@ -216,7 +172,7 @@ def apply_rotary_emb(
# FIXME(DefTruth): Since the key/value in cross-attention depends
# solely on encoder_hidden_states_img (img), the (q_chunk * k) * v
# computation can be parallelized independently. Thus, there is
- # no need to pass the parallel_config here.
+ # no need to pass the config here.
parallel_config=None,
)
hidden_states_img = hidden_states_img.flatten(2, 3)
@@ -233,7 +189,7 @@ def apply_rotary_emb(
# FIXME(DefTruth): Since the key/value in cross-attention depends
# solely on encoder_hidden_states (text), the (q_chunk * k) * v
# computation can be parallelized independently. Thus, there is
- # no need to pass the parallel_config here.
+ # no need to pass the config here.
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
diff --git a/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_zimage.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_zimage.py
new file mode 100644
index 000000000..9b1c6d7ba
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_plan_zimage.py
@@ -0,0 +1,304 @@
+import torch
+import functools
+from typing import Optional
+from torch.distributed import DeviceMesh
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers import ZImageTransformer2DModel
+from diffusers.models.transformers.transformer_z_image import (
+ ZSingleStreamAttnProcessor,
+ dispatch_attention_fn,
+ Attention,
+)
+
+try:
+ from diffusers.models._modeling_parallel import (
+ ContextParallelInput,
+ ContextParallelOutput,
+ ContextParallelModelPlan,
+ )
+except ImportError:
+ raise ImportError(
+ "Context parallelism requires the 'diffusers>=0.36.dev0'."
+ "Please install latest version of diffusers from source: \n"
+ "pip3 install git+https://github.com/huggingface/diffusers.git"
+ )
+from .cp_plan_registers import (
+ ContextParallelismPlanner,
+ ContextParallelismPlannerRegister,
+)
+from cache_dit.parallelism.attention import _unified_all_to_all_o_async_fn
+from cache_dit.parallelism.attention import _unified_all_to_all_qkv_async_fn
+from cache_dit.parallelism.attention import _prepare_ulysses_comm_metadata
+from cache_dit.parallelism.attention import _maybe_patch_find_submodule
+from cache_dit.platforms import current_platform
+
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+@ContextParallelismPlannerRegister.register("ZImageTransformer2DModel")
+class ZImageContextParallelismPlanner(ContextParallelismPlanner):
+ def apply(
+ self,
+ transformer: Optional[torch.nn.Module | ModelMixin] = None,
+ **kwargs,
+ ) -> ContextParallelModelPlan:
+
+ # NOTE: Diffusers native CP plan still not supported for ZImageTransformer2DModel
+ self._cp_planner_preferred_native_diffusers = False
+
+ if transformer is not None and self._cp_planner_preferred_native_diffusers:
+ assert isinstance(
+ transformer, ZImageTransformer2DModel
+ ), "Transformer must be an instance of ZImageTransformer2DModel"
+ if hasattr(transformer, "_cp_plan"):
+ if transformer._cp_plan is not None:
+ return transformer._cp_plan
+
+ experimental_ulysses_async = kwargs.get("experimental_ulysses_async", False)
+ if experimental_ulysses_async:
+ ZSingleStreamAttnProcessor.__call__ = (
+ __patch_ZSingleStreamAttnProcessor_ulysses_async__call__
+ )
+
+ logger.info(
+ "Enabled experimental Async QKV Projection with Ulysses style "
+ "Context Parallelism for ZImageTransformer2DModel."
+ )
+
+ # NOTE: This only a temporary workaround for ZImage to make context parallelism
+ # work compatible with DBCache FnB0. The better way is to make DBCache fully
+ # compatible with diffusers native context parallelism, e.g., check the split/gather
+ # hooks in each block/layer in the initialization of DBCache.
+ # Issue: https://github.com/vipshop/cache-dit/issues/498
+ _maybe_patch_find_submodule()
+ n_noise_refiner_layers = len(transformer.noise_refiner) # 2
+ n_context_refiner_layers = len(transformer.context_refiner) # 2
+ n_layers = len(transformer.layers) # 30
+ # controlnet layer idx: [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28]
+ # num_controlnet_samples = len(transformer.layers) // 2 # 15
+ has_controlnet = kwargs.get("has_controlnet", None)
+ if not has_controlnet:
+ # cp plan for ZImageTransformer2DModel if no controlnet
+ _cp_plan = {
+ # 0. Hooks for noise_refiner layers, 2
+ "noise_refiner.0": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "noise_refiner.*": {
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ f"noise_refiner.{n_noise_refiner_layers - 1}": ContextParallelOutput(
+ gather_dim=1, expected_dims=3
+ ),
+ # 1. Hooks for context_refiner layers, 2
+ "context_refiner.0": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "context_refiner.*": {
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ f"context_refiner.{n_context_refiner_layers - 1}": ContextParallelOutput(
+ gather_dim=1, expected_dims=3
+ ),
+ # 2. Hooks for main transformer layers, num_layers=30
+ "layers.0": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "layers.*": {
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ # NEED: call _maybe_patch_find_submodule to support ModuleDict like 'all_final_layer'
+ "all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ # NOTE: The 'all_final_layer' is a ModuleDict of several final layers,
+ # each for a specific patch size combination, so we do not add hooks for it here.
+ # So, we have to gather the output of the last transformer layer.
+ # f"layers.{num_layers - 1}": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+ else:
+ # Special cp plan for ZImageTransformer2DModel with ZImageControlNetModel
+ logger.warning(
+ "Using special context parallelism plan for ZImageTransformer2DModel "
+ "due to the 'has_controlnet' flag is set to True."
+ )
+ _cp_plan = {
+ # zimage controlnet shared the same refiner as zimage, so, we need to
+ # add gather hooks for all layers in noise_refiner and context_refiner.
+ # 0. Hooks for noise_refiner layers, 2
+ # Insert gather hook after each layers due to the ops: (controlnet)
+ # - x = x + noise_refiner_block_samples[layer_idx]
+ "noise_refiner.*": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ **{
+ f"noise_refiner.{i}": ContextParallelOutput(gather_dim=1, expected_dims=3)
+ for i in range(n_noise_refiner_layers)
+ },
+ # 1. Hooks for context_refiner layers, 2
+ "context_refiner.0": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ },
+ "context_refiner.*": {
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ f"context_refiner.{n_context_refiner_layers - 1}": ContextParallelOutput(
+ gather_dim=1, expected_dims=3
+ ),
+ # 2. Hooks for main transformer layers, num_layers=30
+ # Insert gather hook after each layers due to the ops: (main transformer)
+ # - unified + controlnet_block_samples[layer_idx]
+ "layers.*": {
+ "x": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+ "freqs_cis": ContextParallelInput(
+ split_dim=1, expected_dims=3, split_output=False
+ ),
+ },
+ **{
+ f"layers.{i}": ContextParallelOutput(gather_dim=1, expected_dims=3)
+ for i in range(n_layers)
+ },
+ # NEED: call _maybe_patch_find_submodule to support ModuleDict like 'all_final_layer'
+ "all_final_layer": ContextParallelOutput(gather_dim=1, expected_dims=3),
+ }
+ return _cp_plan
+
+
+# NOTE: Support Async Ulysses QKV projection for Z-Image
+def _ulysses_attn_with_async_qkv_proj_zimage(
+ self: ZSingleStreamAttnProcessor,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+
+ ulysses_mesh: DeviceMesh = self._parallel_config.context_parallel_config._ulysses_mesh
+ group = ulysses_mesh.get_group()
+
+ _all_to_all_o_async_func = _unified_all_to_all_o_async_fn()
+ _all_to_all_qv_async_func = _unified_all_to_all_qkv_async_fn()
+ _all_to_all_k_async_func = _unified_all_to_all_qkv_async_fn(fp8=False)
+
+ # Apply RoPE
+ def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
+ with torch.amp.autocast(current_platform.device_type, enabled=False):
+ x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
+ freqs_cis = freqs_cis.unsqueeze(2)
+ x_out = torch.view_as_real(x * freqs_cis).flatten(3)
+ return x_out.type_as(x_in) # todo
+
+ dtype = hidden_states.dtype
+ query = attn.to_q(hidden_states) # type: torch.Tensor
+ query = query.unflatten(-1, (attn.heads, -1))
+ if attn.norm_q is not None: # Apply Norms
+ query = attn.norm_q(query)
+ if freqs_cis is not None: # Apply RoPE
+ query = apply_rotary_emb(query, freqs_cis)
+
+ metadata = _prepare_ulysses_comm_metadata(query)
+
+ # Async all to all for query
+ query_wait = _all_to_all_qv_async_func(query, group, **metadata)
+
+ key = attn.to_k(hidden_states) # type: torch.Tensor
+ key = key.unflatten(-1, (attn.heads, -1))
+ if attn.norm_k is not None: # Apply Norms
+ key = attn.norm_k(key)
+ if freqs_cis is not None: # Apply RoPE
+ key = apply_rotary_emb(key, freqs_cis)
+
+ # Async all to all for key
+ key_wait = _all_to_all_k_async_func(key, group, **metadata)
+
+ value = attn.to_v(hidden_states) # type: torch.Tensor
+ value = value.unflatten(-1, (attn.heads, -1))
+
+ # Async all to all for value
+ value_wait = _all_to_all_qv_async_func(value, group, **metadata)
+
+ # Ensure the query, key, value are ready
+ query = query_wait()
+ key = key_wait()
+ value = value_wait()
+
+ # Cast to correct dtype
+ query, key = query.to(dtype), key.to(dtype)
+
+ # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
+ if attention_mask is not None and attention_mask.ndim == 2:
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Compute joint attention
+ out = dispatch_attention_fn(
+ query,
+ key,
+ value,
+ attn_mask=attention_mask,
+ dropout_p=0.0,
+ is_causal=False,
+ backend=self._attention_backend,
+ parallel_config=None, # set to None to avoid double parallelism
+ ) # (B, S_GLOBAL, H_LOCAL, D)
+
+ out_wait = _all_to_all_o_async_func(out, group, **metadata) # (B, S_LOCAL, H_GLOBAL, D)
+ hidden_states = out_wait() # type: torch.Tensor
+
+ # Reshape back
+ hidden_states = hidden_states.flatten(2, 3)
+ hidden_states = hidden_states.to(dtype)
+
+ output = attn.to_out[0](hidden_states)
+ if len(attn.to_out) > 1: # dropout
+ output = attn.to_out[1](output)
+
+ return output
+
+
+ZSingleStreamAttnProcessor_original__call__ = ZSingleStreamAttnProcessor.__call__
+
+
+@functools.wraps(ZSingleStreamAttnProcessor_original__call__)
+def __patch_ZSingleStreamAttnProcessor_ulysses_async__call__(
+ self: ZSingleStreamAttnProcessor,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ if (
+ self._parallel_config is not None
+ and hasattr(self._parallel_config, "context_parallel_config")
+ and self._parallel_config.context_parallel_config is not None
+ and self._parallel_config.context_parallel_config.ulysses_degree > 1
+ ):
+ return _ulysses_attn_with_async_qkv_proj_zimage(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states,
+ attention_mask,
+ freqs_cis,
+ )
+ else:
+ return ZSingleStreamAttnProcessor_original__call__(
+ self,
+ attn,
+ hidden_states,
+ encoder_hidden_states,
+ attention_mask,
+ freqs_cis,
+ )
diff --git a/src/cache_dit/parallelism/transformers/context_parallelism/cp_planners.py b/src/cache_dit/parallelism/transformers/context_parallelism/cp_planners.py
new file mode 100644
index 000000000..d91ce71f5
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/context_parallelism/cp_planners.py
@@ -0,0 +1,177 @@
+# Docstring references: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/_modeling_parallel.py#L185
+# A dictionary where keys denote the input to be split across context parallel region, and the
+# value denotes the sharding configuration.
+# If the key is a string, it denotes the name of the parameter in the forward function.
+# If the key is an integer, split_output must be set to True, and it denotes the index of the output
+# to be split across context parallel region.
+# ContextParallelInputType = Dict[
+# Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
+# ]
+
+# A dictionary where keys denote the output to be gathered across context parallel region, and the
+# value denotes the gathering configuration.
+# ContextParallelOutputType = Union[
+# ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
+# ]
+
+# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
+# the module should be split/gathered across context parallel region.
+# ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]
+
+# Example of a ContextParallelModelPlan (QwenImageTransformer2DModel):
+#
+# Each model should define a _cp_plan attribute that contains information on how to shard/gather
+# tensors at different stages of the forward:
+#
+# ```python
+# _cp_plan = {
+# "": {
+# "hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+# "encoder_hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
+# "encoder_hidden_states_mask": ContextParallelInput(split_dim=1, expected_dims=2, split_output=False),
+# },
+# "pos_embed": {
+# 0: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+# 1: ContextParallelInput(split_dim=0, expected_dims=2, split_output=True),
+# },
+# "proj_out": ContextParallelOutput(gather_dim=1, expected_dims=3),
+# }
+# ```
+#
+# The dictionary is a set of module names mapped to their respective CP plan. The inputs/outputs of layers will be
+# split/gathered according to this at the respective module level. Here, the following happens:
+# - "":
+# we specify that we want to split the various inputs across the sequence dim in the pre-forward hook (i.e. before
+# the actual forward logic of the QwenImageTransformer2DModel is run, we will splitthe inputs)
+# - "pos_embed":
+# we specify that we want to split the outputs of the RoPE layer. Since there are two outputs (imag & text freqs),
+# we can individually specify how they should be split
+# - "proj_out":
+# before returning to the user, we gather the entire sequence on each rank in the post-forward hook (after the linear
+# layer forward has run).
+#
+# ContextParallelInput:
+# specifies how to split the input tensor in the pre-forward or post-forward hook of the layer it is attached to
+#
+# ContextParallelOutput:
+# specifies how to gather the input tensor in the post-forward hook in the layer it is attached to
+import importlib
+from cache_dit.logger import init_logger
+from .cp_plan_registers import ContextParallelismPlanner
+
+logger = init_logger(__name__)
+
+
+class ImportErrorContextParallelismPlanner(ContextParallelismPlanner):
+ def plan(
+ self,
+ transformer,
+ **kwargs,
+ ):
+ raise ImportError(
+ "This ContextParallelismPlanner requires latest diffusers to be installed. "
+ "Please install diffusers from source."
+ )
+
+
+def _safe_import(module_name: str, class_name: str) -> type[ContextParallelismPlanner]:
+ try:
+ # e.g., module_name = ".cp_plan_dit", class_name = "DiTContextParallelismPlanner"
+ package = __package__ if __package__ is not None else ""
+ module = importlib.import_module(module_name, package=package)
+ target_class = getattr(module, class_name)
+ return target_class
+ except (ImportError, AttributeError) as e:
+ logger.debug(f"Failed to import {class_name} from {module_name}: {e}")
+ return ImportErrorContextParallelismPlanner
+
+
+def _activate_cp_planners():
+ """Function to register all built-in context parallelism planners."""
+ FluxContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_flux", "FluxContextParallelismPlanner"
+ )
+ QwenImageContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_qwen_image", "QwenImageContextParallelismPlanner"
+ )
+ WanContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_wan", "WanContextParallelismPlanner"
+ )
+ WanVACEContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_wan", "WanVACEContextParallelismPlanner"
+ )
+ LTXVideoContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_ltxvideo", "LTXVideoContextParallelismPlanner"
+ )
+ LTX2ContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_ltx2", "LTX2ContextParallelismPlanner"
+ )
+ HunyuanImageContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_hunyuan", "HunyuanImageContextParallelismPlanner"
+ )
+ HunyuanVideoContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_hunyuan", "HunyuanVideoContextParallelismPlanner"
+ )
+ CogVideoXContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_cogvideox", "CogVideoXContextParallelismPlanner"
+ )
+ CogView3PlusContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_cogview", "CogView3PlusContextParallelismPlanner"
+ )
+ CogView4ContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_cogview", "CogView4ContextParallelismPlanner"
+ )
+ CosisIDContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_cosisid", "CosisIDContextParallelismPlanner"
+ )
+ ChromaContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_chroma", "ChromaContextParallelismPlanner"
+ )
+ PixArtContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_pixart", "PixArtContextParallelismPlanner"
+ )
+ DiTContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_dit", "DiTContextParallelismPlanner"
+ )
+ Kandinsky5ContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_kandinsky", "Kandinsky5ContextParallelismPlanner"
+ )
+ SkyReelsV2ContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_skyreels", "SkyReelsV2ContextParallelismPlanner"
+ )
+ Flux2ContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_flux2", "Flux2ContextParallelismPlanner"
+ )
+ ZImageContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_zimage", "ZImageContextParallelismPlanner"
+ )
+ ChronoEditContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_chrono_edit", "ChronoEditContextParallelismPlanner"
+ )
+ OvisImageContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_ovis_image", "OvisImageContextParallelismPlanner"
+ )
+ LongCatImageContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_longcat_image", "LongCatImageContextParallelismPlanner"
+ )
+
+ try:
+ import nunchaku # noqa: F401
+
+ _nunchaku_available = True
+ except ImportError:
+ _nunchaku_available = False
+
+ if _nunchaku_available:
+ NunchakuFluxContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_nunchaku", "NunchakuFluxContextParallelismPlanner"
+ )
+ NunchakuQwenImageContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_nunchaku", "NunchakuQwenImageContextParallelismPlanner"
+ )
+ NunchakuZImageContextParallelismPlanner = _safe_import( # noqa: F841
+ ".cp_plan_nunchaku", "NunchakuZImageContextParallelismPlanner"
+ )
+
+
+__all__ = ["_activate_cp_planners"]
diff --git a/src/cache_dit/parallelism/transformers/dispatch.py b/src/cache_dit/parallelism/transformers/dispatch.py
new file mode 100644
index 000000000..8a68b89f6
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/dispatch.py
@@ -0,0 +1,134 @@
+import torch
+
+from typing import Optional
+from cache_dit.logger import init_logger
+
+from diffusers.models.modeling_utils import ModelMixin
+
+from cache_dit.parallelism.backend import ParallelismBackend
+from cache_dit.parallelism.config import ParallelismConfig
+
+logger = init_logger(__name__)
+
+
+def maybe_enable_parallelism_for_transformer(
+ transformer: torch.nn.Module | ModelMixin,
+ parallelism_config: Optional[ParallelismConfig],
+) -> torch.nn.Module:
+ assert isinstance(transformer, (torch.nn.Module, ModelMixin)), (
+ "transformer must be an instance of torch.nn.Module or ModelMixin, "
+ f"but got {type(transformer)}"
+ )
+
+ if parallelism_config is None:
+ return transformer
+
+ # Currently, we can dispatch the parallelism based on the backend type.
+ # Now, The context parallelism is only supported in NATIVE_DIFFUSER backend,
+ # and the tensor parallelism is only supported in NATIVE_PYTORCH backend.
+ if parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER:
+ return maybe_enable_context_parallelism_for_transformer(
+ transformer=transformer,
+ parallelism_config=parallelism_config,
+ )
+ elif parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH:
+ return maybe_enable_tensor_parallelism_for_transformer(
+ transformer=transformer,
+ parallelism_config=parallelism_config,
+ )
+ else:
+ raise ValueError(f"{parallelism_config.backend} backend is not supported yet")
+
+
+def maybe_enable_context_parallelism_for_transformer(
+ transformer: torch.nn.Module | ModelMixin,
+ parallelism_config: Optional[ParallelismConfig],
+) -> torch.nn.Module:
+ assert isinstance(transformer, (torch.nn.Module, ModelMixin)), (
+ "transformer must be an instance of torch.nn.Module or ModelMixin, "
+ f"but got {type(transformer)}"
+ )
+
+ if parallelism_config is None:
+ return transformer
+
+ assert isinstance(parallelism_config, ParallelismConfig), (
+ "parallelism_config must be an instance of ParallelismConfig"
+ f" but got {type(parallelism_config)}"
+ )
+
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER, (
+ f"parallelism backend must be {ParallelismBackend.NATIVE_DIFFUSER}, "
+ f"but got {parallelism_config.backend}"
+ )
+
+ if parallelism_config.ulysses_size is not None or parallelism_config.ring_size is not None:
+ from .context_parallelism import maybe_enable_context_parallelism
+
+ transformer = maybe_enable_context_parallelism(
+ transformer,
+ parallelism_config,
+ )
+ transformer._is_parallelized = True # type: ignore[attr-defined]
+ # Use `parallelism` not `parallel` to avoid name conflict with diffusers.
+ transformer._parallelism_config = parallelism_config # type: ignore[attr-defined]
+ logger.info(
+ f"Parallelize Transformer: {transformer.__class__.__name__}, "
+ f"id:{id(transformer)}, {parallelism_config.strify(True)}"
+ )
+
+ else:
+ raise ValueError(
+ "NATIVE_DIFFUSER backend only support context parallelism now. "
+ "Please set ulysses_size or ring_size in parallelism_config."
+ )
+ return transformer
+
+
+def maybe_enable_tensor_parallelism_for_transformer(
+ transformer: torch.nn.Module | ModelMixin,
+ parallelism_config: Optional[ParallelismConfig],
+) -> torch.nn.Module:
+ assert isinstance(transformer, (torch.nn.Module, ModelMixin)), (
+ "transformer must be an instance of torch.nn.Module or ModelMixin, "
+ f"but got {type(transformer)}"
+ )
+
+ if parallelism_config is None:
+ return transformer
+
+ assert parallelism_config.backend == ParallelismBackend.NATIVE_PYTORCH, (
+ "parallelism_config.backend must be ParallelismBackend.NATIVE_PYTORCH "
+ f"but got {parallelism_config.backend}"
+ )
+
+ assert isinstance(parallelism_config, ParallelismConfig), (
+ "parallelism_config must be an instance of ParallelismConfig"
+ f" but got {type(parallelism_config)}"
+ )
+ assert parallelism_config.ulysses_size is None and parallelism_config.ring_size is None, (
+ "Ulysses/Ring parallelism is not supported in Native_PyTorch backend. "
+ "Please set it to None in parallelism_config."
+ )
+
+ if parallelism_config.tp_size is not None and parallelism_config.tp_size > 1:
+ from .tensor_parallelism import maybe_enable_tensor_parallelism
+
+ transformer = maybe_enable_tensor_parallelism(
+ transformer=transformer,
+ parallelism_config=parallelism_config,
+ )
+ transformer._is_parallelized = True # type: ignore[attr-defined]
+ # Use `parallelism` not `parallel` to avoid name conflict with diffusers.
+ transformer._parallelism_config = parallelism_config # type: ignore[attr-defined]
+ logger.info(
+ f"Parallelize Transformer: {transformer.__class__.__name__}, "
+ f"id:{id(transformer)}, {parallelism_config.strify(True)}"
+ )
+
+ else:
+ raise ValueError(
+ "NATIVE_PYTORCH only supported tensor parallelism now. "
+ "Please set tp_size > 1 for tensor parallelism."
+ )
+ return transformer
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/__init__.py
similarity index 72%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/__init__.py
index c606ed5f7..ab1ba2759 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/__init__.py
@@ -9,10 +9,17 @@
import torch
from typing import Optional
from diffusers.models.modeling_utils import ModelMixin
-from cache_dit.parallelism.parallel_backend import ParallelismBackend
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.backend import ParallelismBackend
+from cache_dit.parallelism.config import ParallelismConfig
from cache_dit.logger import init_logger
-from .tp_planners import *
+
+try:
+ from .tp_plan_registers import TensorParallelismPlannerRegister
+ from .tp_planners import _activate_tp_planners
+
+ _activate_tp_planners()
+except ImportError as e:
+ raise ImportError(e)
logger = init_logger(__name__)
@@ -21,9 +28,11 @@ def maybe_enable_tensor_parallelism(
transformer: torch.nn.Module | ModelMixin,
parallelism_config: Optional[ParallelismConfig],
) -> torch.nn.Module:
- assert isinstance(transformer, torch.nn.Module), (
- "transformer must be an instance of torch.nn.Module, " f"but got {type(transformer)}"
+ assert isinstance(transformer, (torch.nn.Module, ModelMixin)), (
+ "transformer must be an instance of torch.nn.Module or ModelMixin, "
+ f"but got {type(transformer)}"
)
+
assert isinstance(transformer, ModelMixin), (
"transformer must be an instance of diffusers' ModelMixin, " f"but got {type(transformer)}"
)
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_cogview.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_cogview.py
similarity index 88%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_cogview.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_cogview.py
index 6b4873683..9d149454d 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_cogview.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_cogview.py
@@ -9,12 +9,13 @@
)
from cache_dit.logger import init_logger
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.config import ParallelismConfig
from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)
+from .tp_utils import shard_divisible_attr
logger = init_logger(__name__)
@@ -54,7 +55,13 @@ def parallelize_transformer(
):
for _, block in transformer.transformer_blocks.named_children():
# Reduce attention heads for tensor parallelism
- block.attn1.heads //= tp_mesh.size()
+ shard_divisible_attr(
+ block.attn1,
+ "heads",
+ tp_mesh.size(),
+ what="attn1",
+ context="CogViewTensorParallelismPlanner",
+ )
layer_plan = {
# Self-attention projections
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_flux.py
similarity index 77%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_flux.py
index dbbc443e6..34804617b 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_flux.py
@@ -13,12 +13,13 @@
)
from cache_dit.logger import init_logger
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.config import ParallelismConfig
from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)
+from .tp_utils import shard_divisible_attr
logger = init_logger(__name__)
@@ -26,7 +27,7 @@
@TensorParallelismPlannerRegister.register("Chroma")
@TensorParallelismPlannerRegister.register("HunyuanImage")
@TensorParallelismPlannerRegister.register("HunyuanVideo")
-@TensorParallelismPlannerRegister.register("Flux")
+@TensorParallelismPlannerRegister.register("FluxTransformer")
class FluxTensorParallelismPlanner(TensorParallelismPlanner):
def apply(
self,
@@ -48,45 +49,22 @@ def apply(
transformer=transformer,
tp_mesh=tp_mesh,
)
- # TODO: Parallelize t5 text encoder via `apply_extra`
- # abstract method and `extra_parallel_kwargs` ?
return transformer
- def parallelize_t5(
- self,
- text_encoder: nn.Module,
- tp_mesh: DeviceMesh,
- ):
- for i, block in enumerate(text_encoder.encoder.block):
- block.layer[0].SelfAttention.n_heads //= tp_mesh.size()
- block.layer[0].SelfAttention.inner_dim //= tp_mesh.size()
- layer_plan = {
- "layer.0.SelfAttention.q": ColwiseParallel(),
- "layer.0.SelfAttention.k": ColwiseParallel(),
- "layer.0.SelfAttention.v": ColwiseParallel(),
- "layer.0.SelfAttention.o": RowwiseParallel(),
- "layer.1.DenseReluDense.wi_0": ColwiseParallel(),
- "layer.1.DenseReluDense.wi_1": ColwiseParallel(),
- "layer.1.DenseReluDense.wo": RowwiseParallel(),
- }
- if i == 0:
- layer_plan["layer.0.SelfAttention.relative_attention_bias"] = ColwiseParallel()
- parallelize_module(
- module=block,
- device_mesh=tp_mesh,
- parallelize_plan=layer_plan,
- )
-
- return text_encoder
-
def parallelize_transformer(
self,
transformer: nn.Module,
tp_mesh: DeviceMesh,
):
for _, block in transformer.transformer_blocks.named_children():
- block.attn.heads //= tp_mesh.size()
+ shard_divisible_attr(
+ block.attn,
+ "heads",
+ tp_mesh.size(),
+ what="attn",
+ context="FluxTensorParallelismPlanner",
+ )
layer_plan = {
"attn.to_q": ColwiseParallel(),
"attn.to_k": ColwiseParallel(),
@@ -131,7 +109,13 @@ def rearrange_proj_out_weight(single_block: FluxSingleTransformerBlock, tp_group
for _, block in transformer.single_transformer_blocks.named_children():
rearrange_proj_out_weight(block, tp_mesh.size())
- block.attn.heads //= tp_mesh.size()
+ shard_divisible_attr(
+ block.attn,
+ "heads",
+ tp_mesh.size(),
+ what="attn",
+ context="FluxTensorParallelismPlanner(single_block)",
+ )
layer_plan = {
"attn.to_q": ColwiseParallel(),
"attn.to_k": ColwiseParallel(),
diff --git a/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_flux2.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_flux2.py
new file mode 100644
index 000000000..178c95bf3
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_flux2.py
@@ -0,0 +1,179 @@
+import torch
+from diffusers.models.transformers.transformer_flux2 import (
+ Flux2SingleTransformerBlock,
+ Flux2TransformerBlock,
+ Flux2Transformer2DModel,
+)
+from einops import rearrange
+from torch.distributed import DeviceMesh, init_device_mesh
+
+from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+from cache_dit.utils import maybe_empty_cache
+from cache_dit.platforms import current_platform
+
+from .tp_plan_registers import TensorParallelismPlanner, TensorParallelismPlannerRegister
+from .tp_utils import shard_divisible_attr
+
+logger = init_logger(__name__)
+
+
+@TensorParallelismPlannerRegister.register("Flux2Transformer")
+class Flux2TensorParallelismPlanner(TensorParallelismPlanner):
+ def apply(
+ self,
+ transformer: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert parallelism_config.tp_size is not None and parallelism_config.tp_size > 1, (
+ "parallel_config.tp_size must be set and greater than 1 for " "tensor parallelism"
+ )
+
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[parallelism_config.tp_size],
+ )
+
+ transformer = self.parallelize_transformer(
+ transformer=transformer,
+ tp_mesh=tp_mesh,
+ )
+
+ return transformer
+
+ @classmethod
+ def rerangege_swiglu_weight(cls, weight: torch.Tensor, tp_size: int):
+ weight = rearrange(weight, "r (g h d) -> r (h g d)", g=2, h=tp_size)
+ return weight
+
+ @classmethod
+ def rearrange_feedforward_weight(cls, block: Flux2TransformerBlock, tp_size: int):
+
+ block.ff.linear_in.weight.data = cls.rerangege_swiglu_weight(
+ block.ff.linear_in.weight.data.T, tp_size
+ ).T
+ block.ff_context.linear_in.weight.data = cls.rerangege_swiglu_weight(
+ block.ff_context.linear_in.weight.data.T, tp_size
+ ).T
+
+ @classmethod
+ def rearrange_singleblock_weight(cls, block: Flux2SingleTransformerBlock, tp_size: int):
+ attn = block.attn
+ to_qkv_mlp_proj_weight = attn.to_qkv_mlp_proj.weight.data.T
+ qkv, mlp = torch.split(
+ to_qkv_mlp_proj_weight,
+ [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor],
+ dim=-1,
+ )
+
+ mlp = cls.rerangege_swiglu_weight(mlp, tp_size)
+
+ def rerangege_qkv_weight(weight: torch.Tensor, tp_size: int):
+ weight = rearrange(weight, "r (g h d) -> r (h g d)", g=3, h=tp_size)
+ return weight
+
+ qkv = rerangege_qkv_weight(qkv, tp_size)
+ qkv = rearrange(qkv, "r (h d) -> r h d", h=tp_size)
+ mlp = rearrange(mlp, "r (h d) -> r h d", h=tp_size)
+ to_qkv_mlp_proj_weight = torch.cat([qkv, mlp], dim=-1)
+ to_qkv_mlp_proj_weight = to_qkv_mlp_proj_weight.flatten(1)
+ attn.to_qkv_mlp_proj.weight.data = to_qkv_mlp_proj_weight.T
+
+ # rearrange out projection weight
+ out_weight = attn.to_out.weight.data.T
+ # FLUX.2-dev, FLUX.2-klein, divide by 4
+ attn_out_dim = out_weight.shape[0] // 4
+ attn_out_weight = out_weight[:attn_out_dim, ...]
+ mlp_out_weight = out_weight[attn_out_dim:, ...]
+
+ attn_out_weight = rearrange(attn_out_weight, "(g d) c -> g d c", g=tp_size)
+ mlp_out_weight = rearrange(mlp_out_weight, "(g d) c -> g d c", g=tp_size)
+
+ new_out_weight = torch.cat([attn_out_weight, mlp_out_weight], dim=1)
+ new_out_weight = rearrange(new_out_weight, "g d c -> (g d) c")
+ attn.to_out.weight.data = new_out_weight.T
+
+ def parallelize_transformer(
+ self,
+ transformer: Flux2Transformer2DModel,
+ tp_mesh: DeviceMesh,
+ ):
+ tp_size = tp_mesh.get_group().size()
+ for _, block in transformer.transformer_blocks.named_children():
+ # moving to cuda speed up the rearrangement process significantly
+ old_device = next(block.parameters()).device
+ block.to(current_platform.device_type)
+ self.rearrange_feedforward_weight(block, tp_size)
+ block.to(old_device)
+ shard_divisible_attr(
+ block.attn,
+ "heads",
+ tp_size,
+ what="attn",
+ context="Flux2TensorParallelismPlanner",
+ )
+ layer_plan = {
+ "attn.to_q": ColwiseParallel(),
+ "attn.to_k": ColwiseParallel(),
+ "attn.to_v": ColwiseParallel(),
+ "attn.to_out.0": RowwiseParallel(),
+ "ff.linear_in": ColwiseParallel(),
+ "ff.linear_out": RowwiseParallel(),
+ "attn.add_q_proj": ColwiseParallel(),
+ "attn.add_k_proj": ColwiseParallel(),
+ "attn.add_v_proj": ColwiseParallel(),
+ "attn.to_add_out": RowwiseParallel(),
+ "ff_context.linear_in": ColwiseParallel(),
+ "ff_context.linear_out": RowwiseParallel(),
+ }
+
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+ maybe_empty_cache()
+
+ for _, block in transformer.single_transformer_blocks.named_children():
+ # moving to cuda speed up the rearrangement process significantly
+ old_device = next(block.parameters()).device
+ block.to(current_platform.device_type)
+ self.rearrange_singleblock_weight(block, tp_size)
+ block.to(old_device)
+ shard_divisible_attr(
+ block.attn,
+ "heads",
+ tp_size,
+ what="attn",
+ context="Flux2TensorParallelismPlanner(single_block)",
+ )
+ shard_divisible_attr(
+ block.attn,
+ "inner_dim",
+ tp_size,
+ what="attn",
+ context="Flux2TensorParallelismPlanner(single_block)",
+ )
+ shard_divisible_attr(
+ block.attn,
+ "mlp_hidden_dim",
+ tp_size,
+ what="attn",
+ context="Flux2TensorParallelismPlanner(single_block)",
+ )
+ layer_plan = {
+ "attn.to_qkv_mlp_proj": ColwiseParallel(),
+ "attn.to_out": RowwiseParallel(),
+ }
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+ maybe_empty_cache()
+
+ return transformer
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_hunyuan_dit.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_hunyuan_dit.py
similarity index 84%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_hunyuan_dit.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_hunyuan_dit.py
index 6f3f0bc1c..677ec2465 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_hunyuan_dit.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_hunyuan_dit.py
@@ -11,12 +11,13 @@
)
from cache_dit.logger import init_logger
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.config import ParallelismConfig
from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)
+from .tp_utils import shard_divisible_attr
logger = init_logger(__name__)
@@ -66,8 +67,21 @@ def parallelize_transformer(
assert isinstance(block, HunyuanDiTBlock)
# Split attention heads across TP devices
- block.attn1.heads //= tp_mesh.size()
- block.attn2.heads //= tp_mesh.size()
+ tp_size = tp_mesh.size()
+ shard_divisible_attr(
+ block.attn1,
+ "heads",
+ tp_size,
+ what="attn1",
+ context="HunyuanDiTTensorParallelismPlanner",
+ )
+ shard_divisible_attr(
+ block.attn2,
+ "heads",
+ tp_size,
+ what="attn2",
+ context="HunyuanDiTTensorParallelismPlanner",
+ )
# Create layer plan for tensor parallelism
layer_plan = {
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_kandinsky5.py
similarity index 79%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_kandinsky5.py
index 7786a76c7..013ad8384 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_kandinsky5.py
@@ -9,12 +9,13 @@
)
from cache_dit.logger import init_logger
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.config import ParallelismConfig
from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)
+from .tp_utils import shard_divisible_attr
logger = init_logger(__name__)
@@ -50,8 +51,21 @@ def parallelize_transformer(
tp_mesh: DeviceMesh,
):
for _, block in transformer.visual_transformer_blocks.named_children():
- block.self_attention.num_heads //= tp_mesh.size()
- block.cross_attention.num_heads //= tp_mesh.size()
+ tp_size = tp_mesh.size()
+ shard_divisible_attr(
+ block.self_attention,
+ "num_heads",
+ tp_size,
+ what="self_attention",
+ context="Kandinsky5TensorParallelismPlanner",
+ )
+ shard_divisible_attr(
+ block.cross_attention,
+ "num_heads",
+ tp_size,
+ what="cross_attention",
+ context="Kandinsky5TensorParallelismPlanner",
+ )
layer_plan = {
"self_attention.to_query": ColwiseParallel(),
"self_attention.to_key": ColwiseParallel(),
diff --git a/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_longcat_image.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_longcat_image.py
new file mode 100644
index 000000000..51b19a452
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_longcat_image.py
@@ -0,0 +1,132 @@
+import torch
+from diffusers.models.transformers.transformer_longcat_image import (
+ LongCatImageSingleTransformerBlock,
+)
+from einops import rearrange
+from torch import nn
+from torch.distributed import DeviceMesh, init_device_mesh
+from torch.distributed._tensor import Replicate
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import (
+ TensorParallelismPlanner,
+ TensorParallelismPlannerRegister,
+)
+from .tp_utils import shard_divisible_attr
+
+logger = init_logger(__name__)
+
+
+@TensorParallelismPlannerRegister.register("LongCatImageTransformer2DModel")
+class LongCatImageTensorParallelismPlanner(TensorParallelismPlanner):
+ def apply(
+ self,
+ transformer: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert (
+ parallelism_config.tp_size is not None and parallelism_config.tp_size > 1
+ ), "parallel_config.tp_size must be set and greater than 1 for tensor parallelism"
+
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[parallelism_config.tp_size],
+ )
+
+ transformer = self.parallelize_transformer(
+ transformer=transformer,
+ tp_mesh=tp_mesh,
+ )
+
+ return transformer
+
+ def parallelize_transformer(
+ self,
+ transformer: nn.Module,
+ tp_mesh: DeviceMesh,
+ ):
+ for _, block in transformer.transformer_blocks.named_children():
+ shard_divisible_attr(
+ block.attn,
+ "heads",
+ tp_mesh.size(),
+ what="attn",
+ context="LongCatImageTensorParallelismPlanner",
+ )
+ layer_plan = {
+ "attn.to_q": ColwiseParallel(),
+ "attn.to_k": ColwiseParallel(),
+ "attn.to_v": ColwiseParallel(),
+ "attn.to_out.0": RowwiseParallel(),
+ "ff.net.0.proj": ColwiseParallel(),
+ "ff.net.2": RowwiseParallel(),
+ "attn.add_q_proj": ColwiseParallel(),
+ "attn.add_k_proj": ColwiseParallel(),
+ "attn.add_v_proj": ColwiseParallel(),
+ "attn.to_add_out": RowwiseParallel(),
+ "ff_context.net.0.proj": ColwiseParallel(),
+ "ff_context.net.2": RowwiseParallel(),
+ }
+
+ if getattr(block.norm1, "linear", None) is not None:
+ layer_plan["norm1.linear"] = ColwiseParallel(output_layouts=Replicate())
+ if getattr(block.norm1_context, "linear", None) is not None:
+ layer_plan["norm1_context.linear"] = ColwiseParallel(output_layouts=Replicate())
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ # NOTE: special handling for LongCatImageSingleTransformerBlock, we have to
+ # rearrange the proj_out weight because it contains both out and down
+ # projection weights in a single matrix.
+ def rearrange_proj_out_weight(
+ single_block: LongCatImageSingleTransformerBlock, tp_group_size
+ ):
+ # rowwise
+ hidden_dim = single_block.attn.to_q.weight.shape[0]
+ requires_grad = single_block.proj_out.weight.requires_grad
+ linear2_weight_data = single_block.proj_out.weight.data.T.detach().clone()
+ out_weight = linear2_weight_data[:hidden_dim, ...]
+ out_weight = rearrange(out_weight, "(G D) C -> G D C", G=tp_group_size)
+ down_weight = linear2_weight_data.data[hidden_dim:, ...]
+ down_weight = rearrange(down_weight, "(G D) C -> G D C", G=tp_group_size)
+ new_linear2_weight = torch.cat([out_weight, down_weight], dim=1)
+ new_linear2_weight = rearrange(new_linear2_weight, "G D C -> (G D) C")
+ single_block.proj_out.weight.data.copy_(new_linear2_weight.T)
+ single_block.proj_out.weight.requires_grad_(requires_grad)
+
+ for _, block in transformer.single_transformer_blocks.named_children():
+ rearrange_proj_out_weight(block, tp_mesh.size())
+ shard_divisible_attr(
+ block.attn,
+ "heads",
+ tp_mesh.size(),
+ what="attn",
+ context="LongCatImageTensorParallelismPlanner(single_block)",
+ )
+ layer_plan = {
+ "attn.to_q": ColwiseParallel(),
+ "attn.to_k": ColwiseParallel(),
+ "attn.to_v": ColwiseParallel(),
+ "proj_mlp": ColwiseParallel(),
+ "proj_out": RowwiseParallel(),
+ }
+ if getattr(block.norm, "linear", None) is not None:
+ layer_plan["norm.linear"] = ColwiseParallel(output_layouts=Replicate())
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+ return transformer
diff --git a/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_ltx2_video.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_ltx2_video.py
new file mode 100644
index 000000000..7c7af960d
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_ltx2_video.py
@@ -0,0 +1,249 @@
+from __future__ import annotations
+
+from typing import Optional, Union
+
+import torch
+from torch import nn
+from torch.distributed import DeviceMesh, init_device_mesh
+from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
+
+from diffusers.models.transformers.transformer_ltx2 import (
+ LTX2AudioVideoAttnProcessor,
+ LTX2VideoTransformer3DModel,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import TensorParallelismPlanner, TensorParallelismPlannerRegister
+from .tp_utils import shard_divisible_attr
+
+logger = init_logger(__name__)
+
+
+class DistributedRMSNorm(nn.Module):
+ def __init__(
+ self,
+ tp_mesh: DeviceMesh,
+ normalized_shape: Union[int, list[int], torch.Size],
+ eps: Optional[float],
+ elementwise_affine: bool,
+ weight: torch.nn.parameter.Parameter,
+ ):
+ super().__init__()
+ self.tp_mesh = tp_mesh
+ self.elementwise_affine = elementwise_affine
+ self.normalized_shape = normalized_shape
+ self.eps = eps
+ if self.elementwise_affine:
+ assert weight is not None
+ self.weight = weight
+
+ @classmethod
+ def from_rmsnorm(cls, tp_mesh: DeviceMesh, rmsnorm: nn.RMSNorm):
+ assert len(rmsnorm.normalized_shape) == 1
+ if rmsnorm.weight is not None:
+ tp_size = tp_mesh.get_group().size()
+ tp_rank = tp_mesh.get_group().rank()
+ weight = rmsnorm.weight.chunk(tp_size, dim=0)[tp_rank]
+ else:
+ weight = None
+ norm = cls(
+ tp_mesh=tp_mesh,
+ normalized_shape=rmsnorm.normalized_shape,
+ eps=rmsnorm.eps,
+ elementwise_affine=rmsnorm.elementwise_affine,
+ weight=weight,
+ )
+ return norm
+
+ def forward(self, x):
+ if self.elementwise_affine:
+ assert x.shape[-1] == self.weight.shape[0]
+ mean_square = torch.mean(x * x, dim=-1, keepdim=True)
+ torch.distributed.all_reduce(
+ mean_square,
+ op=torch.distributed.ReduceOp.AVG,
+ group=self.tp_mesh.get_group(),
+ )
+ root_mean_square = torch.sqrt(mean_square + self.eps)
+ x_normed = x / root_mean_square
+ if self.elementwise_affine:
+ x_normed = x_normed * self.weight.to(device=x.device)
+ assert x_normed.device.type != "cpu"
+ return x_normed
+
+
+class ShardRotaryEmbProcessor:
+ """Shard query/key rotary embeddings to match TP-sharded heads/channels.
+
+ - interleaved RoPE: cos/sin are (B, T, D) -> shard along last dim
+ - split RoPE: cos/sin are (B, H, T, D/2) -> shard along head dim
+ """
+
+ def __init__(self, processor: LTX2AudioVideoAttnProcessor, tp_size: int, tp_rank: int):
+ self.processor = processor
+ self.tp_size = tp_size
+ self.tp_rank = tp_rank
+
+ @classmethod
+ def from_attn_processor(
+ cls, processor: LTX2AudioVideoAttnProcessor, tp_size: int, tp_rank: int
+ ) -> "ShardRotaryEmbProcessor":
+ return cls(processor=processor, tp_size=tp_size, tp_rank=tp_rank)
+
+ def _shard_rope(self, emb):
+ if emb is None:
+ return None
+ cos, sin = emb
+ if cos is None or sin is None:
+ return emb
+ # split rope: (B, H, T, D/2)
+ if cos.ndim == 4:
+ cos = torch.chunk(cos, self.tp_size, dim=1)[self.tp_rank]
+ sin = torch.chunk(sin, self.tp_size, dim=1)[self.tp_rank]
+ else:
+ # interleaved rope: (B, T, D)
+ cos = torch.chunk(cos, self.tp_size, dim=-1)[self.tp_rank]
+ sin = torch.chunk(sin, self.tp_size, dim=-1)[self.tp_rank]
+ return (cos, sin)
+
+ def __call__(
+ self,
+ attn,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ query_rotary_emb=None,
+ key_rotary_emb=None,
+ ) -> torch.Tensor:
+ query_rotary_emb = self._shard_rope(query_rotary_emb)
+ key_rotary_emb = self._shard_rope(key_rotary_emb)
+ return self.processor(
+ attn,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ query_rotary_emb=query_rotary_emb,
+ key_rotary_emb=key_rotary_emb,
+ )
+
+
+@TensorParallelismPlannerRegister.register("LTX2")
+class LTX2VideoTensorParallelismPlanner(TensorParallelismPlanner):
+ def apply(
+ self,
+ transformer: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert (
+ parallelism_config.tp_size is not None and parallelism_config.tp_size > 1
+ ), "parallel_config.tp_size must be set and greater than 1 for tensor parallelism"
+ assert isinstance(transformer, LTX2VideoTransformer3DModel), (
+ "Transformer must be an instance of LTX2VideoTransformer3DModel, "
+ f"but got {type(transformer)}"
+ )
+
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[parallelism_config.tp_size],
+ )
+
+ return self.parallelize_transformer(transformer=transformer, tp_mesh=tp_mesh)
+
+ def parallelize_transformer(self, transformer: nn.Module, tp_mesh: DeviceMesh):
+ tp_size = tp_mesh.get_group().size()
+ tp_rank = tp_mesh.get_group().rank()
+
+ def _shard_attention(attn: nn.Module, what: str):
+ shard_divisible_attr(
+ attn,
+ "heads",
+ tp_size,
+ what=what,
+ context="LTX2VideoTensorParallelismPlanner",
+ )
+
+ def prepare_block(block: nn.Module):
+ # Shard heads for all attention modules inside the block
+ _shard_attention(block.attn1, "attn1")
+ _shard_attention(block.attn2, "attn2")
+ _shard_attention(block.audio_attn1, "audio_attn1")
+ _shard_attention(block.audio_attn2, "audio_attn2")
+ _shard_attention(block.audio_to_video_attn, "audio_to_video_attn")
+ _shard_attention(block.video_to_audio_attn, "video_to_audio_attn")
+
+ layer_plan = {
+ # video self-attn / text cross-attn
+ "attn1.to_q": ColwiseParallel(),
+ "attn1.to_k": ColwiseParallel(),
+ "attn1.to_v": ColwiseParallel(),
+ "attn1.to_out.0": RowwiseParallel(),
+ "attn2.to_q": ColwiseParallel(),
+ "attn2.to_k": ColwiseParallel(),
+ "attn2.to_v": ColwiseParallel(),
+ "attn2.to_out.0": RowwiseParallel(),
+ # audio self-attn / text cross-attn
+ "audio_attn1.to_q": ColwiseParallel(),
+ "audio_attn1.to_k": ColwiseParallel(),
+ "audio_attn1.to_v": ColwiseParallel(),
+ "audio_attn1.to_out.0": RowwiseParallel(),
+ "audio_attn2.to_q": ColwiseParallel(),
+ "audio_attn2.to_k": ColwiseParallel(),
+ "audio_attn2.to_v": ColwiseParallel(),
+ "audio_attn2.to_out.0": RowwiseParallel(),
+ # a2v / v2a cross-attn
+ "audio_to_video_attn.to_q": ColwiseParallel(),
+ "audio_to_video_attn.to_k": ColwiseParallel(),
+ "audio_to_video_attn.to_v": ColwiseParallel(),
+ "audio_to_video_attn.to_out.0": RowwiseParallel(),
+ "video_to_audio_attn.to_q": ColwiseParallel(),
+ "video_to_audio_attn.to_k": ColwiseParallel(),
+ "video_to_audio_attn.to_v": ColwiseParallel(),
+ "video_to_audio_attn.to_out.0": RowwiseParallel(),
+ # FFNs
+ "ff.net.0.proj": ColwiseParallel(),
+ "ff.net.2": RowwiseParallel(),
+ "audio_ff.net.0.proj": ColwiseParallel(),
+ "audio_ff.net.2": RowwiseParallel(),
+ }
+
+ parallelize_module(module=block, device_mesh=tp_mesh, parallelize_plan=layer_plan)
+
+ # Shard qk norms
+ for attn in (
+ block.attn1,
+ block.attn2,
+ block.audio_attn1,
+ block.audio_attn2,
+ block.audio_to_video_attn,
+ block.video_to_audio_attn,
+ ):
+ attn.norm_q = DistributedRMSNorm.from_rmsnorm(tp_mesh, attn.norm_q)
+ attn.norm_k = DistributedRMSNorm.from_rmsnorm(tp_mesh, attn.norm_k)
+
+ # Shard RoPE frequencies for all attention processors in every block.
+ # NOTE: This assumes rotary embedding head counts align with attention heads (true for LTX-2 configs).
+ for _, block in transformer.transformer_blocks.named_children():
+ for attn_name in (
+ "attn1",
+ "attn2",
+ "audio_attn1",
+ "audio_attn2",
+ "audio_to_video_attn",
+ "video_to_audio_attn",
+ ):
+ attn = getattr(block, attn_name)
+ if hasattr(attn, "processor") and isinstance(
+ attn.processor, LTX2AudioVideoAttnProcessor
+ ):
+ attn.processor = ShardRotaryEmbProcessor.from_attn_processor(
+ processor=attn.processor,
+ tp_size=tp_size,
+ tp_rank=tp_rank,
+ )
+ prepare_block(block)
+
+ return transformer
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_ltx_video.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_ltx_video.py
similarity index 91%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_ltx_video.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_ltx_video.py
index ae57a0f88..823f1b170 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_ltx_video.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_ltx_video.py
@@ -11,12 +11,13 @@
from diffusers.models.transformers.transformer_ltx import LTXVideoAttnProcessor
from cache_dit.logger import init_logger
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.config import ParallelismConfig
from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)
+from .tp_utils import shard_divisible_attr
logger = init_logger(__name__)
@@ -139,8 +140,20 @@ def parallelize_transformer(
tp_rank = tp_mesh.get_group().rank()
def prepare_block(block: nn.Module):
- block.attn1.heads //= tp_size
- block.attn2.heads //= tp_size
+ shard_divisible_attr(
+ block.attn1,
+ "heads",
+ tp_size,
+ what="attn1",
+ context="LTXVideoTensorParallelismPlanner",
+ )
+ shard_divisible_attr(
+ block.attn2,
+ "heads",
+ tp_size,
+ what="attn2",
+ context="LTXVideoTensorParallelismPlanner",
+ )
layer_plan = {
"attn1.to_q": ColwiseParallel(),
"attn1.to_k": ColwiseParallel(),
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_mochi.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_mochi.py
similarity index 93%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_mochi.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_mochi.py
index 7096e69a5..0cf86e4a0 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_mochi.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_mochi.py
@@ -14,12 +14,13 @@
)
from cache_dit.logger import init_logger
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.config import ParallelismConfig
from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)
+from .tp_utils import shard_divisible_attr
logger = init_logger(__name__)
@@ -110,7 +111,13 @@ def parallelize_transformer(
)
self.rearrange_feedforward_weight(block, tp_size)
- block.attn1.heads //= tp_size
+ shard_divisible_attr(
+ block.attn1,
+ "heads",
+ tp_size,
+ what="attn1",
+ context="MochiTensorParallelismPlanner",
+ )
layer_plan = {
"attn1.to_q": ColwiseParallel(),
"attn1.to_k": ColwiseParallel(),
diff --git a/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_ovis_image.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_ovis_image.py
new file mode 100644
index 000000000..6b01e94ba
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_ovis_image.py
@@ -0,0 +1,221 @@
+import torch
+from diffusers.models.transformers.transformer_ovis_image import (
+ OvisImageSingleTransformerBlock,
+ OvisImageTransformerBlock,
+ OvisImageTransformer2DModel,
+)
+from einops import rearrange
+from torch import nn
+from torch.distributed import DeviceMesh, init_device_mesh
+from torch.distributed._tensor import Replicate
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ RowwiseParallel,
+ parallelize_module,
+)
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import (
+ TensorParallelismPlanner,
+ TensorParallelismPlannerRegister,
+)
+from .tp_utils import shard_divisible_attr
+
+logger = init_logger(__name__)
+
+
+@TensorParallelismPlannerRegister.register("OvisImage")
+class OvisImageTensorParallelismPlanner(TensorParallelismPlanner):
+ def apply(
+ self,
+ transformer: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert (
+ parallelism_config.tp_size is not None and parallelism_config.tp_size > 1
+ ), "parallel_config.tp_size must be set and greater than 1 for tensor parallelism"
+
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[parallelism_config.tp_size],
+ )
+
+ transformer = self.parallelize_transformer(
+ transformer=transformer,
+ tp_mesh=tp_mesh,
+ )
+
+ return transformer
+
+ def parallelize_transformer(
+ self,
+ transformer: nn.Module,
+ tp_mesh: DeviceMesh,
+ ):
+ assert isinstance(transformer, OvisImageTransformer2DModel)
+
+ for _, block in transformer.transformer_blocks.named_children():
+ assert isinstance(block, OvisImageTransformerBlock)
+ rearrange_ffn_0_swiglu_proj_weight(block.ff.net[0].proj, tp_mesh.size())
+ rearrange_ffn_0_swiglu_proj_weight(block.ff_context.net[0].proj, tp_mesh.size())
+ shard_divisible_attr(
+ block.attn,
+ "heads",
+ tp_mesh.size(),
+ what="attn",
+ context="OvisImageTensorParallelismPlanner(transformer_blocks)",
+ )
+ layer_plan = {
+ "attn.to_q": ColwiseParallel(),
+ "attn.to_k": ColwiseParallel(),
+ "attn.to_v": ColwiseParallel(),
+ "attn.to_out.0": RowwiseParallel(),
+ "ff.net.0.proj": ColwiseParallel(),
+ "ff.net.2": RowwiseParallel(),
+ "attn.add_q_proj": ColwiseParallel(),
+ "attn.add_k_proj": ColwiseParallel(),
+ "attn.add_v_proj": ColwiseParallel(),
+ "attn.to_add_out": RowwiseParallel(),
+ "ff_context.net.0.proj": ColwiseParallel(),
+ "ff_context.net.2": RowwiseParallel(),
+ }
+
+ if getattr(block.norm1, "linear", None) is not None:
+ layer_plan["norm1.linear"] = ColwiseParallel(output_layouts=Replicate())
+ if getattr(block.norm1_context, "linear", None) is not None:
+ layer_plan["norm1_context.linear"] = ColwiseParallel(output_layouts=Replicate())
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ for _, block in transformer.single_transformer_blocks.named_children():
+ assert isinstance(block, OvisImageSingleTransformerBlock)
+ rearrange_proj_out_weight(block, tp_mesh.size())
+ shard_divisible_attr(
+ block.attn,
+ "heads",
+ tp_mesh.size(),
+ what="attn",
+ context="OvisImageTensorParallelismPlanner(single_transformer_blocks)",
+ )
+ rearrange_proj_mlp_weight(block, tp_mesh.size())
+ shard_divisible_attr(
+ block,
+ "mlp_hidden_dim",
+ tp_mesh.size(),
+ what="block",
+ context="OvisImageTensorParallelismPlanner(single_transformer_blocks)",
+ )
+ # Compute order: proj_mlp, to_q, to_k, to_v, proj_out
+ # proj_mlp: dim -> self.mlp_hidden_dim * 2 -> split by mlp_hidden_dim
+ layer_plan = {
+ "proj_mlp": ColwiseParallel(),
+ "attn.to_q": ColwiseParallel(),
+ "attn.to_k": ColwiseParallel(),
+ "attn.to_v": ColwiseParallel(),
+ "proj_out": RowwiseParallel(),
+ }
+ if getattr(block.norm, "linear", None) is not None:
+ layer_plan["norm.linear"] = ColwiseParallel(output_layouts=Replicate())
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+ return transformer
+
+
+# NOTE: Special handling for OvisImageSingleTransformerBlock, we have to rearrange the
+# proj_out weight because it contains both out and down projection weights in a single matrix.
+def rearrange_proj_out_weight(single_block: OvisImageSingleTransformerBlock, tp_group_size):
+ # Rowwise: rearrange the proj_out weight for RowwiseParallel, (M,K)x(K,N), permute at K (in_dim)
+ hidden_dim = single_block.attn.to_q.weight.shape[0]
+ requires_grad = single_block.proj_out.weight.requires_grad
+ linear2_weight_data = single_block.proj_out.weight.data.T.detach().clone()
+ out_weight = linear2_weight_data[:hidden_dim, ...]
+ out_weight = rearrange(out_weight, "(G D) C -> G D C", G=tp_group_size)
+ down_weight = linear2_weight_data.data[hidden_dim:, ...]
+ down_weight = rearrange(down_weight, "(G D) C -> G D C", G=tp_group_size)
+ new_linear2_weight = torch.cat([out_weight, down_weight], dim=1)
+ new_linear2_weight = rearrange(new_linear2_weight, "G D C -> (G D) C")
+ single_block.proj_out.weight.data.copy_(new_linear2_weight.T)
+ single_block.proj_out.weight.requires_grad_(requires_grad)
+
+
+def rearrange_proj_mlp_weight(single_block: OvisImageSingleTransformerBlock, tp_group_size):
+ # Colwise: rearrange the proj_mlp weight for ColwiseParallel, (M,K)x(K,N), permute at N (out_dim)
+ # Original tensor shape: [*, Hd + Gd], where Hd = Gd (Hd and Gd have the same dimension size)
+ # Linear transformation definition: y = x * A^T, where
+ # A: [out_dim, in_dim] (transformation matrix)
+ # x: [*, in_dim] (input tensor, * denotes arbitrary leading dimensions)
+ #
+ # Tensor Parallel (TP) dimension permutation logic:
+ # 1. Split Hd and Gd evenly according to the TP group size (tp_group_size)
+ # - When tp_group_size=2: Split [..., Hd+Gd] into [..., (Hd/2+Gd/2) + (Hd/2+Gd/2)]
+ # - When tp_group_size=4: Split [..., Hd+Gd] into [..., (Hd/4+Gd/4)*4]
+ # Expanded form: [..., Hd/4+Gd/4 + Hd/4+Gd/4 + Hd/4+Gd/4 + Hd/4+Gd/4]
+ # 2. Perform dimension permutation and rearrangement on the split tensor
+ # 3. Reshape the tensor back to the original shape [..., (Hd + Gd)] finally
+ mlp_hidden_dim = single_block.proj_mlp.weight.shape[0] // 2
+ requires_grad = single_block.proj_mlp.weight.requires_grad
+ linear1_weight_data = single_block.proj_mlp.weight.data.T.detach().clone() # [in_dim, out_dim]
+ new_linear1_weight = torch.zeros_like(linear1_weight_data)
+ part1_linear1_weight_data = linear1_weight_data[..., :mlp_hidden_dim]
+ part2_linear1_weight_data = linear1_weight_data[..., mlp_hidden_dim:]
+ split_size = mlp_hidden_dim // tp_group_size
+ for i in range(tp_group_size):
+ start_idx = i * split_size
+ end_idx = (i + 1) * split_size
+ new_linear1_weight[..., i * 2 * split_size : (i * 2 + 1) * split_size] = (
+ part1_linear1_weight_data[..., start_idx:end_idx]
+ )
+ new_linear1_weight[..., (i * 2 + 1) * split_size : (i * 2 + 2) * split_size] = (
+ part2_linear1_weight_data[..., start_idx:end_idx]
+ )
+
+ single_block.proj_mlp.weight.data.copy_(new_linear1_weight.T) # [out_dim, in_dim]
+ single_block.proj_mlp.weight.requires_grad_(requires_grad)
+
+
+# Ovis-Image use SwiGLU: self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
+# hidden_states = self.proj(hidden_states); hidden_states, gate = hidden_states.chunk(2, dim=-1)
+# reference: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/activations.py#L140
+def rearrange_ffn_0_swiglu_proj_weight(proj: torch.nn.Linear, tp_group_size):
+ # Colwise: rearrange the proj_mlp weight for ColwiseParallel, (M,K)x(K,N), permute at N (out_dim)
+ # Original tensor shape: [*, Hd + Gd], where Hd = Gd (Hd and Gd have the same dimension size)
+ # Linear transformation definition: y = x * A^T, where
+ # A: [out_dim, in_dim] (transformation matrix)
+ # x: [*, in_dim] (input tensor, * denotes arbitrary leading dimensions)
+ #
+ # Tensor Parallel (TP) dimension permutation logic:
+ # 1. Split Hd and Gd evenly according to the TP group size (tp_group_size)
+ # - When tp_group_size=2: Split [..., Hd+Gd] into [..., (Hd/2+Gd/2) + (Hd/2+Gd/2)]
+ # - When tp_group_size=4: Split [..., Hd+Gd] into [..., (Hd/4+Gd/4)*4]
+ # Expanded form: [..., Hd/4+Gd/4 + Hd/4+Gd/4 + Hd/4+Gd/4 + Hd/4+Gd/4]
+ # 2. Perform dimension permutation and rearrangement on the split tensor
+ # 3. Reshape the tensor back to the original shape [..., (Hd + Gd)] finally
+ dim_out = proj.weight.shape[0] // 2
+ requires_grad = proj.weight.requires_grad
+ linear1_weight_data = proj.weight.data.T.detach().clone() # [in_dim, out_dim]
+ new_linear1_weight = torch.zeros_like(linear1_weight_data)
+ part1_linear1_weight_data = linear1_weight_data[..., :dim_out]
+ part2_linear1_weight_data = linear1_weight_data[..., dim_out:]
+ split_size = dim_out // tp_group_size
+ for i in range(tp_group_size):
+ start_idx = i * split_size
+ end_idx = (i + 1) * split_size
+ new_linear1_weight[..., i * 2 * split_size : (i * 2 + 1) * split_size] = (
+ part1_linear1_weight_data[..., start_idx:end_idx]
+ )
+ new_linear1_weight[..., (i * 2 + 1) * split_size : (i * 2 + 2) * split_size] = (
+ part2_linear1_weight_data[..., start_idx:end_idx]
+ )
+
+ proj.weight.data.copy_(new_linear1_weight.T) # [out_dim, in_dim]
+ proj.weight.requires_grad_(requires_grad)
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_pixart.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_pixart.py
similarity index 83%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_pixart.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_pixart.py
index 64a113fc9..07a1007b1 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_pixart.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_pixart.py
@@ -8,12 +8,13 @@
)
from cache_dit.logger import init_logger
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.config import ParallelismConfig
from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)
+from .tp_utils import shard_divisible_attr
logger = init_logger(__name__)
@@ -59,8 +60,21 @@ def parallelize_transformer(
"""
for i, block in enumerate(transformer.transformer_blocks):
# Split attention heads across TP devices
- block.attn1.heads //= tp_mesh.size()
- block.attn2.heads //= tp_mesh.size()
+ tp_size = tp_mesh.size()
+ shard_divisible_attr(
+ block.attn1,
+ "heads",
+ tp_size,
+ what="attn1",
+ context="PixArtTensorParallelismPlanner",
+ )
+ shard_divisible_attr(
+ block.attn2,
+ "heads",
+ tp_size,
+ what="attn2",
+ context="PixArtTensorParallelismPlanner",
+ )
# Create layer plan for tensor parallelism
layer_plan = {
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_qwen_image.py
similarity index 71%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_qwen_image.py
index 4f76dda78..94df11dbf 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_qwen_image.py
@@ -1,5 +1,4 @@
import torch
-from torch import nn
from torch.distributed import DeviceMesh, init_device_mesh
from torch.distributed._tensor import Replicate
from torch.distributed.tensor.parallel import (
@@ -7,18 +6,20 @@
RowwiseParallel,
parallelize_module,
)
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from diffusers import QwenImageTransformer2DModel
+from cache_dit.parallelism.config import ParallelismConfig
from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)
+from .tp_utils import shard_divisible_attr
from cache_dit.logger import init_logger
logger = init_logger(__name__)
-@TensorParallelismPlannerRegister.register("QwenImage")
+@TensorParallelismPlannerRegister.register("QwenImageTransformer2DModel")
class QwenImageTensorParallelismPlanner(TensorParallelismPlanner):
def apply(
self,
@@ -26,9 +27,9 @@ def apply(
parallelism_config: ParallelismConfig,
**kwargs,
) -> torch.nn.Module:
- assert parallelism_config.tp_size is not None and parallelism_config.tp_size > 1, (
- "parallel_config.tp_size must be set and greater than 1 for " "tensor parallelism"
- )
+ assert (
+ parallelism_config.tp_size is not None and parallelism_config.tp_size > 1
+ ), "parallel_config.tp_size must be set and greater than 1 for tensor parallelism"
device_type = torch.accelerator.current_accelerator().type
tp_mesh: DeviceMesh = init_device_mesh(
@@ -45,11 +46,20 @@ def apply(
def parallelize_transformer(
self,
- transformer: nn.Module,
+ transformer: QwenImageTransformer2DModel,
tp_mesh: DeviceMesh,
):
+ from diffusers.models.transformers.transformer_qwenimage import QwenImageTransformerBlock
+
for _, block in transformer.transformer_blocks.named_children():
- block.attn.heads //= tp_mesh.size()
+ assert isinstance(block, QwenImageTransformerBlock)
+ shard_divisible_attr(
+ block.attn,
+ "heads",
+ tp_mesh.size(),
+ what="attn",
+ context="QwenImageTensorParallelismPlanner",
+ )
layer_plan = {
"attn.to_q": ColwiseParallel(),
"attn.to_k": ColwiseParallel(),
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_registers.py
similarity index 92%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_registers.py
index 1019ff889..c286e26f7 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_registers.py
@@ -2,15 +2,13 @@
import logging
from abc import abstractmethod
from typing import Dict
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.config import ParallelismConfig
from cache_dit.logger import init_logger
logger = init_logger(__name__)
class TensorParallelismPlanner:
- # TODO: add `apply_extra` abstract method for extra
- # parallelism kwargs handling
@abstractmethod
def apply(
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_skyreels.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_skyreels.py
similarity index 90%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_skyreels.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_skyreels.py
index 2d654b8fd..d092ad7d2 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_skyreels.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_skyreels.py
@@ -9,12 +9,13 @@
)
from cache_dit.logger import init_logger
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.config import ParallelismConfig
from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)
+from .tp_utils import shard_divisible_attr
logger = init_logger(__name__)
@@ -58,7 +59,13 @@ def parallelize_transformer(
# Reduce the number of attention heads per device
if hasattr(block, "attn"):
if hasattr(block.attn, "heads"):
- block.attn.heads //= tp_size
+ shard_divisible_attr(
+ block.attn,
+ "heads",
+ tp_size,
+ what="attn",
+ context="SkyReelsV2TensorParallelismPlanner",
+ )
# Define parallelization plan for each block
# This follows a standard pattern:
diff --git a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_wan.py
similarity index 90%
rename from src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py
rename to src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_wan.py
index 447cdf578..84da045c7 100644
--- a/src/cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_wan.py
@@ -10,12 +10,13 @@
)
from cache_dit.logger import init_logger
-from cache_dit.parallelism.parallel_config import ParallelismConfig
+from cache_dit.parallelism.config import ParallelismConfig
from .tp_plan_registers import (
TensorParallelismPlanner,
TensorParallelismPlannerRegister,
)
+from .tp_utils import shard_divisible_attr
logger = init_logger(__name__)
@@ -107,8 +108,21 @@ def parallelize_transformer(
tp_mesh: DeviceMesh,
):
def prepare_block(block: nn.Module):
- block.attn1.heads //= tp_mesh.size()
- block.attn2.heads //= tp_mesh.size()
+ tp_size = tp_mesh.size()
+ shard_divisible_attr(
+ block.attn1,
+ "heads",
+ tp_size,
+ what="attn1",
+ context="WanTensorParallelismPlanner",
+ )
+ shard_divisible_attr(
+ block.attn2,
+ "heads",
+ tp_size,
+ what="attn2",
+ context="WanTensorParallelismPlanner",
+ )
layer_plan = {
"attn1.to_q": ColwiseParallel(),
"attn1.to_k": ColwiseParallel(),
diff --git a/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_zimage.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_zimage.py
new file mode 100644
index 000000000..786166a43
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_plan_zimage.py
@@ -0,0 +1,84 @@
+import torch
+from torch.distributed import DeviceMesh, init_device_mesh
+from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module
+
+from cache_dit.logger import init_logger
+from cache_dit.parallelism.config import ParallelismConfig
+
+from .tp_plan_registers import TensorParallelismPlanner, TensorParallelismPlannerRegister
+from .tp_utils import shard_divisible_attr
+
+logger = init_logger(__name__)
+
+
+@TensorParallelismPlannerRegister.register("Lumina2")
+@TensorParallelismPlannerRegister.register("ZImage")
+class ZImageTensorParallelismPlanner(TensorParallelismPlanner):
+ def apply(
+ self,
+ transformer: torch.nn.Module,
+ parallelism_config: ParallelismConfig,
+ **kwargs,
+ ) -> torch.nn.Module:
+ assert parallelism_config.tp_size is not None and parallelism_config.tp_size > 1, (
+ "parallel_config.tp_size must be set and greater than 1 for " "tensor parallelism"
+ )
+
+ device_type = torch.accelerator.current_accelerator().type
+ tp_mesh: DeviceMesh = init_device_mesh(
+ device_type=device_type,
+ mesh_shape=[parallelism_config.tp_size],
+ )
+
+ transformer = self.parallelize_transformer(
+ transformer=transformer,
+ tp_mesh=tp_mesh,
+ )
+
+ return transformer
+
+ def parallelize_transformer(
+ self,
+ transformer: torch.nn.Module,
+ tp_mesh: DeviceMesh,
+ ):
+ class_name = transformer.__class__.__name__
+
+ def tp_shard_block(block, tp_size):
+ attn_mod_name = "attention" if class_name.startswith("ZImage") else "attn"
+ ff_linear_name = "w" if class_name.startswith("ZImage") else "linear_"
+ attn = getattr(block, attn_mod_name)
+ shard_divisible_attr(
+ attn,
+ "heads",
+ tp_size,
+ what=attn_mod_name,
+ context="ZImageTensorParallelismPlanner",
+ )
+ layer_plan = {
+ f"{attn_mod_name}.to_q": ColwiseParallel(),
+ f"{attn_mod_name}.to_k": ColwiseParallel(),
+ f"{attn_mod_name}.to_v": ColwiseParallel(),
+ f"{attn_mod_name}.to_out.0": RowwiseParallel(),
+ f"feed_forward.{ff_linear_name}1": ColwiseParallel(),
+ f"feed_forward.{ff_linear_name}3": ColwiseParallel(),
+ f"feed_forward.{ff_linear_name}2": RowwiseParallel(),
+ # saving more memory at the cost of more communication
+ # "adaLN_modulation.0": ColwiseParallel(output_layouts=Replicate()),
+ }
+
+ parallelize_module(
+ module=block,
+ device_mesh=tp_mesh,
+ parallelize_plan=layer_plan,
+ )
+
+ tp_size = tp_mesh.get_group().size()
+ for _, block in transformer.noise_refiner.named_children():
+ tp_shard_block(block, tp_size)
+ for _, block in transformer.context_refiner.named_children():
+ tp_shard_block(block, tp_size)
+ for _, block in transformer.layers.named_children():
+ tp_shard_block(block, tp_size)
+
+ return transformer
diff --git a/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_planners.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_planners.py
new file mode 100644
index 000000000..4b127d93c
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_planners.py
@@ -0,0 +1,81 @@
+import importlib
+from cache_dit.logger import init_logger
+from .tp_plan_registers import TensorParallelismPlanner
+
+logger = init_logger(__name__)
+
+
+class ImportErrorTensorParallelismPlanner(TensorParallelismPlanner):
+ def plan(
+ self,
+ transformer,
+ **kwargs,
+ ):
+ raise ImportError(
+ "This TensorParallelismPlanner requires latest diffusers to be installed. "
+ "Please install diffusers from source."
+ )
+
+
+def _safe_import(module_name: str, class_name: str) -> type[TensorParallelismPlanner]:
+ try:
+ # e.g., module_name = ".tp_plan_dit", class_name = "DiTTensorParallelismPlanner"
+ package = __package__ if __package__ is not None else ""
+ module = importlib.import_module(module_name, package=package)
+ target_class = getattr(module, class_name)
+ return target_class
+ except (ImportError, AttributeError) as e:
+ logger.debug(f"Failed to import {class_name} from {module_name}: {e}")
+ return ImportErrorTensorParallelismPlanner
+
+
+def _activate_tp_planners():
+ """Function to register all built-in tensor parallelism planners."""
+ CogViewTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_cogview", "CogViewTensorParallelismPlanner"
+ )
+ FluxTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_flux", "FluxTensorParallelismPlanner"
+ )
+ Flux2TensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_flux2", "Flux2TensorParallelismPlanner"
+ )
+ HunyuanDiTTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_hunyuan_dit", "HunyuanDiTTensorParallelismPlanner"
+ )
+ Kandinsky5TensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_kandinsky5", "Kandinsky5TensorParallelismPlanner"
+ )
+ MochiTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_mochi", "MochiTensorParallelismPlanner"
+ )
+ LTXVideoTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_ltx_video", "LTXVideoTensorParallelismPlanner"
+ )
+ LTX2VideoTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_ltx2_video", "LTX2VideoTensorParallelismPlanner"
+ )
+ PixArtTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_pixart", "PixArtTensorParallelismPlanner"
+ )
+ QwenImageTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_qwen_image", "QwenImageTensorParallelismPlanner"
+ )
+ WanTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_wan", "WanTensorParallelismPlanner"
+ )
+ SkyReelsV2TensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_skyreels", "SkyReelsV2TensorParallelismPlanner"
+ )
+ ZImageTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_zimage", "ZImageTensorParallelismPlanner"
+ )
+ OvisImageTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_ovis_image", "OvisImageTensorParallelismPlanner"
+ )
+ LongCatImageTensorParallelismPlanner = _safe_import( # noqa: F841
+ ".tp_plan_longcat_image", "LongCatImageTensorParallelismPlanner"
+ )
+
+
+__all__ = ["_activate_tp_planners"]
diff --git a/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_utils.py b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_utils.py
new file mode 100644
index 000000000..c9d08589b
--- /dev/null
+++ b/src/cache_dit/parallelism/transformers/tensor_parallelism/tp_utils.py
@@ -0,0 +1,67 @@
+from __future__ import annotations
+
+from typing import Any, Optional
+
+
+def _divisors(n: int) -> list[int]:
+ n = int(n)
+ if n <= 0:
+ return []
+ small: list[int] = []
+ large: list[int] = []
+ d = 1
+ while d * d <= n:
+ if n % d == 0:
+ small.append(d)
+ if d * d != n:
+ large.append(n // d)
+ d += 1
+ return small + list(reversed(large))
+
+
+def shard_divisible_attr(
+ obj: Any,
+ attr: str,
+ tp_size: int,
+ *,
+ what: Optional[str] = None,
+ context: Optional[str] = None,
+) -> int:
+ """
+ Shard (divide) an integer attribute by tp_size, with a fail-fast divisibility check.
+
+ This is primarily used for sharding attention `heads` / `num_heads` in tensor parallelism
+ planners. If the value is not divisible by tp_size, we raise a clear ValueError during
+ model initialization (before serving / inference).
+ """
+ tp_size = int(tp_size)
+ if tp_size <= 0:
+ raise ValueError(f"[TP] Invalid tp_size={tp_size}.")
+
+ if not hasattr(obj, attr):
+ raise AttributeError(f"[TP] Object {type(obj).__name__} has no attribute '{attr}'.")
+
+ raw = getattr(obj, attr)
+ try:
+ value = int(raw)
+ except Exception as e:
+ raise TypeError(
+ f"[TP] Attribute '{attr}' on {type(obj).__name__} must be int-like, got {raw!r}."
+ ) from e
+
+ if value <= 0:
+ raise ValueError(f"[TP] Attribute '{attr}' must be > 0, got {value}.")
+
+ if value % tp_size != 0:
+ divs = [d for d in _divisors(value) if d > 1]
+ divs_str = ", ".join(map(str, divs)) if divs else "(none)"
+ obj_name = what or type(obj).__name__
+ prefix = f"{context}: " if context else ""
+ raise ValueError(
+ f"[TP] {prefix}Unsupported tp_size={tp_size} for {obj_name}.{attr}={value}. "
+ f"{attr} must be divisible by tp_size. Valid tp_size (>1): {divs_str}."
+ )
+
+ new_value = value // tp_size
+ setattr(obj, attr, new_value)
+ return new_value
diff --git a/src/cache_dit/platforms/__init__.py b/src/cache_dit/platforms/__init__.py
new file mode 100644
index 000000000..a6d8b0b1a
--- /dev/null
+++ b/src/cache_dit/platforms/__init__.py
@@ -0,0 +1,57 @@
+import torch
+import importlib
+from typing import TYPE_CHECKING
+from .platform import BasePlatform, CudaPlatform, CpuPlatform, NPUPlatform # noqa: F401
+
+
+def resolve_obj_by_qualname(qualname: str) -> BasePlatform:
+ """
+ Resolve an object by its fully-qualified class name.
+ """
+ module_name, obj_name = qualname.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+ return getattr(module, obj_name)
+
+
+def resolve_current_platform_cls_qualname() -> str:
+ if torch.cuda.is_available():
+ return "cache_dit.platforms.platform.CudaPlatform"
+ try:
+ import torch_npu # type: ignore # noqa
+
+ return "cache_dit.platforms.platform.NPUPlatform"
+ except ImportError:
+ return "cache_dit.platforms.platform.CpuPlatform"
+
+
+_current_platform: BasePlatform = None
+
+
+if TYPE_CHECKING:
+ current_platform: BasePlatform
+
+
+def __getattr__(name: str):
+ if name == "current_platform":
+ global _current_platform
+ if _current_platform is None:
+ platform_cls_qualname = resolve_current_platform_cls_qualname()
+ _current_platform = resolve_obj_by_qualname(platform_cls_qualname)()
+ return _current_platform
+ elif name in globals():
+ return globals()[name]
+ else:
+ raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
+
+
+def __setattr__(name: str, value):
+ if name == "current_platform":
+ global _current_platform
+ _current_platform = value
+ elif name in globals():
+ globals()[name] = value
+ else:
+ raise AttributeError(f"No attribute named '{name}' exists in {__name__}.")
+
+
+__all__ = ["BasePlatform", "current_platform"]
diff --git a/src/cache_dit/platforms/platform.py b/src/cache_dit/platforms/platform.py
new file mode 100644
index 000000000..e15df5d16
--- /dev/null
+++ b/src/cache_dit/platforms/platform.py
@@ -0,0 +1,214 @@
+# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/platforms
+import torch
+from abc import ABC
+
+
+class BasePlatform(ABC):
+ device_type: str
+ device_control_env_var: str
+ dispatch_key: str
+ dist_backend: str
+ full_dist_backend: str
+
+ @staticmethod
+ def empty_cache(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def ipc_collect(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def get_device_name():
+ raise NotImplementedError
+
+ @staticmethod
+ def device_ctx(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def default_device(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def synchronize(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def device_count(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def is_accelerator_available(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def current_device(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def reset_peak_memory_stats(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def max_memory_allocated(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def get_device_properties(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def set_device(*args, **kwargs):
+ raise NotImplementedError
+
+ @staticmethod
+ def get_device_capability(*args, **kwargs):
+ raise NotImplementedError
+
+
+class CpuPlatform(BasePlatform):
+ device_type: str = "cpu"
+ dispatch_key: str = "CPU"
+ device_control_env_var = "CPU_VISIBLE_MEMORY_NODES"
+ dist_backend: str = "gloo"
+ full_dist_backend: str = "cpu:gloo"
+
+ @staticmethod
+ def default_device():
+ return torch.device("cpu")
+
+ @staticmethod
+ def get_device_name():
+ return "CPU"
+
+ @staticmethod
+ def is_accelerator_available():
+ return False
+
+
+class CudaPlatform(BasePlatform):
+ device_type: str = "cuda"
+ device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
+ dispatch_key: str = "CUDA"
+ dist_backend: str = "nccl"
+ full_dist_backend: str = "cuda:nccl"
+
+ @staticmethod
+ def empty_cache():
+ torch.cuda.empty_cache()
+
+ @staticmethod
+ def ipc_collect():
+ torch.cuda.ipc_collect()
+
+ @staticmethod
+ def get_device_name():
+ return torch.cuda.get_device_name()
+
+ @staticmethod
+ def device_ctx(device):
+ return torch.cuda.device(device)
+
+ @staticmethod
+ def default_device():
+ return torch.device("cuda")
+
+ @staticmethod
+ def synchronize(device=None):
+ torch.cuda.synchronize(device)
+
+ @staticmethod
+ def device_count():
+ return torch.cuda.device_count()
+
+ @staticmethod
+ def is_accelerator_available():
+ return torch.cuda.is_available()
+
+ @staticmethod
+ def current_device():
+ return torch.cuda.current_device()
+
+ @staticmethod
+ def reset_peak_memory_stats(device=None):
+ return torch.cuda.reset_peak_memory_stats(device)
+
+ @staticmethod
+ def max_memory_allocated(device=None):
+ return torch.cuda.max_memory_allocated(device)
+
+ @staticmethod
+ def get_device_properties(device=None):
+ return torch.cuda.get_device_properties(device)
+
+ @staticmethod
+ def set_device(device):
+ return torch.cuda.set_device(device)
+
+ @staticmethod
+ def get_device_capability(device=None):
+ return torch.cuda.get_device_capability(device)
+
+
+class NPUPlatform(BasePlatform):
+ device_type: str = "npu"
+ device_control_env_var: str = "ASCEND_RT_VISIBLE_DEVICES"
+ dispatch_key: str = "PrivateUse1"
+ dist_backend: str = "hccl"
+ full_dist_backend: str = "npu:hccl"
+
+ @staticmethod
+ def empty_cache():
+ torch.npu.empty_cache()
+
+ @staticmethod
+ def ipc_collect():
+ """
+ torch.npu.ipc_collect() is not implemented yet.
+ """
+ pass
+
+ @staticmethod
+ def get_device_name():
+ return torch.npu.get_device_name()
+
+ @staticmethod
+ def device_ctx(device):
+ return torch.npu.device(device)
+
+ @staticmethod
+ def default_device():
+ return torch.device("npu")
+
+ @staticmethod
+ def synchronize(device=None):
+ torch.npu.synchronize(device)
+
+ @staticmethod
+ def device_count():
+ return torch.npu.device_count()
+
+ @staticmethod
+ def is_accelerator_available():
+ return torch.npu.is_available()
+
+ @staticmethod
+ def current_device():
+ return torch.npu.current_device()
+
+ @staticmethod
+ def reset_peak_memory_stats(device=None):
+ return torch.npu.reset_peak_memory_stats(device)
+
+ @staticmethod
+ def max_memory_allocated(device=None):
+ return torch.npu.max_memory_allocated(device)
+
+ @staticmethod
+ def get_device_properties(device=None):
+ return torch.npu.get_device_properties(device)
+
+ @staticmethod
+ def set_device(device):
+ return torch.npu.set_device(device)
diff --git a/src/cache_dit/profiler.py b/src/cache_dit/profiler.py
new file mode 100644
index 000000000..b3f95813d
--- /dev/null
+++ b/src/cache_dit/profiler.py
@@ -0,0 +1,175 @@
+"""
+Torch Profiler for cache-dit.
+
+Reference: Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/bench_one_batch.py
+"""
+
+import logging
+import os
+import time
+from pathlib import Path
+from typing import List, Optional
+
+import torch
+from torch.profiler import ProfilerActivity, profile
+from .platforms import current_platform
+
+logger = logging.getLogger(__name__)
+
+# Default profiler directory
+PROFILER_DIR = os.getenv("CACHE_DIT_TORCH_PROFILER_DIR", "/tmp/cache_dit_profiles")
+
+
+class ProfilerContext:
+
+ def __init__(
+ self,
+ enabled: bool = True,
+ activities: Optional[List[str]] = None,
+ output_dir: Optional[str] = None,
+ profile_name: Optional[str] = None,
+ with_stack: bool = True,
+ record_shapes: bool = True,
+ ):
+ assert (
+ current_platform.is_accelerator_available() and current_platform.device_type == "cuda"
+ ), "Torch ProfilerContext currently only supports CUDA devices."
+ self.enabled = enabled
+ self.activities = activities or ["CPU", "GPU"]
+ self.output_dir = Path(output_dir or PROFILER_DIR).expanduser()
+ self.profile_name = profile_name or f"profile_{int(time.time())}"
+ self.with_stack = with_stack
+ self.record_shapes = record_shapes
+
+ self.profiler = None
+ self.trace_path = None
+ self.memory_snapshot_path = None
+
+ def __enter__(self):
+ if not self.enabled:
+ return self
+
+ assert (
+ current_platform.is_accelerator_available() and current_platform.device_type == "cuda"
+ ), "Torch ProfilerContext currently only supports CUDA devices."
+
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ activity_map = {
+ "CPU": ProfilerActivity.CPU,
+ "GPU": ProfilerActivity.CUDA,
+ }
+ torch_activities = [activity_map[a] for a in self.activities if a in activity_map]
+
+ rank = 0
+ world_size = 1
+ if torch.distributed.is_initialized():
+ rank = torch.distributed.get_rank()
+ world_size = torch.distributed.get_world_size()
+
+ filename_parts = [self.profile_name]
+ if world_size > 1:
+ filename_parts.append(f"rank{rank}")
+ filename = "-".join(filename_parts) + ".trace.json.gz"
+ self.trace_path = self.output_dir / filename
+
+ if "MEM" in self.activities and torch.cuda.is_available():
+ torch.cuda.memory._record_memory_history(max_entries=100000)
+ logger.info("Started CUDA memory profiling")
+
+ if torch_activities:
+ self.profiler = profile(
+ activities=torch_activities,
+ with_stack=self.with_stack,
+ record_shapes=self.record_shapes,
+ )
+
+ self.profiler.start()
+ logger.info(
+ f"Started profiling. Traces will be saved to: {self.output_dir} "
+ f"(activities: {self.activities})"
+ )
+
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if not self.enabled:
+ return
+
+ if self.profiler is not None:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ self.profiler.stop()
+
+ logger.info(f"Exporting trace to: {self.trace_path}")
+ self.profiler.export_chrome_trace(str(self.trace_path))
+
+ logger.info(f"Profiling completed. Trace saved to: {self.trace_path}")
+
+ if "MEM" in self.activities and torch.cuda.is_available():
+ timestamp = int(time.time())
+ rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
+ memory_snapshot_path = (
+ self.output_dir / f"{self.profile_name}-rank{rank}-memory-{timestamp}.pickle"
+ )
+ torch.cuda.memory._dump_snapshot(str(memory_snapshot_path))
+ torch.cuda.memory._record_memory_history(enabled=None)
+ logger.info(f"Memory snapshot saved to: {memory_snapshot_path}")
+
+ memory_summary_path = (
+ self.output_dir / f"{self.profile_name}-rank{rank}-memory-{timestamp}.txt"
+ )
+ with open(memory_summary_path, "w") as f:
+ f.write(torch.cuda.memory_summary())
+ logger.info(f"Memory summary saved to: {memory_summary_path}")
+
+
+def profile_function(
+ enabled: bool = True,
+ activities: Optional[List[str]] = None,
+ output_dir: Optional[str] = None,
+ profile_name: Optional[str] = None,
+ with_stack: bool = False,
+ record_shapes: bool = True,
+):
+ def decorator(func):
+ def wrapper(*args, **kwargs):
+ name = profile_name or func.__name__
+ with ProfilerContext(
+ enabled=enabled,
+ activities=activities,
+ output_dir=output_dir,
+ profile_name=name,
+ with_stack=with_stack,
+ record_shapes=record_shapes,
+ ):
+ return func(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def create_profiler_context(
+ enabled: bool = False,
+ activities: Optional[List[str]] = None,
+ output_dir: Optional[str] = None,
+ profile_name: Optional[str] = None,
+ **kwargs,
+) -> ProfilerContext:
+ return ProfilerContext(
+ enabled=enabled,
+ activities=activities,
+ output_dir=output_dir,
+ profile_name=profile_name,
+ **kwargs,
+ )
+
+
+def get_profiler_output_dir() -> str:
+ return os.environ.get("CACHE_DIT_TORCH_PROFILER_DIR", PROFILER_DIR)
+
+
+def set_profiler_output_dir(path: str):
+ os.environ["CACHE_DIT_TORCH_PROFILER_DIR"] = path
diff --git a/src/cache_dit/quantize/__init__.py b/src/cache_dit/quantize/__init__.py
index 5c1e54130..8d6150c98 100644
--- a/src/cache_dit/quantize/__init__.py
+++ b/src/cache_dit/quantize/__init__.py
@@ -1,8 +1,39 @@
-try:
- import torchao
-except ImportError:
- raise ImportError(
- "Quantization functionality requires the 'quantization' extra dependencies. "
- "Install with: pip install cache-dit[quantization]"
- )
-from cache_dit.quantize.quantize_interface import quantize
+import torch
+from typing import Callable, Optional, List
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def quantize(
+ module: torch.nn.Module,
+ quant_type: Optional[str] = None,
+ backend: str = "ao",
+ # Specific parameters for torchao backend
+ per_row: bool = True,
+ exclude_layers: List[str] = [
+ "embedder",
+ "embed",
+ ],
+ filter_fn: Optional[Callable] = None,
+ **kwargs,
+) -> torch.nn.Module:
+ assert isinstance(module, torch.nn.Module)
+
+ if quant_type is None:
+ quant_type = "float8_weight_only"
+ logger.warning(f"quant_type is not specified, using default: {quant_type}")
+
+ if backend.lower() in ("ao", "torchao"):
+ from .torchao import quantize_ao
+
+ return quantize_ao(
+ module,
+ quant_type=quant_type,
+ per_row=per_row,
+ exclude_layers=exclude_layers,
+ filter_fn=filter_fn,
+ **kwargs,
+ )
+ else:
+ raise ValueError(f"backend: {backend} is not supported now!")
diff --git a/src/cache_dit/quantize/backends/__init__.py b/src/cache_dit/quantize/backends/__init__.py
deleted file mode 100644
index e4ff5cdb2..000000000
--- a/src/cache_dit/quantize/backends/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .torchao import quantize_ao
diff --git a/src/cache_dit/quantize/quantize_backend.py b/src/cache_dit/quantize/bitsandbytes/__init__.py
similarity index 100%
rename from src/cache_dit/quantize/quantize_backend.py
rename to src/cache_dit/quantize/bitsandbytes/__init__.py
diff --git a/src/cache_dit/quantize/quantize_config.py b/src/cache_dit/quantize/quantize_config.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/cache_dit/quantize/quantize_interface.py b/src/cache_dit/quantize/quantize_interface.py
deleted file mode 100644
index 768a1ff79..000000000
--- a/src/cache_dit/quantize/quantize_interface.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import torch
-from typing import Callable, Optional, List
-from cache_dit.logger import init_logger
-
-logger = init_logger(__name__)
-
-
-def quantize(
- module: torch.nn.Module,
- quant_type: str = "float8_weight_only",
- backend: str = "ao",
- exclude_layers: List[str] = [
- "embedder",
- "embed",
- ],
- filter_fn: Optional[Callable] = None,
- **kwargs,
-) -> torch.nn.Module:
- assert isinstance(module, torch.nn.Module)
-
- if backend.lower() in ("ao", "torchao"):
- from cache_dit.quantize.backends.torchao import quantize_ao
-
- return quantize_ao(
- module,
- quant_type=quant_type,
- per_row=kwargs.pop("per_row", True),
- exclude_layers=exclude_layers,
- filter_fn=filter_fn,
- **kwargs,
- )
- else:
- raise ValueError(f"backend: {backend} is not supported now!")
diff --git a/src/cache_dit/quantize/backends/torchao/__init__.py b/src/cache_dit/quantize/torchao/__init__.py
similarity index 100%
rename from src/cache_dit/quantize/backends/torchao/__init__.py
rename to src/cache_dit/quantize/torchao/__init__.py
diff --git a/src/cache_dit/quantize/backends/torchao/quantize_ao.py b/src/cache_dit/quantize/torchao/quantize_ao.py
similarity index 85%
rename from src/cache_dit/quantize/backends/torchao/quantize_ao.py
rename to src/cache_dit/quantize/torchao/quantize_ao.py
index 32f8f301a..51525ff53 100644
--- a/src/cache_dit/quantize/backends/torchao/quantize_ao.py
+++ b/src/cache_dit/quantize/torchao/quantize_ao.py
@@ -1,6 +1,8 @@
import torch
+import copy
from typing import Callable, Optional, List
-from cache_dit.utils import maybe_empty_cache
+from ...utils import maybe_empty_cache
+from ...platforms import current_platform
from cache_dit.logger import init_logger
logger = init_logger(__name__)
@@ -9,29 +11,50 @@
def quantize_ao(
module: torch.nn.Module,
quant_type: str = "float8_weight_only",
+ # Paramters for FP8 DQ quantization
+ # Whether to quantize per row (True) or per tensor (False)
+ per_row: bool = True,
exclude_layers: List[str] = [
"embedder",
"embed",
],
filter_fn: Optional[Callable] = None,
- # paramters for fp8 quantization
- per_row: bool = True,
**kwargs,
) -> torch.nn.Module:
# Apply FP8 DQ for module and skip any `embed` modules
# by default to avoid non-trivial precision downgrade. Please
# set `exclude_layers` as `[]` if you don't want this behavior.
assert isinstance(module, torch.nn.Module)
+ assert (
+ current_platform.is_accelerator_available() and current_platform.device_type == "cuda"
+ ), "Quantization functionality with torchao backend is only supported on CUDA devices."
+ try:
+ import torchao # noqa: F401
+ except ImportError:
+ raise ImportError(
+ "Quantization functionality requires the 'quantization' extra dependencies. "
+ "Install with: pip install cache-dit[quantization]"
+ )
alias_map = {
"float8": "fp8_w8a8_dq",
"float8_weight_only": "fp8_w8a16_wo",
+ "float8_wo": "fp8_w8a16_wo",
"int8": "int8_w8a8_dq",
"int8_weight_only": "int8_w8a16_wo",
+ "int8_wo": "int8_w8a16_wo",
"int4": "int4_w4a8_dq",
"int4_w4a4": "int4_w4a4_dq",
"int4_weight_only": "int4_w4a16_wo",
+ "int4_wo": "int4_w4a16_wo",
}
+ alias_map_rev = copy.deepcopy(alias_map)
+ # remove duplicates *_wo in rev map
+ for key in list(alias_map_rev.keys()):
+ if key.endswith("_wo"):
+ alias_map_rev.pop(key)
+ alias_map_rev = {v: k for k, v in alias_map_rev.items()}
+
if quant_type.lower() in alias_map:
quant_type = alias_map[quant_type.lower()]
@@ -47,7 +70,7 @@ def quantize_ao(
), f"{quant_type} is not supported for torchao backend now!"
if "fp8" in quant_type:
- assert torch.cuda.get_device_capability() >= (
+ assert current_platform.get_device_capability() >= (
8,
9,
), "FP8 is not supported for current device."
@@ -184,7 +207,6 @@ def _quant_config():
maybe_empty_cache()
- alias_map_rev = {v: k for k, v in alias_map.items()}
if quant_type in alias_map_rev:
quant_type = alias_map_rev[quant_type]
diff --git a/src/cache_dit/quantize/utils.py b/src/cache_dit/quantize/utils.py
new file mode 100644
index 000000000..90c83760b
--- /dev/null
+++ b/src/cache_dit/quantize/utils.py
@@ -0,0 +1,13 @@
+from __future__ import annotations
+
+
+def normalize_quantize_type(quantize_type: str | None) -> str | None:
+ if quantize_type is None:
+ return None
+ mapping = {
+ "float8_wo": "float8_weight_only",
+ "int8_wo": "int8_weight_only",
+ "int4_wo": "int4_weight_only",
+ "bnb_4bit": "bitsandbytes_4bit",
+ }
+ return mapping.get(quantize_type, quantize_type)
diff --git a/src/cache_dit/serve/__init__.py b/src/cache_dit/serve/__init__.py
new file mode 100644
index 000000000..677119e35
--- /dev/null
+++ b/src/cache_dit/serve/__init__.py
@@ -0,0 +1,5 @@
+from .model_manager import ModelManager
+from .api_server import create_app
+from .serve import launch_server
+
+__all__ = ["ModelManager", "create_app", "launch_server"]
diff --git a/src/cache_dit/serve/api_server.py b/src/cache_dit/serve/api_server.py
new file mode 100644
index 000000000..7eb393272
--- /dev/null
+++ b/src/cache_dit/serve/api_server.py
@@ -0,0 +1,157 @@
+"""FastAPI HTTP Server for cache-dit.
+
+Adapted from SGLang's HTTP server:
+https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/http_server.py
+"""
+
+import asyncio
+from typing import Optional, Dict, Any, List, Literal
+from fastapi import FastAPI, HTTPException, Response
+from fastapi.responses import JSONResponse
+from pydantic import BaseModel, Field
+
+from .model_manager import ModelManager, GenerateRequest
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+_global_model_manager: Optional[ModelManager] = None
+_request_semaphore: Optional[asyncio.Semaphore] = None
+
+
+class GenerateRequestAPI(BaseModel):
+ """API request model for image/video generation."""
+
+ prompt: str = Field(..., description="Text prompt")
+ negative_prompt: Optional[str] = Field("", description="Negative prompt")
+ width: int = Field(1024, description="Image/Video width", ge=64, le=4096)
+ height: int = Field(1024, description="Image/Video height", ge=64, le=4096)
+ num_inference_steps: int = Field(50, description="Number of inference steps", ge=1, le=200)
+ guidance_scale: float = Field(7.5, description="Guidance scale", ge=0.0, le=20.0)
+ sigmas: Optional[List[float]] = Field(
+ None,
+ description="Custom sigma schedule (e.g. for turbo inference). Length should typically match num_inference_steps.",
+ )
+ seed: Optional[int] = Field(None, description="Random seed")
+ num_images: int = Field(1, description="Number of images to generate", ge=1, le=4)
+ image_urls: Optional[List[str]] = Field(
+ None,
+ description="Input images for image editing. Supports: URLs (http/https), local file paths, base64 strings (with or without data URI prefix)",
+ )
+ num_frames: Optional[int] = Field(
+ None, description="Number of frames for video generation", ge=1, le=200
+ )
+ fps: Optional[int] = Field(16, description="Frames per second for video output", ge=1, le=60)
+ include_stats: bool = Field(False, description="Include stats field in response")
+ output_format: Literal["base64", "path"] = Field(
+ "base64",
+ description="Output format: base64 or path",
+ )
+ output_dir: Optional[str] = Field(
+ None,
+ description="Output directory when output_format=path (server-side path)",
+ )
+
+
+class GenerateResponseAPI(BaseModel):
+ """API response model for image/video generation."""
+
+ images: Optional[list[str]] = Field(None, description="Base64 encoded images or file paths")
+ video: Optional[str] = Field(None, description="Base64 encoded video (mp4) or file path")
+ stats: Optional[Dict[str, Any]] = Field(None, description="Cache statistics")
+ time_cost: Optional[float] = Field(None, description="Generation time in seconds")
+ inference_start_time: Optional[str] = Field(
+ None, description="Inference start time (local time with timezone offset)"
+ )
+ inference_end_time: Optional[str] = Field(
+ None, description="Inference end time (local time with timezone offset)"
+ )
+
+
+def create_app(model_manager: ModelManager) -> FastAPI:
+ """Create FastAPI application."""
+ global _global_model_manager, _request_semaphore
+ _global_model_manager = model_manager
+ _request_semaphore = asyncio.Semaphore(1)
+
+ app = FastAPI(
+ title="Cache-DiT Serving API",
+ description="Text-to-image model serving API with cache-dit acceleration",
+ version="1.0.0",
+ )
+
+ @app.get("/health")
+ async def health():
+ """Health check endpoint."""
+ if _global_model_manager is None or _global_model_manager.pipe is None:
+ return Response(status_code=503, content="Model not loaded")
+ return Response(status_code=200, content="OK")
+
+ @app.get("/get_model_info")
+ async def get_model_info():
+ """Get model information."""
+ if _global_model_manager is None:
+ raise HTTPException(status_code=503, detail="Model manager not initialized")
+
+ return JSONResponse(content=_global_model_manager.get_model_info())
+
+ @app.post("/generate", response_model=GenerateResponseAPI, response_model_exclude_none=True)
+ async def generate(request: GenerateRequestAPI):
+ """Generate images from text prompt."""
+ if _global_model_manager is None:
+ raise HTTPException(status_code=503, detail="Model manager not initialized")
+
+ if _global_model_manager.pipe is None:
+ raise HTTPException(status_code=503, detail="Model not loaded")
+
+ async with _request_semaphore:
+ try:
+ gen_request = GenerateRequest(
+ prompt=request.prompt,
+ negative_prompt=request.negative_prompt,
+ width=request.width,
+ height=request.height,
+ num_inference_steps=request.num_inference_steps,
+ guidance_scale=request.guidance_scale,
+ sigmas=request.sigmas,
+ seed=request.seed,
+ num_images=request.num_images,
+ image_urls=request.image_urls,
+ num_frames=request.num_frames,
+ fps=request.fps,
+ include_stats=request.include_stats,
+ output_format=request.output_format,
+ output_dir=request.output_dir,
+ )
+
+ loop = asyncio.get_event_loop()
+ response = await loop.run_in_executor(
+ None, _global_model_manager.generate, gen_request
+ )
+
+ return GenerateResponseAPI(
+ images=response.images,
+ video=response.video,
+ stats=response.stats,
+ time_cost=response.time_cost,
+ inference_start_time=response.inference_start_time,
+ inference_end_time=response.inference_end_time,
+ )
+
+ except Exception as e:
+ logger.error(f"Error generating image: {type(e).__name__}: {str(e)}")
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
+
+ @app.post("/flush_cache")
+ async def flush_cache():
+ """Flush cache."""
+ if _global_model_manager is None or _global_model_manager.pipe is None:
+ raise HTTPException(status_code=503, detail="Model not loaded")
+
+ try:
+ return JSONResponse(content={"message": "Cache flushed successfully"})
+ except Exception as e:
+ logger.error(f"Error flushing cache: {str(e)}")
+ raise HTTPException(status_code=500, detail=f"Failed to flush cache: {str(e)}")
+
+ return app
diff --git a/src/cache_dit/serve/cache_alignment.py b/src/cache_dit/serve/cache_alignment.py
new file mode 100644
index 000000000..dc820dd84
--- /dev/null
+++ b/src/cache_dit/serve/cache_alignment.py
@@ -0,0 +1,118 @@
+from __future__ import annotations
+
+from typing import Any, Dict, List, Optional
+
+import cache_dit
+from .. import DBCacheConfig, ParamsModifier
+
+
+def get_default_params_modifiers(
+ *,
+ pipe,
+ model_path: str | None,
+ cache_config_obj,
+) -> Optional[List[object]]:
+ if cache_config_obj is None:
+ return None
+
+ model_path_lower = (model_path or "").lower()
+
+ is_flux2 = (pipe is not None and pipe.__class__.__name__ == "Flux2Pipeline") or (
+ "flux.2" in model_path_lower
+ )
+ if not is_flux2:
+ is_wan_2_2 = "wan2.2" in model_path_lower
+ if not is_wan_2_2:
+ return None
+
+ return [
+ ParamsModifier(
+ cache_config=DBCacheConfig().reset(
+ max_warmup_steps=4,
+ max_cached_steps=8,
+ ),
+ ),
+ ParamsModifier(
+ cache_config=DBCacheConfig().reset(
+ max_warmup_steps=2,
+ max_cached_steps=20,
+ ),
+ ),
+ ]
+
+ rdt = getattr(cache_config_obj, "residual_diff_threshold", 0.24)
+ return [
+ ParamsModifier(
+ cache_config=DBCacheConfig().reset(
+ residual_diff_threshold=rdt,
+ ),
+ ),
+ ParamsModifier(
+ cache_config=DBCacheConfig().reset(
+ residual_diff_threshold=rdt * 3,
+ ),
+ ),
+ ]
+
+
+def align_cache_config(
+ *,
+ model_path: str,
+ args,
+ base_cache_config: Optional[Dict[str, Any]],
+) -> Optional[Dict[str, Any]]:
+ if base_cache_config is None:
+ return None
+
+ model_path_lower = (model_path or "").lower()
+ cache_config = dict(base_cache_config)
+
+ is_qwen_lightning = (
+ "qwen-image-lightning" in model_path_lower
+ or "qwen-image-edit-2511-lightning" in model_path_lower
+ or "qwen-image-edit-2509-lightning" in model_path_lower
+ )
+ if is_qwen_lightning:
+ steps = (
+ 8 if getattr(args, "num_inference_steps", None) is None else args.num_inference_steps
+ )
+ if steps not in (4, 8):
+ raise ValueError("Qwen-Image Lightning only supports 4 or 8 steps.")
+ cache_config.update(
+ {
+ "Fn_compute_blocks": 16,
+ "Bn_compute_blocks": 16,
+ "max_warmup_steps": 4 if steps > 4 else 2,
+ "max_cached_steps": 2 if steps > 4 else 1,
+ "max_continuous_cached_steps": 1,
+ "enable_separate_cfg": False,
+ "residual_diff_threshold": 0.50 if steps > 4 else 0.8,
+ }
+ )
+ return cache_config
+
+ if "qwen-image-layered" in model_path_lower:
+ cache_config.setdefault("enable_separate_cfg", False)
+
+ if "z-image-turbo" in model_path_lower and base_cache_config is not None:
+ cache_config["max_warmup_steps"] = min(
+ int(cache_config.get("max_warmup_steps", 8)),
+ 4,
+ )
+ total_steps = (
+ 9 if getattr(args, "num_inference_steps", None) is None else args.num_inference_steps
+ )
+ steps_computation_mask = None
+ if getattr(args, "mask_policy", None) is not None:
+ steps_computation_mask = cache_dit.steps_mask(
+ mask_policy=args.mask_policy,
+ total_steps=total_steps,
+ )
+ elif getattr(args, "steps_mask", False):
+ steps_computation_mask = cache_dit.steps_mask(
+ compute_bins=[5, 1, 1],
+ cache_bins=[1, 1],
+ )
+ cache_config["steps_computation_mask"] = steps_computation_mask
+
+ return cache_config
diff --git a/src/cache_dit/serve/model_manager.py b/src/cache_dit/serve/model_manager.py
new file mode 100644
index 000000000..484b079df
--- /dev/null
+++ b/src/cache_dit/serve/model_manager.py
@@ -0,0 +1,897 @@
+"""Model Manager for cache-dit serving.
+
+Adapted from SGLang's model management:
+https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/tokenizer_manager.py
+"""
+
+import os
+import base64
+import inspect
+import tempfile
+import math
+import uuid
+from datetime import datetime, timezone
+import torch
+import torch.distributed as dist
+import requests
+from io import BytesIO
+from typing import Optional, Dict, Any, List
+from dataclasses import dataclass
+from diffusers import DiffusionPipeline, FlowMatchEulerDiscreteScheduler
+from diffusers.utils import export_to_video
+from diffusers.loaders.lora_base import LoraBaseMixin
+from PIL import Image
+import cache_dit
+from cache_dit.logger import init_logger
+from diffusers import WanImageToVideoPipeline
+from ..platforms import current_platform
+from .utils import prepare_extra_parallel_modules
+from .cache_alignment import get_default_params_modifiers
+
+logger = init_logger(__name__)
+
+
+def load_pipeline_quant_config(pipeline_quant_config_path: str):
+ """Load pipeline quantization config from a custom module."""
+
+ from diffusers.quantizers import PipelineQuantizationConfig
+
+ logger.info(f"Loading pipeline quantization config from: {pipeline_quant_config_path}")
+
+ try:
+ import importlib.util
+ import sys
+
+ # Load the custom module
+ spec = importlib.util.spec_from_file_location(
+ "pipeline_quant_config", pipeline_quant_config_path
+ )
+ if spec is None or spec.loader is None:
+ raise ValueError(f"Cannot load module from {pipeline_quant_config_path}")
+
+ module = importlib.util.module_from_spec(spec)
+ sys.modules[spec.name] = module
+ spec.loader.exec_module(module)
+
+ # Get the pipeline quantization config from the module
+ if not hasattr(module, "get_pipeline_quant_config"):
+ raise ValueError(
+ f"Module {pipeline_quant_config_path} must have a 'get_pipeline_quant_config()' function"
+ )
+
+ quantization_config = module.get_pipeline_quant_config()
+
+ if not isinstance(quantization_config, PipelineQuantizationConfig):
+ raise ValueError(
+ f"get_pipeline_quant_config() must return a PipelineQuantizationConfig object, "
+ f"got {type(quantization_config)}"
+ )
+
+ logger.info("Successfully loaded quantization config from custom module")
+ return quantization_config
+
+ except Exception as e:
+ logger.error(f"Failed to load quantization config from {pipeline_quant_config_path}: {e}")
+ raise
+
+
+@dataclass
+class GenerateRequest:
+ """Image/Video generation request."""
+
+ prompt: str
+ negative_prompt: Optional[str] = ""
+ width: int = 1024
+ height: int = 1024
+ num_inference_steps: int = 50
+ guidance_scale: float = 7.5
+ sigmas: Optional[List[float]] = None
+ seed: Optional[int] = None
+ num_images: int = 1
+ image_urls: Optional[List[str]] = None
+ num_frames: Optional[int] = None
+ fps: Optional[int] = 16
+ include_stats: bool = False
+ output_format: str = "base64"
+ output_dir: Optional[str] = None
+
+ def __repr__(self):
+ image_urls_repr = None
+ if self.image_urls:
+ image_urls_repr = [
+ f"" if len(url) > 100 else url for url in self.image_urls
+ ]
+ return (
+ f"GenerateRequest(prompt={self.prompt[:50]!r}..., "
+ f"width={self.width}, height={self.height}, "
+ f"num_inference_steps={self.num_inference_steps}, "
+ f"guidance_scale={self.guidance_scale}, seed={self.seed}, "
+ f"num_images={self.num_images}, image_urls={image_urls_repr})"
+ )
+
+
+@dataclass
+class GenerateResponse:
+ """Image/Video generation response."""
+
+ images: Optional[List[str]] = None # Base64 encoded images or file paths
+ video: Optional[str] = None # Base64 encoded video (mp4) or file path
+ stats: Optional[Dict[str, Any]] = None
+ time_cost: Optional[float] = None
+ inference_start_time: Optional[str] = None
+ inference_end_time: Optional[str] = None
+
+
+class ModelManager:
+ """Manages diffusion model loading and inference."""
+
+ def __init__(
+ self,
+ model_path: str,
+ device: Optional[str] = None,
+ generator_device: Optional[str] = None,
+ torch_dtype: Optional[torch.dtype] = torch.bfloat16,
+ enable_cache: bool = True,
+ cache_config: Optional[Dict[str, Any]] = None,
+ enable_cpu_offload: bool = False,
+ device_map: Optional[str] = None,
+ enable_compile: bool = False,
+ parallel_type: Optional[str] = None,
+ parallel_args: Optional[Dict[str, Any]] = None,
+ attn_backend: Optional[str] = None,
+ quantize: bool = False,
+ quantize_type: Optional[str] = None,
+ pipeline_quant_config_path: Optional[str] = None,
+ lora_path: Optional[str] = None,
+ lora_name: Optional[str] = None,
+ fuse_lora: bool = True,
+ ):
+ self.model_path = model_path
+ self.device = device or (
+ current_platform.device_type if current_platform.is_accelerator_available() else "cpu"
+ )
+ self.generator_device = generator_device
+ self.torch_dtype = torch_dtype
+ self.enable_cache = enable_cache
+ self.cache_config = cache_config or {}
+ self.enable_cpu_offload = enable_cpu_offload
+ self.device_map = device_map
+ self.enable_compile = enable_compile
+ self.parallel_type = parallel_type
+ self.parallel_args = parallel_args or {}
+ self.attn_backend = attn_backend
+ self.quantize = quantize
+ self.quantize_type = quantize_type
+ self.pipeline_quant_config_path = pipeline_quant_config_path
+ self.lora_path = lora_path
+ self.lora_name = lora_name
+ self.fuse_lora = fuse_lora
+ self.pipe = None
+ self.warmed_up_shapes = set()
+
+ logger.info(
+ f"Initializing ModelManager: model_path={model_path}, device={self.device}, "
+ f"parallel_type={parallel_type}, attn_backend={attn_backend}"
+ )
+
+ def startup_warmup(self, resolutions: List[tuple[int, int]], prompt: str):
+ if self.pipe is None:
+ raise RuntimeError("Model not loaded. Call load_model() first.")
+
+ for width, height in resolutions:
+ shape_key = (width, height)
+ if shape_key in self.warmed_up_shapes:
+ continue
+
+ if self.parallel_type in ["tp", "ulysses", "ring"]:
+ dist.barrier()
+
+ logger.info(f"Startup warming up for shape {width}x{height}...")
+ _ = self.pipe(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=1,
+ )
+ self.warmed_up_shapes.add(shape_key)
+ logger.info(f"Startup warmup completed for {width}x{height}")
+
+ if self.parallel_type in ["tp", "ulysses", "ring"]:
+ dist.barrier()
+
+ def load_model(self):
+ """Load the diffusion model."""
+ logger.info(f"Loading model: {self.model_path}")
+
+ # Load pipeline quantization config
+ quantization_config = None
+ components_quantized_by_diffusers: set[str] = set()
+ if self.quantize and self.pipeline_quant_config_path:
+ quantization_config = load_pipeline_quant_config(self.pipeline_quant_config_path)
+ components_quantized_by_diffusers = set(
+ getattr(quantization_config, "components_to_quantize", []) or []
+ )
+ elif self.quantize:
+ logger.warning("Quantization enabled but no pipeline_quant_config_path provided")
+
+ # Will we quantize transformer via cache-dit(torchao) after parallelism is applied?
+ # NOTE: This is different from diffusers' PipelineQuantizationConfig (e.g., bitsandbytes_4bit).
+ will_torchao_quantize_transformer = (
+ self.quantize
+ and (self.quantize_type is not None)
+ and (self.quantize_type not in ("bitsandbytes_4bit",))
+ and ("transformer" not in components_quantized_by_diffusers)
+ and ("transformer_2" not in components_quantized_by_diffusers)
+ )
+
+ if "Wan2.2-I2V-A14B-Diffusers" in self.model_path:
+ logger.info("Detected Wan2.2-I2V model, using WanImageToVideoPipeline")
+ self.pipe = WanImageToVideoPipeline.from_pretrained(
+ self.model_path,
+ torch_dtype=self.torch_dtype,
+ device_map=self.device_map,
+ quantization_config=quantization_config,
+ )
+ else:
+ if "LTX-2" in self.model_path:
+ ltx2_pipeline = os.environ.get("CACHE_DIT_LTX2_PIPELINE", "t2v").strip().lower()
+ if ltx2_pipeline in ("t2v", "text2video", "text", "default"):
+ from diffusers import LTX2Pipeline
+
+ logger.info("Detected LTX-2 model, using LTX2Pipeline (text-to-video)")
+ self.pipe = LTX2Pipeline.from_pretrained(
+ self.model_path,
+ torch_dtype=self.torch_dtype,
+ device_map=self.device_map,
+ quantization_config=quantization_config,
+ )
+ elif ltx2_pipeline in ("i2v", "image2video", "image"):
+ from diffusers import LTX2ImageToVideoPipeline
+
+ logger.info(
+ "Detected LTX-2 model, using LTX2ImageToVideoPipeline (image-to-video)"
+ )
+ self.pipe = LTX2ImageToVideoPipeline.from_pretrained(
+ self.model_path,
+ torch_dtype=self.torch_dtype,
+ device_map=self.device_map,
+ quantization_config=quantization_config,
+ )
+ else:
+ raise ValueError(
+ "Invalid CACHE_DIT_LTX2_PIPELINE. Please set it to 't2v' or 'i2v'."
+ )
+ else:
+ self.pipe = DiffusionPipeline.from_pretrained(
+ self.model_path,
+ torch_dtype=self.torch_dtype,
+ device_map=self.device_map,
+ quantization_config=quantization_config,
+ )
+
+ if self.lora_path is not None and self.lora_name is not None:
+ if not isinstance(self.pipe, LoraBaseMixin):
+ logger.error("Pipeline does not support LoRA. Skipping LoRA loading.")
+ else:
+ logger.info(f"Loading LoRA weights from: {self.lora_path}/{self.lora_name}")
+ self.pipe.load_lora_weights(self.lora_path, weight_name=self.lora_name)
+ logger.info("LoRA weights loaded successfully")
+
+ if "qwen" in self.lora_name.lower() and "light" in self.lora_name.lower():
+ logger.info("Detected Qwen-Image-Lightning LoRA, updating scheduler...")
+ scheduler_config = {
+ "base_image_seq_len": 256,
+ "base_shift": math.log(3),
+ "invert_sigmas": False,
+ "max_image_seq_len": 8192,
+ "max_shift": math.log(3),
+ "num_train_timesteps": 1000,
+ "shift": 1.0,
+ "shift_terminal": None,
+ "stochastic_sampling": False,
+ "time_shift_type": "exponential",
+ "use_beta_sigmas": False,
+ "use_dynamic_shifting": True,
+ "use_exponential_sigmas": False,
+ "use_karras_sigmas": False,
+ }
+ self.pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
+ scheduler_config
+ )
+ logger.info("Scheduler updated for Lightning model")
+
+ # If transformer will be quantized (either by diffusers quantization_config
+ # or by cache-dit/torchao quantization), do NOT fuse LoRA into transformer.
+ transformer_quantized_or_will_be = (
+ ("transformer" in components_quantized_by_diffusers)
+ or ("transformer_2" in components_quantized_by_diffusers)
+ or will_torchao_quantize_transformer
+ )
+ should_fuse = self.fuse_lora and (not transformer_quantized_or_will_be)
+
+ if should_fuse:
+ logger.info("Fusing LoRA weights into transformer...")
+ self.pipe.fuse_lora()
+ self.pipe.unload_lora_weights()
+ logger.info("LoRA weights fused and unloaded successfully")
+ else:
+ logger.info(
+ "Keeping LoRA weights separate (fusion disabled or transformer quantized)"
+ )
+ elif self.lora_path is not None or self.lora_name is not None:
+ logger.warning("Both --lora-path and --lora-name must be provided to load LoRA weights")
+
+ cache_config_obj = None
+ if self.enable_cache:
+ logger.info("Enabling DBCache acceleration")
+ from cache_dit import DBCacheConfig
+
+ cache_config_obj = DBCacheConfig(
+ residual_diff_threshold=0.24,
+ )
+ if self.cache_config:
+ for key, value in self.cache_config.items():
+ setattr(cache_config_obj, key, value)
+
+ params_modifiers = None
+ if self.enable_cache and cache_config_obj is not None:
+ params_modifiers = get_default_params_modifiers(
+ pipe=self.pipe,
+ model_path=self.model_path,
+ cache_config_obj=cache_config_obj,
+ )
+
+ parallelism_config = None
+ if self.parallel_type is not None:
+ logger.info(
+ f"Enabling parallelism: type={self.parallel_type}, args={self.parallel_args}"
+ )
+ from cache_dit import ParallelismConfig
+ from cache_dit.parallelism import ParallelismBackend
+ import torch.distributed as dist
+
+ world_size = dist.get_world_size() if dist.is_initialized() else 1
+
+ backend = (
+ ParallelismBackend.NATIVE_PYTORCH
+ if self.parallel_type == "tp"
+ else ParallelismBackend.NATIVE_DIFFUSER
+ )
+
+ # Build extra_parallel_modules for text encoder and vae
+ parallel_text_encoder = self.parallel_args.pop("parallel_text_encoder", False)
+ parallel_vae = self.parallel_args.pop("parallel_vae", False)
+ extra_parallel_modules = prepare_extra_parallel_modules(
+ self.pipe,
+ parallel_text_encoder=parallel_text_encoder,
+ parallel_vae=parallel_vae,
+ )
+ self.parallel_args["extra_parallel_modules"] = extra_parallel_modules
+
+ parallelism_config = ParallelismConfig(
+ backend=backend,
+ ulysses_size=world_size if self.parallel_type == "ulysses" else None,
+ ring_size=world_size if self.parallel_type == "ring" else None,
+ tp_size=world_size if self.parallel_type == "tp" else None,
+ parallel_kwargs=self.parallel_args,
+ )
+
+ if cache_config_obj is not None or parallelism_config is not None:
+ cache_dit.enable_cache(
+ self.pipe,
+ cache_config=cache_config_obj,
+ params_modifiers=params_modifiers,
+ parallelism_config=parallelism_config,
+ )
+
+ # Quantize transformer by quantize_type (torchao backend).
+ # WARN: Must apply torchao quantization after tensor/context parallelism is applied.
+ if self.quantize and self.quantize_type is not None:
+ if self.quantize_type in ("bitsandbytes_4bit",):
+ if quantization_config is None:
+ logger.warning(
+ "Requested bitsandbytes_4bit quantization but no "
+ "--pipeline-quant-config-path provided. "
+ "Please provide a PipelineQuantizationConfig that sets "
+ "quant_backend='bitsandbytes_4bit'."
+ )
+ else:
+ if ("transformer" in components_quantized_by_diffusers) or (
+ "transformer_2" in components_quantized_by_diffusers
+ ):
+ logger.warning(
+ "Transformer is already quantized by diffusers PipelineQuantizationConfig; "
+ f"skipping cache-dit(torchao) quantize_type={self.quantize_type}."
+ )
+ else:
+ # Mirror logic from examples: some models do not support per-row quantization.
+ class_not_supported_per_row = {
+ "QwenImageTransformer2DModel",
+ }
+
+ def is_per_row_supported(m: torch.nn.Module) -> bool:
+ return m.__class__.__name__ not in class_not_supported_per_row
+
+ if hasattr(self.pipe, "transformer"):
+ transformer = getattr(self.pipe, "transformer", None)
+ if isinstance(transformer, torch.nn.Module):
+ logger.info(
+ f"Quantizing transformer module: {transformer.__class__.__name__} "
+ f"to {self.quantize_type} (torchao) ..."
+ )
+ setattr(
+ self.pipe,
+ "transformer",
+ cache_dit.quantize(
+ transformer,
+ quant_type=self.quantize_type,
+ per_row=is_per_row_supported(transformer),
+ ),
+ )
+ elif transformer is not None:
+ logger.warning(
+ "Cannot quantize transformer: it is not a torch.nn.Module "
+ f"(got {type(transformer)})."
+ )
+
+ if hasattr(self.pipe, "transformer_2"):
+ transformer_2 = getattr(self.pipe, "transformer_2", None)
+ if isinstance(transformer_2, torch.nn.Module):
+ logger.info(
+ f"Quantizing transformer_2 module: {transformer_2.__class__.__name__} "
+ f"to {self.quantize_type} (torchao) ..."
+ )
+ setattr(
+ self.pipe,
+ "transformer_2",
+ cache_dit.quantize(
+ transformer_2,
+ quant_type=self.quantize_type,
+ per_row=is_per_row_supported(transformer_2),
+ ),
+ )
+ elif transformer_2 is not None:
+ logger.warning(
+ "Cannot quantize transformer_2: it is not a torch.nn.Module "
+ f"(got {type(transformer_2)})."
+ )
+
+ # Move pipeline to device
+ if self.device_map is None and self.device == current_platform.device_type:
+ logger.info(f"Moving pipeline to {current_platform.device_type}")
+ self.pipe.to(self.device)
+
+ if self.enable_cpu_offload and current_platform.device_count() <= 1:
+ logger.info("Enabling CPU offload")
+ self.pipe.enable_model_cpu_offload()
+
+ if self.attn_backend is not None:
+ if hasattr(self.pipe.transformer, "set_attention_backend"):
+ logger.info(f"Setting attention backend to {self.attn_backend}")
+ self.pipe.transformer.set_attention_backend(self.attn_backend)
+ else:
+ logger.warning(
+ f"Transformer does not support set_attention_backend, ignoring --attn {self.attn_backend}"
+ )
+
+ if self.enable_compile:
+ logger.info("Enabling torch.compile")
+ cache_dit.set_compile_configs()
+ self.pipe.transformer = torch.compile(self.pipe.transformer)
+
+ logger.info("Model loaded successfully")
+
+ def _warmup_if_needed(self, width: int, height: int, prompt: str):
+ shape_key = (width, height)
+ if self.enable_compile and shape_key not in self.warmed_up_shapes:
+ if self.parallel_type in ["tp", "ulysses", "ring"]:
+ dist.barrier()
+
+ logger.info(f"Warming up for shape {width}x{height}...")
+ try:
+ _ = self.pipe(
+ prompt=prompt,
+ height=height,
+ width=width,
+ num_inference_steps=1,
+ )
+ self.warmed_up_shapes.add(shape_key)
+ logger.info(f"Warmup completed for {width}x{height}")
+ except Exception as e:
+ logger.warning(f"Warmup failed: {e}")
+
+ if self.parallel_type in ["tp", "ulysses", "ring"]:
+ dist.barrier()
+
+ def _load_images_from_urls(self, image_urls: List[str]) -> Optional[List[Image.Image]]:
+ """Load images from URLs, local paths, or base64 strings."""
+ if not image_urls:
+ return None
+
+ images = []
+ for idx, url in enumerate(image_urls):
+ try:
+ if url.startswith("data:image/"):
+ log_desc = f"data URI (length: {len(url)})"
+ logger.info(f"Loading image {idx + 1} from {log_desc}")
+ header, base64_data = url.split(",", 1)
+ img_data = base64.b64decode(base64_data)
+ image = Image.open(BytesIO(img_data)).convert("RGB")
+ elif url.startswith(("http://", "https://")):
+ log_desc = f"URL: {url[:80]}{'...' if len(url) > 80 else ''}"
+ logger.info(f"Downloading image {idx + 1} from {log_desc}")
+ response = requests.get(url, timeout=30)
+ response.raise_for_status()
+ image = Image.open(BytesIO(response.content)).convert("RGB")
+ elif len(url) > 100:
+ log_desc = f"raw base64 string (length: {len(url)})"
+ logger.info(f"Loading image {idx + 1} from {log_desc}")
+ try:
+ img_data = base64.b64decode(url, validate=True)
+ image = Image.open(BytesIO(img_data)).convert("RGB")
+ except Exception:
+ raise
+ else:
+ log_desc = f"local path: {url}"
+ logger.info(f"Loading image {idx + 1} from {log_desc}")
+ image = Image.open(url).convert("RGB")
+ images.append(image)
+ logger.info(f"Image {idx + 1} loaded successfully: {image.size}")
+ except Exception as e:
+ if len(url) > 100:
+ error_url = f""
+ else:
+ error_url = url
+ logger.error(f"Failed to load image {idx + 1} from {error_url}: {e}")
+ raise RuntimeError(f"Failed to load image {idx + 1}: {e}")
+
+ return images
+
+ def _resolve_output_dir(self, output_dir: Optional[str]) -> str:
+ if output_dir is not None:
+ return os.path.abspath(output_dir)
+ return os.path.join(os.getcwd(), "outputs")
+
+ def _save_image_to_dir(self, image: Image.Image, output_dir: str, name: str) -> str:
+ os.makedirs(output_dir, exist_ok=True)
+ path = os.path.join(output_dir, name)
+ image.save(path, format="PNG")
+ return os.path.abspath(path)
+
+ def _save_video_to_dir(self, video_frames, output_dir: str, name: str, fps: int) -> str:
+ os.makedirs(output_dir, exist_ok=True)
+ path = os.path.join(output_dir, name)
+ export_to_video(video_frames, path, fps=fps)
+ return os.path.abspath(path)
+
+ def generate(self, request: GenerateRequest) -> GenerateResponse:
+ if self.pipe is None:
+ raise RuntimeError("Model not loaded. Call load_model() first.")
+
+ if request.output_format not in ("base64", "path"):
+ raise ValueError(
+ f"Invalid output_format: {request.output_format}. Must be 'base64' or 'path'."
+ )
+
+ is_edit_mode = request.image_urls is not None and len(request.image_urls) > 0
+ is_video_mode = request.num_frames is not None and request.num_frames > 1
+ is_image2video_mode = is_edit_mode and is_video_mode
+ input_images = None
+ if is_edit_mode:
+ input_images = self._load_images_from_urls(request.image_urls)
+ if input_images:
+ logger.info(
+ f"Loaded {len(input_images)} input image(s) for {'image2video' if is_image2video_mode else 'editing'}"
+ )
+
+ if not is_edit_mode and not is_video_mode:
+ self._warmup_if_needed(request.width, request.height, request.prompt)
+
+ seed = request.seed
+ if seed is None and self.parallel_type in ["tp", "ulysses", "ring"]:
+ seed = 0
+ logger.info(f"{self.parallel_type} mode: using fixed seed {seed}")
+
+ if is_image2video_mode:
+ mode_str = "image2video"
+ elif is_video_mode:
+ mode_str = "video generation"
+ elif is_edit_mode:
+ mode_str = "edit"
+ else:
+ mode_str = "generation"
+ logger.info(f"{mode_str}: prompt='{request.prompt[:50]}...', seed={seed}")
+
+ generator = None
+ if seed is not None:
+ gen_device = self.generator_device
+ if gen_device is None:
+ gen_device = (
+ current_platform.device_type
+ if current_platform.is_accelerator_available()
+ else "cpu"
+ )
+ generator = torch.Generator(device=gen_device).manual_seed(seed)
+ logger.debug(f"Created generator with seed {seed} on {gen_device}")
+
+ if self.parallel_type in ["tp", "ulysses", "ring"]:
+ import torch.distributed as dist
+
+ dist.barrier()
+
+ start_dt_raw = datetime.now(timezone.utc)
+
+ pipe_to_use = self.pipe
+
+ if is_image2video_mode:
+ try:
+ sig = inspect.signature(pipe_to_use.__call__)
+ accepts_image = "image" in sig.parameters
+ except Exception:
+ accepts_image = True
+ if not accepts_image:
+ raise ValueError(
+ "Current LTX-2 pipeline does not support image2video. "
+ "Please restart server with CACHE_DIT_LTX2_PIPELINE=i2v."
+ )
+
+ # Build kwargs for pipe call
+ pipe_kwargs = {
+ "prompt": request.prompt,
+ "width": request.width,
+ "height": request.height,
+ "num_inference_steps": request.num_inference_steps,
+ "guidance_scale": request.guidance_scale,
+ "generator": generator,
+ }
+
+ if request.sigmas is not None:
+ try:
+ sig = inspect.signature(self.pipe.__call__)
+ if "sigmas" in sig.parameters:
+ pipe_kwargs["sigmas"] = request.sigmas
+ else:
+ logger.warning("Pipeline does not support sigmas, ignoring request.sigmas")
+ except Exception:
+ pipe_kwargs["sigmas"] = request.sigmas
+
+ # Add num_frames for video generation
+ if is_video_mode:
+ pipe_kwargs["num_frames"] = request.num_frames
+ # For some video pipelines (e.g. LTX2), `frame_rate` is an input condition.
+ # We unify it with request.fps to avoid redundant parameters.
+ try:
+ sig = inspect.signature(pipe_to_use.__call__)
+ if "frame_rate" in sig.parameters:
+ pipe_kwargs["frame_rate"] = (
+ float(request.fps) if request.fps is not None else 24.0
+ )
+ # For LTX2 i2v, exporting + audio handling is easier with numpy output
+ if "output_type" in sig.parameters:
+ pipe_kwargs["output_type"] = "np"
+ if "return_dict" in sig.parameters:
+ pipe_kwargs["return_dict"] = True
+ except Exception:
+ pipe_kwargs["frame_rate"] = float(request.fps) if request.fps is not None else 24.0
+ else:
+ pipe_kwargs["num_images_per_prompt"] = request.num_images
+
+ # Add input images to pipe_kwargs if in edit mode or image2video mode
+ if is_edit_mode and input_images:
+ # For image2video, always use single image (first one if multiple provided)
+ if is_image2video_mode:
+ pipe_kwargs["image"] = input_images[0]
+ logger.info(f"Using first image for image2video: {input_images[0].size}")
+ elif len(input_images) == 1:
+ pipe_kwargs["image"] = input_images[0]
+ else:
+ pipe_kwargs["image"] = input_images
+
+ # Some pipelines (like Flux2Pipeline) don't support negative_prompt
+ if request.negative_prompt:
+ try:
+ sig = inspect.signature(self.pipe.__call__)
+ if "negative_prompt" in sig.parameters:
+ pipe_kwargs["negative_prompt"] = request.negative_prompt
+ except Exception:
+ # If we can't inspect, try to add it anyway
+ pipe_kwargs["negative_prompt"] = request.negative_prompt
+
+ output = pipe_to_use(**pipe_kwargs)
+
+ if self.parallel_type in ["tp", "ulysses", "ring"]:
+ import torch.distributed as dist
+
+ dist.barrier()
+
+ end_dt_raw = datetime.now(timezone.utc)
+
+ start_dt = start_dt_raw.replace(microsecond=(start_dt_raw.microsecond // 1000) * 1000)
+ end_dt = end_dt_raw.replace(microsecond=(end_dt_raw.microsecond // 1000) * 1000)
+
+ time_cost = (end_dt - start_dt).total_seconds()
+
+ inference_start_time = start_dt.isoformat(timespec="milliseconds").replace("+00:00", "Z")
+ inference_end_time = end_dt.isoformat(timespec="milliseconds").replace("+00:00", "Z")
+
+ is_primary_rank = True
+ if (
+ self.parallel_type in ["tp", "ulysses", "ring"]
+ and dist.is_available()
+ and dist.is_initialized()
+ ):
+ try:
+ is_primary_rank = dist.get_rank() == 0
+ except Exception:
+ is_primary_rank = True
+
+ # Debug: Check output shape in distributed mode
+ if self.parallel_type is not None:
+ import torch.distributed as dist
+
+ rank = dist.get_rank()
+ if is_video_mode:
+ logger.info(f"Rank {rank}: Generated video with {len(output.frames[0])} frames")
+ else:
+ logger.info(f"Rank {rank}: Generated {len(output.images)} images")
+
+ stats = None
+ if is_primary_rank and request.include_stats and self.enable_cache:
+ stats_list = cache_dit.summary(self.pipe)
+ # Convert List[CacheStats] to dict for JSON serialization
+ if stats_list:
+ stats = {
+ "cache_stats": [
+ {
+ "cache_options": str(s.cache_options) if s.cache_options else None,
+ "cached_steps": list(s.cached_steps) if s.cached_steps else [],
+ "parallelism_config": (
+ str(s.parallelism_config) if s.parallelism_config else None
+ ),
+ }
+ for s in stats_list
+ ]
+ }
+
+ images_payload = None
+ video_payload = None
+
+ if not is_primary_rank:
+ return GenerateResponse(
+ images=None,
+ video=None,
+ stats=None,
+ time_cost=time_cost,
+ inference_start_time=inference_start_time,
+ inference_end_time=inference_end_time,
+ )
+
+ if is_video_mode:
+ video_frames = output.frames[0]
+ logger.info(
+ f"Video generation completed with {len(video_frames)} frames in {time_cost:.2f}s"
+ )
+
+ if request.output_format == "path":
+ out_dir = self._resolve_output_dir(request.output_dir)
+ # If pipeline returns audio (LTX2), export mp4 with audio track.
+ audio = getattr(output, "audio", None)
+ if audio is not None:
+ try:
+ from diffusers.pipelines.ltx2.export_utils import encode_video
+
+ video_np = video_frames
+ # video_np: (T, H, W, C) float in [0,1]
+ video_uint8 = (video_np * 255).round().astype("uint8")
+ video_t = torch.from_numpy(video_uint8)
+ audio_t = audio[0]
+ if not isinstance(audio_t, torch.Tensor):
+ audio_t = torch.from_numpy(audio_t)
+ audio_t = audio_t.float().cpu()
+ sample_rate = getattr(getattr(pipe_to_use, "vocoder", None), "config", None)
+ sample_rate = getattr(sample_rate, "output_sampling_rate", 24000)
+
+ os.makedirs(out_dir, exist_ok=True)
+ out_path = os.path.abspath(
+ os.path.join(out_dir, f"video_{uuid.uuid4().hex}.mp4")
+ )
+ encode_video(
+ video_t,
+ fps=float(request.fps),
+ audio=audio_t,
+ audio_sample_rate=sample_rate,
+ output_path=out_path,
+ )
+ video_payload = out_path
+ except Exception as e:
+ logger.warning(
+ f"encode_video(with audio) failed ({type(e).__name__}: {e}), "
+ "falling back to export_to_video(video-only)."
+ )
+ video_payload = self._save_video_to_dir(
+ video_frames,
+ out_dir,
+ name=f"video_{uuid.uuid4().hex}.mp4",
+ fps=request.fps,
+ )
+ else:
+ video_payload = self._save_video_to_dir(
+ video_frames,
+ out_dir,
+ name=f"video_{uuid.uuid4().hex}.mp4",
+ fps=request.fps,
+ )
+ else:
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
+ tmp_path = tmp_file.name
+
+ try:
+ audio = getattr(output, "audio", None)
+ if audio is not None:
+ from diffusers.pipelines.ltx2.export_utils import encode_video
+
+ video_np = video_frames
+ video_uint8 = (video_np * 255).round().astype("uint8")
+ video_t = torch.from_numpy(video_uint8)
+ audio_t = audio[0]
+ if not isinstance(audio_t, torch.Tensor):
+ audio_t = torch.from_numpy(audio_t)
+ audio_t = audio_t.float().cpu()
+ sample_rate = getattr(getattr(pipe_to_use, "vocoder", None), "config", None)
+ sample_rate = getattr(sample_rate, "output_sampling_rate", 24000)
+
+ encode_video(
+ video_t,
+ fps=float(request.fps),
+ audio=audio_t,
+ audio_sample_rate=sample_rate,
+ output_path=tmp_path,
+ )
+ else:
+ export_to_video(video_frames, tmp_path, fps=request.fps)
+
+ with open(tmp_path, "rb") as f:
+ video_bytes = f.read()
+ video_payload = base64.b64encode(video_bytes).decode()
+ finally:
+ if os.path.exists(tmp_path):
+ os.unlink(tmp_path)
+ else:
+ images_payload = []
+ if request.output_format == "path":
+ out_dir = self._resolve_output_dir(request.output_dir)
+ for idx, image in enumerate(output.images):
+ images_payload.append(
+ self._save_image_to_dir(
+ image,
+ out_dir,
+ name=f"image_{uuid.uuid4().hex}_{idx}.png",
+ )
+ )
+ else:
+ for image in output.images:
+ buffered = BytesIO()
+ image.save(buffered, format="PNG")
+ img_str = base64.b64encode(buffered.getvalue()).decode()
+ images_payload.append(img_str)
+
+ logger.info(f"Image generation completed in {time_cost:.2f}s")
+
+ return GenerateResponse(
+ images=images_payload,
+ video=video_payload,
+ stats=stats,
+ time_cost=time_cost,
+ inference_start_time=inference_start_time,
+ inference_end_time=inference_end_time,
+ )
+
+ def get_model_info(self) -> Dict[str, Any]:
+ """Get model information."""
+ return {
+ "model_path": self.model_path,
+ "device": self.device,
+ "torch_dtype": str(self.torch_dtype),
+ "enable_cache": self.enable_cache,
+ "is_loaded": self.pipe is not None,
+ }
diff --git a/src/cache_dit/serve/serve.py b/src/cache_dit/serve/serve.py
new file mode 100644
index 000000000..2989ed97e
--- /dev/null
+++ b/src/cache_dit/serve/serve.py
@@ -0,0 +1,464 @@
+"""Server launcher for cache-dit.
+
+Adapted from SGLang's server launcher:
+https://github.com/sgl-project/sglang/blob/main/python/sglang/launch_server.py
+"""
+
+import argparse
+import torch
+import uvicorn
+
+from ..quantize.utils import normalize_quantize_type
+from ..platforms import current_platform, CpuPlatform
+from .model_manager import ModelManager
+from .api_server import create_app
+from .cache_alignment import align_cache_config
+
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def get_args(
+ parse: bool = True,
+) -> argparse.ArgumentParser | argparse.Namespace:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--cache", action="store_true", default=False)
+ parser.add_argument("--compile", action="store_true", default=False)
+ parser.add_argument("--compile-repeated-blocks", action="store_true", default=False)
+ parser.add_argument("--max-autotune", action="store_true", default=False)
+ parser.add_argument(
+ "--lora-path", type=str, default=None, help="Path to LoRA weights directory"
+ )
+ parser.add_argument(
+ "--lora-name", type=str, default=None, help="LoRA weight filename (e.g., model.safetensors)"
+ )
+ parser.add_argument(
+ "--disable-fuse-lora",
+ action="store_true",
+ default=False,
+ help="Disable LoRA fusion (keep LoRA weights separate)",
+ )
+ parser.add_argument(
+ "--num-inference-steps",
+ "--steps",
+ dest="num_inference_steps",
+ type=int,
+ default=None,
+ )
+ parser.add_argument("--warmup", type=int, default=None)
+ parser.add_argument("--repeat", type=int, default=None)
+ parser.add_argument(
+ "--Fn-compute-blocks",
+ "--Fn",
+ dest="Fn_compute_blocks",
+ type=int,
+ default=1,
+ )
+ parser.add_argument(
+ "--Bn-compute-blocks",
+ "--Bn",
+ dest="Bn_compute_blocks",
+ type=int,
+ default=0,
+ )
+ parser.add_argument(
+ "--residual-diff-threshold",
+ "--rdt",
+ dest="residual_diff_threshold",
+ type=float,
+ default=0.24,
+ )
+ parser.add_argument("--max-warmup-steps", "--ws", "--w", type=int, default=8)
+ parser.add_argument("--warmup-interval", "--wi", type=int, default=1)
+ parser.add_argument("--max-cached-steps", "--mc", type=int, default=-1)
+ parser.add_argument("--max-continuous-cached-steps", "--mcc", type=int, default=3)
+ parser.add_argument("--taylorseer", action="store_true", default=False)
+ parser.add_argument("--taylorseer-order", "-order", type=int, default=1)
+ parser.add_argument("--steps-mask", action="store_true", default=False)
+ parser.add_argument(
+ "--mask-policy",
+ "--scm",
+ type=str,
+ default=None,
+ choices=[
+ None,
+ "slow",
+ "s",
+ "medium",
+ "m",
+ "fast",
+ "f",
+ "ultra",
+ "u",
+ ],
+ help="Pre-defined steps computation mask policy",
+ )
+ parser.add_argument("--height", type=int, default=None)
+ parser.add_argument("--width", type=int, default=None)
+ parser.add_argument("--quantize", "-q", action="store_true", default=False)
+ parser.add_argument(
+ "--quantize-type",
+ "--quant-type",
+ type=str,
+ default=None,
+ choices=[
+ None,
+ "float8",
+ "float8_weight_only",
+ "float8_wo",
+ "int8",
+ "int8_weight_only",
+ "int8_wo",
+ "int4",
+ "int4_weight_only",
+ "int4_wo",
+ "bitsandbytes_4bit",
+ "bnb_4bit",
+ ],
+ )
+ parser.add_argument(
+ "--pipeline-quant-config-path",
+ type=str,
+ default=None,
+ help="Path to custom Python module that provides get_pipeline_quant_config() function",
+ )
+ parser.add_argument(
+ "--parallel-type",
+ "--parallel",
+ type=str,
+ default=None,
+ choices=[
+ None,
+ "tp",
+ "ulysses",
+ "ring",
+ ],
+ )
+ # TODO: vae TP will be supported in the future
+ parser.add_argument(
+ "--parallel-vae",
+ action="store_true",
+ default=False,
+ help="Enable VAE parallelism if applicable.",
+ )
+ parser.add_argument(
+ "--parallel-text-encoder",
+ "--parallel-text",
+ action="store_true",
+ default=False,
+ help="Enable text encoder parallelism if applicable.",
+ )
+ parser.add_argument(
+ "--attn",
+ type=str,
+ default=None,
+ choices=[
+ None,
+ "flash",
+ "_flash_3",
+ "native",
+ "_native_cudnn",
+ "_sdpa_cudnn",
+ "sage",
+ ],
+ )
+ parser.add_argument("--perf", action="store_true", default=False)
+ parser.add_argument("--prompt", type=str, default=None, help="Override default prompt")
+ parser.add_argument(
+ "--negative-prompt", type=str, default=None, help="Override default negative prompt"
+ )
+ parser.add_argument("--model-path", type=str, default=None, help="Override model path")
+ parser.add_argument("--image-path", type=str, default=None, help="Override image path")
+ parser.add_argument(
+ "--track-memory",
+ action="store_true",
+ default=False,
+ help="Track and report peak GPU memory usage",
+ )
+ parser.add_argument(
+ "--ulysses-anything",
+ "--uaa",
+ action="store_true",
+ default=False,
+ help="Enable Ulysses Anything Attention for context parallelism",
+ )
+ parser.add_argument(
+ "--ulysses-float8",
+ "--ufp8",
+ action="store_true",
+ default=False,
+ help="Enable Ulysses Attention/UAA Float8 for context parallelism",
+ )
+ parser.add_argument(
+ "--ulysses-async",
+ "--uaqkv",
+ action="store_true",
+ default=False,
+ help="Enabled experimental Async QKV Projection with Ulysses for context parallelism",
+ )
+ args_or_parser = parser.parse_args() if parse else parser
+ if parse:
+ args_or_parser.quantize_type = normalize_quantize_type(args_or_parser.quantize_type)
+ if args_or_parser.quantize_type is not None:
+ args_or_parser.quantize = True
+ if args_or_parser.quantize and args_or_parser.quantize_type is None:
+ args_or_parser.quantize_type = "float8_weight_only"
+
+ if args_or_parser.mask_policy is not None and not args_or_parser.steps_mask:
+ args_or_parser.steps_mask = True
+ if args_or_parser.mask_policy == "s":
+ args_or_parser.mask_policy = "slow"
+ if args_or_parser.mask_policy == "m":
+ args_or_parser.mask_policy = "medium"
+ if args_or_parser.mask_policy == "f":
+ args_or_parser.mask_policy = "fast"
+ if args_or_parser.mask_policy == "u":
+ args_or_parser.mask_policy = "ultra"
+
+ return args_or_parser
+
+
+def parse_args():
+ parser = get_args(parse=False)
+
+ # Add server-specific arguments
+ parser.add_argument(
+ "--host",
+ type=str,
+ default="0.0.0.0",
+ help="Server host",
+ )
+ parser.add_argument(
+ "--port",
+ type=int,
+ default=8000,
+ help="Server port",
+ )
+ parser.add_argument(
+ "--workers",
+ type=int,
+ default=1,
+ help="Number of worker processes",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default=None,
+ help="Device (cuda/cpu), auto-detect by default",
+ )
+ parser.add_argument(
+ "--generator-device",
+ "--gen-device",
+ type=str,
+ default=None,
+ help="Device for torch.Generator, e.g., 'cuda' or 'cpu'. If not set, auto-detect by default.",
+ )
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ default="bfloat16",
+ choices=["float32", "float16", "bfloat16"],
+ help="Model dtype",
+ )
+ parser.add_argument(
+ "--enable-cpu-offload",
+ action="store_true",
+ default=False,
+ help="Enable CPU offload (saves GPU memory)",
+ )
+ parser.add_argument(
+ "--device-map",
+ type=str,
+ default=None,
+ help="Device map strategy (e.g., balanced)",
+ )
+
+ args = parser.parse_args()
+
+ # Handle quantize_type alias
+ args.quantize_type = normalize_quantize_type(getattr(args, "quantize_type", None))
+
+ if args.quantize_type is not None:
+ args.quantize = True
+ if args.quantize and args.quantize_type is None:
+ args.quantize_type = "float8_weight_only"
+
+ # Ensure model_path is required
+ if not args.model_path:
+ parser.error("--model-path is required")
+
+ return args
+
+
+def get_rank_device():
+ import torch.distributed as dist
+
+ available = current_platform.is_accelerator_available()
+ device_type = current_platform.device_type
+ if dist.is_initialized():
+ rank = dist.get_rank()
+ device = torch.device(device_type, rank % current_platform.device_count())
+ return rank, device
+ return 0, torch.device(device_type if available else "cpu")
+
+
+def maybe_init_distributed(args):
+ import torch.distributed as dist
+
+ platform_full_backend = current_platform.full_dist_backend
+ cpu_full_backend = CpuPlatform.full_dist_backend
+ backend = (
+ f"{cpu_full_backend},{platform_full_backend}"
+ if args.ulysses_anything
+ else current_platform.dist_backend
+ )
+
+ available = current_platform.is_accelerator_available()
+ device_type = current_platform.device_type
+ if args.parallel_type is not None:
+ dist.init_process_group(
+ backend=backend,
+ )
+ rank, device = get_rank_device()
+ current_platform.set_device(device)
+ return rank, device
+ return 0, torch.device(device_type if available else "cpu")
+
+
+def launch_server(args=None):
+ """Launch the serving server."""
+ if args is None:
+ args = parse_args()
+
+ rank, device = maybe_init_distributed(args)
+ if args.parallel_type is not None:
+ import torch.distributed as dist
+
+ logger.info(f"Initialized distributed: rank={rank}, world_size={dist.get_world_size()}")
+
+ torch_dtype = getattr(torch, args.dtype)
+
+ # Use cache argument from utils.get_args
+ enable_cache = args.cache
+ cache_config = None
+ if enable_cache:
+ cache_config = {
+ "residual_diff_threshold": args.residual_diff_threshold,
+ "Fn_compute_blocks": args.Fn_compute_blocks,
+ "Bn_compute_blocks": args.Bn_compute_blocks,
+ "max_warmup_steps": args.max_warmup_steps,
+ "warmup_interval": args.warmup_interval,
+ "max_cached_steps": args.max_cached_steps,
+ "max_continuous_cached_steps": args.max_continuous_cached_steps,
+ }
+
+ cache_config = align_cache_config(
+ model_path=args.model_path,
+ args=args,
+ base_cache_config=cache_config,
+ )
+
+ parallel_args = {}
+ if args.parallel_type in ["ulysses", "ring"]:
+ if hasattr(args, "attn") and args.attn is not None:
+ parallel_args["attention_backend"] = args.attn
+ else:
+ parallel_args["attention_backend"] = "native"
+ if hasattr(args, "ulysses_anything") and args.ulysses_anything:
+ parallel_args["experimental_ulysses_anything"] = True
+ if hasattr(args, "ulysses_float8") and args.ulysses_float8:
+ parallel_args["experimental_ulysses_float8"] = True
+ if hasattr(args, "ulysses_async") and args.ulysses_async:
+ parallel_args["experimental_ulysses_async"] = True
+ elif args.parallel_type == "tp":
+ pass
+
+ parallel_args["parallel_text_encoder"] = args.parallel_text_encoder
+ parallel_args["parallel_vae"] = args.parallel_vae
+
+ logger.info("Initializing model manager...")
+ model_manager = ModelManager(
+ model_path=args.model_path,
+ device=args.device or current_platform.device_type,
+ generator_device=args.generator_device,
+ torch_dtype=torch_dtype,
+ enable_cache=enable_cache,
+ cache_config=cache_config,
+ enable_cpu_offload=args.enable_cpu_offload,
+ device_map=args.device_map,
+ enable_compile=args.compile,
+ parallel_type=args.parallel_type,
+ parallel_args=parallel_args,
+ attn_backend=args.attn,
+ quantize=args.quantize,
+ quantize_type=args.quantize_type,
+ pipeline_quant_config_path=args.pipeline_quant_config_path,
+ lora_path=args.lora_path,
+ lora_name=args.lora_name,
+ fuse_lora=not args.disable_fuse_lora,
+ )
+
+ logger.info("Loading model...")
+ model_manager.load_model()
+ logger.info("Model loaded successfully!")
+
+ # For TP and CP, we need all ranks to participate in inference
+ # We use a simple broadcast mechanism to synchronize requests
+ if args.parallel_type in ["tp", "ulysses", "ring"]:
+ import torch.distributed as dist
+
+ dist.barrier()
+ logger.info(f"Rank {rank}: All ranks ready, starting service...")
+
+ if rank == 0:
+ # Rank 0: Start HTTP server and broadcast requests to other ranks
+ from cache_dit.serve.tp_worker import TPCoordinator
+
+ coordinator = TPCoordinator(model_manager, rank, dist.get_world_size())
+ app = create_app(coordinator)
+
+ logger.info(
+ f"Starting distributed server (rank 0, {args.parallel_type}) at http://{args.host}:{args.port}"
+ )
+ logger.info(f"API docs at http://{args.host}:{args.port}/docs")
+
+ uvicorn.run(
+ app,
+ host=args.host,
+ port=args.port,
+ workers=1, # Must be 1 for distributed
+ log_level="info",
+ )
+ else:
+ # Other ranks: Run worker loop to receive and execute requests
+ from cache_dit.serve.tp_worker import run_tp_worker
+
+ logger.info(f"Starting distributed worker (rank {rank}, {args.parallel_type})")
+ run_tp_worker(model_manager, rank)
+ else:
+ # Single GPU mode
+ if rank == 0:
+ app = create_app(model_manager)
+
+ logger.info(f"Starting server at http://{args.host}:{args.port}")
+ logger.info(f"API docs at http://{args.host}:{args.port}/docs")
+
+ uvicorn.run(
+ app,
+ host=args.host,
+ port=args.port,
+ workers=args.workers,
+ log_level="info",
+ )
+ else:
+ # This should not happen in single GPU mode
+ logger.warning(f"Rank {rank}: Unexpected rank in single GPU mode")
+ import time
+
+ while True:
+ time.sleep(1)
+
+
+if __name__ == "__main__":
+ launch_server()
diff --git a/src/cache_dit/serve/tp_worker.py b/src/cache_dit/serve/tp_worker.py
new file mode 100644
index 000000000..708ca6620
--- /dev/null
+++ b/src/cache_dit/serve/tp_worker.py
@@ -0,0 +1,171 @@
+"""
+Tensor Parallelism worker for distributed inference.
+
+This module implements a simple broadcast-based mechanism for TP serving:
+- Rank 0 receives HTTP requests and broadcasts them to all ranks
+- All ranks execute inference synchronously
+- Rank 0 collects and returns the result
+
+Inspired by SGLang's distributed architecture.
+"""
+
+import logging
+import pickle
+import time
+import threading
+
+import torch
+import torch.distributed as dist
+
+from ..platforms import current_platform
+from .model_manager import GenerateRequest, GenerateResponse, ModelManager
+
+logger = logging.getLogger(__name__)
+
+HEARTBEAT_INTERVAL = 300
+HEARTBEAT_SIZE = -1
+
+
+class TPCoordinator:
+ """
+ Coordinator for Tensor Parallelism inference.
+
+ Runs on rank 0 and broadcasts requests to all TP workers.
+ """
+
+ def __init__(self, model_manager: ModelManager, rank: int, world_size: int):
+ self.model_manager = model_manager
+ self.rank = rank
+ self.world_size = world_size
+ self._last_broadcast_time = time.time()
+ self._heartbeat_lock = threading.Lock()
+ self._stop_heartbeat = False
+ self._heartbeat_thread = None
+ logger.info(f"TPCoordinator initialized: rank={rank}, world_size={world_size}")
+ self._start_heartbeat()
+
+ @property
+ def pipe(self):
+ """Expose the underlying model_manager's pipe for compatibility."""
+ return self.model_manager.pipe
+
+ def get_model_info(self):
+ """Get model information from the underlying model manager."""
+ return self.model_manager.get_model_info()
+
+ def _start_heartbeat(self):
+ def heartbeat_loop():
+ while not self._stop_heartbeat:
+ time.sleep(HEARTBEAT_INTERVAL)
+ with self._heartbeat_lock:
+ if time.time() - self._last_broadcast_time > HEARTBEAT_INTERVAL:
+ try:
+ size_tensor = torch.tensor(
+ [HEARTBEAT_SIZE],
+ dtype=torch.long,
+ device=current_platform.device_type,
+ )
+ dist.broadcast(size_tensor, src=0)
+ self._last_broadcast_time = time.time()
+ logger.debug("Heartbeat sent to workers")
+ except Exception as e:
+ logger.error(f"Heartbeat failed: {e}")
+
+ self._heartbeat_thread = threading.Thread(target=heartbeat_loop, daemon=True)
+ self._heartbeat_thread.start()
+ logger.info(f"Heartbeat thread started (interval={HEARTBEAT_INTERVAL}s)")
+
+ def stop(self):
+ self._stop_heartbeat = True
+ if self._heartbeat_thread:
+ self._heartbeat_thread.join(timeout=1)
+
+ def generate(self, request: GenerateRequest) -> GenerateResponse:
+ """
+ Generate images using TP.
+
+ This method broadcasts the request to all ranks and collects the result.
+ """
+ with self._heartbeat_lock:
+ current_platform.synchronize()
+
+ request_data = pickle.dumps(request)
+ request_size = len(request_data)
+
+ size_tensor = torch.tensor(
+ [request_size], dtype=torch.long, device=current_platform.device_type
+ )
+ dist.broadcast(size_tensor, src=0)
+
+ padded_size = (
+ (request_size + self.world_size - 1) // self.world_size
+ ) * self.world_size
+ request_tensor = torch.zeros(
+ padded_size, dtype=torch.uint8, device=current_platform.device_type
+ )
+ request_tensor[:request_size].copy_(torch.frombuffer(request_data, dtype=torch.uint8))
+ dist.broadcast(request_tensor, src=0)
+
+ self._last_broadcast_time = time.time()
+
+ # IMPORTANT: Rank 0 must also deserialize the broadcasted request
+ # to ensure all ranks use exactly the same request object
+ broadcasted_request_data = request_tensor[:request_size].cpu().numpy().tobytes()
+ broadcasted_request = pickle.loads(broadcasted_request_data)
+
+ # All ranks execute inference with the broadcasted request
+ response = self.model_manager.generate(broadcasted_request)
+
+ # Rank 0 returns the result
+ return response
+
+
+def run_tp_worker(model_manager: ModelManager, rank: int):
+ """
+ Worker loop for TP ranks > 0.
+
+ Receives requests from rank 0 and executes inference.
+ """
+ logger.info(f"TP worker {rank} started, waiting for requests...")
+
+ while True:
+ try:
+ current_platform.synchronize()
+
+ size_tensor = torch.tensor([0], dtype=torch.long, device=current_platform.device_type)
+ dist.broadcast(size_tensor, src=0)
+ request_size = size_tensor.item()
+
+ if request_size == HEARTBEAT_SIZE:
+ logger.debug(f"Rank {rank} received heartbeat")
+ continue
+
+ padded_size = (
+ (request_size + dist.get_world_size() - 1) // dist.get_world_size()
+ ) * dist.get_world_size()
+ request_tensor = torch.zeros(
+ padded_size, dtype=torch.uint8, device=current_platform.device_type
+ )
+ dist.broadcast(request_tensor, src=0)
+
+ request_data = request_tensor[:request_size].cpu().numpy().tobytes()
+ request = pickle.loads(request_data)
+
+ logger.debug(f"Rank {rank} executing inference...")
+ _ = model_manager.generate(request)
+ logger.debug(f"Rank {rank} inference completed")
+
+ except KeyboardInterrupt:
+ logger.info(f"TP worker {rank} shutting down...")
+ break
+ except RuntimeError as e:
+ if "NCCL" in str(e) or "timeout" in str(e).lower():
+ logger.error(f"TP worker {rank} NCCL error: {e}")
+ dist.destroy_process_group()
+ break
+ else:
+ logger.exception(f"TP worker {rank} runtime error: {type(e).__name__}: {e}")
+ time.sleep(0.1)
+ except Exception as e:
+ logger.exception(f"TP worker {rank} error: {type(e).__name__}: {e}")
+ time.sleep(0.1)
diff --git a/src/cache_dit/serve/utils.py b/src/cache_dit/serve/utils.py
new file mode 100644
index 000000000..0143ee610
--- /dev/null
+++ b/src/cache_dit/serve/utils.py
@@ -0,0 +1,54 @@
+from typing import List, Optional, Tuple
+import torch
+from diffusers import DiffusionPipeline
+from cache_dit.logger import init_logger
+
+logger = init_logger(__name__)
+
+
+def get_text_encoder_from_pipe(
+ pipe: DiffusionPipeline,
+) -> Tuple[Optional[torch.nn.Module], Optional[str]]:
+ pipe_cls_name = pipe.__class__.__name__
+ if (
+ hasattr(pipe, "text_encoder_2")
+ and not pipe_cls_name.startswith("Hunyuan")
+ and not pipe_cls_name.startswith("Kandinsky")
+ ):
+ # Specific for FluxPipeline, FLUX.1-dev
+ return getattr(pipe, "text_encoder_2"), "text_encoder_2"
+ elif hasattr(pipe, "text_encoder_3"): # HiDream pipeline
+ return getattr(pipe, "text_encoder_3"), "text_encoder_3"
+ elif hasattr(pipe, "text_encoder"): # General case
+ return getattr(pipe, "text_encoder"), "text_encoder"
+ else:
+ return None, None
+
+
+def prepare_extra_parallel_modules(
+ pipe: DiffusionPipeline,
+ parallel_text_encoder: bool = False,
+ parallel_vae: bool = False,
+) -> List[torch.nn.Module]:
+ extra_parallel_modules = []
+
+ if parallel_text_encoder:
+ text_encoder, encoder_name = get_text_encoder_from_pipe(pipe)
+ if text_encoder is not None:
+ extra_parallel_modules.append(text_encoder)
+ logger.info(
+ f"Added {encoder_name} ({text_encoder.__class__.__name__}) to extra_parallel_modules"
+ )
+ else:
+ logger.warning(
+ "parallel_text_encoder is enabled but no text encoder found in the pipeline."
+ )
+
+ if parallel_vae:
+ if hasattr(pipe, "vae") and pipe.vae is not None:
+ extra_parallel_modules.append(pipe.vae)
+ logger.info(f"Added vae ({pipe.vae.__class__.__name__}) to extra_parallel_modules")
+ else:
+ logger.warning("parallel_vae is enabled but no VAE found in the pipeline.")
+
+ return extra_parallel_modules
diff --git a/src/cache_dit/summary.py b/src/cache_dit/summary.py
index 81d006cdc..f3241a518 100644
--- a/src/cache_dit/summary.py
+++ b/src/cache_dit/summary.py
@@ -1,3 +1,4 @@
+import sys
import torch
import dataclasses
@@ -6,13 +7,13 @@
from diffusers import DiffusionPipeline
from typing import Dict, Any, List, Union
-from cache_dit.caching import CacheType
-from cache_dit.caching import BlockAdapter
-from cache_dit.caching import BasicCacheConfig
-from cache_dit.caching import CalibratorConfig
-from cache_dit.caching import FakeDiffusionPipeline
-from cache_dit.parallelism import ParallelismConfig
-from cache_dit.caching import load_options
+from .caching import CacheType
+from .caching import BlockAdapter
+from .caching import BasicCacheConfig
+from .caching import CalibratorConfig
+from .caching import FakeDiffusionPipeline
+from .parallelism import ParallelismConfig
+from .caching import load_options
from cache_dit.logger import init_logger
@@ -269,7 +270,7 @@ def _summary(
cache_options = module._context_kwargs
cache_stats.cache_options = cache_options
if logging:
- print(f"\n🤗Context Options: {cls_name}\n\n{cache_options}")
+ print(f"\n🤗Context Options: {cls_name}\n\n{cache_options}", flush=True)
else:
if logging:
logger.warning(f"Can't find Context Options for: {cls_name}")
@@ -278,7 +279,10 @@ def _summary(
parallelism_config: ParallelismConfig = module._parallelism_config
cache_stats.parallelism_config = parallelism_config
if logging:
- print(f"\n🤖Parallelism Config: {cls_name}\n\n{parallelism_config.strify(True)}")
+ print(
+ f"\n🤖Parallelism Config: {cls_name}\n\n{parallelism_config.strify(True)}",
+ flush=True,
+ )
else:
if logging:
logger.warning(f"Can't find Parallelism Config for: {cls_name}")
@@ -319,63 +323,75 @@ def _summary(
qmax = np.max(diffs_values)
if pruned_ratio is not None:
- print(f"\n⚡️Pruned Blocks and Residual Diffs Statistics: {cls_name}\n")
+ print(f"\n⚡️Pruned Blocks and Residual Diffs Statistics: {cls_name}\n", flush=True)
print(
- "| Pruned Blocks | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
+ "| Pruned Blocks | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |",
+ flush=True,
)
print(
- "|---------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
+ "|---------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|",
+ flush=True,
)
print(
f"| {sum(pruned_blocks):<13} | {round(q0, 3):<9} | {round(q1, 3):<9} "
f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
- f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
+ f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |",
+ flush=True,
)
- print("")
+ print("", flush=True)
else:
- print(f"\n⚡️Cache Steps and Residual Diffs Statistics: {cls_name}\n")
+ print(f"\n⚡️Cache Steps and Residual Diffs Statistics: {cls_name}\n", flush=True)
print(
- "| Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
+ "| Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |",
+ flush=True,
)
print(
- "|-------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
+ "|-------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|",
+ flush=True,
)
print(
f"| {len(cached_steps):<11} | {round(q0, 3):<9} | {round(q1, 3):<9} "
f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
- f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
+ f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |",
+ flush=True,
)
- print("")
+ print("", flush=True)
if pruned_ratio is not None:
print(
- f"Dynamic Block Prune Ratio: {round(pruned_ratio * 100, 2)}% ({sum(pruned_blocks)}/{sum(actual_blocks)})\n"
+ f"Dynamic Block Prune Ratio: {round(pruned_ratio * 100, 2)}% ({sum(pruned_blocks)}/{sum(actual_blocks)})\n",
+ flush=True,
)
if details:
if pruned_ratio is not None:
- print(f"📚Pruned Blocks and Residual Diffs Details: {cls_name}\n")
+ print(f"📚Pruned Blocks and Residual Diffs Details: {cls_name}\n", flush=True)
pprint(
f"Pruned Blocks: {len(pruned_blocks)}, {pruned_blocks}",
)
+ sys.stdout.flush()
pprint(
f"Actual Blocks: {len(actual_blocks)}, {actual_blocks}",
)
+ sys.stdout.flush()
pprint(
f"Residual Diffs: {len(residual_diffs)}, {residual_diffs}",
compact=True,
)
+ sys.stdout.flush()
else:
print(f"📚Cache Steps and Residual Diffs Details: {cls_name}\n")
pprint(
f"Cache Steps: {len(cached_steps)}, {cached_steps}",
)
+ sys.stdout.flush()
pprint(
f"Residual Diffs: {len(residual_diffs)}, {residual_diffs}",
compact=True,
)
+ sys.stdout.flush()
if hasattr(module, "_cfg_cached_steps"):
cfg_cached_steps: list[int] = module._cfg_cached_steps
@@ -412,63 +428,81 @@ def _summary(
qmax = np.max(cfg_diffs_values)
if cfg_pruned_ratio is not None:
- print(f"\n⚡️CFG Pruned Blocks and Residual Diffs Statistics: {cls_name}\n")
+ print(
+ f"\n⚡️CFG Pruned Blocks and Residual Diffs Statistics: {cls_name}\n", flush=True
+ )
print(
- "| CFG Pruned Blocks | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
+ "| CFG Pruned Blocks | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |",
+ flush=True,
)
print(
- "|-------------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
+ "|-------------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|",
+ flush=True,
)
print(
f"| {sum(cfg_pruned_blocks):<18} | {round(q0, 3):<9} | {round(q1, 3):<9} "
f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
- f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
+ f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |",
+ flush=True,
)
- print("")
+ print("", flush=True)
else:
- print(f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {cls_name}\n")
+ print(
+ f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {cls_name}\n", flush=True
+ )
print(
- "| CFG Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |"
+ "| CFG Cache Steps | Diffs P00 | Diffs P25 | Diffs P50 | Diffs P75 | Diffs P95 | Diffs Min | Diffs Max |",
+ flush=True,
)
print(
- "|-----------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|"
+ "|-----------------|-----------|-----------|-----------|-----------|-----------|-----------|-----------|",
+ flush=True,
)
print(
f"| {len(cfg_cached_steps):<15} | {round(q0, 3):<9} | {round(q1, 3):<9} "
f"| {round(q2, 3):<9} | {round(q3, 3):<9} | {round(q4, 3):<9} "
- f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |"
+ f"| {round(qmin, 3):<9} | {round(qmax, 3):<9} |",
+ flush=True,
)
- print("")
+ print("", flush=True)
if cfg_pruned_ratio is not None:
print(
- f"CFG Dynamic Block Prune Ratio: {round(cfg_pruned_ratio * 100, 2)}% ({sum(cfg_pruned_blocks)}/{sum(cfg_actual_blocks)})\n"
+ f"CFG Dynamic Block Prune Ratio: {round(cfg_pruned_ratio * 100, 2)}% ({sum(cfg_pruned_blocks)}/{sum(cfg_actual_blocks)})\n",
+ flush=True,
)
if details:
if cfg_pruned_ratio is not None:
- print(f"📚CFG Pruned Blocks and Residual Diffs Details: {cls_name}\n")
+ print(
+ f"📚CFG Pruned Blocks and Residual Diffs Details: {cls_name}\n", flush=True
+ )
pprint(
f"CFG Pruned Blocks: {len(cfg_pruned_blocks)}, {cfg_pruned_blocks}",
)
+ sys.stdout.flush()
pprint(
f"CFG Actual Blocks: {len(cfg_actual_blocks)}, {cfg_actual_blocks}",
)
+ sys.stdout.flush()
pprint(
f"CFG Residual Diffs: {len(cfg_residual_diffs)}, {cfg_residual_diffs}",
compact=True,
)
+ sys.stdout.flush()
else:
print(f"📚CFG Cache Steps and Residual Diffs Details: {cls_name}\n")
pprint(
f"CFG Cache Steps: {len(cfg_cached_steps)}, {cfg_cached_steps}",
)
+ sys.stdout.flush()
pprint(
f"CFG Residual Diffs: {len(cfg_residual_diffs)}, {cfg_residual_diffs}",
compact=True,
)
+ sys.stdout.flush()
return cache_stats
@@ -483,14 +517,14 @@ def supported_matrix() -> str | None:
_pipelines_supported_cache += [
"LongCatVideo", # not in diffusers, but supported
]
- from cache_dit.parallelism.backends.native_diffusers import (
+ from cache_dit.parallelism.transformers.context_parallelism import (
ContextParallelismPlannerRegister,
)
_pipelines_supported_context_parallelism = (
ContextParallelismPlannerRegister.supported_planners()[1]
)
- from cache_dit.parallelism.backends.native_pytorch import (
+ from cache_dit.parallelism.transformers.tensor_parallelism import (
TensorParallelismPlannerRegister,
)
@@ -544,8 +578,8 @@ def supported_matrix() -> str | None:
matrix_str = "\n".join(matrix_lines)
- print("\nSupported Cache and Parallelism Matrix:\n")
- print(matrix_str)
+ print("\nSupported Cache and Parallelism Matrix:\n", flush=True)
+ print(matrix_str, flush=True)
return matrix_str
except Exception:
return None
diff --git a/src/cache_dit/utils.py b/src/cache_dit/utils.py
index 98d085c0b..43ef51ead 100644
--- a/src/cache_dit/utils.py
+++ b/src/cache_dit/utils.py
@@ -6,7 +6,7 @@
import contextlib
from cache_dit.logger import init_logger
-
+from .platforms import current_platform
logger = init_logger(__name__)
@@ -33,27 +33,34 @@ def maybe_empty_cache():
try:
time.sleep(1)
gc.collect()
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
+ current_platform.empty_cache()
+ current_platform.ipc_collect()
time.sleep(1)
gc.collect()
- torch.cuda.empty_cache()
- torch.cuda.ipc_collect()
+ current_platform.empty_cache()
+ current_platform.ipc_collect()
except Exception:
pass
-@torch.compiler.disable
def print_tensor(
x: torch.Tensor,
name: str,
dim: int = 1,
no_dist_shape: bool = True,
- disable: bool = False,
+ disable: bool = True,
):
if disable:
return
+ if x is None:
+ print(f"{name} is None")
+ return
+
+ if not isinstance(x, torch.Tensor):
+ print(f"{name} is not a tensor, type: {type(x)}")
+ return
+
x = x.contiguous()
if torch.distributed.is_initialized():
# all gather hidden_states and check values mean
@@ -66,11 +73,12 @@ def print_tensor(
else:
x_shape = x.shape
- if torch.distributed.get_rank() == 0:
- print(
- f"{name}, mean: {gather_x.float().mean().item()}, "
- f"std: {gather_x.float().std().item()}, shape: {x_shape}"
- )
+ rank = torch.distributed.get_rank()
+ print(
+ f"\nrank: {rank}, {name}, mean: {gather_x.float().mean().item()}, "
+ f"std: {gather_x.float().std().item()}, shape: {x_shape}",
+ flush=True,
+ )
else:
print(
f"{name}, mean: {x.float().mean().item()}, "
diff --git a/tests/.gitignore b/tests/.gitignore
index 3e41d33f6..f10a37476 100644
--- a/tests/.gitignore
+++ b/tests/.gitignore
@@ -165,3 +165,4 @@ report*.html
.DS_Store
tmp
+data
diff --git a/tests/README.md b/tests/README.md
deleted file mode 100644
index 455bc03a7..000000000
--- a/tests/README.md
+++ /dev/null
@@ -1,63 +0,0 @@
-# Tests
-
-## Taylorseer, Order=2
-
-```bash
-python3 test_taylorseer.py --order 2
-```
-
-
-
-## Taylorseer, Order=4
-
-```bash
-python3 test_taylorseer.py --order 4
-```
-
-
-
-## Metrics
-
-Image Metrics
-
-```bash
-# F1B0 w/o TaylorSeer
-python3 test_metrics.py --img-true data/U0_C0_NONE_R0.08_S0_T24.82s.png --img-test data/U0_C0_DBCACHE_F1B0S1W0T0O2_R0.08_S11_T15.43s.png
-
-data/U0_C0_NONE_R0.08_S0_T24.82s.png vs data/U0_C0_DBCACHE_F1B0S1W0T0O2_R0.08_S11_T15.43s.png, PSNR: 21.240280356949647
-data/U0_C0_NONE_R0.08_S0_T24.82s.png vs data/U0_C0_DBCACHE_F1B0S1W0T0O2_R0.08_S11_T15.43s.png, FID: 136.1958835812449
-
-# F1B0 w/ TaylorSeer, Order=2
-python3 test_metrics.py --img-true data/U0_C0_NONE_R0.08_S0_T24.82s.png --img-test data/U0_C0_DBCACHE_F1B0S1W0T1O2_R0.08_S10_T16.30s.png
-
-data/U0_C0_NONE_R0.08_S0_T24.82s.png vs data/U0_C0_DBCACHE_F1B0S1W0T1O2_R0.08_S10_T16.30s.png, PSNR: 24.68392809867634
-data/U0_C0_NONE_R0.08_S0_T24.82s.png vs data/U0_C0_DBCACHE_F1B0S1W0T1O2_R0.08_S10_T16.30s.png, FID: 75.76806327295184
-```
-
-Video Metrics
-
-```bash
-# F1B0 w/o TaylorSeer
-python3 test_metrics.py --video-true data/wan.NONE.mp4 --video-test data/wan.DBCACHE_F1B0W0T0O2.mp4
-data/wan.NONE.mp4 vs data/wan.DBCACHE_F1B0W0T0O2.mp4, PSNR: 19.043978283539857
-
-# F1B0 w/ TaylorSeer, Order=2
-python3 test_metrics.py --video-true data/wan.NONE.mp4 --video-test data/wan.DBCACHE_F1B0W0T1O2.mp4
-data/wan.NONE.mp4 vs data/wan.DBCACHE_F1B0W0T1O2.mp4, PSNR: 19.794726079042302
-
-# F1B0 w/ TaylorSeer, Order=4
-python3 test_metrics.py --video-true data/wan.NONE.mp4 --video-test data/wan.DBCACHE_F1B0W0T1O4.mp4
-data/wan.NONE.mp4 vs data/wan.DBCACHE_F1B0W0T1O4.mp4, PSNR: 19.779299536586567
-
-# F4B0 w/ TaylorSeer, Order=2
-python3 test_metrics.py --video-true data/wan.NONE.mp4 --video-test data/wan.DBCACHE_F4B0W0T1O2.mp4
-data/wan.NONE.mp4 vs data/wan.DBCACHE_F4B0W0T1O2.mp4, PSNR: 21.52726487066195
-
-# F8B0 w/ TaylorSeer, Order=2
-python3 test_metrics.py --video-true data/wan.NONE.mp4 --video-test data/wan.DBCACHE_F8B0W0T1O2.mp4
-data/wan.NONE.mp4 vs data/wan.DBCACHE_F8B0W0T1O2.mp4, PSNR: 27.970811066301014
-
-# F12B0 w/ TaylorSeer, Order=2
-python3 test_metrics.py --video-true data/wan.NONE.mp4 --video-test data/wan.DBCACHE_F12B0W0T1O2.mp4
-data/wan.NONE.mp4 vs data/wan.DBCACHE_F12B0W0T1O2.mp4, PSNR: 33.32353616116749
-```
diff --git a/tests/api/config.yaml b/tests/api/config.yaml
new file mode 100644
index 000000000..094352c2a
--- /dev/null
+++ b/tests/api/config.yaml
@@ -0,0 +1,17 @@
+cache_config:
+ max_warmup_steps: 8
+ warmup_interval: 2
+ max_cached_steps: -1
+ max_continuous_cached_steps: 2
+ Fn_compute_blocks: 1
+ Bn_compute_blocks: 0
+ num_inference_steps: 28
+ steps_computation_mask: fast
+ residual_diff_threshold: 0.12
+ enable_taylorseer: true
+ taylorseer_order: 1
+parallelism_config:
+ ulysses_size: 2
+ parallel_kwargs:
+ attention_backend: native
+ extra_parallel_modules: ["text_encoder", "vae"]
diff --git a/tests/api/test_forward_pattern.py b/tests/api/test_forward_pattern.py
new file mode 100644
index 000000000..6642e9688
--- /dev/null
+++ b/tests/api/test_forward_pattern.py
@@ -0,0 +1,105 @@
+import gc
+import pytest
+import torch
+import cache_dit
+from cache_dit import ForwardPattern, BlockAdapter, DBCacheConfig
+from cache_dit.platforms import current_platform
+from utils import RandPipeline
+
+
+DEVICES = (
+ ["cpu"]
+ if not current_platform.is_accelerator_available()
+ else ["cpu", current_platform.device_type]
+)
+PATTERNS = [
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_2,
+ ForwardPattern.Pattern_3,
+ ForwardPattern.Pattern_4,
+ ForwardPattern.Pattern_5,
+]
+
+DTYPES = (
+ [torch.float32]
+ if not current_platform.is_accelerator_available()
+ else [torch.float32, torch.bfloat16]
+)
+STEPS = [50]
+
+
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("pattern", PATTERNS)
+@pytest.mark.parametrize("dtype", DTYPES)
+@pytest.mark.parametrize("steps", STEPS)
+def test_forward_pattern(device, pattern, dtype, steps):
+ gc.collect()
+ pipe = RandPipeline(pattern=pattern)
+
+ cache_dit.enable_cache(
+ BlockAdapter(
+ pipe=pipe,
+ transformer=pipe.transformer,
+ blocks=pipe.transformer.transformer_blocks,
+ blocks_name="transformer_blocks",
+ forward_pattern=pipe.pattern,
+ ),
+ cache_config=DBCacheConfig(
+ Fn_compute_blocks=1,
+ Bn_compute_blocks=0,
+ residual_diff_threshold=0.05,
+ ),
+ )
+ bs, seq_len, headdim = 1, 1024, 64
+
+ hidden_states = torch.normal(
+ mean=100.0,
+ std=20.0,
+ size=(bs, seq_len, headdim),
+ dtype=dtype,
+ )
+
+ encoder_hidden_states = None
+ if pattern in [
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_2,
+ ]:
+ encoder_hidden_states = torch.normal(
+ mean=100.0,
+ std=20.0,
+ size=(bs, seq_len, headdim),
+ dtype=dtype,
+ )
+
+ if device == current_platform.device_type:
+ pipe.to(device)
+ hidden_states = hidden_states.to(device)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = encoder_hidden_states.to(device)
+
+ if pattern in [
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_2,
+ ]:
+ _ = pipe(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ num_inference_steps=steps,
+ )
+ else:
+ _ = pipe(
+ hidden_states,
+ num_inference_steps=steps,
+ )
+
+ cache_dit.summary(pipe, details=True)
+ cache_dit.disable_cache(pipe)
+
+ del pipe
+ del hidden_states
+ if encoder_hidden_states is not None:
+ del encoder_hidden_states
+ gc.collect()
diff --git a/tests/api/test_load_configs.py b/tests/api/test_load_configs.py
new file mode 100644
index 000000000..3e0be556e
--- /dev/null
+++ b/tests/api/test_load_configs.py
@@ -0,0 +1,13 @@
+import pytest
+import cache_dit
+
+CONFIG = ["api/config.yaml"]
+
+
+@pytest.mark.parametrize("config", CONFIG)
+def test_load_configs(config):
+ configs = cache_dit.load_configs(config)
+ assert "cache_config" in configs
+ assert "calibrator_config" in configs
+ assert "parallelism_config" in configs
+ print("Loaded configs:", configs, flush=True)
diff --git a/tests/api/test_refresh_context.py b/tests/api/test_refresh_context.py
new file mode 100644
index 000000000..579e95f9e
--- /dev/null
+++ b/tests/api/test_refresh_context.py
@@ -0,0 +1,162 @@
+import gc
+import pytest
+import torch
+import cache_dit
+from cache_dit import ForwardPattern, BlockAdapter, DBCacheConfig
+from cache_dit.platforms import current_platform
+from utils import RandPipeline
+
+
+DEVICES = (
+ ["cpu"]
+ if not current_platform.is_accelerator_available()
+ else ["cpu", current_platform.device_type]
+)
+PATTERNS = [
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_2,
+ ForwardPattern.Pattern_3,
+ ForwardPattern.Pattern_4,
+ ForwardPattern.Pattern_5,
+]
+
+DTYPES = (
+ [torch.float32]
+ if not current_platform.is_accelerator_available()
+ else [torch.float32, torch.bfloat16]
+)
+
+
+@pytest.mark.parametrize("device", DEVICES)
+@pytest.mark.parametrize("pattern", PATTERNS)
+@pytest.mark.parametrize("dtype", DTYPES)
+def test_refresh_context(device, pattern, dtype):
+ gc.collect()
+ pipe = RandPipeline(pattern=pattern) # type: RandPipeline
+ transformer = pipe.transformer
+
+ transformer = pipe.transformer
+ adapter = cache_dit.enable_cache(
+ BlockAdapter(
+ transformer=transformer,
+ blocks=transformer.transformer_blocks,
+ forward_pattern=pipe.pattern,
+ ),
+ cache_config=DBCacheConfig(
+ Fn_compute_blocks=8,
+ Bn_compute_blocks=0,
+ residual_diff_threshold=0.05,
+ ),
+ )
+
+ # Transformer only API
+ bs, seq_len, headdim = 1, 1024, 64
+
+ hidden_states = torch.normal(
+ mean=100.0,
+ std=20.0,
+ size=(bs, seq_len, headdim),
+ dtype=dtype,
+ )
+
+ encoder_hidden_states = None
+ if pattern in [
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_2,
+ ]:
+ encoder_hidden_states = torch.normal(
+ mean=100.0,
+ std=20.0,
+ size=(bs, seq_len, headdim),
+ dtype=dtype,
+ )
+
+ if device == current_platform.device_type:
+ pipe.to(device)
+ hidden_states = hidden_states.to(device)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = encoder_hidden_states.to(device)
+
+ STEPS = [16, 28, 50]
+ if pattern in [
+ ForwardPattern.Pattern_0,
+ ForwardPattern.Pattern_1,
+ ForwardPattern.Pattern_2,
+ ]:
+ for i, steps in enumerate(STEPS):
+ # Refresh cache context
+ if i == 0:
+ # Test num_inference_steps only case
+ cache_dit.refresh_context(
+ transformer,
+ num_inference_steps=steps,
+ verbose=True,
+ )
+ else:
+ cache_dit.refresh_context(
+ transformer,
+ cache_config=DBCacheConfig(
+ Fn_compute_blocks=1,
+ Bn_compute_blocks=0,
+ residual_diff_threshold=0.08,
+ num_inference_steps=steps,
+ steps_computation_mask=cache_dit.steps_mask(
+ mask_policy="fast",
+ total_steps=steps,
+ ),
+ steps_computation_policy="dynamic",
+ enable_separate_cfg=False,
+ ),
+ verbose=True,
+ )
+ _ = pipe(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ num_inference_steps=steps,
+ )
+ else:
+ for i, steps in enumerate(STEPS):
+ if i == 0:
+ # Test num_inference_steps only case
+ cache_dit.refresh_context(
+ transformer,
+ num_inference_steps=steps,
+ verbose=True,
+ )
+ else:
+ # Refresh cache context
+ cache_dit.refresh_context(
+ transformer,
+ cache_config=DBCacheConfig(
+ Fn_compute_blocks=1,
+ Bn_compute_blocks=0,
+ residual_diff_threshold=0.08,
+ num_inference_steps=steps,
+ steps_computation_mask=cache_dit.steps_mask(
+ mask_policy="fast",
+ total_steps=steps,
+ ),
+ steps_computation_policy="dynamic",
+ enable_separate_cfg=False,
+ ),
+ verbose=True,
+ )
+ _ = pipe(
+ hidden_states,
+ num_inference_steps=steps,
+ )
+
+ cache_dit.summary(transformer)
+ # We have to disable cache before deleting the pipe and adapter
+ # using block adapter instance due to the fake pipe we used in
+ # transformer only API.
+ cache_dit.disable_cache(adapter)
+
+ del pipe
+ del adapter
+ del hidden_states
+ if encoder_hidden_states is not None:
+ del encoder_hidden_states
+ gc.collect()
diff --git a/tests/api/test_taylorseers.py b/tests/api/test_taylorseers.py
new file mode 100644
index 000000000..e43462887
--- /dev/null
+++ b/tests/api/test_taylorseers.py
@@ -0,0 +1,41 @@
+import pytest
+import numpy as np
+from cache_dit.caching.cache_contexts.calibrators import (
+ TaylorSeerCalibrator,
+)
+
+N_DERIVATIVES = [1, 2, 3]
+MAX_WARMUP_STEPS = [2, 5]
+SKIP_INTERVAL_STEPS = [1, 2]
+
+
+@pytest.mark.parametrize("n_derivatives", N_DERIVATIVES)
+@pytest.mark.parametrize("max_warmup_steps", MAX_WARMUP_STEPS)
+@pytest.mark.parametrize("skip_interval_steps", SKIP_INTERVAL_STEPS)
+def test_taylor_seer_calibrator(
+ n_derivatives,
+ max_warmup_steps,
+ skip_interval_steps,
+):
+ taylor_seer = TaylorSeerCalibrator(
+ n_derivatives=n_derivatives,
+ max_warmup_steps=max_warmup_steps,
+ skip_interval_steps=skip_interval_steps,
+ )
+
+ x_values = np.arange(0, 10, 0.1)
+
+ y_pred = []
+ errors = []
+ for x in x_values:
+ y = x**2
+ y_approx = taylor_seer.step(y)
+ y_pred.append(y_approx)
+ errors.append(abs(y - y_approx))
+
+ mean_error = np.mean(errors)
+ print(
+ f"Mean approximation error: {mean_error:.5f}, n_derivatives: {n_derivatives}, "
+ f"max_warmup_steps: {max_warmup_steps}, skip_interval_steps: {skip_interval_steps}",
+ flush=True,
+ )
diff --git a/tests/test_forward_pattern.py b/tests/api/utils.py
similarity index 84%
rename from tests/test_forward_pattern.py
rename to tests/api/utils.py
index b23506eeb..2ad391c8a 100644
--- a/tests/test_forward_pattern.py
+++ b/tests/api/utils.py
@@ -1,14 +1,12 @@
import dataclasses
-import argparse
from tqdm import tqdm
import torch
import torch.nn as nn
from typing import Tuple, Union
from diffusers import DiffusionPipeline
-import cache_dit
-from cache_dit import ForwardPattern, BlockAdapter, DBCacheConfig
+from cache_dit import ForwardPattern
RATIO = 0.7
RAND_RATIO = 0.5
@@ -328,72 +326,3 @@ def __call__(
def to(self, *args, **kwargs):
self.transformer.to(*args, **kwargs)
-
-
-def get_args() -> argparse.ArgumentParser:
- parser = argparse.ArgumentParser()
- parser.add_argument("--pattern", type=int, choices=[0, 1, 2, 3, 4, 5], default=0)
- return parser.parse_args()
-
-
-if __name__ == "__main__":
- args = get_args()
- print(args)
- if args.pattern == 0:
- pipe = RandPipeline(pattern=ForwardPattern.Pattern_0)
- elif args.pattern == 1:
- pipe = RandPipeline(pattern=ForwardPattern.Pattern_1)
- elif args.pattern == 2:
- pipe = RandPipeline(pattern=ForwardPattern.Pattern_2)
- elif args.pattern == 3:
- pipe = RandPipeline(pattern=ForwardPattern.Pattern_3)
- elif args.pattern == 4:
- pipe = RandPipeline(pattern=ForwardPattern.Pattern_4)
- else:
- pipe = RandPipeline(pattern=ForwardPattern.Pattern_5)
-
- pipe.to("cuda")
-
- cache_dit.enable_cache(
- BlockAdapter(
- pipe=pipe,
- transformer=pipe.transformer,
- blocks=pipe.transformer.transformer_blocks,
- blocks_name="transformer_blocks",
- forward_pattern=pipe.pattern,
- ),
- cache_config=DBCacheConfig(
- Fn_compute_blocks=1,
- Bn_compute_blocks=0,
- residual_diff_threshold=0.05,
- ),
- )
- bs, seq_len, headdim = 1, 1024, 1024
-
- hidden_states = torch.normal(
- mean=100.0,
- std=20.0,
- size=(bs, seq_len, headdim),
- dtype=torch.bfloat16,
- ).to("cuda")
-
- encoder_hidden_states = torch.normal(
- mean=100.0,
- std=20.0,
- size=(bs, seq_len, headdim),
- dtype=torch.bfloat16,
- ).to("cuda")
-
- if args.pattern in [0, 1, 2]:
- output = pipe(
- hidden_states,
- encoder_hidden_states=encoder_hidden_states,
- num_inference_steps=50,
- )
- else:
- output = pipe(
- hidden_states,
- num_inference_steps=50,
- )
-
- cache_dit.summary(pipe, details=True)
diff --git a/tests/cache_config.yaml b/tests/cache_config.yaml
deleted file mode 100644
index cd9654b36..000000000
--- a/tests/cache_config.yaml
+++ /dev/null
@@ -1,10 +0,0 @@
-max_warmup_steps: 0
-max_cached_steps: -1
-max_continuous_cached_steps: 2
-Fn_compute_blocks: 1
-Bn_compute_blocks: 0
-residual_diff_threshold: 0.12
-enable_taylorseer: true
-enable_encoder_taylorseer: true
-taylorseer_cache_type: residual
-taylorseer_order: 2
diff --git a/tests/parallelism/test_tp_utils_divisible_attr.py b/tests/parallelism/test_tp_utils_divisible_attr.py
new file mode 100644
index 000000000..0f689fe20
--- /dev/null
+++ b/tests/parallelism/test_tp_utils_divisible_attr.py
@@ -0,0 +1,73 @@
+"""
+Minimal runnable test script (non-pytest style), consistent with other files in `tests/`.
+
+Run:
+ python3 tests/parallelism/test_tp_utils_divisible_attr.py
+"""
+
+from cache_dit.parallelism.transformers.native_pytorch.tensor_parallelism.tp_utils import (
+ shard_divisible_attr,
+)
+
+
+class Dummy:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+
+def test_shard_divisible_attr_success_updates_value():
+ attn = Dummy(heads=30)
+ new_heads = shard_divisible_attr(attn, "heads", 5, what="attn", context="test")
+ assert new_heads == 6
+ assert attn.heads == 6
+
+
+def test_shard_divisible_attr_raises_on_not_divisible():
+ attn = Dummy(heads=30)
+ try:
+ shard_divisible_attr(attn, "heads", 4, what="attn", context="test")
+ raise AssertionError("Expected ValueError for non-divisible heads/tp_size, but got none.")
+ except ValueError as e:
+ # should be a clear, startup-time error message
+ msg = str(e)
+ assert "tp_size=4" in msg
+ assert "heads=30" in msg
+
+
+def test_shard_divisible_attr_raises_on_missing_attr():
+ attn = Dummy()
+ try:
+ shard_divisible_attr(attn, "heads", 2, what="attn", context="test")
+ raise AssertionError("Expected AttributeError for missing attr, but got none.")
+ except AttributeError:
+ pass
+
+
+def main():
+ tests = [
+ test_shard_divisible_attr_success_updates_value,
+ test_shard_divisible_attr_raises_on_not_divisible,
+ test_shard_divisible_attr_raises_on_missing_attr,
+ ]
+
+ print("== cache-dit TP utils self-check ==")
+ passed = 0
+ failed = 0
+ for t in tests:
+ name = t.__name__
+ try:
+ t()
+ print(f"[PASS] {name}")
+ passed += 1
+ except Exception as e:
+ print(f"[FAIL] {name}: {type(e).__name__}: {e}")
+ failed += 1
+
+ print(f"Summary: passed={passed}, failed={failed}, total={len(tests)}")
+ if failed != 0:
+ raise SystemExit(1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/pipeline/.gitignore b/tests/serving/.gitignore
similarity index 60%
rename from examples/pipeline/.gitignore
rename to tests/serving/.gitignore
index 04741df5c..ff725431c 100644
--- a/examples/pipeline/.gitignore
+++ b/tests/serving/.gitignore
@@ -1,4 +1,7 @@
-*.gif
*.png
-*.mp4
*.jpg
+*.jpeg
+*.gif
+*.mp4
+tmp
+data
diff --git a/tests/serving/test_concurrent_requests.py b/tests/serving/test_concurrent_requests.py
new file mode 100644
index 000000000..c8fad7d71
--- /dev/null
+++ b/tests/serving/test_concurrent_requests.py
@@ -0,0 +1,48 @@
+import asyncio
+import aiohttp
+import time
+
+
+async def send_request(session, url, request_id):
+ payload = {
+ "prompt": f"test prompt {request_id}",
+ "width": 512,
+ "height": 512,
+ "num_inference_steps": 10,
+ "guidance_scale": 3.5,
+ "seed": request_id,
+ "num_images": 1,
+ }
+
+ start_time = time.time()
+ async with session.post(url, json=payload) as response:
+ elapsed = time.time() - start_time
+ result = await response.json() if response.status == 200 else {}
+ return {
+ "id": request_id,
+ "status": response.status,
+ "elapsed": elapsed,
+ "server_time": result.get("time_cost", 0),
+ }
+
+
+async def main():
+ url = "http://localhost:8000/generate"
+ num_requests = 3
+
+ start_time = time.time()
+ async with aiohttp.ClientSession() as session:
+ tasks = [send_request(session, url, i) for i in range(num_requests)]
+ results = await asyncio.gather(*tasks)
+
+ total_time = time.time() - start_time
+
+ print(f"Total: {total_time:.2f}s")
+ for r in results:
+ print(
+ f"Request {r['id']}: status={r['status']}, elapsed={r['elapsed']:.2f}s, server={r['server_time']:.2f}s"
+ )
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/tests/serving/test_flux2_image_edit_serving.py b/tests/serving/test_flux2_image_edit_serving.py
new file mode 100644
index 000000000..27d81c2d3
--- /dev/null
+++ b/tests/serving/test_flux2_image_edit_serving.py
@@ -0,0 +1,140 @@
+"""
+CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 \
+ -m cache_dit.serve.serve \
+ --model-path black-forest-labs/FLUX.2-dev \
+ --parallel-type ulysses \
+ --parallel-text-encoder \
+ --quantize-type float8_wo \
+ --attn _flash_3 \
+ --cache \
+ --compile \
+ --ulysses-anything
+"""
+
+import os
+import requests
+import base64
+from PIL import Image
+from io import BytesIO
+
+
+def call_api(prompt, image_urls=None, name="test", **kwargs):
+ host = os.environ.get("CACHE_DIT_HOST", "localhost")
+ port = int(os.environ.get("CACHE_DIT_PORT", 8000))
+ url = f"http://{host}:{port}/generate"
+
+ payload = {
+ "prompt": prompt,
+ "width": kwargs.get("width", 1024),
+ "height": kwargs.get("height", 1024),
+ "num_inference_steps": kwargs.get("num_inference_steps", 50),
+ "guidance_scale": kwargs.get("guidance_scale", 4.0),
+ "seed": kwargs.get("seed", 0),
+ }
+
+ if "output_format" in kwargs:
+ payload["output_format"] = kwargs["output_format"]
+ if "output_dir" in kwargs:
+ payload["output_dir"] = kwargs["output_dir"]
+
+ if image_urls:
+ payload["image_urls"] = image_urls
+
+ response = requests.post(url, json=payload, timeout=300)
+ response.raise_for_status()
+ result = response.json()
+ assert "images" in result and len(result["images"]) > 0, "No images in response"
+
+ if payload.get("output_format", "base64") == "path":
+ filename = result["images"][0]
+ assert os.path.exists(filename)
+ img = Image.open(filename)
+ print(f"Saved: {filename} ({img.size[0]}x{img.size[1]})")
+ return filename
+ else:
+ img_data = base64.b64decode(result["images"][0])
+ img = Image.open(BytesIO(img_data))
+
+ filename = f"{name}.png"
+ img.save(filename)
+
+ print(f"Saved: {filename} ({img.size[0]}x{img.size[1]})")
+ return filename
+
+
+def test_single():
+ return call_api(
+ prompt="Put a birthday hat on the dog in the image",
+ image_urls=["https://modelscope.oss-cn-beijing.aliyuncs.com/Dog.png"],
+ name="single_edit",
+ seed=0,
+ )
+
+
+def test_multi():
+ return call_api(
+ prompt="Realistic style, dog chases frisbee",
+ image_urls=[
+ "https://modelscope.oss-cn-beijing.aliyuncs.com/Dog.png",
+ "https://modelscope.oss-cn-beijing.aliyuncs.com/Frisbee.png",
+ ],
+ name="multi_edit",
+ seed=0,
+ )
+
+
+def test_base64():
+ image_url = "https://modelscope.oss-cn-beijing.aliyuncs.com/Dog.png"
+ response = requests.get(image_url, timeout=30)
+ img_base64 = base64.b64encode(response.content).decode("utf-8")
+
+ filename1 = call_api(
+ prompt="Put a birthday hat on the dog", image_urls=[img_base64], name="base64_raw", seed=0
+ )
+
+ data_uri = f"data:image/png;base64,{img_base64}"
+ filename2 = call_api(
+ prompt="Put a birthday hat on the dog", image_urls=[data_uri], name="base64_uri", seed=0
+ )
+
+ return filename1, filename2
+
+
+def test_text():
+ return call_api(
+ prompt="A beautiful landscape with mountains and lakes",
+ name="text_gen",
+ num_inference_steps=28,
+ seed=0,
+ )
+
+
+def test_text_ulysses_bad_resolution_regression():
+ filename = call_api(
+ prompt="A beautiful landscape with mountains and lakes",
+ name="text_gen_724x1080",
+ width=724,
+ height=1080,
+ num_inference_steps=8,
+ seed=0,
+ )
+ return filename
+
+
+def test_text_path_output():
+ return call_api(
+ prompt="A beautiful landscape with mountains and lakes",
+ name="text_gen_path",
+ num_inference_steps=8,
+ output_format="path",
+ output_dir="outputs_test",
+ )
+
+
+if __name__ == "__main__":
+ test_single()
+ test_multi()
+ test_base64()
+ test_text()
+ test_text_ulysses_bad_resolution_regression()
+ test_text_path_output()
diff --git a/tests/serving/test_flux2_turbo_lora_serving.py b/tests/serving/test_flux2_turbo_lora_serving.py
new file mode 100644
index 000000000..49e939c16
--- /dev/null
+++ b/tests/serving/test_flux2_turbo_lora_serving.py
@@ -0,0 +1,104 @@
+"""Test FLUX.2 Turbo LoRA model serving.
+
+Server setup:
+ CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 \
+ -m cache_dit.serve.serve \
+ --model-path black-forest-labs/FLUX.2-dev \
+ --lora-path fal/FLUX.2-dev-Turbo \
+ --lora-name flux.2-turbo-lora.safetensors \
+ --parallel-type ulysses \
+ --parallel-text-encoder \
+ --quantize-type float8_wo \
+ --attn _flash_3 \
+ --cache \
+ --compile \
+ --ulysses-anything
+
+This test calls /generate with a custom sigma schedule (TURBO_SIGMAS) for 8-step turbo inference.
+
+Reference LoRA: https://huggingface.co/fal/FLUX.2-dev-Turbo
+Base model: https://huggingface.co/black-forest-labs/FLUX.2-dev
+"""
+
+import os
+import requests
+import base64
+from PIL import Image
+from io import BytesIO
+
+
+# Pre-shifted custom sigmas for 8-step turbo inference
+TURBO_SIGMAS = [1.0, 0.6509, 0.4374, 0.2932, 0.1893, 0.1108, 0.0495, 0.00031]
+
+
+def call_api(prompt, name="flux2_turbo", **kwargs):
+ host = os.environ.get("CACHE_DIT_HOST", "localhost")
+ port = int(os.environ.get("CACHE_DIT_PORT", 8000))
+ url = f"http://{host}:{port}/generate"
+
+ payload = {
+ "prompt": prompt,
+ "width": kwargs.get("width", 1024),
+ "height": kwargs.get("height", 1024),
+ "num_inference_steps": kwargs.get("num_inference_steps", 8),
+ "guidance_scale": kwargs.get("guidance_scale", 2.5),
+ "sigmas": kwargs.get("sigmas", TURBO_SIGMAS),
+ "seed": kwargs.get("seed", 0),
+ "num_images": kwargs.get("num_images", 1),
+ }
+
+ if "output_format" in kwargs:
+ payload["output_format"] = kwargs["output_format"]
+ if "output_dir" in kwargs:
+ payload["output_dir"] = kwargs["output_dir"]
+
+ response = requests.post(url, json=payload, timeout=600)
+ response.raise_for_status()
+ result = response.json()
+
+ assert "images" in result and result["images"], "No images in response"
+
+ if payload.get("output_format", "base64") == "path":
+ filename = result["images"][0]
+ assert os.path.exists(filename)
+ img = Image.open(filename)
+ print(f"Saved: {filename} ({img.size[0]}x{img.size[1]})")
+ return filename
+
+ img_data = base64.b64decode(result["images"][0])
+ img = Image.open(BytesIO(img_data))
+
+ filename = f"{name}.png"
+ img.save(filename)
+ print(f"Saved: {filename} ({img.size[0]}x{img.size[1]})")
+ return filename
+
+
+def test_flux2_turbo_lora():
+ prompt = (
+ "Industrial product shot of a chrome turbocharger with glowing hot exhaust manifold, "
+ "engraved text 'FLUX.2 [dev] Turbo by fal' on the compressor housing and 'fal' on the turbine wheel, "
+ "gradient heat glow from orange to electric blue , studio lighting with dramatic shadows, "
+ "shallow depth of field, engineering blueprint pattern in background."
+ )
+
+ return call_api(
+ prompt=prompt,
+ name="flux2_turbo_lora",
+ num_inference_steps=8,
+ guidance_scale=2.5,
+ sigmas=TURBO_SIGMAS,
+ width=1024,
+ height=1024,
+ seed=0,
+ )
+
+
+if __name__ == "__main__":
+ print("=" * 80)
+ print("Testing FLUX.2 Turbo LoRA Model Serving")
+ print("=" * 80)
+ test_flux2_turbo_lora()
+ print("=" * 80)
+ print("Done")
+ print("=" * 80)
diff --git a/tests/serving/test_inference_timestamps.py b/tests/serving/test_inference_timestamps.py
new file mode 100644
index 000000000..cdfca5499
--- /dev/null
+++ b/tests/serving/test_inference_timestamps.py
@@ -0,0 +1,44 @@
+import os
+import requests
+
+
+def _call_generate_api(**overrides):
+ host = os.environ.get("CACHE_DIT_HOST", "localhost")
+ port = int(os.environ.get("CACHE_DIT_PORT", 8000))
+ url = f"http://{host}:{port}/generate"
+
+ payload = {
+ "prompt": overrides.get("prompt", "timestamp test prompt"),
+ "width": overrides.get("width", 1024),
+ "height": overrides.get("height", 1024),
+ "num_inference_steps": overrides.get("num_inference_steps", 8),
+ "guidance_scale": overrides.get("guidance_scale", 1.0),
+ "seed": overrides.get("seed", 0),
+ "output_format": "path",
+ }
+
+ response = requests.post(url, json=payload, timeout=600)
+ response.raise_for_status()
+ return response.json()
+
+
+def test_generate_returns_inference_timestamps():
+ data = _call_generate_api()
+
+ assert "inference_start_time" in data
+ assert "inference_end_time" in data
+ assert "time_cost" in data
+
+ print("data: ", data)
+
+ start = data["inference_start_time"]
+ end = data["inference_end_time"]
+ time_cost = data["time_cost"]
+
+ print("start", start)
+ print("end", end)
+ print("time_cost", time_cost)
+
+
+if __name__ == "__main__":
+ test_generate_returns_inference_timestamps()
diff --git a/tests/serving/test_ltx2_image2video.py b/tests/serving/test_ltx2_image2video.py
new file mode 100644
index 000000000..7a265075e
--- /dev/null
+++ b/tests/serving/test_ltx2_image2video.py
@@ -0,0 +1,169 @@
+"""Test LTX-2 Image-to-Video model serving.
+
+Server setup (base model):
+ CACHE_DIT_LTX2_PIPELINE=i2v CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
+ -m cache_dit.serve.serve \
+ --model-path Lightricks/LTX-2 \
+ --parallel-type ulysses \
+ --parallel-text-encoder \
+ --parallel-vae \
+ --cache \
+ --ulysses-anything
+
+ Server setup (base model, TP4):
+ CACHE_DIT_LTX2_PIPELINE=i2v CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 \
+ -m cache_dit.serve.serve \
+ --model-path Lightricks/LTX-2 \
+ --parallel-type tp \
+ --cache
+
+Server setup (base + LoRA):
+ # NOTE: the LoRA weight filename may differ. Common filenames include:
+ # - pytorch_lora_weights.safetensors
+ # - adapter_model.safetensors
+ CACHE_DIT_LTX2_PIPELINE=i2v CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
+ -m cache_dit.serve.serve \
+ --model-path Lightricks/LTX-2 \
+ --lora-path Lightricks/LTX-2-19b-IC-LoRA-Canny-Control \
+ --lora-name ltx-2-19b-ic-lora-canny-control.safetensors \
+ --parallel-type ulysses \
+ --parallel-text-encoder \
+ --parallel-vae \
+ --cache \
+ --ulysses-anything
+
+ Server setup (base + LoRA, TP4):
+ # NOTE: the LoRA weight filename may differ. Common filenames include:
+ # - pytorch_lora_weights.safetensors
+ # - adapter_model.safetensors
+ CACHE_DIT_LTX2_PIPELINE=i2v CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
+ -m cache_dit.serve.serve \
+ --model-path Lightricks/LTX-2 \
+ --lora-path Lightricks/LTX-2-19b-IC-LoRA-Canny-Control \
+ --lora-name ltx-2-19b-ic-lora-canny-control.safetensors \
+ --parallel-type tp \
+ --cache
+
+Usage:
+ # Base model
+ CACHE_DIT_LTX2_MODE=base python -m pytest -q cache-dit/tests/serving/test_ltx2_image2video.py
+
+ # LoRA model
+ CACHE_DIT_LTX2_MODE=lora python -m pytest -q cache-dit/tests/serving/test_ltx2_image2video.py
+"""
+
+import os
+import base64
+import requests
+
+
+def call_api(prompt, image_url, name="ltx2_i2v", **kwargs):
+ host = os.environ.get("CACHE_DIT_HOST", "localhost")
+ port = int(os.environ.get("CACHE_DIT_PORT", 8000))
+ url = f"http://{host}:{port}/generate"
+
+ payload = {
+ "prompt": prompt,
+ "negative_prompt": kwargs.get("negative_prompt", ""),
+ "image_urls": [image_url],
+ "width": kwargs.get("width", 768),
+ "height": kwargs.get("height", 512),
+ "num_inference_steps": kwargs.get("num_inference_steps", 40),
+ "guidance_scale": kwargs.get("guidance_scale", 4.0),
+ "seed": kwargs.get("seed", 1234),
+ "num_frames": kwargs.get("num_frames", 121),
+ "fps": kwargs.get("fps", 24),
+ }
+
+ if "output_format" in kwargs:
+ payload["output_format"] = kwargs["output_format"]
+ if "output_dir" in kwargs:
+ payload["output_dir"] = kwargs["output_dir"]
+
+ response = requests.post(url, json=payload, timeout=1800)
+ response.raise_for_status()
+ result = response.json()
+
+ assert (
+ "video" in result and result["video"] is not None
+ ), f"No video in response: keys={list(result.keys())}"
+
+ if payload.get("output_format", "base64") == "path":
+ filename = result["video"]
+ assert os.path.exists(filename)
+ print(f"Saved: {filename}")
+ return filename
+
+ video_data = base64.b64decode(result["video"])
+ filename = f"{name}.mp4"
+ with open(filename, "wb") as f:
+ f.write(video_data)
+ print(f"Saved: {filename}")
+ return filename
+
+
+def test_ltx2_image2video():
+ # Align with upstream diffusers LTX2 example.
+ image_url = (
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/"
+ "diffusers/astronaut.jpg"
+ )
+
+ mode = os.environ.get("CACHE_DIT_LTX2_MODE", "base").lower()
+ if mode not in ("base", "lora"):
+ raise ValueError("CACHE_DIT_LTX2_MODE must be 'base' or 'lora'")
+
+ if mode == "base":
+ prompt = (
+ "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling "
+ "apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, "
+ "floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, "
+ "weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, "
+ "the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast "
+ "depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the "
+ "foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate "
+ "low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot."
+ )
+ negative_prompt = (
+ "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, "
+ "motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
+ )
+ return call_api(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image_url=image_url,
+ name="ltx2_base_i2v",
+ seed=0,
+ width=768,
+ height=512,
+ num_frames=121,
+ fps=24,
+ num_inference_steps=40,
+ guidance_scale=4.0,
+ )
+
+ # LoRA Canny Control variant. The API schema doesn't expose an explicit canny control image field,
+ # so we still provide `image_urls` as the conditioning image and a canny-oriented prompt.
+ prompt = (
+ "Canny edge control style: keep structure of the input image, "
+ "turn it into a cinematic animated sequence with strong edges and clean contours."
+ )
+ negative_prompt = (
+ "worst quality, low quality, jpeg artifacts, blurry, deformed, disfigured, "
+ "extra limbs, extra fingers, text, watermark"
+ )
+ return call_api(
+ prompt=prompt,
+ image_url=image_url,
+ name="ltx2_lora_canny_i2v",
+ negative_prompt=negative_prompt,
+ seed=123,
+ fps=24,
+ )
+
+
+if __name__ == "__main__":
+ print("Testing LTX-2 Image-to-Video Serving...")
+ print(f"CACHE_DIT_LTX2_MODE={os.environ.get('CACHE_DIT_LTX2_MODE', 'base')}")
+ test_ltx2_image2video()
+ print("Done.")
diff --git a/tests/serving/test_ltx2_text2video.py b/tests/serving/test_ltx2_text2video.py
new file mode 100644
index 000000000..12726bb4d
--- /dev/null
+++ b/tests/serving/test_ltx2_text2video.py
@@ -0,0 +1,156 @@
+"""Test LTX-2 Text-to-Video model serving.
+
+Server setup (base model):
+ CACHE_DIT_LTX2_PIPELINE=t2v CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
+ -m cache_dit.serve.serve \
+ --model-path Lightricks/LTX-2 \
+ --parallel-type ulysses \
+ --parallel-text-encoder \
+ --parallel-vae \
+ --cache \
+ --ulysses-anything
+
+ Server setup (base model, TP4):
+ CACHE_DIT_LTX2_PIPELINE=t2v CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
+ -m cache_dit.serve.serve \
+ --model-path Lightricks/LTX-2 \
+ --parallel-type tp \
+ --cache
+
+Server setup (base + LoRA):
+ # NOTE: the LoRA weight filename may differ. Common filenames include:
+ # - pytorch_lora_weights.safetensors
+ # - adapter_model.safetensors
+ CACHE_DIT_LTX2_PIPELINE=t2v CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
+ -m cache_dit.serve.serve \
+ --model-path Lightricks/LTX-2 \
+ --lora-path Lightricks/LTX-2-19b-IC-LoRA-Canny-Control \
+ --lora-name ltx-2-19b-ic-lora-canny-control.safetensors \
+ --parallel-type ulysses \
+ --parallel-text-encoder \
+ --parallel-vae \
+ --cache \
+ --ulysses-anything
+
+ Server setup (base + LoRA, TP4):
+ # NOTE: the LoRA weight filename may differ. Common filenames include:
+ # - pytorch_lora_weights.safetensors
+ # - adapter_model.safetensors
+ CACHE_DIT_LTX2_PIPELINE=t2v CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 \
+ -m cache_dit.serve.serve \
+ --model-path Lightricks/LTX-2 \
+ --lora-path Lightricks/LTX-2-19b-IC-LoRA-Canny-Control \
+ --lora-name ltx-2-19b-ic-lora-canny-control.safetensors \
+ --parallel-type tp \
+ --cache
+
+Usage:
+ # Base model
+ CACHE_DIT_LTX2_MODE=base python -m pytest -q cache-dit/tests/serving/test_ltx2_text2video.py
+
+ # LoRA model
+ CACHE_DIT_LTX2_MODE=lora python -m pytest -q cache-dit/tests/serving/test_ltx2_text2video.py
+"""
+
+import os
+import base64
+import requests
+
+
+def call_api(prompt, name="ltx2_t2v", **kwargs):
+ host = os.environ.get("CACHE_DIT_HOST", "localhost")
+ port = int(os.environ.get("CACHE_DIT_PORT", 8000))
+ url = f"http://{host}:{port}/generate"
+
+ payload = {
+ "prompt": prompt,
+ "negative_prompt": kwargs.get("negative_prompt", ""),
+ "width": kwargs.get("width", 768),
+ "height": kwargs.get("height", 512),
+ "num_inference_steps": kwargs.get("num_inference_steps", 40),
+ "guidance_scale": kwargs.get("guidance_scale", 4.0),
+ "seed": kwargs.get("seed", 1234),
+ "num_frames": kwargs.get("num_frames", 121),
+ "fps": kwargs.get("fps", 24),
+ }
+
+ if "output_format" in kwargs:
+ payload["output_format"] = kwargs["output_format"]
+ if "output_dir" in kwargs:
+ payload["output_dir"] = kwargs["output_dir"]
+
+ response = requests.post(url, json=payload, timeout=1800)
+ response.raise_for_status()
+ result = response.json()
+
+ assert (
+ "video" in result and result["video"] is not None
+ ), f"No video in response: keys={list(result.keys())}"
+
+ if payload.get("output_format", "base64") == "path":
+ filename = result["video"]
+ assert os.path.exists(filename)
+ print(f"Saved: {filename}")
+ return filename
+
+ video_data = base64.b64decode(result["video"])
+ filename = f"{name}.mp4"
+ with open(filename, "wb") as f:
+ f.write(video_data)
+ print(f"Saved: {filename}")
+ return filename
+
+
+def test_ltx2_text2video():
+ mode = os.environ.get("CACHE_DIT_LTX2_MODE", "base").lower()
+ if mode not in ("base", "lora"):
+ raise ValueError("CACHE_DIT_LTX2_MODE must be 'base' or 'lora'")
+
+ if mode == "base":
+ prompt = (
+ "A cinematic tracking shot through a neon-lit rainy cyberpunk street at night. "
+ "Reflections shimmer on wet asphalt, holographic signs flicker, and steam rises from vents. "
+ "A sleek motorbike glides past the camera in slow motion, droplets scattering in the air. "
+ "Smooth camera motion, natural parallax, ultra-realistic detail, cinematic lighting, film look."
+ )
+ negative_prompt = (
+ "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, "
+ "motion artifacts, bad anatomy, ugly, transition, static, text, watermark."
+ )
+ return call_api(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ name="ltx2_base_t2v",
+ seed=0,
+ width=768,
+ height=512,
+ num_frames=121,
+ fps=24,
+ num_inference_steps=40,
+ guidance_scale=4.0,
+ )
+
+ # LoRA mode: for validation that serving + LoRA wiring works end-to-end.
+ # The same text-only request schema applies.
+ prompt = (
+ "Cinematic animated sequence, strong edges and clean contours, high contrast lighting, "
+ "a robot walks through a foggy corridor, smooth camera dolly in."
+ )
+ negative_prompt = (
+ "worst quality, low quality, jpeg artifacts, blurry, deformed, disfigured, "
+ "extra limbs, extra fingers, text, watermark"
+ )
+ return call_api(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ name="ltx2_lora_t2v",
+ seed=123,
+ fps=24,
+ )
+
+
+if __name__ == "__main__":
+ print("Testing LTX-2 Text-to-Video Serving...")
+ print(f"CACHE_DIT_LTX2_MODE={os.environ.get('CACHE_DIT_LTX2_MODE', 'base')}")
+ test_ltx2_text2video()
+ print("Done.")
diff --git a/tests/serving/test_qwen_image_edit_lora.py b/tests/serving/test_qwen_image_edit_lora.py
new file mode 100644
index 000000000..42953599e
--- /dev/null
+++ b/tests/serving/test_qwen_image_edit_lora.py
@@ -0,0 +1,110 @@
+"""Qwen-Image-Edit + LoRA serving test.
+
+Server (single GPU):
+ CUDA_VISIBLE_DEVICES=0 python -m cache_dit.serve.serve \
+ --model-path Qwen/Qwen-Image-Edit-2511 \
+ --lora-path /home/lmsys/bbuf/qwen-image-lora3 \
+ --lora-name lora3-diffusers.safetensors \
+ --cache
+
+Server (2 GPUs, ulysses2):
+ CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node=2 \
+ -m cache_dit.serve.serve \
+ --model-path Qwen/Qwen-Image-Edit-2511 \
+ --lora-path /home/lmsys/bbuf/qwen-image-lora3 \
+ --lora-name lora3-diffusers.safetensors \
+ --parallel-type ulysses \
+ --parallel-text-encoder \
+ --cache \
+ --ulysses-anything
+
+Run client test:
+ CACHE_DIT_HOST=localhost CACHE_DIT_PORT=8000 python -m pytest -q \
+ cache-dit/tests/serving/test_qwen_image_edit_lora.py
+"""
+
+import base64
+import os
+from io import BytesIO
+
+import requests
+from PIL import Image
+
+
+def call_api(prompt, image_paths, name="qwen_image_edit_lora", **kwargs):
+ host = os.environ.get("CACHE_DIT_HOST", "localhost")
+ port = int(os.environ.get("CACHE_DIT_PORT", 8000))
+ url = f"http://{host}:{port}/generate"
+
+ payload = {
+ "prompt": prompt,
+ "width": kwargs["width"],
+ "height": kwargs["height"],
+ "num_inference_steps": kwargs.get("num_inference_steps", 30),
+ "guidance_scale": kwargs.get("guidance_scale", 4.0),
+ "seed": kwargs.get("seed", 1),
+ "num_images": kwargs.get("num_images", 1),
+ "image_urls": image_paths,
+ }
+
+ if "output_format" in kwargs:
+ payload["output_format"] = kwargs["output_format"]
+ if "output_dir" in kwargs:
+ payload["output_dir"] = kwargs["output_dir"]
+
+ response = requests.post(url, json=payload, timeout=1800)
+ response.raise_for_status()
+ result = response.json()
+
+ assert (
+ "images" in result and result["images"]
+ ), f"No images in response: keys={list(result.keys())}"
+
+ if payload.get("output_format", "base64") == "path":
+ filename = result["images"][0]
+ assert os.path.exists(filename)
+ img = Image.open(filename)
+ print(f"Saved: {filename} ({img.size[0]}x{img.size[1]})")
+ return filename
+
+ img_data = base64.b64decode(result["images"][0])
+ img = Image.open(BytesIO(img_data)).convert("RGB")
+ filename = f"{name}.png"
+ img.save(filename)
+ print(f"Saved: {filename} ({img.size[0]}x{img.size[1]})")
+ return filename
+
+
+def test_qwen_image_edit_lora():
+ images_dir = os.path.join(os.path.dirname(__file__), "images")
+ image_path_0 = os.path.join(images_dir, "input_0.png")
+ image_path_1 = os.path.join(images_dir, "input_1.png")
+
+ img0 = Image.open(image_path_0).convert("RGB")
+ width, height = img0.size
+
+ prompt = (
+ "The first image is the original image, and the second image is erasing the corresponding area, "
+ "Please use the following instructions to perform image repair. "
+ "Using the provided image, first erase all texts (no language restrictions, titles, slogan, date, number, "
+ "logo text) from the image, erase all texts from the image, and finally, keep everything else unchanged."
+ )
+
+ filename = call_api(
+ prompt=prompt,
+ image_paths=[image_path_0, image_path_1],
+ name="qwen_image_edit_lora",
+ seed=1,
+ num_inference_steps=30,
+ guidance_scale=4.0,
+ width=width,
+ height=height,
+ )
+
+ out_img = Image.open(filename)
+ assert out_img.size == (width, height)
+ return filename
+
+
+if __name__ == "__main__":
+ test_qwen_image_edit_lora()
diff --git a/tests/serving/test_qwen_image_lightning_serving.py b/tests/serving/test_qwen_image_lightning_serving.py
new file mode 100644
index 000000000..3fe8d3178
--- /dev/null
+++ b/tests/serving/test_qwen_image_lightning_serving.py
@@ -0,0 +1,182 @@
+"""Test Qwen-Image-Lightning LoRA model serving.
+
+This test demonstrates how to use cache-dit serving with LoRA models.
+Qwen-Image-Lightning is a distilled model that generates high-quality images in 4 or 8 steps.
+
+Server setup:
+ CUDA_VISIBLE_DEVICES=7 torchrun --nproc_per_node=1 -m cache_dit.serve.serve \
+ --model-path Qwen/Qwen-Image \
+ --lora-path lightx2v/Qwen-Image-Lightning \
+ --lora-name Qwen-Image-Lightning-8steps-V1.1-bf16.safetensors \
+ --cache
+
+For 4-step model:
+ CUDA_VISIBLE_DEVICES=7 torchrun --nproc_per_node=1 -m cache_dit.serve.serve \
+ --model-path Qwen/Qwen-Image \
+ --lora-path lightx2v/Qwen-Image-Lightning \
+ --lora-name Qwen-Image-Lightning-4steps-V1.1-bf16.safetensors \
+ --cache
+
+Reference: https://huggingface.co/lightx2v/Qwen-Image-Lightning
+"""
+
+import os
+import requests
+import base64
+from PIL import Image
+from io import BytesIO
+
+
+def call_api(prompt, name="test", **kwargs):
+ """Call the serving API to generate an image."""
+ host = os.environ.get("CACHE_DIT_HOST", "localhost")
+ port = int(os.environ.get("CACHE_DIT_PORT", 8000))
+ url = f"http://{host}:{port}/generate"
+
+ payload = {
+ "prompt": prompt,
+ "width": kwargs.get("width", 1024),
+ "height": kwargs.get("height", 1024),
+ "num_inference_steps": kwargs.get("num_inference_steps", 8),
+ "guidance_scale": kwargs.get("guidance_scale", 1.0),
+ "seed": kwargs.get("seed", 0),
+ "num_images": kwargs.get("num_images", 1),
+ }
+
+ if "include_stats" in kwargs:
+ payload["include_stats"] = kwargs["include_stats"]
+
+ if "negative_prompt" in kwargs:
+ payload["negative_prompt"] = kwargs["negative_prompt"]
+
+ response = requests.post(url, json=payload, timeout=300)
+ response.raise_for_status()
+ result = response.json()
+
+ if "images" not in result or not result["images"]:
+ print("No images in response")
+ return None
+
+ # Save all generated images
+ filenames = []
+ for idx, img_base64 in enumerate(result["images"]):
+ img_data = base64.b64decode(img_base64)
+ image = Image.open(BytesIO(img_data))
+
+ if kwargs.get("num_images", 1) > 1:
+ filename = f"{name}_{idx}.png"
+ else:
+ filename = f"{name}.png"
+
+ image.save(filename)
+ print(f"Saved: {filename} ({image.size})")
+ filenames.append(filename)
+
+ if "stats" in result and result["stats"]:
+ print(f"Stats: {result['stats']}")
+ if "time_cost" in result:
+ print(f"Time cost: {result['time_cost']:.2f}s")
+
+ return filenames
+
+
+def test_basic_8steps():
+ """Test basic image generation with 8 steps (Lightning-8steps model)."""
+ return call_api(
+ prompt="A beautiful landscape with mountains and a lake, high quality",
+ name="qwen_lightning_8steps",
+ num_inference_steps=8,
+ guidance_scale=1.0,
+ seed=0,
+ )
+
+
+def test_include_stats():
+ """Test include_stats parameter returns stats."""
+ host = os.environ.get("CACHE_DIT_HOST", "localhost")
+ port = int(os.environ.get("CACHE_DIT_PORT", 8000))
+
+ model_info_resp = requests.get(f"http://{host}:{port}/get_model_info", timeout=30)
+ model_info_resp.raise_for_status()
+ enable_cache = bool(model_info_resp.json().get("enable_cache", False))
+
+ result = requests.post(
+ f"http://{host}:{port}/generate",
+ json={
+ "prompt": "A cute puppy playing in the garden",
+ "width": 1024,
+ "height": 1024,
+ "num_inference_steps": 8,
+ "guidance_scale": 1.0,
+ "seed": 456,
+ "num_images": 1,
+ "include_stats": True,
+ },
+ timeout=300,
+ )
+ result.raise_for_status()
+ data = result.json()
+ if enable_cache:
+ assert "stats" in data
+
+
+def test_basic_4steps():
+ """Test basic image generation with 4 steps (Lightning-4steps model)."""
+ return call_api(
+ prompt="A beautiful landscape with mountains and a lake, high quality",
+ name="qwen_lightning_4steps",
+ num_inference_steps=4,
+ guidance_scale=1.0,
+ seed=0,
+ )
+
+
+def test_different_resolution():
+ """Test different image resolution (1536x1024)."""
+ return call_api(
+ prompt="A wide panoramic view of a mountain range at dawn",
+ name="qwen_lightning_landscape",
+ width=1536,
+ height=1024,
+ num_inference_steps=8,
+ guidance_scale=1.0,
+ seed=123,
+ )
+
+
+def test_batch_generation():
+ """Test generating multiple images in one request."""
+ return call_api(
+ prompt="A cute puppy playing in the garden",
+ name="qwen_lightning_batch",
+ num_inference_steps=8,
+ guidance_scale=1.0,
+ num_images=4,
+ seed=456,
+ )
+
+
+if __name__ == "__main__":
+ print("=" * 80)
+ print("Testing Qwen-Image-Lightning LoRA Model Serving")
+ print("=" * 80)
+
+ # Run tests
+ print("\n[1/4] Testing basic 8-step generation...")
+ test_basic_8steps()
+
+ print("\n[1.5/4] Testing include_stats...")
+ test_include_stats()
+
+ print("\n[2/4] Testing basic 4-step generation...")
+ test_basic_4steps()
+
+ print("\n[3/4] Testing different resolution (1536x1024)...")
+ test_different_resolution()
+
+ print("\n[4/4] Testing batch generation (4 images)...")
+ test_batch_generation()
+
+ print("\n" + "=" * 80)
+ print("All tests completed!")
+ print("=" * 80)
diff --git a/tests/serving/test_text_to_image_serving.py b/tests/serving/test_text_to_image_serving.py
new file mode 100644
index 000000000..ff108c904
--- /dev/null
+++ b/tests/serving/test_text_to_image_serving.py
@@ -0,0 +1,25 @@
+import requests
+import base64
+from PIL import Image
+from io import BytesIO
+
+
+def test_text_to_image():
+ response = requests.post(
+ "http://localhost:8000/generate",
+ json={
+ "prompt": "A beautiful sunset over the ocean",
+ "width": 1024,
+ "height": 1024,
+ "num_inference_steps": 50,
+ },
+ )
+
+ img_data = base64.b64decode(response.json()["images"][0])
+ Image.open(BytesIO(img_data)).save("output.png")
+ print("Saved: output.png")
+
+
+if __name__ == "__main__":
+ print("Testing Text-to-Image Serving API...")
+ test_text_to_image()
diff --git a/tests/serving/test_wan_i2v_serving.py b/tests/serving/test_wan_i2v_serving.py
new file mode 100644
index 000000000..0d883aa55
--- /dev/null
+++ b/tests/serving/test_wan_i2v_serving.py
@@ -0,0 +1,142 @@
+import os
+import requests
+import base64
+from PIL import Image
+from io import BytesIO
+
+
+def call_api(prompt, image_url, name="test", **kwargs):
+ host = os.environ.get("CACHE_DIT_HOST", "localhost")
+ port = int(os.environ.get("CACHE_DIT_PORT", 8000))
+ url = f"http://{host}:{port}/generate"
+
+ payload = {
+ "prompt": prompt,
+ "image_urls": [image_url],
+ "width": kwargs.get("width", 832),
+ "height": kwargs.get("height", 480),
+ "num_inference_steps": kwargs.get("num_inference_steps", 50),
+ "guidance_scale": kwargs.get("guidance_scale", 3.5),
+ "seed": kwargs.get("seed", 1234),
+ "num_frames": kwargs.get("num_frames", 49),
+ "fps": kwargs.get("fps", 16),
+ }
+
+ if "output_format" in kwargs:
+ payload["output_format"] = kwargs["output_format"]
+ if "output_dir" in kwargs:
+ payload["output_dir"] = kwargs["output_dir"]
+
+ if "negative_prompt" in kwargs:
+ payload["negative_prompt"] = kwargs["negative_prompt"]
+
+ try:
+ response = requests.post(url, json=payload, timeout=600)
+ response.raise_for_status()
+ result = response.json()
+
+ if "video" not in result or result["video"] is None:
+ return None
+
+ if payload.get("output_format", "base64") == "path":
+ filename = result["video"]
+ assert os.path.exists(filename)
+ print(f"Saved: {filename}")
+ return filename
+ else:
+ video_data = base64.b64decode(result["video"])
+ filename = f"{name}.mp4"
+
+ with open(filename, "wb") as f:
+ f.write(video_data)
+
+ print(f"Saved: {filename}")
+ return filename
+
+ except Exception as e:
+ print(f"Error: {e}")
+ return None
+
+
+def test_basic():
+ image_url = "https://www.modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B/resolve/master/examples/i2v_input.JPG"
+ prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
+
+ return call_api(prompt=prompt, image_url=image_url, name="wan_i2v_basic")
+
+
+def test_with_negative_prompt():
+ image_url = "https://www.modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B/resolve/master/examples/i2v_input.JPG"
+ prompt = "A white cat on a surfboard at the beach, enjoying the summer vacation"
+ negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
+
+ return call_api(
+ prompt=prompt,
+ image_url=image_url,
+ negative_prompt=negative_prompt,
+ name="wan_i2v_negative",
+ seed=42,
+ )
+
+
+def test_short_video():
+ image_url = "https://www.modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B/resolve/master/examples/i2v_input.JPG"
+ prompt = "A cat on a surfboard, gentle waves in the background"
+
+ return call_api(
+ prompt=prompt,
+ image_url=image_url,
+ name="wan_i2v_short",
+ num_frames=25,
+ num_inference_steps=30,
+ seed=777,
+ )
+
+
+def test_with_base64_image():
+ image_url = "https://www.modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B/resolve/master/examples/i2v_input.JPG"
+ response = requests.get(image_url)
+ img = Image.open(BytesIO(response.content))
+
+ buffered = BytesIO()
+ img.save(buffered, format="JPEG")
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
+
+ data_uri = f"data:image/jpeg;base64,{img_base64}"
+
+ prompt = "A cat enjoying the beach vacation"
+
+ return call_api(prompt=prompt, image_url=data_uri, name="wan_i2v_base64", seed=555)
+
+
+def test_path_output():
+ image_url = "https://www.modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B/resolve/master/examples/i2v_input.JPG"
+ prompt = "A cat on a surfboard at the beach, enjoying the summer vacation"
+ return call_api(
+ prompt=prompt,
+ image_url=image_url,
+ name="wan_i2v_path",
+ output_format="path",
+ output_dir="outputs_test",
+ seed=123,
+ )
+
+
+if __name__ == "__main__":
+ print("Testing Wan Image-to-Video Serving...")
+ print("\n1. Basic test:")
+ test_basic()
+
+ print("\n2. With negative prompt:")
+ test_with_negative_prompt()
+
+ print("\n3. Short video:")
+ test_short_video()
+
+ print("\n4. Base64 encoded image:")
+ test_with_base64_image()
+
+ print("\n5. Path output:")
+ test_path_output()
+
+ print("\nAll tests completed!")
diff --git a/tests/serving/test_wan_t2v_serving.py b/tests/serving/test_wan_t2v_serving.py
new file mode 100644
index 000000000..15fda77b3
--- /dev/null
+++ b/tests/serving/test_wan_t2v_serving.py
@@ -0,0 +1,123 @@
+import os
+import requests
+import base64
+
+
+def call_api(prompt, name="test", **kwargs):
+ host = os.environ.get("CACHE_DIT_HOST", "localhost")
+ port = int(os.environ.get("CACHE_DIT_PORT", 8000))
+ url = f"http://{host}:{port}/generate"
+
+ payload = {
+ "prompt": prompt,
+ "width": kwargs.get("width", 832),
+ "height": kwargs.get("height", 480),
+ "num_inference_steps": kwargs.get("num_inference_steps", 30),
+ "guidance_scale": kwargs.get("guidance_scale", 5.0),
+ "seed": kwargs.get("seed", 1234),
+ "num_frames": kwargs.get("num_frames", 49),
+ "fps": kwargs.get("fps", 16),
+ }
+
+ if "output_format" in kwargs:
+ payload["output_format"] = kwargs["output_format"]
+ if "output_dir" in kwargs:
+ payload["output_dir"] = kwargs["output_dir"]
+
+ if "negative_prompt" in kwargs:
+ payload["negative_prompt"] = kwargs["negative_prompt"]
+
+ try:
+ response = requests.post(url, json=payload, timeout=600)
+ response.raise_for_status()
+ result = response.json()
+
+ if "video" not in result or result["video"] is None:
+ return None
+
+ if payload.get("output_format", "base64") == "path":
+ filename = result["video"]
+ assert os.path.exists(filename)
+ print(f"Saved: {filename}")
+ return filename
+ else:
+ video_data = base64.b64decode(result["video"])
+ filename = f"{name}.mp4"
+
+ with open(filename, "wb") as f:
+ f.write(video_data)
+
+ print(f"Saved: {filename}")
+ return filename
+
+ except Exception as e:
+ print(f"Error: {e}")
+ return None
+
+
+def test_basic():
+ return call_api(prompt="A cat walks on the grass, realistic", name="wan_t2v_basic")
+
+
+def test_custom_prompt():
+ return call_api(
+ prompt="A beautiful sunset over the ocean with waves crashing on the shore",
+ name="wan_t2v_sunset",
+ seed=0,
+ )
+
+
+def test_path_output():
+ return call_api(
+ prompt="A cat walks on the grass, realistic",
+ name="wan_t2v_path",
+ output_format="path",
+ output_dir="outputs_test",
+ )
+
+
+def test_with_negative_prompt():
+ negative_prompt = (
+ "Bright tones, overexposed, static, blurred details, subtitles, "
+ "style, works, paintings, images, static, overall gray, worst quality, "
+ "low quality, JPEG compression residue, ugly, incomplete, extra fingers, "
+ "poorly drawn hands, poorly drawn faces, deformed, disfigured, "
+ "misshapen limbs, fused fingers, still picture, messy background, "
+ "three legs, many people in the background, walking backwards"
+ )
+
+ return call_api(
+ prompt="A dog running in a park, high quality, realistic",
+ negative_prompt=negative_prompt,
+ name="wan_t2v_negative",
+ seed=999,
+ )
+
+
+def test_short_video():
+ return call_api(
+ prompt="A bird flying in the sky",
+ name="wan_t2v_short",
+ num_frames=25,
+ num_inference_steps=20,
+ seed=777,
+ )
+
+
+def test_different_resolution():
+ return call_api(
+ prompt="A car driving on a highway",
+ name="wan_t2v_resolution",
+ width=1024,
+ height=576,
+ seed=555,
+ )
+
+
+if __name__ == "__main__":
+ test_basic()
+ test_custom_prompt()
+ test_path_output()
+ test_with_negative_prompt()
+ test_short_video()
+ test_different_resolution()
diff --git a/tests/test_cache_loader.py b/tests/test_cache_loader.py
deleted file mode 100644
index b0bf2eda8..000000000
--- a/tests/test_cache_loader.py
+++ /dev/null
@@ -1,7 +0,0 @@
-import cache_dit
-
-cache_options = cache_dit.load_options(
- "cache_config.yaml",
-)
-
-print(f"cache_options from cache_config.yaml:\n {cache_options}")
diff --git a/tests/test_chrono_edit_cp_native.py b/tests/test_chrono_edit_cp_native.py
deleted file mode 100644
index cf14336aa..000000000
--- a/tests/test_chrono_edit_cp_native.py
+++ /dev/null
@@ -1,110 +0,0 @@
-import os
-import time
-import torch
-import numpy as np
-from PIL import Image
-import torch.distributed as dist
-from diffusers import (
- AutoencoderKLWan,
- ChronoEditTransformer3DModel,
- ChronoEditPipeline,
-)
-from diffusers.quantizers import PipelineQuantizationConfig
-from diffusers import ContextParallelConfig
-from diffusers.utils import load_image
-from transformers import CLIPVisionModel
-
-
-dist.init_process_group(backend="nccl")
-rank = dist.get_rank()
-device = torch.device("cuda", rank % torch.cuda.device_count())
-world_size = dist.get_world_size()
-torch.cuda.set_device(device)
-
-model_id = "nvidia/ChronoEdit-14B-Diffusers"
-model_id = os.environ.get("CHRONO_EDIT_DIR", model_id)
-
-image_encoder = CLIPVisionModel.from_pretrained(
- model_id, subfolder="image_encoder", torch_dtype=torch.float32
-)
-vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
-transformer = ChronoEditTransformer3DModel.from_pretrained(
- model_id, subfolder="transformer", torch_dtype=torch.bfloat16
-)
-
-pipe = ChronoEditPipeline.from_pretrained(
- model_id,
- vae=vae,
- image_encoder=image_encoder,
- transformer=transformer,
- torch_dtype=torch.bfloat16,
- quantization_config=(
- PipelineQuantizationConfig(
- quant_backend="bitsandbytes_4bit",
- quant_kwargs={
- "load_in_4bit": True,
- "bnb_4bit_quant_type": "nf4",
- "bnb_4bit_compute_dtype": torch.bfloat16,
- },
- # text_encoder: ~ 6GiB, transformer: ~ 8GiB, total: ~14GiB
- components_to_quantize=["text_encoder", "transformer"],
- )
- ),
-).to(device)
-
-torch.cuda.empty_cache()
-assert isinstance(pipe.vae, AutoencoderKLWan)
-pipe.vae.enable_tiling()
-
-image = load_image("../examples/data/chrono_edit_example.png")
-
-max_area = 720 * 1280
-aspect_ratio = image.height / image.width
-mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
-height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
-width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
-image = image.resize((width, height))
-
-prompt = (
- "The user wants to transform the image by adding a small, cute mouse sitting inside the floral teacup, enjoying a spa bath. The mouse should appear relaxed and cheerful, with a tiny white bath towel draped over its head like a turban. It should be positioned comfortably in the cup’s liquid, with gentle steam rising around it to blend with the cozy atmosphere. "
- "The mouse’s pose should be natural—perhaps sitting upright with paws resting lightly on the rim or submerged in the tea. The teacup’s floral design, gold trim, and warm lighting must remain unchanged to preserve the original aesthetic. The steam should softly swirl around the mouse, enhancing the spa-like, whimsical mood."
-)
-
-assert isinstance(pipe.transformer, ChronoEditTransformer3DModel)
-pipe.transformer.set_attention_backend("native")
-if world_size > 1:
- pipe.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=world_size))
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe(warmup: bool = False):
- output = pipe(
- image=image,
- prompt=prompt,
- height=height,
- width=width,
- num_frames=5,
- guidance_scale=5.0,
- enable_temporal_reasoning=False,
- num_temporal_reasoning_steps=0,
- num_inference_steps=50 if not warmup else 2,
- generator=torch.Generator("cuda").manual_seed(0),
- ).frames[0]
- output = Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8"))
- return output
-
-
-start = time.time()
-output = run_pipe()
-end = time.time()
-
-if rank == 0:
- time_cost = end - start
- save_path = f"chrono-edit.{world_size}gpus.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- output.save(save_path)
-
-if dist.is_initialized():
- dist.destroy_process_group()
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
deleted file mode 100644
index e603f68f2..000000000
--- a/tests/test_metrics.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import os
-import argparse
-from cache_dit.metrics import compute_psnr
-from cache_dit.metrics import compute_video_psnr
-from cache_dit.metrics import compute_fid # FID
-
-
-def get_args():
- parser = argparse.ArgumentParser(
- description="CacheDiT's Metrics CLI",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
- parser.add_argument(
- "--img-true",
- type=str,
- default=None,
- help="Path to ground truth image",
- )
- parser.add_argument(
- "--img-test",
- type=str,
- default=None,
- help="Path to predicted image",
- )
- parser.add_argument(
- "--video-true",
- type=str,
- default=None,
- help="Path to ground truth video",
- )
- parser.add_argument(
- "--video-test",
- type=str,
- default=None,
- help="Path to predicted video",
- )
- parser.add_argument(
- "--compute-fid",
- "--fid",
- action="store_true",
- default=False,
- help="Compute FID for image",
- )
-
- return parser.parse_args()
-
-
-def main():
- args = get_args()
- print(args)
-
- if args.img_true is not None and args.img_test is not None:
- if any(
- (
- not os.path.exists(args.img_true),
- not os.path.exists(args.img_test),
- )
- ):
- return
- img_psnr, n = compute_psnr(args.img_true, args.img_test)
- print(f"{args.img_true} vs {args.img_test}, Num: {n}, PSNR: {img_psnr}")
- if args.compute_fid:
- img_fid, n = compute_fid(args.img_true, args.img_test)
- print(f"{args.img_true} vs {args.img_test}, Num: {n}, FID: {img_fid}")
- if args.video_true is not None and args.video_test is not None:
- if any(
- (
- not os.path.exists(args.video_true),
- not os.path.exists(args.video_test),
- )
- ):
- return
- video_psnr, n = compute_video_psnr(args.video_true, args.video_test)
- print(f"{args.video_true} vs {args.video_test}, Frames: {n}, PSNR: {video_psnr}")
-
-
-if __name__ == "__main__":
- main()
- # python3 test_metrics.py --img-true true.png --img-test test.png
diff --git a/tests/test_patch_functor.py b/tests/test_patch_functor.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/tests/test_qwen_image_cp_native.py b/tests/test_qwen_image_cp_native.py
deleted file mode 100644
index e87b035f8..000000000
--- a/tests/test_qwen_image_cp_native.py
+++ /dev/null
@@ -1,99 +0,0 @@
-import os
-import time
-import torch
-import torch.distributed as dist
-from diffusers import (
- QwenImagePipeline,
- QwenImageTransformer2DModel,
- ContextParallelConfig,
-)
-
-
-def maybe_init_distributed():
- if not dist.is_initialized():
- dist.init_process_group("nccl")
- rank = dist.get_rank()
- device = torch.device("cuda", rank % torch.cuda.device_count())
- torch.cuda.set_device(device)
- return rank, device
-
-
-def maybe_destroy_distributed():
- if dist.is_initialized():
- dist.destroy_process_group()
-
-
-rank, device = maybe_init_distributed()
-
-pipe = QwenImagePipeline.from_pretrained(
- os.environ.get(
- "QWEN_IMAGE_DIR",
- "Qwen/Qwen-Image",
- ),
- torch_dtype=torch.bfloat16,
-)
-
-# NOTE: Enable cpu offload before enabling parallelism will
-# raise shape error after first pipe call, so we enable it after.
-# It seems a bug of diffusers that cpu offload is not fully
-# compatible with context parallelism, visa versa.
-# pipe.enable_model_cpu_offload(device=device)
-
-assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
-# pipe.transformer.set_attention_backend("flash")
-pipe.transformer.set_attention_backend("_native_cudnn")
-pipe.transformer.enable_parallelism(
- config=ContextParallelConfig(ulysses_degree=dist.get_world_size())
-)
-
-# NOTE: Enable cpu offload after enabling parallelism
-pipe.enable_model_cpu_offload(device=device)
-
-# assert isinstance(pipe.vae, AutoencoderKLQwenImage)
-# pipe.vae.enable_tiling()
-
-positive_magic = {
- "en": ", Ultra HD, 4K, cinematic composition.", # for english prompt
- "zh": ", 超清,4K,电影级构图.", # for chinese prompt
-}
-
-# Generate image
-prompt = """A coffee shop entrance features a chalkboard sign reading "Qwen Coffee 😊 $2 per cup," with a neon light beside it displaying "通义千问". Next to it hangs a poster showing a beautiful Chinese woman, and beneath the poster is written "π≈3.1415926-53589793-23846264-33832795-02384197". Ultra HD, 4K, cinematic composition"""
-
-# using an empty string if you do not have specific concept to remove
-negative_prompt = " "
-
-pipe.set_progress_bar_config(disable=rank != 0)
-
-
-def run_pipe():
- # do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
- image = pipe(
- prompt=prompt + positive_magic["en"],
- negative_prompt=negative_prompt,
- width=1024,
- height=1024,
- num_inference_steps=50,
- true_cfg_scale=4.0,
- generator=torch.Generator(device="cpu").manual_seed(42),
- ).images[0]
-
- return image
-
-
-# warmup
-_ = run_pipe() # always work
-
-start = time.time()
-image = run_pipe() # raise error here if cpu offload is enabled before parallelism
-end = time.time()
-
-
-if rank == 0:
- time_cost = end - start
- save_path = f"qwen-image.cp{dist.get_world_size()}.png"
- print(f"Time cost: {time_cost:.2f}s")
- print(f"Saving image to {save_path}")
- image.save(save_path)
-
-maybe_destroy_distributed()
diff --git a/tests/test_taylorseer.py b/tests/test_taylorseer.py
deleted file mode 100644
index 02eca9578..000000000
--- a/tests/test_taylorseer.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import argparse
-import numpy as np
-import matplotlib.pyplot as plt
-from cache_dit.caching.cache_contexts.calibrators import (
- TaylorSeerCalibrator,
-)
-
-
-def get_args():
- parser = argparse.ArgumentParser(description="Test TaylorSeer approximation.")
- parser.add_argument(
- "--n_derivatives",
- "--order",
- type=int,
- default=2,
- help="Number of derivatives to approximate.",
- )
- parser.add_argument(
- "--max_warmup_steps",
- "--warmup",
- type=int,
- default=2,
- help="Number of warmup steps before approximation starts.",
- )
- parser.add_argument(
- "--skip_interval_steps",
- type=int,
- default=2,
- help="Interval of steps to skip for approximation.",
- )
- return parser.parse_args()
-
-
-args = get_args()
-
-
-taylor_seer = TaylorSeerCalibrator(
- n_derivatives=args.n_derivatives,
- max_warmup_steps=args.max_warmup_steps,
- skip_interval_steps=args.skip_interval_steps,
-)
-
-x_values = np.arange(0, 10, 0.1)
-y_true = x_values**2
-
-y_pred = []
-errors = []
-for x in x_values:
- y = x**2
- y_approx = taylor_seer.step(y)
- y_pred.append(y_approx)
- errors.append(abs(y - y_approx))
-
-
-save_path = f"taylorseer_approximation_order_{args.n_derivatives}.png"
-plt.figure(figsize=(10, 5))
-plt.subplot(1, 2, 1)
-plt.plot(x_values, y_true, label="True $y=x^2$")
-plt.plot(
- x_values,
- y_pred,
- "--",
- label=f"TaylorSeer Approximation, Order={args.n_derivatives}",
-)
-plt.legend()
-plt.xlabel("x")
-plt.ylabel("y")
-plt.title("TaylorSeer Approximation Test")
-plt.grid()
-
-plt.subplot(1, 2, 2)
-plt.plot(x_values, errors, color="red", label="Absolute Error")
-plt.legend()
-plt.xlabel("x")
-plt.ylabel("Error")
-plt.title("Approximation Error")
-plt.grid()
-
-plt.tight_layout()
-plt.savefig(save_path)
-print(f"Test completed and saved as {save_path}.")
diff --git a/tests/test_wan_2.2_i2v.py b/tests/test_wan_2.2_i2v.py
deleted file mode 100644
index 855d2d832..000000000
--- a/tests/test_wan_2.2_i2v.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import os
-import torch
-import numpy as np
-from diffusers import WanImageToVideoPipeline
-from diffusers.utils import export_to_video, load_image
-
-model_id = model_id = os.environ.get(
- "WAN_2_2_I2V_DIR",
- "Wan-AI/Wan2.2-I2V-A14B-Diffusers",
-)
-
-dtype = torch.bfloat16
-device = "cuda"
-
-pipe = WanImageToVideoPipeline.from_pretrained(
- model_id,
- torch_dtype=dtype,
- # device_map="balanced"
-)
-
-# issue: https://github.com/huggingface/diffusers/issues/12499
-pipe.enable_model_cpu_offload()
-
-image = load_image(
- "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/wan_i2v_input.JPG"
-)
-max_area = 480 * 832
-aspect_ratio = image.height / image.width
-mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
-height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
-width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
-image = image.resize((width, height))
-prompt = "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside."
-
-negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
-generator = torch.Generator(device=device).manual_seed(0)
-output = pipe(
- image=image,
- prompt=prompt,
- negative_prompt=negative_prompt,
- height=height,
- width=width,
- num_frames=81,
- guidance_scale=3.5,
- num_inference_steps=40,
- generator=generator,
-).frames[0]
-export_to_video(output, "i2v_output.mp4", fps=16)