Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="tensor-shape-assert",
version="0.3.3",
version="0.3.4",
description="A simple runtime assert library for tensor-based frameworks.",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
6 changes: 5 additions & 1 deletion src/tensor_shape_assert/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,11 @@ def assert_shape_here(obj_or_shape: Any, descriptor: str) -> None:
"""

# skip if check is disabled
if _global_check_mode == "never":
if _global_check_mode in ('once', 'never'):
warnings.warn(CheckDisabledWarning(
"Global check mode is set to 'once' or 'never'. Calls to "
"``assert_shape_here`` will be skipped."
))
return

check_if_context_is_available()
Expand Down
58 changes: 36 additions & 22 deletions src/tensor_shape_assert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,20 +1120,7 @@ def test(x: ShapedTensor["m n 2"]) -> ShapedTensor["2"]:
test(xp.zeros((4, 3, 3)))

set_global_check_mode('always')

def test_never_global_ignores_assert_shape_here(self):
set_global_check_mode('never')

@check_tensor_shapes()
def test(x) -> ShapedTensor["2"]:
assert_shape_here(x, "m n 2")
return x.sum(axis=(0, 1))

test(xp.zeros((4, 3, 2)))
test(xp.zeros((4, 3, 1)))
test(xp.zeros((4, 3, 3)))

set_global_check_mode('always')

def test_local_always_overrides_global_never(self):
set_global_check_mode('never')
Expand Down Expand Up @@ -1227,31 +1214,58 @@ def test_invalid_global_check_mode_raises(self):
with self.assertRaises(TensorShapeAssertError):
set_global_check_mode("invalid_mode")

def test_assert_shape_here_respects_global_check_mode(self):
def test_assert_shape_here_respects_global_check_mode_never(self):
set_global_check_mode('never')

@check_tensor_shapes()
def test(x) -> ShapedTensor["2"]:
def test(x):
assert_shape_here(x, "m n 2")
return x.sum(axis=(0, 1))

test(xp.zeros((4, 3, 2)))
test(xp.zeros((4, 3, 1)))
test(xp.zeros((4, 3, 3)))
with self.assertWarns(CheckDisabledWarning):
test(xp.zeros((4, 3, 2)))
with self.assertWarns(CheckDisabledWarning):
test(xp.zeros((4, 3, 1)))
with self.assertWarns(CheckDisabledWarning):
test(xp.zeros((4, 3, 3)))

set_global_check_mode('always')


def test_assert_shape_here_respects_global_check_mode_once(self):

set_global_check_mode('once')

@check_tensor_shapes()
def test2(x):
assert_shape_here(x, "m n 2")
return x.sum(axis=(0, 1))

with self.assertWarns(CheckDisabledWarning):
test2(xp.zeros((4, 3, 1)))
with self.assertWarns(CheckDisabledWarning):
test2(xp.zeros((4, 3, 2)))
with self.assertWarns(CheckDisabledWarning):
test2(xp.zeros((4, 3, 3)))

set_global_check_mode('always')


def test_assert_shape_here_respects_local_check_mode_always(self):

set_global_check_mode('always')

@check_tensor_shapes()
def test2(x) -> ShapedTensor["2"]:
def test3(x):
assert_shape_here(x, "m n 2")
return x.sum(axis=(0, 1))

test2(xp.zeros((4, 3, 2)))
test3(xp.zeros((4, 3, 2)))

with self.assertRaises(TensorShapeAssertError):
test2(xp.zeros((4, 3, 1)))
test3(xp.zeros((4, 3, 1)))
with self.assertRaises(TensorShapeAssertError):
test2(xp.zeros((4, 3, 3)))
test3(xp.zeros((4, 3, 3)))

class TestScalarValues(unittest.TestCase):
def test_scalar_inputs(self):
Expand Down