From f6565953fa0a444b6cdd5de9fa3e3598512c405f Mon Sep 17 00:00:00 2001 From: Konrad Drozd Date: Fri, 22 May 2026 10:23:46 +0300 Subject: [PATCH 1/5] Int64 support for UpsampleNearest3d --- .../xpu/sycl/UpSampleNearest3dKernels.cpp | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp index d13c26f9a3..06bd40ba49 100644 --- a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp @@ -28,13 +28,13 @@ namespace at::native::xpu { template struct UpsampleNearest3dKernelFunctor { void operator()(sycl::nd_item<1> item) const { - int dst_idx = item.get_global_linear_id(); + int64_t dst_idx = item.get_global_linear_id(); if (dst_idx >= dim_c_ * dst_dim_d_ * dst_dim_h_ * dst_dim_w_) return; - int dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; - int src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; + int64_t dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; + int64_t src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; int c = (dst_idx / (dst_c_stride)) % dim_c_; @@ -46,7 +46,7 @@ struct UpsampleNearest3dKernelFunctor { int dst_x = dst_idx % dst_dim_w_; int src_x = index_op_(width_scale_, dst_x, src_dim_w_); - int src_idx = c * src_c_stride + src_z * src_dim_h_ * src_dim_w_ + + int64_t src_idx = c * src_c_stride + src_z * src_dim_h_ * src_dim_w_ + src_y * src_dim_w_ + src_x; for (int b = 0; b < dim_b_; b++) { output_[dst_idx] = input_[src_idx]; @@ -104,7 +104,7 @@ struct UpsampleNearest3dKernelFunctor { template void upsample_nearest3d_out_template( const scalar_t* input, - unsigned int n, + int64_t n, size_t dim_b, size_t dim_c, size_t src_dim_d, @@ -135,8 +135,7 @@ void upsample_nearest3d_out_template( width_scale, index_op); auto work_group_size = syclMaxWorkGroupSize(kfn); - int64_t work_group_num = - at::ceil_div((unsigned int)n, (unsigned int)work_group_size); + int64_t work_group_num = at::ceil_div(n, (int64_t)work_group_size); sycl_kernel_submit( work_group_num * work_group_size, work_group_size, queue, kfn); } @@ -170,8 +169,11 @@ void upsample_nearest3d_kernel( int input_width = input_.size(4); Tensor input = input_.contiguous(); - unsigned int n = output.numel() / nbatch; - TORCH_CHECK(output.numel() <= std::numeric_limits::max()); + int64_t n = output.numel() / nbatch; + TORCH_CHECK( + output.numel() <= std::numeric_limits::max(), + "upsample_nearest3d only supports output tensors with less than INT64_MAX elements, but got ", + output.sizes()); AT_DISPATCH_FLOATING_TYPES_AND3( ScalarType::Half, ScalarType::BFloat16, @@ -232,13 +234,13 @@ void upsample_nearest3d_kernel( template struct UpsampleNearest3dBackwardFunctor { void operator()(sycl::nd_item<1> item) const { - int dst_idx = item.get_global_linear_id(); + int64_t dst_idx = item.get_global_linear_id(); if (dst_idx >= dim_c_ * dst_dim_d_ * dst_dim_h_ * dst_dim_w_) return; - int dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; - int src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; + int64_t dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; + int64_t src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; int c = (dst_idx / (dst_c_stride)) % dim_c_; @@ -259,7 +261,7 @@ struct UpsampleNearest3dBackwardFunctor { for (int z = src_z; z < src_z_up; z++) { for (int y = src_y; y < src_y_up; y++) { for (int x = src_x; x < src_x_up; x++) { - int src_idx = b * dim_c_ * src_c_stride + c * src_c_stride + + int64_t src_idx = b * dim_c_ * src_c_stride + c * src_c_stride + z * src_dim_h_ * src_dim_w_ + y * src_dim_w_ + x; grad += grad_o_[src_idx]; } @@ -318,7 +320,7 @@ struct UpsampleNearest3dBackwardFunctor { template void upsample_nearest3d_backward_template( const scalar_t* grad_o, - unsigned int n, + int64_t n, size_t dim_b, size_t dim_c, size_t src_dim_d, @@ -350,7 +352,7 @@ void upsample_nearest3d_backward_template( width_scale, index_bw_op); auto work_group_size = syclMaxWorkGroupSize(kfn); - int64_t work_group_num = at::ceil_div(n, (unsigned int)work_group_size); + int64_t work_group_num = at::ceil_div(n, (int64_t)work_group_size); sycl_kernel_submit( work_group_num * work_group_size, work_group_size, queue, kfn); } @@ -383,9 +385,15 @@ void upsample_nearest3d_backward_kernel( int input_width = input_size[4]; Tensor grad_output = grad_output_.contiguous(); - unsigned int n = grad_input.numel() / nbatch; - TORCH_CHECK(grad_input.numel() <= std::numeric_limits::max()); - TORCH_CHECK(grad_output.numel() <= std::numeric_limits::max()); + int64_t n = grad_input.numel() / nbatch; + TORCH_CHECK( + grad_input.numel() <= std::numeric_limits::max(), + "upsample_nearest3d_backward only supports input tensors with less than INT64_MAX elements, but got ", + grad_input.sizes()); + TORCH_CHECK( + grad_output.numel() <= std::numeric_limits::max(), + "upsample_nearest3d_backward only supports output tensors with less than INT64_MAX elements, but got ", + grad_output.sizes()); AT_DISPATCH_FLOATING_TYPES_AND3( ScalarType::Half, ScalarType::BFloat16, From 06d5d6962a3a8f1cee88d0ed95b4db8d5f5cc756 Mon Sep 17 00:00:00 2001 From: Konrad Drozd Date: Fri, 22 May 2026 10:03:36 +0200 Subject: [PATCH 2/5] remove redundant cast Co-authored-by: Slawomir Siwek --- src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp index 06bd40ba49..cf4f5b5521 100644 --- a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp @@ -135,7 +135,7 @@ void upsample_nearest3d_out_template( width_scale, index_op); auto work_group_size = syclMaxWorkGroupSize(kfn); - int64_t work_group_num = at::ceil_div(n, (int64_t)work_group_size); + int64_t work_group_num = at::ceil_div(n, work_group_size); sycl_kernel_submit( work_group_num * work_group_size, work_group_size, queue, kfn); } From cb34db6e867037ced881dcb90af1cc8951287050 Mon Sep 17 00:00:00 2001 From: Konrad Drozd Date: Fri, 22 May 2026 10:09:28 +0200 Subject: [PATCH 3/5] ditto Co-authored-by: Slawomir Siwek --- src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp index cf4f5b5521..8e9e55121e 100644 --- a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp @@ -352,7 +352,7 @@ void upsample_nearest3d_backward_template( width_scale, index_bw_op); auto work_group_size = syclMaxWorkGroupSize(kfn); - int64_t work_group_num = at::ceil_div(n, (int64_t)work_group_size); + int64_t work_group_num = at::ceil_div(n, work_group_size); sycl_kernel_submit( work_group_num * work_group_size, work_group_size, queue, kfn); } From 12bc3241ab9707bad5c0deed22987043863593ef Mon Sep 17 00:00:00 2001 From: Konrad Drozd Date: Fri, 22 May 2026 16:09:22 +0300 Subject: [PATCH 4/5] Switch to dispatch based approach --- .../xpu/sycl/UpSampleNearest3dKernels.cpp | 207 +++++++++--------- 1 file changed, 109 insertions(+), 98 deletions(-) diff --git a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp index 8e9e55121e..b73b68e1ba 100644 --- a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp @@ -17,6 +17,7 @@ DISABLE_RETURN_TYPE_WARNING_BEGIN #include #include #include +#include #include #include @@ -25,16 +26,16 @@ DISABLE_RETURN_TYPE_WARNING_BEGIN namespace at::native::xpu { -template +template struct UpsampleNearest3dKernelFunctor { void operator()(sycl::nd_item<1> item) const { - int64_t dst_idx = item.get_global_linear_id(); + index_t dst_idx = item.get_global_linear_id(); if (dst_idx >= dim_c_ * dst_dim_d_ * dst_dim_h_ * dst_dim_w_) return; - int64_t dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; - int64_t src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; + index_t dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; + index_t src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; int c = (dst_idx / (dst_c_stride)) % dim_c_; @@ -46,7 +47,7 @@ struct UpsampleNearest3dKernelFunctor { int dst_x = dst_idx % dst_dim_w_; int src_x = index_op_(width_scale_, dst_x, src_dim_w_); - int64_t src_idx = c * src_c_stride + src_z * src_dim_h_ * src_dim_w_ + + index_t src_idx = c * src_c_stride + src_z * src_dim_h_ * src_dim_w_ + src_y * src_dim_w_ + src_x; for (int b = 0; b < dim_b_; b++) { output_[dst_idx] = input_[src_idx]; @@ -101,7 +102,7 @@ struct UpsampleNearest3dKernelFunctor { index_op_t index_op_; }; -template +template void upsample_nearest3d_out_template( const scalar_t* input, int64_t n, @@ -119,7 +120,7 @@ void upsample_nearest3d_out_template( float width_scale, index_op_t index_op) { auto& queue = at::xpu::getCurrentSYCLQueue(); - auto kfn = UpsampleNearest3dKernelFunctor( + auto kfn = UpsampleNearest3dKernelFunctor( input, dim_b, dim_c, @@ -135,7 +136,8 @@ void upsample_nearest3d_out_template( width_scale, index_op); auto work_group_size = syclMaxWorkGroupSize(kfn); - int64_t work_group_num = at::ceil_div(n, work_group_size); + int64_t work_group_num = + at::ceil_div(n, static_cast(work_group_size)); sycl_kernel_submit( work_group_num * work_group_size, work_group_size, queue, kfn); } @@ -170,10 +172,6 @@ void upsample_nearest3d_kernel( Tensor input = input_.contiguous(); int64_t n = output.numel() / nbatch; - TORCH_CHECK( - output.numel() <= std::numeric_limits::max(), - "upsample_nearest3d only supports output tensors with less than INT64_MAX elements, but got ", - output.sizes()); AT_DISPATCH_FLOATING_TYPES_AND3( ScalarType::Half, ScalarType::BFloat16, @@ -190,57 +188,64 @@ void upsample_nearest3d_kernel( compute_scales_value(scales_h, input_height, output_height); const float width_scale = compute_scales_value(scales_w, input_width, output_width); - if (is_exact) { - upsample_nearest3d_out_template( - idata, - n, - nbatch, - channels, - input_depth, - input_height, - input_width, - output_depth, - output_height, - output_width, - odata, - depth_scale, - height_scale, - width_scale, - NearestExactIndexOp()); - } else { - upsample_nearest3d_out_template( - idata, - n, - nbatch, - channels, - input_depth, - input_height, - input_width, - output_depth, - output_height, - output_width, - odata, - depth_scale, - height_scale, - width_scale, - NearestIndexOp()); - } + AT_DISPATCH_INDEX_TYPES( + at::native::canUse32BitIndexMath(output) + ? ScalarType::Int + : ScalarType::Long, + "upsample_nearest3d_xpu_index", + [&] { + if (is_exact) { + upsample_nearest3d_out_template( + idata, + n, + nbatch, + channels, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width, + odata, + depth_scale, + height_scale, + width_scale, + NearestExactIndexOp()); + } else { + upsample_nearest3d_out_template( + idata, + n, + nbatch, + channels, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width, + odata, + depth_scale, + height_scale, + width_scale, + NearestIndexOp()); + } + }); }); if (!output.is_contiguous()) { output.copy_(output_c); } } -template +template struct UpsampleNearest3dBackwardFunctor { void operator()(sycl::nd_item<1> item) const { - int64_t dst_idx = item.get_global_linear_id(); + index_t dst_idx = item.get_global_linear_id(); if (dst_idx >= dim_c_ * dst_dim_d_ * dst_dim_h_ * dst_dim_w_) return; - int64_t dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; - int64_t src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; + index_t dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; + index_t src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; int c = (dst_idx / (dst_c_stride)) % dim_c_; @@ -261,7 +266,7 @@ struct UpsampleNearest3dBackwardFunctor { for (int z = src_z; z < src_z_up; z++) { for (int y = src_y; y < src_y_up; y++) { for (int x = src_x; x < src_x_up; x++) { - int64_t src_idx = b * dim_c_ * src_c_stride + c * src_c_stride + + index_t src_idx = b * dim_c_ * src_c_stride + c * src_c_stride + z * src_dim_h_ * src_dim_w_ + y * src_dim_w_ + x; grad += grad_o_[src_idx]; } @@ -317,7 +322,7 @@ struct UpsampleNearest3dBackwardFunctor { index_bw_op_t index_bw_op_; }; -template +template void upsample_nearest3d_backward_template( const scalar_t* grad_o, int64_t n, @@ -336,7 +341,7 @@ void upsample_nearest3d_backward_template( index_bw_op_t index_bw_op) { auto& queue = at::xpu::getCurrentSYCLQueue(); auto kfn = - UpsampleNearest3dBackwardFunctor( + UpsampleNearest3dBackwardFunctor( grad_o, dim_b, dim_c, @@ -352,7 +357,8 @@ void upsample_nearest3d_backward_template( width_scale, index_bw_op); auto work_group_size = syclMaxWorkGroupSize(kfn); - int64_t work_group_num = at::ceil_div(n, work_group_size); + int64_t work_group_num = + at::ceil_div(n, static_cast(work_group_size)); sycl_kernel_submit( work_group_num * work_group_size, work_group_size, queue, kfn); } @@ -386,14 +392,6 @@ void upsample_nearest3d_backward_kernel( Tensor grad_output = grad_output_.contiguous(); int64_t n = grad_input.numel() / nbatch; - TORCH_CHECK( - grad_input.numel() <= std::numeric_limits::max(), - "upsample_nearest3d_backward only supports input tensors with less than INT64_MAX elements, but got ", - grad_input.sizes()); - TORCH_CHECK( - grad_output.numel() <= std::numeric_limits::max(), - "upsample_nearest3d_backward only supports output tensors with less than INT64_MAX elements, but got ", - grad_output.sizes()); AT_DISPATCH_FLOATING_TYPES_AND3( ScalarType::Half, ScalarType::BFloat16, @@ -412,41 +410,54 @@ void upsample_nearest3d_backward_kernel( scales_h, output_height, input_height); float width_scale = compute_scales_value_backwards( scales_w, output_width, input_width); - if (is_exact) { - upsample_nearest3d_backward_template( - odata, - n, - nbatch, - channels, - output_depth, - output_height, - output_width, - input_depth, - input_height, - input_width, - idata, - depth_scale, - height_scale, - width_scale, - NearestExactBwIndexOp()); - } else { - upsample_nearest3d_backward_template( - odata, - n, - nbatch, - channels, - output_depth, - output_height, - output_width, - input_depth, - input_height, - input_width, - idata, - depth_scale, - height_scale, - width_scale, - NearestBwIndexOp()); - } + AT_DISPATCH_INDEX_TYPES( + at::native::canUse32BitIndexMath(grad_input) + ? ScalarType::Int + : ScalarType::Long, + "upsample_nearest3d_backward_xpu_index", + [&] { + if (is_exact) { + upsample_nearest3d_backward_template< + scalar_t, + accscalar_t, + index_t>( + odata, + n, + nbatch, + channels, + output_depth, + output_height, + output_width, + input_depth, + input_height, + input_width, + idata, + depth_scale, + height_scale, + width_scale, + NearestExactBwIndexOp()); + } else { + upsample_nearest3d_backward_template< + scalar_t, + accscalar_t, + index_t>( + odata, + n, + nbatch, + channels, + output_depth, + output_height, + output_width, + input_depth, + input_height, + input_width, + idata, + depth_scale, + height_scale, + width_scale, + NearestBwIndexOp()); + } + }); }); } From c43c8806cef939b966849cf7d5ef50ccc1bda74f Mon Sep 17 00:00:00 2001 From: Konrad Drozd Date: Fri, 22 May 2026 16:20:51 +0300 Subject: [PATCH 5/5] Lint --- .../xpu/sycl/UpSampleNearest3dKernels.cpp | 57 +++++++++++-------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp index b73b68e1ba..833c9503df 100644 --- a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp @@ -189,9 +189,8 @@ void upsample_nearest3d_kernel( const float width_scale = compute_scales_value(scales_w, input_width, output_width); AT_DISPATCH_INDEX_TYPES( - at::native::canUse32BitIndexMath(output) - ? ScalarType::Int - : ScalarType::Long, + at::native::canUse32BitIndexMath(output) ? ScalarType::Int + : ScalarType::Long, "upsample_nearest3d_xpu_index", [&] { if (is_exact) { @@ -236,7 +235,11 @@ void upsample_nearest3d_kernel( } } -template +template < + typename scalar_t, + typename accscalar_t, + typename index_t, + typename index_bw_op_t> struct UpsampleNearest3dBackwardFunctor { void operator()(sycl::nd_item<1> item) const { index_t dst_idx = item.get_global_linear_id(); @@ -322,7 +325,11 @@ struct UpsampleNearest3dBackwardFunctor { index_bw_op_t index_bw_op_; }; -template +template < + typename scalar_t, + typename accscalar_t, + typename index_t, + typename index_bw_op_t> void upsample_nearest3d_backward_template( const scalar_t* grad_o, int64_t n, @@ -340,22 +347,25 @@ void upsample_nearest3d_backward_template( float width_scale, index_bw_op_t index_bw_op) { auto& queue = at::xpu::getCurrentSYCLQueue(); - auto kfn = - UpsampleNearest3dBackwardFunctor( - grad_o, - dim_b, - dim_c, - src_dim_d, - src_dim_h, - src_dim_w, - dst_dim_d, - dst_dim_h, - dst_dim_w, - grad_i, - depth_scale, - height_scale, - width_scale, - index_bw_op); + auto kfn = UpsampleNearest3dBackwardFunctor< + scalar_t, + accscalar_t, + index_t, + index_bw_op_t>( + grad_o, + dim_b, + dim_c, + src_dim_d, + src_dim_h, + src_dim_w, + dst_dim_d, + dst_dim_h, + dst_dim_w, + grad_i, + depth_scale, + height_scale, + width_scale, + index_bw_op); auto work_group_size = syclMaxWorkGroupSize(kfn); int64_t work_group_num = at::ceil_div(n, static_cast(work_group_size)); @@ -411,9 +421,8 @@ void upsample_nearest3d_backward_kernel( float width_scale = compute_scales_value_backwards( scales_w, output_width, input_width); AT_DISPATCH_INDEX_TYPES( - at::native::canUse32BitIndexMath(grad_input) - ? ScalarType::Int - : ScalarType::Long, + at::native::canUse32BitIndexMath(grad_input) ? ScalarType::Int + : ScalarType::Long, "upsample_nearest3d_backward_xpu_index", [&] { if (is_exact) {