From 974842471125d44e542ad1101e8cd0a589e16339 Mon Sep 17 00:00:00 2001 From: Leif Van Holland Date: Wed, 11 Feb 2026 14:06:09 +0000 Subject: [PATCH 1/2] added flag to suppress warning about unions --- src/tensor_shape_assert/wrapper.py | 27 +++++++++++++++++++-------- src/tensor_shape_assert_test.py | 17 ++++++++++++++++- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/tensor_shape_assert/wrapper.py b/src/tensor_shape_assert/wrapper.py index 9c91ce1..afa2773 100644 --- a/src/tensor_shape_assert/wrapper.py +++ b/src/tensor_shape_assert/wrapper.py @@ -31,7 +31,7 @@ -def unroll_iterable_annotation(annotation, obj): +def unroll_iterable_annotation(annotation, obj, disable_union_warning: bool): if isinstance(annotation, (ShapeDescriptor, OptionalShapeDescriptor)): yield annotation, obj @@ -70,9 +70,9 @@ def unroll_iterable_annotation(annotation, obj): ) for sub_ann, sub_obj in zip(sub_annotations, obj): - yield from unroll_iterable_annotation(sub_ann, sub_obj) + yield from unroll_iterable_annotation(sub_ann, sub_obj, disable_union_warning) - elif isinstance(annotation, types.UnionType): + elif isinstance(annotation, types.UnionType) and not disable_union_warning: for arg in annotation.__args__: if isinstance(arg, types.GenericAlias): warnings.warn(RuntimeWarning( @@ -88,8 +88,14 @@ def unroll_iterable_annotation(annotation, obj): )) -def check_iterable(annotation: Any, obj: Any, variables: VariablesType, name: str) -> VariablesType: - for descriptor, obj in unroll_iterable_annotation(annotation, obj): +def check_iterable( + annotation: Any, + obj: Any, + variables: VariablesType, + name: str, + disable_union_warning: bool +) -> VariablesType: + for descriptor, obj in unroll_iterable_annotation(annotation, obj, disable_union_warning): # skip if its optional and obj is None if isinstance(descriptor, OptionalShapeDescriptor) and obj is None: @@ -214,7 +220,8 @@ def check_tensor_shapes( ints_to_variables: bool = True, experimental_enable_autogen_constraints: bool = False, check_mode: CheckMode | None = None, - include_outer_variables: bool | None = None + include_outer_variables: bool | None = None, + disable_union_warning: bool = False ): """ Enables tensor checking for the decorated function. @@ -248,6 +255,8 @@ def check_tensor_shapes( current function. This allows to define variables in outer functions and use them in inner functions. Default is ``False`` for functions, ``True`` for NamedTuple instances. + disable_union_warning : bool, optional + If ``True``, the warning about limited support for union types is disabled. """ if constraints is None: @@ -366,7 +375,8 @@ def _check_tensor_shapes_wrapper(*args, **kwargs): annotation=parameter.annotation, obj=bound_arguments[key], variables=variables, - name=key + name=key, + disable_union_warning=disable_union_warning ) except TensorShapeAssertError as e: @@ -409,7 +419,8 @@ def _check_tensor_shapes_wrapper(*args, **kwargs): annotation=signature.return_annotation, obj=return_value, variables=variables, - name="" + name="", + disable_union_warning=disable_union_warning ) except TensorShapeAssertError as e: # wrap exception to provide location info (output) diff --git a/src/tensor_shape_assert_test.py b/src/tensor_shape_assert_test.py index 6f4b805..eb4a06e 100644 --- a/src/tensor_shape_assert_test.py +++ b/src/tensor_shape_assert_test.py @@ -864,7 +864,22 @@ def test(x: ShapedTensor["a 1"] | None) -> tuple[ShapedTensor["1"]] | None: return (xp.zeros(1), ) with self.assertWarns(RuntimeWarning): - test(xp.zeros((2, 1))) + test(xp.zeros((2, 1))) + + def test_dont_warn_optional_output_tuple_if_disabled(self): + @check_tensor_shapes(disable_union_warning=True) + def test(x: ShapedTensor["a 1"] | None) -> tuple[ShapedTensor["1"]] | None: + if x is not None: + return x[0, :] + else: + return (xp.zeros(1), ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + test(xp.zeros((2, 1))) # should not trigger warning + self.assertEqual(len(w), 0) + + class TestVariableConstraints(unittest.TestCase): def test_lambda_constraints(self): From 9f611657b905b72b5fc9dfc281ed044fca89203f Mon Sep 17 00:00:00 2001 From: Leif Van Holland Date: Wed, 11 Feb 2026 14:09:15 +0000 Subject: [PATCH 2/2] fixed typing issue --- src/tensor_shape_assert/wrapper.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tensor_shape_assert/wrapper.py b/src/tensor_shape_assert/wrapper.py index afa2773..fc97dc7 100644 --- a/src/tensor_shape_assert/wrapper.py +++ b/src/tensor_shape_assert/wrapper.py @@ -37,7 +37,7 @@ def unroll_iterable_annotation(annotation, obj, disable_union_warning: bool): elif isinstance(annotation, types.GenericAlias): # try to infer how annotation maps to iterable - sub_annotations = None + sub_annotations: list[Any] | tuple[Any] | None = None if annotation.__origin__ == tuple: sub_annotations = annotation.__args__ @@ -72,9 +72,9 @@ def unroll_iterable_annotation(annotation, obj, disable_union_warning: bool): for sub_ann, sub_obj in zip(sub_annotations, obj): yield from unroll_iterable_annotation(sub_ann, sub_obj, disable_union_warning) - elif isinstance(annotation, types.UnionType) and not disable_union_warning: + elif isinstance(annotation, types.UnionType): for arg in annotation.__args__: - if isinstance(arg, types.GenericAlias): + if isinstance(arg, types.GenericAlias) and not disable_union_warning: warnings.warn(RuntimeWarning( "You used a union type in a function to be checked by " "tensor_shape_assert. check_tensor_shapes currently does "