From d013cd3f73a41b1bb61364582ee7479b2cf346d2 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Fri, 17 Apr 2026 20:55:02 +0800 Subject: [PATCH 01/10] Add single workgroup topk kernel for XPU (from CUDA single-block path) SYCL translation of PyTorch CUDA's single-block radix select path. A 1024-thread workgroup processes one slice using RADIX_BITS=4 radix select to find the k-th value, then gathers matching elements. Output is unsorted. Best for large dim (>= 4096). Dispatch updated: dim < 1024 -> original; k <= 16 + large batch -> subgroup kernel; dim >= 4096 -> single workgroup kernel; else -> original. Also fixes NaN handling in SortingRadixSelect.h for half/float/double. 432/432 accuracy tests pass, 324/324 sortedness tests pass. --- src/ATen/native/xpu/sycl/SortingRadixSelect.h | 10 +- .../xpu/sycl/TensorTopKSbtopkKernel.cpp | 27 +- .../native/xpu/sycl/TensorTopKSbtopkKernel.h | 11 +- .../xpu/sycl/TensorTopKSingleWgKernel.cpp | 793 ++++++++++++++++++ .../xpu/sycl/TensorTopKSingleWgKernel.h | 23 + 5 files changed, 852 insertions(+), 12 deletions(-) create mode 100644 src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp create mode 100644 src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.h diff --git a/src/ATen/native/xpu/sycl/SortingRadixSelect.h b/src/ATen/native/xpu/sycl/SortingRadixSelect.h index 52d18a43fc..f8569c1bdf 100644 --- a/src/ATen/native/xpu/sycl/SortingRadixSelect.h +++ b/src/ATen/native/xpu/sycl/SortingRadixSelect.h @@ -87,7 +87,7 @@ struct TopKTypeConfig { static inline RadixType convert(float v) { RadixType x = *((uint32_t*)&v); RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; - return (x ^ mask); + return (v == v) ? (x ^ mask) : 0xffffffff; } static inline float deconvert(RadixType v) { @@ -168,7 +168,7 @@ struct TopKTypeConfig { static inline RadixType convert(double v) { RadixType x = *((uint64_t*)&v); RadixType mask = -((x >> 63)) | 0x8000000000000000; - return (x ^ mask); + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; } static inline double deconvert(RadixType v) { @@ -183,12 +183,12 @@ struct TopKTypeConfig { static inline RadixType convert(at::Half v) { RadixType x = *((uint16_t*)&v); - RadixType mask = -((x >> 15)) | 0x8000; - return (x ^ mask); + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; } static inline at::Half deconvert(RadixType v) { - RadixType mask = ((v >> 15) - 1) | 0x8000; + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; return __ushort_as_half(v ^ mask); } }; diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index 298cfbc410..e6c550fbdf 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -68,10 +69,13 @@ static bool subgroup_topk_try_launch( #undef SBTOPK_LAUNCH // ================================================================ -// Dispatch: subgroup top-k vs original +// Dispatch: subgroup top-k vs single-workgroup top-k vs original // // - dim < 32: original (need at least SG_SIZE elements) -// - dim >= 32, large batch, k <= 16: subgroup top-k +// - dim >= 32, large batch, k <= 8: subgroup top-k +// - dim >= 4096, any bs: single-workgroup top-k +// - dim >= 4096, k > 8: single-workgroup top-k +// (subgroup only supports k <= 8) // ================================================================ SbtopkResult sbtopk_try_launch( const at::Tensor& self, @@ -86,12 +90,12 @@ SbtopkResult sbtopk_try_launch( return SbtopkResult::FAILED; } - // Subgroup top-k: best for large batch, k<=8. + // Subgroup top-k: best for large batch, k <= 8. // Output is ALREADY SORTED (descending for largest, ascending for smallest). // // Threshold: nsegments >= thread_slots / 4. // Subgroup top-k uses 1 sub-group per slice (reading data once), while - // the original kernel reads data multiple times (~3 radix passes). So + // single-wg/original read data multiple times (~3 radix passes). So // subgroup top-k reaches memory-BW saturation at much lower occupancy. // thread_slots/4 is the conservative cutoff. // @@ -107,6 +111,21 @@ SbtopkResult sbtopk_try_launch( return SbtopkResult::FAILED; } + // Single-workgroup top-k: radix select, good for large + // dim. Output is UNSORTED. + // Use nelements >= 4096 so dim=1024/1025 falls through to original + // (single-wg has regressions at dim=1024 for medium/large batches). + // Single-wg uses int for numSlices internally; reject + // if nsegments overflows int32. + if (nelements >= 4096 && nsegments <= std::numeric_limits::max()) { + if (single_wg_topk_try_launch( + self, nsegments, nelements, k, largest, values, indices)) { + return SbtopkResult::UNSORTED; + } + return SbtopkResult::FAILED; + } + + // Fallback to original for dim=32-4095 or k>8 with small batch return SbtopkResult::FAILED; } diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h index 91c1bfcf81..d224aa3bff 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h @@ -15,10 +15,10 @@ namespace at::native::xpu { // Result of sbtopk_try_launch. -// FAILED - did not run; caller should fall back to original kernel. -// UNSORTED - ran; output contains top-k values but is not sorted. +// FAILED - sbtopk did not run; caller should fall back to original. +// UNSORTED - sbtopk ran; output contains top-k values but is not sorted. // Caller must sort if sorted output is requested. -// SORTED - ran; output is already sorted (descending for largest, +// SORTED - sbtopk ran; output is already sorted (descending for largest, // ascending for smallest). Caller can skip sort. enum class SbtopkResult : int { FAILED = 0, @@ -26,6 +26,11 @@ enum class SbtopkResult : int { SORTED = 2, }; +// Try to run topk using an optimized kernel path. +// +// Dispatches between the subgroup topk kernel (sub-group bitonic merge, +// output SORTED) and the single workgroup topk kernel (radix select, +// output UNSORTED) based on (nsegments, nelements, k). SbtopkResult sbtopk_try_launch( const at::Tensor& self, int64_t nsegments, diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp new file mode 100644 index 0000000000..5d2b4c58da --- /dev/null +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -0,0 +1,793 @@ +/* + * Single-workgroup top-k kernel — SYCL translation of CUDA single-block path. + * + * One workgroup (1024 threads) handles one slice. Algorithm: + * 1. radixSelect: iterates MSB→LSB in RADIX_BITS=4 steps to identify + * the k-th largest (or smallest) value. + * 2. gatherTopK: two-pass gather — first collects values strictly + * better than the k-th value, then fills remaining slots with + * values equal to the k-th value. + * + * Output is UNSORTED. Caller applies segmented sort if sorted output needed. + * + * CUDA sources translated 1:1: + * - SortingRadixSelect.cuh: countRadixUsingMask (line 176), findPattern (239) + * - ScanUtils.cuh: inclusiveBinaryPrefixScan (16), exclusiveBinaryPrefixScan + * (64) + * - TensorTopK.cu: gatherTopK (lines 40-182), radixSelect (860) + * + * Key CUDA -> SYCL mappings: + * WARP_BALLOT(pred) -> sycl::ext::oneapi::group_ballot(sg, pred) + * __popc(ballot) -> ballot.count() (or extract_bits + + * __builtin_popcount) getLaneMaskLe() & ballot -> extract_bits + manual + * le_mask + __builtin_popcount getLaneId() -> + * sg.get_local_linear_id() atomicAdd (smem) -> sycl::atomic_ref + * (local_space) + * __syncthreads() -> sycl::group_barrier(item.get_group()) + * doLdg(ptr) -> direct load (no read-only cache hint in SYCL) + * Bitfield::getBitfield -> software shift+mask (no PTX BFE/BFI in SYCL) + * smem[] -> int* from local accessor + * + * withinSliceStride = 1 (input is .contiguous() before calling this kernel). + * Uses SBTOPK_RADIX_BITS=4, SBTOPK_RADIX_SIZE=16, SBTOPK_RADIX_MASK=15 + * (halving radix passes vs the original RADIX_BITS=2). + * + * Future: CUDA multi-block radix path will be added for very large + * (nsegments × nelements) workloads. + */ + +#include +#include +#include +#include +#include +#include + +namespace at { +namespace native { +namespace xpu { + +// Uses RADIX_BITS=4 (16 digits per pass), halving radix passes for fp32. +// Cannot reuse RADIX_BITS/SIZE/MASK from SortingRadixSelect.h (constexpr int, +// can't #undef). +constexpr int SBTOPK_RADIX_BITS = 4; +constexpr int SBTOPK_RADIX_SIZE = 16; // 2 ^ SBTOPK_RADIX_BITS +constexpr int SBTOPK_RADIX_MASK = (SBTOPK_RADIX_SIZE - 1); + +// Block size = 1024 threads, matching CUDA C10_LAUNCH_BOUNDS_1(1024) +constexpr int SBTOPK_BLOCK = 1024; + +// SLM layout: +// [0..63] : used by countRadixUsingMask (smem[0..SBTOPK_RADIX_SIZE-1] for +// counts) +// and by exclusiveIntPrefixScan (smem[0..num_sgs-1] for carries) +// [64..65] : used by findPattern (flag + found index) +// Total: 68 ints = 272 bytes +constexpr int SMEM_INTS = 68; + +template < + typename scalar_t, + int VEC_SIZE = 4, + int ELEMS_PER_THREAD = 32, + int SIMD = 32> +struct SbtopkGatherFunctor { + using RadixT = typename TopKTypeConfig::RadixType; + // CUDA uses sizeof(scalar_t)*8, NOT sizeof(RadixType)*8. + // For fp16: sizeof(Half)=2 -> 16 bits, but sizeof(uint32_t)=4 -> 32 bits. + // Using RadixT would scan garbage upper bits and break Half/BFloat16. + static constexpr int NUM_BITS = sizeof(scalar_t) * 8; + + // ================================================================ + // countRadixUsingMask — per-thread counting + sub-group/work-group reduce + // + // Replaces ballot-based counting. Each thread: + // 1. Loads VEC_SIZE elements per iteration (vectorized) + // 2. Locally increments counts[digit] (pure ALU, no cross-lane) + // 3. After loop: sub-group reduce + lane0 atomicAdd to smem + broadcast + // + // Eliminates all group_ballot calls in counting. + // Result: all threads have identical counts[0..RADIX_SIZE-1]. + // ================================================================ + __attribute__((noinline)) void countRadixUsingMask( + sycl::nd_item<1> item, + sycl::sub_group sg, + int* smem, + int counts[SBTOPK_RADIX_SIZE], + RadixT desired, + RadixT desiredMask, + int digitPos, + const scalar_t* data, + int sliceSize) const { + int lid = item.get_local_id(0); + int block_size = item.get_local_range(0); + int sg_lid = sg.get_local_linear_id(); + +#pragma unroll + for (int i = 0; i < SBTOPK_RADIX_SIZE; ++i) { + counts[i] = 0; + } + if (lid < SBTOPK_RADIX_SIZE) { + smem[lid] = 0; + } + sycl::group_barrier(item.get_group()); + + // Each thread processes VEC_SIZE consecutive elements per iteration. + // Stride = block_size * VEC_SIZE for coalesced access across threads. + using LoadT = memory::aligned_vector; + int stride = block_size * VEC_SIZE; + + // Vectorized main loop — full VEC_SIZE loads + int base = lid * VEC_SIZE; + for (; base + VEC_SIZE <= sliceSize; base += stride) { + scalar_t src[VEC_SIZE]; + *reinterpret_cast(&src) = + *reinterpret_cast(&data[base]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + RadixT val = TopKTypeConfig::convert(src[v]); + if ((val & desiredMask) == desired) { + RadixT digit = + Bitfield::getBitfield(val, digitPos, SBTOPK_RADIX_BITS); + counts[digit]++; + } + } + } + // Scalar tail — remaining elements + for (int idx = base; idx < sliceSize && idx < base + VEC_SIZE; ++idx) { + RadixT val = TopKTypeConfig::convert(data[idx]); + if ((val & desiredMask) == desired) { + RadixT digit = + Bitfield::getBitfield(val, digitPos, SBTOPK_RADIX_BITS); + counts[digit]++; + } + } + + // Sub-group reduce + lane0 atomicAdd to smem. +#pragma unroll + for (int j = 0; j < SBTOPK_RADIX_SIZE; ++j) { + int total = sycl::reduce_over_group(sg, counts[j], sycl::plus()); + if (sg_lid == 0) { + sycl::atomic_ref< + int, + sycl::memory_order::relaxed, + sycl::memory_scope::work_group, + sycl::access::address_space::local_space> + ref(smem[j]); + ref.fetch_add(total); + } + } + sycl::group_barrier(item.get_group()); + + // All threads read block-level totals +#pragma unroll + for (int j = 0; j < SBTOPK_RADIX_SIZE; ++j) { + counts[j] = smem[j]; + } + } + + // ================================================================ + // findPattern (SortingRadixSelect.cuh:239) + // + // Finds the unique value whose convert() matches desired. + // Returns RadixT (converted form) directly — no deconvert needed. + // SYCL uses smem[64]=flag(int), smem[65]=index(int), then + // convert(data[index]). + // ================================================================ + __attribute__((noinline)) RadixT findPattern( + sycl::nd_item<1> item, + int* smem, + const scalar_t* data, + int sliceSize, + RadixT desired, + RadixT desiredMask) const { + int lid = item.get_local_id(0); + int block_size = item.get_local_range(0); + + if (lid == 0) { + smem[64] = 0; // found flag + smem[65] = -1; // found index + } + // Barrier required: init must be visible before any thread enters the loop + sycl::group_barrier(item.get_group()); + + int numIterations = + ((sliceSize + block_size - 1) / block_size) * block_size; + + for (int i = lid; i < numIterations; i += block_size) { + bool inRange = (i < sliceSize); + scalar_t v = inRange ? data[i] : static_cast(0); + + if (inRange && + ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { + smem[64] = 1; // flag + smem[65] = i; // index of found value + } + sycl::group_barrier(item.get_group()); + + int found = smem[64]; + int foundIdx = smem[65]; + + if (found != 0) { + return TopKTypeConfig::convert(data[foundIdx]); + } + + // WAR barrier: protect smem writes in next iteration from current reads + sycl::group_barrier(item.get_group()); + } + return static_cast(0); + } + + // ================================================================ + // exclusiveIntPrefixScan — integer exclusive prefix scan + // + // Each thread provides an integer count (0..VEC_SIZE). Returns: + // out: exclusive prefix sum (write offset for this thread) + // carry: total sum across all threads in the work-group + // + // Sub-group level: exclusive_scan_over_group + reduce_over_group + // Cross sub-group: smem serial scan (same pattern as binary version) + // ================================================================ + __attribute__((noinline)) void exclusiveIntPrefixScan( + sycl::nd_item<1> item, + sycl::sub_group sg, + int* smem, + int local_count, + int& out, + int& carry) const { + int sg_lid = sg.get_local_linear_id(); + int sg_id = sg.get_group_linear_id(); + constexpr int num_sgs = SBTOPK_BLOCK / SIMD; + + int sg_inclusive = + sycl::inclusive_scan_over_group(sg, local_count, sycl::plus()); + int sg_exclusive = sg_inclusive - local_count; + + // group_broadcast to get sub-group total (last lane's inclusive value) + int sg_total = sycl::group_broadcast(sg, sg_inclusive, SIMD - 1); + if (sg_lid == SIMD - 1) { + smem[sg_id] = sg_total; + } + sycl::group_barrier(item.get_group()); + + // Thread 0: serial inclusive prefix sum over sub-group totals + if (item.get_local_id(0) == 0) { + int current = 0; + for (int i = 0; i < num_sgs; ++i) { + int v = smem[i]; + smem[i] = v + current; + current += v; + } + } + sycl::group_barrier(item.get_group()); + + int cross_sg_prefix = (sg_id >= 1) ? smem[sg_id - 1] : 0; + out = sg_exclusive + cross_sg_prefix; + carry = smem[num_sgs - 1]; + } + + // ================================================================ + // radixSelect (SortingRadixSelect.cuh:860, non-ROCm path) + // + // Iterates MSB to LSB in RADIX_BITS steps. + // At each step: count digits, scan to find which digit contains k-th. + // found_unique (count==1, kToFind==1): findPattern + return RadixT + // found_non_unique (count>=kToFind): narrow desired/desiredMask, continue + // End: return desired (RadixT, fully determined) + // ================================================================ + __attribute__((noinline)) RadixT radixSelect( + sycl::nd_item<1> item, + sycl::sub_group sg, + int* smem, + const scalar_t* data, + int k, + bool largest, + int sliceSize) const { + int counts[SBTOPK_RADIX_SIZE]; + RadixT desired = 0; + RadixT desiredMask = 0; + int kToFind = k; + + for (int digitPos = NUM_BITS - SBTOPK_RADIX_BITS; digitPos >= 0; + digitPos -= SBTOPK_RADIX_BITS) { + countRadixUsingMask( + item, + sg, + smem, + counts, + desired, + desiredMask, + digitPos, + data, + sliceSize); + + // All threads execute the same scan logic (counts are identical). + // Replicates CUDA found_unique / found_non_unique lambdas exactly. + if (largest) { + for (int i = SBTOPK_RADIX_SIZE - 1; i >= 0; --i) { + int count = counts[i]; + + // found_unique: return from radixSelect + if (count == 1 && kToFind == 1) { + desired = Bitfield::setBitfield( + desired, i, digitPos, SBTOPK_RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, SBTOPK_RADIX_MASK, digitPos, SBTOPK_RADIX_BITS); + return findPattern( + item, smem, data, sliceSize, desired, desiredMask); + } + + // found_non_unique: break inner loop, continue outer + if (count >= kToFind) { + desired = Bitfield::setBitfield( + desired, i, digitPos, SBTOPK_RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, SBTOPK_RADIX_MASK, digitPos, SBTOPK_RADIX_BITS); + break; + } + + kToFind -= count; + } + } else { + for (int i = 0; i < SBTOPK_RADIX_SIZE; ++i) { + int count = counts[i]; + + if (count == 1 && kToFind == 1) { + desired = Bitfield::setBitfield( + desired, i, digitPos, SBTOPK_RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, SBTOPK_RADIX_MASK, digitPos, SBTOPK_RADIX_BITS); + return findPattern( + item, smem, data, sliceSize, desired, desiredMask); + } + + if (count >= kToFind) { + desired = Bitfield::setBitfield( + desired, i, digitPos, SBTOPK_RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, SBTOPK_RADIX_MASK, digitPos, SBTOPK_RADIX_BITS); + break; + } + + kToFind -= count; + } + } + } + + // No unique result; desired fully determined + return desired; + } + + // ================================================================ + // operator() — gatherTopK (TensorTopK.cu:40-182) + // + // 1. radixSelect to find k-th value + // 2. Gather values strictly > topK (largest) or < topK (!largest) + // 3. Fill remaining with values == topK + // + // Each thread processes ELEMS_PER_THREAD elements per iteration + // (LOADS_PER_ITER × vec loads), then ONE prefix scan per iteration. + // With ELEMS_PER_THREAD=32 and 1024 threads, each iteration covers + // 32K elements, so dim=131072 needs only 4 iterations. + // ================================================================ + void operator()(sycl::nd_item<1> item) const { + int slice = item.get_group_linear_id(); + if (slice >= numSlices_) + return; + + sycl::sub_group sg = item.get_sub_group(); + + // Get raw int* pointer from local accessor + int* smem = + local_mem_.template get_multi_ptr().get(); + + const scalar_t* inputSlice = inputData_ + (int64_t)slice * sliceSize_; + scalar_t* topKSlice = topKData_ + (int64_t)slice * k_; + int64_t* indicesSlice = indicesData_ + (int64_t)slice * k_; + + // Step 1: radixSelect — returns RadixT directly (no deconvert/convert + // round-trip) + RadixT topKConverted = + radixSelect(item, sg, smem, inputSlice, k_, largest_, sliceSize_); + + // Vectorized gather setup + // ELEMS_PER_THREAD: each thread processes this many elements per iteration. + // Multiple vec loads per iteration, then ONE prefix scan. + // With ELEMS_PER_THREAD=32: 4 iterations for dim=131072. + static constexpr int LOADS_PER_ITER = ELEMS_PER_THREAD / VEC_SIZE; + using LoadT = memory::aligned_vector; + int lid = item.get_local_id(0); + + // Each iteration covers SBTOPK_BLOCK * ELEMS_PER_THREAD elements. + int iter_stride = SBTOPK_BLOCK * ELEMS_PER_THREAD; + int numIters = (sliceSize_ + iter_stride - 1) / iter_stride; + + // Step 2: Gather values strictly greater/less than topKValue + int writeIndexStart = 0; + + for (int iter = 0; iter < numIters; ++iter) { + // Each thread loads ELEMS_PER_THREAD elements from LOADS_PER_ITER vec4 + // chunks. Thread layout: consecutive threads handle consecutive VEC_SIZE + // chunks. Thread t handles chunks at offsets: t*VEC_SIZE, + // (t+SBTOPK_BLOCK)*VEC_SIZE, ... + scalar_t vals[ELEMS_PER_THREAD]; + int match_indices[ELEMS_PER_THREAD]; // global index of matching elements + int local_count = 0; + + int iter_base = iter * iter_stride; + +#pragma unroll + for (int L = 0; L < LOADS_PER_ITER; ++L) { + int base = iter_base + L * SBTOPK_BLOCK * VEC_SIZE + lid * VEC_SIZE; + + if (base + VEC_SIZE <= sliceSize_) { + scalar_t src[VEC_SIZE]; + *reinterpret_cast(&src) = + *reinterpret_cast(&inputSlice[base]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + RadixT cv = TopKTypeConfig::convert(src[v]); + bool match = largest_ ? (cv > topKConverted) : (cv < topKConverted); + if (match) { + vals[local_count] = src[v]; + match_indices[local_count] = base + v; + local_count++; + } + } + } else if (base < sliceSize_) { + for (int v = 0; v < VEC_SIZE && base + v < sliceSize_; ++v) { + scalar_t sv = inputSlice[base + v]; + RadixT cv = TopKTypeConfig::convert(sv); + bool match = largest_ ? (cv > topKConverted) : (cv < topKConverted); + if (match) { + vals[local_count] = sv; + match_indices[local_count] = base + v; + local_count++; + } + } + } + } + + int offset, carry; + exclusiveIntPrefixScan(item, sg, smem, local_count, offset, carry); + + for (int j = 0; j < local_count; ++j) { + int writeIndex = writeIndexStart + offset + j; + if (writeIndex < k_) { + topKSlice[writeIndex] = vals[j]; + indicesSlice[writeIndex] = match_indices[j]; + } + } + writeIndexStart += carry; + } + + // Step 3: Fill remaining with values == topKValue + int topKRemaining = k_ - writeIndexStart; + + for (int iter = 0; iter < numIters; ++iter) { + scalar_t vals[ELEMS_PER_THREAD]; + int match_indices[ELEMS_PER_THREAD]; + int local_count = 0; + + int iter_base = iter * iter_stride; + +#pragma unroll + for (int L = 0; L < LOADS_PER_ITER; ++L) { + int base = iter_base + L * SBTOPK_BLOCK * VEC_SIZE + lid * VEC_SIZE; + + if (base + VEC_SIZE <= sliceSize_) { + scalar_t src[VEC_SIZE]; + *reinterpret_cast(&src) = + *reinterpret_cast(&inputSlice[base]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + RadixT cv = TopKTypeConfig::convert(src[v]); + if (cv == topKConverted) { + vals[local_count] = src[v]; + match_indices[local_count] = base + v; + local_count++; + } + } + } else if (base < sliceSize_) { + for (int v = 0; v < VEC_SIZE && base + v < sliceSize_; ++v) { + scalar_t sv = inputSlice[base + v]; + RadixT cv = TopKTypeConfig::convert(sv); + if (cv == topKConverted) { + vals[local_count] = sv; + match_indices[local_count] = base + v; + local_count++; + } + } + } + } + + int offset, carry; + exclusiveIntPrefixScan(item, sg, smem, local_count, offset, carry); + + for (int j = 0; j < local_count; ++j) { + if (offset + j < topKRemaining) { + int writeIndex = writeIndexStart + offset + j; + topKSlice[writeIndex] = vals[j]; + indicesSlice[writeIndex] = match_indices[j]; + } + } + + if (carry >= topKRemaining) { + break; + } + topKRemaining -= carry; + writeIndexStart += carry; + } + } + + SbtopkGatherFunctor( + const scalar_t* inputData, + scalar_t* topKData, + int64_t* indicesData, + int numSlices, + int sliceSize, + int k, + bool largest, + sycl::local_accessor local_mem) + : inputData_(inputData), + topKData_(topKData), + indicesData_(indicesData), + numSlices_(numSlices), + sliceSize_(sliceSize), + k_(k), + largest_(largest), + local_mem_(local_mem) {} + + const scalar_t* inputData_; + scalar_t* topKData_; + int64_t* indicesData_; + int numSlices_; + int sliceSize_; + int k_; + bool largest_; + sycl::local_accessor local_mem_; +}; + +// ================================================================ +// Launch function +// ================================================================ +template +static void sbtopk_launch_impl( + const scalar_t* input, + scalar_t* topK, + int64_t* indices, + int numSlices, + int sliceSize, + int k, + bool largest) { + namespace syclex = sycl::ext::oneapi::experimental; + + constexpr int SIMD = 32; + using Functor = + SbtopkGatherFunctor; + + syclex::properties kernel_props{syclex::sub_group_size}; + + auto& q = at::xpu::getCurrentSYCLQueue(); + q.submit([&](sycl::handler& cgh) { + sycl::local_accessor local_mem(SMEM_INTS, cgh); + Functor functor( + input, topK, indices, numSlices, sliceSize, k, largest, local_mem); + cgh.parallel_for( + sycl::nd_range<1>( + sycl::range<1>(numSlices * SBTOPK_BLOCK), + sycl::range<1>(SBTOPK_BLOCK)), + kernel_props, + functor); + }); +} + +// Dispatch macro to reduce boilerplate +#define SBTOPK_LAUNCH(V, E) \ + sbtopk_launch_impl( \ + input, topK, indices, numSlices, sliceSize, k, largest) + +template +static void sbtopk_launch_kernel( + const scalar_t* input, + scalar_t* topK, + int64_t* indices, + int numSlices, + int sliceSize, + int k, + bool largest) { + // Determine ELEMS_PER_THREAD based on dim: target ~4 iterations + int ept; + if (sliceSize >= 32 * SBTOPK_BLOCK) + ept = 32; + else if (sliceSize >= 16 * SBTOPK_BLOCK) + ept = 16; + else if (sliceSize >= 8 * SBTOPK_BLOCK) + ept = 8; + else if (sliceSize >= 4 * SBTOPK_BLOCK) + ept = 4; + else if (sliceSize >= 2 * SBTOPK_BLOCK) + ept = 2; + else + ept = 1; + + // Determine VEC_SIZE: largest power-of-2 dividing sliceSize, + // capped by type max AND by EPT (vec <= ept required) + constexpr int MAX_VEC = sizeof(scalar_t) <= 2 ? 8 : 4; + int cap = MAX_VEC < ept ? MAX_VEC : ept; + int vec = 1; + if (cap >= 8 && sliceSize % 8 == 0) + vec = 8; + else if (cap >= 4 && sliceSize % 4 == 0) + vec = 4; + else if (cap >= 2 && sliceSize % 2 == 0) + vec = 2; + + // Dispatch: VEC determines which EPT values are valid (EPT >= VEC, EPT % VEC + // == 0) + if constexpr (MAX_VEC == 8) { + // 16-bit types: VEC can be 8, 4, 2, 1 + if (vec == 8) { + switch (ept) { + case 8: + SBTOPK_LAUNCH(8, 8); + return; + case 16: + SBTOPK_LAUNCH(8, 16); + return; + default: + SBTOPK_LAUNCH(8, 32); + return; + } + } else if (vec == 4) { + switch (ept) { + case 4: + SBTOPK_LAUNCH(4, 4); + return; + case 8: + SBTOPK_LAUNCH(4, 8); + return; + case 16: + SBTOPK_LAUNCH(4, 16); + return; + default: + SBTOPK_LAUNCH(4, 32); + return; + } + } else if (vec == 2) { + switch (ept) { + case 2: + SBTOPK_LAUNCH(2, 2); + return; + case 4: + SBTOPK_LAUNCH(2, 4); + return; + case 8: + SBTOPK_LAUNCH(2, 8); + return; + case 16: + SBTOPK_LAUNCH(2, 16); + return; + default: + SBTOPK_LAUNCH(2, 32); + return; + } + } else { + switch (ept) { + case 1: + SBTOPK_LAUNCH(1, 1); + return; + case 2: + SBTOPK_LAUNCH(1, 2); + return; + case 4: + SBTOPK_LAUNCH(1, 4); + return; + case 8: + SBTOPK_LAUNCH(1, 8); + return; + case 16: + SBTOPK_LAUNCH(1, 16); + return; + default: + SBTOPK_LAUNCH(1, 32); + return; + } + } + } else { + // 32-bit types: VEC can be 4, 2, 1 + if (vec >= 4) { + switch (ept) { + case 4: + SBTOPK_LAUNCH(4, 4); + return; + case 8: + SBTOPK_LAUNCH(4, 8); + return; + case 16: + SBTOPK_LAUNCH(4, 16); + return; + default: + SBTOPK_LAUNCH(4, 32); + return; + } + } else if (vec == 2) { + switch (ept) { + case 2: + SBTOPK_LAUNCH(2, 2); + return; + case 4: + SBTOPK_LAUNCH(2, 4); + return; + case 8: + SBTOPK_LAUNCH(2, 8); + return; + case 16: + SBTOPK_LAUNCH(2, 16); + return; + default: + SBTOPK_LAUNCH(2, 32); + return; + } + } else { + switch (ept) { + case 1: + SBTOPK_LAUNCH(1, 1); + return; + case 2: + SBTOPK_LAUNCH(1, 2); + return; + case 4: + SBTOPK_LAUNCH(1, 4); + return; + case 8: + SBTOPK_LAUNCH(1, 8); + return; + case 16: + SBTOPK_LAUNCH(1, 16); + return; + default: + SBTOPK_LAUNCH(1, 32); + return; + } + } + } +} + +#undef SBTOPK_LAUNCH + +bool single_wg_topk_try_launch( + const at::Tensor& self, + int64_t nsegments, + int64_t nelements, + int64_t k, + bool largest, + const at::Tensor& values, + const at::Tensor& indices) { + // Only handle cases where single-workgroup topk is beneficial: + // large dim, small k, contiguous last-dim + if (k > 256) { + return false; + } + + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "single_wg_topk_xpu", + [&]() { + sbtopk_launch_kernel( + static_cast(self.const_data_ptr()), + static_cast(values.data_ptr()), + static_cast(indices.data_ptr()), + static_cast(nsegments), + static_cast(nelements), + static_cast(k), + largest); + }); + + return true; +} + +} // namespace xpu +} // namespace native +} // namespace at diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.h b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.h new file mode 100644 index 0000000000..a4c260c4b0 --- /dev/null +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace at { +namespace native { +namespace xpu { + +// Single-workgroup top-k kernel (translated from CUDA single-block path). +// One workgroup per slice: radix select to find the k-th value, then gather. +// Good for large dim (≥4096), any batch size. Output is UNSORTED. +bool single_wg_topk_try_launch( + const at::Tensor& self, + int64_t nsegments, + int64_t nelements, + int64_t k, + bool largest, + const at::Tensor& values, + const at::Tensor& indices); + +} // namespace xpu +} // namespace native +} // namespace at From 8ce0a72888368d60df1239ea6dbc378c60865f44 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Sat, 18 Apr 2026 14:27:56 +0800 Subject: [PATCH 02/10] Pin grf_size<128> on single-wg kernel to prevent compiler from switching to GRF 256 --- src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp index 5d2b4c58da..67deffa258 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -41,6 +41,7 @@ #include #include #include +#include #include namespace at { @@ -560,12 +561,14 @@ static void sbtopk_launch_impl( int k, bool largest) { namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; constexpr int SIMD = 32; using Functor = SbtopkGatherFunctor; - syclex::properties kernel_props{syclex::sub_group_size}; + syclex::properties kernel_props{ + syclex::sub_group_size, intelex::grf_size<128>}; auto& q = at::xpu::getCurrentSYCLQueue(); q.submit([&](sycl::handler& cgh) { From b22c740a12f41a0cabc58f9d0f1858544117c7b0 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Sat, 18 Apr 2026 14:55:06 +0800 Subject: [PATCH 03/10] Use IndexT for shared memory and element indices in single-wg kernel Add IndexT template parameter to SbtopkGatherFunctor so that shared memory, histogram counts, and element indices use the correct index type. Dispatch IndexT as int when nsegments*nelements <= INT_MAX, int64_t otherwise. Remove the nsegments <= INT_MAX guard in the caller since the kernel now handles both cases internally. --- .../xpu/sycl/TensorTopKSbtopkKernel.cpp | 4 +- .../xpu/sycl/TensorTopKSingleWgKernel.cpp | 181 ++++++++++-------- 2 files changed, 107 insertions(+), 78 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index e6c550fbdf..bd833a41f5 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -115,9 +115,7 @@ SbtopkResult sbtopk_try_launch( // dim. Output is UNSORTED. // Use nelements >= 4096 so dim=1024/1025 falls through to original // (single-wg has regressions at dim=1024 for medium/large batches). - // Single-wg uses int for numSlices internally; reject - // if nsegments overflows int32. - if (nelements >= 4096 && nsegments <= std::numeric_limits::max()) { + if (nelements >= 4096) { if (single_wg_topk_try_launch( self, nsegments, nelements, k, largest, values, indices)) { return SbtopkResult::UNSORTED; diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp index 67deffa258..91e0cfb28e 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -26,7 +26,7 @@ * __syncthreads() -> sycl::group_barrier(item.get_group()) * doLdg(ptr) -> direct load (no read-only cache hint in SYCL) * Bitfield::getBitfield -> software shift+mask (no PTX BFE/BFI in SYCL) - * smem[] -> int* from local accessor + * smem[] -> IndexT* from local accessor * * withinSliceStride = 1 (input is .contiguous() before calling this kernel). * Uses SBTOPK_RADIX_BITS=4, SBTOPK_RADIX_SIZE=16, SBTOPK_RADIX_MASK=15 @@ -63,14 +63,15 @@ constexpr int SBTOPK_BLOCK = 1024; // counts) // and by exclusiveIntPrefixScan (smem[0..num_sgs-1] for carries) // [64..65] : used by findPattern (flag + found index) -// Total: 68 ints = 272 bytes -constexpr int SMEM_INTS = 68; +// Total: 68 elements (IndexT-sized) +constexpr int SMEM_ELEMS = 68; template < typename scalar_t, int VEC_SIZE = 4, int ELEMS_PER_THREAD = 32, - int SIMD = 32> + int SIMD = 32, + typename IndexT = int> struct SbtopkGatherFunctor { using RadixT = typename TopKTypeConfig::RadixType; // CUDA uses sizeof(scalar_t)*8, NOT sizeof(RadixType)*8. @@ -92,13 +93,13 @@ struct SbtopkGatherFunctor { __attribute__((noinline)) void countRadixUsingMask( sycl::nd_item<1> item, sycl::sub_group sg, - int* smem, - int counts[SBTOPK_RADIX_SIZE], + IndexT* smem, + IndexT counts[SBTOPK_RADIX_SIZE], RadixT desired, RadixT desiredMask, int digitPos, const scalar_t* data, - int sliceSize) const { + IndexT sliceSize) const { int lid = item.get_local_id(0); int block_size = item.get_local_range(0); int sg_lid = sg.get_local_linear_id(); @@ -115,10 +116,10 @@ struct SbtopkGatherFunctor { // Each thread processes VEC_SIZE consecutive elements per iteration. // Stride = block_size * VEC_SIZE for coalesced access across threads. using LoadT = memory::aligned_vector; - int stride = block_size * VEC_SIZE; + IndexT stride = static_cast(block_size) * VEC_SIZE; // Vectorized main loop — full VEC_SIZE loads - int base = lid * VEC_SIZE; + IndexT base = static_cast(lid) * VEC_SIZE; for (; base + VEC_SIZE <= sliceSize; base += stride) { scalar_t src[VEC_SIZE]; *reinterpret_cast(&src) = @@ -134,7 +135,7 @@ struct SbtopkGatherFunctor { } } // Scalar tail — remaining elements - for (int idx = base; idx < sliceSize && idx < base + VEC_SIZE; ++idx) { + for (IndexT idx = base; idx < sliceSize && idx < base + VEC_SIZE; ++idx) { RadixT val = TopKTypeConfig::convert(data[idx]); if ((val & desiredMask) == desired) { RadixT digit = @@ -146,10 +147,11 @@ struct SbtopkGatherFunctor { // Sub-group reduce + lane0 atomicAdd to smem. #pragma unroll for (int j = 0; j < SBTOPK_RADIX_SIZE; ++j) { - int total = sycl::reduce_over_group(sg, counts[j], sycl::plus()); + IndexT total = + sycl::reduce_over_group(sg, counts[j], sycl::plus()); if (sg_lid == 0) { sycl::atomic_ref< - int, + IndexT, sycl::memory_order::relaxed, sycl::memory_scope::work_group, sycl::access::address_space::local_space> @@ -171,14 +173,13 @@ struct SbtopkGatherFunctor { // // Finds the unique value whose convert() matches desired. // Returns RadixT (converted form) directly — no deconvert needed. - // SYCL uses smem[64]=flag(int), smem[65]=index(int), then - // convert(data[index]). + // SYCL uses smem[64]=flag, smem[65]=index, then convert(data[index]). // ================================================================ __attribute__((noinline)) RadixT findPattern( sycl::nd_item<1> item, - int* smem, + IndexT* smem, const scalar_t* data, - int sliceSize, + IndexT sliceSize, RadixT desired, RadixT desiredMask) const { int lid = item.get_local_id(0); @@ -186,15 +187,15 @@ struct SbtopkGatherFunctor { if (lid == 0) { smem[64] = 0; // found flag - smem[65] = -1; // found index + smem[65] = static_cast(-1); // found index } // Barrier required: init must be visible before any thread enters the loop sycl::group_barrier(item.get_group()); - int numIterations = + IndexT numIterations = ((sliceSize + block_size - 1) / block_size) * block_size; - for (int i = lid; i < numIterations; i += block_size) { + for (IndexT i = lid; i < numIterations; i += block_size) { bool inRange = (i < sliceSize); scalar_t v = inRange ? data[i] : static_cast(0); @@ -205,8 +206,8 @@ struct SbtopkGatherFunctor { } sycl::group_barrier(item.get_group()); - int found = smem[64]; - int foundIdx = smem[65]; + IndexT found = smem[64]; + IndexT foundIdx = smem[65]; if (found != 0) { return TopKTypeConfig::convert(data[foundIdx]); @@ -221,17 +222,20 @@ struct SbtopkGatherFunctor { // ================================================================ // exclusiveIntPrefixScan — integer exclusive prefix scan // - // Each thread provides an integer count (0..VEC_SIZE). Returns: + // Each thread provides an integer count (0..ELEMS_PER_THREAD). Returns: // out: exclusive prefix sum (write offset for this thread) // carry: total sum across all threads in the work-group // - // Sub-group level: exclusive_scan_over_group + reduce_over_group + // Values are bounded by SBTOPK_BLOCK * ELEMS_PER_THREAD (always fits int), + // but smem is IndexT* so we cast at the interface. + // + // Sub-group level: inclusive_scan_over_group + group_broadcast // Cross sub-group: smem serial scan (same pattern as binary version) // ================================================================ __attribute__((noinline)) void exclusiveIntPrefixScan( sycl::nd_item<1> item, sycl::sub_group sg, - int* smem, + IndexT* smem, int local_count, int& out, int& carry) const { @@ -246,24 +250,24 @@ struct SbtopkGatherFunctor { // group_broadcast to get sub-group total (last lane's inclusive value) int sg_total = sycl::group_broadcast(sg, sg_inclusive, SIMD - 1); if (sg_lid == SIMD - 1) { - smem[sg_id] = sg_total; + smem[sg_id] = static_cast(sg_total); } sycl::group_barrier(item.get_group()); // Thread 0: serial inclusive prefix sum over sub-group totals if (item.get_local_id(0) == 0) { - int current = 0; + IndexT current = 0; for (int i = 0; i < num_sgs; ++i) { - int v = smem[i]; + IndexT v = smem[i]; smem[i] = v + current; current += v; } } sycl::group_barrier(item.get_group()); - int cross_sg_prefix = (sg_id >= 1) ? smem[sg_id - 1] : 0; + int cross_sg_prefix = (sg_id >= 1) ? static_cast(smem[sg_id - 1]) : 0; out = sg_exclusive + cross_sg_prefix; - carry = smem[num_sgs - 1]; + carry = static_cast(smem[num_sgs - 1]); } // ================================================================ @@ -278,12 +282,12 @@ struct SbtopkGatherFunctor { __attribute__((noinline)) RadixT radixSelect( sycl::nd_item<1> item, sycl::sub_group sg, - int* smem, + IndexT* smem, const scalar_t* data, int k, bool largest, - int sliceSize) const { - int counts[SBTOPK_RADIX_SIZE]; + IndexT sliceSize) const { + IndexT counts[SBTOPK_RADIX_SIZE]; RadixT desired = 0; RadixT desiredMask = 0; int kToFind = k; @@ -305,7 +309,7 @@ struct SbtopkGatherFunctor { // Replicates CUDA found_unique / found_non_unique lambdas exactly. if (largest) { for (int i = SBTOPK_RADIX_SIZE - 1; i >= 0; --i) { - int count = counts[i]; + IndexT count = counts[i]; // found_unique: return from radixSelect if (count == 1 && kToFind == 1) { @@ -326,11 +330,12 @@ struct SbtopkGatherFunctor { break; } - kToFind -= count; + // count < kToFind here, so count fits in int + kToFind -= static_cast(count); } } else { for (int i = 0; i < SBTOPK_RADIX_SIZE; ++i) { - int count = counts[i]; + IndexT count = counts[i]; if (count == 1 && kToFind == 1) { desired = Bitfield::setBitfield( @@ -349,7 +354,7 @@ struct SbtopkGatherFunctor { break; } - kToFind -= count; + kToFind -= static_cast(count); } } } @@ -371,19 +376,20 @@ struct SbtopkGatherFunctor { // 32K elements, so dim=131072 needs only 4 iterations. // ================================================================ void operator()(sycl::nd_item<1> item) const { - int slice = item.get_group_linear_id(); + IndexT slice = static_cast(item.get_group_linear_id()); if (slice >= numSlices_) return; sycl::sub_group sg = item.get_sub_group(); - // Get raw int* pointer from local accessor - int* smem = + // Get raw IndexT* pointer from local accessor + IndexT* smem = local_mem_.template get_multi_ptr().get(); - const scalar_t* inputSlice = inputData_ + (int64_t)slice * sliceSize_; - scalar_t* topKSlice = topKData_ + (int64_t)slice * k_; - int64_t* indicesSlice = indicesData_ + (int64_t)slice * k_; + const scalar_t* inputSlice = + inputData_ + static_cast(slice) * sliceSize_; + scalar_t* topKSlice = topKData_ + static_cast(slice) * k_; + int64_t* indicesSlice = indicesData_ + static_cast(slice) * k_; // Step 1: radixSelect — returns RadixT directly (no deconvert/convert // round-trip) @@ -399,8 +405,9 @@ struct SbtopkGatherFunctor { int lid = item.get_local_id(0); // Each iteration covers SBTOPK_BLOCK * ELEMS_PER_THREAD elements. - int iter_stride = SBTOPK_BLOCK * ELEMS_PER_THREAD; - int numIters = (sliceSize_ + iter_stride - 1) / iter_stride; + IndexT iter_stride = static_cast(SBTOPK_BLOCK) * ELEMS_PER_THREAD; + int numIters = + static_cast((sliceSize_ + iter_stride - 1) / iter_stride); // Step 2: Gather values strictly greater/less than topKValue int writeIndexStart = 0; @@ -411,14 +418,16 @@ struct SbtopkGatherFunctor { // chunks. Thread t handles chunks at offsets: t*VEC_SIZE, // (t+SBTOPK_BLOCK)*VEC_SIZE, ... scalar_t vals[ELEMS_PER_THREAD]; - int match_indices[ELEMS_PER_THREAD]; // global index of matching elements + IndexT match_indices[ELEMS_PER_THREAD]; int local_count = 0; - int iter_base = iter * iter_stride; + IndexT iter_base = static_cast(iter) * iter_stride; #pragma unroll for (int L = 0; L < LOADS_PER_ITER; ++L) { - int base = iter_base + L * SBTOPK_BLOCK * VEC_SIZE + lid * VEC_SIZE; + IndexT base = iter_base + + static_cast(L) * SBTOPK_BLOCK * VEC_SIZE + + static_cast(lid) * VEC_SIZE; if (base + VEC_SIZE <= sliceSize_) { scalar_t src[VEC_SIZE]; @@ -466,14 +475,16 @@ struct SbtopkGatherFunctor { for (int iter = 0; iter < numIters; ++iter) { scalar_t vals[ELEMS_PER_THREAD]; - int match_indices[ELEMS_PER_THREAD]; + IndexT match_indices[ELEMS_PER_THREAD]; int local_count = 0; - int iter_base = iter * iter_stride; + IndexT iter_base = static_cast(iter) * iter_stride; #pragma unroll for (int L = 0; L < LOADS_PER_ITER; ++L) { - int base = iter_base + L * SBTOPK_BLOCK * VEC_SIZE + lid * VEC_SIZE; + IndexT base = iter_base + + static_cast(L) * SBTOPK_BLOCK * VEC_SIZE + + static_cast(lid) * VEC_SIZE; if (base + VEC_SIZE <= sliceSize_) { scalar_t src[VEC_SIZE]; @@ -524,11 +535,11 @@ struct SbtopkGatherFunctor { const scalar_t* inputData, scalar_t* topKData, int64_t* indicesData, - int numSlices, - int sliceSize, + IndexT numSlices, + IndexT sliceSize, int k, bool largest, - sycl::local_accessor local_mem) + sycl::local_accessor local_mem) : inputData_(inputData), topKData_(topKData), indicesData_(indicesData), @@ -541,23 +552,27 @@ struct SbtopkGatherFunctor { const scalar_t* inputData_; scalar_t* topKData_; int64_t* indicesData_; - int numSlices_; - int sliceSize_; + IndexT numSlices_; + IndexT sliceSize_; int k_; bool largest_; - sycl::local_accessor local_mem_; + sycl::local_accessor local_mem_; }; // ================================================================ // Launch function // ================================================================ -template +template < + typename scalar_t, + int VEC_SIZE, + int ELEMS_PER_THREAD, + typename IndexT> static void sbtopk_launch_impl( const scalar_t* input, scalar_t* topK, int64_t* indices, - int numSlices, - int sliceSize, + IndexT numSlices, + IndexT sliceSize, int k, bool largest) { namespace syclex = sycl::ext::oneapi::experimental; @@ -565,19 +580,19 @@ static void sbtopk_launch_impl( constexpr int SIMD = 32; using Functor = - SbtopkGatherFunctor; + SbtopkGatherFunctor; syclex::properties kernel_props{ syclex::sub_group_size, intelex::grf_size<128>}; auto& q = at::xpu::getCurrentSYCLQueue(); q.submit([&](sycl::handler& cgh) { - sycl::local_accessor local_mem(SMEM_INTS, cgh); + sycl::local_accessor local_mem(SMEM_ELEMS, cgh); Functor functor( input, topK, indices, numSlices, sliceSize, k, largest, local_mem); cgh.parallel_for( sycl::nd_range<1>( - sycl::range<1>(numSlices * SBTOPK_BLOCK), + sycl::range<1>(static_cast(numSlices) * SBTOPK_BLOCK), sycl::range<1>(SBTOPK_BLOCK)), kernel_props, functor); @@ -585,17 +600,17 @@ static void sbtopk_launch_impl( } // Dispatch macro to reduce boilerplate -#define SBTOPK_LAUNCH(V, E) \ - sbtopk_launch_impl( \ +#define SBTOPK_LAUNCH(V, E) \ + sbtopk_launch_impl( \ input, topK, indices, numSlices, sliceSize, k, largest) -template +template static void sbtopk_launch_kernel( const scalar_t* input, scalar_t* topK, int64_t* indices, - int numSlices, - int sliceSize, + IndexT numSlices, + IndexT sliceSize, int k, bool largest) { // Determine ELEMS_PER_THREAD based on dim: target ~4 iterations @@ -778,14 +793,30 @@ bool single_wg_topk_try_launch( self.scalar_type(), "single_wg_topk_xpu", [&]() { - sbtopk_launch_kernel( - static_cast(self.const_data_ptr()), - static_cast(values.data_ptr()), - static_cast(indices.data_ptr()), - static_cast(nsegments), - static_cast(nelements), - static_cast(k), - largest); + const auto* input = static_cast(self.const_data_ptr()); + auto* topK = static_cast(values.data_ptr()); + auto* idx = static_cast(indices.data_ptr()); + + if (nsegments * nelements <= + static_cast(std::numeric_limits::max())) { + sbtopk_launch_kernel( + input, + topK, + idx, + static_cast(nsegments), + static_cast(nelements), + static_cast(k), + largest); + } else { + sbtopk_launch_kernel( + input, + topK, + idx, + nsegments, + nelements, + static_cast(k), + largest); + } }); return true; From 9de8c41851297a43de08ecadc5ed8f70c1b77327 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Sat, 18 Apr 2026 21:22:09 +0800 Subject: [PATCH 04/10] Address review: alignas on vectorized loads, named SLM offset constants - Add alignas(alignof(LoadT)) on all scalar_t src[VEC_SIZE] arrays used for vectorized loads (3 occurrences) to ensure proper alignment - Replace magic numbers smem[64]/smem[65] with named constants SMEM_FOUND_FLAG / SMEM_FOUND_IDX for clarity and maintainability --- .../xpu/sycl/TensorTopKSingleWgKernel.cpp | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp index 91e0cfb28e..b8fc2c39ba 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -64,6 +64,8 @@ constexpr int SBTOPK_BLOCK = 1024; // and by exclusiveIntPrefixScan (smem[0..num_sgs-1] for carries) // [64..65] : used by findPattern (flag + found index) // Total: 68 elements (IndexT-sized) +constexpr int SMEM_FOUND_FLAG = 64; +constexpr int SMEM_FOUND_IDX = 65; constexpr int SMEM_ELEMS = 68; template < @@ -121,7 +123,7 @@ struct SbtopkGatherFunctor { // Vectorized main loop — full VEC_SIZE loads IndexT base = static_cast(lid) * VEC_SIZE; for (; base + VEC_SIZE <= sliceSize; base += stride) { - scalar_t src[VEC_SIZE]; + alignas(alignof(LoadT)) scalar_t src[VEC_SIZE]; *reinterpret_cast(&src) = *reinterpret_cast(&data[base]); #pragma unroll @@ -173,7 +175,8 @@ struct SbtopkGatherFunctor { // // Finds the unique value whose convert() matches desired. // Returns RadixT (converted form) directly — no deconvert needed. - // SYCL uses smem[64]=flag, smem[65]=index, then convert(data[index]). + // SYCL uses smem[SMEM_FOUND_FLAG]=flag, smem[SMEM_FOUND_IDX]=index, + // then convert(data[index]). // ================================================================ __attribute__((noinline)) RadixT findPattern( sycl::nd_item<1> item, @@ -186,8 +189,8 @@ struct SbtopkGatherFunctor { int block_size = item.get_local_range(0); if (lid == 0) { - smem[64] = 0; // found flag - smem[65] = static_cast(-1); // found index + smem[SMEM_FOUND_FLAG] = 0; + smem[SMEM_FOUND_IDX] = static_cast(-1); } // Barrier required: init must be visible before any thread enters the loop sycl::group_barrier(item.get_group()); @@ -201,13 +204,13 @@ struct SbtopkGatherFunctor { if (inRange && ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { - smem[64] = 1; // flag - smem[65] = i; // index of found value + smem[SMEM_FOUND_FLAG] = 1; + smem[SMEM_FOUND_IDX] = i; } sycl::group_barrier(item.get_group()); - IndexT found = smem[64]; - IndexT foundIdx = smem[65]; + IndexT found = smem[SMEM_FOUND_FLAG]; + IndexT foundIdx = smem[SMEM_FOUND_IDX]; if (found != 0) { return TopKTypeConfig::convert(data[foundIdx]); @@ -430,7 +433,7 @@ struct SbtopkGatherFunctor { static_cast(lid) * VEC_SIZE; if (base + VEC_SIZE <= sliceSize_) { - scalar_t src[VEC_SIZE]; + alignas(alignof(LoadT)) scalar_t src[VEC_SIZE]; *reinterpret_cast(&src) = *reinterpret_cast(&inputSlice[base]); #pragma unroll @@ -487,7 +490,7 @@ struct SbtopkGatherFunctor { static_cast(lid) * VEC_SIZE; if (base + VEC_SIZE <= sliceSize_) { - scalar_t src[VEC_SIZE]; + alignas(alignof(LoadT)) scalar_t src[VEC_SIZE]; *reinterpret_cast(&src) = *reinterpret_cast(&inputSlice[base]); #pragma unroll From 89da72605278b96f5f8b0c46e8bd67c3ddbe8436 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Sun, 26 Apr 2026 16:56:29 +0800 Subject: [PATCH 05/10] Use overflow-safe check for IndexT dispatch instead of nsegments*nelements --- src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp index b8fc2c39ba..3a380c6f66 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -800,8 +800,12 @@ bool single_wg_topk_try_launch( auto* topK = static_cast(values.data_ptr()); auto* idx = static_cast(indices.data_ptr()); - if (nsegments * nelements <= - static_cast(std::numeric_limits::max())) { + if (nsegments <= + static_cast(std::numeric_limits::max()) && + nelements <= + static_cast(std::numeric_limits::max()) && + nsegments <= static_cast(std::numeric_limits::max()) / + (nelements > 0 ? nelements : 1)) { sbtopk_launch_kernel( input, topK, From 7ac220f0f06b08cfabedf48b5dce4843e83983f8 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Thu, 14 May 2026 14:02:28 +0800 Subject: [PATCH 06/10] Address review: rename internal funcs, use C10_NOINLINE and sycl_kernel_submit, simplify EPT with PowerOf2Floor, remove unnecessary include --- .../xpu/sycl/TensorTopKSingleWgKernel.cpp | 178 ++++++++---------- 1 file changed, 83 insertions(+), 95 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp index 3a380c6f66..c75d60422c 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -11,29 +11,24 @@ * Output is UNSORTED. Caller applies segmented sort if sorted output needed. * * CUDA sources translated 1:1: - * - SortingRadixSelect.cuh: countRadixUsingMask (line 176), findPattern (239) - * - ScanUtils.cuh: inclusiveBinaryPrefixScan (16), exclusiveBinaryPrefixScan - * (64) - * - TensorTopK.cu: gatherTopK (lines 40-182), radixSelect (860) + * - SortingRadixSelect.cuh: countRadixUsingMask, findPattern + * - ScanUtils.cuh: inclusiveBinaryPrefixScan, exclusiveBinaryPrefixScan + * - TensorTopK.cu: gatherTopK, radixSelect * * Key CUDA -> SYCL mappings: - * WARP_BALLOT(pred) -> sycl::ext::oneapi::group_ballot(sg, pred) - * __popc(ballot) -> ballot.count() (or extract_bits + - * __builtin_popcount) getLaneMaskLe() & ballot -> extract_bits + manual - * le_mask + __builtin_popcount getLaneId() -> - * sg.get_local_linear_id() atomicAdd (smem) -> sycl::atomic_ref - * (local_space) - * __syncthreads() -> sycl::group_barrier(item.get_group()) - * doLdg(ptr) -> direct load (no read-only cache hint in SYCL) - * Bitfield::getBitfield -> software shift+mask (no PTX BFE/BFI in SYCL) - * smem[] -> IndexT* from local accessor + * WARP_BALLOT(pred) -> sycl::ext::oneapi::group_ballot(sg, pred) + * __popc(ballot) -> ballot.count() + * getLaneMaskLe() & ballot -> extract_bits + manual le_mask + popcount + * getLaneId() -> sg.get_local_linear_id() + * atomicAdd (smem) -> sycl::atomic_ref (local_space) + * __syncthreads() -> sycl::group_barrier(item.get_group()) + * doLdg(ptr) -> direct load (no read-only cache hint in SYCL) + * Bitfield::getBitfield -> software shift+mask (no PTX BFE/BFI in SYCL) + * smem[] -> IndexT* from local accessor * * withinSliceStride = 1 (input is .contiguous() before calling this kernel). * Uses SBTOPK_RADIX_BITS=4, SBTOPK_RADIX_SIZE=16, SBTOPK_RADIX_MASK=15 * (halving radix passes vs the original RADIX_BITS=2). - * - * Future: CUDA multi-block radix path will be added for very large - * (nsegments × nelements) workloads. */ #include @@ -41,6 +36,8 @@ #include #include #include +#include +#include #include #include @@ -74,7 +71,7 @@ template < int ELEMS_PER_THREAD = 32, int SIMD = 32, typename IndexT = int> -struct SbtopkGatherFunctor { +struct SbtopkGatherFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { using RadixT = typename TopKTypeConfig::RadixType; // CUDA uses sizeof(scalar_t)*8, NOT sizeof(RadixType)*8. // For fp16: sizeof(Half)=2 -> 16 bits, but sizeof(uint32_t)=4 -> 32 bits. @@ -92,7 +89,7 @@ struct SbtopkGatherFunctor { // Eliminates all group_ballot calls in counting. // Result: all threads have identical counts[0..RADIX_SIZE-1]. // ================================================================ - __attribute__((noinline)) void countRadixUsingMask( + C10_NOINLINE void countRadixUsingMask( sycl::nd_item<1> item, sycl::sub_group sg, IndexT* smem, @@ -178,7 +175,7 @@ struct SbtopkGatherFunctor { // SYCL uses smem[SMEM_FOUND_FLAG]=flag, smem[SMEM_FOUND_IDX]=index, // then convert(data[index]). // ================================================================ - __attribute__((noinline)) RadixT findPattern( + C10_NOINLINE RadixT findPattern( sycl::nd_item<1> item, IndexT* smem, const scalar_t* data, @@ -235,7 +232,7 @@ struct SbtopkGatherFunctor { // Sub-group level: inclusive_scan_over_group + group_broadcast // Cross sub-group: smem serial scan (same pattern as binary version) // ================================================================ - __attribute__((noinline)) void exclusiveIntPrefixScan( + C10_NOINLINE void exclusiveIntPrefixScan( sycl::nd_item<1> item, sycl::sub_group sg, IndexT* smem, @@ -282,7 +279,7 @@ struct SbtopkGatherFunctor { // found_non_unique (count>=kToFind): narrow desired/desiredMask, continue // End: return desired (RadixT, fully determined) // ================================================================ - __attribute__((noinline)) RadixT radixSelect( + C10_NOINLINE RadixT radixSelect( sycl::nd_item<1> item, sycl::sub_group sg, IndexT* smem, @@ -541,16 +538,18 @@ struct SbtopkGatherFunctor { IndexT numSlices, IndexT sliceSize, int k, - bool largest, - sycl::local_accessor local_mem) + bool largest) : inputData_(inputData), topKData_(topKData), indicesData_(indicesData), numSlices_(numSlices), sliceSize_(sliceSize), k_(k), - largest_(largest), - local_mem_(local_mem) {} + largest_(largest) {} + + void sycl_ker_config_convention(sycl::handler& cgh) { + local_mem_ = sycl::local_accessor(SMEM_ELEMS, cgh); + } const scalar_t* inputData_; scalar_t* topKData_; @@ -570,7 +569,7 @@ template < int VEC_SIZE, int ELEMS_PER_THREAD, typename IndexT> -static void sbtopk_launch_impl( +static void single_wg_launch_impl( const scalar_t* input, scalar_t* topK, int64_t* indices, @@ -585,30 +584,23 @@ static void sbtopk_launch_impl( using Functor = SbtopkGatherFunctor; - syclex::properties kernel_props{ - syclex::sub_group_size, intelex::grf_size<128>}; - - auto& q = at::xpu::getCurrentSYCLQueue(); - q.submit([&](sycl::handler& cgh) { - sycl::local_accessor local_mem(SMEM_ELEMS, cgh); - Functor functor( - input, topK, indices, numSlices, sliceSize, k, largest, local_mem); - cgh.parallel_for( - sycl::nd_range<1>( - sycl::range<1>(static_cast(numSlices) * SBTOPK_BLOCK), - sycl::range<1>(SBTOPK_BLOCK)), - kernel_props, - functor); - }); + Functor functor(input, topK, indices, numSlices, sliceSize, k, largest); + + sycl_kernel_submit( + static_cast(numSlices) * SBTOPK_BLOCK, + static_cast(SBTOPK_BLOCK), + at::xpu::getCurrentSYCLQueue(), + syclex::properties{syclex::sub_group_size, intelex::grf_size<128>}, + functor); } // Dispatch macro to reduce boilerplate -#define SBTOPK_LAUNCH(V, E) \ - sbtopk_launch_impl( \ +#define SINGLE_WG_LAUNCH(V, E) \ + single_wg_launch_impl( \ input, topK, indices, numSlices, sliceSize, k, largest) template -static void sbtopk_launch_kernel( +static void single_wg_launch_kernel( const scalar_t* input, scalar_t* topK, int64_t* indices, @@ -616,20 +608,16 @@ static void sbtopk_launch_kernel( IndexT sliceSize, int k, bool largest) { - // Determine ELEMS_PER_THREAD based on dim: target ~4 iterations - int ept; - if (sliceSize >= 32 * SBTOPK_BLOCK) - ept = 32; - else if (sliceSize >= 16 * SBTOPK_BLOCK) - ept = 16; - else if (sliceSize >= 8 * SBTOPK_BLOCK) - ept = 8; - else if (sliceSize >= 4 * SBTOPK_BLOCK) - ept = 4; - else if (sliceSize >= 2 * SBTOPK_BLOCK) - ept = 2; - else - ept = 1; + // ELEMS_PER_THREAD: largest power-of-2 such that + // ept * SBTOPK_BLOCK <= sliceSize, capped at 32. + // With ept=32 and 1024 threads, each iteration covers 32K elements, + // keeping the number of gatherTopK iterations small (~4 for dim=131072). + int64_t ratio = sliceSize / SBTOPK_BLOCK; + int ept = ratio >= 1 ? std::min( + 32, + static_cast(c10::llvm::PowerOf2Floor( + static_cast(ratio)))) + : 1; // Determine VEC_SIZE: largest power-of-2 dividing sliceSize, // capped by type max AND by EPT (vec <= ept required) @@ -650,67 +638,67 @@ static void sbtopk_launch_kernel( if (vec == 8) { switch (ept) { case 8: - SBTOPK_LAUNCH(8, 8); + SINGLE_WG_LAUNCH(8, 8); return; case 16: - SBTOPK_LAUNCH(8, 16); + SINGLE_WG_LAUNCH(8, 16); return; default: - SBTOPK_LAUNCH(8, 32); + SINGLE_WG_LAUNCH(8, 32); return; } } else if (vec == 4) { switch (ept) { case 4: - SBTOPK_LAUNCH(4, 4); + SINGLE_WG_LAUNCH(4, 4); return; case 8: - SBTOPK_LAUNCH(4, 8); + SINGLE_WG_LAUNCH(4, 8); return; case 16: - SBTOPK_LAUNCH(4, 16); + SINGLE_WG_LAUNCH(4, 16); return; default: - SBTOPK_LAUNCH(4, 32); + SINGLE_WG_LAUNCH(4, 32); return; } } else if (vec == 2) { switch (ept) { case 2: - SBTOPK_LAUNCH(2, 2); + SINGLE_WG_LAUNCH(2, 2); return; case 4: - SBTOPK_LAUNCH(2, 4); + SINGLE_WG_LAUNCH(2, 4); return; case 8: - SBTOPK_LAUNCH(2, 8); + SINGLE_WG_LAUNCH(2, 8); return; case 16: - SBTOPK_LAUNCH(2, 16); + SINGLE_WG_LAUNCH(2, 16); return; default: - SBTOPK_LAUNCH(2, 32); + SINGLE_WG_LAUNCH(2, 32); return; } } else { switch (ept) { case 1: - SBTOPK_LAUNCH(1, 1); + SINGLE_WG_LAUNCH(1, 1); return; case 2: - SBTOPK_LAUNCH(1, 2); + SINGLE_WG_LAUNCH(1, 2); return; case 4: - SBTOPK_LAUNCH(1, 4); + SINGLE_WG_LAUNCH(1, 4); return; case 8: - SBTOPK_LAUNCH(1, 8); + SINGLE_WG_LAUNCH(1, 8); return; case 16: - SBTOPK_LAUNCH(1, 16); + SINGLE_WG_LAUNCH(1, 16); return; default: - SBTOPK_LAUNCH(1, 32); + SINGLE_WG_LAUNCH(1, 32); return; } } @@ -719,62 +707,62 @@ static void sbtopk_launch_kernel( if (vec >= 4) { switch (ept) { case 4: - SBTOPK_LAUNCH(4, 4); + SINGLE_WG_LAUNCH(4, 4); return; case 8: - SBTOPK_LAUNCH(4, 8); + SINGLE_WG_LAUNCH(4, 8); return; case 16: - SBTOPK_LAUNCH(4, 16); + SINGLE_WG_LAUNCH(4, 16); return; default: - SBTOPK_LAUNCH(4, 32); + SINGLE_WG_LAUNCH(4, 32); return; } } else if (vec == 2) { switch (ept) { case 2: - SBTOPK_LAUNCH(2, 2); + SINGLE_WG_LAUNCH(2, 2); return; case 4: - SBTOPK_LAUNCH(2, 4); + SINGLE_WG_LAUNCH(2, 4); return; case 8: - SBTOPK_LAUNCH(2, 8); + SINGLE_WG_LAUNCH(2, 8); return; case 16: - SBTOPK_LAUNCH(2, 16); + SINGLE_WG_LAUNCH(2, 16); return; default: - SBTOPK_LAUNCH(2, 32); + SINGLE_WG_LAUNCH(2, 32); return; } } else { switch (ept) { case 1: - SBTOPK_LAUNCH(1, 1); + SINGLE_WG_LAUNCH(1, 1); return; case 2: - SBTOPK_LAUNCH(1, 2); + SINGLE_WG_LAUNCH(1, 2); return; case 4: - SBTOPK_LAUNCH(1, 4); + SINGLE_WG_LAUNCH(1, 4); return; case 8: - SBTOPK_LAUNCH(1, 8); + SINGLE_WG_LAUNCH(1, 8); return; case 16: - SBTOPK_LAUNCH(1, 16); + SINGLE_WG_LAUNCH(1, 16); return; default: - SBTOPK_LAUNCH(1, 32); + SINGLE_WG_LAUNCH(1, 32); return; } } } } -#undef SBTOPK_LAUNCH +#undef SINGLE_WG_LAUNCH bool single_wg_topk_try_launch( const at::Tensor& self, @@ -806,7 +794,7 @@ bool single_wg_topk_try_launch( static_cast(std::numeric_limits::max()) && nsegments <= static_cast(std::numeric_limits::max()) / (nelements > 0 ? nelements : 1)) { - sbtopk_launch_kernel( + single_wg_launch_kernel( input, topK, idx, @@ -815,7 +803,7 @@ bool single_wg_topk_try_launch( static_cast(k), largest); } else { - sbtopk_launch_kernel( + single_wg_launch_kernel( input, topK, idx, From fcd240508095d51acbe0db16be48ba36cfda6c6c Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Thu, 14 May 2026 14:41:30 +0800 Subject: [PATCH 07/10] Add input pointer alignment check for vectorized loads in single-wg kernel --- .../native/xpu/sycl/TensorTopKSingleWgKernel.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp index c75d60422c..a0da5a9bf6 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -620,15 +620,21 @@ static void single_wg_launch_kernel( : 1; // Determine VEC_SIZE: largest power-of-2 dividing sliceSize, - // capped by type max AND by EPT (vec <= ept required) + // capped by type max AND by EPT (vec <= ept required). + // Also check input pointer alignment — a non-zero storage offset + // can break alignment even for contiguous tensors. constexpr int MAX_VEC = sizeof(scalar_t) <= 2 ? 8 : 4; int cap = MAX_VEC < ept ? MAX_VEC : ept; + auto input_align = reinterpret_cast(input); + auto aligned = [&](int v) { + return input_align % (sizeof(scalar_t) * v) == 0; + }; int vec = 1; - if (cap >= 8 && sliceSize % 8 == 0) + if (cap >= 8 && sliceSize % 8 == 0 && aligned(8)) vec = 8; - else if (cap >= 4 && sliceSize % 4 == 0) + else if (cap >= 4 && sliceSize % 4 == 0 && aligned(4)) vec = 4; - else if (cap >= 2 && sliceSize % 2 == 0) + else if (cap >= 2 && sliceSize % 2 == 0 && aligned(2)) vec = 2; // Dispatch: VEC determines which EPT values are valid (EPT >= VEC, EPT % VEC From 820d6dcadadfacf515e96ae911f73ba59243308d Mon Sep 17 00:00:00 2001 From: jianyizh Date: Thu, 14 May 2026 14:45:39 +0800 Subject: [PATCH 08/10] Update src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp Co-authored-by: Yu, Guangye --- src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp index a0da5a9bf6..225730f5c7 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -160,7 +160,7 @@ struct SbtopkGatherFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { } sycl::group_barrier(item.get_group()); - // All threads read block-level totals + // All threads read workgroup-level totals #pragma unroll for (int j = 0; j < SBTOPK_RADIX_SIZE; ++j) { counts[j] = smem[j]; From 27e8ed2bd4109f81a017793376357c8fe1cc1b7e Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Thu, 14 May 2026 15:33:50 +0800 Subject: [PATCH 09/10] Address review: use SYCL terminology, at::ceil_div, and static_assert for SLM layout safety --- src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp index 225730f5c7..6c9084c52f 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -193,7 +193,7 @@ struct SbtopkGatherFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { sycl::group_barrier(item.get_group()); IndexT numIterations = - ((sliceSize + block_size - 1) / block_size) * block_size; + at::ceil_div(sliceSize, static_cast(block_size)) * block_size; for (IndexT i = lid; i < numIterations; i += block_size) { bool inRange = (i < sliceSize); @@ -242,6 +242,9 @@ struct SbtopkGatherFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int sg_lid = sg.get_local_linear_id(); int sg_id = sg.get_group_linear_id(); constexpr int num_sgs = SBTOPK_BLOCK / SIMD; + static_assert( + num_sgs <= SMEM_FOUND_FLAG, + "num_sgs exceeds SMEM_FOUND_FLAG; SLM layout collision"); int sg_inclusive = sycl::inclusive_scan_over_group(sg, local_count, sycl::plus()); From ff99c422724e01d335430696b9296466d6767458 Mon Sep 17 00:00:00 2001 From: "Zhang, Jianyi" Date: Thu, 14 May 2026 15:54:49 +0800 Subject: [PATCH 10/10] Use nested namespace, remove redundant overflow check --- .../native/xpu/sycl/TensorTopKSingleWgKernel.cpp | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp index 6c9084c52f..70b5895204 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -41,9 +41,7 @@ #include #include -namespace at { -namespace native { -namespace xpu { +namespace at::native::xpu { // Uses RADIX_BITS=4 (16 digits per pass), halving radix passes for fp32. // Cannot reuse RADIX_BITS/SIZE/MASK from SortingRadixSelect.h (constexpr int, @@ -800,9 +798,7 @@ bool single_wg_topk_try_launch( if (nsegments <= static_cast(std::numeric_limits::max()) && nelements <= - static_cast(std::numeric_limits::max()) && - nsegments <= static_cast(std::numeric_limits::max()) / - (nelements > 0 ? nelements : 1)) { + static_cast(std::numeric_limits::max())) { single_wg_launch_kernel( input, topK, @@ -826,6 +822,4 @@ bool single_wg_topk_try_launch( return true; } -} // namespace xpu -} // namespace native -} // namespace at +} // namespace at::native::xpu