Skip to content

perf(dit): eliminate ~5600 .item() syncs from DiT forward (-13% generate())#51

Merged
JingyeChen merged 9 commits into
Robbyant:mainfrom
prashant182:feat/b3-eliminate-item-calls
May 20, 2026
Merged

perf(dit): eliminate ~5600 .item() syncs from DiT forward (-13% generate())#51
JingyeChen merged 9 commits into
Robbyant:mainfrom
prashant182:feat/b3-eliminate-item-calls

Conversation

@prashant182
Copy link
Copy Markdown
Contributor

@prashant182 prashant182 commented May 19, 2026

Summary

Eliminates ~5600 .item() / int() syncs from the DiT forward by threading four precomputed values as kwargs and adding a no-eviction fast-path for the default local_attn_size = -1 case. Each removed call was a cudaStreamSynchronize pessimizing pipelining; aten::item was the #1 CPU op in the profile.

Bit-identical output — MD5 ed2f82628308a3f8acd9b7935bb84401 preserved.

Measurement (8× H100, 480×832, 81 frames, seed 42)

End-to-end optimization chain on the same canonical config:

Stage generate() vs prior vs unoptimized
Unoptimized baseline 22146 ms
+ prewarm() (PR #48) 15564 ms −30% −30%
+ T5 prompt cache (PR #50, repeat call) 15134 ms −2.8% −32%
+ this PR (B3) 13523 ms −10.6% −39%

B3 alone, isolating the sync-elimination win:

Metric Pre-B3 Post-B3 Δ
aten::item calls (2-chunk window) 6044 434 −93%
aten::item total time 506 ms 1.3 ms −99.7%
Output MD5 ed2f8262… ed2f8262… identical

What changed

Four hot sync sites, all replaceable without any math change:

  1. frame_seqlen kwargmath.prod(grid_sizes[0][1:]).item() is computed once in generate() and threaded through. Default None falls back to the old .item().
  2. cross_attn_first_call boolcrossattn_cache["is_init"].item() == 0 replaced with a pipe-level Python bool reset per generate(). The .fill_(1) is kept for caller-observable cache state.
  3. local_attn_size == -1 fast-path — when the cache is global (the only path the shipped checkpoints use), local_end_index == current_end always, so the two .item() cache-index reads are skipped. Sliding-window eviction (local_attn_size > 0) is preserved verbatim under elif.
  4. seq_lens_int kwargint(seq_lens) lifted out of sp_attn_forward_causal (was running 32×/forward) into one cast at the top of sp_dit_forward_causal.

All new kwargs default to None; external callers are unaffected.

Validation

  • Isolated test (5 cases, tiny dims, no FSDP/SP): single-call + 7-chunk bit-identity; fast path emits 0 aten::item under torch.profiler; baseline emits expected sync events.
  • End-to-end gates: bench MD5 matches the locked baseline; torch.profiler confirms the 93% call-count reduction; wall-clock strictly improved.

Stack & out of scope

Sits on top of PR #48 (prewarm); reported diff shrinks once that lands.

Not in this PR (separate follow-ups): torch.compile (current_start slice-indexing still triggers per-value recompiles — needs index_copy_); sliding-window eviction refactor (dead code for shipped configs); FSDP reshard_after_forward=False.

prashant182 and others added 9 commits May 17, 2026 17:35
Add WanI2VFast.prewarm() to hide kernel-autotune tax (~30% generate() speedup, bit-identical)
Adds an example script that constructs WanI2VFast() once, calls
prewarm() once, then runs generate() in a loop over multiple prompts.
Borrows the "don't recreate the engine per request" pattern from LLM
serving (vLLM continuous batching, SGLang).

The cold-start cost (~129s model load + ~7s prewarm = ~136s on 8xH100)
is paid once. Each subsequent generate() runs at amortized steady-state
~15.5s. For users running multiple generations from the same starting
image, this is a large multiplier on top of the existing prewarm() win
from #1.

Measured on 8xH100, 3 prompts back-to-back, 480*832 / 81 frames:

  construct:                 128932 ms (once)
  prewarm:                     7020 ms (once)
  generate[0]:                15587 ms
  generate[1]:                15202 ms
  generate[2]:                15160 ms
  total wall-clock:          181901 ms

  naive (3 separate invocations of generate_fast.py):  ~568s
  speedup with persistent pipe:                         3.13x

Speedup grows with more prompts: at 10 calls ≈5x, at 100 calls ≈7x.

Library code is unchanged. The example is pure documentation of how to
use the existing public API (WanI2VFast, prewarm, generate) correctly
for the multi-call case.

Files:
  + examples/persistent_inference.py  (new, 174 lines)

Tier A: outputs are bit-identical to the canonical baseline at the same
seed; the script is functionally equivalent to running generate_fast.py
3 times sequentially, just with model load + prewarm amortized.
Adds a per-pipe-instance cache of T5 encoder outputs keyed by
sha256(prompt). Same-prompt re-encodes hit the dict instead of re-
running umt5-xxl, saving ~360-430 ms per repeat call.

Pattern borrowed from SGLang's RadixAttention: when the same prompt
appears across calls, the encoder output is identical, so cache and
return the prior tensor.

Implementation:
  - `self._t5_cache: dict[str, list[Tensor]]` initialized in __init__
  - `clear_text_cache()` public method to free the dict
  - In generate(), check cache before invoking text_encoder; populate
    after first compute.

Bit-identical: cached tensor IS the prior call's tensor (no copy, no
quantization). MD5 preserved across cached vs uncached calls.

Measured on 8xH100, 480*832/81 frames, same prompt, twice:

  call[0] (cold T5):  15520.9 ms  md5=ed2f8262… (cache_size: 1)
  call[1] (cached):   15089.4 ms  md5=ed2f8262… (cache_size: 1)
  delta:              +431.5 ms saved

Bigger when paired with examples/persistent_inference.py (same-prompt
loop). Effectively zero when prompts vary every call.

Δ generate_ms (repeat call): 15521 → 15089  (-3%)
MD5 unchanged.
Tier: A (bit-identical).
docs(examples): add persistent_inference.py for multi-call amortization (3.13x speedup over naive)
feat: add T5 prompt-embedding cache (~430ms/repeat-call, bit-identical)
Step 2.1 of B3 (eliminate .item() syncs). Adds an optional frame_seqlen
kwarg to WanModelFast.forward, sp_dit_forward_causal, CausalWanAttentionBlock,
CausalWanSelfAttention, and sp_attn_forward_causal. When provided, skips
the math.prod(grid_sizes[0][1:]).item() sync at the start of each attention
forward (model_fast.py:108, sequence_parallel.py:462) — caller already has
the value as a Python int.

Default frame_seqlen=None falls back to the original .item() path for any
external callers. Bit-identical otherwise.

Pipeline (image2video_fast.py) passes frame_seqlen from prewarm() and the
generate() chunk loop where the value is already computed.
Step 2.2 of B3. Threads cross_attn_first_call kwarg from WanI2VFast.generate()
through WanModelFast.forward / sp_dit_forward_causal /
CausalWanAttentionBlock.forward into WanCrossAttention.forward.

WanI2VFast tracks _cross_attn_initialized as a Python bool, reset at the
top of generate() and flipped True after the first DiT forward. When the
kwarg is provided, WanCrossAttention uses it as the gate; otherwise it
falls back to the existing crossattn_cache["is_init"].item() check, so
external callers (and the prewarm() throwaway forward) keep working.

The .fill_(1) on the tensor is preserved to keep cache state consistent
for any caller still relying on it.
Step 2.3 of B3. Adds a no-eviction branch in both
CausalWanSelfAttention.forward (model_fast.py) and sp_attn_forward_causal
(sequence_parallel.py) for the local_attn_size == -1 case (global cache,
the default and only path our shipped models use).

In that path, both kv_cache["global_end_index"] and kv_cache["local_end_index"]
start at 0 and advance by current_end - current_start every forward, so
local_end_index always equals current_end and local_start_index equals
current_start — both already available as Python ints. Eliminates the
two .item() syncs in the previous else-branch.

The sliding-window eviction logic (local_attn_size > 0) is preserved
verbatim in the elif/else for any caller that re-enables it.
Step 2.4 of B3. In sp_dit_forward_causal, computes seq_lens_int = int(seq_lens)
once at the top (replacing two separate int(seq_lens) casts on lines 309-310)
and threads it through kwargs into sp_attn_forward_causal.

Previously, sp_attn_forward_causal did `seq_lens_int = int(seq_lens)` per
attention layer (~32 syncs per forward). Now the per-layer cast is gone;
the value arrives as a Python int via the new kwarg.

CausalWanSelfAttention.forward and CausalWanAttentionBlock.forward accept
the same kwarg for signature parity (ignored in the non-SP path).
@ppatel418
Copy link
Copy Markdown

@JingyeChen Can you please review this? on a BS = 1 to my surprise this is bottlenecked on CPU rather than GPU. I want to try and saturate the HBM and also try to run longer sequence lengths without perf regression. I think you will find that this is faster than previous version by many seconds.

@prashant182
Copy link
Copy Markdown
Contributor Author

prashant182 commented May 20, 2026

Profiler trace uploaded for inspection:

https://github.com/prashant182/lingbot-world/releases/tag/b3-profile-trace

Two files:

  • torch_trace.json.gz (16 MB) — gunzip → drag-drop into https://ui.perfetto.dev to view the full timeline (CPU + CUDA + NCCL).
  • profile_digest.txt — top operations table, NVTX phase breakdown, and a quick how-to-view if you don't want to download the trace itself.

The trace is the post-B3 baseline (chunks 2-3 of one canonical generate() at 8×H100 / 480×832 / 81 frames / seed 42). Searching the timeline for aten::item will show all remaining hits are in VAE / T5 / setup code, with zero hits inside the DiT block.forward ranges win

cc @JingyeChen @Robbyant

@prashant182
Copy link
Copy Markdown
Contributor Author

prashant182 commented May 20, 2026

I have BS = 2 working, to have more than 1 user on the same box. Once above 3 merges in I'll send that one in. It's relatively small #54 . Saturates the entire 8xH100 box fully, Basically 25 FPS end to end streaming latency from GPU to screen altogether.

@JingyeChen
Copy link
Copy Markdown
Collaborator

Thanks @prashant182. The eliminate of .item() operation indeed boosts the efficiency. We have checked and merged this.

@JingyeChen
Copy link
Copy Markdown
Collaborator

@prashant182 By the way, how do you get the profile_digest.txt file? It is quite useful and I am curious to learn about it :D

@prashant182
Copy link
Copy Markdown
Contributor Author

Sure thing! It's torch.profiler Chrome trace JSON parsed by a small Python script that buckets events by category (kernel, NCCL, CPU op, NVTX range) and prints the top by time. NVTX ranges are monkey-patched around T5 / VAE / DiT forwards. Setup is in profile_inference.py in the release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants