From a74988318c274349d755d89d91e886d9199528a2 Mon Sep 17 00:00:00 2001 From: DeanTMaxim Date: Sun, 14 Jun 2026 17:57:21 +0800 Subject: [PATCH 1/2] Complete bcomplex32 real/imag on NumPy 2.5+ and warn on NumPy <2.5 Takes over #375 per issue #355. Registers real/imag as ufunc loops on the np.real/np.imag ufuncs via PyUFunc_AddLoopsFromSpecs (the NumPy 2.4+ convenience API), gated on PyArray_RUNTIME_VERSION >= 0x16 (NPY_2_5_API_VERSION) so NumPy <2.5 skips registration. On NumPy 2.5+ this makes arr.real / arr.imag return correct results for bcomplex32 and complex32, which NumPy otherwise cannot recognize as complex (their dtype kind is 'W', not 'c'). The registered loop returns NPY_NO_CASTING with a view_offset (0 for the real half, elsize for the imaginary half), so the result is a zero-copy view into the input buffer, with a reinterpret-based strided loop as the fallback. static_asserts guard the [real, imag] no-padding layout that the view/extraction relies on. PyUFunc_AddLoopsFromSpecs is backported in numpy.h (slot 47) so a single wheel runs across NumPy versions. On NumPy <2.5 the native arr.real/arr.imag are silently wrong, so the ml_dtypes.real/imag Python helpers (correct on all versions) emit a RuntimeWarning steering users toward the helpers or an upgrade. Builds as C++17: the PyArrayMethod_Spec is constructed by field assignment rather than designated initializers, so no build-standard change is needed. Adds a NumPy 2.5 pre-release job to the CI matrix. Several other complex-aware builtins (np.iscomplex, np.vdot, np.linalg.norm, np.angle, np.linalg.det/inv) do not recognize these custom complex dtypes on any NumPy version, because NumPy keys complex-ness off the builtin type-number range rather than the dtype kind; bcomplex32.__doc__ documents these with one-line workarounds. Tested locally against NumPy 2.5.0rc1. --- .github/workflows/test.yml | 25 +++++ CHANGELOG.md | 4 + ml_dtypes/__init__.py | 52 ++++++++++ ml_dtypes/_src/custom_complex.h | 126 ++++++++++++++++++++++++- ml_dtypes/_src/numpy.h | 22 +++++ ml_dtypes/tests/custom_complex_test.py | 83 +++++++++++++--- 6 files changed, 297 insertions(+), 15 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 45f27d31..43b2bb19 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -104,6 +104,31 @@ jobs: - name: Run tests run: | pytest -n auto + build-numpy25: + name: Python 3.12 with NumPy 2.5 (pre-release) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + submodules: true + persist-credentials: false + - name: Set up Python 3.12 + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + with: + python-version: "3.12" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install scikit-build-core cmake wheel + python -m pip install -U --pre "numpy==2.5.*" \ + -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple + python -c "import numpy; print(f'{numpy.__version__=}')" + - name: Build ml_dtypes + run: | + python -m pip install .[dev] --no-build-isolation + - name: Run tests + run: | + pytest -n auto build-oldest-numpy: name: Python 3.10 with oldest supported numpy runs-on: ubuntu-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index fa2488f8..b38ccf27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,10 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): * Drop support for Python 3.9, which reached end-of-life in October 2025. * Drop support for NumPy < 1.24. +* `arr.real` and `arr.imag` now return correct results for `bcomplex32` and + `complex32` arrays on NumPy 2.5+ ([#355](https://github.com/jax-ml/ml_dtypes/issues/355)). +* `ml_dtypes.real` and `ml_dtypes.imag` now emit a `RuntimeWarning` on + NumPy <2.5, where `arr.real`/`arr.imag` are silently incorrect. ## [0.5.4] - 2025-11-17 diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index 125d2847..cb4b3c16 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -37,8 +37,12 @@ "uint1", "uint2", "uint4", + "real", + "imag", ] +import warnings as _warnings + from ml_dtypes._finfo import finfo from ml_dtypes._iinfo import iinfo from ml_dtypes._ml_dtypes_ext import bcomplex32 @@ -84,6 +88,46 @@ bcomplex32: type[_np.generic] complex32: type[_np.generic] +# Augment the C++ extension's terse docstring with a clearer class summary. +bcomplex32.__doc__ = ( + "complex: a 4-byte complex number pairing two bfloat16\n" + "halves (real + imaginary), exposed as an ml_dtypes extension dtype.\n\n" + "WARNING: NumPy does not natively understand this custom complex dtype.\n" + "On NumPy <2.5, arr.real / arr.imag are SILENTLY wrong; use\n" + "ml_dtypes.real() / ml_dtypes.imag() instead (NumPy 2.5+ fixes them).\n\n" + "These complex-aware builtins also do NOT recognize this dtype on ANY\n" + "NumPy version -- cast to np.complex64 first, or use the workaround:\n" + " np.vdot(a,b) -> np.dot(np.conjugate(a), b)\n" + " np.linalg.norm(a) -> np.linalg.norm(a.astype(np.complex64))\n" + " np.iscomplex(a) -> ml_dtypes.imag(a) != 0\n" + " np.angle(a) -> np.arctan2(ml_dtypes.imag(a), ml_dtypes.real(a))\n" + " np.linalg.det/inv -> cast to np.complex64 first (else they raise)\n" + "np.abs, conjugate, arithmetic, reductions, np.dot/inner/outer, casts OK." +) + + +def _warn_old_numpy(fn_name: str) -> None: + """Emit a RuntimeWarning on NumPy <2.5. + + On NumPy <2.5, arr.real / arr.imag return silently incorrect results for + ml_dtypes complex arrays (bcomplex32, complex32). This helper itself is + correct; the warning steers users away from arr.real/arr.imag and toward + upgrading to NumPy 2.5+. + + Args: + fn_name: The public function name (``"real"`` or ``"imag"``) to include + in the warning message so callers can identify which helper fired it. + """ + if _np.lib.NumpyVersion(_np.__version__) < "2.5.0.dev0": + _warnings.warn( + f"NumPy <2.5 miscomputes arr.real/arr.imag for ml_dtypes complex " + f"arrays; this ml_dtypes.{fn_name}() call is correct, but prefer " + "upgrading to NumPy 2.5+.", + RuntimeWarning, + # 1 = _warn_old_numpy, 2 = real/imag, 3 = user's call site + stacklevel=3, + ) + def real(x: _np.ndarray) -> _np.ndarray: """Return the real part of a complex array. @@ -92,12 +136,16 @@ def real(x: _np.ndarray) -> _np.ndarray: bcomplex32 or complex32. NumPy cannot correctly understand that these are complex dtypes as of NumPy 2.4 at least. + On NumPy <2.5, a ``RuntimeWarning`` is emitted because the legacy path + may produce incorrect results for ml_dtypes custom complex types. + Args: x: The input array. Returns: The real part of the input array. """ + _warn_old_numpy("real") if isinstance(x, _np.ndarray): # Use a view. We add an axes to ensure it is contiguous. if x.dtype.type is bcomplex32: @@ -116,12 +164,16 @@ def imag(x: _np.ndarray) -> _np.ndarray: bcomplex32 or complex32. NumPy cannot correctly understand that these are complex dtypes as of NumPy 2.4 at least. + On NumPy <2.5, a ``RuntimeWarning`` is emitted because the legacy path + may produce incorrect results for ml_dtypes custom complex types. + Args: x: The input array. Returns: The imaginary part of the input array. """ + _warn_old_numpy("imag") if isinstance(x, _np.ndarray): # Use a view. We add an axes to ensure it is contiguous. if x.dtype.type is bcomplex32: diff --git a/ml_dtypes/_src/custom_complex.h b/ml_dtypes/_src/custom_complex.h index c640aee7..94e569be 100644 --- a/ml_dtypes/_src/custom_complex.h +++ b/ml_dtypes/_src/custom_complex.h @@ -29,7 +29,8 @@ limitations under the License. #include // NOLINT #include // NOLINT #include // NOLINT -#include // NOLINT +#include // NOLINT +#include // NOLINT (std::is_standard_layout_v) #include // NOLINT // Place `` before to avoid a build failure in macOS. #include @@ -904,6 +905,125 @@ bool RegisterComplexUFuncs(PyObject* numpy) { return ok; } + +// Identical to the NumPy code, we could actually do without the loop. +// (Supports full ufunc path, although this is currently unexposed in NumPy.) +template +static NPY_CASTING +complex_to_real_resolve_descriptors( + PyObject *NPY_UNUSED(self), + PyArray_DTypeMeta *const dtypes[2], + PyArray_Descr *const given_descrs[2], + PyArray_Descr *loop_descrs[2], + npy_intp *view_offset) +{ + Py_INCREF(given_descrs[0]); + loop_descrs[0] = given_descrs[0]; + Py_INCREF(dtypes[1]->singleton); + loop_descrs[1] = dtypes[1]->singleton; + + if (PyDataType_ISBYTESWAPPED(loop_descrs[0])) { + Py_SETREF( + loop_descrs[1], PyArray_DescrNewByteorder(loop_descrs[1], NPY_SWAP)); + if (loop_descrs[1] == NULL) { + Py_DECREF(loop_descrs[0]); + return _NPY_ERROR_OCCURRED_IN_CAST; + } + } + if constexpr (real_part) { + *view_offset = 0; + } + else { + *view_offset = PyDataType_ELSIZE(loop_descrs[1]); + } + return NPY_NO_CASTING; +} + + +/* We shouldn't normally use it, but define a simple loop anyway. */ +template +static int extract_complex_part_loop( + PyArrayMethod_Context *context, char *const data[], + npy_intp const dimensions[], npy_intp const strides[], + NpyAuxData *NPY_UNUSED(auxdata)) +{ + using real_type = typename T::value_type; + // Both complex_to_real_resolve_descriptors' view_offset (0 for real, + // sizeof(real_type) for imag) and this loop's reinterpret-based extraction + // assume a complex element is exactly two consecutive real_type halves with + // no padding. Guard that invariant at compile time per instantiated type. + static_assert(sizeof(T) == 2 * sizeof(real_type), + "complex element must be exactly two value_types, no padding"); + static_assert(std::is_standard_layout_v, + "complex element must be standard-layout so reinterpret-based " + "real/imag extraction is well-defined"); + static_assert(alignof(T) == alignof(real_type), + "complex alignment must match its half-type alignment"); + npy_intp N = dimensions[0]; + char *in = data[0]; + char *out = data[1]; + npy_intp istride = strides[0]; + npy_intp ostride = strides[1]; + + if constexpr (!real_part) { + in += sizeof(real_type); + } + + while (N--) { + real_type value = *reinterpret_cast(in); + *reinterpret_cast(out) = value; + in += istride; + out += ostride; + } + return 0; +} + +template +int RegisterRealImag(PyArray_DTypeMeta* complex_dtype) { + using real_type = typename T::value_type; + Safe_PyObjectPtr real_descr = make_safe( + (PyObject*)PyArray_DescrFromType(TypeDescriptor::Dtype())); + if (!real_descr) { + return -1; + } + + PyType_Slot meth_slots[] = { + {NPY_METH_resolve_descriptors, (void*)&complex_to_real_resolve_descriptors}, + {NPY_METH_strided_loop, (void*)&extract_complex_part_loop}, + {0, nullptr}, + }; + PyArray_DTypeMeta* dtypes[2] = {complex_dtype, NPY_DTYPE(real_descr.get())}; + PyArrayMethod_Spec meth_spec; + meth_spec.name = "generic_real_imag_loop"; + meth_spec.flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + meth_spec.nin = 1; + meth_spec.nout = 1; + meth_spec.dtypes = dtypes; + meth_spec.slots = meth_slots; + meth_spec.casting = NPY_NO_CASTING; + constexpr const char* ufunc_name = real_part ? "real" : "imag"; + PyUFunc_LoopSlot loop_slots[] = { + {ufunc_name, &meth_spec}, + {nullptr, nullptr}, + }; + return PyUFunc_AddLoopsFromSpecs(loop_slots); +} + + +template +int RegisterRealAndImag(PyArray_DTypeMeta* complex_dtype) { + // real()/imag() helpers landed in NumPy 2.5 (NPY_2_5_API_VERSION == 0x16). + // Gate registration on the runtime API version so older NumPy skips it. + if (PyArray_RUNTIME_VERSION < 0x16) { + return 0; + } + if (RegisterRealImag(complex_dtype) < 0) { + return -1; + } + return RegisterRealImag(complex_dtype); +} + + template bool RegisterComplexDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -963,6 +1083,10 @@ bool RegisterComplexDtype(PyObject* numpy) { CustomComplexType::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); + if (RegisterRealAndImag(NPY_DTYPE(CustomComplexType::npy_descr)) < 0) { + return false; + } + Safe_PyObjectPtr typeDict_obj = make_safe(PyObject_GetAttrString(numpy, "sctypeDict")); if (!typeDict_obj) return false; diff --git a/ml_dtypes/_src/numpy.h b/ml_dtypes/_src/numpy.h index 8b55e4d9..e1f0ba7a 100644 --- a/ml_dtypes/_src/numpy.h +++ b/ml_dtypes/_src/numpy.h @@ -37,6 +37,28 @@ limitations under the License. #include "numpy/arrayobject.h" #include "numpy/arrayscalars.h" #include "numpy/ufuncobject.h" +#include "numpy/dtype_api.h" + + +#ifndef PyUFunc_AddLoopsFromSpecs +// Backport `PyUFunc_AddLoopsFromSpecs` for conditional use if we are +// on NumPy 2.5+ at runtime. (function available with 2.4, but not imag/real) +#if NPY_API_VERSION < 0x15 +typedef struct { + const char *name; + PyArrayMethod_Spec *spec; +} PyUFunc_LoopSlot; // defined starting NumPy 2.4+ +#endif + +static inline int PyUFunc_AddLoopsFromSpecs(PyUFunc_LoopSlot *loop_specs) { + if (PyArray_RUNTIME_VERSION < 0x15) { + return 0; // no-op as function is not available. + } + return (*(int (*)(PyUFunc_LoopSlot *))PyUFunc_API[47])(loop_specs); +} + +#endif + namespace ml_dtypes { diff --git a/ml_dtypes/tests/custom_complex_test.py b/ml_dtypes/tests/custom_complex_test.py index ddf76ff9..26ae20bb 100644 --- a/ml_dtypes/tests/custom_complex_test.py +++ b/ml_dtypes/tests/custom_complex_test.py @@ -17,6 +17,7 @@ import operator import pickle import sys +import warnings import ml_dtypes import numpy as np @@ -121,14 +122,36 @@ def test_real_imag_scalars(sctype): def test_real_imag_arrays(sctype): # Test ml_dtypes.real() and ml_dtypes.imag() helpers. arr = np.array([1 + 2j, 3 + 4j], dtype=sctype) - real_part = ml_dtypes.real(arr) - imag_part = ml_dtypes.imag(arr) + # Suppress the NumPy <2.5 RuntimeWarning so pytest.ini's + # filterwarnings=error doesn't promote it to an exception. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + real_part = ml_dtypes.real(arr) + imag_part = ml_dtypes.imag(arr) expected_dtype = ml_dtypes.finfo(sctype).dtype # the real one assert real_part.dtype == imag_part.dtype == expected_dtype np.testing.assert_array_equal(real_part, [1.0, 3.0]) np.testing.assert_array_equal(imag_part, [2.0, 4.0]) +@pytest.mark.parametrize("sctype", COMPLEX_SCTYPES) +@pytest.mark.xfail( + np.lib.NumpyVersion(np.__version__) < "2.5.0.dev0", + reason="2.5 introduced real and imag helpers." +) +def test_real_imag_arrays_numpy25(sctype): + # Test ml_dtypes.real() and ml_dtypes.imag() helpers (NumPy 2.5+ path). + arr = np.array([1 + 2j, 3 + 4j], dtype=sctype) + # Suppress the NumPy <2.5 RuntimeWarning (test runs on 2.5+ but be safe). + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + real_part = ml_dtypes.real(arr) + imag_part = ml_dtypes.imag(arr) + expected_dtype = ml_dtypes.finfo(sctype).dtype # the real one + assert real_part.dtype == imag_part.dtype == expected_dtype + np.testing.assert_array_equal(real_part, [1.0, 3.0]) + np.testing.assert_array_equal(imag_part, [2.0, 4.0]) + @pytest.mark.parametrize("sctype", COMPLEX_SCTYPES) @pytest.mark.parametrize("value", COMPLEX_VALUES) def test_str_repr(sctype, value): @@ -232,8 +255,13 @@ def test_cast_from_float(sctype, from_dtype): x = np.array([1.0, 2.0, 3.0], dtype=from_dtype) y = x.astype(sctype) assert y.dtype == sctype - np.testing.assert_array_equal(ml_dtypes.real(y).astype(np.float32), x) - np.testing.assert_array_equal(ml_dtypes.imag(y).astype(np.float32), 0.0) + # Suppress the NumPy <2.5 RuntimeWarning from ml_dtypes.real/imag. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + real_part = ml_dtypes.real(y).astype(np.float32) + imag_part = ml_dtypes.imag(y).astype(np.float32) + np.testing.assert_array_equal(real_part, x) + np.testing.assert_array_equal(imag_part, 0.0) @pytest.mark.parametrize("sctype", COMPLEX_SCTYPES) @@ -381,15 +409,18 @@ def test_unary_ufuncs(sctype, ufunc): if sys.platform == "win32": mismatch = np.zeros(len(x), dtype=bool) - if ufunc == np.cos: - # cos(1+infj) returns inf+infj instead of inf-infj - mismatch = (ml_dtypes.real(x) == 1.0) & (ml_dtypes.imag(x) == np.inf) - elif ufunc == np.sinh: - # sinh(+/-inf+0j) returns +/-inf+infj instead of +/-inf+0j - mismatch = np.isinf(ml_dtypes.real(x)) & (ml_dtypes.imag(x) == 0.0) - elif ufunc == np.cosh: - # cosh(-inf+0j) signs 0j wrong - mismatch = (ml_dtypes.real(x) == -np.inf) & (ml_dtypes.imag(x) == 0.0) + # Suppress the NumPy <2.5 RuntimeWarning from ml_dtypes.real/imag. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + if ufunc == np.cos: + # cos(1+infj) returns inf+infj instead of inf-infj + mismatch = (ml_dtypes.real(x) == 1.0) & (ml_dtypes.imag(x) == np.inf) + elif ufunc == np.sinh: + # sinh(+/-inf+0j) returns +/-inf+infj instead of +/-inf+0j + mismatch = np.isinf(ml_dtypes.real(x)) & (ml_dtypes.imag(x) == 0.0) + elif ufunc == np.cosh: + # cosh(-inf+0j) signs 0j wrong + mismatch = (ml_dtypes.real(x) == -np.inf) & (ml_dtypes.imag(x) == 0.0) expected = expected[~mismatch] result = result[~mismatch] @@ -442,7 +473,10 @@ def test_binary_ufuncs(sctype, ufunc): if ufunc == np.power: # TODO(seberg): std::power deals poorly with some values, drop for now. - x = x[(ml_dtypes.real(x) != 0) & np.isfinite(x)] + # Suppress the NumPy <2.5 RuntimeWarning from ml_dtypes.real. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) + x = x[(ml_dtypes.real(x) != 0) & np.isfinite(x)] y = x[:, np.newaxis] @@ -472,3 +506,24 @@ def test_dot_product(sctype): result = np.dot(x, y) expected = np.dot(x.astype(np.complex64), y.astype(np.complex64)) np.testing.assert_allclose(complex(result), complex(expected), rtol=1e-2) + + +@pytest.mark.parametrize( + "fn_name,fn", [("real", ml_dtypes.real), ("imag", ml_dtypes.imag)] +) +def test_real_imag_warning_on_old_numpy(fn_name, fn, monkeypatch): + """On NumPy <2.5, ml_dtypes.real/imag must emit RuntimeWarning naming the function.""" + # Force the version check to think we're on old NumPy + monkeypatch.setattr(np, "__version__", "2.4.0") + arr = np.array([1 + 2j, 3 + 4j], dtype=ml_dtypes.bcomplex32) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + fn(arr) + assert len(w) == 1 + assert issubclass(w[0].category, RuntimeWarning) + msg = str(w[0].message) + assert f"ml_dtypes.{fn_name}" in msg, ( + f"warning should name ml_dtypes.{fn_name}, got: {msg}" + ) + assert "miscomputes" in msg + assert "upgrading to NumPy" in msg From 1223c258f136467ac400bf6a58478937fcbd5da1 Mon Sep 17 00:00:00 2001 From: DeanTMaxim Date: Sun, 14 Jun 2026 18:56:08 +0800 Subject: [PATCH 2/2] Drop the NumPy <2.5 helper warning per review The warning fired inside ml_dtypes.real/imag, which are already correct on all NumPy versions, so it only reached users of the safe helpers -- not the arr.real/arr.imag users who hit the actual silent-wrong on <2.5. Drop it and rely on the docstring; a module-level __getattr__ guard would be the correct placement but needs the types exposed lazily to actually fire. Removes _warn_old_numpy, its calls, the warning test, the now-unneeded RuntimeWarning suppressions in the other tests, and the CHANGELOG entry. --- CHANGELOG.md | 2 - ml_dtypes/__init__.py | 33 ------------ ml_dtypes/tests/custom_complex_test.py | 70 ++++++-------------------- 3 files changed, 16 insertions(+), 89 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b38ccf27..ac90b30a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,8 +27,6 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): * Drop support for NumPy < 1.24. * `arr.real` and `arr.imag` now return correct results for `bcomplex32` and `complex32` arrays on NumPy 2.5+ ([#355](https://github.com/jax-ml/ml_dtypes/issues/355)). -* `ml_dtypes.real` and `ml_dtypes.imag` now emit a `RuntimeWarning` on - NumPy <2.5, where `arr.real`/`arr.imag` are silently incorrect. ## [0.5.4] - 2025-11-17 diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index cb4b3c16..968c4bac 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -41,8 +41,6 @@ "imag", ] -import warnings as _warnings - from ml_dtypes._finfo import finfo from ml_dtypes._iinfo import iinfo from ml_dtypes._ml_dtypes_ext import bcomplex32 @@ -106,29 +104,6 @@ ) -def _warn_old_numpy(fn_name: str) -> None: - """Emit a RuntimeWarning on NumPy <2.5. - - On NumPy <2.5, arr.real / arr.imag return silently incorrect results for - ml_dtypes complex arrays (bcomplex32, complex32). This helper itself is - correct; the warning steers users away from arr.real/arr.imag and toward - upgrading to NumPy 2.5+. - - Args: - fn_name: The public function name (``"real"`` or ``"imag"``) to include - in the warning message so callers can identify which helper fired it. - """ - if _np.lib.NumpyVersion(_np.__version__) < "2.5.0.dev0": - _warnings.warn( - f"NumPy <2.5 miscomputes arr.real/arr.imag for ml_dtypes complex " - f"arrays; this ml_dtypes.{fn_name}() call is correct, but prefer " - "upgrading to NumPy 2.5+.", - RuntimeWarning, - # 1 = _warn_old_numpy, 2 = real/imag, 3 = user's call site - stacklevel=3, - ) - - def real(x: _np.ndarray) -> _np.ndarray: """Return the real part of a complex array. @@ -136,16 +111,12 @@ def real(x: _np.ndarray) -> _np.ndarray: bcomplex32 or complex32. NumPy cannot correctly understand that these are complex dtypes as of NumPy 2.4 at least. - On NumPy <2.5, a ``RuntimeWarning`` is emitted because the legacy path - may produce incorrect results for ml_dtypes custom complex types. - Args: x: The input array. Returns: The real part of the input array. """ - _warn_old_numpy("real") if isinstance(x, _np.ndarray): # Use a view. We add an axes to ensure it is contiguous. if x.dtype.type is bcomplex32: @@ -164,16 +135,12 @@ def imag(x: _np.ndarray) -> _np.ndarray: bcomplex32 or complex32. NumPy cannot correctly understand that these are complex dtypes as of NumPy 2.4 at least. - On NumPy <2.5, a ``RuntimeWarning`` is emitted because the legacy path - may produce incorrect results for ml_dtypes custom complex types. - Args: x: The input array. Returns: The imaginary part of the input array. """ - _warn_old_numpy("imag") if isinstance(x, _np.ndarray): # Use a view. We add an axes to ensure it is contiguous. if x.dtype.type is bcomplex32: diff --git a/ml_dtypes/tests/custom_complex_test.py b/ml_dtypes/tests/custom_complex_test.py index 26ae20bb..f97c93b0 100644 --- a/ml_dtypes/tests/custom_complex_test.py +++ b/ml_dtypes/tests/custom_complex_test.py @@ -17,7 +17,6 @@ import operator import pickle import sys -import warnings import ml_dtypes import numpy as np @@ -122,12 +121,8 @@ def test_real_imag_scalars(sctype): def test_real_imag_arrays(sctype): # Test ml_dtypes.real() and ml_dtypes.imag() helpers. arr = np.array([1 + 2j, 3 + 4j], dtype=sctype) - # Suppress the NumPy <2.5 RuntimeWarning so pytest.ini's - # filterwarnings=error doesn't promote it to an exception. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RuntimeWarning) - real_part = ml_dtypes.real(arr) - imag_part = ml_dtypes.imag(arr) + real_part = ml_dtypes.real(arr) + imag_part = ml_dtypes.imag(arr) expected_dtype = ml_dtypes.finfo(sctype).dtype # the real one assert real_part.dtype == imag_part.dtype == expected_dtype np.testing.assert_array_equal(real_part, [1.0, 3.0]) @@ -142,11 +137,8 @@ def test_real_imag_arrays(sctype): def test_real_imag_arrays_numpy25(sctype): # Test ml_dtypes.real() and ml_dtypes.imag() helpers (NumPy 2.5+ path). arr = np.array([1 + 2j, 3 + 4j], dtype=sctype) - # Suppress the NumPy <2.5 RuntimeWarning (test runs on 2.5+ but be safe). - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RuntimeWarning) - real_part = ml_dtypes.real(arr) - imag_part = ml_dtypes.imag(arr) + real_part = ml_dtypes.real(arr) + imag_part = ml_dtypes.imag(arr) expected_dtype = ml_dtypes.finfo(sctype).dtype # the real one assert real_part.dtype == imag_part.dtype == expected_dtype np.testing.assert_array_equal(real_part, [1.0, 3.0]) @@ -255,11 +247,8 @@ def test_cast_from_float(sctype, from_dtype): x = np.array([1.0, 2.0, 3.0], dtype=from_dtype) y = x.astype(sctype) assert y.dtype == sctype - # Suppress the NumPy <2.5 RuntimeWarning from ml_dtypes.real/imag. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RuntimeWarning) - real_part = ml_dtypes.real(y).astype(np.float32) - imag_part = ml_dtypes.imag(y).astype(np.float32) + real_part = ml_dtypes.real(y).astype(np.float32) + imag_part = ml_dtypes.imag(y).astype(np.float32) np.testing.assert_array_equal(real_part, x) np.testing.assert_array_equal(imag_part, 0.0) @@ -409,18 +398,15 @@ def test_unary_ufuncs(sctype, ufunc): if sys.platform == "win32": mismatch = np.zeros(len(x), dtype=bool) - # Suppress the NumPy <2.5 RuntimeWarning from ml_dtypes.real/imag. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RuntimeWarning) - if ufunc == np.cos: - # cos(1+infj) returns inf+infj instead of inf-infj - mismatch = (ml_dtypes.real(x) == 1.0) & (ml_dtypes.imag(x) == np.inf) - elif ufunc == np.sinh: - # sinh(+/-inf+0j) returns +/-inf+infj instead of +/-inf+0j - mismatch = np.isinf(ml_dtypes.real(x)) & (ml_dtypes.imag(x) == 0.0) - elif ufunc == np.cosh: - # cosh(-inf+0j) signs 0j wrong - mismatch = (ml_dtypes.real(x) == -np.inf) & (ml_dtypes.imag(x) == 0.0) + if ufunc == np.cos: + # cos(1+infj) returns inf+infj instead of inf-infj + mismatch = (ml_dtypes.real(x) == 1.0) & (ml_dtypes.imag(x) == np.inf) + elif ufunc == np.sinh: + # sinh(+/-inf+0j) returns +/-inf+infj instead of +/-inf+0j + mismatch = np.isinf(ml_dtypes.real(x)) & (ml_dtypes.imag(x) == 0.0) + elif ufunc == np.cosh: + # cosh(-inf+0j) signs 0j wrong + mismatch = (ml_dtypes.real(x) == -np.inf) & (ml_dtypes.imag(x) == 0.0) expected = expected[~mismatch] result = result[~mismatch] @@ -473,10 +459,7 @@ def test_binary_ufuncs(sctype, ufunc): if ufunc == np.power: # TODO(seberg): std::power deals poorly with some values, drop for now. - # Suppress the NumPy <2.5 RuntimeWarning from ml_dtypes.real. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RuntimeWarning) - x = x[(ml_dtypes.real(x) != 0) & np.isfinite(x)] + x = x[(ml_dtypes.real(x) != 0) & np.isfinite(x)] y = x[:, np.newaxis] @@ -506,24 +489,3 @@ def test_dot_product(sctype): result = np.dot(x, y) expected = np.dot(x.astype(np.complex64), y.astype(np.complex64)) np.testing.assert_allclose(complex(result), complex(expected), rtol=1e-2) - - -@pytest.mark.parametrize( - "fn_name,fn", [("real", ml_dtypes.real), ("imag", ml_dtypes.imag)] -) -def test_real_imag_warning_on_old_numpy(fn_name, fn, monkeypatch): - """On NumPy <2.5, ml_dtypes.real/imag must emit RuntimeWarning naming the function.""" - # Force the version check to think we're on old NumPy - monkeypatch.setattr(np, "__version__", "2.4.0") - arr = np.array([1 + 2j, 3 + 4j], dtype=ml_dtypes.bcomplex32) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - fn(arr) - assert len(w) == 1 - assert issubclass(w[0].category, RuntimeWarning) - msg = str(w[0].message) - assert f"ml_dtypes.{fn_name}" in msg, ( - f"warning should name ml_dtypes.{fn_name}, got: {msg}" - ) - assert "miscomputes" in msg - assert "upgrading to NumPy" in msg