diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 45f27d31..43b2bb19 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -104,6 +104,31 @@ jobs: - name: Run tests run: | pytest -n auto + build-numpy25: + name: Python 3.12 with NumPy 2.5 (pre-release) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + submodules: true + persist-credentials: false + - name: Set up Python 3.12 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.12" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install scikit-build-core cmake wheel + python -m pip install -U --pre "numpy==2.5.*" \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -c "import numpy; print(f'{numpy.__version__=}')" + - name: Build ml_dtypes + run: | + python -m pip install .[dev] --no-build-isolation + - name: Run tests + run: | + pytest -n auto build-oldest-numpy: name: Python 3.10 with oldest supported numpy runs-on: ubuntu-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index fa2488f8..ac90b30a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): * Drop support for Python 3.9, which reached end-of-life in October 2025. * Drop support for NumPy < 1.24. +* `arr.real` and `arr.imag` now return correct results for `bcomplex32` and + `complex32` arrays on NumPy 2.5+ ([#355](https://github.com/jax-ml/ml_dtypes/issues/355)). ## [0.5.4] - 2025-11-17 diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index 125d2847..968c4bac 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -37,6 +37,8 @@ "uint1", "uint2", "uint4", + "real", + "imag", ] from ml_dtypes._finfo import finfo @@ -84,6 +86,23 @@ bcomplex32: type[_np.generic] complex32: type[_np.generic] +# Augment the C++ extension's terse docstring with a clearer class summary. +bcomplex32.__doc__ = ( + "complex: a 4-byte complex number pairing two bfloat16\n" + "halves (real + imaginary), exposed as an ml_dtypes extension dtype.\n\n" + "WARNING: NumPy does not natively understand this custom complex dtype.\n" + "On NumPy <2.5, arr.real / arr.imag are SILENTLY wrong; use\n" + "ml_dtypes.real() / ml_dtypes.imag() instead (NumPy 2.5+ fixes them).\n\n" + "These complex-aware builtins also do NOT recognize this dtype on ANY\n" + "NumPy version -- cast to np.complex64 first, or use the workaround:\n" + " np.vdot(a,b) -> np.dot(np.conjugate(a), b)\n" + " np.linalg.norm(a) -> np.linalg.norm(a.astype(np.complex64))\n" + " np.iscomplex(a) -> ml_dtypes.imag(a) != 0\n" + " np.angle(a) -> np.arctan2(ml_dtypes.imag(a), ml_dtypes.real(a))\n" + " np.linalg.det/inv -> cast to np.complex64 first (else they raise)\n" + "np.abs, conjugate, arithmetic, reductions, np.dot/inner/outer, casts OK." +) + def real(x: _np.ndarray) -> _np.ndarray: """Return the real part of a complex array. diff --git a/ml_dtypes/_src/custom_complex.h b/ml_dtypes/_src/custom_complex.h index c640aee7..94e569be 100644 --- a/ml_dtypes/_src/custom_complex.h +++ b/ml_dtypes/_src/custom_complex.h @@ -29,7 +29,8 @@ limitations under the License. #include // NOLINT #include // NOLINT #include // NOLINT -#include // NOLINT +#include // NOLINT +#include // NOLINT (std::is_standard_layout_v) #include // NOLINT // Place `` before to avoid a build failure in macOS. #include @@ -904,6 +905,125 @@ bool RegisterComplexUFuncs(PyObject* numpy) { return ok; } + +// Identical to the NumPy code, we could actually do without the loop. +// (Supports full ufunc path, although this is currently unexposed in NumPy.) +template +static NPY_CASTING +complex_to_real_resolve_descriptors( + PyObject *NPY_UNUSED(self), + PyArray_DTypeMeta *const dtypes[2], + PyArray_Descr *const given_descrs[2], + PyArray_Descr *loop_descrs[2], + npy_intp *view_offset) +{ + Py_INCREF(given_descrs[0]); + loop_descrs[0] = given_descrs[0]; + Py_INCREF(dtypes[1]->singleton); + loop_descrs[1] = dtypes[1]->singleton; + + if (PyDataType_ISBYTESWAPPED(loop_descrs[0])) { + Py_SETREF( + loop_descrs[1], PyArray_DescrNewByteorder(loop_descrs[1], NPY_SWAP)); + if (loop_descrs[1] == NULL) { + Py_DECREF(loop_descrs[0]); + return _NPY_ERROR_OCCURRED_IN_CAST; + } + } + if constexpr (real_part) { + *view_offset = 0; + } + else { + *view_offset = PyDataType_ELSIZE(loop_descrs[1]); + } + return NPY_NO_CASTING; +} + + +/* We shouldn't normally use it, but define a simple loop anyway. */ +template +static int extract_complex_part_loop( + PyArrayMethod_Context *context, char *const data[], + npy_intp const dimensions[], npy_intp const strides[], + NpyAuxData *NPY_UNUSED(auxdata)) +{ + using real_type = typename T::value_type; + // Both complex_to_real_resolve_descriptors' view_offset (0 for real, + // sizeof(real_type) for imag) and this loop's reinterpret-based extraction + // assume a complex element is exactly two consecutive real_type halves with + // no padding. Guard that invariant at compile time per instantiated type. + static_assert(sizeof(T) == 2 * sizeof(real_type), + "complex element must be exactly two value_types, no padding"); + static_assert(std::is_standard_layout_v, + "complex element must be standard-layout so reinterpret-based " + "real/imag extraction is well-defined"); + static_assert(alignof(T) == alignof(real_type), + "complex alignment must match its half-type alignment"); + npy_intp N = dimensions[0]; + char *in = data[0]; + char *out = data[1]; + npy_intp istride = strides[0]; + npy_intp ostride = strides[1]; + + if constexpr (!real_part) { + in += sizeof(real_type); + } + + while (N--) { + real_type value = *reinterpret_cast(in); + *reinterpret_cast(out) = value; + in += istride; + out += ostride; + } + return 0; +} + +template +int RegisterRealImag(PyArray_DTypeMeta* complex_dtype) { + using real_type = typename T::value_type; + Safe_PyObjectPtr real_descr = make_safe( + (PyObject*)PyArray_DescrFromType(TypeDescriptor::Dtype())); + if (!real_descr) { + return -1; + } + + PyType_Slot meth_slots[] = { + {NPY_METH_resolve_descriptors, (void*)&complex_to_real_resolve_descriptors}, + {NPY_METH_strided_loop, (void*)&extract_complex_part_loop}, + {0, nullptr}, + }; + PyArray_DTypeMeta* dtypes[2] = {complex_dtype, NPY_DTYPE(real_descr.get())}; + PyArrayMethod_Spec meth_spec; + meth_spec.name = "generic_real_imag_loop"; + meth_spec.flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + meth_spec.nin = 1; + meth_spec.nout = 1; + meth_spec.dtypes = dtypes; + meth_spec.slots = meth_slots; + meth_spec.casting = NPY_NO_CASTING; + constexpr const char* ufunc_name = real_part ? "real" : "imag"; + PyUFunc_LoopSlot loop_slots[] = { + {ufunc_name, &meth_spec}, + {nullptr, nullptr}, + }; + return PyUFunc_AddLoopsFromSpecs(loop_slots); +} + + +template +int RegisterRealAndImag(PyArray_DTypeMeta* complex_dtype) { + // real()/imag() helpers landed in NumPy 2.5 (NPY_2_5_API_VERSION == 0x16). + // Gate registration on the runtime API version so older NumPy skips it. + if (PyArray_RUNTIME_VERSION < 0x16) { + return 0; + } + if (RegisterRealImag(complex_dtype) < 0) { + return -1; + } + return RegisterRealImag(complex_dtype); +} + + template bool RegisterComplexDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -963,6 +1083,10 @@ bool RegisterComplexDtype(PyObject* numpy) { CustomComplexType::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); + if (RegisterRealAndImag(NPY_DTYPE(CustomComplexType::npy_descr)) < 0) { + return false; + } + Safe_PyObjectPtr typeDict_obj = make_safe(PyObject_GetAttrString(numpy, "sctypeDict")); if (!typeDict_obj) return false; diff --git a/ml_dtypes/_src/numpy.h b/ml_dtypes/_src/numpy.h index 8b55e4d9..e1f0ba7a 100644 --- a/ml_dtypes/_src/numpy.h +++ b/ml_dtypes/_src/numpy.h @@ -37,6 +37,28 @@ limitations under the License. #include "numpy/arrayobject.h" #include "numpy/arrayscalars.h" #include "numpy/ufuncobject.h" +#include "numpy/dtype_api.h" + + +#ifndef PyUFunc_AddLoopsFromSpecs +// Backport `PyUFunc_AddLoopsFromSpecs` for conditional use if we are +// on NumPy 2.5+ at runtime. (function available with 2.4, but not imag/real) +#if NPY_API_VERSION < 0x15 +typedef struct { + const char *name; + PyArrayMethod_Spec *spec; +} PyUFunc_LoopSlot; // defined starting NumPy 2.4+ +#endif + +static inline int PyUFunc_AddLoopsFromSpecs(PyUFunc_LoopSlot *loop_specs) { + if (PyArray_RUNTIME_VERSION < 0x15) { + return 0; // no-op as function is not available. + } + return (*(int (*)(PyUFunc_LoopSlot *))PyUFunc_API[47])(loop_specs); +} + +#endif + namespace ml_dtypes { diff --git a/ml_dtypes/tests/custom_complex_test.py b/ml_dtypes/tests/custom_complex_test.py index ddf76ff9..f97c93b0 100644 --- a/ml_dtypes/tests/custom_complex_test.py +++ b/ml_dtypes/tests/custom_complex_test.py @@ -129,6 +129,21 @@ def test_real_imag_arrays(sctype): np.testing.assert_array_equal(imag_part, [2.0, 4.0]) +@pytest.mark.parametrize("sctype", COMPLEX_SCTYPES) +@pytest.mark.xfail( + np.lib.NumpyVersion(np.__version__) < "2.5.0.dev0", + reason="2.5 introduced real and imag helpers." +) +def test_real_imag_arrays_numpy25(sctype): + # Test ml_dtypes.real() and ml_dtypes.imag() helpers (NumPy 2.5+ path). + arr = np.array([1 + 2j, 3 + 4j], dtype=sctype) + real_part = ml_dtypes.real(arr) + imag_part = ml_dtypes.imag(arr) + expected_dtype = ml_dtypes.finfo(sctype).dtype # the real one + assert real_part.dtype == imag_part.dtype == expected_dtype + np.testing.assert_array_equal(real_part, [1.0, 3.0]) + np.testing.assert_array_equal(imag_part, [2.0, 4.0]) + @pytest.mark.parametrize("sctype", COMPLEX_SCTYPES) @pytest.mark.parametrize("value", COMPLEX_VALUES) def test_str_repr(sctype, value): @@ -232,8 +247,10 @@ def test_cast_from_float(sctype, from_dtype): x = np.array([1.0, 2.0, 3.0], dtype=from_dtype) y = x.astype(sctype) assert y.dtype == sctype - np.testing.assert_array_equal(ml_dtypes.real(y).astype(np.float32), x) - np.testing.assert_array_equal(ml_dtypes.imag(y).astype(np.float32), 0.0) + real_part = ml_dtypes.real(y).astype(np.float32) + imag_part = ml_dtypes.imag(y).astype(np.float32) + np.testing.assert_array_equal(real_part, x) + np.testing.assert_array_equal(imag_part, 0.0) @pytest.mark.parametrize("sctype", COMPLEX_SCTYPES)