Conversation
Reapply the sliding-window MTP decode, PS reduce, and KV_BLOCK_SIZE=1024 fixes on top of the latest main so the change can be reviewed and merged from a branch with shared history. Made-with: Cursor
Signed-off-by: fsx950223 <fsx950223@outlook.com> Made-with: Cursor
Signed-off-by: fsx950223 <fsx950223@outlook.com> Made-with: Cursor
Match the composable_kernel submodule pointer with origin/main so this branch uses the same dependency baseline. Made-with: Cursor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
Reapplies and updates the paged-attention (Gluon/Triton) sliding-window MTP work on top of current main, including KV_BLOCK_SIZE=1024 support and PS (persistent scheduling) decode/reduce handling.
Changes:
- Extend Gluon decode kernels/wrappers to support sliding-window behavior and KV_BLOCK_SIZE=1024, including updated PS scheduling/reduction paths and FlyDSL reduce fallback.
- Update AOT compilation signature to account for the new kernel constexpr parameter(s).
- Adjust the paged-attention decode test harness to exercise sliding-window/PS-related configurations and assembly (persistent scheduling) comparison logic.
Reviewed changes
Copilot reviewed 2 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
aiter/ops/triton/gluon/pa_decode_gluon.py |
Core kernel/wrapper updates for sliding-window MTP, KV_BLOCK_SIZE=1024, PS decode/reduce, plus new helper utilities. |
csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py |
AOT signature update to match the updated kernel parameter list. |
op_tests/triton_tests/test_pa_decode_gluon.py |
Test harness updates for sliding-window cases and persistent-scheduling (assembly) comparison. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if max_context_partition_num > 1: | ||
| if exp_sums is None: | ||
| exp_sums = torch.empty( | ||
| batch_size, | ||
| num_kv_heads, | ||
| max_context_partition_num, | ||
| equivalent_query_group_size, | ||
| device=query.device, | ||
| dtype=aiter.dtypes.fp32, | ||
| ) | ||
| if max_logits is None: | ||
| max_logits = torch.empty( | ||
| batch_size, | ||
| num_kv_heads, | ||
| max_context_partition_num, | ||
| equivalent_query_group_size, | ||
| device=query.device, | ||
| dtype=aiter.dtypes.fp32, | ||
| ) | ||
| if temporary_output is None: | ||
| temporary_output = torch.empty( | ||
| batch_size, | ||
| num_kv_heads, | ||
| max_context_partition_num, | ||
| equivalent_query_group_size, | ||
| head_size, | ||
| device=query.device, | ||
| dtype=query.dtype, | ||
| ) |
There was a problem hiding this comment.
pa_decode_gluon only allocates exp_sums / max_logits / temporary_output when max_context_partition_num > 1, but these tensors are used unconditionally later (e.g., exp_sums.stride(...) and passed into the decode wrapper) even when max_context_partition_num == 1 / one_shot is true. If callers omit these optional buffers, this will raise at runtime. Allocate (or create dummy buffers) whenever the corresponding argument is None (and/or whenever one_shot is false), not only when max_context_partition_num > 1.
| if max_context_partition_num > 1: | |
| if exp_sums is None: | |
| exp_sums = torch.empty( | |
| batch_size, | |
| num_kv_heads, | |
| max_context_partition_num, | |
| equivalent_query_group_size, | |
| device=query.device, | |
| dtype=aiter.dtypes.fp32, | |
| ) | |
| if max_logits is None: | |
| max_logits = torch.empty( | |
| batch_size, | |
| num_kv_heads, | |
| max_context_partition_num, | |
| equivalent_query_group_size, | |
| device=query.device, | |
| dtype=aiter.dtypes.fp32, | |
| ) | |
| if temporary_output is None: | |
| temporary_output = torch.empty( | |
| batch_size, | |
| num_kv_heads, | |
| max_context_partition_num, | |
| equivalent_query_group_size, | |
| head_size, | |
| device=query.device, | |
| dtype=query.dtype, | |
| ) | |
| alloc_context_partition_num = max(1, max_context_partition_num) | |
| if exp_sums is None: | |
| exp_sums = torch.empty( | |
| batch_size, | |
| num_kv_heads, | |
| alloc_context_partition_num, | |
| equivalent_query_group_size, | |
| device=query.device, | |
| dtype=aiter.dtypes.fp32, | |
| ) | |
| if max_logits is None: | |
| max_logits = torch.empty( | |
| batch_size, | |
| num_kv_heads, | |
| alloc_context_partition_num, | |
| equivalent_query_group_size, | |
| device=query.device, | |
| dtype=aiter.dtypes.fp32, | |
| ) | |
| if temporary_output is None: | |
| temporary_output = torch.empty( | |
| batch_size, | |
| num_kv_heads, | |
| alloc_context_partition_num, | |
| equivalent_query_group_size, | |
| head_size, | |
| device=query.device, | |
| dtype=query.dtype, | |
| ) |
| @@ -1648,24 +1651,91 @@ def run_pa_gluon_test( | |||
| or True | |||
There was a problem hiding this comment.
skip_assembly is forced to True due to the trailing or True, which makes the new persistent-scheduling assembly comparison path unreachable and silently disables the intended coverage/perf reporting for ASM. Remove the unconditional or True (or gate it behind an explicit flag) so the skip logic reflects the actual supported cases.
| or True |
| max_context_partition_num: int, | ||
| context_partition_size: int, | ||
| compute_type: torch.dtype, | ||
| query_scale: torch.Tensor, # [num_seqs * query_length, num_query_heads, 1] or [1] | ||
| key_scale: torch.Tensor, # [num_blocks, num_kv_heads, kv_block_size, 1] | ||
| value_scale: torch.Tensor, # [num_blocks, num_kv_heads, kv_block_size, 1] | ||
| context_partition_size: int = 256, | ||
| compute_type: torch.dtype = torch.bfloat16, | ||
| query_scale: torch.Tensor = None, # [num_seqs * query_length, num_query_heads, 1] or [1] | ||
| key_scale: torch.Tensor = None, # [num_blocks, num_kv_heads, kv_block_size, 1] | ||
| value_scale: torch.Tensor = None, # [num_blocks, num_kv_heads, kv_block_size, 1] |
There was a problem hiding this comment.
pa_decode_gluon now defaults query_scale, key_scale, and value_scale to None, but the kernel logic still assumes scales are provided when the corresponding tensors are FP8. Without scales, QUERY_QUANT_MODE/KV_QUANT_MODE remain -1, which can silently produce incorrect results (and may not match the FP8 cache layout assumptions). Add explicit validation that query_scale is present for FP8 query, and key_scale/value_scale are present for FP8 KV caches.
Summary
fix_sliding_window_mtpchanges on top of the currentmainso the sliding-window MTP work can be reviewed from a branch with shared history.KV_BLOCK_SIZE=1024path, PS reduce handling, and the related AOT/test updates.mainTriton version helper while carrying over the branch-only scheduling and decode changes.Test plan
python3 -m py_compile aiter/ops/triton/gluon/pa_decode_gluon.py csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py op_tests/triton_tests/test_pa_decode_gluon.pypython3 -m pytest op_tests/triton_tests/test_pa_decode_gluon.py -k 'sliding_window and not performance' -q(not run here:pytestis not installed in this environment)Context
#1840by porting the same work onto the latestmain.Made with Cursor