Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,31 @@ jobs:
- name: Run tests
run: |
pytest -n auto
build-numpy25:
name: Python 3.12 with NumPy 2.5 (pre-release)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
submodules: true
persist-credentials: false
- name: Set up Python 3.12
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
python-version: "3.12"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install scikit-build-core cmake wheel
python -m pip install -U --pre "numpy==2.5.*" \
-i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple
python -c "import numpy; print(f'{numpy.__version__=}')"
- name: Build ml_dtypes
run: |
python -m pip install .[dev] --no-build-isolation
- name: Run tests
run: |
pytest -n auto
build-oldest-numpy:
name: Python 3.10 with oldest supported numpy
runs-on: ubuntu-latest
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Drop support for Python 3.9, which reached end-of-life in October 2025.
* Drop support for NumPy < 1.24.
* `arr.real` and `arr.imag` now return correct results for `bcomplex32` and
`complex32` arrays on NumPy 2.5+ ([#355](https://github.com/jax-ml/ml_dtypes/issues/355)).

## [0.5.4] - 2025-11-17

Expand Down
19 changes: 19 additions & 0 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
"uint1",
"uint2",
"uint4",
"real",
"imag",
]

from ml_dtypes._finfo import finfo
Expand Down Expand Up @@ -84,6 +86,23 @@
bcomplex32: type[_np.generic]
complex32: type[_np.generic]

# Augment the C++ extension's terse docstring with a clearer class summary.
bcomplex32.__doc__ = (
"complex<bfloat16>: a 4-byte complex number pairing two bfloat16\n"
"halves (real + imaginary), exposed as an ml_dtypes extension dtype.\n\n"
"WARNING: NumPy does not natively understand this custom complex dtype.\n"
"On NumPy <2.5, arr.real / arr.imag are SILENTLY wrong; use\n"
"ml_dtypes.real() / ml_dtypes.imag() instead (NumPy 2.5+ fixes them).\n\n"
"These complex-aware builtins also do NOT recognize this dtype on ANY\n"
"NumPy version -- cast to np.complex64 first, or use the workaround:\n"
" np.vdot(a,b) -> np.dot(np.conjugate(a), b)\n"
" np.linalg.norm(a) -> np.linalg.norm(a.astype(np.complex64))\n"
" np.iscomplex(a) -> ml_dtypes.imag(a) != 0\n"
" np.angle(a) -> np.arctan2(ml_dtypes.imag(a), ml_dtypes.real(a))\n"
" np.linalg.det/inv -> cast to np.complex64 first (else they raise)\n"
"np.abs, conjugate, arithmetic, reductions, np.dot/inner/outer, casts OK."
)


def real(x: _np.ndarray) -> _np.ndarray:
"""Return the real part of a complex array.
Expand Down
126 changes: 125 additions & 1 deletion ml_dtypes/_src/custom_complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ limitations under the License.
#include <limits> // NOLINT
#include <locale> // NOLINT
#include <memory> // NOLINT
#include <sstream> // NOLINT
#include <sstream> // NOLINT
#include <type_traits> // NOLINT (std::is_standard_layout_v)
#include <vector> // NOLINT
// Place `<locale>` before <Python.h> to avoid a build failure in macOS.
#include <Python.h>
Expand Down Expand Up @@ -904,6 +905,125 @@ bool RegisterComplexUFuncs(PyObject* numpy) {
return ok;
}


// Identical to the NumPy code, we could actually do without the loop.
// (Supports full ufunc path, although this is currently unexposed in NumPy.)
template <bool real_part>
static NPY_CASTING
complex_to_real_resolve_descriptors(
PyObject *NPY_UNUSED(self),
PyArray_DTypeMeta *const dtypes[2],
PyArray_Descr *const given_descrs[2],
PyArray_Descr *loop_descrs[2],
npy_intp *view_offset)
{
Py_INCREF(given_descrs[0]);
loop_descrs[0] = given_descrs[0];
Py_INCREF(dtypes[1]->singleton);
loop_descrs[1] = dtypes[1]->singleton;

if (PyDataType_ISBYTESWAPPED(loop_descrs[0])) {
Py_SETREF(
loop_descrs[1], PyArray_DescrNewByteorder(loop_descrs[1], NPY_SWAP));
if (loop_descrs[1] == NULL) {
Py_DECREF(loop_descrs[0]);
return _NPY_ERROR_OCCURRED_IN_CAST;
}
}
if constexpr (real_part) {
*view_offset = 0;
}
else {
*view_offset = PyDataType_ELSIZE(loop_descrs[1]);
}
return NPY_NO_CASTING;
}


/* We shouldn't normally use it, but define a simple loop anyway. */
template <typename T, bool real_part>
static int extract_complex_part_loop(
PyArrayMethod_Context *context, char *const data[],
npy_intp const dimensions[], npy_intp const strides[],
NpyAuxData *NPY_UNUSED(auxdata))
{
using real_type = typename T::value_type;
// Both complex_to_real_resolve_descriptors' view_offset (0 for real,
// sizeof(real_type) for imag) and this loop's reinterpret-based extraction
// assume a complex element is exactly two consecutive real_type halves with
// no padding. Guard that invariant at compile time per instantiated type.
static_assert(sizeof(T) == 2 * sizeof(real_type),
"complex element must be exactly two value_types, no padding");
static_assert(std::is_standard_layout_v<T>,
"complex element must be standard-layout so reinterpret-based "
"real/imag extraction is well-defined");
static_assert(alignof(T) == alignof(real_type),
"complex alignment must match its half-type alignment");
npy_intp N = dimensions[0];
char *in = data[0];
char *out = data[1];
npy_intp istride = strides[0];
npy_intp ostride = strides[1];

if constexpr (!real_part) {
in += sizeof(real_type);
}

while (N--) {
real_type value = *reinterpret_cast<real_type *>(in);
*reinterpret_cast<real_type *>(out) = value;
in += istride;
out += ostride;
}
return 0;
}

template <typename T, bool real_part>
int RegisterRealImag(PyArray_DTypeMeta* complex_dtype) {
using real_type = typename T::value_type;
Safe_PyObjectPtr real_descr = make_safe(
(PyObject*)PyArray_DescrFromType(TypeDescriptor<real_type>::Dtype()));
if (!real_descr) {
return -1;
}

PyType_Slot meth_slots[] = {
{NPY_METH_resolve_descriptors, (void*)&complex_to_real_resolve_descriptors<real_part>},
{NPY_METH_strided_loop, (void*)&extract_complex_part_loop<T, real_part>},
{0, nullptr},
};
PyArray_DTypeMeta* dtypes[2] = {complex_dtype, NPY_DTYPE(real_descr.get())};
PyArrayMethod_Spec meth_spec;
meth_spec.name = "generic_real_imag_loop";
meth_spec.flags = NPY_METH_NO_FLOATINGPOINT_ERRORS;
meth_spec.nin = 1;
meth_spec.nout = 1;
meth_spec.dtypes = dtypes;
meth_spec.slots = meth_slots;
meth_spec.casting = NPY_NO_CASTING;
constexpr const char* ufunc_name = real_part ? "real" : "imag";
PyUFunc_LoopSlot loop_slots[] = {
{ufunc_name, &meth_spec},
{nullptr, nullptr},
};
return PyUFunc_AddLoopsFromSpecs(loop_slots);
}


template <typename T>
int RegisterRealAndImag(PyArray_DTypeMeta* complex_dtype) {
// real()/imag() helpers landed in NumPy 2.5 (NPY_2_5_API_VERSION == 0x16).
// Gate registration on the runtime API version so older NumPy skips it.
if (PyArray_RUNTIME_VERSION < 0x16) {
return 0;
}
if (RegisterRealImag<T, true>(complex_dtype) < 0) {
return -1;
}
return RegisterRealImag<T, false>(complex_dtype);
}


template <typename T>
bool RegisterComplexDtype(PyObject* numpy) {
// bases must be a tuple for Python 3.9 and earlier. Change to just pass
Expand Down Expand Up @@ -963,6 +1083,10 @@ bool RegisterComplexDtype(PyObject* numpy) {
CustomComplexType<T>::npy_descr =
PyArray_DescrFromType(TypeDescriptor<T>::npy_type);

if (RegisterRealAndImag<T>(NPY_DTYPE(CustomComplexType<T>::npy_descr)) < 0) {
return false;
}

Safe_PyObjectPtr typeDict_obj =
make_safe(PyObject_GetAttrString(numpy, "sctypeDict"));
if (!typeDict_obj) return false;
Expand Down
22 changes: 22 additions & 0 deletions ml_dtypes/_src/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@ limitations under the License.
#include "numpy/arrayobject.h"
#include "numpy/arrayscalars.h"
#include "numpy/ufuncobject.h"
#include "numpy/dtype_api.h"


#ifndef PyUFunc_AddLoopsFromSpecs
// Backport `PyUFunc_AddLoopsFromSpecs` for conditional use if we are
// on NumPy 2.5+ at runtime. (function available with 2.4, but not imag/real)
#if NPY_API_VERSION < 0x15
typedef struct {
const char *name;
PyArrayMethod_Spec *spec;
} PyUFunc_LoopSlot; // defined starting NumPy 2.4+
#endif

static inline int PyUFunc_AddLoopsFromSpecs(PyUFunc_LoopSlot *loop_specs) {
if (PyArray_RUNTIME_VERSION < 0x15) {
return 0; // no-op as function is not available.
}
return (*(int (*)(PyUFunc_LoopSlot *))PyUFunc_API[47])(loop_specs);
}

#endif


namespace ml_dtypes {

Expand Down
21 changes: 19 additions & 2 deletions ml_dtypes/tests/custom_complex_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,21 @@ def test_real_imag_arrays(sctype):
np.testing.assert_array_equal(imag_part, [2.0, 4.0])


@pytest.mark.parametrize("sctype", COMPLEX_SCTYPES)
@pytest.mark.xfail(
np.lib.NumpyVersion(np.__version__) < "2.5.0.dev0",
reason="2.5 introduced real and imag helpers."
)
def test_real_imag_arrays_numpy25(sctype):
# Test ml_dtypes.real() and ml_dtypes.imag() helpers (NumPy 2.5+ path).
arr = np.array([1 + 2j, 3 + 4j], dtype=sctype)
real_part = ml_dtypes.real(arr)
imag_part = ml_dtypes.imag(arr)
expected_dtype = ml_dtypes.finfo(sctype).dtype # the real one
assert real_part.dtype == imag_part.dtype == expected_dtype
np.testing.assert_array_equal(real_part, [1.0, 3.0])
np.testing.assert_array_equal(imag_part, [2.0, 4.0])

@pytest.mark.parametrize("sctype", COMPLEX_SCTYPES)
@pytest.mark.parametrize("value", COMPLEX_VALUES)
def test_str_repr(sctype, value):
Expand Down Expand Up @@ -232,8 +247,10 @@ def test_cast_from_float(sctype, from_dtype):
x = np.array([1.0, 2.0, 3.0], dtype=from_dtype)
y = x.astype(sctype)
assert y.dtype == sctype
np.testing.assert_array_equal(ml_dtypes.real(y).astype(np.float32), x)
np.testing.assert_array_equal(ml_dtypes.imag(y).astype(np.float32), 0.0)
real_part = ml_dtypes.real(y).astype(np.float32)
imag_part = ml_dtypes.imag(y).astype(np.float32)
np.testing.assert_array_equal(real_part, x)
np.testing.assert_array_equal(imag_part, 0.0)


@pytest.mark.parametrize("sctype", COMPLEX_SCTYPES)
Expand Down
Loading