Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/tensor_shape_assert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,12 @@ def check_if_dtype_matches(obj, kind, bits):
raise TensorShapeAssertError(
f"Dtype '{obj.dtype}' does not support bit checking."
)

def is_typing_namedtuple_instance(x) -> bool:
t = type(x)
return (
isinstance(x, tuple)
and hasattr(t, "_fields")
and isinstance(getattr(t, "_fields", None), tuple)
and hasattr(t, "__annotations__") # typing.NamedTuple gives annotations
)
61 changes: 52 additions & 9 deletions src/tensor_shape_assert/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import warnings
from contextlib import contextmanager

from .utils import TensorShapeAssertError, check_if_dtype_matches
from .utils import (
TensorShapeAssertError,
check_if_dtype_matches
)
from .descriptor import (
descriptor_to_variables,
split_to_descriptor_items,
Expand Down Expand Up @@ -211,6 +214,7 @@ def check_tensor_shapes(
ints_to_variables: bool = True,
experimental_enable_autogen_constraints: bool = False,
check_mode: CheckMode | None = None,
include_outer_variables: bool | None = None
):
"""
Enables tensor checking for the decorated function.
Expand Down Expand Up @@ -238,6 +242,12 @@ def check_tensor_shapes(
check_mode : CheckMode, optional
The check mode to use for this function. If not specified, the global
check mode is used. See ``set_global_check_mode`` for details.
include_outer_variables : bool, optional
If ``True``, variables defined in outer functions wrapped with
``check_tensor_shapes`` will also be considered when checking the
current function. This allows to define variables in outer functions
and use them in inner functions. Default is ``False`` for functions,
``True`` for NamedTuple instances.
"""

if constraints is None:
Expand Down Expand Up @@ -299,23 +309,56 @@ def _check_tensor_shapes_wrapper(*args, **kwargs):
bindings.apply_defaults()
bound_arguments = dict(bindings.arguments)

# check input type hints
# get variables...

variables: VariablesType = dict()

# ...from outer function

if include_outer_variables and len(_current_variables_stack) > 0:
# include variables from outer function
variables.update(_current_variables_stack[-1])

add_assignment_trace(
name="<outer variables>",
annotation="",
shape=(),
assignments=variables
)

# ...from ints

if ints_to_variables:
variables: VariablesType = {
int_variables: VariablesType = {
k: v for k, v in bound_arguments.items()
if type(v) is int
}

if len(variables) > 0:
# check for collisions with outer variables

for k in int_variables.keys():
if k in variables and variables[k] != int_variables[k]:
raise VariableConstraintError(
f"Cannot assign integer parameter '{k}' to "
f"shape variable as it is already defined "
f"in the outer function with a different "
f"value ({variables[k]} != "
f"{int_variables[k]})."
)

# add to variables

variables.update(int_variables)

if len(int_variables) > 0:
add_assignment_trace(
name="<int variables>",
annotation="int",
annotation="",
shape=(),
assignments=variables
)
else:
variables: VariablesType = dict()
assignments=int_variables
)

# run checks for inputs

for key, parameter in signature.parameters.items():
try:
Expand Down
129 changes: 128 additions & 1 deletion src/tensor_shape_assert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,7 +1432,7 @@ def h(x: ShapedTensor["a b n"], n: int = 2) -> tuple[ScalarTensor, ScalarTensor]

self.assertIn("h (defined at", record_str)
self.assertIn(", stack index: 0, call index: 0", record_str)
self.assertIn("| <int variables> : (int) -> shape () => {'n': 2}", record_str)
self.assertIn("| <int variables> : () -> shape () => {'n': 2}", record_str)
self.assertIn("| x : (a b n) -> shape (3, 4, 2) => {'n': 2, 'a': 3, 'b': 4}", record_str)
self.assertIn("| ", record_str)
self.assertIn("| g (defined at", record_str)
Expand Down Expand Up @@ -1478,3 +1478,130 @@ def test(x: MyInputTuple) -> MyOutputTuple:
"| | result : (n) -> shape (5,) => {'n': 5}",
record_str
)


class TestKeepingOuterVariables(unittest.TestCase):
def test_inner_function_sees_collision_if_enabled(self):

@check_tensor_shapes(include_outer_variables=True)
def inner(y: ShapedTensor["a 2"]):
return y

@check_tensor_shapes()
def outer(x: ShapedTensor["a b 2"]):
return inner(x[0, ...])

# this works, because a == b
outer(xp.zeros((3, 3, 2)))

# this shouldn't work
with self.assertRaises(TensorShapeAssertError):
outer(xp.zeros((3, 4, 2)))


def test_inner_function_does_not_raise_if_disabled(self):

@check_tensor_shapes()
def inner(y: ShapedTensor["a 2"]):
return y

@check_tensor_shapes()
def outer(x: ShapedTensor["a b 2"]):
return inner(x[0, ...])

# this works, because a == b
outer(xp.zeros((3, 3, 2)))

# this also works, because inner can't see 'a'
outer(xp.zeros((3, 4, 2)))


def test_inner_function_can_see_outer_variable_if_enabled(self):

@check_tensor_shapes(include_outer_variables=True)
def inner(y: ShapedTensor["a 2"]):
# this should return the value of 'b' from outer
self.assertEqual(get_shape_variables("b")[0], 4)
return y

@check_tensor_shapes()
def outer(x: ShapedTensor["a b 2"]):
return inner(x[:, 0, :])

outer(xp.zeros((3, 4, 2)))

def test_inner_function_can_not_see_outer_variable_if_disabled(self):

@check_tensor_shapes()
def inner(y: ShapedTensor["a 2"]):
# this should return None, as 'b' is not visible here
self.assertIsNone(get_shape_variables("b")[0])
return y

@check_tensor_shapes()
def outer(x: ShapedTensor["a b 2"]):
return inner(x[:, 0, :])

outer(xp.zeros((3, 4, 2)))

def test_inner_function_can_not_see_more_than_one_outer_function(self):

@check_tensor_shapes(include_outer_variables=True)
def inner(z: ShapedTensor["2"]):
self.assertEqual(get_shape_variables("b")[0], 4)
self.assertIsNone(get_shape_variables("a")[0])
return z

@check_tensor_shapes()
def middle(y: ShapedTensor["b 2"]):
self.assertIsNone(get_shape_variables("a")[0])
return inner(y[0, :])

@check_tensor_shapes()
def outer(x: ShapedTensor["a b 2"]):
return middle(x[0, ...])

outer(xp.zeros((3, 4, 2)))

def test_namedtuple_as_return_type_if_enabled(self):
@check_tensor_shapes(include_outer_variables=True)
class MyTupleEnabled(NamedTuple):
p: ShapedTensor["a 2"]
q: ShapedTensor["b 2"]

@check_tensor_shapes()
def test(x: ShapedTensor["a b 2"]) -> MyTupleEnabled:
return MyTupleEnabled(
p=x[0, ...],
q=x[:, 0, :]
)

# works, because a == b
test(xp.zeros((3, 3, 2)))

# doesn't work, because a != b
with self.assertRaises(TensorShapeAssertError):
test(xp.zeros((3, 4, 2)))


def test_namedtuple_as_return_type_if_disabled(self):
@check_tensor_shapes(include_outer_variables=False)
class MyTupleDisabled(NamedTuple):
p: ShapedTensor["a 2"]
q: ShapedTensor["b 2"]

@check_tensor_shapes()
def test(x: ShapedTensor["a b 2"]) -> MyTupleDisabled:
return MyTupleDisabled(
p=x[0, ...],
q=x[:, 0, :]
)

# works, because a == b
test(xp.zeros((3, 3, 2)))

# should also work, because outer variables are not visible in the
# return type check
test(xp.zeros((3, 4, 2)))