Skip to content

feat(batch): enable B>1 DiT forward in SP path (1.42× at B=2, B=1 bit-identical)#54

Open
prashant182 wants to merge 16 commits into
Robbyant:mainfrom
prashant182:feat/e3-batch2-prototype
Open

feat(batch): enable B>1 DiT forward in SP path (1.42× at B=2, B=1 bit-identical)#54
prashant182 wants to merge 16 commits into
Robbyant:mainfrom
prashant182:feat/e3-batch2-prototype

Conversation

@prashant182
Copy link
Copy Markdown
Contributor

Summary

Three minimal changes to sp_dit_forward_causal that unlock batched multi-user inference. Each is a no-op at B=1; together they let B>1 run correctly through the same path.

B=1 bit-identical (MD5 ed2f82628308a3f8acd9b7935bb84401). B=2 throughput: 1.42× (single-chunk forward, two distinct users batched).

What changed

Site Fix Why no-op at B=1
assert len(x) == 1 removed Assertion was always true at B=1; surrounding code already handles a list.
torch.cat(c2ws, dim=1)dim=0 each user keeps its own camera conditioning One-element list — both dims give the same tensor.
t.expand(B, S)t.unsqueeze(1).expand(B, S) works for any B At B=1 the broadcast result is identical.

Verification

B=1 bench:    MD5 ed2f82628308a3f8acd9b7935bb84401 (locked baseline)
              generate() 11690 ms (noise vs prior)

B=2 probe:    user A match (atol=3e-2): PASS
              user B match (atol=3e-2): PASS
              throughput 1.42× (272 ms for 2 users vs 388 ms sequential)

Tolerance is set at bf16 attention precision (flash-attn reduces matmuls in different orders for B=1 vs B=2 → tiny numerical drift, max abs diff 1.8e-2 / mean 2.8e-3 across ~1M elements).

Scope

Lifts the inner DiT forward only. The higher-level generate() loop still constructs B=1 inputs, so production paths are unchanged. Multi-user serving infrastructure (per-user session, continuous-batching scheduler, PagedAttention for heterogeneous timelines) is separate follow-on work.

Stack

Stacks on #51, #52, #53. Reported diff shrinks once those land.

prashant182 and others added 16 commits May 17, 2026 17:35
Add WanI2VFast.prewarm() to hide kernel-autotune tax (~30% generate() speedup, bit-identical)
Adds an example script that constructs WanI2VFast() once, calls
prewarm() once, then runs generate() in a loop over multiple prompts.
Borrows the "don't recreate the engine per request" pattern from LLM
serving (vLLM continuous batching, SGLang).

The cold-start cost (~129s model load + ~7s prewarm = ~136s on 8xH100)
is paid once. Each subsequent generate() runs at amortized steady-state
~15.5s. For users running multiple generations from the same starting
image, this is a large multiplier on top of the existing prewarm() win
from #1.

Measured on 8xH100, 3 prompts back-to-back, 480*832 / 81 frames:

  construct:                 128932 ms (once)
  prewarm:                     7020 ms (once)
  generate[0]:                15587 ms
  generate[1]:                15202 ms
  generate[2]:                15160 ms
  total wall-clock:          181901 ms

  naive (3 separate invocations of generate_fast.py):  ~568s
  speedup with persistent pipe:                         3.13x

Speedup grows with more prompts: at 10 calls ≈5x, at 100 calls ≈7x.

Library code is unchanged. The example is pure documentation of how to
use the existing public API (WanI2VFast, prewarm, generate) correctly
for the multi-call case.

Files:
  + examples/persistent_inference.py  (new, 174 lines)

Tier A: outputs are bit-identical to the canonical baseline at the same
seed; the script is functionally equivalent to running generate_fast.py
3 times sequentially, just with model load + prewarm amortized.
Adds a per-pipe-instance cache of T5 encoder outputs keyed by
sha256(prompt). Same-prompt re-encodes hit the dict instead of re-
running umt5-xxl, saving ~360-430 ms per repeat call.

Pattern borrowed from SGLang's RadixAttention: when the same prompt
appears across calls, the encoder output is identical, so cache and
return the prior tensor.

Implementation:
  - `self._t5_cache: dict[str, list[Tensor]]` initialized in __init__
  - `clear_text_cache()` public method to free the dict
  - In generate(), check cache before invoking text_encoder; populate
    after first compute.

Bit-identical: cached tensor IS the prior call's tensor (no copy, no
quantization). MD5 preserved across cached vs uncached calls.

