Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/ATen/native/xpu/sycl/SortingRadixSelect.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ struct TopKTypeConfig<float> {
static inline RadixType convert(float v) {
RadixType x = *((uint32_t*)&v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (x ^ mask);
return (v == v) ? (x ^ mask) : 0xffffffff;
Comment thread
jianyizh marked this conversation as resolved.
}
Comment thread
jianyizh marked this conversation as resolved.

static inline float deconvert(RadixType v) {
Expand Down Expand Up @@ -168,7 +168,7 @@ struct TopKTypeConfig<double> {
static inline RadixType convert(double v) {
RadixType x = *((uint64_t*)&v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (x ^ mask);
return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
Comment thread
jianyizh marked this conversation as resolved.
}

static inline double deconvert(RadixType v) {
Expand All @@ -183,12 +183,12 @@ struct TopKTypeConfig<at::Half> {

static inline RadixType convert(at::Half v) {
RadixType x = *((uint16_t*)&v);
RadixType mask = -((x >> 15)) | 0x8000;
return (x ^ mask);
RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
return (v == v) ? (x ^ mask) : 0xffff;
}

static inline at::Half deconvert(RadixType v) {
RadixType mask = ((v >> 15) - 1) | 0x8000;
RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
return __ushort_as_half(v ^ mask);
}
};
Expand Down
25 changes: 21 additions & 4 deletions src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h>
#include <ATen/native/xpu/sycl/TensorTopKSbtopkKernelImpl.h>
#include <ATen/native/xpu/sycl/TensorTopKSingleWgKernel.h>
#include <c10/util/llvmMathExtras.h>
Comment thread
jianyizh marked this conversation as resolved.
#include <comm/DeviceProperties.h>

Expand Down Expand Up @@ -68,10 +69,13 @@ static bool subgroup_topk_try_launch(
#undef SBTOPK_LAUNCH

// ================================================================
// Dispatch: subgroup top-k vs original
// Dispatch: subgroup top-k vs single-workgroup top-k vs original
//
// - dim < 32: original (need at least SG_SIZE elements)
// - dim >= 32, large batch, k <= 16: subgroup top-k
// - dim >= 32, large batch, k <= 8: subgroup top-k
// - dim >= 4096, any bs: single-workgroup top-k
// - dim >= 4096, k > 8: single-workgroup top-k
// (subgroup only supports k <= 8)
// ================================================================
SbtopkResult sbtopk_try_launch(
const at::Tensor& self,
Expand All @@ -86,12 +90,12 @@ SbtopkResult sbtopk_try_launch(
return SbtopkResult::FAILED;
}

// Subgroup top-k: best for large batch, k<=8.
// Subgroup top-k: best for large batch, k <= 8.
// Output is ALREADY SORTED (descending for largest, ascending for smallest).
//
// Threshold: nsegments >= thread_slots / 4.
// Subgroup top-k uses 1 sub-group per slice (reading data once), while
// the original kernel reads data multiple times (~3 radix passes). So
// single-wg/original read data multiple times (~3 radix passes). So
// subgroup top-k reaches memory-BW saturation at much lower occupancy.
// thread_slots/4 is the conservative cutoff.
//
Expand All @@ -107,6 +111,19 @@ SbtopkResult sbtopk_try_launch(
return SbtopkResult::FAILED;
}

// Single-workgroup top-k: radix select, good for large
// dim. Output is UNSORTED.
// Use nelements >= 4096 so dim=1024/1025 falls through to original
// (single-wg has regressions at dim=1024 for medium/large batches).
if (nelements >= 4096) {
if (single_wg_topk_try_launch(
self, nsegments, nelements, k, largest, values, indices)) {
return SbtopkResult::UNSORTED;
Comment thread
jianyizh marked this conversation as resolved.
}
return SbtopkResult::FAILED;
}

// Fallback to original for dim=32-4095 or k>8 with small batch
return SbtopkResult::FAILED;
}

Expand Down
11 changes: 8 additions & 3 deletions src/ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,22 @@
namespace at::native::xpu {

// Result of sbtopk_try_launch.
// FAILED - did not run; caller should fall back to original kernel.
// UNSORTED - ran; output contains top-k values but is not sorted.
// FAILED - sbtopk did not run; caller should fall back to original.
// UNSORTED - sbtopk ran; output contains top-k values but is not sorted.
// Caller must sort if sorted output is requested.
// SORTED - ran; output is already sorted (descending for largest,
// SORTED - sbtopk ran; output is already sorted (descending for largest,
// ascending for smallest). Caller can skip sort.
enum class SbtopkResult : int {
FAILED = 0,
UNSORTED = 1,
SORTED = 2,
};

// Try to run topk using an optimized kernel path.
//
// Dispatches between the subgroup topk kernel (sub-group bitonic merge,
// output SORTED) and the single workgroup topk kernel (radix select,
// output UNSORTED) based on (nsegments, nelements, k).
SbtopkResult sbtopk_try_launch(
const at::Tensor& self,
int64_t nsegments,
Expand Down
Loading
Loading