Skip to content

Fixes a bug where custom dtypes (floats, ints, complexes) raised TypeError when compared (==, !=) against incompatible types like strings or None.#380

Merged
copybara-service[bot] merged 1 commit into
mainfrom
test_930538602
Jun 11, 2026

Conversation

@copybara-service

Copy link
Copy Markdown

Fixes a bug where custom dtypes (floats, ints, complexes) raised TypeError when compared (==, !=) against incompatible types like strings or None.

Before this change, comparing a custom dtype to an incompatible type crashed:

>>> import ml_dtypes
>>> ml_dtypes.bfloat16(1.0) == "param"
TypeError: ufunc 'equal' not supported for the input types, and the inputs could not be safely coerced...

Now it safely returns False (or True for !=):

>>> import ml_dtypes
>>> ml_dtypes.bfloat16(1.0) == "param"
False

The C++ rich comparison now returns Py_NotImplemented early for strings and non-sequence incompatible types. This allows Python's standard identity fallback to work safely instead of crashing, while still allowing NumPy to handle valid sequence comparisons (e.g. lists). This isn't ideal, but I'm not sure it's possible to do better until we migrate to NumPy 2.0 dtypes.

I believe this to be the root cause of this JAX CI flake https://github.com/jax-ml/jax/actions/runs/27343723608/job/80786816925 which occurs when a bfloat16 scalar and a string happen to collide in a dict and end up getting compared.

…Error when compared (==, !=) against incompatible types like strings or None.

Before this change, comparing a custom dtype to an incompatible type crashed:
```python
>>> import ml_dtypes
>>> ml_dtypes.bfloat16(1.0) == "param"
TypeError: ufunc 'equal' not supported for the input types, and the inputs could not be safely coerced...

```
Now it safely returns False (or True for !=):
```python
>>> import ml_dtypes
>>> ml_dtypes.bfloat16(1.0) == "param"
False
```

The C++ rich comparison now returns Py_NotImplemented early for strings and non-sequence incompatible types. This allows Python's standard identity fallback to work safely instead of crashing, while still allowing NumPy to handle valid sequence comparisons (e.g. lists). This isn't ideal, but I'm not sure it's possible to do better until we migrate to NumPy 2.0 dtypes.

I believe this to be the root cause of this JAX CI flake https://github.com/jax-ml/jax/actions/runs/27343723608/job/80786816925 which occurs when a bfloat16 scalar and a string happen to collide in a dict and end up getting compared.

PiperOrigin-RevId: 930641853
@copybara-service copybara-service Bot merged commit dcb22b3 into main Jun 11, 2026
@copybara-service copybara-service Bot deleted the test_930538602 branch June 11, 2026 18:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant