From 10cef1807d892036d9644e30d791ee254a0e8ae1 Mon Sep 17 00:00:00 2001 From: Leif Van Holland Date: Thu, 29 Jan 2026 10:44:50 +0000 Subject: [PATCH] add rudimentary tracing for debugging --- noxfile.py | 2 - setup.py | 2 +- src/tensor_shape_assert/__init__.py | 18 +++- src/tensor_shape_assert/descriptor.py | 2 + src/tensor_shape_assert/trace.py | 122 +++++++++++++++++++++++ src/tensor_shape_assert/types.py | 10 ++ src/tensor_shape_assert/wrapper.py | 90 +++++++++++++---- src/tensor_shape_assert_test.py | 72 ++++++++++--- src/tensor_shape_assert_typesafe_test.py | 4 - tracetest.py | 38 +++++++ 10 files changed, 313 insertions(+), 47 deletions(-) create mode 100644 src/tensor_shape_assert/trace.py create mode 100644 tracetest.py diff --git a/noxfile.py b/noxfile.py index ed760c2..009b7f5 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,6 +1,4 @@ import nox -import sys -# v = sys.version.split(" ")[0] # Reuse environments to speed things up locally (optional) nox.options.reuse_venv = "yes" diff --git a/setup.py b/setup.py index b394d6e..b10c172 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="tensor-shape-assert", - version="0.3.0", + version="0.3.1", description="A simple runtime assert library for tensor-based frameworks.", long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/tensor_shape_assert/__init__.py b/src/tensor_shape_assert/__init__.py index 434d47c..b2486ed 100644 --- a/src/tensor_shape_assert/__init__.py +++ b/src/tensor_shape_assert/__init__.py @@ -1,6 +1,18 @@ from .wrapper import ( - check_tensor_shapes, get_shape_variables, assert_shape_here, + check_tensor_shapes, + get_shape_variables, + assert_shape_here, set_global_check_mode ) -from .types import ShapedTensor, ShapedTorchLiteral, ShapedNumpyLiteral, ShapedLiteral -from .types import ScalarTensor # type: ignore \ No newline at end of file +from .types import ( + ShapedTensor, + ShapedTorchLiteral, + ShapedNumpyLiteral, + ShapedLiteral +) +from .types import ScalarTensor # type: ignore +from .trace import ( + start_trace_recording, + stop_trace_recording, + trace_records_to_string +) \ No newline at end of file diff --git a/src/tensor_shape_assert/descriptor.py b/src/tensor_shape_assert/descriptor.py index 930165e..482a689 100644 --- a/src/tensor_shape_assert/descriptor.py +++ b/src/tensor_shape_assert/descriptor.py @@ -150,6 +150,8 @@ def descriptor_to_variables(shape_descriptor, shape, variables=None): ) if resulting_value is not None and not isinstance(desc_item, int): + if not isinstance(resulting_value, int): + resulting_value = tuple(resulting_value) variables[desc_item] = resulting_value return variables diff --git a/src/tensor_shape_assert/trace.py b/src/tensor_shape_assert/trace.py new file mode 100644 index 0000000..3220ac2 --- /dev/null +++ b/src/tensor_shape_assert/trace.py @@ -0,0 +1,122 @@ +import inspect + +from typing import NamedTuple +from .types import VariablesType + + +class TracedVariableAssignment(NamedTuple): + name: str | None + annotation: str | None + shape: tuple[int, ...] + assignments: VariablesType + + def __str__(self) -> str: + return ( + f"{self.name} : ({self.annotation}) -> shape {self.shape} => {self.assignments}" + ) + +class TracedFunctionCall(NamedTuple): + function_name: str | None + file: str | None + line: int + stack_index: int + call_index: int + + def __str__(self) -> str: + return ( + f"{self.function_name} (defined at {self.file}:{self.line}), " + f"stack index: {self.stack_index}, call index: {self.call_index}" + ) + +class TraceRecord(NamedTuple): + # function metadata + function: TracedFunctionCall + assignment: TracedVariableAssignment + + +_trace_stack: list[TracedFunctionCall] = [] +_trace_records: list[TraceRecord] = [] +_trace_enabled: bool = False + +def add_function_trace(fn): + global _trace_enabled + if not _trace_enabled: + return + + fn_code = inspect.getsourcefile(fn), inspect.getsourcelines(fn)[1] + trace_record = TracedFunctionCall( + function_name=fn.__name__, + file=fn_code[0], + line=fn_code[1], + stack_index=len(_trace_stack), + call_index=len(_trace_records) + ) + _trace_stack.append(trace_record) + +def add_assignment_trace( + name: str | None, + annotation: str | None, + shape: tuple[int, ...], + assignments: VariablesType + ): + global _trace_enabled + if not _trace_enabled: + return + + if len(_trace_stack) == 0: + raise RuntimeError( + "Internal error: Tried to add assignment trace without an active " + "function trace." + ) + + function_trace = _trace_stack[-1] + _trace_records.append( + TraceRecord( + function=function_trace, + assignment=TracedVariableAssignment( + name=name, + annotation=annotation, + shape=shape, + assignments=assignments.copy() + ) + ) + ) + +def finalize_function_trace(): + global _trace_enabled + if not _trace_enabled: + return + + if len(_trace_stack) == 0: + raise RuntimeError( + "Internal error: Tried to finalize function trace without an active " + "function trace." + ) + + _trace_stack.pop() + + +def start_trace_recording(): + global _trace_enabled + _trace_enabled = True + +def stop_trace_recording() -> list[TraceRecord]: + global _trace_enabled + _trace_enabled = False + records = _trace_records.copy() + _trace_records.clear() + return records + +def trace_records_to_string(records: list[TraceRecord]) -> str: + lines = [] + cur_stack_size = -1 + for record in records: + indentation = "| " * record.function.stack_index + + if record.function.stack_index > cur_stack_size: + lines.append(f"{indentation}\n{indentation}{record.function}") + cur_stack_size = record.function.stack_index + + lines.append(f"{indentation}| {record.assignment}") + + return "\n".join(lines) \ No newline at end of file diff --git a/src/tensor_shape_assert/types.py b/src/tensor_shape_assert/types.py index 6ce315f..07d5e26 100644 --- a/src/tensor_shape_assert/types.py +++ b/src/tensor_shape_assert/types.py @@ -19,6 +19,8 @@ clean_up_descriptor ) +VariablesType = dict[str, tuple[int] | int] + # define str subclasses to identify shape descriptors _NAME_TO_KIND = { @@ -144,6 +146,13 @@ def __class_getitem__(cls, key): type_params=(T, S) ) + # TODO: this can be made more useful by using a library-specific scalar type + ScalarTensor = TypeAliasType( + 'ScalarTensor', + ShapedLiteral[float, Literal[""]], + type_params=() + ) + # torch try: @@ -167,6 +176,7 @@ def __class_getitem__(cls, key): ) except ImportError: pass + else: ShapedLiteral = ShapedTensor diff --git a/src/tensor_shape_assert/wrapper.py b/src/tensor_shape_assert/wrapper.py index 358c981..993f08c 100644 --- a/src/tensor_shape_assert/wrapper.py +++ b/src/tensor_shape_assert/wrapper.py @@ -1,7 +1,8 @@ import types import inspect -from typing import Any, Callable, Literal, get_args +from typing import Any, Callable, Literal, get_args, NamedTuple import warnings +from contextlib import contextmanager from .utils import TensorShapeAssertError, check_if_dtype_matches from .descriptor import ( @@ -10,7 +11,7 @@ clean_up_descriptor ) -from .types import ShapeDescriptor, OptionalShapeDescriptor +from .types import ShapeDescriptor, OptionalShapeDescriptor, VariablesType from .errors import ( AnnotationMatchingError, DtypeConstraintError, @@ -19,6 +20,14 @@ NoVariableContextExistsError ) +from .trace import ( + add_function_trace, + add_assignment_trace, + finalize_function_trace, +) + + + def unroll_iterable_annotation(annotation, obj): if isinstance(annotation, (ShapeDescriptor, OptionalShapeDescriptor)): yield annotation, obj @@ -76,7 +85,7 @@ def unroll_iterable_annotation(annotation, obj): )) -def check_iterable(annotation, obj, variables): +def check_iterable(annotation: Any, obj: Any, variables: VariablesType, name: str) -> VariablesType: for descriptor, obj in unroll_iterable_annotation(annotation, obj): # skip if its optional and obj is None @@ -103,17 +112,26 @@ def check_iterable(annotation, obj, variables): raise NotImplementedError( "Device checks are not implemented yet." ) + + add_assignment_trace( + name=name, + annotation=str(descriptor), + shape=tuple(obj.shape), + assignments=variables + ) return variables # define a module level stack for currently declared variables -_current_variables_stack = [] +_current_variables_stack: list[VariablesType] = [] # module level check mode CheckMode = Literal["always", "once", "never"] _global_check_mode: CheckMode = "always" _checked_functions: set = set() + + def assert_valid_check_mode(mode: CheckMode): if mode is not None and mode not in get_args(CheckMode): raise TensorShapeAssertError(f"Invalid check mode: '{mode}'") @@ -134,21 +152,32 @@ def set_global_check_mode(mode: CheckMode): _global_check_mode = mode + + + +# def print_if_trace(msg: str, has_vars: bool = False): +# if _trace_mode == "enabled": +# # infer stack size by looking at current variables stack +# stack_size = len(_current_variables_stack) +# if not has_vars: +# stack_size += 1 +# print(">", " " * 4 * stack_size, msg) + def run_expression_constraint( expression: str, - variables: dict[str, int] + variables: dict[str, tuple[int] | int] ) -> bool: if "=" not in expression: assert False if "==" not in expression: expression = expression.replace("=", "==") - exec_globals = {'__builtins__': {}, **variables} + exec_globals: dict[str, Any] = {'__builtins__': {}, **variables} return eval(expression, exec_globals) def check_constraints( - constraints: list[Callable[[dict[str, int]], bool] | str], - variables: dict[str, int], + constraints: list[Callable[[VariablesType], bool] | str], + variables: VariablesType, skip_on_error: bool ): for i, constraint_fn in enumerate(constraints): @@ -178,7 +207,7 @@ def check_constraints( def check_tensor_shapes( fn_or_cls = None, *, - constraints: list[str | Callable[[dict[str, int]], bool]] | None = None, + constraints: list[str | Callable[[VariablesType], bool]] | None = None, ints_to_variables: bool = True, experimental_enable_autogen_constraints: bool = False, check_mode: CheckMode | None = None, @@ -225,6 +254,8 @@ def _make_check_tensor_shapes_wrapper(fn): def _check_tensor_shapes_wrapper(*args, **kwargs): + add_function_trace(fn) + # get check mode _check_mode = _global_check_mode if check_mode is None else check_mode @@ -260,7 +291,6 @@ def _check_tensor_shapes_wrapper(*args, **kwargs): )) return fn(*args, **kwargs) - _checked_functions.add(signature) # bind parameters @@ -272,17 +302,30 @@ def _check_tensor_shapes_wrapper(*args, **kwargs): # check input type hints if ints_to_variables: - variables = {k: v for k, v in bound_arguments.items() if type(v) is int} + variables: VariablesType = { + k: v for k, v in bound_arguments.items() + if type(v) is int + } + + if len(variables) > 0: + add_assignment_trace( + name="", + annotation="int", + shape=(), + assignments=variables + ) else: - variables = dict() + variables: VariablesType = dict() for key, parameter in signature.parameters.items(): try: variables = check_iterable( annotation=parameter.annotation, obj=bound_arguments[key], - variables=variables + variables=variables, + name=key ) + except TensorShapeAssertError as e: # wrap exception to provide location info (input) raise TensorShapeAssertError( @@ -318,10 +361,12 @@ def _check_tensor_shapes_wrapper(*args, **kwargs): # check outputs try: - check_iterable( + # TODO check if its wrong that we update variables here again + variables = check_iterable( annotation=signature.return_annotation, obj=return_value, - variables=variables + variables=variables, + name="" ) except TensorShapeAssertError as e: # wrap exception to provide location info (output) @@ -342,9 +387,10 @@ def _check_tensor_shapes_wrapper(*args, **kwargs): temp_variables[temp_replace_name] = temp_replace_val temp_constraint = f"{temp_replace_name} == {k}" check_constraints([temp_constraint], temp_variables, skip_on_error=False) - + # remove vars from stack _current_variables_stack.pop() + finalize_function_trace() # return @@ -389,7 +435,7 @@ def check_if_context_is_available(): "here." ) -def get_shape_variables(names: str) -> tuple[int, ...]: +def get_shape_variables(names: str) -> tuple[tuple[int] | int | None, ...]: """ Returns the inferred values of the tensor shape variables of the innermost function wrapped with check_tensor_shapes. @@ -402,7 +448,7 @@ def get_shape_variables(names: str) -> tuple[int, ...]: Returns ------- - tuple[int] + tuple[int | tuple[int] | None, ...] A tuple of integers representing the inferred values of the variables given in ``names``. """ @@ -415,10 +461,10 @@ def get_shape_variables(names: str) -> tuple[int, ...]: else: var_names = (*front, mdd, *back) - values = tuple(_current_variables_stack[-1].get(name, None) for name in var_names) - if len(values) == 1: - return values[0] - return values + return tuple( + _current_variables_stack[-1].get(str(name), None) + for name in var_names + ) def assert_shape_here(obj_or_shape: Any, descriptor: str) -> None: """ diff --git a/src/tensor_shape_assert_test.py b/src/tensor_shape_assert_test.py index a74a59d..7743443 100644 --- a/src/tensor_shape_assert_test.py +++ b/src/tensor_shape_assert_test.py @@ -12,21 +12,25 @@ from multiprocessing import Queue from typing import Callable, NamedTuple, TYPE_CHECKING -from tensor_shape_assert.errors import ( - MalformedDescriptorError, - UnionTypeUnsupportedError, - CheckDisabledWarning, -) -from tensor_shape_assert.types import ( - ShapedTensor, - ScalarTensor, -) -from tensor_shape_assert.wrapper import ( +# import these from public module +from tensor_shape_assert import ( check_tensor_shapes, get_shape_variables, assert_shape_here, set_global_check_mode, + ShapedTensor, + ScalarTensor, + start_trace_recording, + stop_trace_recording, + trace_records_to_string +) + +from tensor_shape_assert.errors import ( + MalformedDescriptorError, + UnionTypeUnsupportedError, + CheckDisabledWarning, ) + from tensor_shape_assert.utils import TensorShapeAssertError from tensor_shape_assert.wrapper import ( NoVariableContextExistsError, @@ -625,7 +629,7 @@ def test_error_on_no_context(self): def test_unknown_variable_is_none(self): @check_tensor_shapes() def test(x: ShapedTensor["a b"]) -> ShapedTensor["a"]: - c = get_shape_variables("c") + c = get_shape_variables("c")[0] self.assertIsNone(c) return x.sum(axis=1) @@ -643,7 +647,7 @@ def test(x: ShapedTensor["a b"]) -> ShapedTensor["a"]: def test_state_does_not_collect_ints(self): @check_tensor_shapes() def test(x: ShapedTensor["a 2"]): - k = get_shape_variables("2") + k = get_shape_variables("2")[0] self.assertIsNone(k) return x @@ -652,7 +656,7 @@ def test(x: ShapedTensor["a 2"]): def test_state_has_batch_dimension(self): @check_tensor_shapes() def test1(x: ShapedTensor["... 4"]): - batch = get_shape_variables("...") + batch = get_shape_variables("...")[0] self.assertTupleEqual(batch, (1, 2, 3)) return x @@ -660,7 +664,7 @@ def test1(x: ShapedTensor["... 4"]): @check_tensor_shapes() def test2(x: ShapedTensor["4 ..."]): - batch = get_shape_variables("...") + batch = get_shape_variables("...")[0] self.assertTupleEqual(batch, (3, 2, 1)) return x @@ -668,7 +672,7 @@ def test2(x: ShapedTensor["4 ..."]): @check_tensor_shapes() def test3(x: ShapedTensor["1 ... 4"]): - batch = get_shape_variables("...") + batch = get_shape_variables("...")[0] self.assertTupleEqual(batch, (2, 3)) return x @@ -1404,3 +1408,41 @@ def test(x: ShapedTensor["n"], info: tuple[int, int]) -> ShapedTensor["n"]: test(x=xp.zeros((10, 5)), info=(42, 3.14)) +class TestTraceLogging(unittest.TestCase): + def test_trace_logging_prints_to_stdout(self): + + @check_tensor_shapes() + def f(x: ShapedTensor["a b n"]) -> ShapedTensor["a n"]: + return xp.sum(x, axis=1) + + @check_tensor_shapes() + def g(x: ShapedTensor["a b 2"]) -> ShapedTensor["a"]: + y = f(x) + return y[:, 0] * y[:, 1] + + @check_tensor_shapes() + def h(x: ShapedTensor["a b n"], n: int = 2) -> tuple[ScalarTensor, ScalarTensor]: + y = g(x) + return xp.mean(y), xp.var(y) + + start_trace_recording() + h(xp.zeros((3, 4, 2))) + records = stop_trace_recording() + record_str = trace_records_to_string(records) + + self.assertIn("h (defined at", record_str) + self.assertIn(", stack index: 0, call index: 0", record_str) + self.assertIn("| : (int) -> shape () => {'n': 2}", record_str) + self.assertIn("| x : (a b n) -> shape (3, 4, 2) => {'n': 2, 'a': 3, 'b': 4}", record_str) + self.assertIn("| ", record_str) + self.assertIn("| g (defined at", record_str) + self.assertIn(", stack index: 1, call index: 2", record_str) + self.assertIn("| | x : (a b 2) -> shape (3, 4, 2) => {'a': 3, 'b': 4}", record_str) + self.assertIn("| | ", record_str) + self.assertIn("| | f (defined at", record_str) + self.assertIn(", stack index: 2, call index: 3", record_str) + self.assertIn("| | | x : (a b n) -> shape (3, 4, 2) => {'a': 3, 'b': 4, 'n': 2}", record_str) + self.assertIn("| | | : (a n) -> shape (3, 2) => {'a': 3, 'b': 4, 'n': 2}", record_str) + self.assertIn("| | : (a) -> shape (3,) => {'a': 3, 'b': 4}", record_str) + self.assertIn("| : () -> shape () => {'n': 2, 'a': 3, 'b': 4}", record_str) + self.assertIn("| : () -> shape () => {'n': 2, 'a': 3, 'b': 4}", record_str) \ No newline at end of file diff --git a/src/tensor_shape_assert_typesafe_test.py b/src/tensor_shape_assert_typesafe_test.py index 22aacea..652f0eb 100644 --- a/src/tensor_shape_assert_typesafe_test.py +++ b/src/tensor_shape_assert_typesafe_test.py @@ -22,7 +22,6 @@ def test_literal_annotation_torch(self): if lib != "torch": self.skipTest("Skipping torch-specific test") - from typing_extensions import Literal as L import torch @check_tensor_shapes() @@ -39,7 +38,6 @@ def test_literal_annotation_torch_alias(self): if lib != "torch": self.skipTest("Skipping torch-specific test") - from typing_extensions import Literal as L import torch @check_tensor_shapes() @@ -56,7 +54,6 @@ def test_literal_annotation_numpy(self): if lib != "numpy": self.skipTest("Skipping numpy-specific test") - from typing_extensions import Literal as L import numpy as np @check_tensor_shapes() @@ -74,7 +71,6 @@ def test_literal_annotation_numpy_alias(self): if lib != "numpy": self.skipTest("Skipping numpy-specific test") - from typing_extensions import Literal as L import numpy as np @check_tensor_shapes() diff --git a/tracetest.py b/tracetest.py new file mode 100644 index 0000000..4c2b96c --- /dev/null +++ b/tracetest.py @@ -0,0 +1,38 @@ +import torch +from src.tensor_shape_assert import ( + check_tensor_shapes, ShapedTensor, ScalarTensor, start_trace_recording, + stop_trace_recording, trace_records_to_string +) + +if __name__ == "__main__": + @check_tensor_shapes() + def f(x: ShapedTensor["a b n"]) -> ShapedTensor["a n"]: + return x.sum(dim=1) + + @check_tensor_shapes() + def g(x: ShapedTensor["a b 2"]) -> ShapedTensor["a"]: + y = f(x) + return y[:, 0] * y[:, 1] + + @check_tensor_shapes() + def h(x: ShapedTensor["a b n"], n: int = 2) -> tuple[ScalarTensor, ScalarTensor]: + y = g(x) + return y.mean(), y.var() + + start_trace_recording() + h(torch.randn(3, 4, 2)) + + records = stop_trace_recording() + print(trace_records_to_string(records)) + + @check_tensor_shapes() + def rec(x: ShapedTensor["b m m"], n: int) -> tuple[ShapedTensor["b m m"], int]: + if n == 0: + return x, n + else: + return rec(x @ x, n - 1) + + start_trace_recording() + rec(torch.randn(2, 3, 3), 10) + records = stop_trace_recording() + print(trace_records_to_string(records))