perf(fsdp): SHARD_GRAD_OP for inference (-13.7%, -1855ms, bit-identical)#52
Open
prashant182 wants to merge 12 commits into
Open
perf(fsdp): SHARD_GRAD_OP for inference (-13.7%, -1855ms, bit-identical)#52prashant182 wants to merge 12 commits into
prashant182 wants to merge 12 commits into
Conversation
Add WanI2VFast.prewarm() to hide kernel-autotune tax (~30% generate() speedup, bit-identical)
Adds an example script that constructs WanI2VFast() once, calls prewarm() once, then runs generate() in a loop over multiple prompts. Borrows the "don't recreate the engine per request" pattern from LLM serving (vLLM continuous batching, SGLang). The cold-start cost (~129s model load + ~7s prewarm = ~136s on 8xH100) is paid once. Each subsequent generate() runs at amortized steady-state ~15.5s. For users running multiple generations from the same starting image, this is a large multiplier on top of the existing prewarm() win from #1. Measured on 8xH100, 3 prompts back-to-back, 480*832 / 81 frames: construct: 128932 ms (once) prewarm: 7020 ms (once) generate[0]: 15587 ms generate[1]: 15202 ms generate[2]: 15160 ms total wall-clock: 181901 ms naive (3 separate invocations of generate_fast.py): ~568s speedup with persistent pipe: 3.13x Speedup grows with more prompts: at 10 calls ≈5x, at 100 calls ≈7x. Library code is unchanged. The example is pure documentation of how to use the existing public API (WanI2VFast, prewarm, generate) correctly for the multi-call case. Files: + examples/persistent_inference.py (new, 174 lines) Tier A: outputs are bit-identical to the canonical baseline at the same seed; the script is functionally equivalent to running generate_fast.py 3 times sequentially, just with model load + prewarm amortized.
Adds a per-pipe-instance cache of T5 encoder outputs keyed by
sha256(prompt). Same-prompt re-encodes hit the dict instead of re-
running umt5-xxl, saving ~360-430 ms per repeat call.
Pattern borrowed from SGLang's RadixAttention: when the same prompt
appears across calls, the encoder output is identical, so cache and
return the prior tensor.
Implementation:
- `self._t5_cache: dict[str, list[Tensor]]` initialized in __init__
- `clear_text_cache()` public method to free the dict
- In generate(), check cache before invoking text_encoder; populate
after first compute.
Bit-identical: cached tensor IS the prior call's tensor (no copy, no
quantization). MD5 preserved across cached vs uncached calls.
Measured on 8xH100, 480*832/81 frames, same prompt, twice:
call[0] (cold T5): 15520.9 ms md5=ed2f8262… (cache_size: 1)
call[1] (cached): 15089.4 ms md5=ed2f8262… (cache_size: 1)
delta: +431.5 ms saved
Bigger when paired with examples/persistent_inference.py (same-prompt
loop). Effectively zero when prompts vary every call.
Δ generate_ms (repeat call): 15521 → 15089 (-3%)
MD5 unchanged.
Tier: A (bit-identical).
docs(examples): add persistent_inference.py for multi-call amortization (3.13x speedup over naive)
feat: add T5 prompt-embedding cache (~430ms/repeat-call, bit-identical)
Step 2.1 of B3 (eliminate .item() syncs). Adds an optional frame_seqlen kwarg to WanModelFast.forward, sp_dit_forward_causal, CausalWanAttentionBlock, CausalWanSelfAttention, and sp_attn_forward_causal. When provided, skips the math.prod(grid_sizes[0][1:]).item() sync at the start of each attention forward (model_fast.py:108, sequence_parallel.py:462) — caller already has the value as a Python int. Default frame_seqlen=None falls back to the original .item() path for any external callers. Bit-identical otherwise. Pipeline (image2video_fast.py) passes frame_seqlen from prewarm() and the generate() chunk loop where the value is already computed.
Step 2.2 of B3. Threads cross_attn_first_call kwarg from WanI2VFast.generate() through WanModelFast.forward / sp_dit_forward_causal / CausalWanAttentionBlock.forward into WanCrossAttention.forward. WanI2VFast tracks _cross_attn_initialized as a Python bool, reset at the top of generate() and flipped True after the first DiT forward. When the kwarg is provided, WanCrossAttention uses it as the gate; otherwise it falls back to the existing crossattn_cache["is_init"].item() check, so external callers (and the prewarm() throwaway forward) keep working. The .fill_(1) on the tensor is preserved to keep cache state consistent for any caller still relying on it.
Step 2.3 of B3. Adds a no-eviction branch in both CausalWanSelfAttention.forward (model_fast.py) and sp_attn_forward_causal (sequence_parallel.py) for the local_attn_size == -1 case (global cache, the default and only path our shipped models use). In that path, both kv_cache["global_end_index"] and kv_cache["local_end_index"] start at 0 and advance by current_end - current_start every forward, so local_end_index always equals current_end and local_start_index equals current_start — both already available as Python ints. Eliminates the two .item() syncs in the previous else-branch. The sliding-window eviction logic (local_attn_size > 0) is preserved verbatim in the elif/else for any caller that re-enables it.
Step 2.4 of B3. In sp_dit_forward_causal, computes seq_lens_int = int(seq_lens) once at the top (replacing two separate int(seq_lens) casts on lines 309-310) and threads it through kwargs into sp_attn_forward_causal. Previously, sp_attn_forward_causal did `seq_lens_int = int(seq_lens)` per attention layer (~32 syncs per forward). Now the per-layer cast is gone; the value arrives as a Python int via the new kwarg. CausalWanSelfAttention.forward and CausalWanAttentionBlock.forward accept the same kwarg for signature parity (ignored in the non-SP path).
perf(dit): eliminate ~5600 .item() syncs from DiT forward (-13% generate())
E8 of autoresearch — see /workspace/lingbot-world-artifacts/EXPERIMENTS.md. Switches DiT FSDP from FULL_SHARD to SHARD_GRAD_OP. Functionally equivalent to FSDP2's reshard_after_forward=False, which FSDP1 doesn't expose directly. Effect: parameters are unsharded after the first all-gather and stay resident across all 35 DiT forwards in a generate(), instead of being re-gathered per-forward. Profiling rationale: AllGather was 91% of NCCL time (975 ms / 2-chunk window) at the post-B3 baseline. The within-forward per-block gathers still fire, but the outer re-gather between forwards is gone — that's where the time went. Memory cost: unsharded 14B-param DiT at bf16 ~ 28GB per rank vs 3.5GB sharded across 8 ranks. On 80GB H100s with VAE+T5+KV cache + activations, we have headroom. Measurement (8×H100, 480×832, 81 frames, seed 42): generate() before 13523 ms (B3 baseline) generate() after 11668 ms Δ -1855 ms (-13.7%) MD5 ed2f82628308a3f8acd9b7935bb84401 (bit-identical, Tier A)
Contributor
Author
|
Nudge for review when you have a moment Single-line config flip in |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Switches DiT FSDP from
FULL_SHARDtoSHARD_GRAD_OP. Parameters stay resident across all 35 DiT forwards in agenerate()instead of being re-gathered per-forward. Bit-identical output (MD5ed2f82628308a3f8acd9b7935bb84401), −13.7% generate().SHARD_GRAD_OPis FSDP1's equivalent of FSDP2'sreshard_after_forward=False(the FSDP2 kwarg isn't accepted by the current PyTorch's FSDP1 constructor).Measurement (8×H100, 480×832, 81 frames, seed 42)
Stacked on top of #51 (B3) for the canonical chain:
generate()prewarm()(PR #48)Why
Profiling the post-B3 baseline showed AllGather was 91% of NCCL time — 975 ms over a 2-chunk window. With
FULL_SHARD, FSDP shards params after each forward and re-gathers on the next call. Across the 35 forwards in onegenerate(), that re-gather fires repeatedly.SHARD_GRAD_OPkeeps params unsharded after the first all-gather — the per-forward re-gather disappears.Memory
Unsharded 14B-param DiT at bf16 ≈ 28 GB per rank vs ~3.5 GB sharded across 8 ranks. On 80 GB H100s with VAE + T5 + KV cache + activations resident, there is headroom; no OOM observed across multiple bench runs.
Tradeoffs / scope
FULL_SHARDdefault would need to come back via a flag — happy to add one if requested.wan/distributed/fsdp.py); ~10 LOC including comment block.Stack
Stacked on PR #51 (B3). Reported diff will shrink to one line + comment once that lands.