Skip to content

Use new DType API with backcompat path and fix result_type()#374

Open
seberg wants to merge 3 commits into
jax-ml:mainfrom
seberg:try-new-dtype-hack
Open

Use new DType API with backcompat path and fix result_type()#374
seberg wants to merge 3 commits into
jax-ml:mainfrom
seberg:try-new-dtype-hack

Conversation

@seberg

@seberg seberg commented Apr 20, 2026

Copy link
Copy Markdown
Contributor

gh-360 is a little bit hard to actually pull of. In NumPy 2.5+ (backport vendorered to compile on older NumPy versions), it is now allowed to create a "new style DType" that still is a legacy dtype.
That means we inherit a few existing quirks, but it mostly means that ml_dtypes can start using new API conveniently without much concern for backwards compatibility issues.

The current state is:

  • result_type() now does a better job/the right thing (yes this adds a bunch of code).
  • There is a small annoyance that arrays (at least on <2.5) print with array(..., dtype=dtype('bfloat16')) rather than just bfloat16. (I am restoring that in NumPy for 2.5+, although we need a better solution).
  • NumPy wants the copy cast defined, so I did that. Probably speeds up strided copies, but it is more of a requirement.

Next steps, unlocked things:

  • We can implement np.finfo() (given one backport PR from me in NumPy) for NumPy 2.5+.
  • I think it would be nice to move all casts over but it won't do much besides speeding up strided casts quite a lot.
  • There may be some other new API now accessible, although I can't think of anything super interesting right now. But of course if we add things in NumPy it'll now be easier to use it quickly (similar to the finfo).

I agree that the long list of dtypes is a bit annoying for CommonDType. After NumPy has a better DType hierarchy, we could allow creating your own baseclass there easily which would simplify this (although maybe not speed up in practice).

(heavy use of claude to spit out code, but of course with absolute design micro-managing in many relevant parts -- but I'll need to go through once more myself).

@seberg seberg force-pushed the try-new-dtype-hack branch 2 times, most recently from de3a5a8 to eff09d7 Compare April 20, 2026 12:58
@seberg

seberg commented Apr 21, 2026

Copy link
Copy Markdown
Contributor Author

@hawkinsp just in case you have a quick thought here. I tried to rewrite things to just use the new API but keep it a "legacy" dtype to some degree, so that there should be no real regressions but at the same time it works fine all the way back to NumPy 2 (the only regression I noticed it that arrays print a bit less nice).

However, this uses PyType_Ready. NumPy predated PyType_FromMetaclass and it still implements e.g. tp_new making it's use incompatible.
So that is a bigger downside with this approach: adopting the Python stable API may be hard or indefinitely deferred, because one would need to hack around this for old NumPy versions.
(I am pretty sure it is possible to hack around it, NumPy effectively does it and I think so does pybind11 probably. But it may be pretty ugly...)

The alternative I can currently think off is to just allow PyArrayInitDTypeMeta_FromSpec to amend the current legacy dtype.
Less forward looking and also I liked how this backported while the amending pattern is nice for new NumPy versions but backports worse, I expect (mainly because of cast definition patterns).

@hawkinsp

Copy link
Copy Markdown
Collaborator

I haven't looked yet, but

However, this uses PyType_Ready. NumPy predated PyType_FromMetaclass and it still implements e.g. tp_new making it's use incompatible.
So that is a bigger downside with this approach: adopting the Python stable API may be hard or indefinitely deferred, because one would need to hack around this for old NumPy versions.

It's not the end of the world to have to build ml_dtypes per Python version: it's a small enough package. I had previously abandoned trying to use the limited dtype API for similar reasons (#195). And eventually when the oldest supported NumPy ages off our support matrix, we can switch.

@seberg

seberg commented Apr 22, 2026

Copy link
Copy Markdown
Contributor Author

OK, cool, then I think I'll pursue this, we need a better way to transition a package like ml_dtypes and I think this is viable.

Long term for the stable API: I suspect the right thing will be to have a new DType creation function, that creates the full heap-type for you based on the spec.
(That way, even if we need a bit crazy things, that can live in NumPy. I.e. PyArrayDTypeMeta_FromSpecs(module, type_slots, dtype_slots), but that'll be a NumPy 2.6 thing at best -- I am also very curious about the Python stable API developments around this. Back in the day, I stole their ideas, but they seemed to have improved on them quite a lot!)

@seberg seberg changed the title EXPLORATORY: Add a path that allows using new DType API Use new-style DType API with backcompat path and fix result_type() Jun 5, 2026
@seberg seberg changed the title Use new-style DType API with backcompat path and fix result_type() Use new DType API with backcompat path and fix result_type() Jun 5, 2026
This uses the new-style API for NumPy 2.0+.  NumPy 2.5 ships some
backport compatibility hacks that we vendor here to compile on
older NumPy versions as well.

This requires some churn, the biggest one being settuping a within
DType casting implementation.

However, it allows using new API optionally the biggest thing being
that `result_type()` can now do the right thing.

The one downside is that the `dtype=` is printing not so nice
for NumPy <2.5.
@seberg seberg force-pushed the try-new-dtype-hack branch from c67581e to 800f815 Compare June 5, 2026 10:48
@seberg

seberg commented Jun 5, 2026

Copy link
Copy Markdown
Contributor Author

I made a big pass on this cleaning things up and hopefully fixing some issues (byte-swapping and an incorrect result_type() branching).

This should now actually be OK, although it does require NumPy 2.0+.

With this being merged in NumPy and working out fine here, this is now actually ready. I'll note again the one little annoyance I have found for now (hopefully the only one), and that is that when printing arrays we now print dtype=dtype('bfloat16') in long-form.
For NumPy 2.5, I'll try to just preserve the current behavior (even if not perfect for byte-swapping), then we can see.
(In theory we could monkey-patch NumPy if we wanted to.)

Unlike gh-360 this doesn't make it obvious to just not set a character for example but it does ensure that old-style code paths are taken so that actual regressions are unlikely (while it seems it'll be a long whack-a-mole with gh-360).

@seberg seberg marked this pull request as ready for review June 5, 2026 10:52
@seberg

seberg commented Jun 5, 2026

Copy link
Copy Markdown
Contributor Author

This also now fixes gh-301 which I suspect is the magic thing that might make my CuPy "test everything" attempt feasible (to the point I am considering if we should just hack that flag if this PR isn't so easy).

(Nevermin, older NumPy will still not support it nicely of course. If we want to improve this without the flag, I think the solution might be to check if isfinite is defined for the dtype or not.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants