diff --git a/src/ATen/native/xpu/sycl/SortingRadixSelect.h b/src/ATen/native/xpu/sycl/SortingRadixSelect.h index 52d18a43fc..f8569c1bdf 100644 --- a/src/ATen/native/xpu/sycl/SortingRadixSelect.h +++ b/src/ATen/native/xpu/sycl/SortingRadixSelect.h @@ -87,7 +87,7 @@ struct TopKTypeConfig { static inline RadixType convert(float v) { RadixType x = *((uint32_t*)&v); RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; - return (x ^ mask); + return (v == v) ? (x ^ mask) : 0xffffffff; } static inline float deconvert(RadixType v) { @@ -168,7 +168,7 @@ struct TopKTypeConfig { static inline RadixType convert(double v) { RadixType x = *((uint64_t*)&v); RadixType mask = -((x >> 63)) | 0x8000000000000000; - return (x ^ mask); + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; } static inline double deconvert(RadixType v) { @@ -183,12 +183,12 @@ struct TopKTypeConfig { static inline RadixType convert(at::Half v) { RadixType x = *((uint16_t*)&v); - RadixType mask = -((x >> 15)) | 0x8000; - return (x ^ mask); + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v == v) ? (x ^ mask) : 0xffff; } static inline at::Half deconvert(RadixType v) { - RadixType mask = ((v >> 15) - 1) | 0x8000; + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; return __ushort_as_half(v ^ mask); } }; diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp index 298cfbc410..bd833a41f5 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -68,10 +69,13 @@ static bool subgroup_topk_try_launch( #undef SBTOPK_LAUNCH // ================================================================ -// Dispatch: subgroup top-k vs original +// Dispatch: subgroup top-k vs single-workgroup top-k vs original // // - dim < 32: original (need at least SG_SIZE elements) -// - dim >= 32, large batch, k <= 16: subgroup top-k +// - dim >= 32, large batch, k <= 8: subgroup top-k +// - dim >= 4096, any bs: single-workgroup top-k +// - dim >= 4096, k > 8: single-workgroup top-k +// (subgroup only supports k <= 8) // ================================================================ SbtopkResult sbtopk_try_launch( const at::Tensor& self, @@ -86,12 +90,12 @@ SbtopkResult sbtopk_try_launch( return SbtopkResult::FAILED; } - // Subgroup top-k: best for large batch, k<=8. + // Subgroup top-k: best for large batch, k <= 8. // Output is ALREADY SORTED (descending for largest, ascending for smallest). // // Threshold: nsegments >= thread_slots / 4. // Subgroup top-k uses 1 sub-group per slice (reading data once), while - // the original kernel reads data multiple times (~3 radix passes). So + // single-wg/original read data multiple times (~3 radix passes). So // subgroup top-k reaches memory-BW saturation at much lower occupancy. // thread_slots/4 is the conservative cutoff. // @@ -107,6 +111,19 @@ SbtopkResult sbtopk_try_launch( return SbtopkResult::FAILED; } + // Single-workgroup top-k: radix select, good for large + // dim. Output is UNSORTED. + // Use nelements >= 4096 so dim=1024/1025 falls through to original + // (single-wg has regressions at dim=1024 for medium/large batches). + if (nelements >= 4096) { + if (single_wg_topk_try_launch( + self, nsegments, nelements, k, largest, values, indices)) { + return SbtopkResult::UNSORTED; + } + return SbtopkResult::FAILED; + } + + // Fallback to original for dim=32-4095 or k>8 with small batch return SbtopkResult::FAILED; } diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h index 91c1bfcf81..d224aa3bff 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h +++ b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h @@ -15,10 +15,10 @@ namespace at::native::xpu { // Result of sbtopk_try_launch. -// FAILED - did not run; caller should fall back to original kernel. -// UNSORTED - ran; output contains top-k values but is not sorted. +// FAILED - sbtopk did not run; caller should fall back to original. +// UNSORTED - sbtopk ran; output contains top-k values but is not sorted. // Caller must sort if sorted output is requested. -// SORTED - ran; output is already sorted (descending for largest, +// SORTED - sbtopk ran; output is already sorted (descending for largest, // ascending for smallest). Caller can skip sort. enum class SbtopkResult : int { FAILED = 0, @@ -26,6 +26,11 @@ enum class SbtopkResult : int { SORTED = 2, }; +// Try to run topk using an optimized kernel path. +// +// Dispatches between the subgroup topk kernel (sub-group bitonic merge, +// output SORTED) and the single workgroup topk kernel (radix select, +// output UNSORTED) based on (nsegments, nelements, k). SbtopkResult sbtopk_try_launch( const at::Tensor& self, int64_t nsegments, diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp new file mode 100644 index 0000000000..70b5895204 --- /dev/null +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.cpp @@ -0,0 +1,825 @@ +/* + * Single-workgroup top-k kernel — SYCL translation of CUDA single-block path. + * + * One workgroup (1024 threads) handles one slice. Algorithm: + * 1. radixSelect: iterates MSB→LSB in RADIX_BITS=4 steps to identify + * the k-th largest (or smallest) value. + * 2. gatherTopK: two-pass gather — first collects values strictly + * better than the k-th value, then fills remaining slots with + * values equal to the k-th value. + * + * Output is UNSORTED. Caller applies segmented sort if sorted output needed. + * + * CUDA sources translated 1:1: + * - SortingRadixSelect.cuh: countRadixUsingMask, findPattern + * - ScanUtils.cuh: inclusiveBinaryPrefixScan, exclusiveBinaryPrefixScan + * - TensorTopK.cu: gatherTopK, radixSelect + * + * Key CUDA -> SYCL mappings: + * WARP_BALLOT(pred) -> sycl::ext::oneapi::group_ballot(sg, pred) + * __popc(ballot) -> ballot.count() + * getLaneMaskLe() & ballot -> extract_bits + manual le_mask + popcount + * getLaneId() -> sg.get_local_linear_id() + * atomicAdd (smem) -> sycl::atomic_ref (local_space) + * __syncthreads() -> sycl::group_barrier(item.get_group()) + * doLdg(ptr) -> direct load (no read-only cache hint in SYCL) + * Bitfield::getBitfield -> software shift+mask (no PTX BFE/BFI in SYCL) + * smem[] -> IndexT* from local accessor + * + * withinSliceStride = 1 (input is .contiguous() before calling this kernel). + * Uses SBTOPK_RADIX_BITS=4, SBTOPK_RADIX_SIZE=16, SBTOPK_RADIX_MASK=15 + * (halving radix passes vs the original RADIX_BITS=2). + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native::xpu { + +// Uses RADIX_BITS=4 (16 digits per pass), halving radix passes for fp32. +// Cannot reuse RADIX_BITS/SIZE/MASK from SortingRadixSelect.h (constexpr int, +// can't #undef). +constexpr int SBTOPK_RADIX_BITS = 4; +constexpr int SBTOPK_RADIX_SIZE = 16; // 2 ^ SBTOPK_RADIX_BITS +constexpr int SBTOPK_RADIX_MASK = (SBTOPK_RADIX_SIZE - 1); + +// Block size = 1024 threads, matching CUDA C10_LAUNCH_BOUNDS_1(1024) +constexpr int SBTOPK_BLOCK = 1024; + +// SLM layout: +// [0..63] : used by countRadixUsingMask (smem[0..SBTOPK_RADIX_SIZE-1] for +// counts) +// and by exclusiveIntPrefixScan (smem[0..num_sgs-1] for carries) +// [64..65] : used by findPattern (flag + found index) +// Total: 68 elements (IndexT-sized) +constexpr int SMEM_FOUND_FLAG = 64; +constexpr int SMEM_FOUND_IDX = 65; +constexpr int SMEM_ELEMS = 68; + +template < + typename scalar_t, + int VEC_SIZE = 4, + int ELEMS_PER_THREAD = 32, + int SIMD = 32, + typename IndexT = int> +struct SbtopkGatherFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + using RadixT = typename TopKTypeConfig::RadixType; + // CUDA uses sizeof(scalar_t)*8, NOT sizeof(RadixType)*8. + // For fp16: sizeof(Half)=2 -> 16 bits, but sizeof(uint32_t)=4 -> 32 bits. + // Using RadixT would scan garbage upper bits and break Half/BFloat16. + static constexpr int NUM_BITS = sizeof(scalar_t) * 8; + + // ================================================================ + // countRadixUsingMask — per-thread counting + sub-group/work-group reduce + // + // Replaces ballot-based counting. Each thread: + // 1. Loads VEC_SIZE elements per iteration (vectorized) + // 2. Locally increments counts[digit] (pure ALU, no cross-lane) + // 3. After loop: sub-group reduce + lane0 atomicAdd to smem + broadcast + // + // Eliminates all group_ballot calls in counting. + // Result: all threads have identical counts[0..RADIX_SIZE-1]. + // ================================================================ + C10_NOINLINE void countRadixUsingMask( + sycl::nd_item<1> item, + sycl::sub_group sg, + IndexT* smem, + IndexT counts[SBTOPK_RADIX_SIZE], + RadixT desired, + RadixT desiredMask, + int digitPos, + const scalar_t* data, + IndexT sliceSize) const { + int lid = item.get_local_id(0); + int block_size = item.get_local_range(0); + int sg_lid = sg.get_local_linear_id(); + +#pragma unroll + for (int i = 0; i < SBTOPK_RADIX_SIZE; ++i) { + counts[i] = 0; + } + if (lid < SBTOPK_RADIX_SIZE) { + smem[lid] = 0; + } + sycl::group_barrier(item.get_group()); + + // Each thread processes VEC_SIZE consecutive elements per iteration. + // Stride = block_size * VEC_SIZE for coalesced access across threads. + using LoadT = memory::aligned_vector; + IndexT stride = static_cast(block_size) * VEC_SIZE; + + // Vectorized main loop — full VEC_SIZE loads + IndexT base = static_cast(lid) * VEC_SIZE; + for (; base + VEC_SIZE <= sliceSize; base += stride) { + alignas(alignof(LoadT)) scalar_t src[VEC_SIZE]; + *reinterpret_cast(&src) = + *reinterpret_cast(&data[base]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + RadixT val = TopKTypeConfig::convert(src[v]); + if ((val & desiredMask) == desired) { + RadixT digit = + Bitfield::getBitfield(val, digitPos, SBTOPK_RADIX_BITS); + counts[digit]++; + } + } + } + // Scalar tail — remaining elements + for (IndexT idx = base; idx < sliceSize && idx < base + VEC_SIZE; ++idx) { + RadixT val = TopKTypeConfig::convert(data[idx]); + if ((val & desiredMask) == desired) { + RadixT digit = + Bitfield::getBitfield(val, digitPos, SBTOPK_RADIX_BITS); + counts[digit]++; + } + } + + // Sub-group reduce + lane0 atomicAdd to smem. +#pragma unroll + for (int j = 0; j < SBTOPK_RADIX_SIZE; ++j) { + IndexT total = + sycl::reduce_over_group(sg, counts[j], sycl::plus()); + if (sg_lid == 0) { + sycl::atomic_ref< + IndexT, + sycl::memory_order::relaxed, + sycl::memory_scope::work_group, + sycl::access::address_space::local_space> + ref(smem[j]); + ref.fetch_add(total); + } + } + sycl::group_barrier(item.get_group()); + + // All threads read workgroup-level totals +#pragma unroll + for (int j = 0; j < SBTOPK_RADIX_SIZE; ++j) { + counts[j] = smem[j]; + } + } + + // ================================================================ + // findPattern (SortingRadixSelect.cuh:239) + // + // Finds the unique value whose convert() matches desired. + // Returns RadixT (converted form) directly — no deconvert needed. + // SYCL uses smem[SMEM_FOUND_FLAG]=flag, smem[SMEM_FOUND_IDX]=index, + // then convert(data[index]). + // ================================================================ + C10_NOINLINE RadixT findPattern( + sycl::nd_item<1> item, + IndexT* smem, + const scalar_t* data, + IndexT sliceSize, + RadixT desired, + RadixT desiredMask) const { + int lid = item.get_local_id(0); + int block_size = item.get_local_range(0); + + if (lid == 0) { + smem[SMEM_FOUND_FLAG] = 0; + smem[SMEM_FOUND_IDX] = static_cast(-1); + } + // Barrier required: init must be visible before any thread enters the loop + sycl::group_barrier(item.get_group()); + + IndexT numIterations = + at::ceil_div(sliceSize, static_cast(block_size)) * block_size; + + for (IndexT i = lid; i < numIterations; i += block_size) { + bool inRange = (i < sliceSize); + scalar_t v = inRange ? data[i] : static_cast(0); + + if (inRange && + ((TopKTypeConfig::convert(v) & desiredMask) == desired)) { + smem[SMEM_FOUND_FLAG] = 1; + smem[SMEM_FOUND_IDX] = i; + } + sycl::group_barrier(item.get_group()); + + IndexT found = smem[SMEM_FOUND_FLAG]; + IndexT foundIdx = smem[SMEM_FOUND_IDX]; + + if (found != 0) { + return TopKTypeConfig::convert(data[foundIdx]); + } + + // WAR barrier: protect smem writes in next iteration from current reads + sycl::group_barrier(item.get_group()); + } + return static_cast(0); + } + + // ================================================================ + // exclusiveIntPrefixScan — integer exclusive prefix scan + // + // Each thread provides an integer count (0..ELEMS_PER_THREAD). Returns: + // out: exclusive prefix sum (write offset for this thread) + // carry: total sum across all threads in the work-group + // + // Values are bounded by SBTOPK_BLOCK * ELEMS_PER_THREAD (always fits int), + // but smem is IndexT* so we cast at the interface. + // + // Sub-group level: inclusive_scan_over_group + group_broadcast + // Cross sub-group: smem serial scan (same pattern as binary version) + // ================================================================ + C10_NOINLINE void exclusiveIntPrefixScan( + sycl::nd_item<1> item, + sycl::sub_group sg, + IndexT* smem, + int local_count, + int& out, + int& carry) const { + int sg_lid = sg.get_local_linear_id(); + int sg_id = sg.get_group_linear_id(); + constexpr int num_sgs = SBTOPK_BLOCK / SIMD; + static_assert( + num_sgs <= SMEM_FOUND_FLAG, + "num_sgs exceeds SMEM_FOUND_FLAG; SLM layout collision"); + + int sg_inclusive = + sycl::inclusive_scan_over_group(sg, local_count, sycl::plus()); + int sg_exclusive = sg_inclusive - local_count; + + // group_broadcast to get sub-group total (last lane's inclusive value) + int sg_total = sycl::group_broadcast(sg, sg_inclusive, SIMD - 1); + if (sg_lid == SIMD - 1) { + smem[sg_id] = static_cast(sg_total); + } + sycl::group_barrier(item.get_group()); + + // Thread 0: serial inclusive prefix sum over sub-group totals + if (item.get_local_id(0) == 0) { + IndexT current = 0; + for (int i = 0; i < num_sgs; ++i) { + IndexT v = smem[i]; + smem[i] = v + current; + current += v; + } + } + sycl::group_barrier(item.get_group()); + + int cross_sg_prefix = (sg_id >= 1) ? static_cast(smem[sg_id - 1]) : 0; + out = sg_exclusive + cross_sg_prefix; + carry = static_cast(smem[num_sgs - 1]); + } + + // ================================================================ + // radixSelect (SortingRadixSelect.cuh:860, non-ROCm path) + // + // Iterates MSB to LSB in RADIX_BITS steps. + // At each step: count digits, scan to find which digit contains k-th. + // found_unique (count==1, kToFind==1): findPattern + return RadixT + // found_non_unique (count>=kToFind): narrow desired/desiredMask, continue + // End: return desired (RadixT, fully determined) + // ================================================================ + C10_NOINLINE RadixT radixSelect( + sycl::nd_item<1> item, + sycl::sub_group sg, + IndexT* smem, + const scalar_t* data, + int k, + bool largest, + IndexT sliceSize) const { + IndexT counts[SBTOPK_RADIX_SIZE]; + RadixT desired = 0; + RadixT desiredMask = 0; + int kToFind = k; + + for (int digitPos = NUM_BITS - SBTOPK_RADIX_BITS; digitPos >= 0; + digitPos -= SBTOPK_RADIX_BITS) { + countRadixUsingMask( + item, + sg, + smem, + counts, + desired, + desiredMask, + digitPos, + data, + sliceSize); + + // All threads execute the same scan logic (counts are identical). + // Replicates CUDA found_unique / found_non_unique lambdas exactly. + if (largest) { + for (int i = SBTOPK_RADIX_SIZE - 1; i >= 0; --i) { + IndexT count = counts[i]; + + // found_unique: return from radixSelect + if (count == 1 && kToFind == 1) { + desired = Bitfield::setBitfield( + desired, i, digitPos, SBTOPK_RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, SBTOPK_RADIX_MASK, digitPos, SBTOPK_RADIX_BITS); + return findPattern( + item, smem, data, sliceSize, desired, desiredMask); + } + + // found_non_unique: break inner loop, continue outer + if (count >= kToFind) { + desired = Bitfield::setBitfield( + desired, i, digitPos, SBTOPK_RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, SBTOPK_RADIX_MASK, digitPos, SBTOPK_RADIX_BITS); + break; + } + + // count < kToFind here, so count fits in int + kToFind -= static_cast(count); + } + } else { + for (int i = 0; i < SBTOPK_RADIX_SIZE; ++i) { + IndexT count = counts[i]; + + if (count == 1 && kToFind == 1) { + desired = Bitfield::setBitfield( + desired, i, digitPos, SBTOPK_RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, SBTOPK_RADIX_MASK, digitPos, SBTOPK_RADIX_BITS); + return findPattern( + item, smem, data, sliceSize, desired, desiredMask); + } + + if (count >= kToFind) { + desired = Bitfield::setBitfield( + desired, i, digitPos, SBTOPK_RADIX_BITS); + desiredMask = Bitfield::setBitfield( + desiredMask, SBTOPK_RADIX_MASK, digitPos, SBTOPK_RADIX_BITS); + break; + } + + kToFind -= static_cast(count); + } + } + } + + // No unique result; desired fully determined + return desired; + } + + // ================================================================ + // operator() — gatherTopK (TensorTopK.cu:40-182) + // + // 1. radixSelect to find k-th value + // 2. Gather values strictly > topK (largest) or < topK (!largest) + // 3. Fill remaining with values == topK + // + // Each thread processes ELEMS_PER_THREAD elements per iteration + // (LOADS_PER_ITER × vec loads), then ONE prefix scan per iteration. + // With ELEMS_PER_THREAD=32 and 1024 threads, each iteration covers + // 32K elements, so dim=131072 needs only 4 iterations. + // ================================================================ + void operator()(sycl::nd_item<1> item) const { + IndexT slice = static_cast(item.get_group_linear_id()); + if (slice >= numSlices_) + return; + + sycl::sub_group sg = item.get_sub_group(); + + // Get raw IndexT* pointer from local accessor + IndexT* smem = + local_mem_.template get_multi_ptr().get(); + + const scalar_t* inputSlice = + inputData_ + static_cast(slice) * sliceSize_; + scalar_t* topKSlice = topKData_ + static_cast(slice) * k_; + int64_t* indicesSlice = indicesData_ + static_cast(slice) * k_; + + // Step 1: radixSelect — returns RadixT directly (no deconvert/convert + // round-trip) + RadixT topKConverted = + radixSelect(item, sg, smem, inputSlice, k_, largest_, sliceSize_); + + // Vectorized gather setup + // ELEMS_PER_THREAD: each thread processes this many elements per iteration. + // Multiple vec loads per iteration, then ONE prefix scan. + // With ELEMS_PER_THREAD=32: 4 iterations for dim=131072. + static constexpr int LOADS_PER_ITER = ELEMS_PER_THREAD / VEC_SIZE; + using LoadT = memory::aligned_vector; + int lid = item.get_local_id(0); + + // Each iteration covers SBTOPK_BLOCK * ELEMS_PER_THREAD elements. + IndexT iter_stride = static_cast(SBTOPK_BLOCK) * ELEMS_PER_THREAD; + int numIters = + static_cast((sliceSize_ + iter_stride - 1) / iter_stride); + + // Step 2: Gather values strictly greater/less than topKValue + int writeIndexStart = 0; + + for (int iter = 0; iter < numIters; ++iter) { + // Each thread loads ELEMS_PER_THREAD elements from LOADS_PER_ITER vec4 + // chunks. Thread layout: consecutive threads handle consecutive VEC_SIZE + // chunks. Thread t handles chunks at offsets: t*VEC_SIZE, + // (t+SBTOPK_BLOCK)*VEC_SIZE, ... + scalar_t vals[ELEMS_PER_THREAD]; + IndexT match_indices[ELEMS_PER_THREAD]; + int local_count = 0; + + IndexT iter_base = static_cast(iter) * iter_stride; + +#pragma unroll + for (int L = 0; L < LOADS_PER_ITER; ++L) { + IndexT base = iter_base + + static_cast(L) * SBTOPK_BLOCK * VEC_SIZE + + static_cast(lid) * VEC_SIZE; + + if (base + VEC_SIZE <= sliceSize_) { + alignas(alignof(LoadT)) scalar_t src[VEC_SIZE]; + *reinterpret_cast(&src) = + *reinterpret_cast(&inputSlice[base]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + RadixT cv = TopKTypeConfig::convert(src[v]); + bool match = largest_ ? (cv > topKConverted) : (cv < topKConverted); + if (match) { + vals[local_count] = src[v]; + match_indices[local_count] = base + v; + local_count++; + } + } + } else if (base < sliceSize_) { + for (int v = 0; v < VEC_SIZE && base + v < sliceSize_; ++v) { + scalar_t sv = inputSlice[base + v]; + RadixT cv = TopKTypeConfig::convert(sv); + bool match = largest_ ? (cv > topKConverted) : (cv < topKConverted); + if (match) { + vals[local_count] = sv; + match_indices[local_count] = base + v; + local_count++; + } + } + } + } + + int offset, carry; + exclusiveIntPrefixScan(item, sg, smem, local_count, offset, carry); + + for (int j = 0; j < local_count; ++j) { + int writeIndex = writeIndexStart + offset + j; + if (writeIndex < k_) { + topKSlice[writeIndex] = vals[j]; + indicesSlice[writeIndex] = match_indices[j]; + } + } + writeIndexStart += carry; + } + + // Step 3: Fill remaining with values == topKValue + int topKRemaining = k_ - writeIndexStart; + + for (int iter = 0; iter < numIters; ++iter) { + scalar_t vals[ELEMS_PER_THREAD]; + IndexT match_indices[ELEMS_PER_THREAD]; + int local_count = 0; + + IndexT iter_base = static_cast(iter) * iter_stride; + +#pragma unroll + for (int L = 0; L < LOADS_PER_ITER; ++L) { + IndexT base = iter_base + + static_cast(L) * SBTOPK_BLOCK * VEC_SIZE + + static_cast(lid) * VEC_SIZE; + + if (base + VEC_SIZE <= sliceSize_) { + alignas(alignof(LoadT)) scalar_t src[VEC_SIZE]; + *reinterpret_cast(&src) = + *reinterpret_cast(&inputSlice[base]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + RadixT cv = TopKTypeConfig::convert(src[v]); + if (cv == topKConverted) { + vals[local_count] = src[v]; + match_indices[local_count] = base + v; + local_count++; + } + } + } else if (base < sliceSize_) { + for (int v = 0; v < VEC_SIZE && base + v < sliceSize_; ++v) { + scalar_t sv = inputSlice[base + v]; + RadixT cv = TopKTypeConfig::convert(sv); + if (cv == topKConverted) { + vals[local_count] = sv; + match_indices[local_count] = base + v; + local_count++; + } + } + } + } + + int offset, carry; + exclusiveIntPrefixScan(item, sg, smem, local_count, offset, carry); + + for (int j = 0; j < local_count; ++j) { + if (offset + j < topKRemaining) { + int writeIndex = writeIndexStart + offset + j; + topKSlice[writeIndex] = vals[j]; + indicesSlice[writeIndex] = match_indices[j]; + } + } + + if (carry >= topKRemaining) { + break; + } + topKRemaining -= carry; + writeIndexStart += carry; + } + } + + SbtopkGatherFunctor( + const scalar_t* inputData, + scalar_t* topKData, + int64_t* indicesData, + IndexT numSlices, + IndexT sliceSize, + int k, + bool largest) + : inputData_(inputData), + topKData_(topKData), + indicesData_(indicesData), + numSlices_(numSlices), + sliceSize_(sliceSize), + k_(k), + largest_(largest) {} + + void sycl_ker_config_convention(sycl::handler& cgh) { + local_mem_ = sycl::local_accessor(SMEM_ELEMS, cgh); + } + + const scalar_t* inputData_; + scalar_t* topKData_; + int64_t* indicesData_; + IndexT numSlices_; + IndexT sliceSize_; + int k_; + bool largest_; + sycl::local_accessor local_mem_; +}; + +// ================================================================ +// Launch function +// ================================================================ +template < + typename scalar_t, + int VEC_SIZE, + int ELEMS_PER_THREAD, + typename IndexT> +static void single_wg_launch_impl( + const scalar_t* input, + scalar_t* topK, + int64_t* indices, + IndexT numSlices, + IndexT sliceSize, + int k, + bool largest) { + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + constexpr int SIMD = 32; + using Functor = + SbtopkGatherFunctor; + + Functor functor(input, topK, indices, numSlices, sliceSize, k, largest); + + sycl_kernel_submit( + static_cast(numSlices) * SBTOPK_BLOCK, + static_cast(SBTOPK_BLOCK), + at::xpu::getCurrentSYCLQueue(), + syclex::properties{syclex::sub_group_size, intelex::grf_size<128>}, + functor); +} + +// Dispatch macro to reduce boilerplate +#define SINGLE_WG_LAUNCH(V, E) \ + single_wg_launch_impl( \ + input, topK, indices, numSlices, sliceSize, k, largest) + +template +static void single_wg_launch_kernel( + const scalar_t* input, + scalar_t* topK, + int64_t* indices, + IndexT numSlices, + IndexT sliceSize, + int k, + bool largest) { + // ELEMS_PER_THREAD: largest power-of-2 such that + // ept * SBTOPK_BLOCK <= sliceSize, capped at 32. + // With ept=32 and 1024 threads, each iteration covers 32K elements, + // keeping the number of gatherTopK iterations small (~4 for dim=131072). + int64_t ratio = sliceSize / SBTOPK_BLOCK; + int ept = ratio >= 1 ? std::min( + 32, + static_cast(c10::llvm::PowerOf2Floor( + static_cast(ratio)))) + : 1; + + // Determine VEC_SIZE: largest power-of-2 dividing sliceSize, + // capped by type max AND by EPT (vec <= ept required). + // Also check input pointer alignment — a non-zero storage offset + // can break alignment even for contiguous tensors. + constexpr int MAX_VEC = sizeof(scalar_t) <= 2 ? 8 : 4; + int cap = MAX_VEC < ept ? MAX_VEC : ept; + auto input_align = reinterpret_cast(input); + auto aligned = [&](int v) { + return input_align % (sizeof(scalar_t) * v) == 0; + }; + int vec = 1; + if (cap >= 8 && sliceSize % 8 == 0 && aligned(8)) + vec = 8; + else if (cap >= 4 && sliceSize % 4 == 0 && aligned(4)) + vec = 4; + else if (cap >= 2 && sliceSize % 2 == 0 && aligned(2)) + vec = 2; + + // Dispatch: VEC determines which EPT values are valid (EPT >= VEC, EPT % VEC + // == 0) + if constexpr (MAX_VEC == 8) { + // 16-bit types: VEC can be 8, 4, 2, 1 + if (vec == 8) { + switch (ept) { + case 8: + SINGLE_WG_LAUNCH(8, 8); + return; + case 16: + SINGLE_WG_LAUNCH(8, 16); + return; + default: + SINGLE_WG_LAUNCH(8, 32); + return; + } + } else if (vec == 4) { + switch (ept) { + case 4: + SINGLE_WG_LAUNCH(4, 4); + return; + case 8: + SINGLE_WG_LAUNCH(4, 8); + return; + case 16: + SINGLE_WG_LAUNCH(4, 16); + return; + default: + SINGLE_WG_LAUNCH(4, 32); + return; + } + } else if (vec == 2) { + switch (ept) { + case 2: + SINGLE_WG_LAUNCH(2, 2); + return; + case 4: + SINGLE_WG_LAUNCH(2, 4); + return; + case 8: + SINGLE_WG_LAUNCH(2, 8); + return; + case 16: + SINGLE_WG_LAUNCH(2, 16); + return; + default: + SINGLE_WG_LAUNCH(2, 32); + return; + } + } else { + switch (ept) { + case 1: + SINGLE_WG_LAUNCH(1, 1); + return; + case 2: + SINGLE_WG_LAUNCH(1, 2); + return; + case 4: + SINGLE_WG_LAUNCH(1, 4); + return; + case 8: + SINGLE_WG_LAUNCH(1, 8); + return; + case 16: + SINGLE_WG_LAUNCH(1, 16); + return; + default: + SINGLE_WG_LAUNCH(1, 32); + return; + } + } + } else { + // 32-bit types: VEC can be 4, 2, 1 + if (vec >= 4) { + switch (ept) { + case 4: + SINGLE_WG_LAUNCH(4, 4); + return; + case 8: + SINGLE_WG_LAUNCH(4, 8); + return; + case 16: + SINGLE_WG_LAUNCH(4, 16); + return; + default: + SINGLE_WG_LAUNCH(4, 32); + return; + } + } else if (vec == 2) { + switch (ept) { + case 2: + SINGLE_WG_LAUNCH(2, 2); + return; + case 4: + SINGLE_WG_LAUNCH(2, 4); + return; + case 8: + SINGLE_WG_LAUNCH(2, 8); + return; + case 16: + SINGLE_WG_LAUNCH(2, 16); + return; + default: + SINGLE_WG_LAUNCH(2, 32); + return; + } + } else { + switch (ept) { + case 1: + SINGLE_WG_LAUNCH(1, 1); + return; + case 2: + SINGLE_WG_LAUNCH(1, 2); + return; + case 4: + SINGLE_WG_LAUNCH(1, 4); + return; + case 8: + SINGLE_WG_LAUNCH(1, 8); + return; + case 16: + SINGLE_WG_LAUNCH(1, 16); + return; + default: + SINGLE_WG_LAUNCH(1, 32); + return; + } + } + } +} + +#undef SINGLE_WG_LAUNCH + +bool single_wg_topk_try_launch( + const at::Tensor& self, + int64_t nsegments, + int64_t nelements, + int64_t k, + bool largest, + const at::Tensor& values, + const at::Tensor& indices) { + // Only handle cases where single-workgroup topk is beneficial: + // large dim, small k, contiguous last-dim + if (k > 256) { + return false; + } + + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + self.scalar_type(), + "single_wg_topk_xpu", + [&]() { + const auto* input = static_cast(self.const_data_ptr()); + auto* topK = static_cast(values.data_ptr()); + auto* idx = static_cast(indices.data_ptr()); + + if (nsegments <= + static_cast(std::numeric_limits::max()) && + nelements <= + static_cast(std::numeric_limits::max())) { + single_wg_launch_kernel( + input, + topK, + idx, + static_cast(nsegments), + static_cast(nelements), + static_cast(k), + largest); + } else { + single_wg_launch_kernel( + input, + topK, + idx, + nsegments, + nelements, + static_cast(k), + largest); + } + }); + + return true; +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.h b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.h new file mode 100644 index 0000000000..a4c260c4b0 --- /dev/null +++ b/src/ATen/native/xpu/sycl/TensorTopKSingleWgKernel.h @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace at { +namespace native { +namespace xpu { + +// Single-workgroup top-k kernel (translated from CUDA single-block path). +// One workgroup per slice: radix select to find the k-th value, then gather. +// Good for large dim (≥4096), any batch size. Output is UNSORTED. +bool single_wg_topk_try_launch( + const at::Tensor& self, + int64_t nsegments, + int64_t nelements, + int64_t k, + bool largest, + const at::Tensor& values, + const at::Tensor& indices); + +} // namespace xpu +} // namespace native +} // namespace at