Skip to content
Merged
31 changes: 20 additions & 11 deletions src/ATen/native/xpu/sycl/TensorTopKKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <ATen/native/xpu/sycl/SortingKernels.h>

#include <ATen/native/xpu/sycl/TensorTopKKernel.h>
#include <ATen/native/xpu/sycl/TensorTopKSbtopkKernel.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -113,17 +114,25 @@ void topk_kernel(
const scalar_t* self_ptr = self_.const_data_ptr<scalar_t>();
scalar_t* values_ptr = values_.data_ptr<scalar_t>();
int64_t* indices_ptr = indices_.data_ptr<int64_t>();
segmented_group_select_pairs<scalar_t, int64_t>(
self_ptr,
values_ptr,
nullptr,
(int64_t*)indices_ptr,
nsegments,
nelements,
k,
largest);

if (sorted) {

SbtopkResult sbtopk_result = sbtopk_try_launch(
self_, nsegments, nelements, k, largest, values_, indices_);
if (sbtopk_result == SbtopkResult::FAILED) {
Comment thread
jianyizh marked this conversation as resolved.
Comment thread
jianyizh marked this conversation as resolved.
segmented_group_select_pairs<scalar_t, int64_t>(
self_ptr,
values_ptr,
nullptr,
(int64_t*)indices_ptr,
nsegments,
nelements,
k,
largest);
Comment thread
jianyizh marked this conversation as resolved.
}

// Only sort if the user asked for sorted output AND the optimized
// kernel did not already produce a sorted result. The subgroup topk
// kernel returns SORTED; the original kernel returns FAILED.
if (sorted && sbtopk_result != SbtopkResult::SORTED) {
segmented_sort_pairs<scalar_t, int64_t>(
values_ptr,
values_ptr,
Expand Down
Loading
Loading