Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 136 additions & 108 deletions src/ATen/native/xpu/sycl/UpSampleNearest3dKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ DISABLE_RETURN_TYPE_WARNING_BEGIN
#include <ATen/TensorUtils.h>
#include <ATen/ceil_div.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/CanUse32BitIndexMath.h>
#include <ATen/ops/empty.h>

#include <ATen/native/xpu/UpSample.h>
Expand All @@ -25,16 +26,16 @@ DISABLE_RETURN_TYPE_WARNING_BEGIN

namespace at::native::xpu {

template <typename scalar_t, typename index_op_t>
template <typename scalar_t, typename index_t, typename index_op_t>
struct UpsampleNearest3dKernelFunctor {
void operator()(sycl::nd_item<1> item) const {
int dst_idx = item.get_global_linear_id();
index_t dst_idx = item.get_global_linear_id();

if (dst_idx >= dim_c_ * dst_dim_d_ * dst_dim_h_ * dst_dim_w_)
return;

int dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_;
int src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_;
index_t dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_;
index_t src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_;

int c = (dst_idx / (dst_c_stride)) % dim_c_;

Expand All @@ -46,7 +47,7 @@ struct UpsampleNearest3dKernelFunctor {
int dst_x = dst_idx % dst_dim_w_;
int src_x = index_op_(width_scale_, dst_x, src_dim_w_);

int src_idx = c * src_c_stride + src_z * src_dim_h_ * src_dim_w_ +
index_t src_idx = c * src_c_stride + src_z * src_dim_h_ * src_dim_w_ +
src_y * src_dim_w_ + src_x;
for (int b = 0; b < dim_b_; b++) {
output_[dst_idx] = input_[src_idx];
Expand Down Expand Up @@ -101,10 +102,10 @@ struct UpsampleNearest3dKernelFunctor {
index_op_t index_op_;
};

template <typename scalar_t, typename index_op_t>
template <typename scalar_t, typename index_t, typename index_op_t>
void upsample_nearest3d_out_template(
const scalar_t* input,
unsigned int n,
int64_t n,
size_t dim_b,
size_t dim_c,
size_t src_dim_d,
Expand All @@ -119,7 +120,7 @@ void upsample_nearest3d_out_template(
float width_scale,
index_op_t index_op) {
auto& queue = at::xpu::getCurrentSYCLQueue();
auto kfn = UpsampleNearest3dKernelFunctor<scalar_t, index_op_t>(
auto kfn = UpsampleNearest3dKernelFunctor<scalar_t, index_t, index_op_t>(
input,
dim_b,
dim_c,
Expand All @@ -136,7 +137,7 @@ void upsample_nearest3d_out_template(
index_op);
auto work_group_size = syclMaxWorkGroupSize(kfn);
int64_t work_group_num =
at::ceil_div((unsigned int)n, (unsigned int)work_group_size);
at::ceil_div(n, static_cast<int64_t>(work_group_size));
sycl_kernel_submit(
work_group_num * work_group_size, work_group_size, queue, kfn);
}
Expand Down Expand Up @@ -170,8 +171,7 @@ void upsample_nearest3d_kernel(
int input_width = input_.size(4);

Tensor input = input_.contiguous();
unsigned int n = output.numel() / nbatch;
TORCH_CHECK(output.numel() <= std::numeric_limits<int32_t>::max());
int64_t n = output.numel() / nbatch;
AT_DISPATCH_FLOATING_TYPES_AND3(
ScalarType::Half,
ScalarType::BFloat16,
Expand All @@ -188,57 +188,67 @@ void upsample_nearest3d_kernel(
compute_scales_value<float>(scales_h, input_height, output_height);
const float width_scale =
compute_scales_value<float>(scales_w, input_width, output_width);
if (is_exact) {
upsample_nearest3d_out_template<scalar_t>(
idata,
n,
nbatch,
channels,
input_depth,
input_height,
input_width,
output_depth,
output_height,
output_width,
odata,
depth_scale,
height_scale,
width_scale,
NearestExactIndexOp());
} else {
upsample_nearest3d_out_template<scalar_t>(
idata,
n,
nbatch,
channels,
input_depth,
input_height,
input_width,
output_depth,
output_height,
output_width,
odata,
depth_scale,
height_scale,
width_scale,
NearestIndexOp());
}
AT_DISPATCH_INDEX_TYPES(
at::native::canUse32BitIndexMath(output) ? ScalarType::Int
: ScalarType::Long,
"upsample_nearest3d_xpu_index",
[&] {
if (is_exact) {
upsample_nearest3d_out_template<scalar_t, index_t>(
idata,
n,
nbatch,
channels,
input_depth,
input_height,
input_width,
output_depth,
output_height,
output_width,
odata,
depth_scale,
height_scale,
width_scale,
NearestExactIndexOp());
} else {
upsample_nearest3d_out_template<scalar_t, index_t>(
idata,
n,
nbatch,
channels,
input_depth,
input_height,
input_width,
output_depth,
output_height,
output_width,
odata,
depth_scale,
height_scale,
width_scale,
NearestIndexOp());
}
});
});
if (!output.is_contiguous()) {
output.copy_(output_c);
}
}

template <typename scalar_t, typename accscalar_t, typename index_bw_op_t>
template <
typename scalar_t,
typename accscalar_t,
typename index_t,
typename index_bw_op_t>
struct UpsampleNearest3dBackwardFunctor {
void operator()(sycl::nd_item<1> item) const {
int dst_idx = item.get_global_linear_id();
index_t dst_idx = item.get_global_linear_id();

if (dst_idx >= dim_c_ * dst_dim_d_ * dst_dim_h_ * dst_dim_w_)
return;

int dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_;
int src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_;
index_t dst_c_stride = dst_dim_d_ * dst_dim_h_ * dst_dim_w_;
index_t src_c_stride = src_dim_d_ * src_dim_h_ * src_dim_w_;

int c = (dst_idx / (dst_c_stride)) % dim_c_;

Expand All @@ -259,7 +269,7 @@ struct UpsampleNearest3dBackwardFunctor {
for (int z = src_z; z < src_z_up; z++) {
for (int y = src_y; y < src_y_up; y++) {
for (int x = src_x; x < src_x_up; x++) {
int src_idx = b * dim_c_ * src_c_stride + c * src_c_stride +
index_t src_idx = b * dim_c_ * src_c_stride + c * src_c_stride +
z * src_dim_h_ * src_dim_w_ + y * src_dim_w_ + x;
grad += grad_o_[src_idx];
}
Expand Down Expand Up @@ -315,10 +325,14 @@ struct UpsampleNearest3dBackwardFunctor {
index_bw_op_t index_bw_op_;
};

template <typename scalar_t, typename accscalar_t, typename index_bw_op_t>
template <
typename scalar_t,
typename accscalar_t,
typename index_t,
typename index_bw_op_t>
void upsample_nearest3d_backward_template(
const scalar_t* grad_o,
unsigned int n,
int64_t n,
size_t dim_b,
size_t dim_c,
size_t src_dim_d,
Expand All @@ -333,24 +347,28 @@ void upsample_nearest3d_backward_template(
float width_scale,
index_bw_op_t index_bw_op) {
auto& queue = at::xpu::getCurrentSYCLQueue();
auto kfn =
UpsampleNearest3dBackwardFunctor<scalar_t, accscalar_t, index_bw_op_t>(
grad_o,
dim_b,
dim_c,
src_dim_d,
src_dim_h,
src_dim_w,
dst_dim_d,
dst_dim_h,
dst_dim_w,
grad_i,
depth_scale,
height_scale,
width_scale,
index_bw_op);
auto kfn = UpsampleNearest3dBackwardFunctor<
scalar_t,
accscalar_t,
index_t,
index_bw_op_t>(
grad_o,
dim_b,
dim_c,
src_dim_d,
src_dim_h,
src_dim_w,
dst_dim_d,
dst_dim_h,
dst_dim_w,
grad_i,
depth_scale,
height_scale,
width_scale,
index_bw_op);
auto work_group_size = syclMaxWorkGroupSize(kfn);
int64_t work_group_num = at::ceil_div(n, (unsigned int)work_group_size);
int64_t work_group_num =
at::ceil_div(n, static_cast<int64_t>(work_group_size));
sycl_kernel_submit(
work_group_num * work_group_size, work_group_size, queue, kfn);
}
Expand Down Expand Up @@ -383,9 +401,7 @@ void upsample_nearest3d_backward_kernel(
int input_width = input_size[4];

Tensor grad_output = grad_output_.contiguous();
unsigned int n = grad_input.numel() / nbatch;
TORCH_CHECK(grad_input.numel() <= std::numeric_limits<int32_t>::max());
TORCH_CHECK(grad_output.numel() <= std::numeric_limits<int32_t>::max());
int64_t n = grad_input.numel() / nbatch;
AT_DISPATCH_FLOATING_TYPES_AND3(
ScalarType::Half,
ScalarType::BFloat16,
Expand All @@ -404,41 +420,53 @@ void upsample_nearest3d_backward_kernel(
scales_h, output_height, input_height);
float width_scale = compute_scales_value_backwards<float>(
scales_w, output_width, input_width);
if (is_exact) {
upsample_nearest3d_backward_template<scalar_t, accscalar_t>(
odata,
n,
nbatch,
channels,
output_depth,
output_height,
output_width,
input_depth,
input_height,
input_width,
idata,
depth_scale,
height_scale,
width_scale,
NearestExactBwIndexOp());
} else {
upsample_nearest3d_backward_template<scalar_t, accscalar_t>(
odata,
n,
nbatch,
channels,
output_depth,
output_height,
output_width,
input_depth,
input_height,
input_width,
idata,
depth_scale,
height_scale,
width_scale,
NearestBwIndexOp());
}
AT_DISPATCH_INDEX_TYPES(
at::native::canUse32BitIndexMath(grad_input) ? ScalarType::Int
: ScalarType::Long,
"upsample_nearest3d_backward_xpu_index",
[&] {
if (is_exact) {
upsample_nearest3d_backward_template<
scalar_t,
accscalar_t,
index_t>(
odata,
n,
nbatch,
channels,
output_depth,
output_height,
output_width,
input_depth,
input_height,
input_width,
idata,
depth_scale,
height_scale,
width_scale,
NearestExactBwIndexOp());
} else {
upsample_nearest3d_backward_template<
scalar_t,
accscalar_t,
index_t>(
odata,
n,
nbatch,
channels,
output_depth,
output_height,
output_width,
input_depth,
input_height,
input_width,
idata,
depth_scale,
height_scale,
width_scale,
NearestBwIndexOp());
}
});
});
}

Expand Down
Loading