diff --git a/src/ATen/native/xpu/RangeFactories.cpp b/src/ATen/native/xpu/RangeFactories.cpp index cfb538c7b6..131164f11d 100644 --- a/src/ATen/native/xpu/RangeFactories.cpp +++ b/src/ATen/native/xpu/RangeFactories.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -43,23 +44,7 @@ Tensor& arange_out_xpu( out.scalar_type(), "arange_xpu_preprocess", [&]() { - using accscalar_t = at::acc_type_device; - auto xstart = start.to(); - auto xend = end.to(); - auto xstep = step.to(); - - TORCH_CHECK(xstep > 0 || xstep < 0, "step must be nonzero"); - TORCH_CHECK( - std::isfinite(static_cast(xstart)) && - std::isfinite(static_cast(xend)), - "unsupported range: ", - xstart, - " -> ", - xend); - TORCH_CHECK( - ((xstep > 0) && (xend >= xstart)) || - ((xstep < 0) && (xend <= xstart)), - "upper bound and larger bound inconsistent with step sign"); + arange_check_bounds(start, end, step); // we use double precision for (start - end) / step // to compute size_d for consistency across devices. @@ -71,6 +56,11 @@ Tensor& arange_out_xpu( // than double double size_d; if constexpr (std::is_same_v) { + using accscalar_t = at::acc_type_device; + auto xstart = start.to(); + auto xend = end.to(); + auto xstep = step.to(); + TORCH_CHECK_VALUE(xstep != 0, "step must be nonzero"); int64_t sgn = (xstep > 0) - (xstep < 0); size_d = std::ceil((xend - xstart + xstep - sgn) / xstep); } else {