Skip to content

NaN comparison fails in np.testing.assert_equal #301

@justinchuby

Description

@justinchuby

The test utility np.testing.assert_equal will treat NaN as equal values. However this is not the case for some ml_dtypes arrays:

import ml_dtypes
import numpy as np

# This will succeed 
fp32_array = np.array(np.nan, dtype=np.float32)
np.testing.assert_equal(fp32_array, fp32_array)

# This will fail
array = np.array(np.nan, dtype=ml_dtypes.bfloat16)
np.testing.assert_equal(array, array)

with

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "numpy/testing/_private/utils.py", line 371, in assert_equal
    return assert_array_equal(actual, desired, err_msg, verbose,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "numpy/testing/_private/utils.py", line 1051, in assert_array_equal
    assert_array_compare(operator.__eq__, actual, desired, err_msg=err_msg,
  File "numpy/testing/_private/utils.py", line 916, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Arrays are not equal

Mismatched elements: 1 / 1 (100%)
Max absolute difference among violations: nan
Max relative difference among violations: nan
 ACTUAL: array(nan, dtype=bfloat16)
 DESIRED: array(nan, dtype=bfloat16)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions