diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 70bb4243f6..4815645dd4 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -124,6 +124,14 @@ cdef object _as_zero_dim_ndarray(object usm_ary): return view +cdef inline void _check_0d_scalar_conversion(object usm_ary) except *: + "Raise TypeError if array cannot be converted to a Python scalar" + if (usm_ary.ndim != 0): + raise TypeError( + "only 0-dimensional arrays can be converted to Python scalars" + ) + + cdef int _copy_writable(int lhs_flags, int rhs_flags): "Copy the WRITABLE flag to lhs_flags from rhs_flags" return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE) @@ -1147,6 +1155,7 @@ cdef class usm_ndarray: def __float__(self): if self.size == 1: + _check_0d_scalar_conversion(self) view = _as_zero_dim_ndarray(self) return view.__float__() @@ -1156,6 +1165,7 @@ cdef class usm_ndarray: def __complex__(self): if self.size == 1: + _check_0d_scalar_conversion(self) view = _as_zero_dim_ndarray(self) return view.__complex__() @@ -1165,6 +1175,7 @@ cdef class usm_ndarray: def __int__(self): if self.size == 1: + _check_0d_scalar_conversion(self) view = _as_zero_dim_ndarray(self) return view.__int__() diff --git a/dpctl/tests/test_usm_ndarray_ctor.py b/dpctl/tests/test_usm_ndarray_ctor.py index df55dcfc48..3ce94b9c62 100644 --- a/dpctl/tests/test_usm_ndarray_ctor.py +++ b/dpctl/tests/test_usm_ndarray_ctor.py @@ -20,6 +20,7 @@ import numpy as np import pytest +from numpy.testing import assert_raises_regex import dpctl import dpctl.memory as dpm @@ -282,34 +283,60 @@ def test_properties(dt): V.mT -@pytest.mark.parametrize("func", [bool, float, int, complex]) @pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)]) @pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"]) -def test_copy_scalar_with_func(func, shape, dtype): - try: - X = dpt.usm_ndarray(shape, dtype=dtype) - except dpctl.SyclDeviceCreationError: - pytest.skip("No SYCL devices available") - Y = np.arange(1, X.size + 1, dtype=dtype) - X.usm_data.copy_from_host(Y.view("|u1")) - Y.shape = tuple() - assert func(X) == func(Y) +class TestCopyScalar: + def test_copy_bool_scalar_with_func(self, shape, dtype): + try: + X = dpt.usm_ndarray(shape, dtype=dtype) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + Y = np.arange(1, X.size + 1, dtype=dtype) + X.usm_data.copy_from_host(Y.view("|u1")) + Y.shape = tuple() + assert bool(X) == bool(Y) + @pytest.mark.parametrize("func", [float, int, complex]) + def test_copy_numeric_scalar_with_func(self, func, shape, dtype): + try: + X = dpt.usm_ndarray(shape, dtype=dtype) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + Y = np.arange(1, X.size + 1, dtype=dtype) + X.usm_data.copy_from_host(Y.view("|u1")) + Y.shape = tuple() + # Non-0D numeric arrays must not be convertible to Python scalars + if len(shape) != 0: + assert_raises_regex(TypeError, "only 0-dimensional arrays", func, X) + else: + # 0D arrays are allowed to convert + assert func(X) == func(Y) + + def test_copy_bool_scalar_with_method(self, shape, dtype): + try: + X = dpt.usm_ndarray(shape, dtype=dtype) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + Y = np.arange(1, X.size + 1, dtype=dtype) + X.usm_data.copy_from_host(Y.view("|u1")) + Y.shape = tuple() + assert getattr(X, "__bool__")() == getattr(Y, "__bool__")() -@pytest.mark.parametrize( - "method", ["__bool__", "__float__", "__int__", "__complex__"] -) -@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)]) -@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"]) -def test_copy_scalar_with_method(method, shape, dtype): - try: - X = dpt.usm_ndarray(shape, dtype=dtype) - except dpctl.SyclDeviceCreationError: - pytest.skip("No SYCL devices available") - Y = np.arange(1, X.size + 1, dtype=dtype) - X.usm_data.copy_from_host(Y.view("|u1")) - Y.shape = tuple() - assert getattr(X, method)() == getattr(Y, method)() + @pytest.mark.parametrize("method", ["__float__", "__int__", "__complex__"]) + def test_copy_numeric_scalar_with_method(self, method, shape, dtype): + try: + X = dpt.usm_ndarray(shape, dtype=dtype) + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + Y = np.arange(1, X.size + 1, dtype=dtype) + X.usm_data.copy_from_host(Y.view("|u1")) + Y.shape = tuple() + if len(shape) != 0: + assert_raises_regex( + TypeError, "only 0-dimensional arrays", getattr(X, method) + ) + else: + assert getattr(X, method)() == getattr(Y, method)() @pytest.mark.parametrize("func", [bool, float, int, complex])