From 74dcf38df9086d299f4909db135f5b30030b9bdc Mon Sep 17 00:00:00 2001 From: Peng Date: Fri, 24 Apr 2026 23:09:17 +0000 Subject: [PATCH 1/4] =?UTF-8?q?ci(release):=20cherry-pick=20#2875=20?= =?UTF-8?q?=E2=80=94=20torch=5Fpin=20+=20torch=5Findex=5Furl=20workflow=20?= =?UTF-8?q?inputs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Brings the manylinux + torch ABI pin workflow controls onto release/v0.1.13 so v0.1.13 release wheel builds can dispatch torch_pin=2.10.0 directly, matching the post2 build path. Cherry-pick of: 7c4cc6c41c9b669900831a976fd737e6daa91dbe (#2875) --- .github/workflows/aiter-release.yaml | 41 ++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/.github/workflows/aiter-release.yaml b/.github/workflows/aiter-release.yaml index 1807fbf440..0255edb264 100644 --- a/.github/workflows/aiter-release.yaml +++ b/.github/workflows/aiter-release.yaml @@ -67,6 +67,16 @@ on: description: 'Use pytorch/manylinux2_28-builder ROCm image (AlmaLinux 8 + devtoolset, glibc 2.28). Produces wheels ABI-compatible with vLLM/Ubuntu 22 containers.' type: boolean default: false + torch_pin: + description: 'Optional torch version pin for the manylinux build (e.g. 2.10.0+rocm7.1). Empty = latest available for the detected ROCm flavor.' + type: string + required: false + default: '' + torch_index_url: + description: 'Optional override for the torch wheel index URL. Empty = auto-derive from the manylinux builder image tag (https://download.pytorch.org/whl/rocmX.Y).' + type: string + required: false + default: '' workflow_call: inputs: release_type: @@ -111,6 +121,16 @@ on: type: boolean required: false default: false + torch_pin: + description: 'Optional torch version pin for the manylinux build (e.g. 2.10.0+rocm7.1). Empty = latest.' + type: string + required: false + default: '' + torch_index_url: + description: 'Optional torch index URL override. Empty = auto-derive from builder image tag.' + type: string + required: false + default: '' outputs: wheel_names: description: 'Space-separated list of built wheel filenames' @@ -145,6 +165,8 @@ jobs: RELEASE_TYPE: ${{ inputs.release_type || github.event.inputs.release_type }} ADD_DATE_STAMP: ${{ inputs.add_date_stamp || github.event.inputs.add_date_stamp }} USE_MANYLINUX: ${{ inputs.use_manylinux || github.event.inputs.use_manylinux || (startsWith(matrix.docker_image, 'pytorch/manylinux') && 'true') || 'false' }} + TORCH_PIN: ${{ inputs.torch_pin || github.event.inputs.torch_pin }} + TORCH_INDEX_URL: ${{ inputs.torch_index_url || github.event.inputs.torch_index_url }} steps: - name: Checkout aiter repo @@ -301,17 +323,32 @@ jobs: IMG="${BUILD_DOCKER_IMAGE}" ROCM_TAG="${IMG##*:}" # rocm7.2 / rocm7.1 / rocm7.0 ROCM_NUM="${ROCM_TAG#rocm}" # 7.2 - TORCH_INDEX="https://download.pytorch.org/whl/rocm${ROCM_NUM}" + # Allow caller to override the torch wheel index (e.g. pin a release + # to a specific ROCm flavor's PyTorch ABI). Defaults preserve the + # legacy auto-derived behavior. + if [ -n "${TORCH_INDEX_URL}" ]; then + TORCH_INDEX="${TORCH_INDEX_URL}" + else + TORCH_INDEX="https://download.pytorch.org/whl/rocm${ROCM_NUM}" + fi + # Optional torch version pin (e.g. 2.10.0+rocm7.1). Empty = latest. + if [ -n "${TORCH_PIN}" ]; then + TORCH_SPEC="torch==${TORCH_PIN}" + else + TORCH_SPEC="torch" + fi echo "Torch index: ${TORCH_INDEX}" + echo "Torch spec: ${TORCH_SPEC}" docker exec \ -w /workspace \ -e PYBIN="${PYBIN}" \ -e TORCH_INDEX="${TORCH_INDEX}" \ + -e TORCH_SPEC="${TORCH_SPEC}" \ aiter_build_${{ matrix.python_version }} \ bash -c ' set -e ${PYBIN}/pip install --upgrade --timeout=60 --retries=10 pip - ${PYBIN}/pip install --timeout=60 --retries=10 --index-url "${TORCH_INDEX}" torch + ${PYBIN}/pip install --timeout=60 --retries=10 --index-url "${TORCH_INDEX}" "${TORCH_SPEC}" # flydsl publishes only manylinux_2_35 wheels which cannot install # on AlmaLinux 8 (glibc 2.28). FlyDSL AOT pre-compilation in # setup.py is wrapped in try/except and is skipped gracefully when From 54688d611fe0f648f0a096777e8a86bf2f9dbf28 Mon Sep 17 00:00:00 2001 From: Jeff Huang Date: Fri, 24 Apr 2026 20:57:10 +0800 Subject: [PATCH 2/4] fix(fmha): support >4GB KV cache in batch prefill via runtime dispatch (#2893) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test: expand test_batch_prefill_large_kvcache for >4GB KV cache overflow Rewrite test_batch_prefill_large_kvcache to validate the per-tile SRD rebase fix for >4GB KV caches across all page sizes, dtypes, and attention configurations: - Add page_size=1 and 16 (page_size < kN0, exercises rebase path) - Add GQA (16, 8) in addition to MHA (8, 8) - Add causal masking with CK-compatible attn_mask for SDPA reference - Use full KV cache (4.5GB) with pages spanning the overflow boundary - Use torch SDPA as reference (memory-efficient backend, no score matrix materialization) - Add scatter_pages parameter (False only; True for future global_load_lds flat addressing) - Add GPU memory check to skip configs that exceed HBM capacity Test matrix: 24 cases (3 page_sizes × 2 dtypes × 2 causal × 2 GQA × 1 scatter) * test: add GPU sync after CK kernel in large_kvcache test Add torch.cuda.synchronize() after CK kernel launch in test_batch_prefill_large_kvcache to ensure all async GPU work completes before memory is freed between tests. Without this sync, repeated allocate/free cycles of large KV cache buffers (~20GB) with mixed dtype (bf16→fp8) can trigger GPU page faults when the HIP memory allocator reuses virtual addresses that are still referenced by pending async GPU work. The fault manifests as VM_L2_PROTECTION_FAULT at address 0x0 (NULL), causing GPU reset and kernel soft lockup. * feat(fmha): runtime dispatch for >4GB KV cache in batch prefill Add use_64bit_load to batch prefill traits and runtime overflow detection. When page_block_size < 128 and max_page_byte_offset > INT32_MAX, dispatch to the flat 64-bit load kernel variant for correctness. Also add vectorized KV layout coverage to test_batch_prefill_large_kvcache. * fix: remove unused k_vector_size variable in large_kvcache test * fix(mha): improve batch_prefill TORCH_CHECK error message for >4GB KV cache Include page_size, num_pages, and dtype in the error message when kernel dispatch fails. Add hint about CDNA3+ GPU requirement when KV cache exceeds 4GB with page_size < 128. * test: update scatter_pages comment in large_kvcache test The comment incorrectly stated scatter_pages=True was "expected to FAIL". This is no longer true — the flat 64-bit load path handles scattered pages correctly. Update to describe the test's purpose instead. * fix(mha): widen batch_prefill 64-bit threshold to total KV bytes The previous check used (num_total_pages - 1) * batch_stride * element_size which measures the last-page base offset, missing within-page offsets and producing an off-by-one at exactly INT32_MAX (the largest representable SRD voffset). Switch to total KV cache footprint (num_total_pages * batch_stride * element_size > INT32_MAX) so within-page reads on the last page are covered, and drop the redundant num_total_pages > 1 guard since single-page configs trivially fit in 32 bits. Also unify wording: 4GB → 2GB (INT32_MAX byte offset for SRD voffset), matching CK's TwoGB convention. The actual hardware bound has always been 2GB; the prior comments were imprecise. Found during batch prefill template dispatch review. * docs(mha): unify >2GB wording in batch_prefill error and test The 4GB number in the TORCH_CHECK error message and the test comment was imprecise — the actual SRD voffset bound is 2GB (INT32_MAX). Update both to match the threshold check and CK's TwoGB convention. Found during batch prefill template dispatch review. * refactor(mha): drop wrapper-side use_64bit_load; let CK dispatcher decide The wrapper hardcoded kN0_min = 128 to compute the >2GB KV cache predicate, which leaked CK tile config into aiter and would silently break if a new arm with bn0 != 128 were added. The CK auto-generated dispatcher now decides per-arm using its own compile-time bn0 and per-dtype kElementBytes, so the wrapper just forwards args. Remove the `use_64bit_load` runtime field from `mha_batch_prefill_traits`, the parameter from `get_mha_batch_prefill_traits()`, and the entire predicate computation block from the dispatcher call site. Bumps CK submodule to pull in the matching codegen change. * chore(mha): bump CK + update wrapper wording for kUseGlobalLoad rename Bumps 3rdparty/composable_kernel to dd8d293ea (refactor(fmha): batch prefill review polish — assert helper + setter guards) which builds on the prior 99a3ca9af kUseGlobalLoad rename. Wrapper-side updates to match: * csrc/cpp_itfs/mha_fwd_batch_prefill.cu: rename "64-bit-load" wording in the per-arm dispatcher comment to "kUseGlobalLoad" so the wrapper comment matches the CK-side identifier. Also drops the trailing `false /* skip_min_seqlen_q */` argument from the get_mha_batch_prefill_traits call to match the upstream CK API signature change. * csrc/py_itfs_ck/mha_batch_prefill_kernels.cu: change the >2GB error message from "page_size < 128" to "page_size < kN0" so the diagnostic tracks the tile-size constant rather than a magic number. * op_tests/test_batch_prefill.py (test_batch_prefill_large_kvcache): three documentation enhancements with no behavior change — - explain why qo_len caps at 128 (causal) / 1024 (non-causal): the causal cap is a math-backend cliff for the SDPA reference, not a kernel limit; - explain that the +256 padding on kv_page_indices is a batch_prefill ABI requirement (kernel may speculatively read up to bn0=256 entries past the last valid page index); - expand the torch.cuda.synchronize comment to call out the misattribution failure mode and GPU-reset cascade risk. * test(fmha): parametrize test_batch_prefill_large_kvcache over batch_size {1, 4} Adds multi-batch coverage to the >2GB KV cache regression test. The previous single-batch coverage left the kernel's per-sequence SRD rebase path unexercised: with cu_seqlens_q=[0, qo_len] and kv_indptr= [0, num_blocks], the kernel never walks the indptr to reposition K/V SRDs across batch boundaries. After the kUseGlobalLoad rename and the new positive static_assert(kUseGlobalLoad_) calls in update_physical_pages and set_page_stride_elements, we want a regression that catches any boundary-crossing SRD bug -- the failure mode no single-batch test can detect (one batch correct, others wrong). batch_size=4 partitions the >2GB page pool across 4 sequences (last sequence absorbs the remainder), exercising 3 cross-batch SRD transitions. The SDPA reference is computed per-batch and concatenated; per-iteration free + empty_cache keeps peak memory at one batch's worth. Verified on: - gfx950 (smci355-gfx950, MI355X): 160 passed, 32 skipped - gfx942 (smc300x-clt, MI308X): 160 passed, 32 skipped Skips are the existing vectorized + page_size=1 incompatibility (3D tensor layout), now 16 per batch_size value. --------- Co-authored-by: Xin Huang --- 3rdparty/composable_kernel | 2 +- csrc/cpp_itfs/mha_fwd_batch_prefill.cu | 8 +- csrc/py_itfs_ck/mha_batch_prefill_kernels.cu | 8 +- op_tests/test_batch_prefill.py | 358 ++++++++++++++----- 4 files changed, 287 insertions(+), 89 deletions(-) diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 5348b577ed..fdf4bb7fcc 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 5348b577ed7a5d88d350d88fd720b882176466ae +Subproject commit fdf4bb7fcc984811cef48ce817d89aac064b984a diff --git a/csrc/cpp_itfs/mha_fwd_batch_prefill.cu b/csrc/cpp_itfs/mha_fwd_batch_prefill.cu index 2c5da43ef2..7994e7b2d9 100644 --- a/csrc/cpp_itfs/mha_fwd_batch_prefill.cu +++ b/csrc/cpp_itfs/mha_fwd_batch_prefill.cu @@ -47,7 +47,13 @@ float mha_batch_prefill(mha_batch_prefill_args args, int head_size_q = args.hdim_q; int head_size_v = args.hdim_v; bool has_dropout = args.p_drop > 0.f; - auto traits = get_mha_batch_prefill_traits(head_size_q, + + // The kUseGlobalLoad decision (>2GB KV cache → use `global_load_lds_*` + // instead of SRD `buffer_load_*`) is made per-arm inside the auto-generated + // dispatcher in fmha_batch_prefill_api.cpp, where each arm knows its own + // compile-time bn0 and dtype element size. The wrapper just forwards args; + // no runtime trait field for it. + auto traits = get_mha_batch_prefill_traits(head_size_q, head_size_v, q_dtype_str, is_group_mode, diff --git a/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu b/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu index cd8dd1c531..15a4878ed9 100644 --- a/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu +++ b/csrc/py_itfs_ck/mha_batch_prefill_kernels.cu @@ -817,7 +817,13 @@ mha_batch_prefill(at::Tensor& q, // [total_q, hq, d] has_lse, qscale_type, false); - TORCH_CHECK(t >= 0, "invalid argument for batch_prefill"); + TORCH_CHECK(t >= 0, + "invalid argument for batch_prefill: no matching kernel found. " + "page_size=", args.page_block_size, + ", num_pages=", args.num_total_pages, + ", dtype=", dtype_str, + ". If KV cache exceeds 2GB (INT32_MAX byte offset) with page_size < kN0, " + "CDNA3+ GPU (MI300/MI350) is required."); } else { diff --git a/op_tests/test_batch_prefill.py b/op_tests/test_batch_prefill.py index ab99206988..ee5489fec4 100644 --- a/op_tests/test_batch_prefill.py +++ b/op_tests/test_batch_prefill.py @@ -1705,133 +1705,319 @@ def reference_attention_kv_blockscale( return output.to(torch.bfloat16) -@pytest.mark.parametrize( - "num_blocks,page_size", - [ - (5000, 1024), # ~10GB KV cache - (10000, 1024), # ~20GB KV cache - ], -) -@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("kv_cache_size_gb", [4.5]) +@pytest.mark.parametrize("page_size", [1, 16, 1024]) +@pytest.mark.parametrize("num_qo_heads,num_kv_heads", [(8, 8), (16, 8)]) @pytest.mark.parametrize("head_dim", [128]) -@pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("input_dtype", ["bf16", "fp8"]) +# scatter_pages=True: adjacent logical tokens map to physically distant pages, +# stress-testing the paged KV cache addressing when pages span large physical distances. +@pytest.mark.parametrize("scatter_pages", [False, True]) +@pytest.mark.parametrize("kv_layout", ["linear", "vectorized"]) def test_batch_prefill_large_kvcache( - num_blocks, + batch_size, + kv_cache_size_gb, page_size, + num_qo_heads, num_kv_heads, head_dim, causal, input_dtype, + scatter_pages, + kv_layout, ): """ Test that batch prefill produces correct results with large KV caches - whose element offsets exceed the INT32_MAX boundary (~4GB for bf16). + whose element offsets exceed the INT32_MAX boundary. + + Uses the full KV cache for attention with pages spanning the overflow + boundary, and compares kernel output against SDPA reference. + For page_size < kN0 (128), this validates the per-tile SRD rebase path. + + Args: + batch_size: Number of sequences. >1 partitions the >2GB page pool + across batches, exercising the per-sequence SRD rebase path. + scatter_pages: If True, interleave page indices so adjacent logical + tokens map to physically distant pages (stress-tests rebase). + kv_layout: "linear" or "vectorized" KV cache memory layout. """ + # page_size=1 only supports linear layout (3D tensor) + if page_size == 1 and kv_layout == "vectorized": + pytest.skip("page_size=1 does not support vectorized layout") + torch.manual_seed(42) + torch.cuda.empty_cache() is_fp8 = input_dtype == "fp8" dtype = torch.bfloat16 - num_qo_heads = num_kv_heads # MHA (no GQA) for simplicity - stride_per_page = page_size * num_kv_heads * head_dim # elements per page block + # Compute num_blocks from target KV cache size + elem_size = 1 if is_fp8 else 2 # fp8=1 byte, bf16=2 bytes + elements_per_block = page_size * num_kv_heads * head_dim + target_bytes = int(kv_cache_size_gb * 1024**3) + num_blocks = target_bytes // (elements_per_block * elem_size) + + # Verify this config triggers overflow + stride_per_page = elements_per_block max_offset = (num_blocks - 1) * stride_per_page INT32_MAX = 2**31 - 1 - if max_offset <= INT32_MAX: pytest.skip( f"max_offset {max_offset} doesn't exceed INT32_MAX, not an overflow test" ) - # Check available GPU memory -- skip if not enough + # Check available GPU memory free_mem = torch.cuda.mem_get_info()[0] - elem_size = 1 if is_fp8 else 2 # fp8=1 byte, bf16=2 bytes - required_mem = 2 * num_blocks * page_size * num_kv_heads * head_dim * elem_size - if free_mem < required_mem * 1.1: # 10% headroom + # Per-batch page partition: uniform split, remainder absorbed by the last + # sequence to keep all kv_indptr deltas > 0 (zero-length sequences would be + # skipped by the kernel's per-batch dispatch and hide any rebase bug). + blocks_per_seq = [num_blocks // batch_size] * batch_size + blocks_per_seq[-1] += num_blocks % batch_size + kv_lens_per_seq = [bps * page_size for bps in blocks_per_seq] + max_kv_len_per_seq = max(kv_lens_per_seq) + # Causal with attn_mask forces SDPA math backend which materializes + # [H_q, qo_len, kv_len] score + mask tensors. Magnitudes empirically chosen: + # non-causal: 1024 -- flash backend, no full score matrix, headroom is large + # causal: 128 -- math backend cliff: 3x [H_q, qo, kv] fp32 buffers must + # fit alongside K/V cache (kv_len up to ~5GB at this scale) + # qo_len is per-batch; total qo tokens = batch_size * qo_len. + qo_len = min(128, max_kv_len_per_seq) if causal else min(1024, max_kv_len_per_seq) + total_qo_len = batch_size * qo_len + # SDPA causal with attn_mask forces math backend: expanded mask + score matrix + # + softmax intermediates, each [1, H_q, qo, kv_per_batch] fp32. ~3x overhead. + # The per-batch SDPA loop allocates one batch's worth at a time (kv_len + # divided by batch_size), then frees before the next iteration. + sdpa_causal_mem = ( + 3 * num_qo_heads * qo_len * max_kv_len_per_seq * 4 if causal else 0 + ) + # GQA expands K/V from H_kv to H_q heads for SDPA reference + gqa_ratio = num_qo_heads // num_kv_heads + # Sequential pages reuse K/V directly; scattered need a gathered copy + gathered_mem = 2 * num_blocks * elements_per_block * 2 if scatter_pages else 0 + required_mem = ( + 2 * num_blocks * elements_per_block * 2 # K/V bf16 + + 2 * num_blocks * elements_per_block * elem_size # kernel K/V (fp8 or bf16) + + gathered_mem + + 2 * num_blocks * elements_per_block * 2 * (gqa_ratio - 1) # GQA K/V expansion + + sdpa_causal_mem + ) + if free_mem < required_mem * 1.1: pytest.skip( f"Not enough GPU memory: need {required_mem / 1e9:.1f}GB, " f"have {free_mem / 1e9:.1f}GB" ) - # Allocate KV caches in linear layout: [num_blocks, page_size, num_kv_heads, head_dim] - k_cache_bf16 = torch.randn( - num_blocks, page_size, num_kv_heads, head_dim, device="cuda", dtype=dtype - ) - v_cache_bf16 = torch.randn( - num_blocks, page_size, num_kv_heads, head_dim, device="cuda", dtype=dtype + # Allocate KV caches in bf16 + # page_size=1 uses 3D linear layout [num_tokens, num_kv_heads, head_dim] + # page_size>1 uses 4D paged layout [num_blocks, page_size, num_kv_heads, head_dim] + if page_size == 1: + kv_shape = (num_blocks, num_kv_heads, head_dim) + else: + kv_shape = (num_blocks, page_size, num_kv_heads, head_dim) + + k_cache_bf16 = torch.randn(*kv_shape, device="cuda", dtype=dtype) + if scatter_pages: + # Use page-dependent V values to detect address wrapping bugs. + # With random V, wrong addresses read statistically similar data -> false pass. + # With V[page] ? page_index, wrapped addresses (low pages) give ~0 instead of + # the correct ~1 for high pages, making the error detectable. + page_vals = ( + torch.arange(num_blocks, device="cuda", dtype=torch.float32) / num_blocks + ) + if page_size == 1: + v_cache_bf16 = page_vals.view(-1, 1, 1).expand(*kv_shape).to(dtype) + else: + v_cache_bf16 = page_vals.view(-1, 1, 1, 1).expand(*kv_shape).to(dtype) + else: + v_cache_bf16 = torch.randn(*kv_shape, device="cuda", dtype=dtype) + + # Query: flat [total_qo_len, H_q, D] layout matching mha_batch_prefill_func + # input contract. Per-batch slices recovered via cu_seqlens_q in the loop below. + q_bf16 = torch.randn( + total_qo_len, num_qo_heads, head_dim, device="cuda", dtype=dtype ) - if is_fp8: - k_cache, k_descale = per_tensor_quant(k_cache_bf16, quant_dtype=dtypes.fp8) - v_cache, v_descale = per_tensor_quant(v_cache_bf16, quant_dtype=dtypes.fp8) + # Page indices: since the buffer exceeds INT32_MAX elements, these pages + # naturally span the overflow boundary. + overflow_page = INT32_MAX // stride_per_page + + if scatter_pages: + # Interleave: [0, N-1, 1, N-2, 2, N-3, ...] so adjacent logical tokens + # map to physically distant pages (low <-> high, spanning >2GB gap). + lo = torch.arange(0, num_blocks, 2, dtype=torch.int32) + hi = torch.arange(num_blocks - 1, -1, -2, dtype=torch.int32) + page_indices = torch.zeros(num_blocks, dtype=torch.int32) + page_indices[0::2] = lo[: (num_blocks + 1) // 2] + page_indices[1::2] = hi[: num_blocks // 2] else: - k_cache = k_cache_bf16 - v_cache = v_cache_bf16 + # Sequential: [0, 1, 2, ..., N-1] + page_indices = torch.arange(num_blocks, dtype=torch.int32) - # Test pages that span the overflow boundary - qo_len = 1 - kv_len = page_size # one full page + # --- Step 1: Compute SDPA reference FIRST (while bf16 data is alive) --- + # Per-batch loop: each iteration gathers its slice of pages, runs SDPA, + # and frees intermediates before the next batch. Keeps peak memory at + # one batch's worth (vs. materializing the full multi-batch score tensor). + o_ref_list = [] + page_offset = 0 + for b in range(batch_size): + n_blocks_b = blocks_per_seq[b] + page_slice_b = page_indices[page_offset : page_offset + n_blocks_b] + page_offset += n_blocks_b + kv_len_b = kv_lens_per_seq[b] + + # Always gather: even sequential pages need a per-batch slice to keep + # the multi-batch SDPA references aligned with the kernel's per-batch + # SRD rebase. (For batch_size=1 + sequential, this is just an alias + # of the full cache via the index slice.) + if page_size == 1: + k_ref_b = k_cache_bf16[page_slice_b.long()] + v_ref_b = v_cache_bf16[page_slice_b.long()] + else: + k_ref_b = k_cache_bf16[page_slice_b.long()].reshape( + -1, num_kv_heads, head_dim + ) + v_ref_b = v_cache_bf16[page_slice_b.long()].reshape( + -1, num_kv_heads, head_dim + ) - q_bf16 = torch.randn(qo_len, num_qo_heads, head_dim, device="cuda", dtype=dtype) - if is_fp8: - q, q_descale = per_tensor_quant(q_bf16, quant_dtype=dtypes.fp8) - else: - q = q_bf16 - cu_seqlens_q = torch.tensor([0, qo_len], device="cuda", dtype=torch.int32) + q_b = q_bf16[b * qo_len : (b + 1) * qo_len] - # Test at several page indices: before, at, and after the overflow boundary - overflow_page = INT32_MAX // stride_per_page - test_pages = [ - 0, - overflow_page - 1, - overflow_page, - overflow_page + 1, - num_blocks - 1, - ] - test_pages = [p for p in test_pages if 0 <= p < num_blocks] - # Remove duplicates while preserving order - test_pages = list(dict.fromkeys(test_pages)) - - threshold = 0.055 if is_fp8 else 0.01 - - for page_idx in test_pages: - offset = page_idx * stride_per_page - label = "OVERFLOW" if offset > INT32_MAX else "safe" - - kv_indptr = torch.tensor([0, 1], device="cuda", dtype=torch.int32) - kv_page_indices = torch.tensor([page_idx], device="cuda", dtype=torch.int32) - kv_last_page_lens = torch.tensor([page_size], device="cuda", dtype=torch.int32) - - extra_kwargs = {} - if is_fp8: - extra_kwargs = dict( - q_descale=q_descale, k_descale=k_descale, v_descale=v_descale + # SDPA expects [batch, heads, seq, dim] + q_sdpa = q_b.unsqueeze(0).transpose(1, 2) + k_sdpa = k_ref_b.unsqueeze(0).transpose(1, 2) + v_sdpa = v_ref_b.unsqueeze(0).transpose(1, 2) + del k_ref_b, v_ref_b + + # GQA: manual K/V head expansion (see comment in non-multi-batch + # equivalent removed in this commit -- using enable_gqa=True with + # causal attn_mask forces SDPA math backend and OOMs for large kv_len). + if num_qo_heads != num_kv_heads: + ratio = num_qo_heads // num_kv_heads + k_sdpa = k_sdpa.repeat_interleave(ratio, dim=1) + v_sdpa = v_sdpa.repeat_interleave(ratio, dim=1) + + sdpa_kwargs = {} + if causal: + # CK batch prefill causal: Q is at the END of the KV context. + # Q[i] can see K[j] where j <= (kv_len_b - qo_len) + i. + offset = kv_len_b - qo_len + row_idx = torch.arange(qo_len, device="cuda").unsqueeze(1) + col_idx = torch.arange(kv_len_b, device="cuda").unsqueeze(0) + sdpa_kwargs["attn_mask"] = col_idx <= (offset + row_idx) + + o_b = ( + torch.nn.functional.scaled_dot_product_attention( + q_sdpa, k_sdpa, v_sdpa, **sdpa_kwargs ) + .squeeze(0) + .transpose(0, 1) + ) + o_ref_list.append(o_b) + del q_sdpa, k_sdpa, v_sdpa, sdpa_kwargs + torch.cuda.empty_cache() - result = aiter.mha_batch_prefill_func( - q, - k_cache, - v_cache, - cu_seqlens_q, - kv_indptr, - kv_page_indices, - qo_len, - kv_len, - causal=causal, - kv_last_page_lens=kv_last_page_lens, - **extra_kwargs, + o_ref = torch.cat(o_ref_list, dim=0) + del o_ref_list + torch.cuda.empty_cache() + + # --- Step 2: Prepare kernel inputs (quantize for FP8, free bf16 after) --- + if is_fp8: + k_cache_kernel, k_descale = per_tensor_quant( + k_cache_bf16, quant_dtype=dtypes.fp8 + ) + v_cache_kernel, v_descale = per_tensor_quant( + v_cache_bf16, quant_dtype=dtypes.fp8 + ) + q_kernel, q_descale = per_tensor_quant(q_bf16, quant_dtype=dtypes.fp8) + del k_cache_bf16, v_cache_bf16, q_bf16 + torch.cuda.empty_cache() + else: + k_cache_kernel = k_cache_bf16 + v_cache_kernel = v_cache_bf16 + q_kernel = q_bf16 + + # Apply vectorized layout transformation if needed + if kv_layout == "vectorized" and page_size > 1: + kv_vector_size = 16 // k_cache_kernel.element_size() + k_cache_kernel, v_cache_kernel = apply_kv_layout( + k_cache_kernel, + v_cache_kernel, + num_kv_heads, + head_dim, + page_size, + kv_vector_size, + "vectorized", ) - out = result[0] if isinstance(result, (list, tuple)) else result - # Reference: direct attention on the original bf16 data - k_page = k_cache_bf16[page_idx] # [page_size, num_kv_heads, head_dim] - v_page = v_cache_bf16[page_idx] - o_ref = ref_masked_attention(q_bf16, k_page, v_page, causal=causal) + # Multi-batch indptrs: cu_seqlens_q is the cumulative qo offset per batch + # (uniform qo_len), kv_indptr is the cumulative page count per batch. + cu_seqlens_q = torch.tensor( + [0] + [(i + 1) * qo_len for i in range(batch_size)], + device="cuda", + dtype=torch.int32, + ) + kv_indptr = torch.tensor( + [0] + list(itertools.accumulate(blocks_per_seq)), + device="cuda", + dtype=torch.int32, + ) + # +256 padding is a batch_prefill ABI requirement: the kernel may speculatively + # read up to 256 entries past the last valid page index (one bn0=256 tile worth) + # before the bounds check kicks in. Padding with 0 keeps reads in-bounds; the + # values are masked out by causal/length logic and never affect the output. + kv_page_indices = torch.nn.functional.pad(page_indices, (0, 256), value=0).to( + "cuda" + ) + kv_last_page_lens = torch.tensor( + [page_size] * batch_size, device="cuda", dtype=torch.int32 + ) + + # --- Step 3: Run CK kernel --- + extra_kwargs = {} + if is_fp8: + extra_kwargs = dict( + q_descale=q_descale, k_descale=k_descale, v_descale=v_descale + ) - max_diff = (out - o_ref).abs().max().item() - assert max_diff < threshold, ( - f"[{input_dtype}] page {page_idx} (offset={offset}, {label}): " - f"max_diff={max_diff} exceeds threshold {threshold}" + result = aiter.mha_batch_prefill_func( + q_kernel, + k_cache_kernel, + v_cache_kernel, + cu_seqlens_q, + kv_indptr, + kv_page_indices, + qo_len, + max_kv_len_per_seq, + causal=causal, + kv_last_page_lens=kv_last_page_lens, + **extra_kwargs, + ) + # Synchronize immediately to catch async GPU faults from CK kernel before + # they cascade. Without this sync, an async fault can surface inside the + # next test's torch.cuda.empty_cache() (or any other CUDA call), causing + # the failure to be misattributed to that unrelated test -- and on bad + # faults the cascade can trigger a GPU reset that wipes out subsequent + # test results too. + torch.cuda.synchronize() + out = result[0] if isinstance(result, (list, tuple)) else result + + # Compare kernel output vs SDPA reference + if is_fp8: + verify_fp8_output(out, o_ref, threshold=0.055) + else: + rtol, atol = get_tolerances(dtype) + torch.testing.assert_close( + out, + o_ref, + rtol=rtol, + atol=atol, + msg=lambda msg: ( + f"[{input_dtype}] batch_size={batch_size} " + f"page_size={page_size} num_pages={num_blocks} " + f"(overflow at page {overflow_page}): {msg}" + ), ) From 33b7a4c62e695f9a588ef370d3eab7a7dc0f9538 Mon Sep 17 00:00:00 2001 From: la <46212055+junhaha666@users.noreply.github.com> Date: Fri, 24 Apr 2026 20:59:46 +0800 Subject: [PATCH 3/4] Fix top_k_per_row_prefill err when batched_token_numm > 4096 (#2901) --- csrc/kernels/topk_per_row_kernels.cu | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/csrc/kernels/topk_per_row_kernels.cu b/csrc/kernels/topk_per_row_kernels.cu index 89331c52df..6edf377ca8 100644 --- a/csrc/kernels/topk_per_row_kernels.cu +++ b/csrc/kernels/topk_per_row_kernels.cu @@ -1435,6 +1435,18 @@ __global__ void radix_topk_one_block_kernel(T const* in, return; } + // Long-row path: kernel internally treats in[0..row_len) as the valid + // window. Shift `in` (and `in_idx`) up by `rowStart` so that the radix + // pipeline reads the actual valid columns rather than the masked-out + // [0, rowStart) prefix that fp8_mqa_logits fills with -inf. Internal + // indices i are then relative to rowStart; we add rowStart back to + // out_idx at the end of this branch to get absolute column indices. + in += rowStart; + if(in_idx) + { + in_idx += rowStart; + } + const IdxT buf_len = calc_buf_len(len); bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); @@ -1522,6 +1534,23 @@ __global__ void radix_topk_one_block_kernel(T const* in, break; } } + + // Long-row path was using rowStart-relative indices inside the radix + // pipeline (because we shifted `in` by rowStart above). Translate them + // back to absolute column indices for downstream consumers. Sentinels + // (-1, written when fewer than k valid candidates exist) are preserved. + if(rowStart > 0) + { + __syncthreads(); + for(int i = threadIdx.x; i < k; i += BlockSize) + { + IdxT v = out_idx[i]; + if(v >= 0) + { + out_idx[i] = v + rowStart; + } + } + } } inline size_t calc_aligned_size(std::vector const& sizes) From 930c94120459bb352e1d7c68349b331b06397280 Mon Sep 17 00:00:00 2001 From: XiaobingZhang Date: Fri, 24 Apr 2026 22:57:00 +0800 Subject: [PATCH 4/4] revert gptoss tuned config (#2904) --------- Co-authored-by: zhuyuhua-v --- .../model_configs/gptoss_bf16_tuned_gemm.csv | 114 +++++++++--------- 1 file changed, 57 insertions(+), 57 deletions(-) diff --git a/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv b/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv index 5587bed4c7..1deaa684f9 100644 --- a/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv +++ b/aiter/configs/model_configs/gptoss_bf16_tuned_gemm.csv @@ -1,58 +1,58 @@ gfx,cu_num,M,N,K,bias,dtype,outdtype,scaleAB,bpreshuffle,libtype,solidx,splitK,us,kernelName,err_ratio,tflops,bw -gfx950,256,1,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,9,4.6262,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,0.16,160.67 -gfx950,256,2,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,9,4.6019,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0156,0.32,162.83 -gfx950,256,4,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,15,4.7202,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0137,0.62,161.29 -gfx950,256,8,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,15,4.7061,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0273,1.25,166.89 -gfx950,256,16,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,15,4.7651,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0273,2.48,174.93 -gfx950,256,32,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,42,15,4.6463,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp2_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0237,5.08,200.11 -gfx950,256,48,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,39,15,4.7651,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0262,7.43,215.33 -gfx950,256,64,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,188,15,4.7406,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0243,9.95,236.74 -gfx950,256,80,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,239,15,4.8623,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0253,12.13,250.61 -gfx950,256,96,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,242,15,5.3055,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp2_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0247,13.34,247.82 -gfx950,256,112,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,240,15,5.2015,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k15_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0246,15.88,271.28 -gfx950,256,128,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,205,9,5.6343,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0177,16.75,267.53 -gfx950,256,256,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,370,9,6.5478,flydsl_gemm2_abf16_wbf16_bf16_t32x64x64_split_k9_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.018,28.83,347.81 -gfx950,256,1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,243,6,8.5347,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k6_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0141,1.73,1729.0 -gfx950,256,2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,243,6,7.904,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k6_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0152,3.73,1868.34 -gfx950,256,4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,245,15,7.9144,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k15_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0266,7.45,1868.63 -gfx950,256,8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,245,15,8.2005,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k15_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0251,14.39,1808.75 -gfx950,256,16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,268,9,8.614,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0182,27.39,1732.03 -gfx950,256,32,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.0478,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0099,46.96,1502.2 -gfx950,256,48,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.5933,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0073,66.81,1441.27 -gfx950,256,64,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.7011,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0074,88.19,1443.02 -gfx950,256,128,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,11.7861,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0074,160.14,1369.26 -gfx950,256,256,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,2,15.3912,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0041,245.26,1139.02 -gfx950,256,1,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,139,16,6.0968,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur16_gfx950,0.0212,1.93,1936.48 -gfx950,256,1,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,388,16,8.3638,flydsl_gemm2_abf16_wbf16_bf16_t16x192x64_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe2_ur8_gfx950,0.0316,2.82,2822.51 -gfx950,256,2,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,138,16,6.0853,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe2_ur8_gfx950,0.0224,3.88,1941.76 -gfx950,256,2,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,139,16,8.2195,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur16_gfx950,0.0269,5.74,2873.76 -gfx950,256,4,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,140,16,6.1308,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur8_gfx950,0.0224,7.7,1930.56 -gfx950,256,4,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,138,16,8.2605,flydsl_gemm2_abf16_wbf16_bf16_t16x192x128_split_k16_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe2_ur8_gfx950,0.0292,11.42,2862.87 -gfx950,256,8,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,771,8,6.726,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_wpe4_ur8_gfx950,0.0161,14.03,1765.59 -gfx950,256,8,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,792,8,8.6182,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsTrue_b_preshuffleFalse_c_to_ldsFalse_small_m_ur16_gfx950,0.0195,21.9,2750.53 -gfx950,256,16,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,750,4,6.9944,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0103,26.98,1709.11 -gfx950,256,16,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,798,8,9.0841,flydsl_gemm2_abf16_wbf16_bf16_t16x64x256_split_k8_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0205,41.55,2621.74 -gfx950,256,32,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,145,4,8.0743,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0104,46.75,1500.05 -gfx950,256,32,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,164,4,10.4391,flydsl_gemm2_abf16_wbf16_bf16_t32x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0132,72.32,2302.83 -gfx950,256,48,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,353,4,9.5624,flydsl_gemm2_abf16_wbf16_bf16_t64x64x256_split_k4_block_m_warp1_block_n_warp4_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_gfx950,0.0103,59.21,1283.11 -gfx950,256,48,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,5,12.1652,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0136,93.09,1994.43 -gfx950,256,64,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,10.045,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0066,75.16,1237.16 -gfx950,256,64,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,5,12.6691,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0136,119.18,1932.73 -gfx950,256,80,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,5,12.8662,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0137,146.7,1920.47 -gfx950,256,96,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,5,13.1474,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0136,172.27,1896.37 -gfx950,256,112,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.299,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,152.75,1454.16 -gfx950,256,128,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,11.5282,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0035,130.98,1132.7 -gfx950,256,128,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.0438,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,177.18,1489.04 -gfx950,256,256,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,6,1,15.42,_ZN5aiter37bf16gemm_fp32bf16_tn_64x64_pf3_splitkE,0.0,195.84,928.64 -gfx950,256,1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,247,9,10.4898,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0182,2.81,2812.94 -gfx950,256,2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,247,9,9.6817,flydsl_gemm2_abf16_wbf16_bf16_t16x64x160_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0194,6.09,3049.38 -gfx950,256,4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,271,9,10.2067,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0201,11.56,2895.67 -gfx950,256,8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,271,9,11.2498,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0189,20.97,2632.86 -gfx950,256,16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,flydsl,271,9,11.9689,flydsl_gemm2_abf16_wbf16_bf16_t16x64x64_split_k9_block_m_warp1_block_n_warp2_async_copyTrue_b_to_ldsFalse_b_preshuffleFalse_c_to_ldsFalse_small_m_gfx950,0.0191,39.42,2485.37 -gfx950,256,32,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.2241,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,77.2,2454.43 -gfx950,256,48,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,3,12.9649,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0076,109.19,2333.93 -gfx950,256,64,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,13.6686,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0077,138.09,2232.5 -gfx950,256,80,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,3,13.957,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0076,169.04,2204.71 -gfx950,256,96,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,3,14.3466,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0076,197.34,2162.69 -gfx950,256,112,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.3332,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,170.85,1618.11 -gfx950,256,128,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.5488,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,193.1,1613.36 +gfx950,256,1,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9558,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0234,0.15,149.99 +gfx950,256,2,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,4.9466,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0195,0.3,151.48 +gfx950,256,4,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9687,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0156,0.59,153.23 +gfx950,256,8,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,14,4.9927,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0176,1.18,157.31 +gfx950,256,16,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,5.031,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0171,2.34,165.68 +gfx950,256,32,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,13,4.6354,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0203,5.09,200.59 +gfx950,256,48,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,14,5.2547,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0212,6.73,195.26 +gfx950,256,64,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,13,5.3561,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.021,8.81,209.54 +gfx950,256,80,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,13,5.6419,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0218,10.45,215.98 +gfx950,256,96,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,9,5.7166,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0163,12.38,230.0 +gfx950,256,112,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,9,5.9183,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0165,13.95,238.43 +gfx950,256,128,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,9,5.97,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0172,15.81,252.48 +gfx950,256,256,128,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,6,7.0187,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0126,26.89,324.47 +gfx950,256,1,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.6772,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0105,1.52,1524.87 +gfx950,256,2,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.8371,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0107,3.0,1501.19 +gfx950,256,4,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.8551,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0118,5.98,1500.66 +gfx950,256,8,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.9035,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0101,11.91,1497.72 +gfx950,256,16,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.0897,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.01,23.38,1478.7 +gfx950,256,32,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,10.2307,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.01,46.12,1475.34 +gfx950,256,48,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.7334,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0073,65.94,1422.46 +gfx950,256,64,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,10.6753,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0074,88.4,1446.51 +gfx950,256,128,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,11.6504,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0074,162.01,1385.21 +gfx950,256,256,2560,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,2,15.3742,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0041,245.53,1140.28 +gfx950,256,1,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.0917,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0083,1.3,1298.58 +gfx950,256,1,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,10.0871,auto,0.0,2.34,2340.31 +gfx950,256,2,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.2607,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0056,2.55,1275.95 +gfx950,256,2,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,9.6663,auto,0.0,4.88,2443.63 +gfx950,256,4,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.1797,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0091,5.14,1289.36 +gfx950,256,4,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,triton,0,0,10.1637,auto,0.0,9.29,2326.79 +gfx950,256,8,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,4,9.2315,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0092,10.22,1286.39 +gfx950,256,8,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.4653,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0132,16.46,2067.51 +gfx950,256,16,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.3253,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0062,20.24,1281.91 +gfx950,256,16,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.591,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0136,32.57,2054.71 +gfx950,256,32,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,9.2671,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0065,40.73,1306.98 +gfx950,256,32,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,5,11.7821,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0137,64.08,2040.33 +gfx950,256,48,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,4,9.8145,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.009,57.69,1250.15 +gfx950,256,48,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,5,12.1519,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0136,93.19,1996.61 +gfx950,256,64,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,4,10.1075,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.009,74.69,1229.51 +gfx950,256,64,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,5,12.6689,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0137,119.19,1932.76 +gfx950,256,80,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,5,13.1447,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0137,143.59,1879.78 +gfx950,256,96,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,5,13.5787,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0136,166.8,1836.14 +gfx950,256,112,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.4374,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,151.54,1442.62 +gfx950,256,128,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,11.3635,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0035,132.88,1149.12 +gfx950,256,128,2880,4096,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,2,17.3233,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0048,174.33,1465.01 +gfx950,256,256,2880,2048,True,torch.bfloat16,torch.bfloat16,False,False,asm,6,1,15.5453,_ZN5aiter37bf16gemm_fp32bf16_tn_64x64_pf3_splitkE,0.0,194.26,921.15 +gfx950,256,1,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.5513,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,2.55,2554.45 +gfx950,256,2,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.4061,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0075,5.17,2588.37 +gfx950,256,4,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.0536,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0075,9.79,2451.98 +gfx950,256,8,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,11.96,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,19.73,2476.52 +gfx950,256,16,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.2044,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0077,38.66,2437.42 +gfx950,256,32,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,1,3,12.3457,_ZN5aiter39bf16gemm_fp32bf16_tn_32x64_splitk_cleanE,0.0078,76.44,2430.26 +gfx950,256,48,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,3,3,13.0203,_ZN5aiter39bf16gemm_fp32bf16_tn_48x64_splitk_cleanE,0.0076,108.72,2324.0 +gfx950,256,64,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,5,3,13.4435,_ZN5aiter39bf16gemm_fp32bf16_tn_64x64_splitk_cleanE,0.0077,140.4,2269.89 +gfx950,256,80,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,7,3,14.2374,_ZN5aiter39bf16gemm_fp32bf16_tn_80x64_splitk_cleanE,0.0076,165.71,2161.29 +gfx950,256,96,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,8,3,14.4747,_ZN5aiter39bf16gemm_fp32bf16_tn_96x64_splitk_cleanE,0.0076,195.59,2143.55 +gfx950,256,112,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.1801,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,172.21,1631.02 +gfx950,256,128,5120,2880,True,torch.bfloat16,torch.bfloat16,False,False,asm,4,1,19.3026,_ZN5aiter37bf16gemm_fp32bf16_tn_48x64_pf3_splitkE,0.0,195.56,1633.94