Measured on 8xH100, 480*832/81 frames, same prompt, twice:

  call[0] (cold T5):  15520.9 ms  md5=ed2f8262… (cache_size: 1)
  call[1] (cached):   15089.4 ms  md5=ed2f8262… (cache_size: 1)
  delta:              +431.5 ms saved

Bigger when paired with examples/persistent_inference.py (same-prompt
loop). Effectively zero when prompts vary every call.

Δ generate_ms (repeat call): 15521 → 15089  (-3%)
MD5 unchanged.
Tier: A (bit-identical).
docs(examples): add persistent_inference.py for multi-call amortization (3.13x speedup over naive)
feat: add T5 prompt-embedding cache (~430ms/repeat-call, bit-identical)
Step 2.1 of B3 (eliminate .item() syncs). Adds an optional frame_seqlen
kwarg to WanModelFast.forward, sp_dit_forward_causal, CausalWanAttentionBlock,
CausalWanSelfAttention, and sp_attn_forward_causal. When provided, skips
the math.prod(grid_sizes[0][1:]).item() sync at the start of each attention
forward (model_fast.py:108, sequence_parallel.py:462) — caller already has
the value as a Python int.

Default frame_seqlen=None falls back to the original .item() path for any
external callers. Bit-identical otherwise.

Pipeline (image2video_fast.py) passes frame_seqlen from prewarm() and the
generate() chunk loop where the value is already computed.
Step 2.2 of B3. Threads cross_attn_first_call kwarg from WanI2VFast.generate()
through WanModelFast.forward / sp_dit_forward_causal /
CausalWanAttentionBlock.forward into WanCrossAttention.forward.

WanI2VFast tracks _cross_attn_initialized as a Python bool, reset at the
top of generate() and flipped True after the first DiT forward. When the
kwarg is provided, WanCrossAttention uses it as the gate; otherwise it
falls back to the existing crossattn_cache["is_init"].item() check, so
external callers (and the prewarm() throwaway forward) keep working.

The .fill_(1) on the tensor is preserved to keep cache state consistent
for any caller still relying on it.
Step 2.3 of B3. Adds a no-eviction branch in both
CausalWanSelfAttention.forward (model_fast.py) and sp_attn_forward_causal
(sequence_parallel.py) for the local_attn_size == -1 case (global cache,
the default and only path our shipped models use).

In that path, both kv_cache["global_end_index"] and kv_cache["local_end_index"]
start at 0 and advance by current_end - current_start every forward, so
local_end_index always equals current_end and local_start_index equals
current_start — both already available as Python ints. Eliminates the
two .item() syncs in the previous else-branch.

The sliding-window eviction logic (local_attn_size > 0) is preserved
verbatim in the elif/else for any caller that re-enables it.
Step 2.4 of B3. In sp_dit_forward_causal, computes seq_lens_int = int(seq_lens)
once at the top (replacing two separate int(seq_lens) casts on lines 309-310)
and threads it through kwargs into sp_attn_forward_causal.

Previously, sp_attn_forward_causal did `seq_lens_int = int(seq_lens)` per
attention layer (~32 syncs per forward). Now the per-layer cast is gone;
the value arrives as a Python int via the new kwarg.

CausalWanSelfAttention.forward and CausalWanAttentionBlock.forward accept
the same kwarg for signature parity (ignored in the non-SP path).
perf(dit): eliminate ~5600 .item() syncs from DiT forward (-13% generate())
E1 of the autoresearch sequence — see /workspace/lingbot-world-artifacts/EXPERIMENTS.md.

Replaces `kv_cache["k"][:, current_start:current_end] = roped_key` with
`kv_cache["k"].index_copy_(1, kv_write_index, roped_key)` in the
local_attn_size == -1 fast-path. kv_write_index is a [seq_lens]-shape
tensor built once per chunk in generate() via torch.arange and threaded
through as a kwarg.

Mirrored in both the non-SP (CausalWanSelfAttention.forward) and SP
(sp_attn_forward_causal) paths.

Why
- Multi-tenant prerequisite: at B>1 different users can be batched into
  one forward, each writing the same positions of their own KV slab.
  Slice-assign with a Python-int range works at B=1 but doesn't
  generalize cleanly; index_copy_ does.
- Graph capture: removes one of two Python-int slice sources keeping the
  DiT forward out of CUDA Graphs. (The cache READ slice
  `cache["k"][:, :local_end_index]` is still Python-int — that's the next
  experiment, E2.)

Verification
- Isolated test (test_e1_index_copy.py): B=1 and B=2, 7-chunk sequences,
  output and cache state bit-equal to slice path.
- End-to-end bench MD5 ed2f82628308a3f8acd9b7935bb84401 (locked).
- generate() 13414 ms — within noise of the post-B3 baseline (13523 ms).
  This commit ships no perf gain on its own; it's the enabler.

Defaults preserved: every new kwarg defaults to None; the slice path
remains the eager fallback for callers that don't pass kv_write_index.
E8 of autoresearch — see /workspace/lingbot-world-artifacts/EXPERIMENTS.md.

Switches DiT FSDP from FULL_SHARD to SHARD_GRAD_OP. Functionally equivalent
to FSDP2's reshard_after_forward=False, which FSDP1 doesn't expose
directly. Effect: parameters are unsharded after the first all-gather
and stay resident across all 35 DiT forwards in a generate(), instead
of being re-gathered per-forward.

Profiling rationale: AllGather was 91% of NCCL time (975 ms / 2-chunk
window) at the post-B3 baseline. The within-forward per-block gathers
still fire, but the outer re-gather between forwards is gone — that's
where the time went.

Memory cost: unsharded 14B-param DiT at bf16 ~ 28GB per rank vs 3.5GB
sharded across 8 ranks. On 80GB H100s with VAE+T5+KV cache + activations,
we have headroom.

Measurement (8×H100, 480×832, 81 frames, seed 42):

   generate()  before  13523 ms (B3 baseline)
   generate()  after   11668 ms
   Δ                  -1855 ms (-13.7%)
   MD5                ed2f82628308a3f8acd9b7935bb84401 (bit-identical, Tier A)
perf(fsdp): SHARD_GRAD_OP for inference (-13.7%, -1855ms)
refactor(kv): index_copy_ for KV-cache writes (multi-tenant primitive)
Three minimal changes to sp_dit_forward_causal that lift the B=1 lock-in
without changing B=1 semantics. Validates the multi-tenant batching
direction (see memory: realtime gameplay server).

Changes
-------
1. Drop `assert len(x) == 1`. The assertion was the only hard B=1 lock;
   the surrounding code already handles a list-of-tensors flow.
2. Camera-conditioning cat: dim=1 → dim=0. The original `torch.cat(...,
   dim=1)` was designed for "multiple videos merged into one batch=1
   sequence." For multi-user, each user must keep its own conditioning,
   so cat along batch dim. At B=1 (a single-element list) `cat(dim=0)`
   and `cat(dim=1)` are identical — no-op for the legacy path.
3. Time-embed broadcast: `t.expand(B, S)` only works when t is shape [1]
   (singleton expandable along dim 0). For B>1 t is shape [B] and expand
   raises. `t.unsqueeze(1).expand(B, S)` works for any B and is
   mathematically identical at B=1.

Verification
------------
- End-to-end bench at B=1: MD5 ed2f82628308a3f8acd9b7935bb84401
  matches the locked baseline. generate() 11690 ms (within noise of
  post-E8 baseline 11668 ms).
- B=2 prototype (probe_e3_batch2_correctness.py): two distinct seeded
  users produce per-batch outputs matching their B=1 references within
  bf16 precision (max abs diff 1.8e-2, mean 2.8e-3 over ~1M elements
  per user). 1.42× throughput speedup at B=2 (per-user 136 ms vs
  198 ms at B=1).

Scope
-----
- Homogeneous batching (same shape across users) only. Heterogeneous
  batching needs PagedAttention; tracked separately.
- Higher-level generate() loop still constructs B=1 inputs; this PR
  unlocks the inner forward path but does not change generate().
- Multi-user serving infrastructure (per-user session, continuous
  batching scheduler) is a separate workstream.
@prashant182
Copy link
Copy Markdown
Contributor Author

prashant182 commented May 21, 2026

Nudge for review when you have a moment Three minimal changes in sp_dit_forward_causal that lift the B=1 lock. Each is a no-op at B=1 (verified: MD5 unchanged on the locked baseline), and together they let B=2 run correctly. 1.42× throughput at B=2, per-batch outputs match B=1 references within bf16 precision. This is the multi-tenant primitive working end to end on top of #53.

cc @Robbyant @JingyeChen

@prashant182
Copy link
Copy Markdown
Contributor Author

Do you guys have the slack or some other realtime communication tool somewhere to shorten the contribution loop let me know and I can join that forum.

@prashant182
Copy link
Copy Markdown
Contributor Author

prashant182 commented May 22, 2026

Hi checking to see if you can review these PRs please? @qiuyu96

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant