Skip to content

Pa gluon swa mtp opt#2914

Open
Bernard-Liu wants to merge 8 commits intomainfrom
pa_gluon_swa_mtp_opt
Open

Pa gluon swa mtp opt#2914
Bernard-Liu wants to merge 8 commits intomainfrom
pa_gluon_swa_mtp_opt

Conversation

@Bernard-Liu
Copy link
Copy Markdown
Contributor

Summary

  • Reapply the fix_sliding_window_mtp changes on top of the current main so the sliding-window MTP work can be reviewed from a branch with shared history.
  • Restore the paged attention decode updates for sliding-window MTP, including the KV_BLOCK_SIZE=1024 path, PS reduce handling, and the related AOT/test updates.
  • Keep the current main Triton version helper while carrying over the branch-only scheduling and decode changes.

Test plan

  • python3 -m py_compile aiter/ops/triton/gluon/pa_decode_gluon.py csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py op_tests/triton_tests/test_pa_decode_gluon.py
  • python3 -m pytest op_tests/triton_tests/test_pa_decode_gluon.py -k 'sliding_window and not performance' -q (not run here: pytest is not installed in this environment)

Context

  • Supersedes the previously closed #1840 by porting the same work onto the latest main.

Made with Cursor

fsx950223 and others added 7 commits April 25, 2026 07:46
Reapply the sliding-window MTP decode, PS reduce, and KV_BLOCK_SIZE=1024 fixes on top of the latest main so the change can be reviewed and merged from a branch with shared history.

Made-with: Cursor
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Made-with: Cursor
Signed-off-by: fsx950223 <fsx950223@outlook.com>
Made-with: Cursor
Match the composable_kernel submodule pointer with origin/main so this branch uses the same dependency baseline.

Made-with: Cursor
@Bernard-Liu Bernard-Liu requested review from a team and Copilot April 25, 2026 08:06
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2914 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Reapplies and updates the paged-attention (Gluon/Triton) sliding-window MTP work on top of current main, including KV_BLOCK_SIZE=1024 support and PS (persistent scheduling) decode/reduce handling.

Changes:

  • Extend Gluon decode kernels/wrappers to support sliding-window behavior and KV_BLOCK_SIZE=1024, including updated PS scheduling/reduction paths and FlyDSL reduce fallback.
  • Update AOT compilation signature to account for the new kernel constexpr parameter(s).
  • Adjust the paged-attention decode test harness to exercise sliding-window/PS-related configurations and assembly (persistent scheduling) comparison logic.

Reviewed changes

Copilot reviewed 2 out of 3 changed files in this pull request and generated 3 comments.

File Description
aiter/ops/triton/gluon/pa_decode_gluon.py Core kernel/wrapper updates for sliding-window MTP, KV_BLOCK_SIZE=1024, PS decode/reduce, plus new helper utilities.
csrc/cpp_itfs/pa_gluon_aot/pa_decode_gluon_aot.py AOT signature update to match the updated kernel parameter list.
op_tests/triton_tests/test_pa_decode_gluon.py Test harness updates for sliding-window cases and persistent-scheduling (assembly) comparison.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +5417 to +5445
if max_context_partition_num > 1:
if exp_sums is None:
exp_sums = torch.empty(
batch_size,
num_kv_heads,
max_context_partition_num,
equivalent_query_group_size,
device=query.device,
dtype=aiter.dtypes.fp32,
)
if max_logits is None:
max_logits = torch.empty(
batch_size,
num_kv_heads,
max_context_partition_num,
equivalent_query_group_size,
device=query.device,
dtype=aiter.dtypes.fp32,
)
if temporary_output is None:
temporary_output = torch.empty(
batch_size,
num_kv_heads,
max_context_partition_num,
equivalent_query_group_size,
head_size,
device=query.device,
dtype=query.dtype,
)
Copy link

Copilot AI Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pa_decode_gluon only allocates exp_sums / max_logits / temporary_output when max_context_partition_num > 1, but these tensors are used unconditionally later (e.g., exp_sums.stride(...) and passed into the decode wrapper) even when max_context_partition_num == 1 / one_shot is true. If callers omit these optional buffers, this will raise at runtime. Allocate (or create dummy buffers) whenever the corresponding argument is None (and/or whenever one_shot is false), not only when max_context_partition_num > 1.

Suggested change
if max_context_partition_num > 1:
if exp_sums is None:
exp_sums = torch.empty(
batch_size,
num_kv_heads,
max_context_partition_num,
equivalent_query_group_size,
device=query.device,
dtype=aiter.dtypes.fp32,
)
if max_logits is None:
max_logits = torch.empty(
batch_size,
num_kv_heads,
max_context_partition_num,
equivalent_query_group_size,
device=query.device,
dtype=aiter.dtypes.fp32,
)
if temporary_output is None:
temporary_output = torch.empty(
batch_size,
num_kv_heads,
max_context_partition_num,
equivalent_query_group_size,
head_size,
device=query.device,
dtype=query.dtype,
)
alloc_context_partition_num = max(1, max_context_partition_num)
if exp_sums is None:
exp_sums = torch.empty(
batch_size,
num_kv_heads,
alloc_context_partition_num,
equivalent_query_group_size,
device=query.device,
dtype=aiter.dtypes.fp32,
)
if max_logits is None:
max_logits = torch.empty(
batch_size,
num_kv_heads,
alloc_context_partition_num,
equivalent_query_group_size,
device=query.device,
dtype=aiter.dtypes.fp32,
)
if temporary_output is None:
temporary_output = torch.empty(
batch_size,
num_kv_heads,
alloc_context_partition_num,
equivalent_query_group_size,
head_size,
device=query.device,
dtype=query.dtype,
)

Copilot uses AI. Check for mistakes.
@@ -1648,24 +1651,91 @@ def run_pa_gluon_test(
or True
Copy link

Copilot AI Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

skip_assembly is forced to True due to the trailing or True, which makes the new persistent-scheduling assembly comparison path unreachable and silently disables the intended coverage/perf reporting for ASM. Remove the unconditional or True (or gate it behind an explicit flag) so the skip logic reflects the actual supported cases.

Suggested change
or True

Copilot uses AI. Check for mistakes.
Comment on lines 5224 to +5229
max_context_partition_num: int,
context_partition_size: int,
compute_type: torch.dtype,
query_scale: torch.Tensor, # [num_seqs * query_length, num_query_heads, 1] or [1]
key_scale: torch.Tensor, # [num_blocks, num_kv_heads, kv_block_size, 1]
value_scale: torch.Tensor, # [num_blocks, num_kv_heads, kv_block_size, 1]
context_partition_size: int = 256,
compute_type: torch.dtype = torch.bfloat16,
query_scale: torch.Tensor = None, # [num_seqs * query_length, num_query_heads, 1] or [1]
key_scale: torch.Tensor = None, # [num_blocks, num_kv_heads, kv_block_size, 1]
value_scale: torch.Tensor = None, # [num_blocks, num_kv_heads, kv_block_size, 1]
Copy link

Copilot AI Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pa_decode_gluon now defaults query_scale, key_scale, and value_scale to None, but the kernel logic still assumes scales are provided when the corresponding tensors are FP8. Without scales, QUERY_QUANT_MODE/KV_QUANT_MODE remain -1, which can silently produce incorrect results (and may not match the FP8 cache layout assumptions). Add explicit validation that query_scale is present for FP8 query, and key_scale/value_scale are present for FP8 KV caches.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants