diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 45f27d31..dbffa929 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -123,7 +123,7 @@ jobs: - name: Build ml_dtypes run: | python -m pip install .[dev] - python -m pip install numpy==1.24.0 # keep in sync with oldest numpy version in pyproject.toml + python -m pip install numpy==2.0.0 # keep in sync with oldest numpy version in pyproject.toml - name: Run tests run: | pytest -n auto diff --git a/ml_dtypes/_src/custom_complex.h b/ml_dtypes/_src/custom_complex.h index c640aee7..7d25997b 100644 --- a/ml_dtypes/_src/custom_complex.h +++ b/ml_dtypes/_src/custom_complex.h @@ -35,8 +35,9 @@ limitations under the License. #include #include "Eigen/Core" -#include "ml_dtypes/_src/common.h" // NOLINT -#include "ml_dtypes/_src/ufuncs.h" // NOLINT +#include "ml_dtypes/_src/common.h" // NOLINT +#include "ml_dtypes/_src/dtype_common.h" // NOLINT +#include "ml_dtypes/_src/ufuncs.h" // NOLINT #include "ml_dtypes/include/complex_types.h" #undef copysign // TODO(ddunleavy): temporary fix for Windows bazel build @@ -69,6 +70,9 @@ struct CustomComplexType { static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; + + // New-style DType metaclass object. + static PyArray_DTypeMeta dtype_meta; }; template @@ -79,6 +83,14 @@ template PyArray_DescrProto CustomComplexType::npy_descr_proto; template PyArray_Descr* CustomComplexType::npy_descr = nullptr; +template +PyArray_DTypeMeta CustomComplexType::dtype_meta = {}; + +// True if `meta` is one of our custom complex DTypes. +inline bool IsCustomComplexDType(const PyArray_DTypeMeta* meta) { + return meta == &CustomComplexType::dtype_meta || + meta == &CustomComplexType::dtype_meta; +} // Representation of a Python custom float object. template @@ -904,6 +916,96 @@ bool RegisterComplexUFuncs(PyObject* numpy) { return ok; } +// --------------------------------------------------------------------------- +// New-style DType slot functions for CustomComplex types +// --------------------------------------------------------------------------- + +template +static PyObject* NPyCustomComplex_NewStyleGetItem(PyArray_Descr* /*descr*/, + char* data) { + return NPyCustomComplex_GetItem(data, /*arr=*/nullptr); +} + +template +static int NPyCustomComplex_NewStyleSetItem(PyArray_Descr* /*descr*/, + PyObject* item, char* data) { + return NPyCustomComplex_SetItem(item, data, /*arr=*/nullptr); +} + +template +static PyArray_Descr* NPyCustomComplex_EnsureCanonical(PyArray_Descr* self) { + Py_INCREF(self); + return self; +} + +template +static PyArray_Descr* NPyCustomComplex_DefaultDescr(PyArray_DTypeMeta* cls) { + Py_INCREF(cls->singleton); + return cls->singleton; +} + +template +static PyArray_DTypeMeta* NPyCustomComplex_CommonDType( + PyArray_DTypeMeta* cls, PyArray_DTypeMeta* other) { + if (cls == other) { + Py_INCREF(cls); + return cls; + } + // Python abstract scalars defer to the concrete type. + if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType || + other == &PyArray_PyComplexDType) { + Py_INCREF(cls); + return cls; + } + + switch (other->type_num) { + // bool, ints, half, float: wrap in the smallest complex that holds both. + // Our custom complex types all fit in cfloat. + case NPY_BOOL: + case NPY_BYTE: case NPY_SHORT: case NPY_INT: + case NPY_LONG: case NPY_LONGLONG: + case NPY_UBYTE: case NPY_USHORT: case NPY_UINT: + case NPY_ULONG: case NPY_ULONGLONG: + case NPY_HALF: case NPY_FLOAT: + Py_INCREF(reinterpret_cast(&PyArray_CFloatDType)); + return &PyArray_CFloatDType; + case NPY_DOUBLE: + Py_INCREF(reinterpret_cast(&PyArray_CDoubleDType)); + return &PyArray_CDoubleDType; + case NPY_LONGDOUBLE: + Py_INCREF(reinterpret_cast(&PyArray_CLongDoubleDType)); + return &PyArray_CLongDoubleDType; + // Built-in complex: our types are smaller, return other. + case NPY_CFLOAT: case NPY_CDOUBLE: case NPY_CLONGDOUBLE: + Py_INCREF(other); + return other; + + default: + break; + } + + // ---- Our own custom DTypes ---- + // Custom float or custom int: all fit in cfloat alongside our complex. + if (IsCustomFloatDType(other) || IsCustomIntDType(other)) { + Py_INCREF(reinterpret_cast(&PyArray_CFloatDType)); + return &PyArray_CFloatDType; + } + + // Another custom complex: both fit in cfloat. + if (IsCustomComplexDType(other)) { + if (cls->type_num < other->type_num) { + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); + } + Py_INCREF(reinterpret_cast(&PyArray_CFloatDType)); + return &PyArray_CFloatDType; + } + + // Unknown user type: return NotImplemented. + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); +} + template bool RegisterComplexDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -928,7 +1030,7 @@ bool RegisterComplexDtype(PyObject* numpy) { return false; } - // Initializes the NumPy descriptor. + // Initializes the NumPy ArrFuncs (used by legacy code paths after the swap). PyArray_ArrFuncs& arr_funcs = CustomComplexType::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyCustomComplex_GetItem; @@ -937,29 +1039,44 @@ bool RegisterComplexDtype(PyObject* numpy) { arr_funcs.copyswapn = NPyCustomComplex_CopySwapN; arr_funcs.copyswap = NPyCustomComplex_CopySwap; arr_funcs.nonzero = NPyCustomComplex_NonZero; - arr_funcs.fill = nullptr; // NPyCustomComplex_Fill; + arr_funcs.fill = nullptr; arr_funcs.dotfunc = NPyCustomComplex_DotFunc; arr_funcs.compare = NPyCustomComplex_CompareFunc; - arr_funcs.argmax = nullptr; // NumPy defines them, but it's shaky + arr_funcs.argmax = nullptr; arr_funcs.argmin = nullptr; - // This is messy, but that's because the NumPy 2.0 API transition is messy. - // Before 2.0, NumPy assumes we'll keep the descriptor passed in to - // RegisterDataType alive, because it stores its pointer. - // After 2.0, the proto and descriptor types diverge, and NumPy allocates - // and manages the lifetime of the descriptor itself. + // Prepare the legacy descriptor proto referenced by the DType spec below. PyArray_DescrProto& descr_proto = CustomComplexType::npy_descr_proto; descr_proto = GetCustomComplexDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); + descr_proto.f = &arr_funcs; + + PyArray_DTypeMeta& dm = CustomComplexType::dtype_meta; + if (!InitDTypeMeta(&dm, TypeDescriptor::kTypeName)) { + return false; + } - TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); - if (TypeDescriptor::npy_type < 0) { + PyType_Slot dtype_slots[] = { + {NPY_DT_legacy_descriptor_proto, + reinterpret_cast(&descr_proto)}, + {NPY_DT_getitem, + reinterpret_cast(NPyCustomComplex_NewStyleGetItem)}, + {NPY_DT_setitem, + reinterpret_cast(NPyCustomComplex_NewStyleSetItem)}, + {NPY_DT_ensure_canonical, + reinterpret_cast(NPyCustomComplex_EnsureCanonical)}, + {NPY_DT_default_descr, + reinterpret_cast(NPyCustomComplex_DefaultDescr)}, + {NPY_DT_common_dtype, + reinterpret_cast(NPyCustomComplex_CommonDType)}, + {0, nullptr}}; + if (InitDTypeFromSlots(&dm, reinterpret_cast(type), + dtype_slots) < 0) { return false; } + TypeDescriptor::npy_type = dm.type_num; - // TODO(phawkins): We intentionally leak the pointer to the descriptor. - // Implement a better module destructor to handle this. CustomComplexType::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); diff --git a/ml_dtypes/_src/custom_float.h b/ml_dtypes/_src/custom_float.h index bf2568a7..16aae33e 100644 --- a/ml_dtypes/_src/custom_float.h +++ b/ml_dtypes/_src/custom_float.h @@ -35,8 +35,9 @@ limitations under the License. #include #include "Eigen/Core" -#include "ml_dtypes/_src/common.h" // NOLINT -#include "ml_dtypes/_src/ufuncs.h" // NOLINT +#include "ml_dtypes/_src/common.h" // NOLINT +#include "ml_dtypes/_src/dtype_common.h" // NOLINT +#include "ml_dtypes/_src/ufuncs.h" // NOLINT #undef copysign // TODO(ddunleavy): temporary fix for Windows bazel build // Possible this has to do with numpy.h being included before @@ -66,6 +67,10 @@ struct CustomFloatType { static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; + + // New-style DType metaclass object. Zero-initialized; fields are filled in + // at registration time before PyType_Ready is called. + static PyArray_DTypeMeta dtype_meta; }; template @@ -76,6 +81,24 @@ template PyArray_DescrProto CustomFloatType::npy_descr_proto; template PyArray_Descr* CustomFloatType::npy_descr = nullptr; +template +PyArray_DTypeMeta CustomFloatType::dtype_meta = {}; + +// True if `meta` is one of our custom floating-point DTypes. +inline bool IsCustomFloatDType(const PyArray_DTypeMeta* meta) { + return meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta || + meta == &CustomFloatType::dtype_meta; +} // Representation of a Python custom float object. template @@ -841,6 +864,135 @@ bool RegisterFloatUFuncs(PyObject* numpy) { return ok; } +// --------------------------------------------------------------------------- +// New-style DType slot functions for CustomFloat types +// --------------------------------------------------------------------------- + +// New-style getitem: (PyArray_Descr*, char*) -> PyObject* +template +static PyObject* NPyCustomFloat_NewStyleGetItem(PyArray_Descr* /*descr*/, + char* data) { + return NPyCustomFloat_GetItem(data, /*arr=*/nullptr); +} + +// New-style setitem: (PyArray_Descr*, PyObject*, char*) -> int +template +static int NPyCustomFloat_NewStyleSetItem(PyArray_Descr* /*descr*/, + PyObject* item, char* data) { + return NPyCustomFloat_SetItem(item, data, /*arr=*/nullptr); +} + +// ensure_canonical: for a non-parametric dtype just return self. +template +static PyArray_Descr* NPyCustomFloat_EnsureCanonical(PyArray_Descr* self) { + Py_INCREF(self); + return self; +} + +// default_descr: return the singleton. +// This avoids use_new_as_default calling dm() -> arraydescr_new, which fails +// for legacy-flagged DTypes because the legacy-check branch errors out. +template +static PyArray_Descr* NPyCustomFloat_DefaultDescr(PyArray_DTypeMeta* cls) { + Py_INCREF(cls->singleton); + return cls->singleton; +} + +// True if every value of Src is exactly representable in Dst: Dst must have +// at least as many mantissa bits (precision) and at least as much exponent +// range as Src. +template +static constexpr bool CustomFloatSafeTo() { + return std::numeric_limits::digits >= std::numeric_limits::digits && + std::numeric_limits::max_exponent >= + std::numeric_limits::max_exponent; +} + +template +static PyArray_DTypeMeta* NPyCustomFloat_CommonDType(PyArray_DTypeMeta* cls, + PyArray_DTypeMeta* other) { + if (cls == other) { + Py_INCREF(cls); + return cls; + } + // Python abstract scalars defer to the concrete type. (should add complex here) + if (other == &PyArray_PyLongDType || other == &PyArray_PyFloatDType) { + Py_INCREF(cls); + return cls; + } + + constexpr bool is_bfloat16 = std::is_same_v; + + if (PyTypeNum_ISINTEGER(other->type_num)) { + if (is_bfloat16 && (other->type_num == NPY_BYTE || other->type_num == NPY_UBYTE)) { + Py_INCREF(cls); + return cls; + } + /* Our precision is irrelevant, the integer one is higher. */ + return PyArray_CommonDType(&PyArray_PyFloatDType, other); + } + + switch (other->type_num) { + case NPY_BOOL: + Py_INCREF(cls); + return cls; + case NPY_HALF: + if (is_bfloat16) { + Py_INCREF(reinterpret_cast(&PyArray_FloatDType)); + return &PyArray_FloatDType; + } + [[fallthrough]]; + case NPY_FLOAT: case NPY_DOUBLE: case NPY_LONGDOUBLE: + [[fallthrough]]; + case NPY_CFLOAT: case NPY_CDOUBLE: case NPY_CLONGDOUBLE: + Py_INCREF(other); + return other; + default: + break; + } + + // ---- Our own custom DTypes ---- + // Another custom float: use compile-time safe-cast predicate to pick the + // wider type; fall back to float32 when neither contains the other. + // T is known at compile time so all CustomFloatSafeTo calls fold away. +#define TRY_CUSTOM_FLOAT(OtherT) \ + if (other == &CustomFloatType::dtype_meta) { \ + if constexpr (CustomFloatSafeTo()) { \ + Py_INCREF(other); return other; \ + } else if constexpr (CustomFloatSafeTo()) { \ + Py_INCREF(cls); return cls; \ + } else { \ + Py_INCREF(reinterpret_cast(&PyArray_FloatDType)); \ + return &PyArray_FloatDType; \ + } \ + } + TRY_CUSTOM_FLOAT(bfloat16) + TRY_CUSTOM_FLOAT(float8_e3m4) + TRY_CUSTOM_FLOAT(float8_e4m3) + TRY_CUSTOM_FLOAT(float8_e4m3b11fnuz) + TRY_CUSTOM_FLOAT(float8_e4m3fn) + TRY_CUSTOM_FLOAT(float8_e4m3fnuz) + TRY_CUSTOM_FLOAT(float8_e5m2) + TRY_CUSTOM_FLOAT(float8_e5m2fnuz) + TRY_CUSTOM_FLOAT(float6_e2m3fn) + TRY_CUSTOM_FLOAT(float6_e3m2fn) + TRY_CUSTOM_FLOAT(float4_e2m1fn) + TRY_CUSTOM_FLOAT(float8_e8m0fnu) +#undef TRY_CUSTOM_FLOAT + + // Custom int: float dominates. NPyIntN_CommonDType returns NotImplemented + // for user types it can't see, so we handle this side explicitly. + if (IsCustomIntDType(other)) { + Py_INCREF(cls); + return cls; + } + + // Custom complex or unknown user type: swapping will work (NPyCustomComplex + // handles complex+float and returns the appropriate complex result). + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); +} + template bool RegisterFloatDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -865,7 +1017,7 @@ bool RegisterFloatDtype(PyObject* numpy) { return false; } - // Initializes the NumPy descriptor. + // Initializes the NumPy ArrFuncs (used by legacy code paths after the swap). PyArray_ArrFuncs& arr_funcs = CustomFloatType::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyCustomFloat_GetItem; @@ -880,23 +1032,40 @@ bool RegisterFloatDtype(PyObject* numpy) { arr_funcs.argmax = NPyCustomFloat_ArgMaxFunc; arr_funcs.argmin = NPyCustomFloat_ArgMinFunc; - // This is messy, but that's because the NumPy 2.0 API transition is messy. - // Before 2.0, NumPy assumes we'll keep the descriptor passed in to - // RegisterDataType alive, because it stores its pointer. - // After 2.0, the proto and descriptor types diverge, and NumPy allocates - // and manages the lifetime of the descriptor itself. + // Prepare the legacy descriptor proto referenced by the DType spec below. PyArray_DescrProto& descr_proto = CustomFloatType::npy_descr_proto; descr_proto = GetCustomFloatDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); + descr_proto.f = &arr_funcs; + + PyArray_DTypeMeta& dm = CustomFloatType::dtype_meta; + if (!InitDTypeMeta(&dm, TypeDescriptor::kTypeName)) { + return false; + } - TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); - if (TypeDescriptor::npy_type < 0) { + PyType_Slot dtype_slots[] = { + {NPY_DT_legacy_descriptor_proto, + reinterpret_cast(&descr_proto)}, + {NPY_DT_getitem, + reinterpret_cast(NPyCustomFloat_NewStyleGetItem)}, + {NPY_DT_setitem, + reinterpret_cast(NPyCustomFloat_NewStyleSetItem)}, + {NPY_DT_ensure_canonical, + reinterpret_cast(NPyCustomFloat_EnsureCanonical)}, + {NPY_DT_default_descr, + reinterpret_cast(NPyCustomFloat_DefaultDescr)}, + {NPY_DT_common_dtype, + reinterpret_cast(NPyCustomFloat_CommonDType)}, + {0, nullptr}}; + if (InitDTypeFromSlots(&dm, reinterpret_cast(type), + dtype_slots) < 0) { return false; } + TypeDescriptor::npy_type = dm.type_num; - // TODO(phawkins): We intentionally leak the pointer to the descriptor. - // Implement a better module destructor to handle this. + // The singleton is owned by dm; grab a borrowed reference for npy_descr. + // PyArray_DescrFromType returns a new reference — intentionally leaked. CustomFloatType::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); diff --git a/ml_dtypes/_src/dtype_common.h b/ml_dtypes/_src/dtype_common.h new file mode 100644 index 00000000..304439c7 --- /dev/null +++ b/ml_dtypes/_src/dtype_common.h @@ -0,0 +1,231 @@ +/* Copyright 2025 The ml_dtypes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* + * Shared machinery for defining new style array methods (casts/ufuncs) and + * creating the new style DTypes related to PyArrayInitDTypeMeta_FromSpec. + */ + +#ifndef ML_DTYPES_DTYPE_COMMON_H_ +#define ML_DTYPES_DTYPE_COMMON_H_ + +// clang-format off +#include "ml_dtypes/_src/numpy.h" // NOLINT (must be first) +// clang-format on + +#include +#include +#include "numpy/arrayobject.h" +#include "numpy/dtype_api.h" + +#include "common.h" + +namespace ml_dtypes { + +// --------------------------------------------------------------------------- +// Metaclass setup +// --------------------------------------------------------------------------- + +// NumPy < 2.5 forces every DType metaclass to define tp_repr / tp_str, so we +// provide real functions that forward to the base PyArrayDescr_Type behaviour. +inline PyObject* DTypeRepr(PyObject* self) { + return PyArrayDescr_Type.tp_repr(self); +} +inline PyObject* DTypeStr(PyObject* self) { + return PyArrayDescr_Type.tp_str(self); +} + +// Initializes the common fields of a new-style DType metaclass object and runs +// PyType_Ready. All our DTypes derive from PyArrayDescr_Type and are plain +// value types, so the only thing that varies is the name. +inline bool InitDTypeMeta(PyArray_DTypeMeta* dm, const char* name) { + auto* tp = reinterpret_cast(dm); + tp->tp_name = name; + tp->tp_base = &PyArrayDescr_Type; + tp->tp_flags = Py_TPFLAGS_DEFAULT; + tp->tp_basicsize = sizeof(_PyArray_LegacyDescr); + tp->tp_repr = DTypeRepr; + tp->tp_str = DTypeStr; + return PyType_Ready(tp) >= 0; +} + +// --------------------------------------------------------------------------- +// Within-dtype cast (copy / byte swap) +// --------------------------------------------------------------------------- + +// Identity used for the within-dtype copy. The cast loops below are templated +// on the operation applied to each element so they can be reused when porting +// the remaining casts off the legacy ArrFuncs path; for the copy it is inlined +// away to a plain load/store. +template +struct CopyOp { + T operator()(const T& x) const { return x; } +}; + +template +static int StridedUnaryLoop(PyArrayMethod_Context* /*context*/, + char* const data[], + const npy_intp dimensions[], + const npy_intp strides[], + NpyAuxData* /*auxdata*/) { + const npy_intp n = dimensions[0]; + const char* in = data[0]; + char* out = data[1]; + npy_intp stride_in, stride_out; + if constexpr (contiguous) { + stride_in = sizeof(In); + stride_out = sizeof(Out); + } else { + stride_in = strides[0]; + stride_out = strides[1]; + } + + Op op; + for (npy_intp i = 0; i < n; ++i) { + *reinterpret_cast(out) = op(*reinterpret_cast(in)); + in += stride_in; + out += stride_out; + } + return 0; +} + +// Unaligned within-dtype copy: a plain memcpy per element handles any +// alignment and stride. This also handles byte-swapping if needed. +template +static int UnalignedStridedCopyLoop(PyArrayMethod_Context* /*context*/, + char* const data[], + const npy_intp dimensions[], + const npy_intp strides[], + NpyAuxData* /*auxdata*/) { + const npy_intp n = dimensions[0]; + const char* in = data[0]; + char* out = data[1]; + for (npy_intp i = 0; i < n; ++i) { + std::memcpy(out, in, sizeof(T)); + if constexpr (swap) { + if constexpr (is_complex_v) { + static_assert(sizeof(T) == 4); // currently only have 32bit complex + ByteSwap16(out); + ByteSwap16(out + 2); + } + else if constexpr (sizeof(T) == 2) { + ByteSwap16(out); + } else if constexpr (sizeof(T) == 4) { + ByteSwap32(out); + } + else { + // static assert needs to depend on T, so check sizeof(T) is single byte. + static_assert(sizeof(T) == 1); + } + } + in += strides[0]; + out += strides[1]; + } + return 0; +} + +// resolve_descriptors for the within-dtype cast which doesn't do much +// since given dtypes are also the loop ones. Does indicate view and casting +// safety. (The default resolve_descriptors may normalize the byte order.) +static NPY_CASTING WithinDTypeCastResolve( + struct PyArrayMethodObject_tag* /*method*/, + PyArray_DTypeMeta* const* dtypes, PyArray_Descr* const* given_descrs, + PyArray_Descr** loop_descrs, npy_intp* view_offset) { + Py_INCREF(given_descrs[0]); + loop_descrs[0] = given_descrs[0]; + if (given_descrs[1] != nullptr) { + Py_INCREF(given_descrs[1]); + loop_descrs[1] = given_descrs[1]; + } else { + Py_INCREF(dtypes[1]->singleton); + loop_descrs[1] = dtypes[1]->singleton; + } + if (PyDataType_ISNOTSWAPPED(loop_descrs[0]) == + PyDataType_ISNOTSWAPPED(loop_descrs[1])) { + *view_offset = 0; + return NPY_NO_CASTING; + } + return NPY_EQUIV_CASTING; +} + +// get_loop for the within-dtype cast. Selects the appropriate specialization +// based on byte order, alignment and contiguity. +template +static int WithinDTypeCastGetLoop(PyArrayMethod_Context* context, int aligned, + int /*move_references*/, + const npy_intp* strides, + PyArrayMethod_StridedLoop** out_loop, + NpyAuxData** out_transferdata, + NPY_ARRAYMETHOD_FLAGS* flags) { + PyArray_Descr* const* descrs = context->descriptors; + *flags = NPY_METH_NO_FLOATINGPOINT_ERRORS; + *out_transferdata = nullptr; + const npy_intp elsize = static_cast(sizeof(T)); + + if (PyDataType_ISNOTSWAPPED(descrs[0]) != PyDataType_ISNOTSWAPPED(descrs[1])) { + *out_loop = UnalignedStridedCopyLoop; + } + else if (!aligned) { + *out_loop = UnalignedStridedCopyLoop; + } + else if (strides[0] == elsize && strides[1] == elsize) { + *out_loop = StridedUnaryLoop, T, T, true>; + } else { + *out_loop = StridedUnaryLoop, T, T, false>; + } + return 0; +} + +// --------------------------------------------------------------------------- +// Registration +// --------------------------------------------------------------------------- + +// Registers a fixed-size DType from its slots, wiring up the within-dtype cast +// shared by all our dtypes. Returns 0 on success and -1 on failure. +template +inline int InitDTypeFromSlots(PyArray_DTypeMeta* dm, PyTypeObject* scalar_type, + PyType_Slot* slots) { + PyArray_DTypeMeta* self_cast_dtypes[2] = {nullptr, nullptr}; + PyType_Slot self_cast_slots[] = { + {NPY_METH_resolve_descriptors, + reinterpret_cast(WithinDTypeCastResolve)}, + {NPY_METH_get_loop, + reinterpret_cast(WithinDTypeCastGetLoop)}, + {0, nullptr}}; + PyArrayMethod_Spec self_cast_spec; + self_cast_spec.name = "within_dtype_cast"; + self_cast_spec.nin = 1; + self_cast_spec.nout = 1; + self_cast_spec.casting = NPY_NO_CASTING; + self_cast_spec.flags = static_cast( + NPY_METH_SUPPORTS_UNALIGNED | NPY_METH_NO_FLOATINGPOINT_ERRORS); + self_cast_spec.dtypes = self_cast_dtypes; + self_cast_spec.slots = self_cast_slots; + // TODO(seberg): It would be good to define all other casts here as well. + PyArrayMethod_Spec* casts[] = {&self_cast_spec, nullptr}; + + PyArrayDTypeMeta_Spec dtype_spec; + dtype_spec.typeobj = scalar_type; + dtype_spec.flags = NPY_DT_NUMERIC; + dtype_spec.casts = casts; + dtype_spec.slots = slots; + dtype_spec.baseclass = nullptr; + + return PyArrayInitDTypeMeta_FromSpec(dm, &dtype_spec); +} + +} // namespace ml_dtypes + +#endif // ML_DTYPES_DTYPE_COMMON_H_ diff --git a/ml_dtypes/_src/intn_numpy.h b/ml_dtypes/_src/intn_numpy.h index 8e32a63c..461fefea 100644 --- a/ml_dtypes/_src/intn_numpy.h +++ b/ml_dtypes/_src/intn_numpy.h @@ -25,8 +25,9 @@ limitations under the License. // clang-format on #include "Eigen/Core" -#include "ml_dtypes/_src/common.h" // NOLINT -#include "ml_dtypes/_src/ufuncs.h" // NOLINT +#include "ml_dtypes/_src/common.h" // NOLINT +#include "ml_dtypes/_src/dtype_common.h" // NOLINT +#include "ml_dtypes/_src/ufuncs.h" // NOLINT #include "ml_dtypes/include/intn.h" #if NPY_ABI_VERSION < 0x02000000 @@ -56,6 +57,9 @@ struct IntNTypeDescriptor { static PyArray_ArrFuncs arr_funcs; static PyArray_DescrProto npy_descr_proto; static PyArray_Descr* npy_descr; + + // New-style DType metaclass object. + static PyArray_DTypeMeta dtype_meta; }; template @@ -66,6 +70,18 @@ template PyArray_DescrProto IntNTypeDescriptor::npy_descr_proto; template PyArray_Descr* IntNTypeDescriptor::npy_descr = nullptr; +template +PyArray_DTypeMeta IntNTypeDescriptor::dtype_meta = {}; + +// True if `meta` is one of our custom integer DTypes. +inline bool IsCustomIntDType(const PyArray_DTypeMeta* meta) { + return meta == &IntNTypeDescriptor::dtype_meta || + meta == &IntNTypeDescriptor::dtype_meta || + meta == &IntNTypeDescriptor::dtype_meta || + meta == &IntNTypeDescriptor::dtype_meta || + meta == &IntNTypeDescriptor::dtype_meta || + meta == &IntNTypeDescriptor::dtype_meta; +} // Representation of a Python custom integer object. template @@ -774,6 +790,84 @@ bool RegisterIntNUFuncs(PyObject* numpy) { return ok; } +// --------------------------------------------------------------------------- +// New-style DType slot functions for IntN types +// --------------------------------------------------------------------------- + +template +static PyObject* NPyIntN_NewStyleGetItem(PyArray_Descr* /*descr*/, char* data) { + return NPyIntN_GetItem(data, /*arr=*/nullptr); +} + +template +static int NPyIntN_NewStyleSetItem(PyArray_Descr* /*descr*/, PyObject* item, + char* data) { + return NPyIntN_SetItem(item, data, /*arr=*/nullptr); +} + +template +static PyArray_Descr* NPyIntN_EnsureCanonical(PyArray_Descr* self) { + Py_INCREF(self); + return self; +} + +template +static PyArray_Descr* NPyIntN_DefaultDescr(PyArray_DTypeMeta* cls) { + Py_INCREF(cls->singleton); + return cls->singleton; +} + +template +static PyArray_DTypeMeta* NPyIntN_CommonDType(PyArray_DTypeMeta* cls, + PyArray_DTypeMeta* other) { + if (cls == other) { + Py_INCREF(cls); + return cls; + } + // Python abstract scalars (weak promotion). A Python int defers to our + // integer type. For a Python float/complex we must NOT collapse to a + // concrete float/complex here: returning the abstract DType unchanged keeps + // the scalar weak, so NumPy resolves it correctly in context. Alone it + // becomes the default (float64 / complex128), e.g. result_type(int2, 1.0) == + // float64; alongside a concrete float it stays weak, e.g. + // result_type(int2, 1.0, float32) == float32 (in any argument order). + if (other == &PyArray_PyLongDType) { + Py_INCREF(cls); + return cls; + } + if (other == &PyArray_PyFloatDType || other == &PyArray_PyComplexDType) { + Py_INCREF(other); + return other; + } + + // Our intN types are smaller than every NumPy built-in except bool. + if (other->type_num == NPY_BOOL) { + Py_INCREF(cls); + return cls; + } + if (!PyTypeNum_ISUSERDEF(other->type_num)) { + Py_INCREF(other); + return other; + } + + // ---- Our own custom DTypes ---- + // Another custom int: lower type_num defers (swapping will work). + if (IsCustomIntDType(other)) { + if (cls->type_num < other->type_num) { + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); + } + // No cross-custom-int safe casts are registered; int16 contains all. + Py_INCREF(reinterpret_cast(&PyArray_Int16DType)); + return &PyArray_Int16DType; + } + + // Custom float or custom complex: swapping will work (NPyCustomFloat handles + // float+int, NPyCustomComplex handles complex+int). + Py_INCREF(Py_NotImplemented); + return reinterpret_cast(Py_NotImplemented); +} + template bool RegisterIntNDtype(PyObject* numpy) { // bases must be a tuple for Python 3.9 and earlier. Change to just pass @@ -799,7 +893,7 @@ bool RegisterIntNDtype(PyObject* numpy) { return false; } - // Initializes the NumPy descriptor. + // Initializes the NumPy ArrFuncs (used by legacy code paths after the swap). PyArray_ArrFuncs& arr_funcs = IntNTypeDescriptor::arr_funcs; PyArray_InitArrFuncs(&arr_funcs); arr_funcs.getitem = NPyIntN_GetItem; @@ -814,22 +908,38 @@ bool RegisterIntNDtype(PyObject* numpy) { arr_funcs.argmax = NPyIntN_ArgMaxFunc; arr_funcs.argmin = NPyIntN_ArgMinFunc; - // This is messy, but that's because the NumPy 2.0 API transition is messy. - // Before 2.0, NumPy assumes we'll keep the descriptor passed in to - // RegisterDataType alive, because it stores its pointer. - // After 2.0, the proto and descriptor types diverge, and NumPy allocates - // and manages the lifetime of the descriptor itself. + // Prepare the legacy descriptor proto referenced by the DType spec below. PyArray_DescrProto& descr_proto = IntNTypeDescriptor::npy_descr_proto; descr_proto = GetIntNDescrProto(); Py_SET_TYPE(&descr_proto, &PyArrayDescr_Type); descr_proto.typeobj = reinterpret_cast(type); + descr_proto.f = &arr_funcs; - TypeDescriptor::npy_type = PyArray_RegisterDataType(&descr_proto); - if (TypeDescriptor::npy_type < 0) { + PyArray_DTypeMeta& dm = IntNTypeDescriptor::dtype_meta; + if (!InitDTypeMeta(&dm, TypeDescriptor::kTypeName)) { return false; } - // TODO(phawkins): We intentionally leak the pointer to the descriptor. - // Implement a better module destructor to handle this. + + PyType_Slot dtype_slots[] = { + {NPY_DT_legacy_descriptor_proto, + reinterpret_cast(&descr_proto)}, + {NPY_DT_getitem, + reinterpret_cast(NPyIntN_NewStyleGetItem)}, + {NPY_DT_setitem, + reinterpret_cast(NPyIntN_NewStyleSetItem)}, + {NPY_DT_ensure_canonical, + reinterpret_cast(NPyIntN_EnsureCanonical)}, + {NPY_DT_default_descr, + reinterpret_cast(NPyIntN_DefaultDescr)}, + {NPY_DT_common_dtype, + reinterpret_cast(NPyIntN_CommonDType)}, + {0, nullptr}}; + if (InitDTypeFromSlots(&dm, reinterpret_cast(type), + dtype_slots) < 0) { + return false; + } + TypeDescriptor::npy_type = dm.type_num; + IntNTypeDescriptor::npy_descr = PyArray_DescrFromType(TypeDescriptor::npy_type); diff --git a/ml_dtypes/_src/npy_2_compat_new_dtypes.h b/ml_dtypes/_src/npy_2_compat_new_dtypes.h new file mode 100644 index 00000000..ce5c952e --- /dev/null +++ b/ml_dtypes/_src/npy_2_compat_new_dtypes.h @@ -0,0 +1,188 @@ +#ifndef ML_DTYPES__NPY_2_COMPAT_H_ +#define ML_DTYPES__NPY_2_COMPAT_H_ + +// This file vendors parts of npy_2_compat.h from NumPy 2.5 needed to compile +// with NumPy 2.0-2.4. +// When Python 3.11 is dropped one could instead depend on 2.5+ at build time. + +#if NPY_API_VERSION < 0x00000016 + +/* + * Backport of `NPY_DT_legacy_descriptor_proto` (and ABI fix for slot IDs). + * This backport allows dtypes that are currently implemented as legacy + * (i.e. have a kind, char, a character code, and only the byte-order parameter) + * to work with only minor changes on NumPy 2.0+ but use any part of the new + * DType API they want to. + * This also will allow us to deprecate the weirder parts of it, i.e. cast + * registration. + * (Possibly the only remaining change may be poor `dtype=` printing in + * arrays, which can be worked around.) + */ +/* + * `NPY_2_4_API_VERSION` and `NPY_2_5_API_VERSION` may not be defined when + * this header is vendored alongside an older `numpyconfig.h`. Provide + * fallback definitions so the rest of the backport can use named constants. + */ + #ifndef NPY_2_4_API_VERSION + #define NPY_2_4_API_VERSION 0x00000015 + #endif + #ifndef NPY_2_5_API_VERSION + #define NPY_2_5_API_VERSION 0x00000016 + #endif + + #if NPY_TARGET_VERSION < NPY_2_5_API_VERSION \ + && NPY_TARGET_VERSION >= NPY_2_0_API_VERSION + + #ifndef NPY_DT_legacy_descriptor_proto + #define NPY_DT_legacy_descriptor_proto ((1 << 11) - 1) + #endif + + #define _PyArrayInitDTypeMeta_FromSpec \ + (*(int (*)(PyArray_DTypeMeta *, PyArrayDTypeMeta_Spec *))PyArray_API[362]) + #undef PyArrayInitDTypeMeta_FromSpec + + static inline int PyArrayInitDTypeMeta_FromSpec( + PyArray_DTypeMeta *DType, PyArrayDTypeMeta_Spec *spec) + { + PyArray_DescrProto *proto = NULL; + if (spec->slots != NULL && spec->slots[0].slot == NPY_DT_legacy_descriptor_proto) { + proto = (PyArray_DescrProto *)spec->slots[0].pfunc; + } + + #if NPY_TARGET_VERSION < NPY_2_4_API_VERSION + /* + * In NumPy 2.4 the slot IDs ABI was accidentally changed, so we translate + * them even if `NPY_DT_legacy_descriptor_proto` is unused. The translation + * is idempotent. + */ + PyType_Slot *slot = spec->slots; + int bad_offset = (PyArray_RUNTIME_VERSION >= NPY_2_4_API_VERSION) + ? (1 << 10) : (1 << 11); + int good_offset = (PyArray_RUNTIME_VERSION >= NPY_2_4_API_VERSION) + ? (1 << 11) : (1 << 10); + while (slot->slot != 0 || slot->pfunc != NULL) { + if (slot->slot >= bad_offset && slot->slot < bad_offset + 30) { + slot->slot += good_offset - bad_offset; + } + slot++; + } + #endif + + if (proto == NULL || PyArray_RUNTIME_VERSION >= NPY_2_5_API_VERSION) { + return _PyArrayInitDTypeMeta_FromSpec(DType, spec); + } + + #if defined(Py_LIMITED_API) + PyErr_SetString(PyExc_RuntimeError, + "NPY_DT_legacy_descriptor_proto backport not supported in Python limited API"); + return -1; + #else + + /* + * Step 1: Register old-style with a garbage typeobj so that + * _PyArray_MapPyTypeToDType does NOT add the auto-DTypeMeta to the + * pytype-to-DType dict (it bails out on NPY_DT_is_legacy for non-generic + * types), regardless of whether the real scalar subclasses np.generic. + */ + PyArray_DescrProto new_proto = *proto; + new_proto.typeobj = &PyBaseObject_Type; + int typenum = PyArray_RegisterDataType(&new_proto); + if (typenum < 0) { + return -1; + } + + /* + * Step 2: Initialise the user's DType with new-style slots and casts. + * type_num stays at -1 / 0 for now; we fix it in step 3. + */ + PyArrayDTypeMeta_Spec new_spec = *spec; + new_spec.slots = &spec->slots[1]; /* skip proto slot */ + if (_PyArrayInitDTypeMeta_FromSpec(DType, &new_spec) < 0) { + return -1; + } + + /* + * Step 3: Steal the singleton descriptor and type_num from the legacy + * registration. Point the descriptor's Python type at the user's DType + * and fix up its typeobj field (which we temporarily set to + * PyBaseObject_Type in step 1). + */ + PyArray_Descr *descr = PyArray_DescrFromType(typenum); + if (descr == NULL) { + return -1; + } + + /* Save the auto-DTypeMeta so we can decref it after the swap. */ + PyObject *old_meta = (PyObject *)Py_TYPE(descr); + + DType->type_num = typenum; + /* PyArray_DescrFromType returns a new reference; transfer ownership. */ + DType->singleton = descr; + /* + * Set the legacy flag (bit 0 == _NPY_DT_LEGACY_FLAG) so NumPy uses + * legacy code paths (copyswap, ArrFuncs, etc.) where the new-style API + * doesn't cover them yet. + */ + DType->flags |= 1; + + /* Re-type the descriptor so it belongs to the user's DType class. */ + Py_INCREF(DType); + Py_SET_TYPE(descr, (PyTypeObject *)(DType)); + Py_DECREF(old_meta); + + /* + * Fix the descriptor's scalar-type field (it was set to + * PyBaseObject_Type in step 1 by PyArray_RegisterDataType copying + * proto->typeobj). + */ + Py_INCREF(proto->typeobj); + Py_XDECREF(descr->typeobj); + descr->typeobj = proto->typeobj; + + /* + * Initialize legacy ArrFuncs from the descriptor prototype. + */ + if (proto->f != NULL) { + PyArray_ArrFuncs *f = _PyDataType_GetArrFuncs(descr); + /* + * Preserve ArrFuncs that were explicitly set via the new API slots + * (step 2), and fill missing ones from the legacy prototype. + * getitem/setitem always come from the legacy descriptor path. + */ + if (proto->f->getitem != NULL) { + f->getitem = proto->f->getitem; + } + if (proto->f->setitem != NULL) { + f->setitem = proto->f->setitem; + } + #define NPY_PROTO_FILL_IF_NULL(FIELD) \ + if (f->FIELD == NULL) { \ + f->FIELD = proto->f->FIELD; \ + } + NPY_PROTO_FILL_IF_NULL(copyswap); + NPY_PROTO_FILL_IF_NULL(copyswapn); + NPY_PROTO_FILL_IF_NULL(compare); + NPY_PROTO_FILL_IF_NULL(argmax); + NPY_PROTO_FILL_IF_NULL(dotfunc); + NPY_PROTO_FILL_IF_NULL(scanfunc); + NPY_PROTO_FILL_IF_NULL(fromstr); + NPY_PROTO_FILL_IF_NULL(nonzero); + NPY_PROTO_FILL_IF_NULL(fill); + NPY_PROTO_FILL_IF_NULL(fillwithscalar); + NPY_PROTO_FILL_IF_NULL(scalarkind); + NPY_PROTO_FILL_IF_NULL(argmin); + #undef NPY_PROTO_FILL_IF_NULL + for (int i = 0; i < NPY_NSORTS; i++) { + f->sort[i] = proto->f->sort[i]; + f->argsort[i] = proto->f->argsort[i]; + } + } + #endif /* Py_LIMITED_API */ + return 0; + } + #endif + + +#endif + +#endif // ML_DTYPES__NPY_2_COMPAT_H_ diff --git a/ml_dtypes/_src/numpy.h b/ml_dtypes/_src/numpy.h index 8b55e4d9..0de805fe 100644 --- a/ml_dtypes/_src/numpy.h +++ b/ml_dtypes/_src/numpy.h @@ -22,6 +22,8 @@ limitations under the License. // Disallow Numpy 1.7 deprecated symbols. #define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#define NPY_TARGET_VERSION NPY_2_0_API_VERSION + // We import_array in the ml_dtypes init function only. #define PY_ARRAY_UNIQUE_SYMBOL _ml_dtypes_numpy_api @@ -38,6 +40,9 @@ limitations under the License. #include "numpy/arrayscalars.h" #include "numpy/ufuncobject.h" +// Needed to compile with NumPy < 2.5 (does nothing with newer NumPy) +#include "npy_2_compat_new_dtypes.h" + namespace ml_dtypes { // Import numpy. This wrapper function exists so that the diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index 410fccce..1ee8ee10 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -266,6 +266,7 @@ def testPickleable(self, float_type): x = np.arange(10, dtype=float_type) serialized = pickle.dumps(x) x_out = pickle.loads(serialized) + # NumPy 2.5+ could rely on NaNs working and not cast to floa32 self.assertEqual(x_out.dtype, x.dtype) np.testing.assert_array_equal(x_out.astype("float32"), x.astype("float32")) @@ -656,6 +657,27 @@ def testByteSwap(self, float_type): # 8-bit types should be unchanged self.assertEqual(original_bytes, swapped.tobytes()) + def testAstypeByteSwapped(self, float_type): + """Casting to a byte-swapped dtype goes through the within-dtype cast.""" + dt = np.dtype(float_type) + swapped_dt = dt.newbyteorder("S") + # The swapped dtype is still the same custom type, just a different order. + self.assertIs(swapped_dt.type, float_type) + + # NumPy 2.5+ we could rely on NaNs working and use a larger range + arr = np.arange(1, 31).astype(float_type) + swapped = arr.astype(swapped_dt) + self.assertIs(swapped.dtype.type, float_type) + + # Casting preserves the logical values regardless of byte order. + np.testing.assert_array_equal(swapped.astype(float_type), arr) + + if dt.itemsize > 1: + # The stored bytes really are swapped for multi-byte types. + self.assertEqual(swapped.tobytes(), arr.byteswap().tobytes()) + else: + self.assertEqual(swapped.tobytes(), arr.tobytes()) + BinaryOp = collections.namedtuple("BinaryOp", ["op"]) diff --git a/ml_dtypes/tests/result_type_test.py b/ml_dtypes/tests/result_type_test.py new file mode 100644 index 00000000..8804ab39 --- /dev/null +++ b/ml_dtypes/tests/result_type_test.py @@ -0,0 +1,294 @@ +# Copyright 2026 The ml_dtypes Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for np.result_type() across ml_dtypes custom DTypes.""" + +import itertools + +import ml_dtypes +import numpy as np +import pytest + +# Short aliases for readability in parametrize lists +bf16 = ml_dtypes.bfloat16 +f4 = ml_dtypes.float4_e2m1fn +f6_e2m3 = ml_dtypes.float6_e2m3fn +f6_e3m2 = ml_dtypes.float6_e3m2fn +f8_e3m4 = ml_dtypes.float8_e3m4 +f8_e4m3 = ml_dtypes.float8_e4m3 +f8_e4m3fn = ml_dtypes.float8_e4m3fn +f8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz +f8_e4m3b11 = ml_dtypes.float8_e4m3b11fnuz +f8_e5m2 = ml_dtypes.float8_e5m2 +f8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +f8_e8m0 = ml_dtypes.float8_e8m0fnu +bc32 = ml_dtypes.bcomplex32 +c32 = ml_dtypes.complex32 +i1, i2, i4 = ml_dtypes.int1, ml_dtypes.int2, ml_dtypes.int4 +u1, u2, u4 = ml_dtypes.uint1, ml_dtypes.uint2, ml_dtypes.uint4 + +ALL_CUSTOM_FLOATS = [bf16, f4, f6_e2m3, f6_e3m2, + f8_e3m4, f8_e4m3, f8_e4m3fn, f8_e4m3fnuz, + f8_e4m3b11, f8_e5m2, f8_e5m2fnuz, f8_e8m0] +ALL_INTN = [i1, i2, i4, u1, u2, u4] +ALL_CUSTOM_COMPLEX = [bc32, c32] + + +def rt(a, b): + return np.result_type(a, b) + + +# --------------------------------------------------------------------------- +# Custom float vs NumPy built-in types +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- bool: custom float always wins ---- + (bf16, np.bool_, bf16), + (f8_e4m3fn, np.bool_, f8_e4m3fn), + (f4, np.bool_, f4), + # ---- floats: pick the wider ---- + (f4, np.float16, np.float16), # f4 fits in float16 + (f8_e4m3fn, np.float16, np.float16), # float8 fits in float16 + (f8_e5m2, np.float16, np.float16), # float8 fits in float16 + (bf16, np.float16, np.float32), # incomparable → float32 + (f8_e4m3fn, np.float32, np.float32), # all custom floats fit in float32 + (bf16, np.float32, np.float32), + (bf16, np.float64, np.float64), + (f8_e4m3fn, np.float64, np.float64), + # ---- integers: PyArray_CommonDType decides ---- + (bf16, np.int8, bf16), # bfloat16 has 8 sig bits, int8 needs 7 → bf16 wins + (bf16, np.int16, np.float64), # bfloat16 has 8 sig bits, int16 needs 15 → float64 + (f8_e4m3fn, np.int8, np.float64), # float8 can't represent all int8 values + (f8_e4m3fn, np.int32, np.float64), # float8 can't represent all int32 values + # ---- complex: other always wins ---- + (bf16, np.complex64, np.complex64), + (f8_e4m3fn, np.complex64, np.complex64), + (bf16, np.complex128, np.complex128), + (f8_e4m3fn, np.complex128, np.complex128), +]) +def test_custom_float_vs_numpy(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Custom float vs custom float +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- same type ---- + (bf16, bf16, bf16), + (f8_e4m3fn, f8_e4m3fn, f8_e4m3fn), + (f4, f4, f4), + # ---- narrower fits safely into wider ---- + (f4, f6_e2m3, f6_e2m3), # f4 ⊂ f6_e2m3 (more exp + mantissa) + (f4, f8_e4m3fn, f8_e4m3fn), # f4 fits in every float8+ + (f4, bf16, bf16), # f4 fits in bfloat16 + (f8_e4m3fn, bf16, bf16), # float8 fits in bfloat16 + (f8_e5m2, bf16, bf16), # float8 fits in bfloat16 + (f8_e3m4, bf16, bf16), # float8 fits in bfloat16 + # ---- incomparable: one has more exp, other more mantissa → float32 ---- + (bf16, f8_e5m2, bf16), # f8_e5m2 fits in bf16 (bf16 > in all dims) + (f8_e4m3fn, f8_e5m2, np.float32), # e4m3 has more mantissa, e5m2 has more exp + (f8_e4m3fn, f8_e4m3fnuz, f8_e4m3fn), # same digits/max_exp → numeric_limits match; fn wins + (f6_e2m3, f6_e3m2, np.float32), # one has more mantissa, other more exp +]) +def test_custom_float_vs_custom_float(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Custom float vs custom int (float always dominates) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("float_t, int_t", [ + (bf16, i4), + (bf16, u4), + (bf16, i1), + (f8_e4m3fn, i4), + (f8_e4m3fn, u4), + (f8_e5m2, i2), + (f4, i1), +]) +def test_custom_float_beats_custom_int(float_t, int_t): + assert rt(float_t, int_t) == np.dtype(float_t) + assert rt(int_t, float_t) == np.dtype(float_t) # symmetric + + +# --------------------------------------------------------------------------- +# Custom int vs NumPy built-in types +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- bool: custom int always wins ---- + (i4, np.bool_, i4), + (u4, np.bool_, u4), + (i1, np.bool_, i1), + # ---- all other NumPy types: return other (intN is always smaller) ---- + (i4, np.int8, np.int8), + (i4, np.int16, np.int16), + (i4, np.int32, np.int32), + (i4, np.uint8, np.uint8), + (u4, np.int8, np.int8), + (i2, np.int8, np.int8), + (i4, np.float16, np.float16), + (i4, np.float32, np.float32), + (i4, np.float64, np.float64), + (i4, np.complex64, np.complex64), + (i4, np.complex128, np.complex128), + (u4, np.float32, np.float32), +]) +def test_custom_int_vs_numpy(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Custom int vs custom int +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- same type ---- + (i4, i4, i4), + (u4, u4, u4), + # ---- mixed sign: neither fits the other → int16 ---- + (i4, u4, np.int16), + (i2, u2, np.int16), + (i1, u1, np.int16), + # ---- same sign, different width → int16 fallback ---- + (i2, i4, np.int16), + (u2, u4, np.int16), + (i1, i4, np.int16), +]) +def test_custom_int_vs_custom_int(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Custom complex vs NumPy built-in types +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- bool + integers: wrap in cfloat ---- + (bc32, np.bool_, np.complex64), + (bc32, np.int8, np.complex64), + (bc32, np.int32, np.complex64), + (c32, np.bool_, np.complex64), + (c32, np.int8, np.complex64), + # ---- floats ≤ float32: wrap in cfloat ---- + (bc32, np.float16, np.complex64), + (bc32, np.float32, np.complex64), + (c32, np.float16, np.complex64), + (c32, np.float32, np.complex64), + # ---- float64+: need cdouble ---- + (bc32, np.float64, np.complex128), + (bc32, np.longdouble, np.clongdouble), + (c32, np.float64, np.complex128), + # ---- built-in complex: other always wins ---- + (bc32, np.complex64, np.complex64), + (bc32, np.complex128, np.complex128), + (c32, np.complex64, np.complex64), + (c32, np.complex128, np.complex128), +]) +def test_custom_complex_vs_numpy(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Custom complex vs custom float / custom int +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("a, b, expected", [ + # ---- custom floats: all fit in cfloat alongside our complex ---- + (bc32, bf16, np.complex64), + (bc32, f8_e4m3fn, np.complex64), + (bc32, f8_e5m2, np.complex64), + (bc32, f4, np.complex64), + (c32, bf16, np.complex64), + (c32, f8_e4m3fn, np.complex64), + # ---- custom ints: all tiny, fit in cfloat ---- + (bc32, i4, np.complex64), + (bc32, u4, np.complex64), + (bc32, i1, np.complex64), + (c32, i4, np.complex64), + # ---- two custom complex types ---- + (bc32, c32, np.complex64), + (bc32, bc32, bc32), + (c32, c32, c32), +]) +def test_custom_complex_vs_custom(a, b, expected): + assert rt(a, b) == np.dtype(expected) + assert rt(b, a) == np.dtype(expected) # must be symmetric + + +# --------------------------------------------------------------------------- +# Python scalars: 0, 0.0, 0.0j (abstract types) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dtype, scalar, expected", [ + # ---- custom floats dominate Python int and Python float ---- + (bf16, 0, bf16), + (bf16, 0.0, bf16), + (f8_e4m3fn, 0, f8_e4m3fn), + (f8_e4m3fn, 0.0, f8_e4m3fn), + (f4, 0, f4), + (f4, 0.0, f4), + # ---- custom float + Python complex → cfloat ---- + (bf16, 0.0j, np.complex64), + (f8_e4m3fn, 0.0j, np.complex64), + (f4, 0.0j, np.complex64), + # ---- custom ints: a Python int defers to the int dtype, but a Python + # float/complex crosses the integer kind and promotes to the default + # float64 / complex128 (matching NumPy's built-in integers) ---- + (i4, 0, i4), + (i4, 0.0, np.float64), + (i4, 0.0j, np.complex128), + (u4, 0, u4), + (u4, 0.0, np.float64), + (u4, 0.0j, np.complex128), + (i1, 0, i1), + (i1, 0.0, np.float64), + (i2, 0.0j, np.complex128), + # ---- custom complex dominates all Python scalars ---- + (bc32, 0, bc32), + (bc32, 0.0, bc32), + (bc32, 0.0j, bc32), + (c32, 0, c32), + (c32, 0.0, c32), + (c32, 0.0j, c32), +]) +def test_python_scalars(dtype, scalar, expected): + assert rt(dtype, scalar) == np.dtype(expected) + + +@pytest.mark.parametrize("int_t", ALL_INTN) +@pytest.mark.parametrize("scalar, concrete, expected", [ + (1.0, np.float16, np.float16), + (1.0, np.float32, np.float32), + (1.0, np.float64, np.float64), + (1.0, np.complex64, np.complex64), + (1.0j, np.float16, np.complex64), + (1.0j, np.float32, np.complex64), + (1.0j, np.float64, np.complex128), + (1.0j, np.complex64, np.complex64), +]) +def test_weak_scalar_stays_weak(int_t, scalar, concrete, expected): + # Sanity check that no-matter the order the concrete precision wins + # (i.e. promoting int + pyfloat -> pyfloat). + for args in itertools.permutations([int_t, scalar, concrete]): + assert np.result_type(*args) == np.dtype(expected), args diff --git a/pyproject.toml b/pyproject.toml index ef599aef..61762c35 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,7 @@ keywords = [] # pip dependencies of the project dependencies = [ # Ensure numpy release supports Python version. - "numpy>=1.24.0", - "numpy>=1.26.0; python_version>='3.12'", + "numpy>=2.0.0", "numpy>=2.1.0; python_version>='3.13'", "numpy>=2.3.0; python_version>='3.14'", ]