From abfadebdb60bbea596f4c3a7196b3aa7234e3fcd Mon Sep 17 00:00:00 2001 From: "Cui, Yifeng" Date: Tue, 19 May 2026 05:44:20 +0000 Subject: [PATCH] Revert "Add subgroup topk kernel for XPU (part1 of #3369) (#3371)" This reverts commit 8eaa591f664d7c934ae98bbf316ddf6e27990fd4. --- src/ATen/native/xpu/sycl/TensorTopKKernel.cpp | 31 +- .../xpu/sycl/TensorTopKSbtopkKernel.cpp | 486 ------------------ .../native/xpu/sycl/TensorTopKSbtopkKernel.h | 42 -- 3 files changed, 11 insertions(+), 548 deletions(-) delete mode 100644 src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp delete mode 100644 src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h diff --git a/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp index b2ed7949c9..3042f685fd 100644 --- a/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp +++ b/src/ATen/native/xpu/sycl/TensorTopKKernel.cpp @@ -16,7 +16,6 @@ #include #include -#include namespace at { namespace native { @@ -114,25 +113,17 @@ void topk_kernel( const scalar_t* self_ptr = self_.const_data_ptr(); scalar_t* values_ptr = values_.data_ptr(); int64_t* indices_ptr = indices_.data_ptr(); - - SbtopkResult sbtopk_result = sbtopk_try_launch( - self_, nsegments, nelements, k, largest, values_, indices_); - if (sbtopk_result == SbtopkResult::FAILED) { - segmented_group_select_pairs( - self_ptr, - values_ptr, - nullptr, - (int64_t*)indices_ptr, - nsegments, - nelements, - k, - largest); - } - - // Only sort if the user asked for sorted output AND the optimized - // kernel did not already produce a sorted result. The subgroup topk - // kernel returns SORTED; the original kernel returns FAILED. - if (sorted && sbtopk_result != SbtopkResult::SORTED) { + segmented_group_select_pairs( + self_ptr, + values_ptr, + nullptr, + (int64_t*)indices_ptr, + nsegments, + nelements, + k, + largest); + + if (sorted) { segmented_sort_pairs( values_ptr, values_ptr, diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp deleted file mode 100644 index 9dc97ea4f2..0000000000 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp +++ /dev/null @@ -1,486 +0,0 @@ -/* - * Subgroup top-k kernel for optimized topk on XPU. - * - * Each sub-group (32 lanes) handles one slice entirely in registers. - * Zero SLM, zero barriers. - * - * Algorithm: - * Phase 1: Each lane scans nelements/32 elements, maintains a sorted top-k - * buffer via insertion sort (fully unrolled, no branches on - * direction thanks to compile-time Largest template param). - * Phase 2: 5 levels of pairwise bitonic merge via sub-group shuffles - * to combine 32 per-lane buffers into one global top-k. - * Phase 3: Lane 0 writes k results. Output is already sorted. - * - * Dispatch: k <= 16 and enough segments (large batch) and nelements >= 32 - * routes to subgroup top-k; otherwise falls back to original. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace at { -namespace native { -namespace xpu { - -namespace syclex = sycl::ext::oneapi::experimental; -namespace intelex = sycl::ext::intel::experimental; - -static constexpr int SG_SIZE = 32; -// Number of pairwise merge levels in Phase 2 = log2(SG_SIZE). -static constexpr int SG_MERGE_LEVELS = 5; -static_assert( - (1 << SG_MERGE_LEVELS) == SG_SIZE, - "SG_MERGE_LEVELS must equal log2(SG_SIZE)"); - -// ================================================================ -// SubgroupTopKFunctor -// -// K: compile-time max top-k (must be >= runtime k) -// VEC_SIZE: vectorized load width -// Largest: compile-time direction flag. Eliminates per-element branches -// on largest_ that otherwise pessimize the tight insert/merge loops. -// IndexT: int32 when numSlices <= INT_MAX (common case), int64_t -// otherwise. int32 avoids 64-bit arithmetic on slice indices, -// reducing register pressure. -// ================================================================ -template < - typename scalar_t, - int K, - int VEC_SIZE, - bool Largest, - typename IndexT = int> -struct SubgroupTopKFunctor { - // Insert val into a K-sorted buffer. For Largest=true the buffer is sorted - // descending (top_vals[0] is max); for Largest=false it is sorted ascending - // (top_vals[0] is min). The comparator `better(a, b)` means "a should sit - // above b in the buffer" — i.e. strictly greater for largest, strictly less - // for smallest. - // - // Fully unrolled, no early break — SIMD-friendly. - inline void insert( - scalar_t* top_vals, - IndexT* top_idx, - int count, - scalar_t val, - IndexT idx) const { - // Threshold is at the bottom of the buffer (top_vals[K-1]). - if constexpr (Largest) { - if (count >= K && !(val > top_vals[K - 1])) - return; - } else { - if (count >= K && !(val < top_vals[K - 1])) - return; - } - bool inserted = false; -#pragma unroll - for (int i = K - 1; i >= 0; --i) { - // When count < K the buffer is partially filled: positions [0, count) - // hold real values while [count, K) still contain sentinels. - // Guard "i <= count" ensures we only stop at position i when - // top_vals[i-1] is a real entry. Without this guard, an input - // value equal to the sentinel (e.g. all -inf for largest=true) - // would always stop at position K-1, overwriting it repeatedly - // instead of filling lower positions. - bool stop; - if constexpr (Largest) { - stop = (i == 0) || (i <= count && !(val > top_vals[i - 1])); - } else { - stop = (i == 0) || (i <= count && !(val < top_vals[i - 1])); - } - if (!inserted && stop) { - top_vals[i] = val; - top_idx[i] = idx; - inserted = true; - } else if (!inserted) { - top_vals[i] = top_vals[i - 1]; - top_idx[i] = top_idx[i - 1]; - } - } - } - - // Bitonic merge: A[K] and B[K] are both sorted in the "better" direction. - // Step 1: A[i] = better(A[i], B[K-1-i]) — produces bitonic sequence. - // Step 2: bitonic sort restores the sorted-by-better order on A. - inline void bitonic_merge( - scalar_t* A, - IndexT* A_idx, - const scalar_t* B, - const IndexT* B_idx) const { - // Step 1: compare with reversed partner -#pragma unroll - for (int i = 0; i < K; ++i) { - scalar_t bv = B[K - 1 - i]; - IndexT bi = B_idx[K - 1 - i]; - bool take; - if constexpr (Largest) { - take = bv > A[i]; - } else { - take = bv < A[i]; - } - if (take) { - A[i] = bv; - A_idx[i] = bi; - } - } - // Step 2: bitonic sort — standard bitonic merge network. - // - // After step 1, A[0..K-1] is bitonic (first decreasing then increasing, - // or vice versa). At stride = K/2 we compare A[i] with A[i + K/2] - // for i in [0, K/2) and swap so the "better" value goes to the low - // half. This guarantees: - // (a) every element in A[0..K/2-1] >= every element in A[K/2..K-1], - // (b) each half is itself bitonic (splitting a bitonic sequence at - // the midpoint with min/max produces two bitonic subsequences). - // Recurse with stride K/4, K/8, ..., 1 and each sub-piece halves - // again, until every piece has length 1 — the array is sorted. - // - // j = i ^ stride pairs each element with its partner at distance - // `stride`. The guard j > i ensures each pair is processed once. - // - // Example for K = 16: - // stride 8: (0,8) (1,9) (2,10) ... (7,15) — 8 pairs - // stride 4: (0,4) (1,5) (2,6) (3,7) — two groups of 4 - // (8,12) (9,13) (10,14) (11,15) - // stride 2: (0,2) (1,3) (4,6) (5,7) ... — four groups of 2 - // stride 1: (0,1) (2,3) (4,5) ... (14,15) — 8 adjacent pairs -#pragma unroll - for (int stride = K / 2; stride >= 1; stride >>= 1) { -#pragma unroll - for (int i = 0; i < K; ++i) { - int j = i ^ stride; - bool swap; - if constexpr (Largest) { - swap = (j > i) && (A[i] < A[j]); - } else { - swap = (j > i) && (A[i] > A[j]); - } - if (swap) { - scalar_t tv = A[i]; - A[i] = A[j]; - A[j] = tv; - IndexT ti = A_idx[i]; - A_idx[i] = A_idx[j]; - A_idx[j] = ti; - } - } - } - } - - void operator()(sycl::nd_item<1> item) const { - sycl::sub_group sg = item.get_sub_group(); - int sg_lid = sg.get_local_linear_id(); - - // Each sub-group handles one slice - int sgs_per_wg = item.get_local_range(0) / SG_SIZE; - IndexT slice = - static_cast(item.get_group_linear_id()) * sgs_per_wg + - sg.get_group_linear_id(); - if (slice >= numSlices_) - return; - - const scalar_t* inputSlice = - inputData_ + static_cast(slice) * sliceSize_; - scalar_t* topKSlice = topKData_ + static_cast(slice) * k_; - int64_t* indicesSlice = indicesData_ + static_cast(slice) * k_; - - // Initialize sorted top-K buffer - scalar_t top_vals[K]; - IndexT top_idx_local[K]; - scalar_t init_val; - if constexpr (Largest) { - if constexpr (std::numeric_limits::has_infinity) { - init_val = -std::numeric_limits::infinity(); - } else { - init_val = std::numeric_limits::lowest(); - } - } else { - if constexpr (std::numeric_limits::has_infinity) { - init_val = std::numeric_limits::infinity(); - } else { - init_val = std::numeric_limits::max(); - } - } -#pragma unroll - for (int i = 0; i < K; ++i) { - top_vals[i] = init_val; - top_idx_local[i] = -1; - } - int count = 0; - - // ---- Phase 1: scan data with vec loads ---- - using LoadT = memory::aligned_vector; - int stride = SG_SIZE * VEC_SIZE; - - int64_t base; - for (base = sg_lid * VEC_SIZE; base + VEC_SIZE <= sliceSize_; - base += stride) { - alignas(alignof(LoadT)) scalar_t src[VEC_SIZE]; - *reinterpret_cast(&src) = - *reinterpret_cast(&inputSlice[base]); -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - insert( - top_vals, - top_idx_local, - count, - src[v], - static_cast(base + v)); - if (count < K) - count++; - } - } - // Scalar tail - for (IndexT idx = static_cast(base); - idx < sliceSize_ && idx < base + VEC_SIZE; - ++idx) { - scalar_t val = inputSlice[idx]; - insert(top_vals, top_idx_local, count, val, idx); - if (count < K) - count++; - } - - // ---- Phase 2: sub-group bitonic merge ---- -#pragma unroll - for (int d = 0; d < SG_MERGE_LEVELS; ++d) { - int partner = sg_lid ^ (1 << d); - - scalar_t partner_vals[K]; - IndexT partner_idx[K]; -#pragma unroll - for (int i = 0; i < K; ++i) { - partner_vals[i] = sycl::select_from_group(sg, top_vals[i], partner); - partner_idx[i] = sycl::select_from_group(sg, top_idx_local[i], partner); - } - - bitonic_merge(top_vals, top_idx_local, partner_vals, partner_idx); - } - - // ---- Phase 3: lane 0 writes output ---- - if (sg_lid == 0) { - for (int i = 0; i < k_; ++i) { - topKSlice[i] = top_vals[i]; - indicesSlice[i] = static_cast(top_idx_local[i]); - } - } - } - - SubgroupTopKFunctor( - const scalar_t* inputData, - scalar_t* topKData, - int64_t* indicesData, - IndexT numSlices, - int64_t sliceSize, - int k) - : inputData_(inputData), - topKData_(topKData), - indicesData_(indicesData), - numSlices_(numSlices), - sliceSize_(sliceSize), - k_(k) {} - - const scalar_t* inputData_; - scalar_t* topKData_; - int64_t* indicesData_; - IndexT numSlices_; - int64_t sliceSize_; - int k_; -}; - -// ================================================================ -// Launch helpers -// ================================================================ -template -static void sbtopk_launch_impl( - const scalar_t* input, - scalar_t* topK, - int64_t* indices, - IndexT numSlices, - int64_t sliceSize, - int k) { - constexpr int WG_SIZE = 256; // 8 sub-groups per work-group - constexpr int SGS_PER_WG = WG_SIZE / SG_SIZE; - auto num_wgs = at::ceil_div( - static_cast(numSlices), static_cast(SGS_PER_WG)); - - SubgroupTopKFunctor functor( - input, topK, indices, numSlices, sliceSize, k); - - sycl_kernel_submit( - num_wgs * WG_SIZE, - WG_SIZE, - at::xpu::getCurrentSYCLQueue(), - syclex::properties{ - syclex::sub_group_size, intelex::grf_size<128>}, - functor); -} - -// Vec-size dispatch: picks the largest VEC_SIZE compatible with -// (dtype, sliceSize). -template -static void sbtopk_launch_vec_dispatch( - const scalar_t* input, - scalar_t* topK, - int64_t* indices, - IndexT numSlices, - int64_t sliceSize, - int k) { - // Max VEC_SIZE for this dtype - constexpr int MAX_VEC = sizeof(scalar_t) <= 2 ? 8 : 4; - - // Pick largest VEC_SIZE such that: - // 1. SG_SIZE * VEC_SIZE <= sliceSize (all threads get at - // least one full vector) - // 2. sliceSize % VEC_SIZE == 0 (slice boundaries are aligned) - // 3. input pointer is aligned to sizeof(scalar_t) * VEC_SIZE - // (usually guaranteed by PyTorch allocators, but a non-zero - // storage offset can break alignment) - auto input_align = reinterpret_cast(input); - auto can_use_vec = [&](int vec) { - return MAX_VEC >= vec && sliceSize % vec == 0 && - SG_SIZE * vec <= sliceSize && - input_align % (sizeof(scalar_t) * vec) == 0; - }; - if (can_use_vec(8)) { - sbtopk_launch_impl( - input, topK, indices, numSlices, sliceSize, k); - } else if (can_use_vec(4)) { - sbtopk_launch_impl( - input, topK, indices, numSlices, sliceSize, k); - } else if (can_use_vec(2)) { - sbtopk_launch_impl( - input, topK, indices, numSlices, sliceSize, k); - } else { - sbtopk_launch_impl( - input, topK, indices, numSlices, sliceSize, k); - } -} - -template -static void sbtopk_launch_kernel( - const scalar_t* input, - scalar_t* topK, - int64_t* indices, - int64_t numSlices, - int64_t sliceSize, - int k, - bool largest) { - // Dispatch on (K, Largest, IndexT). - // K: smallest power-of-two >= k. Smaller K means fewer unrolled - // iterations in insert/merge, less register pressure, zero spills. - // K=1 is a valid special case (top-1 = max/min element). - // Largest: compile-time direction eliminates per-element branches. - // IndexT: int32 when numSlices fits (common), int64 otherwise. - - // Select K: round up k to next power of two, clamped to [1, 16]. - int K_sel = std::min( - static_cast(c10::llvm::PowerOf2Ceil(static_cast(k))), 16); - -#define SBTOPK_DISPATCH_K(K_VAL, LARGEST, INDEX_T, NUM_SLICES) \ - sbtopk_launch_vec_dispatch( \ - input, topK, indices, NUM_SLICES, sliceSize, k) - -#define SBTOPK_DISPATCH_LARGEST(K_VAL, INDEX_T, NUM_SLICES) \ - if (largest) { \ - SBTOPK_DISPATCH_K(K_VAL, true, INDEX_T, NUM_SLICES); \ - } else { \ - SBTOPK_DISPATCH_K(K_VAL, false, INDEX_T, NUM_SLICES); \ - } - -#define SBTOPK_DISPATCH_INDEX(K_VAL) \ - if (numSlices <= std::numeric_limits::max()) { \ - int numSlices32 = static_cast(numSlices); \ - SBTOPK_DISPATCH_LARGEST(K_VAL, int, numSlices32); \ - } else { \ - SBTOPK_DISPATCH_LARGEST(K_VAL, int64_t, numSlices); \ - } - - switch (K_sel) { - case 1: - SBTOPK_DISPATCH_INDEX(1); - break; - case 2: - SBTOPK_DISPATCH_INDEX(2); - break; - case 4: - SBTOPK_DISPATCH_INDEX(4); - break; - case 8: - SBTOPK_DISPATCH_INDEX(8); - break; - default: - SBTOPK_DISPATCH_INDEX(16); - break; - } - -#undef SBTOPK_DISPATCH_INDEX -#undef SBTOPK_DISPATCH_LARGEST -#undef SBTOPK_DISPATCH_K -} - -// ================================================================ -// Dispatch: subgroup top-k vs original -// -// - dim < 32: original (need at least SG_SIZE elements) -// - dim >= 32, large batch, k <= 16: subgroup top-k -// ================================================================ -SbtopkResult sbtopk_try_launch( - const at::Tensor& self, - int64_t nsegments, - int64_t nelements, - int64_t k, - bool largest, - const at::Tensor& values, - const at::Tensor& indices) { - // Subgroup kernel needs at least SG_SIZE (32) elements per slice - if (nelements < 32) { - return SbtopkResult::FAILED; - } - - // Subgroup top-k: best for large batch, k<=16. - // Output is ALREADY SORTED (descending for largest, ascending for smallest). - // - // Threshold: nsegments >= thread_slots / 4. - // Subgroup top-k uses 1 sub-group per slice (reading data once), while - // the original kernel reads data multiple times (~3 radix passes). So - // subgroup top-k reaches memory-BW saturation at much lower occupancy. - // thread_slots/4 is the conservative cutoff. - // - // On B580: thread_slots = 160 EU * 8 HW threads = 1280, threshold = 320. - int64_t thread_slots = - ::xpu::sycl::syclGpuEuCount() * ::xpu::sycl::syclGpuHWThreadsPerEU(); - int64_t sg_threshold = thread_slots / 4; - if (k <= 16 && nsegments >= sg_threshold) { - AT_DISPATCH_ALL_TYPES_AND2( - at::ScalarType::Half, - at::ScalarType::BFloat16, - self.scalar_type(), - "subgroup_topk_xpu", - [&]() { - sbtopk_launch_kernel( - static_cast(self.const_data_ptr()), - static_cast(values.data_ptr()), - static_cast(indices.data_ptr()), - nsegments, - nelements, - static_cast(k), - largest); - }); - return SbtopkResult::SORTED; - } - - return SbtopkResult::FAILED; -} - -} // namespace xpu -} // namespace native -} // namespace at diff --git a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h b/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h deleted file mode 100644 index 031fa64be5..0000000000 --- a/src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2020-2026 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - */ - -#pragma once - -#include - -namespace at { -namespace native { -namespace xpu { - -// Result of sbtopk_try_launch. -// FAILED - did not run; caller should fall back to original kernel. -// UNSORTED - ran; output contains top-k values but is not sorted. -// Caller must sort if sorted output is requested. -// SORTED - ran; output is already sorted (descending for largest, -// ascending for smallest). Caller can skip sort. -enum class SbtopkResult : int { - FAILED = 0, - UNSORTED = 1, - SORTED = 2, -}; - -SbtopkResult sbtopk_try_launch( - const at::Tensor& self, - int64_t nsegments, - int64_t nelements, - int64_t k, - bool largest, - const at::Tensor& values, - const at::Tensor& indices); - -} // namespace xpu -} // namespace native -} // namespace at