From 485e532eee741710845313a1729c3d832509e154 Mon Sep 17 00:00:00 2001 From: Leif Van Holland Date: Mon, 2 Feb 2026 05:35:11 +0000 Subject: [PATCH 1/2] added first version + tests for variable scope --- src/tensor_shape_assert/utils.py | 9 ++ src/tensor_shape_assert/wrapper.py | 61 ++++++++++++-- src/tensor_shape_assert_test.py | 127 +++++++++++++++++++++++++++++ 3 files changed, 188 insertions(+), 9 deletions(-) diff --git a/src/tensor_shape_assert/utils.py b/src/tensor_shape_assert/utils.py index 14eff0b..d224877 100644 --- a/src/tensor_shape_assert/utils.py +++ b/src/tensor_shape_assert/utils.py @@ -30,3 +30,12 @@ def check_if_dtype_matches(obj, kind, bits): raise TensorShapeAssertError( f"Dtype '{obj.dtype}' does not support bit checking." ) + +def is_typing_namedtuple_instance(x) -> bool: + t = type(x) + return ( + isinstance(x, tuple) + and hasattr(t, "_fields") + and isinstance(getattr(t, "_fields", None), tuple) + and hasattr(t, "__annotations__") # typing.NamedTuple gives annotations + ) \ No newline at end of file diff --git a/src/tensor_shape_assert/wrapper.py b/src/tensor_shape_assert/wrapper.py index 993f08c..2b7f0b5 100644 --- a/src/tensor_shape_assert/wrapper.py +++ b/src/tensor_shape_assert/wrapper.py @@ -4,7 +4,10 @@ import warnings from contextlib import contextmanager -from .utils import TensorShapeAssertError, check_if_dtype_matches +from .utils import ( + TensorShapeAssertError, + check_if_dtype_matches +) from .descriptor import ( descriptor_to_variables, split_to_descriptor_items, @@ -211,6 +214,7 @@ def check_tensor_shapes( ints_to_variables: bool = True, experimental_enable_autogen_constraints: bool = False, check_mode: CheckMode | None = None, + include_outer_variables: bool | None = None ): """ Enables tensor checking for the decorated function. @@ -238,6 +242,12 @@ def check_tensor_shapes( check_mode : CheckMode, optional The check mode to use for this function. If not specified, the global check mode is used. See ``set_global_check_mode`` for details. + include_outer_variables : bool, optional + If ``True``, variables defined in outer functions wrapped with + ``check_tensor_shapes`` will also be considered when checking the + current function. This allows to define variables in outer functions + and use them in inner functions. Default is ``False`` for functions, + ``True`` for NamedTuple instances. """ if constraints is None: @@ -299,23 +309,56 @@ def _check_tensor_shapes_wrapper(*args, **kwargs): bindings.apply_defaults() bound_arguments = dict(bindings.arguments) - # check input type hints + # get variables... + + variables: VariablesType = dict() + + # ...from outer function + + if include_outer_variables and len(_current_variables_stack) > 0: + # include variables from outer function + variables.update(_current_variables_stack[-1]) + + add_assignment_trace( + name="", + annotation="", + shape=(), + assignments=variables + ) + + # ...from ints if ints_to_variables: - variables: VariablesType = { + int_variables: VariablesType = { k: v for k, v in bound_arguments.items() if type(v) is int } - if len(variables) > 0: + # check for collisions with outer variables + + for k in int_variables.keys(): + if k in variables and variables[k] != int_variables[k]: + raise VariableConstraintError( + f"Cannot assign integer parameter '{k}' to " + f"shape variable as it is already defined " + f"in the outer function with a different " + f"value ({variables[k]} != " + f"{int_variables[k]})." + ) + + # add to variables + + variables.update(int_variables) + + if len(int_variables) > 0: add_assignment_trace( name="", - annotation="int", + annotation="", shape=(), - assignments=variables - ) - else: - variables: VariablesType = dict() + assignments=int_variables + ) + + # run checks for inputs for key, parameter in signature.parameters.items(): try: diff --git a/src/tensor_shape_assert_test.py b/src/tensor_shape_assert_test.py index 0f21c32..b54c424 100644 --- a/src/tensor_shape_assert_test.py +++ b/src/tensor_shape_assert_test.py @@ -1478,3 +1478,130 @@ def test(x: MyInputTuple) -> MyOutputTuple: "| | result : (n) -> shape (5,) => {'n': 5}", record_str ) + + +class TestKeepingOuterVariables(unittest.TestCase): + def test_inner_function_sees_collision_if_enabled(self): + + @check_tensor_shapes(include_outer_variables=True) + def inner(y: ShapedTensor["a 2"]): + return y + + @check_tensor_shapes() + def outer(x: ShapedTensor["a b 2"]): + return inner(x[0]) + + # this works, because a == b + outer(xp.zeros((3, 3, 2))) + + # this shouldn't work + with self.assertRaises(TensorShapeAssertError): + outer(xp.zeros((3, 4, 2))) + + + def test_inner_function_does_not_raise_if_disabled(self): + + @check_tensor_shapes() + def inner(y: ShapedTensor["a 2"]): + return y + + @check_tensor_shapes() + def outer(x: ShapedTensor["a b 2"]): + return inner(x[0]) + + # this works, because a == b + outer(xp.zeros((3, 3, 2))) + + # this also works, because inner can't see 'a' + outer(xp.zeros((3, 4, 2))) + + + def test_inner_function_can_see_outer_variable_if_enabled(self): + + @check_tensor_shapes(include_outer_variables=True) + def inner(y: ShapedTensor["a 2"]): + # this should return the value of 'b' from outer + self.assertEqual(get_shape_variables("b")[0], 4) + return y + + @check_tensor_shapes() + def outer(x: ShapedTensor["a b 2"]): + return inner(x[:, 0]) + + outer(xp.zeros((3, 4, 2))) + + def test_inner_function_can_not_see_outer_variable_if_disabled(self): + + @check_tensor_shapes() + def inner(y: ShapedTensor["a 2"]): + # this should return None, as 'b' is not visible here + self.assertIsNone(get_shape_variables("b")[0]) + return y + + @check_tensor_shapes() + def outer(x: ShapedTensor["a b 2"]): + return inner(x[:, 0]) + + outer(xp.zeros((3, 4, 2))) + + def test_inner_function_can_not_see_more_than_one_outer_function(self): + + @check_tensor_shapes(include_outer_variables=True) + def inner(z: ShapedTensor["2"]): + self.assertEqual(get_shape_variables("b")[0], 4) + self.assertIsNone(get_shape_variables("a")[0]) + return z + + @check_tensor_shapes() + def middle(y: ShapedTensor["b 2"]): + self.assertIsNone(get_shape_variables("a")[0]) + return inner(y[0]) + + @check_tensor_shapes() + def outer(x: ShapedTensor["a b 2"]): + return middle(x[0]) + + outer(xp.zeros((3, 4, 2))) + + def test_namedtuple_as_return_type_if_enabled(self): + @check_tensor_shapes(include_outer_variables=True) + class MyTupleEnabled(NamedTuple): + p: ShapedTensor["a 2"] + q: ShapedTensor["b 2"] + + @check_tensor_shapes() + def test(x: ShapedTensor["a b 2"]) -> MyTupleEnabled: + return MyTupleEnabled( + p=x[0], + q=x[:, 0] + ) + + # works, because a == b + test(xp.zeros((3, 3, 2))) + + # doesn't work, because a != b + with self.assertRaises(TensorShapeAssertError): + test(xp.zeros((3, 4, 2))) + + + def test_namedtuple_as_return_type_if_disabled(self): + @check_tensor_shapes(include_outer_variables=False) + class MyTupleDisabled(NamedTuple): + p: ShapedTensor["a 2"] + q: ShapedTensor["b 2"] + + @check_tensor_shapes() + def test(x: ShapedTensor["a b 2"]) -> MyTupleDisabled: + return MyTupleDisabled( + p=x[0], + q=x[:, 0] + ) + + # works, because a == b + test(xp.zeros((3, 3, 2))) + + # should also work, because outer variables are not visible in the + # return type check + test(xp.zeros((3, 4, 2))) + + \ No newline at end of file From 7a18d0ad980cdd539f600989fa9ea714a999f65f Mon Sep 17 00:00:00 2001 From: Leif Van Holland Date: Mon, 2 Feb 2026 15:48:32 +0000 Subject: [PATCH 2/2] fixed tests to run with ndonnx --- src/tensor_shape_assert_test.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/tensor_shape_assert_test.py b/src/tensor_shape_assert_test.py index b54c424..39e0a3f 100644 --- a/src/tensor_shape_assert_test.py +++ b/src/tensor_shape_assert_test.py @@ -1432,7 +1432,7 @@ def h(x: ShapedTensor["a b n"], n: int = 2) -> tuple[ScalarTensor, ScalarTensor] self.assertIn("h (defined at", record_str) self.assertIn(", stack index: 0, call index: 0", record_str) - self.assertIn("| : (int) -> shape () => {'n': 2}", record_str) + self.assertIn("| : () -> shape () => {'n': 2}", record_str) self.assertIn("| x : (a b n) -> shape (3, 4, 2) => {'n': 2, 'a': 3, 'b': 4}", record_str) self.assertIn("| ", record_str) self.assertIn("| g (defined at", record_str) @@ -1489,7 +1489,7 @@ def inner(y: ShapedTensor["a 2"]): @check_tensor_shapes() def outer(x: ShapedTensor["a b 2"]): - return inner(x[0]) + return inner(x[0, ...]) # this works, because a == b outer(xp.zeros((3, 3, 2))) @@ -1507,7 +1507,7 @@ def inner(y: ShapedTensor["a 2"]): @check_tensor_shapes() def outer(x: ShapedTensor["a b 2"]): - return inner(x[0]) + return inner(x[0, ...]) # this works, because a == b outer(xp.zeros((3, 3, 2))) @@ -1526,7 +1526,7 @@ def inner(y: ShapedTensor["a 2"]): @check_tensor_shapes() def outer(x: ShapedTensor["a b 2"]): - return inner(x[:, 0]) + return inner(x[:, 0, :]) outer(xp.zeros((3, 4, 2))) @@ -1540,7 +1540,7 @@ def inner(y: ShapedTensor["a 2"]): @check_tensor_shapes() def outer(x: ShapedTensor["a b 2"]): - return inner(x[:, 0]) + return inner(x[:, 0, :]) outer(xp.zeros((3, 4, 2))) @@ -1555,12 +1555,12 @@ def inner(z: ShapedTensor["2"]): @check_tensor_shapes() def middle(y: ShapedTensor["b 2"]): self.assertIsNone(get_shape_variables("a")[0]) - return inner(y[0]) + return inner(y[0, :]) @check_tensor_shapes() def outer(x: ShapedTensor["a b 2"]): - return middle(x[0]) - + return middle(x[0, ...]) + outer(xp.zeros((3, 4, 2))) def test_namedtuple_as_return_type_if_enabled(self): @@ -1572,8 +1572,8 @@ class MyTupleEnabled(NamedTuple): @check_tensor_shapes() def test(x: ShapedTensor["a b 2"]) -> MyTupleEnabled: return MyTupleEnabled( - p=x[0], - q=x[:, 0] + p=x[0, ...], + q=x[:, 0, :] ) # works, because a == b @@ -1593,8 +1593,8 @@ class MyTupleDisabled(NamedTuple): @check_tensor_shapes() def test(x: ShapedTensor["a b 2"]) -> MyTupleDisabled: return MyTupleDisabled( - p=x[0], - q=x[:, 0] + p=x[0, ...], + q=x[:, 0, :] ) # works, because a == b