Skip to content

perf(fsdp): SHARD_GRAD_OP for inference (-13.7%, -1855ms, bit-identical)#52

Open
prashant182 wants to merge 12 commits into
Robbyant:mainfrom
prashant182:feat/e8-fsdp-no-reshard
Open

perf(fsdp): SHARD_GRAD_OP for inference (-13.7%, -1855ms, bit-identical)#52
prashant182 wants to merge 12 commits into
Robbyant:mainfrom
prashant182:feat/e8-fsdp-no-reshard

Conversation

@prashant182
Copy link
Copy Markdown
Contributor

Summary

Switches DiT FSDP from FULL_SHARD to SHARD_GRAD_OP. Parameters stay resident across all 35 DiT forwards in a generate() instead of being re-gathered per-forward. Bit-identical output (MD5 ed2f82628308a3f8acd9b7935bb84401), −13.7% generate().

SHARD_GRAD_OP is FSDP1's equivalent of FSDP2's reshard_after_forward=False (the FSDP2 kwarg isn't accepted by the current PyTorch's FSDP1 constructor).

Measurement (8×H100, 480×832, 81 frames, seed 42)

Stacked on top of #51 (B3) for the canonical chain:

Stage generate() vs unopt vs prior
Unoptimized 22146 ms
+ prewarm() (PR #48) 15564 ms −30%
+ T5 cache (PR #50) 15134 ms −32% −2.8%
+ B3 (PR #51) 13523 ms −39% −10.6%
+ this PR 11668 ms −47% −13.7%

Why

Profiling the post-B3 baseline showed AllGather was 91% of NCCL time — 975 ms over a 2-chunk window. With FULL_SHARD, FSDP shards params after each forward and re-gathers on the next call. Across the 35 forwards in one generate(), that re-gather fires repeatedly. SHARD_GRAD_OP keeps params unsharded after the first all-gather — the per-forward re-gather disappears.

Memory

Unsharded 14B-param DiT at bf16 ≈ 28 GB per rank vs ~3.5 GB sharded across 8 ranks. On 80 GB H100s with VAE + T5 + KV cache + activations resident, there is headroom; no OOM observed across multiple bench runs.

Tradeoffs / scope

  • Purely an inference-mode optimization. If anyone runs this code in a training context, the FULL_SHARD default would need to come back via a flag — happy to add one if requested.
  • Bit-identical (Tier A): MD5 match verified end-to-end at the locked baseline.
  • One file changed (wan/distributed/fsdp.py); ~10 LOC including comment block.

Stack

Stacked on PR #51 (B3). Reported diff will shrink to one line + comment once that lands.

prashant182 and others added 12 commits May 17, 2026 17:35
Add WanI2VFast.prewarm() to hide kernel-autotune tax (~30% generate() speedup, bit-identical)
Adds an example script that constructs WanI2VFast() once, calls
prewarm() once, then runs generate() in a loop over multiple prompts.
Borrows the "don't recreate the engine per request" pattern from LLM
serving (vLLM continuous batching, SGLang).

The cold-start cost (~129s model load + ~7s prewarm = ~136s on 8xH100)
is paid once. Each subsequent generate() runs at amortized steady-state
~15.5s. For users running multiple generations from the same starting
image, this is a large multiplier on top of the existing prewarm() win
from #1.

Measured on 8xH100, 3 prompts back-to-back, 480*832 / 81 frames:

  construct:                 128932 ms (once)
  prewarm:                     7020 ms (once)
  generate[0]:                15587 ms
  generate[1]:                15202 ms
  generate[2]:                15160 ms
  total wall-clock:          181901 ms

  naive (3 separate invocations of generate_fast.py):  ~568s
  speedup with persistent pipe:                         3.13x

Speedup grows with more prompts: at 10 calls ≈5x, at 100 calls ≈7x.

Library code is unchanged. The example is pure documentation of how to
use the existing public API (WanI2VFast, prewarm, generate) correctly
for the multi-call case.

Files:
  + examples/persistent_inference.py  (new, 174 lines)

Tier A: outputs are bit-identical to the canonical baseline at the same
seed; the script is functionally equivalent to running generate_fast.py
3 times sequentially, just with model load + prewarm amortized.
Adds a per-pipe-instance cache of T5 encoder outputs keyed by
sha256(prompt). Same-prompt re-encodes hit the dict instead of re-
running umt5-xxl, saving ~360-430 ms per repeat call.

Pattern borrowed from SGLang's RadixAttention: when the same prompt
appears across calls, the encoder output is identical, so cache and
return the prior tensor.

Implementation:
  - `self._t5_cache: dict[str, list[Tensor]]` initialized in __init__
  - `clear_text_cache()` public method to free the dict
  - In generate(), check cache before invoking text_encoder; populate
    after first compute.

Bit-identical: cached tensor IS the prior call's tensor (no copy, no
quantization). MD5 preserved across cached vs uncached calls.

Measured on 8xH100, 480*832/81 frames, same prompt, twice:

  call[0] (cold T5):  15520.9 ms  md5=ed2f8262… (cache_size: 1)
  call[1] (cached):   15089.4 ms  md5=ed2f8262… (cache_size: 1)
  delta:              +431.5 ms saved

Bigger when paired with examples/persistent_inference.py (same-prompt
loop). Effectively zero when prompts vary every call.

Δ generate_ms (repeat call): 15521 → 15089  (-3%)
MD5 unchanged.
Tier: A (bit-identical).
docs(examples): add persistent_inference.py for multi-call amortization (3.13x speedup over naive)
feat: add T5 prompt-embedding cache (~430ms/repeat-call, bit-identical)
Step 2.1 of B3 (eliminate .item() syncs). Adds an optional frame_seqlen
kwarg to WanModelFast.forward, sp_dit_forward_causal, CausalWanAttentionBlock,
CausalWanSelfAttention, and sp_attn_forward_causal. When provided, skips
the math.prod(grid_sizes[0][1:]).item() sync at the start of each attention
forward (model_fast.py:108, sequence_parallel.py:462) — caller already has
the value as a Python int.

Default frame_seqlen=None falls back to the original .item() path for any
external callers. Bit-identical otherwise.

Pipeline (image2video_fast.py) passes frame_seqlen from prewarm() and the
generate() chunk loop where the value is already computed.
Step 2.2 of B3. Threads cross_attn_first_call kwarg from WanI2VFast.generate()
through WanModelFast.forward / sp_dit_forward_causal /
CausalWanAttentionBlock.forward into WanCrossAttention.forward.

WanI2VFast tracks _cross_attn_initialized as a Python bool, reset at the
top of generate() and flipped True after the first DiT forward. When the
kwarg is provided, WanCrossAttention uses it as the gate; otherwise it
falls back to the existing crossattn_cache["is_init"].item() check, so
external callers (and the prewarm() throwaway forward) keep working.

The .fill_(1) on the tensor is preserved to keep cache state consistent
for any caller still relying on it.
Step 2.3 of B3. Adds a no-eviction branch in both
CausalWanSelfAttention.forward (model_fast.py) and sp_attn_forward_causal
(sequence_parallel.py) for the local_attn_size == -1 case (global cache,
the default and only path our shipped models use).

In that path, both kv_cache["global_end_index"] and kv_cache["local_end_index"]
start at 0 and advance by current_end - current_start every forward, so
local_end_index always equals current_end and local_start_index equals
current_start — both already available as Python ints. Eliminates the
two .item() syncs in the previous else-branch.

The sliding-window eviction logic (local_attn_size > 0) is preserved
verbatim in the elif/else for any caller that re-enables it.
Step 2.4 of B3. In sp_dit_forward_causal, computes seq_lens_int = int(seq_lens)
once at the top (replacing two separate int(seq_lens) casts on lines 309-310)
and threads it through kwargs into sp_attn_forward_causal.

Previously, sp_attn_forward_causal did `seq_lens_int = int(seq_lens)` per
attention layer (~32 syncs per forward). Now the per-layer cast is gone;
the value arrives as a Python int via the new kwarg.

CausalWanSelfAttention.forward and CausalWanAttentionBlock.forward accept
the same kwarg for signature parity (ignored in the non-SP path).
perf(dit): eliminate ~5600 .item() syncs from DiT forward (-13% generate())
E8 of autoresearch — see /workspace/lingbot-world-artifacts/EXPERIMENTS.md.

Switches DiT FSDP from FULL_SHARD to SHARD_GRAD_OP. Functionally equivalent
to FSDP2's reshard_after_forward=False, which FSDP1 doesn't expose
directly. Effect: parameters are unsharded after the first all-gather
and stay resident across all 35 DiT forwards in a generate(), instead
of being re-gathered per-forward.

Profiling rationale: AllGather was 91% of NCCL time (975 ms / 2-chunk
window) at the post-B3 baseline. The within-forward per-block gathers
still fire, but the outer re-gather between forwards is gone — that's
where the time went.

Memory cost: unsharded 14B-param DiT at bf16 ~ 28GB per rank vs 3.5GB
sharded across 8 ranks. On 80GB H100s with VAE+T5+KV cache + activations,
we have headroom.

Measurement (8×H100, 480×832, 81 frames, seed 42):

   generate()  before  13523 ms (B3 baseline)
   generate()  after   11668 ms
   Δ                  -1855 ms (-13.7%)
   MD5                ed2f82628308a3f8acd9b7935bb84401 (bit-identical, Tier A)
@prashant182
Copy link
Copy Markdown
Contributor Author

prashant182 commented May 21, 2026

Nudge for review when you have a moment Single-line config flip in wan/distributed/fsdp.py (FULL_SHARDSHARD_GRAD_OP). Profile showed AllGather was 91% of NCCL time; this keeps params resident across the 35 forwards per generate(). −13.7%, bit-identical (MD5 match). Lowest-risk PR in the stack: one file, ten lines including comments, independent of #53 and #54.

cc @Robbyant @JingyeChen @qiuyu96

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant