Skip to content

[Perf] Optimizes loads in gfx950 fp16 decode kernel#313

Open
Yu-Zhewen wants to merge 1 commit into
lightseekorg:mainfrom
Yu-Zhewen:perf/decode-paged-load-gfx950
Open

[Perf] Optimizes loads in gfx950 fp16 decode kernel#313
Yu-Zhewen wants to merge 1 commit into
lightseekorg:mainfrom
Yu-Zhewen:perf/decode-paged-load-gfx950

Conversation

@Yu-Zhewen
Copy link
Copy Markdown

@Yu-Zhewen Yu-Zhewen commented May 29, 2026

  • peeled + masked page-table prefetch.
  • warp_pipeline_stage split into issue / qk_softmax / pv with priority hints.

@Yu-Zhewen Yu-Zhewen requested a review from a team as a code owner May 29, 2026 15:34
@Yu-Zhewen Yu-Zhewen force-pushed the perf/decode-paged-load-gfx950 branch from bfd6e6a to cd88449 Compare May 29, 2026 18:05
async_copy.wait_group(0)
v = program.shared_load_v(v_smem)
acc = program.compute_pv(p, v, acc)
with gl.amd.warp_pipeline_stage("issue", priority=1):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have mostly seen warp_pipeline use for ping-pong i.e 2 wave per simd, num_warp=8. This kernel is num_warp=1, can you explain a little how warp_pipeline benefit for num_warp=1 as well?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is currently num_warps=1 (still investigating perf at >=2). What I found is that even for a single wave, warp_pipeline_stage emits s_setprio + sched_barrier, which help instruction scheduling

physical_page = program.load_page(start_n + cfg.BLOCK_N)

with gl.amd.warp_pipeline_stage("qk_softmax", priority=0):
async_copy.wait_group(1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know if async_copy.wait_group(1) lowers to s_barrier under the hood? asking because if yes we may mess up the ping-pong pattern

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it lowers to s_waitcnt vmcnt(...) at the backend, and I don't think have ping-pong here anyway

Copy link
Copy Markdown

@raikonenfnu raikonenfnu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good stuff Zhewen! have a couple questions :)

Peel the first page-table load into a prologue and prefetch the next page
index inside the loop, and mask the page-table lookup so out-of-range pages
safely read 0. Split the issue / QK+softmax / PV phases under
warp_pipeline_stage priority hints so the single wave issues K/V loads ahead
of the dependent compute. A single K/V LDS buffer is kept to preserve
occupancy on this bandwidth-bound kernel.

The KV cache read keeps async_copy.global_load_to_shared: its 64-bit
per-thread addressing avoids the 4 GiB buffer-descriptor window that
buffer_load_to_shared is limited to, which a paged KV pool can exceed at
production scale.

Correctness verified against the fp32 paged reference (non-sliding + sliding).

Co-authored-by: Cursor <cursoragent@cursor.com>
Signed-off-by: Yu-Zhewen <zhewenyu@amd.com>
@Yu-Zhewen Yu-Zhewen force-pushed the perf/decode-paged-load-gfx950 branch from cd88449 to 5766aba Compare June 1, 2026 14:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants