perf(dit): eliminate ~5600 .item() syncs from DiT forward (-13% generate())#51
Conversation
Add WanI2VFast.prewarm() to hide kernel-autotune tax (~30% generate() speedup, bit-identical)
Adds an example script that constructs WanI2VFast() once, calls prewarm() once, then runs generate() in a loop over multiple prompts. Borrows the "don't recreate the engine per request" pattern from LLM serving (vLLM continuous batching, SGLang). The cold-start cost (~129s model load + ~7s prewarm = ~136s on 8xH100) is paid once. Each subsequent generate() runs at amortized steady-state ~15.5s. For users running multiple generations from the same starting image, this is a large multiplier on top of the existing prewarm() win from #1. Measured on 8xH100, 3 prompts back-to-back, 480*832 / 81 frames: construct: 128932 ms (once) prewarm: 7020 ms (once) generate[0]: 15587 ms generate[1]: 15202 ms generate[2]: 15160 ms total wall-clock: 181901 ms naive (3 separate invocations of generate_fast.py): ~568s speedup with persistent pipe: 3.13x Speedup grows with more prompts: at 10 calls ≈5x, at 100 calls ≈7x. Library code is unchanged. The example is pure documentation of how to use the existing public API (WanI2VFast, prewarm, generate) correctly for the multi-call case. Files: + examples/persistent_inference.py (new, 174 lines) Tier A: outputs are bit-identical to the canonical baseline at the same seed; the script is functionally equivalent to running generate_fast.py 3 times sequentially, just with model load + prewarm amortized.
Adds a per-pipe-instance cache of T5 encoder outputs keyed by
sha256(prompt). Same-prompt re-encodes hit the dict instead of re-
running umt5-xxl, saving ~360-430 ms per repeat call.
Pattern borrowed from SGLang's RadixAttention: when the same prompt
appears across calls, the encoder output is identical, so cache and
return the prior tensor.
Implementation:
- `self._t5_cache: dict[str, list[Tensor]]` initialized in __init__
- `clear_text_cache()` public method to free the dict
- In generate(), check cache before invoking text_encoder; populate
after first compute.
Bit-identical: cached tensor IS the prior call's tensor (no copy, no
quantization). MD5 preserved across cached vs uncached calls.
Measured on 8xH100, 480*832/81 frames, same prompt, twice:
call[0] (cold T5): 15520.9 ms md5=ed2f8262… (cache_size: 1)
call[1] (cached): 15089.4 ms md5=ed2f8262… (cache_size: 1)
delta: +431.5 ms saved
Bigger when paired with examples/persistent_inference.py (same-prompt
loop). Effectively zero when prompts vary every call.
Δ generate_ms (repeat call): 15521 → 15089 (-3%)
MD5 unchanged.
Tier: A (bit-identical).
docs(examples): add persistent_inference.py for multi-call amortization (3.13x speedup over naive)
feat: add T5 prompt-embedding cache (~430ms/repeat-call, bit-identical)
Step 2.1 of B3 (eliminate .item() syncs). Adds an optional frame_seqlen kwarg to WanModelFast.forward, sp_dit_forward_causal, CausalWanAttentionBlock, CausalWanSelfAttention, and sp_attn_forward_causal. When provided, skips the math.prod(grid_sizes[0][1:]).item() sync at the start of each attention forward (model_fast.py:108, sequence_parallel.py:462) — caller already has the value as a Python int. Default frame_seqlen=None falls back to the original .item() path for any external callers. Bit-identical otherwise. Pipeline (image2video_fast.py) passes frame_seqlen from prewarm() and the generate() chunk loop where the value is already computed.
Step 2.2 of B3. Threads cross_attn_first_call kwarg from WanI2VFast.generate() through WanModelFast.forward / sp_dit_forward_causal / CausalWanAttentionBlock.forward into WanCrossAttention.forward. WanI2VFast tracks _cross_attn_initialized as a Python bool, reset at the top of generate() and flipped True after the first DiT forward. When the kwarg is provided, WanCrossAttention uses it as the gate; otherwise it falls back to the existing crossattn_cache["is_init"].item() check, so external callers (and the prewarm() throwaway forward) keep working. The .fill_(1) on the tensor is preserved to keep cache state consistent for any caller still relying on it.
Step 2.3 of B3. Adds a no-eviction branch in both CausalWanSelfAttention.forward (model_fast.py) and sp_attn_forward_causal (sequence_parallel.py) for the local_attn_size == -1 case (global cache, the default and only path our shipped models use). In that path, both kv_cache["global_end_index"] and kv_cache["local_end_index"] start at 0 and advance by current_end - current_start every forward, so local_end_index always equals current_end and local_start_index equals current_start — both already available as Python ints. Eliminates the two .item() syncs in the previous else-branch. The sliding-window eviction logic (local_attn_size > 0) is preserved verbatim in the elif/else for any caller that re-enables it.
Step 2.4 of B3. In sp_dit_forward_causal, computes seq_lens_int = int(seq_lens) once at the top (replacing two separate int(seq_lens) casts on lines 309-310) and threads it through kwargs into sp_attn_forward_causal. Previously, sp_attn_forward_causal did `seq_lens_int = int(seq_lens)` per attention layer (~32 syncs per forward). Now the per-layer cast is gone; the value arrives as a Python int via the new kwarg. CausalWanSelfAttention.forward and CausalWanAttentionBlock.forward accept the same kwarg for signature parity (ignored in the non-SP path).
|
@JingyeChen Can you please review this? on a BS = 1 to my surprise this is bottlenecked on CPU rather than GPU. I want to try and saturate the HBM and also try to run longer sequence lengths without perf regression. I think you will find that this is faster than previous version by many seconds. |
|
Profiler trace uploaded for inspection: https://github.com/prashant182/lingbot-world/releases/tag/b3-profile-trace Two files:
The trace is the post-B3 baseline (chunks 2-3 of one canonical generate() at 8×H100 / 480×832 / 81 frames / seed 42). Searching the timeline for |
|
I have BS = 2 working, to have more than 1 user on the same box. Once above 3 merges in I'll send that one in. It's relatively small #54 . Saturates the entire 8xH100 box fully, Basically 25 FPS end to end streaming latency from GPU to screen altogether. |
|
Thanks @prashant182. The eliminate of .item() operation indeed boosts the efficiency. We have checked and merged this. |
|
@prashant182 By the way, how do you get the profile_digest.txt file? It is quite useful and I am curious to learn about it :D |
|
Sure thing! It's torch.profiler Chrome trace JSON parsed by a small Python script that buckets events by category (kernel, NCCL, CPU op, NVTX range) and prints the top by time. NVTX ranges are monkey-patched around T5 / VAE / DiT forwards. Setup is in profile_inference.py in the release. |
Summary
Eliminates ~5600
.item()/int()syncs from the DiT forward by threading four precomputed values as kwargs and adding a no-eviction fast-path for the defaultlocal_attn_size = -1case. Each removed call was acudaStreamSynchronizepessimizing pipelining;aten::itemwas the #1 CPU op in the profile.Bit-identical output — MD5
ed2f82628308a3f8acd9b7935bb84401preserved.Measurement (8× H100, 480×832, 81 frames, seed 42)
End-to-end optimization chain on the same canonical config:
generate()prewarm()(PR #48)B3 alone, isolating the sync-elimination win:
aten::itemcalls (2-chunk window)aten::itemtotal timeed2f8262…ed2f8262…What changed
Four hot sync sites, all replaceable without any math change:
frame_seqlenkwarg —math.prod(grid_sizes[0][1:]).item()is computed once ingenerate()and threaded through. DefaultNonefalls back to the old.item().cross_attn_first_callbool —crossattn_cache["is_init"].item() == 0replaced with a pipe-level Python bool reset pergenerate(). The.fill_(1)is kept for caller-observable cache state.local_attn_size == -1fast-path — when the cache is global (the only path the shipped checkpoints use),local_end_index == current_endalways, so the two.item()cache-index reads are skipped. Sliding-window eviction (local_attn_size > 0) is preserved verbatim underelif.seq_lens_intkwarg —int(seq_lens)lifted out ofsp_attn_forward_causal(was running 32×/forward) into one cast at the top ofsp_dit_forward_causal.All new kwargs default to
None; external callers are unaffected.Validation
aten::itemundertorch.profiler; baseline emits expected sync events.torch.profilerconfirms the 93% call-count reduction; wall-clock strictly improved.Stack & out of scope
Sits on top of PR #48 (prewarm); reported diff shrinks once that lands.
Not in this PR (separate follow-ups):
torch.compile(current_startslice-indexing still triggers per-value recompiles — needsindex_copy_); sliding-window eviction refactor (dead code for shipped configs); FSDPreshard_after_forward=False.