feat(triton/rope): fused QKV split, QK RMSNorm, RoPE, and paged KV cache#2902
feat(triton/rope): fused QKV split, QK RMSNorm, RoPE, and paged KV cache#2902hellozhuo-amd wants to merge 10 commits intomainfrom
Conversation
Add fused_qkv_split_qk_norm_rope_cache (Triton kernel + Python entrypoint): split flat QKV, RMSNorm on Q/K, RoPE (NeoX/GPT-J), optional KV scales, and write descaled K/V into paged caches using slot_mapping. Support partial rotation aligned with ref_rope_sbhd_fwd, optional rotary_dim_half vs cos/sin, and gated QKV layouts (interleaved vs blocked). Extend test_fused_qkv_split_qk_rope with cache E2E tests, layout-aware reference, rotary_dim parametrization, and clearer separation from the non-cache fused_qkv_split_qk_rope tests. Made-with: Cursor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Co-authored-by: Tianxing Wu <Tianxing.Wu@amd.com> Co-authored-by: Tres Popp <Tres.Popp@amd.com>
There was a problem hiding this comment.
Pull request overview
This PR introduces a new fused Triton implementation for attention pre-processing that combines QKV splitting, Q/K RMSNorm, RoPE application (NeoX/GPT-J), optional KV scaling, and writing K/V into a paged KV cache, along with expanded end-to-end tests to validate the new cache path.
Changes:
- Added
fused_qkv_split_qk_norm_rope_cachePython entrypoint and corresponding Triton kernel. - Extended
test_fused_qkv_split_qk_rope.pywith a torch reference for the cache path and new cache E2E assertions (including gated QKV layouts and partial rotary dim coverage). - Refactored test input generation to support gated QKV layouts (
interleavedvsblocked) and added more parameter coverage.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
op_tests/triton_tests/rope/test_fused_qkv_split_qk_rope.py |
Adds cache E2E tests + reference implementation and expands parametrization/grid. |
aiter/ops/triton/rope/fused_qkv_split_qk_norm_rope_cache.py |
New Python wrapper/entrypoint that validates shapes and launches the Triton kernel. |
aiter/ops/triton/_triton_kernels/rope/fused_qkv_split_qk_norm_rope_cache.py |
New fused Triton kernel implementing split + RMSNorm + RoPE + paged KV cache write. |
Comments suppressed due to low confidence (1)
op_tests/triton_tests/rope/test_fused_qkv_split_qk_rope.py:117
- The parametrization grid for this test is extremely large (hundreds to >1k cases when combined), which is likely to cause CI timeouts / long runtimes, especially since each case launches Triton kernels. Consider reducing the cartesian product (e.g., smaller representative sets, split into separate fast/slow tests with markers, or use
pytest.mark.parametrizewith fewer values and add a dedicated stress test under a slow/optional marker).
@pytest.mark.parametrize("B", [1, 4, 8, 16, 32])
@pytest.mark.parametrize("QH_PER_KH", [1, 2, 4, 8, 16])
@pytest.mark.parametrize("KH", [1, 4])
@pytest.mark.parametrize("D", [64, 128])
@pytest.mark.parametrize("rotate_style", [RotateStyle.GPTJ, RotateStyle.NEOX])
@pytest.mark.parametrize("max_embed_positions", [131072])
@pytest.mark.parametrize(
"nope, nope_first", [(False, False), (True, False), (True, True)]
)
@pytest.mark.parametrize("reuse_freqs_front_part", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_fused_qkv_split_qk_rope(
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @pytest.mark.parametrize("B", [1, 4, 8]) | ||
| @pytest.mark.parametrize("QH_PER_KH", [1, 2, 4]) | ||
| @pytest.mark.parametrize("KH", [1, 4]) | ||
| @pytest.mark.parametrize("D", [64, 128]) | ||
| @pytest.mark.parametrize("block_size", [16]) | ||
| @pytest.mark.parametrize("rotate_style", [RotateStyle.GPTJ, RotateStyle.NEOX]) | ||
| @pytest.mark.parametrize("max_embed_positions", [131072]) | ||
| @pytest.mark.parametrize("reuse_freqs_front_part", [False, True]) | ||
| @pytest.mark.parametrize("attn_output_gate", [False, True]) | ||
| @pytest.mark.parametrize("use_kv_scale", [False, True]) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16]) | ||
| @pytest.mark.parametrize("qkv_layout", ["interleaved", "blocked"]) | ||
| @pytest.mark.parametrize( | ||
| "rotary_dim", | ||
| [None, 32], | ||
| ids=["full", "32"], | ||
| ) |
There was a problem hiding this comment.
Done in commit ae3959f. Reduced the grid from 1152 to ~128 cases (~9× reduction) by:
B:[1, 4, 8]→[4](single representative batch size)QH_PER_KH:[1, 2, 4]→[1, 4](MHA and GQA, drops redundant middle value)reuse_freqs_front_part:[False, True]→[True](theFalsepath is always skipped whenrotary_dim < D, so it adds no new coverage over the full-dim run)
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/87aa0311-0913-463f-9021-ea78e35ee61c Co-authored-by: hellozhuo-amd <225919697+hellozhuo-amd@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/c28dcd1b-edcd-4f01-bc04-f6b29c143537 Co-authored-by: hellozhuo-amd <225919697+hellozhuo-amd@users.noreply.github.com>
… cache write path Agent-Logs-Url: https://github.com/ROCm/aiter/sessions/e34170cd-4738-422a-b53d-96b60439fe28 Co-authored-by: hellozhuo-amd <225919697+hellozhuo-amd@users.noreply.github.com>
|
GPT-OSS has RMSnorm before the QKV projection Are you adding this kernel for another model? Would you mind sharing the integration instance on vLLM/SGLang/ATOM, ... etc? Thanks, |
| valid_slots = slot_mapping[slot_mapping >= 0] | ||
| if valid_slots.numel() > 0: | ||
| max_slot = int(valid_slots.max().item()) | ||
| assert max_slot < num_blocks * block_size, ( |
There was a problem hiding this comment.
Both valid_slots = slot_mapping[slot_mapping >= 0] and max_slot = int(valid_slots.max().item()) creates GPU-side workload, which is not the desired way to do this.
For valid_slots = slot_mapping[slot_mapping >= 0], you can actually remove it and simply perform an early exit of the kernel if the current workgroup is loaded with a negative slot index
For max_slot = int(valid_slots.max().item()), you can also remove it (remove the assertion as well), and perform an early exit of the kernel if the workgroup is loaded with a slot index that is larger than or equal to total_num_kv_cache_tokens = num_blocks * block_size, which can be passed into the kernel as a dynamic variable (tl.int64)
the example code would be something like:
@triton.jit
def _fused_qkv_split_qk_norm_rope_cache_kernel(
...
total_num_kv_cache_tokens: tl.int64,
...
):
...
slots = tl.load(slot_mapping_ptr + t_offs, mask=t_mask)
// early exit if slot index is out-of-bound
if slots < 0 or slots >= total_num_kv_cache_tokens:
return
// continue with valid slot index
...
Add fused_qkv_split_qk_norm_rope_cache (Triton kernel + Python entrypoint): split flat QKV, RMSNorm on Q/K, RoPE (NeoX/GPT-J), optional KV scales, and write descaled K/V into paged caches using slot_mapping.
Support partial rotation aligned with ref_rope_sbhd_fwd, optional rotary_dim_half vs cos/sin, and gated QKV layouts (interleaved vs blocked).
Extend test_fused_qkv_split_qk_rope with cache E2E tests, layout-aware reference, rotary_dim parametrization, and clearer separation from the non-cache fused_qkv_split_qk_rope tests.
Co-authored-by: Tianxing Wu Tianxing.Wu@amd.com
Co-authored-by: Tres Popp Tres.Popp@amd.com
Co-authored-by: Cursor
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist