Skip to content

feat(triton/rope): fused QKV split, QK RMSNorm, RoPE, and paged KV cache#2902

Open
hellozhuo-amd wants to merge 10 commits intomainfrom
qkv-norm-rope-cache
Open

feat(triton/rope): fused QKV split, QK RMSNorm, RoPE, and paged KV cache#2902
hellozhuo-amd wants to merge 10 commits intomainfrom
qkv-norm-rope-cache

Conversation

@hellozhuo-amd
Copy link
Copy Markdown

@hellozhuo-amd hellozhuo-amd commented Apr 24, 2026

Add fused_qkv_split_qk_norm_rope_cache (Triton kernel + Python entrypoint): split flat QKV, RMSNorm on Q/K, RoPE (NeoX/GPT-J), optional KV scales, and write descaled K/V into paged caches using slot_mapping.

Support partial rotation aligned with ref_rope_sbhd_fwd, optional rotary_dim_half vs cos/sin, and gated QKV layouts (interleaved vs blocked).

Extend test_fused_qkv_split_qk_rope with cache E2E tests, layout-aware reference, rotary_dim parametrization, and clearer separation from the non-cache fused_qkv_split_qk_rope tests.

Co-authored-by: Tianxing Wu Tianxing.Wu@amd.com

Co-authored-by: Tres Popp Tres.Popp@amd.com

Co-authored-by: Cursor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Add fused_qkv_split_qk_norm_rope_cache (Triton kernel + Python entrypoint):
split flat QKV, RMSNorm on Q/K, RoPE (NeoX/GPT-J), optional KV scales, and
write descaled K/V into paged caches using slot_mapping.

Support partial rotation aligned with ref_rope_sbhd_fwd, optional
rotary_dim_half vs cos/sin, and gated QKV layouts (interleaved vs blocked).

Extend test_fused_qkv_split_qk_rope with cache E2E tests, layout-aware
reference, rotary_dim parametrization, and clearer separation from the
non-cache fused_qkv_split_qk_rope tests.

Made-with: Cursor
@hellozhuo-amd hellozhuo-amd requested review from a team and Copilot April 24, 2026 08:38
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2902 --add-label <label>

Co-authored-by: Tianxing Wu <Tianxing.Wu@amd.com>

Co-authored-by: Tres Popp <Tres.Popp@amd.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new fused Triton implementation for attention pre-processing that combines QKV splitting, Q/K RMSNorm, RoPE application (NeoX/GPT-J), optional KV scaling, and writing K/V into a paged KV cache, along with expanded end-to-end tests to validate the new cache path.

Changes:

  • Added fused_qkv_split_qk_norm_rope_cache Python entrypoint and corresponding Triton kernel.
  • Extended test_fused_qkv_split_qk_rope.py with a torch reference for the cache path and new cache E2E assertions (including gated QKV layouts and partial rotary dim coverage).
  • Refactored test input generation to support gated QKV layouts (interleaved vs blocked) and added more parameter coverage.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
op_tests/triton_tests/rope/test_fused_qkv_split_qk_rope.py Adds cache E2E tests + reference implementation and expands parametrization/grid.
aiter/ops/triton/rope/fused_qkv_split_qk_norm_rope_cache.py New Python wrapper/entrypoint that validates shapes and launches the Triton kernel.
aiter/ops/triton/_triton_kernels/rope/fused_qkv_split_qk_norm_rope_cache.py New fused Triton kernel implementing split + RMSNorm + RoPE + paged KV cache write.
Comments suppressed due to low confidence (1)

op_tests/triton_tests/rope/test_fused_qkv_split_qk_rope.py:117

  • The parametrization grid for this test is extremely large (hundreds to >1k cases when combined), which is likely to cause CI timeouts / long runtimes, especially since each case launches Triton kernels. Consider reducing the cartesian product (e.g., smaller representative sets, split into separate fast/slow tests with markers, or use pytest.mark.parametrize with fewer values and add a dedicated stress test under a slow/optional marker).
@pytest.mark.parametrize("B", [1, 4, 8, 16, 32])
@pytest.mark.parametrize("QH_PER_KH", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("KH", [1, 4])
@pytest.mark.parametrize("D", [64, 128])
@pytest.mark.parametrize("rotate_style", [RotateStyle.GPTJ, RotateStyle.NEOX])
@pytest.mark.parametrize("max_embed_positions", [131072])
@pytest.mark.parametrize(
    "nope, nope_first", [(False, False), (True, False), (True, True)]
)
@pytest.mark.parametrize("reuse_freqs_front_part", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_fused_qkv_split_qk_rope(

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/ops/triton/rope/fused_qkv_split_qk_norm_rope_cache.py Outdated
Comment thread aiter/ops/triton/rope/fused_qkv_split_qk_norm_rope_cache.py Outdated
Comment on lines +301 to +317
@pytest.mark.parametrize("B", [1, 4, 8])
@pytest.mark.parametrize("QH_PER_KH", [1, 2, 4])
@pytest.mark.parametrize("KH", [1, 4])
@pytest.mark.parametrize("D", [64, 128])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("rotate_style", [RotateStyle.GPTJ, RotateStyle.NEOX])
@pytest.mark.parametrize("max_embed_positions", [131072])
@pytest.mark.parametrize("reuse_freqs_front_part", [False, True])
@pytest.mark.parametrize("attn_output_gate", [False, True])
@pytest.mark.parametrize("use_kv_scale", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("qkv_layout", ["interleaved", "blocked"])
@pytest.mark.parametrize(
"rotary_dim",
[None, 32],
ids=["full", "32"],
)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot apply changes based on this feedback

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit ae3959f. Reduced the grid from 1152 to ~128 cases (~9× reduction) by:

  • B: [1, 4, 8][4] (single representative batch size)
  • QH_PER_KH: [1, 2, 4][1, 4] (MHA and GQA, drops redundant middle value)
  • reuse_freqs_front_part: [False, True][True] (the False path is always skipped when rotary_dim < D, so it adds no new coverage over the full-dim run)

hellozhuo-amd and others added 3 commits April 24, 2026 11:53
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/c28dcd1b-edcd-4f01-bc04-f6b29c143537

Co-authored-by: hellozhuo-amd <225919697+hellozhuo-amd@users.noreply.github.com>
… cache write path

Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e34170cd-4738-422a-b53d-96b60439fe28

Co-authored-by: hellozhuo-amd <225919697+hellozhuo-amd@users.noreply.github.com>
@k50112113
Copy link
Copy Markdown
Contributor

@hellozhuo-amd

GPT-OSS has RMSnorm before the QKV projection
https://github.com/ROCm/vllm/blob/355_wip/vllm/model_executor/models/gpt_oss.py#L157

Are you adding this kernel for another model? Would you mind sharing the integration instance on vLLM/SGLang/ATOM, ... etc?

Thanks,
Shao-Chun

valid_slots = slot_mapping[slot_mapping >= 0]
if valid_slots.numel() > 0:
max_slot = int(valid_slots.max().item())
assert max_slot < num_blocks * block_size, (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both valid_slots = slot_mapping[slot_mapping >= 0] and max_slot = int(valid_slots.max().item()) creates GPU-side workload, which is not the desired way to do this.

For valid_slots = slot_mapping[slot_mapping >= 0], you can actually remove it and simply perform an early exit of the kernel if the current workgroup is loaded with a negative slot index

For max_slot = int(valid_slots.max().item()), you can also remove it (remove the assertion as well), and perform an early exit of the kernel if the workgroup is loaded with a slot index that is larger than or equal to total_num_kv_cache_tokens = num_blocks * block_size, which can be passed into the kernel as a dynamic variable (tl.int64)

the example code would be something like:

@triton.jit
def _fused_qkv_split_qk_norm_rope_cache_kernel(
    ...
    total_num_kv_cache_tokens: tl.int64,
    ...
    ):
    ...
    slots = tl.load(slot_mapping_ptr + t_offs, mask=t_mask)

    // early exit if slot index is out-of-bound
    if slots < 0 or slots >= total_num_kv_cache_tokens:
        return

    // continue with valid slot index
    ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants