diff --git a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp index d13c26f9a3..833c9503df 100644 --- a/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp +++ b/src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp @@ -17,6 +17,7 @@ DISABLE_RETURN_TYPE_WARNING_BEGIN #include #include #include +#include #include #include @@ -25,16 +26,16 @@ DISABLE_RETURN_TYPE_WARNING_BEGIN namespace at::native::xpu { -template +template struct UpsampleNearest3dKernelFunctor { void operator()(sycl::nd_item<1> item) const { - int dst_idx = item.get_global_linear_id(); + index_t dst_idx = item.get_global_linear_id(); if (dst_idx >= dim_c_ * dst_dim_d_ * dst_dim_h_ * dst_dim_w_) return; - int dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; - int src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; + index_t dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; + index_t src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; int c = (dst_idx / (dst_c_stride)) % dim_c_; @@ -46,7 +47,7 @@ struct UpsampleNearest3dKernelFunctor { int dst_x = dst_idx % dst_dim_w_; int src_x = index_op_(width_scale_, dst_x, src_dim_w_); - int src_idx = c * src_c_stride + src_z * src_dim_h_ * src_dim_w_ + + index_t src_idx = c * src_c_stride + src_z * src_dim_h_ * src_dim_w_ + src_y * src_dim_w_ + src_x; for (int b = 0; b < dim_b_; b++) { output_[dst_idx] = input_[src_idx]; @@ -101,10 +102,10 @@ struct UpsampleNearest3dKernelFunctor { index_op_t index_op_; }; -template +template void upsample_nearest3d_out_template( const scalar_t* input, - unsigned int n, + int64_t n, size_t dim_b, size_t dim_c, size_t src_dim_d, @@ -119,7 +120,7 @@ void upsample_nearest3d_out_template( float width_scale, index_op_t index_op) { auto& queue = at::xpu::getCurrentSYCLQueue(); - auto kfn = UpsampleNearest3dKernelFunctor( + auto kfn = UpsampleNearest3dKernelFunctor( input, dim_b, dim_c, @@ -136,7 +137,7 @@ void upsample_nearest3d_out_template( index_op); auto work_group_size = syclMaxWorkGroupSize(kfn); int64_t work_group_num = - at::ceil_div((unsigned int)n, (unsigned int)work_group_size); + at::ceil_div(n, static_cast(work_group_size)); sycl_kernel_submit( work_group_num * work_group_size, work_group_size, queue, kfn); } @@ -170,8 +171,7 @@ void upsample_nearest3d_kernel( int input_width = input_.size(4); Tensor input = input_.contiguous(); - unsigned int n = output.numel() / nbatch; - TORCH_CHECK(output.numel() <= std::numeric_limits::max()); + int64_t n = output.numel() / nbatch; AT_DISPATCH_FLOATING_TYPES_AND3( ScalarType::Half, ScalarType::BFloat16, @@ -188,57 +188,67 @@ void upsample_nearest3d_kernel( compute_scales_value(scales_h, input_height, output_height); const float width_scale = compute_scales_value(scales_w, input_width, output_width); - if (is_exact) { - upsample_nearest3d_out_template( - idata, - n, - nbatch, - channels, - input_depth, - input_height, - input_width, - output_depth, - output_height, - output_width, - odata, - depth_scale, - height_scale, - width_scale, - NearestExactIndexOp()); - } else { - upsample_nearest3d_out_template( - idata, - n, - nbatch, - channels, - input_depth, - input_height, - input_width, - output_depth, - output_height, - output_width, - odata, - depth_scale, - height_scale, - width_scale, - NearestIndexOp()); - } + AT_DISPATCH_INDEX_TYPES( + at::native::canUse32BitIndexMath(output) ? ScalarType::Int + : ScalarType::Long, + "upsample_nearest3d_xpu_index", + [&] { + if (is_exact) { + upsample_nearest3d_out_template( + idata, + n, + nbatch, + channels, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width, + odata, + depth_scale, + height_scale, + width_scale, + NearestExactIndexOp()); + } else { + upsample_nearest3d_out_template( + idata, + n, + nbatch, + channels, + input_depth, + input_height, + input_width, + output_depth, + output_height, + output_width, + odata, + depth_scale, + height_scale, + width_scale, + NearestIndexOp()); + } + }); }); if (!output.is_contiguous()) { output.copy_(output_c); } } -template +template < + typename scalar_t, + typename accscalar_t, + typename index_t, + typename index_bw_op_t> struct UpsampleNearest3dBackwardFunctor { void operator()(sycl::nd_item<1> item) const { - int dst_idx = item.get_global_linear_id(); + index_t dst_idx = item.get_global_linear_id(); if (dst_idx >= dim_c_ * dst_dim_d_ * dst_dim_h_ * dst_dim_w_) return; - int dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; - int src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; + index_t dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_; + index_t src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_; int c = (dst_idx / (dst_c_stride)) % dim_c_; @@ -259,7 +269,7 @@ struct UpsampleNearest3dBackwardFunctor { for (int z = src_z; z < src_z_up; z++) { for (int y = src_y; y < src_y_up; y++) { for (int x = src_x; x < src_x_up; x++) { - int src_idx = b * dim_c_ * src_c_stride + c * src_c_stride + + index_t src_idx = b * dim_c_ * src_c_stride + c * src_c_stride + z * src_dim_h_ * src_dim_w_ + y * src_dim_w_ + x; grad += grad_o_[src_idx]; } @@ -315,10 +325,14 @@ struct UpsampleNearest3dBackwardFunctor { index_bw_op_t index_bw_op_; }; -template +template < + typename scalar_t, + typename accscalar_t, + typename index_t, + typename index_bw_op_t> void upsample_nearest3d_backward_template( const scalar_t* grad_o, - unsigned int n, + int64_t n, size_t dim_b, size_t dim_c, size_t src_dim_d, @@ -333,24 +347,28 @@ void upsample_nearest3d_backward_template( float width_scale, index_bw_op_t index_bw_op) { auto& queue = at::xpu::getCurrentSYCLQueue(); - auto kfn = - UpsampleNearest3dBackwardFunctor( - grad_o, - dim_b, - dim_c, - src_dim_d, - src_dim_h, - src_dim_w, - dst_dim_d, - dst_dim_h, - dst_dim_w, - grad_i, - depth_scale, - height_scale, - width_scale, - index_bw_op); + auto kfn = UpsampleNearest3dBackwardFunctor< + scalar_t, + accscalar_t, + index_t, + index_bw_op_t>( + grad_o, + dim_b, + dim_c, + src_dim_d, + src_dim_h, + src_dim_w, + dst_dim_d, + dst_dim_h, + dst_dim_w, + grad_i, + depth_scale, + height_scale, + width_scale, + index_bw_op); auto work_group_size = syclMaxWorkGroupSize(kfn); - int64_t work_group_num = at::ceil_div(n, (unsigned int)work_group_size); + int64_t work_group_num = + at::ceil_div(n, static_cast(work_group_size)); sycl_kernel_submit( work_group_num * work_group_size, work_group_size, queue, kfn); } @@ -383,9 +401,7 @@ void upsample_nearest3d_backward_kernel( int input_width = input_size[4]; Tensor grad_output = grad_output_.contiguous(); - unsigned int n = grad_input.numel() / nbatch; - TORCH_CHECK(grad_input.numel() <= std::numeric_limits::max()); - TORCH_CHECK(grad_output.numel() <= std::numeric_limits::max()); + int64_t n = grad_input.numel() / nbatch; AT_DISPATCH_FLOATING_TYPES_AND3( ScalarType::Half, ScalarType::BFloat16, @@ -404,41 +420,53 @@ void upsample_nearest3d_backward_kernel( scales_h, output_height, input_height); float width_scale = compute_scales_value_backwards( scales_w, output_width, input_width); - if (is_exact) { - upsample_nearest3d_backward_template( - odata, - n, - nbatch, - channels, - output_depth, - output_height, - output_width, - input_depth, - input_height, - input_width, - idata, - depth_scale, - height_scale, - width_scale, - NearestExactBwIndexOp()); - } else { - upsample_nearest3d_backward_template( - odata, - n, - nbatch, - channels, - output_depth, - output_height, - output_width, - input_depth, - input_height, - input_width, - idata, - depth_scale, - height_scale, - width_scale, - NearestBwIndexOp()); - } + AT_DISPATCH_INDEX_TYPES( + at::native::canUse32BitIndexMath(grad_input) ? ScalarType::Int + : ScalarType::Long, + "upsample_nearest3d_backward_xpu_index", + [&] { + if (is_exact) { + upsample_nearest3d_backward_template< + scalar_t, + accscalar_t, + index_t>( + odata, + n, + nbatch, + channels, + output_depth, + output_height, + output_width, + input_depth, + input_height, + input_width, + idata, + depth_scale, + height_scale, + width_scale, + NearestExactBwIndexOp()); + } else { + upsample_nearest3d_backward_template< + scalar_t, + accscalar_t, + index_t>( + odata, + n, + nbatch, + channels, + output_depth, + output_height, + output_width, + input_depth, + input_height, + input_width, + idata, + depth_scale, + height_scale, + width_scale, + NearestBwIndexOp()); + } + }); }); }