Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions src/tensor_shape_assert/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@



def unroll_iterable_annotation(annotation, obj):
def unroll_iterable_annotation(annotation, obj, disable_union_warning: bool):
if isinstance(annotation, (ShapeDescriptor, OptionalShapeDescriptor)):
yield annotation, obj

elif isinstance(annotation, types.GenericAlias):
# try to infer how annotation maps to iterable
sub_annotations = None
sub_annotations: list[Any] | tuple[Any] | None = None

if annotation.__origin__ == tuple:
sub_annotations = annotation.__args__
Expand Down Expand Up @@ -70,11 +70,11 @@ def unroll_iterable_annotation(annotation, obj):
)

for sub_ann, sub_obj in zip(sub_annotations, obj):
yield from unroll_iterable_annotation(sub_ann, sub_obj)
yield from unroll_iterable_annotation(sub_ann, sub_obj, disable_union_warning)

elif isinstance(annotation, types.UnionType):
for arg in annotation.__args__:
if isinstance(arg, types.GenericAlias):
if isinstance(arg, types.GenericAlias) and not disable_union_warning:
warnings.warn(RuntimeWarning(
"You used a union type in a function to be checked by "
"tensor_shape_assert. check_tensor_shapes currently does "
Expand All @@ -88,8 +88,14 @@ def unroll_iterable_annotation(annotation, obj):
))


def check_iterable(annotation: Any, obj: Any, variables: VariablesType, name: str) -> VariablesType:
for descriptor, obj in unroll_iterable_annotation(annotation, obj):
def check_iterable(
annotation: Any,
obj: Any,
variables: VariablesType,
name: str,
disable_union_warning: bool
) -> VariablesType:
for descriptor, obj in unroll_iterable_annotation(annotation, obj, disable_union_warning):

# skip if its optional and obj is None
if isinstance(descriptor, OptionalShapeDescriptor) and obj is None:
Expand Down Expand Up @@ -214,7 +220,8 @@ def check_tensor_shapes(
ints_to_variables: bool = True,
experimental_enable_autogen_constraints: bool = False,
check_mode: CheckMode | None = None,
include_outer_variables: bool | None = None
include_outer_variables: bool | None = None,
disable_union_warning: bool = False
):
"""
Enables tensor checking for the decorated function.
Expand Down Expand Up @@ -248,6 +255,8 @@ def check_tensor_shapes(
current function. This allows to define variables in outer functions
and use them in inner functions. Default is ``False`` for functions,
``True`` for NamedTuple instances.
disable_union_warning : bool, optional
If ``True``, the warning about limited support for union types is disabled.
"""

if constraints is None:
Expand Down Expand Up @@ -366,7 +375,8 @@ def _check_tensor_shapes_wrapper(*args, **kwargs):
annotation=parameter.annotation,
obj=bound_arguments[key],
variables=variables,
name=key
name=key,
disable_union_warning=disable_union_warning
)

except TensorShapeAssertError as e:
Expand Down Expand Up @@ -409,7 +419,8 @@ def _check_tensor_shapes_wrapper(*args, **kwargs):
annotation=signature.return_annotation,
obj=return_value,
variables=variables,
name="<return>"
name="<return>",
disable_union_warning=disable_union_warning
)
except TensorShapeAssertError as e:
# wrap exception to provide location info (output)
Expand Down
17 changes: 16 additions & 1 deletion src/tensor_shape_assert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,22 @@ def test(x: ShapedTensor["a 1"] | None) -> tuple[ShapedTensor["1"]] | None:
return (xp.zeros(1), )

with self.assertWarns(RuntimeWarning):
test(xp.zeros((2, 1)))
test(xp.zeros((2, 1)))

def test_dont_warn_optional_output_tuple_if_disabled(self):
@check_tensor_shapes(disable_union_warning=True)
def test(x: ShapedTensor["a 1"] | None) -> tuple[ShapedTensor["1"]] | None:
if x is not None:
return x[0, :]
else:
return (xp.zeros(1), )

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
test(xp.zeros((2, 1))) # should not trigger warning
self.assertEqual(len(w), 0)



class TestVariableConstraints(unittest.TestCase):
def test_lambda_constraints(self):
Expand Down