diff --git a/setup.py b/setup.py index 31a7b15..d74fd14 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="tensor-shape-assert", - version="0.3.3", + version="0.3.4", description="A simple runtime assert library for tensor-based frameworks.", long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/tensor_shape_assert/wrapper.py b/src/tensor_shape_assert/wrapper.py index 2b7f0b5..9c91ce1 100644 --- a/src/tensor_shape_assert/wrapper.py +++ b/src/tensor_shape_assert/wrapper.py @@ -527,7 +527,11 @@ def assert_shape_here(obj_or_shape: Any, descriptor: str) -> None: """ # skip if check is disabled - if _global_check_mode == "never": + if _global_check_mode in ('once', 'never'): + warnings.warn(CheckDisabledWarning( + "Global check mode is set to 'once' or 'never'. Calls to " + "``assert_shape_here`` will be skipped." + )) return check_if_context_is_available() diff --git a/src/tensor_shape_assert_test.py b/src/tensor_shape_assert_test.py index 39e0a3f..6f4b805 100644 --- a/src/tensor_shape_assert_test.py +++ b/src/tensor_shape_assert_test.py @@ -1120,20 +1120,7 @@ def test(x: ShapedTensor["m n 2"]) -> ShapedTensor["2"]: test(xp.zeros((4, 3, 3))) set_global_check_mode('always') - - def test_never_global_ignores_assert_shape_here(self): - set_global_check_mode('never') - - @check_tensor_shapes() - def test(x) -> ShapedTensor["2"]: - assert_shape_here(x, "m n 2") - return x.sum(axis=(0, 1)) - test(xp.zeros((4, 3, 2))) - test(xp.zeros((4, 3, 1))) - test(xp.zeros((4, 3, 3))) - - set_global_check_mode('always') def test_local_always_overrides_global_never(self): set_global_check_mode('never') @@ -1227,31 +1214,58 @@ def test_invalid_global_check_mode_raises(self): with self.assertRaises(TensorShapeAssertError): set_global_check_mode("invalid_mode") - def test_assert_shape_here_respects_global_check_mode(self): + def test_assert_shape_here_respects_global_check_mode_never(self): set_global_check_mode('never') @check_tensor_shapes() - def test(x) -> ShapedTensor["2"]: + def test(x): assert_shape_here(x, "m n 2") return x.sum(axis=(0, 1)) - test(xp.zeros((4, 3, 2))) - test(xp.zeros((4, 3, 1))) - test(xp.zeros((4, 3, 3))) + with self.assertWarns(CheckDisabledWarning): + test(xp.zeros((4, 3, 2))) + with self.assertWarns(CheckDisabledWarning): + test(xp.zeros((4, 3, 1))) + with self.assertWarns(CheckDisabledWarning): + test(xp.zeros((4, 3, 3))) + + set_global_check_mode('always') + + + def test_assert_shape_here_respects_global_check_mode_once(self): + + set_global_check_mode('once') + + @check_tensor_shapes() + def test2(x): + assert_shape_here(x, "m n 2") + return x.sum(axis=(0, 1)) + + with self.assertWarns(CheckDisabledWarning): + test2(xp.zeros((4, 3, 1))) + with self.assertWarns(CheckDisabledWarning): + test2(xp.zeros((4, 3, 2))) + with self.assertWarns(CheckDisabledWarning): + test2(xp.zeros((4, 3, 3))) + + set_global_check_mode('always') + + + def test_assert_shape_here_respects_local_check_mode_always(self): set_global_check_mode('always') @check_tensor_shapes() - def test2(x) -> ShapedTensor["2"]: + def test3(x): assert_shape_here(x, "m n 2") return x.sum(axis=(0, 1)) - test2(xp.zeros((4, 3, 2))) + test3(xp.zeros((4, 3, 2))) with self.assertRaises(TensorShapeAssertError): - test2(xp.zeros((4, 3, 1))) + test3(xp.zeros((4, 3, 1))) with self.assertRaises(TensorShapeAssertError): - test2(xp.zeros((4, 3, 3))) + test3(xp.zeros((4, 3, 3))) class TestScalarValues(unittest.TestCase): def test_scalar_inputs(self):