From faebb77f03725194027c8bc40f9eb937c144cdcf Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Fri, 17 Apr 2026 16:55:38 +0800 Subject: [PATCH 1/9] Add subgroup topk kernel for XPU Add an optimized topk kernel path where each 32-lane sub-group processes one slice entirely in registers via insertion sort + bitonic merge. Zero SLM, zero barriers. Output is already sorted. Constraints: k <= 16, large enough batch (nsegments >= HW_threads/4). Compile-time template dispatch on largest (direction) and IndexT (int32/int64). Kernel isolated in a separate translation unit to avoid SYCL compiler interference with the original kernel. 432/432 accuracy tests pass, 324/324 sortedness tests pass. --- src/ATen/native/xpu/sycl/TensorTopKKernel.cpp | 32 +- .../xpu/sycl/TensorTopKSbtopkKernel.cpp | 367 ++++++++++++++++++ .../native/xpu/sycl/TensorTopKSbtopkKernel.h | 54 +++ 3 files changed, 442 insertions(+), 11 deletions(-) create mode 100644 src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp create mode 100644 src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h diff --git a/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp index 3042f685fd..9a108d51a9 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp @@ -16,6 +16,7 @@ #include #include +#include namespace at { namespace native { @@ -113,17 +114,26 @@ void topk_kernel( const scalar_t* self_ptr = self_.const_data_ptr(); scalar_t* values_ptr = values_.data_ptr(); int64_t* indices_ptr = indices_.data_ptr(); - segmented_group_select_pairs( - self_ptr, - values_ptr, - nullptr, - (int64_t*)indices_ptr, - nsegments, - nelements, - k, - largest); - - if (sorted) { + + SbtopkResult sbtopk_result = sbtopk_try_launch( + self_, nsegments, nelements, k, largest, values_, indices_); + if (sbtopk_result == SbtopkResult::FAILED) { + segmented_group_select_pairs( + self_ptr, + values_ptr, + nullptr, + (int64_t*)indices_ptr, + nsegments, + nelements, + k, + largest); + } + + // Only sort if the user asked for sorted output AND sbtopk did not + // already produce a sorted result. The subgroup topk kernel returns + // SORTED; the single workgroup kernel and the original radix select + // return UNSORTED (or FAILED). + if (sorted && sbtopk_result != SbtopkResult::SORTED) { segmented_sort_pairs( values_ptr, values_ptr, diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp new file mode 100644 index 0000000000..c14a2a19ef --- /dev/null +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -0,0 +1,367 @@ +/* + * Subgroup top-k kernel for optimized topk on XPU. + * + * Each sub-group (32 lanes) handles one slice entirely in registers. + * Zero SLM, zero barriers. + * + * Algorithm: + * Phase 1: Each lane scans dim/32 elements, maintains a sorted top-k + * buffer via insertion sort (fully unrolled, no branches on + * direction thanks to compile-time Largest template param). + * Phase 2: 5 levels of pairwise bitonic merge via sub-group shuffles + * to combine 32 per-lane buffers into one global top-k. + * Phase 3: Lane 0 writes k results. Output is already sorted. + * + * Dispatch: k <= 16 and enough segments (large batch) and dim >= 1024 + * routes to subgroup top-k; otherwise falls back to original. + */ + +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace xpu { + +static constexpr int SG_SIZE = 32; + +// ================================================================ +// SubgroupTopKFunctor +// +// K: compile-time max top-k (must be >= runtime k) +// VEC_SIZE: vectorized load width +// Largest: compile-time direction flag. Eliminates per-element branches +// on largest_ that otherwise pessimize the tight insert/merge loops. +// IndexT: int32 for nsegments <= INT_MAX (common), int64_t for huge batch. +// Mirrors CUDA's canUse32BitIndexMath dispatch. Only numSlices_ and +// the slice variable need IndexT; sliceSize_ and k_ stay int because +// nelements is already checked <= INT_MAX. +// ================================================================ +template +struct SubgroupTopKFunctor { + + // Insert val into a K-sorted buffer. For Largest=true the buffer is sorted + // descending (top_vals[0] is max); for Largest=false it is sorted ascending + // (top_vals[0] is min). The comparator `better(a, b)` means "a should sit + // above b in the buffer" — i.e. strictly greater for largest, strictly less + // for smallest. + // + // Fully unrolled, no early break — SIMD-friendly. + inline void insert( + scalar_t* top_vals, int* top_idx, int count, + scalar_t val, int idx) const { + // Threshold is at the bottom of the buffer (top_vals[K-1]). + if constexpr (Largest) { + if (count >= K && !(val > top_vals[K - 1])) return; + } else { + if (count >= K && !(val < top_vals[K - 1])) return; + } + bool inserted = false; +#pragma unroll + for (int i = K - 1; i >= 0; --i) { + bool stop; + if constexpr (Largest) { + stop = (i == 0) || !(val > top_vals[i - 1]); + } else { + stop = (i == 0) || !(val < top_vals[i - 1]); + } + if (!inserted && stop) { + top_vals[i] = val; + top_idx[i] = idx; + inserted = true; + } else if (!inserted) { + top_vals[i] = top_vals[i - 1]; + top_idx[i] = top_idx[i - 1]; + } + } + } + + // Bitonic merge: A[K] and B[K] are both sorted in the "better" direction. + // Step 1: A[i] = better(A[i], B[K-1-i]) — produces bitonic sequence. + // Step 2: bitonic sort restores the sorted-by-better order on A. + inline void bitonic_merge( + scalar_t* A, int* A_idx, + const scalar_t* B, const int* B_idx) const { + // Step 1: compare with reversed partner +#pragma unroll + for (int i = 0; i < K; ++i) { + scalar_t bv = B[K - 1 - i]; + int bi = B_idx[K - 1 - i]; + bool take; + if constexpr (Largest) { + take = bv > A[i]; + } else { + take = bv < A[i]; + } + if (take) { + A[i] = bv; + A_idx[i] = bi; + } + } + // Step 2: bitonic sort in the "better" direction +#pragma unroll + for (int stride = K / 2; stride >= 1; stride >>= 1) { +#pragma unroll + for (int i = 0; i < K; ++i) { + int j = i ^ stride; + bool swap; + if constexpr (Largest) { + swap = (j > i) && (A[i] < A[j]); + } else { + swap = (j > i) && (A[i] > A[j]); + } + if (swap) { + scalar_t tv = A[i]; A[i] = A[j]; A[j] = tv; + int ti = A_idx[i]; A_idx[i] = A_idx[j]; A_idx[j] = ti; + } + } + } + } + + [[sycl::reqd_sub_group_size(32)]] + void operator()(sycl::nd_item<1> item) const { + sycl::sub_group sg = item.get_sub_group(); + int sg_lid = sg.get_local_linear_id(); + + // Each sub-group handles one slice + int sgs_per_wg = item.get_local_range(0) / SG_SIZE; + IndexT slice = static_cast(item.get_group_linear_id()) * sgs_per_wg + + sg.get_group_linear_id(); + if (slice >= numSlices_) return; + + const scalar_t* inputSlice = inputData_ + static_cast(slice) * sliceSize_; + scalar_t* topKSlice = topKData_ + static_cast(slice) * k_; + int64_t* indicesSlice = indicesData_ + static_cast(slice) * k_; + + // Initialize sorted top-K buffer + scalar_t top_vals[K]; + int top_idx_local[K]; + scalar_t init_val; + if constexpr (Largest) { + init_val = -std::numeric_limits::infinity(); + } else { + init_val = std::numeric_limits::infinity(); + } +#pragma unroll + for (int i = 0; i < K; ++i) { + top_vals[i] = init_val; + top_idx_local[i] = -1; + } + int count = 0; + + // ---- Phase 1: scan data with vec loads ---- + using LoadT = memory::aligned_vector; + int stride = SG_SIZE * VEC_SIZE; + + int base; + for (base = sg_lid * VEC_SIZE; base + VEC_SIZE <= sliceSize_; base += stride) { + scalar_t src[VEC_SIZE]; + *reinterpret_cast(&src) = + *reinterpret_cast(&inputSlice[base]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + insert(top_vals, top_idx_local, count, src[v], base + v); + if (count < K) count++; + } + } + // Scalar tail + for (int idx = base; idx < sliceSize_ && idx < base + VEC_SIZE; ++idx) { + scalar_t val = inputSlice[idx]; + insert(top_vals, top_idx_local, count, val, idx); + if (count < K) count++; + } + + // ---- Phase 2: sub-group bitonic merge (5 levels for sg_size=32) ---- +#pragma unroll + for (int d = 0; d < 5; ++d) { + int partner = sg_lid ^ (1 << d); + + scalar_t partner_vals[K]; + int partner_idx[K]; +#pragma unroll + for (int i = 0; i < K; ++i) { + partner_vals[i] = sycl::select_from_group(sg, top_vals[i], partner); + partner_idx[i] = sycl::select_from_group(sg, top_idx_local[i], partner); + } + + bitonic_merge(top_vals, top_idx_local, partner_vals, partner_idx); + } + + // ---- Phase 3: lane 0 writes output ---- + if (sg_lid == 0) { + for (int i = 0; i < k_; ++i) { + topKSlice[i] = top_vals[i]; + indicesSlice[i] = static_cast(top_idx_local[i]); + } + } + } + + SubgroupTopKFunctor( + const scalar_t* inputData, + scalar_t* topKData, + int64_t* indicesData, + IndexT numSlices, + int sliceSize, + int k) + : inputData_(inputData), + topKData_(topKData), + indicesData_(indicesData), + numSlices_(numSlices), + sliceSize_(sliceSize), + k_(k) {} + + const scalar_t* inputData_; + scalar_t* topKData_; + int64_t* indicesData_; + IndexT numSlices_; + int sliceSize_; + int k_; +}; + +// ================================================================ +// Launch helpers +// ================================================================ +template +static void sbtopk_launch_impl( + const scalar_t* input, + scalar_t* topK, + int64_t* indices, + IndexT numSlices, + int sliceSize, + int k) { + constexpr int WG_SIZE = 256; // 8 sub-groups per work-group + constexpr int SGS_PER_WG = WG_SIZE / SG_SIZE; + auto num_wgs = (static_cast(numSlices) + SGS_PER_WG - 1) / SGS_PER_WG; + + SubgroupTopKFunctor functor( + input, topK, indices, numSlices, sliceSize, k); + + sycl_kernel_submit( + sycl::range<1>(num_wgs * WG_SIZE), + sycl::range<1>(WG_SIZE), + at::xpu::getCurrentSYCLQueue(), + functor); +} + +// Vec-size dispatch: picks the largest VEC_SIZE compatible with (dtype, sliceSize). +template +static void sbtopk_launch_vec_dispatch( + const scalar_t* input, + scalar_t* topK, + int64_t* indices, + IndexT numSlices, + int sliceSize, + int k) { + // Max VEC_SIZE for this dtype + constexpr int MAX_VEC = sizeof(scalar_t) <= 2 ? 8 : 4; + + // Pick largest VEC_SIZE such that: + // 1. SG_SIZE * VEC_SIZE <= sliceSize (all threads get at least one full vector) + // 2. sliceSize % VEC_SIZE == 0 (slice boundaries are aligned) + if (MAX_VEC >= 8 && sliceSize % 8 == 0 && SG_SIZE * 8 <= sliceSize) { + sbtopk_launch_impl(input, topK, indices, numSlices, sliceSize, k); + } else if (MAX_VEC >= 4 && sliceSize % 4 == 0 && SG_SIZE * 4 <= sliceSize) { + sbtopk_launch_impl(input, topK, indices, numSlices, sliceSize, k); + } else if (sliceSize % 2 == 0 && SG_SIZE * 2 <= sliceSize) { + sbtopk_launch_impl(input, topK, indices, numSlices, sliceSize, k); + } else { + sbtopk_launch_impl(input, topK, indices, numSlices, sliceSize, k); + } +} + +template +static void sbtopk_launch_kernel( + const scalar_t* input, + scalar_t* topK, + int64_t* indices, + int64_t numSlices, + int sliceSize, + int k, + bool largest) { + constexpr int K = 16; + // Dispatch on (largest, IndexT) at the outermost level: + // - largest: so tight insert/merge loops have no runtime direction branch + // - IndexT: int32 when nsegments fits (common), int64 for huge batch. + // Mirrors CUDA's canUse32BitIndexMath. int32 avoids 64-bit slice + // arithmetic and reduces register pressure. + if (numSlices <= std::numeric_limits::max()) { + int numSlices32 = static_cast(numSlices); + if (largest) { + sbtopk_launch_vec_dispatch( + input, topK, indices, numSlices32, sliceSize, k); + } else { + sbtopk_launch_vec_dispatch( + input, topK, indices, numSlices32, sliceSize, k); + } + } else { + if (largest) { + sbtopk_launch_vec_dispatch( + input, topK, indices, numSlices, sliceSize, k); + } else { + sbtopk_launch_vec_dispatch( + input, topK, indices, numSlices, sliceSize, k); + } + } +} + +// ================================================================ +// Dispatch: subgroup top-k vs original +// +// From benchmark on B580 (3 dtypes x 6 bs x 4 dims x 2 align x 3 k): +// - dim < 1024: original wins (kernel launch overhead dominates) +// - dim >= 1024, bs >= ~320 segments, k <= 16: subgroup top-k wins +// ================================================================ +SbtopkResult sbtopk_try_launch( + const at::Tensor& self, + int64_t nsegments, + int64_t nelements, + int64_t k, + bool largest, + const at::Tensor& values, + const at::Tensor& indices) { + // Not beneficial for small dim + if (nelements < 1024) { + return SbtopkResult::FAILED; + } + + // Subgroup top-k: best for large batch, k<=16. + // Output is ALREADY SORTED (descending for largest, ascending for smallest). + // + // Threshold: nsegments >= thread_slots / 4. + // Subgroup top-k uses 1 sub-group per slice (reading data once), while + // the original kernel reads data multiple times (~3 radix passes). So + // subgroup top-k reaches memory-BW saturation at much lower occupancy. + // thread_slots/4 is the conservative cutoff. + // + // On B580: thread_slots = 160 EU * 8 HW threads = 1280, threshold = 320. + int64_t thread_slots = ::xpu::sycl::syclGpuEuCount() * ::xpu::sycl::syclGpuHWThreadsPerEU(); + int64_t sg_threshold = thread_slots / 4; + if (k <= 16 && nsegments >= sg_threshold) { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "subgroup_topk_xpu", + [&]() { + sbtopk_launch_kernel( + static_cast(self.const_data_ptr()), + static_cast(values.data_ptr()), + static_cast(indices.data_ptr()), + nsegments, + static_cast(nelements), + static_cast(k), + largest); + }); + return SbtopkResult::SORTED; + } + + return SbtopkResult::FAILED; +} + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h new file mode 100644 index 0000000000..713be5a7f7 --- /dev/null +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h @@ -0,0 +1,54 @@ +/* + * Copyright 2020-2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +#pragma once + +#include + +namespace at { +namespace native { +namespace xpu { + +// Result of sbtopk_try_launch. +// FAILED - did not run; caller should fall back to original kernel. +// UNSORTED - ran; output contains top-k values but is not sorted. +// Caller must sort if sorted output is requested. +// SORTED - ran; output is already sorted (descending for largest, +// ascending for smallest). Caller can skip sort. +enum class SbtopkResult : int { + FAILED = 0, + UNSORTED = 1, + SORTED = 2, +}; + +// Try to run topk using the subgroup topk kernel. +// +// This function is compiled in a separate translation unit +// (TensorTopKSbtopkKernel.cpp) to isolate the kernel's template +// instantiations from the original topk kernel. The SYCL compiler's global +// optimization decisions are affected by the total set of templates in a +// compilation unit; keeping them separate prevents regressing the original +// kernel's performance on small-dim cases where the optimized path is not +// even used. +// +// Currently dispatches to the subgroup topk kernel (sub-group bitonic merge, +// output SORTED) when k <= 16 and batch size is large enough. +TORCH_XPU_API SbtopkResult sbtopk_try_launch( + const at::Tensor& self, + int64_t nsegments, + int64_t nelements, + int64_t k, + bool largest, + const at::Tensor& values, + const at::Tensor& indices); + +} // namespace xpu +} // namespace native +} // namespace at From c970ea52f403e23c10ad26e8f833f5e21aaf54be Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Fri, 17 Apr 2026 20:48:46 +0800 Subject: [PATCH 2/9] Address review: kernel properties, integer safety, comment cleanup - Add kernel properties (sub_group_size<32>, grf_size<128>) to launch for explicit sub-group size and smaller GRF (better occupancy) - Fix std::numeric_limits::infinity() for integer dtypes: use lowest()/max() when has_infinity is false - Add #include - Clarify insert() idx param is within-slice (int, bounded by sliceSize) - Shorten header comment for sbtopk_try_launch - Fix TensorTopKKernel.cpp comment (remove single-wg kernel reference) 432/432 accuracy, 324/324 sortedness pass. --- src/ATen/native/xpu/sycl/TensorTopKKernel.cpp | 7 ++-- .../xpu/sycl/TensorTopKSbtopkKernel.cpp | 35 ++++++++++++++----- .../native/xpu/sycl/TensorTopKSbtopkKernel.h | 14 ++------ 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp index 9a108d51a9..b2ed7949c9 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp @@ -129,10 +129,9 @@ void topk_kernel( largest); } - // Only sort if the user asked for sorted output AND sbtopk did not - // already produce a sorted result. The subgroup topk kernel returns - // SORTED; the single workgroup kernel and the original radix select - // return UNSORTED (or FAILED). + // Only sort if the user asked for sorted output AND the optimized + // kernel did not already produce a sorted result. The subgroup topk + // kernel returns SORTED; the original kernel returns FAILED. if (sorted && sbtopk_result != SbtopkResult::SORTED) { segmented_sort_pairs( values_ptr, diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index c14a2a19ef..6876fadde8 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -18,15 +18,19 @@ #include #include +#include +#include #include #include -#include #include namespace at { namespace native { namespace xpu { +namespace syclex = sycl::ext::oneapi::experimental; +namespace intelex = sycl::ext::intel::experimental; + static constexpr int SG_SIZE = 32; // ================================================================ @@ -50,6 +54,9 @@ struct SubgroupTopKFunctor { // above b in the buffer" — i.e. strictly greater for largest, strictly less // for smallest. // + // idx is the element position within a slice (0..sliceSize-1), always int + // because sliceSize <= INT_MAX. IndexT is only for the slice count. + // // Fully unrolled, no early break — SIMD-friendly. inline void insert( scalar_t* top_vals, int* top_idx, int count, @@ -142,9 +149,17 @@ struct SubgroupTopKFunctor { int top_idx_local[K]; scalar_t init_val; if constexpr (Largest) { - init_val = -std::numeric_limits::infinity(); + if constexpr (std::numeric_limits::has_infinity) { + init_val = -std::numeric_limits::infinity(); + } else { + init_val = std::numeric_limits::lowest(); + } } else { - init_val = std::numeric_limits::infinity(); + if constexpr (std::numeric_limits::has_infinity) { + init_val = std::numeric_limits::infinity(); + } else { + init_val = std::numeric_limits::max(); + } } #pragma unroll for (int i = 0; i < K; ++i) { @@ -240,11 +255,15 @@ static void sbtopk_launch_impl( SubgroupTopKFunctor functor( input, topK, indices, numSlices, sliceSize, k); - sycl_kernel_submit( - sycl::range<1>(num_wgs * WG_SIZE), - sycl::range<1>(WG_SIZE), - at::xpu::getCurrentSYCLQueue(), - functor); + auto q = at::xpu::getCurrentSYCLQueue(); + q.submit([&](sycl::handler& cgh) { + cgh.parallel_for( + sycl::nd_range<1>(num_wgs * WG_SIZE, WG_SIZE), + syclex::properties{ + syclex::sub_group_size, + intelex::grf_size<128>}, + functor); + }); } // Vec-size dispatch: picks the largest VEC_SIZE compatible with (dtype, sliceSize). diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h index 713be5a7f7..ac241a3f43 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h @@ -28,18 +28,8 @@ enum class SbtopkResult : int { SORTED = 2, }; -// Try to run topk using the subgroup topk kernel. -// -// This function is compiled in a separate translation unit -// (TensorTopKSbtopkKernel.cpp) to isolate the kernel's template -// instantiations from the original topk kernel. The SYCL compiler's global -// optimization decisions are affected by the total set of templates in a -// compilation unit; keeping them separate prevents regressing the original -// kernel's performance on small-dim cases where the optimized path is not -// even used. -// -// Currently dispatches to the subgroup topk kernel (sub-group bitonic merge, -// output SORTED) when k <= 16 and batch size is large enough. +// Try to run topk using the subgroup topk kernel (separate TU to avoid +// SYCL compiler interference with the original kernel's codegen). TORCH_XPU_API SbtopkResult sbtopk_try_launch( const at::Tensor& self, int64_t nsegments, From d4daf78f076a77352e0d6931e956b24e43def506 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Fri, 17 Apr 2026 21:20:26 +0800 Subject: [PATCH 3/9] Address review: use IndexT for element indices, add bitonic sort comment, drop TORCH_XPU_API --- .../xpu/sycl/TensorTopKSbtopkKernel.cpp | 159 ++++++++++++------ .../native/xpu/sycl/TensorTopKSbtopkKernel.h | 8 +- 2 files changed, 108 insertions(+), 59 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index 6876fadde8..0bc7670d14 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -18,11 +18,11 @@ #include #include -#include -#include #include -#include #include +#include +#include +#include namespace at { namespace native { @@ -40,32 +40,38 @@ static constexpr int SG_SIZE = 32; // VEC_SIZE: vectorized load width // Largest: compile-time direction flag. Eliminates per-element branches // on largest_ that otherwise pessimize the tight insert/merge loops. -// IndexT: int32 for nsegments <= INT_MAX (common), int64_t for huge batch. -// Mirrors CUDA's canUse32BitIndexMath dispatch. Only numSlices_ and -// the slice variable need IndexT; sliceSize_ and k_ stay int because -// nelements is already checked <= INT_MAX. +// IndexT: int32 when total elements (nsegments * nelements) <= INT_MAX +// (common case), int64_t otherwise. Mirrors CUDA's +// canUse32BitIndexMath. int32 avoids 64-bit arithmetic on slice +// indices and element indices, reducing register pressure. // ================================================================ -template +template < + typename scalar_t, + int K, + int VEC_SIZE, + bool Largest, + typename IndexT = int> struct SubgroupTopKFunctor { - // Insert val into a K-sorted buffer. For Largest=true the buffer is sorted // descending (top_vals[0] is max); for Largest=false it is sorted ascending // (top_vals[0] is min). The comparator `better(a, b)` means "a should sit // above b in the buffer" — i.e. strictly greater for largest, strictly less // for smallest. // - // idx is the element position within a slice (0..sliceSize-1), always int - // because sliceSize <= INT_MAX. IndexT is only for the slice count. - // // Fully unrolled, no early break — SIMD-friendly. inline void insert( - scalar_t* top_vals, int* top_idx, int count, - scalar_t val, int idx) const { + scalar_t* top_vals, + IndexT* top_idx, + int count, + scalar_t val, + IndexT idx) const { // Threshold is at the bottom of the buffer (top_vals[K-1]). if constexpr (Largest) { - if (count >= K && !(val > top_vals[K - 1])) return; + if (count >= K && !(val > top_vals[K - 1])) + return; } else { - if (count >= K && !(val < top_vals[K - 1])) return; + if (count >= K && !(val < top_vals[K - 1])) + return; } bool inserted = false; #pragma unroll @@ -91,13 +97,15 @@ struct SubgroupTopKFunctor { // Step 1: A[i] = better(A[i], B[K-1-i]) — produces bitonic sequence. // Step 2: bitonic sort restores the sorted-by-better order on A. inline void bitonic_merge( - scalar_t* A, int* A_idx, - const scalar_t* B, const int* B_idx) const { + scalar_t* A, + IndexT* A_idx, + const scalar_t* B, + const IndexT* B_idx) const { // Step 1: compare with reversed partner #pragma unroll for (int i = 0; i < K; ++i) { scalar_t bv = B[K - 1 - i]; - int bi = B_idx[K - 1 - i]; + IndexT bi = B_idx[K - 1 - i]; bool take; if constexpr (Largest) { take = bv > A[i]; @@ -109,7 +117,27 @@ struct SubgroupTopKFunctor { A_idx[i] = bi; } } - // Step 2: bitonic sort in the "better" direction + // Step 2: bitonic sort — standard bitonic merge network. + // + // After step 1, A[0..K-1] is bitonic (first decreasing then increasing, + // or vice versa). At stride = K/2 we compare A[i] with A[i + K/2] + // for i in [0, K/2) and swap so the "better" value goes to the low + // half. This guarantees: + // (a) every element in A[0..K/2-1] >= every element in A[K/2..K-1], + // (b) each half is itself bitonic (splitting a bitonic sequence at + // the midpoint with min/max produces two bitonic subsequences). + // Recurse with stride K/4, K/8, ..., 1 and each sub-piece halves + // again, until every piece has length 1 — the array is sorted. + // + // j = i ^ stride pairs each element with its partner at distance + // `stride`. The guard j > i ensures each pair is processed once. + // + // Example for K = 16: + // stride 8: (0,8) (1,9) (2,10) ... (7,15) — 8 pairs + // stride 4: (0,4) (1,5) (2,6) (3,7) — two groups of 4 + // (8,12) (9,13) (10,14) (11,15) + // stride 2: (0,2) (1,3) (4,6) (5,7) ... — four groups of 2 + // stride 1: (0,1) (2,3) (4,5) ... (14,15) — 8 adjacent pairs #pragma unroll for (int stride = K / 2; stride >= 1; stride >>= 1) { #pragma unroll @@ -122,31 +150,37 @@ struct SubgroupTopKFunctor { swap = (j > i) && (A[i] > A[j]); } if (swap) { - scalar_t tv = A[i]; A[i] = A[j]; A[j] = tv; - int ti = A_idx[i]; A_idx[i] = A_idx[j]; A_idx[j] = ti; + scalar_t tv = A[i]; + A[i] = A[j]; + A[j] = tv; + IndexT ti = A_idx[i]; + A_idx[i] = A_idx[j]; + A_idx[j] = ti; } } } } - [[sycl::reqd_sub_group_size(32)]] void operator()(sycl::nd_item<1> item) const { sycl::sub_group sg = item.get_sub_group(); int sg_lid = sg.get_local_linear_id(); // Each sub-group handles one slice int sgs_per_wg = item.get_local_range(0) / SG_SIZE; - IndexT slice = static_cast(item.get_group_linear_id()) * sgs_per_wg - + sg.get_group_linear_id(); - if (slice >= numSlices_) return; + IndexT slice = + static_cast(item.get_group_linear_id()) * sgs_per_wg + + sg.get_group_linear_id(); + if (slice >= numSlices_) + return; - const scalar_t* inputSlice = inputData_ + static_cast(slice) * sliceSize_; + const scalar_t* inputSlice = + inputData_ + static_cast(slice) * sliceSize_; scalar_t* topKSlice = topKData_ + static_cast(slice) * k_; int64_t* indicesSlice = indicesData_ + static_cast(slice) * k_; // Initialize sorted top-K buffer scalar_t top_vals[K]; - int top_idx_local[K]; + IndexT top_idx_local[K]; scalar_t init_val; if constexpr (Largest) { if constexpr (std::numeric_limits::has_infinity) { @@ -172,22 +206,32 @@ struct SubgroupTopKFunctor { using LoadT = memory::aligned_vector; int stride = SG_SIZE * VEC_SIZE; - int base; - for (base = sg_lid * VEC_SIZE; base + VEC_SIZE <= sliceSize_; base += stride) { + int64_t base; + for (base = sg_lid * VEC_SIZE; base + VEC_SIZE <= sliceSize_; + base += stride) { scalar_t src[VEC_SIZE]; *reinterpret_cast(&src) = *reinterpret_cast(&inputSlice[base]); #pragma unroll for (int v = 0; v < VEC_SIZE; ++v) { - insert(top_vals, top_idx_local, count, src[v], base + v); - if (count < K) count++; + insert( + top_vals, + top_idx_local, + count, + src[v], + static_cast(base + v)); + if (count < K) + count++; } } // Scalar tail - for (int idx = base; idx < sliceSize_ && idx < base + VEC_SIZE; ++idx) { + for (IndexT idx = static_cast(base); + idx < sliceSize_ && idx < base + VEC_SIZE; + ++idx) { scalar_t val = inputSlice[idx]; insert(top_vals, top_idx_local, count, val, idx); - if (count < K) count++; + if (count < K) + count++; } // ---- Phase 2: sub-group bitonic merge (5 levels for sg_size=32) ---- @@ -196,7 +240,7 @@ struct SubgroupTopKFunctor { int partner = sg_lid ^ (1 << d); scalar_t partner_vals[K]; - int partner_idx[K]; + IndexT partner_idx[K]; #pragma unroll for (int i = 0; i < K; ++i) { partner_vals[i] = sycl::select_from_group(sg, top_vals[i], partner); @@ -220,7 +264,7 @@ struct SubgroupTopKFunctor { scalar_t* topKData, int64_t* indicesData, IndexT numSlices, - int sliceSize, + int64_t sliceSize, int k) : inputData_(inputData), topKData_(topKData), @@ -233,7 +277,7 @@ struct SubgroupTopKFunctor { scalar_t* topKData_; int64_t* indicesData_; IndexT numSlices_; - int sliceSize_; + int64_t sliceSize_; int k_; }; @@ -246,11 +290,12 @@ static void sbtopk_launch_impl( scalar_t* topK, int64_t* indices, IndexT numSlices, - int sliceSize, + int64_t sliceSize, int k) { constexpr int WG_SIZE = 256; // 8 sub-groups per work-group constexpr int SGS_PER_WG = WG_SIZE / SG_SIZE; - auto num_wgs = (static_cast(numSlices) + SGS_PER_WG - 1) / SGS_PER_WG; + auto num_wgs = + (static_cast(numSlices) + SGS_PER_WG - 1) / SGS_PER_WG; SubgroupTopKFunctor functor( input, topK, indices, numSlices, sliceSize, k); @@ -260,35 +305,40 @@ static void sbtopk_launch_impl( cgh.parallel_for( sycl::nd_range<1>(num_wgs * WG_SIZE, WG_SIZE), syclex::properties{ - syclex::sub_group_size, - intelex::grf_size<128>}, + syclex::sub_group_size, intelex::grf_size<128>}, functor); }); } -// Vec-size dispatch: picks the largest VEC_SIZE compatible with (dtype, sliceSize). +// Vec-size dispatch: picks the largest VEC_SIZE compatible with +// (dtype, sliceSize). template static void sbtopk_launch_vec_dispatch( const scalar_t* input, scalar_t* topK, int64_t* indices, IndexT numSlices, - int sliceSize, + int64_t sliceSize, int k) { // Max VEC_SIZE for this dtype constexpr int MAX_VEC = sizeof(scalar_t) <= 2 ? 8 : 4; // Pick largest VEC_SIZE such that: - // 1. SG_SIZE * VEC_SIZE <= sliceSize (all threads get at least one full vector) + // 1. SG_SIZE * VEC_SIZE <= sliceSize (all threads get at + // least one full vector) // 2. sliceSize % VEC_SIZE == 0 (slice boundaries are aligned) if (MAX_VEC >= 8 && sliceSize % 8 == 0 && SG_SIZE * 8 <= sliceSize) { - sbtopk_launch_impl(input, topK, indices, numSlices, sliceSize, k); + sbtopk_launch_impl( + input, topK, indices, numSlices, sliceSize, k); } else if (MAX_VEC >= 4 && sliceSize % 4 == 0 && SG_SIZE * 4 <= sliceSize) { - sbtopk_launch_impl(input, topK, indices, numSlices, sliceSize, k); + sbtopk_launch_impl( + input, topK, indices, numSlices, sliceSize, k); } else if (sliceSize % 2 == 0 && SG_SIZE * 2 <= sliceSize) { - sbtopk_launch_impl(input, topK, indices, numSlices, sliceSize, k); + sbtopk_launch_impl( + input, topK, indices, numSlices, sliceSize, k); } else { - sbtopk_launch_impl(input, topK, indices, numSlices, sliceSize, k); + sbtopk_launch_impl( + input, topK, indices, numSlices, sliceSize, k); } } @@ -298,16 +348,16 @@ static void sbtopk_launch_kernel( scalar_t* topK, int64_t* indices, int64_t numSlices, - int sliceSize, + int64_t sliceSize, int k, bool largest) { constexpr int K = 16; // Dispatch on (largest, IndexT) at the outermost level: // - largest: so tight insert/merge loops have no runtime direction branch - // - IndexT: int32 when nsegments fits (common), int64 for huge batch. - // Mirrors CUDA's canUse32BitIndexMath. int32 avoids 64-bit slice + // - IndexT: int32 when total elements fit (common), int64 otherwise. + // Mirrors CUDA's canUse32BitIndexMath. int32 avoids 64-bit // arithmetic and reduces register pressure. - if (numSlices <= std::numeric_limits::max()) { + if (numSlices * sliceSize <= std::numeric_limits::max()) { int numSlices32 = static_cast(numSlices); if (largest) { sbtopk_launch_vec_dispatch( @@ -357,7 +407,8 @@ SbtopkResult sbtopk_try_launch( // thread_slots/4 is the conservative cutoff. // // On B580: thread_slots = 160 EU * 8 HW threads = 1280, threshold = 320. - int64_t thread_slots = ::xpu::sycl::syclGpuEuCount() * ::xpu::sycl::syclGpuHWThreadsPerEU(); + int64_t thread_slots = + ::xpu::sycl::syclGpuEuCount() * ::xpu::sycl::syclGpuHWThreadsPerEU(); int64_t sg_threshold = thread_slots / 4; if (k <= 16 && nsegments >= sg_threshold) { AT_DISPATCH_ALL_TYPES_AND2( @@ -371,7 +422,7 @@ SbtopkResult sbtopk_try_launch( static_cast(values.data_ptr()), static_cast(indices.data_ptr()), nsegments, - static_cast(nelements), + nelements, static_cast(k), largest); }); diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h index ac241a3f43..031fa64be5 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h @@ -23,14 +23,12 @@ namespace xpu { // SORTED - ran; output is already sorted (descending for largest, // ascending for smallest). Caller can skip sort. enum class SbtopkResult : int { - FAILED = 0, + FAILED = 0, UNSORTED = 1, - SORTED = 2, + SORTED = 2, }; -// Try to run topk using the subgroup topk kernel (separate TU to avoid -// SYCL compiler interference with the original kernel's codegen). -TORCH_XPU_API SbtopkResult sbtopk_try_launch( +SbtopkResult sbtopk_try_launch( const at::Tensor& self, int64_t nsegments, int64_t nelements, From 8c3f054fc8c104e569f6b11e28384c2aa28307db Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Sat, 18 Apr 2026 14:16:04 +0800 Subject: [PATCH 4/9] Remove benchmark source details from dispatch comment --- src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index 0bc7670d14..5378cfe6e8 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -380,9 +380,8 @@ static void sbtopk_launch_kernel( // ================================================================ // Dispatch: subgroup top-k vs original // -// From benchmark on B580 (3 dtypes x 6 bs x 4 dims x 2 align x 3 k): -// - dim < 1024: original wins (kernel launch overhead dominates) -// - dim >= 1024, bs >= ~320 segments, k <= 16: subgroup top-k wins +// - dim < 1024: original (kernel launch overhead dominates) +// - dim >= 1024, large batch, k <= 16: subgroup top-k // ================================================================ SbtopkResult sbtopk_try_launch( const at::Tensor& self, From 6aa63c95dfc999eeef22aa64ed33bea2c9dcd657 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Sat, 18 Apr 2026 15:16:32 +0800 Subject: [PATCH 5/9] Fix sentinel-value insert bug, add alignas and pointer alignment check - insert(): add count-aware stop condition so input values equal to the sentinel (e.g. all -inf for largest=true) fill the buffer correctly instead of repeatedly overwriting position K-1 - Add alignas(alignof(LoadT)) on local vectorized-load array - Add pointer alignment check in vec dispatch to safely fall back to scalar loads when input has a non-aligned storage offset --- .../xpu/sycl/TensorTopKSbtopkKernel.cpp | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index 5378cfe6e8..52aed664d2 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace at { @@ -76,11 +77,18 @@ struct SubgroupTopKFunctor { bool inserted = false; #pragma unroll for (int i = K - 1; i >= 0; --i) { + // When count < K the buffer is partially filled: positions [0, count) + // hold real values while [count, K) still contain sentinels. + // Guard "i <= count" ensures we only stop at position i when + // top_vals[i-1] is a real entry. Without this guard, an input + // value equal to the sentinel (e.g. all -inf for largest=true) + // would always stop at position K-1, overwriting it repeatedly + // instead of filling lower positions. bool stop; if constexpr (Largest) { - stop = (i == 0) || !(val > top_vals[i - 1]); + stop = (i == 0) || (i <= count && !(val > top_vals[i - 1])); } else { - stop = (i == 0) || !(val < top_vals[i - 1]); + stop = (i == 0) || (i <= count && !(val < top_vals[i - 1])); } if (!inserted && stop) { top_vals[i] = val; @@ -209,7 +217,7 @@ struct SubgroupTopKFunctor { int64_t base; for (base = sg_lid * VEC_SIZE; base + VEC_SIZE <= sliceSize_; base += stride) { - scalar_t src[VEC_SIZE]; + alignas(alignof(LoadT)) scalar_t src[VEC_SIZE]; *reinterpret_cast(&src) = *reinterpret_cast(&inputSlice[base]); #pragma unroll @@ -327,13 +335,22 @@ static void sbtopk_launch_vec_dispatch( // 1. SG_SIZE * VEC_SIZE <= sliceSize (all threads get at // least one full vector) // 2. sliceSize % VEC_SIZE == 0 (slice boundaries are aligned) - if (MAX_VEC >= 8 && sliceSize % 8 == 0 && SG_SIZE * 8 <= sliceSize) { + // 3. input pointer is aligned to sizeof(scalar_t) * VEC_SIZE + // (usually guaranteed by PyTorch allocators, but a non-zero + // storage offset can break alignment) + auto input_align = reinterpret_cast(input); + if (MAX_VEC >= 8 && sliceSize % 8 == 0 && SG_SIZE * 8 <= sliceSize && + input_align % (sizeof(scalar_t) * 8) == 0) { sbtopk_launch_impl( input, topK, indices, numSlices, sliceSize, k); - } else if (MAX_VEC >= 4 && sliceSize % 4 == 0 && SG_SIZE * 4 <= sliceSize) { + } else if ( + MAX_VEC >= 4 && sliceSize % 4 == 0 && SG_SIZE * 4 <= sliceSize && + input_align % (sizeof(scalar_t) * 4) == 0) { sbtopk_launch_impl( input, topK, indices, numSlices, sliceSize, k); - } else if (sliceSize % 2 == 0 && SG_SIZE * 2 <= sliceSize) { + } else if ( + sliceSize % 2 == 0 && SG_SIZE * 2 <= sliceSize && + input_align % (sizeof(scalar_t) * 2) == 0) { sbtopk_launch_impl( input, topK, indices, numSlices, sliceSize, k); } else { From ad407ac4b002a27108ac92f847bc692943a3406c Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Mon, 20 Apr 2026 13:18:06 +0800 Subject: [PATCH 6/9] Lower subgroup topk dispatch threshold from dim>=1024 to dim>=32 Benchmarks show subgroup kernel is 2-4x faster than original even for small dims (32-512) when batch size is large. The previous dim>=1024 guard was overly conservative. The only hard requirement is dim>=SG_SIZE (32) so each lane gets at least one element. --- src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index 52aed664d2..22307fbadd 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -12,7 +12,7 @@ * to combine 32 per-lane buffers into one global top-k. * Phase 3: Lane 0 writes k results. Output is already sorted. * - * Dispatch: k <= 16 and enough segments (large batch) and dim >= 1024 + * Dispatch: k <= 16 and enough segments (large batch) and dim >= 32 * routes to subgroup top-k; otherwise falls back to original. */ @@ -397,8 +397,8 @@ static void sbtopk_launch_kernel( // ================================================================ // Dispatch: subgroup top-k vs original // -// - dim < 1024: original (kernel launch overhead dominates) -// - dim >= 1024, large batch, k <= 16: subgroup top-k +// - dim < 32: original (need at least SG_SIZE elements) +// - dim >= 32, large batch, k <= 16: subgroup top-k // ================================================================ SbtopkResult sbtopk_try_launch( const at::Tensor& self, @@ -408,8 +408,8 @@ SbtopkResult sbtopk_try_launch( bool largest, const at::Tensor& values, const at::Tensor& indices) { - // Not beneficial for small dim - if (nelements < 1024) { + // Subgroup kernel needs at least SG_SIZE (32) elements per slice + if (nelements < 32) { return SbtopkResult::FAILED; } From 3e536e040cf6ee56decf1c0bc6d31cc3718e38fd Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Tue, 21 Apr 2026 09:02:13 +0800 Subject: [PATCH 7/9] Dispatch on smallest power-of-two K to reduce register pressure Select K from {1, 2, 4, 8, 16} based on runtime k (round up to next power of two). Smaller K means fewer unrolled iterations in insert/merge/shuffle loops, dramatically reducing register pressure. K<=8 eliminates all register spills on B580 (GRF 128) across fp32, fp16, and bf16. For k=4 this gives 3-11x speedup over the previous fixed K=16 path; k=16 takes the same K=16 template as before (no regression). --- .../xpu/sycl/TensorTopKSbtopkKernel.cpp | 82 +++++++++++++------ 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index 22307fbadd..b256166052 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -368,30 +368,66 @@ static void sbtopk_launch_kernel( int64_t sliceSize, int k, bool largest) { - constexpr int K = 16; - // Dispatch on (largest, IndexT) at the outermost level: - // - largest: so tight insert/merge loops have no runtime direction branch - // - IndexT: int32 when total elements fit (common), int64 otherwise. - // Mirrors CUDA's canUse32BitIndexMath. int32 avoids 64-bit - // arithmetic and reduces register pressure. - if (numSlices * sliceSize <= std::numeric_limits::max()) { - int numSlices32 = static_cast(numSlices); - if (largest) { - sbtopk_launch_vec_dispatch( - input, topK, indices, numSlices32, sliceSize, k); - } else { - sbtopk_launch_vec_dispatch( - input, topK, indices, numSlices32, sliceSize, k); - } - } else { - if (largest) { - sbtopk_launch_vec_dispatch( - input, topK, indices, numSlices, sliceSize, k); - } else { - sbtopk_launch_vec_dispatch( - input, topK, indices, numSlices, sliceSize, k); - } + // Dispatch on (K, Largest, IndexT). + // K: smallest power-of-two >= k. Smaller K means fewer unrolled + // iterations in insert/merge, less register pressure, zero spills. + // K=1 is a valid special case (top-1 = max/min element). + // Largest: compile-time direction eliminates per-element branches. + // IndexT: int32 when total elements fit (common), int64 otherwise. + + // Select K: round up k to next power of two, clamped to [1, 16]. + int K_sel; + if (k <= 1) + K_sel = 1; + else if (k <= 2) + K_sel = 2; + else if (k <= 4) + K_sel = 4; + else if (k <= 8) + K_sel = 8; + else + K_sel = 16; + +#define SBTOPK_DISPATCH_K(K_VAL, LARGEST, INDEX_T, NUM_SLICES) \ + sbtopk_launch_vec_dispatch( \ + input, topK, indices, NUM_SLICES, sliceSize, k) + +#define SBTOPK_DISPATCH_LARGEST(K_VAL, INDEX_T, NUM_SLICES) \ + if (largest) { \ + SBTOPK_DISPATCH_K(K_VAL, true, INDEX_T, NUM_SLICES); \ + } else { \ + SBTOPK_DISPATCH_K(K_VAL, false, INDEX_T, NUM_SLICES); \ + } + +#define SBTOPK_DISPATCH_INDEX(K_VAL) \ + if (numSlices * sliceSize <= std::numeric_limits::max()) { \ + int numSlices32 = static_cast(numSlices); \ + SBTOPK_DISPATCH_LARGEST(K_VAL, int, numSlices32); \ + } else { \ + SBTOPK_DISPATCH_LARGEST(K_VAL, int64_t, numSlices); \ } + + switch (K_sel) { + case 1: + SBTOPK_DISPATCH_INDEX(1); + break; + case 2: + SBTOPK_DISPATCH_INDEX(2); + break; + case 4: + SBTOPK_DISPATCH_INDEX(4); + break; + case 8: + SBTOPK_DISPATCH_INDEX(8); + break; + default: + SBTOPK_DISPATCH_INDEX(16); + break; + } + +#undef SBTOPK_DISPATCH_INDEX +#undef SBTOPK_DISPATCH_LARGEST +#undef SBTOPK_DISPATCH_K } // ================================================================ From bde1bdb9c393f5f939179ddbb88d166fdcdc2e2f Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Sun, 26 Apr 2026 16:56:55 +0800 Subject: [PATCH 8/9] Use overflow-safe check for IndexT dispatch instead of nsegments*nelements --- .../native/xpu/sycl/TensorTopKSbtopkKernel.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index b256166052..dfd151a5b8 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -399,12 +399,15 @@ static void sbtopk_launch_kernel( SBTOPK_DISPATCH_K(K_VAL, false, INDEX_T, NUM_SLICES); \ } -#define SBTOPK_DISPATCH_INDEX(K_VAL) \ - if (numSlices * sliceSize <= std::numeric_limits::max()) { \ - int numSlices32 = static_cast(numSlices); \ - SBTOPK_DISPATCH_LARGEST(K_VAL, int, numSlices32); \ - } else { \ - SBTOPK_DISPATCH_LARGEST(K_VAL, int64_t, numSlices); \ +#define SBTOPK_DISPATCH_INDEX(K_VAL) \ + if (numSlices <= std::numeric_limits::max() && \ + sliceSize <= std::numeric_limits::max() && \ + numSlices <= \ + std::numeric_limits::max() / (sliceSize > 0 ? sliceSize : 1)) { \ + int numSlices32 = static_cast(numSlices); \ + SBTOPK_DISPATCH_LARGEST(K_VAL, int, numSlices32); \ + } else { \ + SBTOPK_DISPATCH_LARGEST(K_VAL, int64_t, numSlices); \ } switch (K_sel) { From 3a7de357bbff121d5f67ae9f16ec3b9fe50fce1b Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Mon, 27 Apr 2026 20:39:18 +0800 Subject: [PATCH 9/9] Address review: use PowerOf2Ceil, sycl_kernel_submit, simplify dispatch - Replace K_sel if-else chain with c10::llvm::PowerOf2Ceil + std::min - Replace q.submit with sycl_kernel_submit + kernel properties - Add sycl_kernel_submit overloads accepting properties to SYCLHelpers.h - Simplify SBTOPK_DISPATCH_INDEX: only check numSlices <= INT_MAX (IndexT is only used for slice indices, not cross-slice global indices) - Add SG_MERGE_LEVELS constexpr + static_assert, replace magic number 5 - Refactor vec dispatch with can_use_vec lambda - Update IndexT comment to reflect simplified dispatch condition --- .../xpu/sycl/TensorTopKSbtopkKernel.cpp | 88 +++++++++---------- src/comm/SYCLHelpers.h | 78 ++++++++++++++++ 2 files changed, 119 insertions(+), 47 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index dfd151a5b8..9dc97ea4f2 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -5,22 +5,25 @@ * Zero SLM, zero barriers. * * Algorithm: - * Phase 1: Each lane scans dim/32 elements, maintains a sorted top-k + * Phase 1: Each lane scans nelements/32 elements, maintains a sorted top-k * buffer via insertion sort (fully unrolled, no branches on * direction thanks to compile-time Largest template param). * Phase 2: 5 levels of pairwise bitonic merge via sub-group shuffles * to combine 32 per-lane buffers into one global top-k. * Phase 3: Lane 0 writes k results. Output is already sorted. * - * Dispatch: k <= 16 and enough segments (large batch) and dim >= 32 + * Dispatch: k <= 16 and enough segments (large batch) and nelements >= 32 * routes to subgroup top-k; otherwise falls back to original. */ #include #include +#include #include #include +#include #include +#include #include #include #include @@ -33,6 +36,11 @@ namespace syclex = sycl::ext::oneapi::experimental; namespace intelex = sycl::ext::intel::experimental; static constexpr int SG_SIZE = 32; +// Number of pairwise merge levels in Phase 2 = log2(SG_SIZE). +static constexpr int SG_MERGE_LEVELS = 5; +static_assert( + (1 << SG_MERGE_LEVELS) == SG_SIZE, + "SG_MERGE_LEVELS must equal log2(SG_SIZE)"); // ================================================================ // SubgroupTopKFunctor @@ -41,10 +49,9 @@ static constexpr int SG_SIZE = 32; // VEC_SIZE: vectorized load width // Largest: compile-time direction flag. Eliminates per-element branches // on largest_ that otherwise pessimize the tight insert/merge loops. -// IndexT: int32 when total elements (nsegments * nelements) <= INT_MAX -// (common case), int64_t otherwise. Mirrors CUDA's -// canUse32BitIndexMath. int32 avoids 64-bit arithmetic on slice -// indices and element indices, reducing register pressure. +// IndexT: int32 when numSlices <= INT_MAX (common case), int64_t +// otherwise. int32 avoids 64-bit arithmetic on slice indices, +// reducing register pressure. // ================================================================ template < typename scalar_t, @@ -242,9 +249,9 @@ struct SubgroupTopKFunctor { count++; } - // ---- Phase 2: sub-group bitonic merge (5 levels for sg_size=32) ---- + // ---- Phase 2: sub-group bitonic merge ---- #pragma unroll - for (int d = 0; d < 5; ++d) { + for (int d = 0; d < SG_MERGE_LEVELS; ++d) { int partner = sg_lid ^ (1 << d); scalar_t partner_vals[K]; @@ -302,20 +309,19 @@ static void sbtopk_launch_impl( int k) { constexpr int WG_SIZE = 256; // 8 sub-groups per work-group constexpr int SGS_PER_WG = WG_SIZE / SG_SIZE; - auto num_wgs = - (static_cast(numSlices) + SGS_PER_WG - 1) / SGS_PER_WG; + auto num_wgs = at::ceil_div( + static_cast(numSlices), static_cast(SGS_PER_WG)); SubgroupTopKFunctor functor( input, topK, indices, numSlices, sliceSize, k); - auto q = at::xpu::getCurrentSYCLQueue(); - q.submit([&](sycl::handler& cgh) { - cgh.parallel_for( - sycl::nd_range<1>(num_wgs * WG_SIZE, WG_SIZE), - syclex::properties{ - syclex::sub_group_size, intelex::grf_size<128>}, - functor); - }); + sycl_kernel_submit( + num_wgs * WG_SIZE, + WG_SIZE, + at::xpu::getCurrentSYCLQueue(), + syclex::properties{ + syclex::sub_group_size, intelex::grf_size<128>}, + functor); } // Vec-size dispatch: picks the largest VEC_SIZE compatible with @@ -339,18 +345,18 @@ static void sbtopk_launch_vec_dispatch( // (usually guaranteed by PyTorch allocators, but a non-zero // storage offset can break alignment) auto input_align = reinterpret_cast(input); - if (MAX_VEC >= 8 && sliceSize % 8 == 0 && SG_SIZE * 8 <= sliceSize && - input_align % (sizeof(scalar_t) * 8) == 0) { + auto can_use_vec = [&](int vec) { + return MAX_VEC >= vec && sliceSize % vec == 0 && + SG_SIZE * vec <= sliceSize && + input_align % (sizeof(scalar_t) * vec) == 0; + }; + if (can_use_vec(8)) { sbtopk_launch_impl( input, topK, indices, numSlices, sliceSize, k); - } else if ( - MAX_VEC >= 4 && sliceSize % 4 == 0 && SG_SIZE * 4 <= sliceSize && - input_align % (sizeof(scalar_t) * 4) == 0) { + } else if (can_use_vec(4)) { sbtopk_launch_impl( input, topK, indices, numSlices, sliceSize, k); - } else if ( - sliceSize % 2 == 0 && SG_SIZE * 2 <= sliceSize && - input_align % (sizeof(scalar_t) * 2) == 0) { + } else if (can_use_vec(2)) { sbtopk_launch_impl( input, topK, indices, numSlices, sliceSize, k); } else { @@ -373,20 +379,11 @@ static void sbtopk_launch_kernel( // iterations in insert/merge, less register pressure, zero spills. // K=1 is a valid special case (top-1 = max/min element). // Largest: compile-time direction eliminates per-element branches. - // IndexT: int32 when total elements fit (common), int64 otherwise. + // IndexT: int32 when numSlices fits (common), int64 otherwise. // Select K: round up k to next power of two, clamped to [1, 16]. - int K_sel; - if (k <= 1) - K_sel = 1; - else if (k <= 2) - K_sel = 2; - else if (k <= 4) - K_sel = 4; - else if (k <= 8) - K_sel = 8; - else - K_sel = 16; + int K_sel = std::min( + static_cast(c10::llvm::PowerOf2Ceil(static_cast(k))), 16); #define SBTOPK_DISPATCH_K(K_VAL, LARGEST, INDEX_T, NUM_SLICES) \ sbtopk_launch_vec_dispatch( \ @@ -399,15 +396,12 @@ static void sbtopk_launch_kernel( SBTOPK_DISPATCH_K(K_VAL, false, INDEX_T, NUM_SLICES); \ } -#define SBTOPK_DISPATCH_INDEX(K_VAL) \ - if (numSlices <= std::numeric_limits::max() && \ - sliceSize <= std::numeric_limits::max() && \ - numSlices <= \ - std::numeric_limits::max() / (sliceSize > 0 ? sliceSize : 1)) { \ - int numSlices32 = static_cast(numSlices); \ - SBTOPK_DISPATCH_LARGEST(K_VAL, int, numSlices32); \ - } else { \ - SBTOPK_DISPATCH_LARGEST(K_VAL, int64_t, numSlices); \ +#define SBTOPK_DISPATCH_INDEX(K_VAL) \ + if (numSlices <= std::numeric_limits::max()) { \ + int numSlices32 = static_cast(numSlices); \ + SBTOPK_DISPATCH_LARGEST(K_VAL, int, numSlices32); \ + } else { \ + SBTOPK_DISPATCH_LARGEST(K_VAL, int64_t, numSlices); \ } switch (K_sel) { diff --git a/src/comm/SYCLHelpers.h b/src/comm/SYCLHelpers.h index 6904037cbf..dd56981c68 100644 --- a/src/comm/SYCLHelpers.h +++ b/src/comm/SYCLHelpers.h @@ -148,6 +148,84 @@ sycl_kernel_submit( q.submit(cgf); } +// Overloads accepting kernel properties (e.g., sub_group_size, grf_size). + +template +static inline typename std::enable_if< + std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, ker_t>, + void>::type +sycl_kernel_submit( + ::sycl::range global_range, + ::sycl::range local_range, + ::sycl::queue q, + Props properties, + ker_t ker) { + auto cgf = [&](::sycl::handler& cgh) { + ker.sycl_ker_config_convention(cgh); + cgh.parallel_for( + ::sycl::nd_range(global_range, local_range), properties, ker); + }; + q.submit(cgf); +} + +template +static inline typename std::enable_if< + !std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, ker_t>, + void>::type +sycl_kernel_submit( + ::sycl::range global_range, + ::sycl::range local_range, + ::sycl::queue q, + Props properties, + ker_t ker) { + auto cgf = [&](::sycl::handler& cgh) { + cgh.parallel_for( + ::sycl::nd_range(global_range, local_range), properties, ker); + }; + q.submit(cgf); +} + +template +static inline typename std::enable_if< + std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, ker_t>, + void>::type +sycl_kernel_submit( + int64_t global_range, + int64_t local_range, + ::sycl::queue q, + Props properties, + ker_t ker) { + auto cgf = [&](::sycl::handler& cgh) { + ker.sycl_ker_config_convention(cgh); + cgh.parallel_for( + ::sycl::nd_range<1>( + ::sycl::range<1>(global_range), ::sycl::range<1>(local_range)), + properties, + ker); + }; + q.submit(cgf); +} + +template +static inline typename std::enable_if< + !std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, ker_t>, + void>::type +sycl_kernel_submit( + int64_t global_range, + int64_t local_range, + ::sycl::queue q, + Props properties, + ker_t ker) { + auto cgf = [&](::sycl::handler& cgh) { + cgh.parallel_for( + ::sycl::nd_range<1>( + ::sycl::range<1>(global_range), ::sycl::range<1>(local_range)), + properties, + ker); + }; + q.submit(cgf); +} + #ifdef __SYCL_DEVICE_ONLY__ #define SYCL_KERNEL_STRING(var, str) \ static const __attribute__((opencl_constant)) char var[] = str