[Perf] Optimizes loads in gfx950 fp16 decode kernel#313
Conversation
bfd6e6a to
cd88449
Compare
| async_copy.wait_group(0) | ||
| v = program.shared_load_v(v_smem) | ||
| acc = program.compute_pv(p, v, acc) | ||
| with gl.amd.warp_pipeline_stage("issue", priority=1): |
There was a problem hiding this comment.
I have mostly seen warp_pipeline use for ping-pong i.e 2 wave per simd, num_warp=8. This kernel is num_warp=1, can you explain a little how warp_pipeline benefit for num_warp=1 as well?
There was a problem hiding this comment.
Yeah, this is currently num_warps=1 (still investigating perf at >=2). What I found is that even for a single wave, warp_pipeline_stage emits s_setprio + sched_barrier, which help instruction scheduling
| physical_page = program.load_page(start_n + cfg.BLOCK_N) | ||
|
|
||
| with gl.amd.warp_pipeline_stage("qk_softmax", priority=0): | ||
| async_copy.wait_group(1) |
There was a problem hiding this comment.
Do we know if async_copy.wait_group(1) lowers to s_barrier under the hood? asking because if yes we may mess up the ping-pong pattern
There was a problem hiding this comment.
it lowers to s_waitcnt vmcnt(...) at the backend, and I don't think have ping-pong here anyway
Peel the first page-table load into a prologue and prefetch the next page index inside the loop, and mask the page-table lookup so out-of-range pages safely read 0. Split the issue / QK+softmax / PV phases under warp_pipeline_stage priority hints so the single wave issues K/V loads ahead of the dependent compute. A single K/V LDS buffer is kept to preserve occupancy on this bandwidth-bound kernel. The KV cache read keeps async_copy.global_load_to_shared: its 64-bit per-thread addressing avoids the 4 GiB buffer-descriptor window that buffer_load_to_shared is limited to, which a paged KV pool can exceed at production scale. Correctness verified against the fp32 paged reference (non-sliding + sliding). Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
cd88449 to
5766aba
Compare
warp_pipeline_stagesplit into issue / qk_softmax / pv with priority hints.