Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
return view


cdef inline void _check_0d_scalar_conversion(object usm_ary) except *:
"Raise TypeError if array cannot be converted to a Python scalar"
if (usm_ary.ndim != 0):
raise TypeError(
"only 0-dimensional arrays can be converted to Python scalars"
)


cdef int _copy_writable(int lhs_flags, int rhs_flags):
"Copy the WRITABLE flag to lhs_flags from rhs_flags"
return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE)
Expand Down Expand Up @@ -1147,6 +1155,7 @@ cdef class usm_ndarray:

def __float__(self):
if self.size == 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to check the size on that case?

_check_0d_scalar_conversion(self)
view = _as_zero_dim_ndarray(self)
return view.__float__()

Expand All @@ -1156,6 +1165,7 @@ cdef class usm_ndarray:

def __complex__(self):
if self.size == 1:
_check_0d_scalar_conversion(self)
view = _as_zero_dim_ndarray(self)
return view.__complex__()

Expand All @@ -1165,6 +1175,7 @@ cdef class usm_ndarray:

def __int__(self):
if self.size == 1:
_check_0d_scalar_conversion(self)
view = _as_zero_dim_ndarray(self)
return view.__int__()

Expand Down
75 changes: 51 additions & 24 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import pytest
from numpy.testing import assert_raises_regex

import dpctl
import dpctl.memory as dpm
Expand Down Expand Up @@ -282,34 +283,60 @@ def test_properties(dt):
V.mT


@pytest.mark.parametrize("func", [bool, float, int, complex])
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
def test_copy_scalar_with_func(func, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
assert func(X) == func(Y)
class TestCopyScalar:
def test_copy_bool_scalar_with_func(self, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
assert bool(X) == bool(Y)

@pytest.mark.parametrize("func", [float, int, complex])
def test_copy_numeric_scalar_with_func(self, func, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
# Non-0D numeric arrays must not be convertible to Python scalars
if len(shape) != 0:
assert_raises_regex(TypeError, "only 0-dimensional arrays", func, X)
else:
# 0D arrays are allowed to convert
assert func(X) == func(Y)

def test_copy_bool_scalar_with_method(self, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In-place modification of ndarray.shape is a pending deprecation based on NumPy changelog.
It is going to be deprecated in NumPy 2.5 release.

assert getattr(X, "__bool__")() == getattr(Y, "__bool__")()

@pytest.mark.parametrize(
"method", ["__bool__", "__float__", "__int__", "__complex__"]
)
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
def test_copy_scalar_with_method(method, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
assert getattr(X, method)() == getattr(Y, method)()
@pytest.mark.parametrize("method", ["__float__", "__int__", "__complex__"])
def test_copy_numeric_scalar_with_method(self, method, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
if len(shape) != 0:
assert_raises_regex(
TypeError, "only 0-dimensional arrays", getattr(X, method)
)
else:
assert getattr(X, method)() == getattr(Y, method)()


@pytest.mark.parametrize("func", [bool, float, int, complex])
Expand Down
Loading