diff --git a/setup.py b/setup.py index b10c172..13e9d18 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="tensor-shape-assert", - version="0.3.1", + version="0.3.2", description="A simple runtime assert library for tensor-based frameworks.", long_description=long_description, long_description_content_type="text/markdown", diff --git a/src/tensor_shape_assert/trace.py b/src/tensor_shape_assert/trace.py index 3220ac2..ac2193c 100644 --- a/src/tensor_shape_assert/trace.py +++ b/src/tensor_shape_assert/trace.py @@ -43,11 +43,17 @@ def add_function_trace(fn): if not _trace_enabled: return - fn_code = inspect.getsourcefile(fn), inspect.getsourcelines(fn)[1] + try: + source_file = inspect.getsourcefile(fn) + source_line = inspect.getsourcelines(fn)[1] + except OSError: + source_file = f"{fn.__module__}.{fn.__qualname__}" + source_line = -1 + trace_record = TracedFunctionCall( function_name=fn.__name__, - file=fn_code[0], - line=fn_code[1], + file=source_file, + line=source_line, stack_index=len(_trace_stack), call_index=len(_trace_records) ) @@ -99,24 +105,27 @@ def finalize_function_trace(): def start_trace_recording(): global _trace_enabled _trace_enabled = True + _trace_records.clear() + _trace_stack.clear() def stop_trace_recording() -> list[TraceRecord]: global _trace_enabled _trace_enabled = False records = _trace_records.copy() _trace_records.clear() + _trace_stack.clear() return records def trace_records_to_string(records: list[TraceRecord]) -> str: lines = [] - cur_stack_size = -1 + mentioned_calls = set() + for record in records: indentation = "| " * record.function.stack_index - if record.function.stack_index > cur_stack_size: + if record.function not in mentioned_calls: lines.append(f"{indentation}\n{indentation}{record.function}") - cur_stack_size = record.function.stack_index + mentioned_calls.add(record.function) lines.append(f"{indentation}| {record.assignment}") - - return "\n".join(lines) \ No newline at end of file + return "\n".join(lines) diff --git a/src/tensor_shape_assert_test.py b/src/tensor_shape_assert_test.py index 7743443..0f21c32 100644 --- a/src/tensor_shape_assert_test.py +++ b/src/tensor_shape_assert_test.py @@ -1409,7 +1409,7 @@ def test(x: ShapedTensor["n"], info: tuple[int, int]) -> ShapedTensor["n"]: class TestTraceLogging(unittest.TestCase): - def test_trace_logging_prints_to_stdout(self): + def test_trace_logging_example(self): @check_tensor_shapes() def f(x: ShapedTensor["a b n"]) -> ShapedTensor["a n"]: @@ -1445,4 +1445,36 @@ def h(x: ShapedTensor["a b n"], n: int = 2) -> tuple[ScalarTensor, ScalarTensor] self.assertIn("| | | : (a n) -> shape (3, 2) => {'a': 3, 'b': 4, 'n': 2}", record_str) self.assertIn("| | : (a) -> shape (3,) => {'a': 3, 'b': 4}", record_str) self.assertIn("| : () -> shape () => {'n': 2, 'a': 3, 'b': 4}", record_str) - self.assertIn("| : () -> shape () => {'n': 2, 'a': 3, 'b': 4}", record_str) \ No newline at end of file + self.assertIn("| : () -> shape () => {'n': 2, 'a': 3, 'b': 4}", record_str) + + def test_tracing_namedtuples(self): + @check_tensor_shapes() + class MyInputTuple(NamedTuple): + p: ShapedTensor["n m"] + q: ShapedTensor["m 1"] + + @check_tensor_shapes() + class MyOutputTuple(NamedTuple): + result: ShapedTensor["n"] + + @check_tensor_shapes() + def test(x: MyInputTuple) -> MyOutputTuple: + return MyOutputTuple(result=(x.p @ x.q)[:, 0]) + + start_trace_recording() + test(MyInputTuple( + p=xp.zeros((5, 4)), + q=xp.zeros((4, 1)) + )) + records = stop_trace_recording() + record_str = trace_records_to_string(records) + + self.assertIn( + "__new__ (defined at namedtuple_MyInputTuple.MyInputTuple.__new__:-1), stack index: 0, call index: 0\n" + "| p : (n m) -> shape (5, 4) => {'n': 5, 'm': 4}\n" + "| q : (m 1) -> shape (4, 1) => {'n': 5, 'm': 4}\n" + "| \n" + "| __new__ (defined at namedtuple_MyOutputTuple.MyOutputTuple.__new__:-1), stack index: 1, call index: 2\n" + "| | result : (n) -> shape (5,) => {'n': 5}", + record_str + ) diff --git a/tracetest.py b/tracetest.py index 4c2b96c..f074a2b 100644 --- a/tracetest.py +++ b/tracetest.py @@ -3,8 +3,15 @@ check_tensor_shapes, ShapedTensor, ScalarTensor, start_trace_recording, stop_trace_recording, trace_records_to_string ) +from typing import NamedTuple if __name__ == "__main__": + + @check_tensor_shapes() + class Result(NamedTuple): + mean: ScalarTensor + var: ScalarTensor + @check_tensor_shapes() def f(x: ShapedTensor["a b n"]) -> ShapedTensor["a n"]: return x.sum(dim=1) @@ -15,9 +22,9 @@ def g(x: ShapedTensor["a b 2"]) -> ShapedTensor["a"]: return y[:, 0] * y[:, 1] @check_tensor_shapes() - def h(x: ShapedTensor["a b n"], n: int = 2) -> tuple[ScalarTensor, ScalarTensor]: + def h(x: ShapedTensor["a b n"], n: int = 2) -> Result: y = g(x) - return y.mean(), y.var() + return Result(mean=y.mean(), var=y.var()) start_trace_recording() h(torch.randn(3, 4, 2)) @@ -36,3 +43,25 @@ def rec(x: ShapedTensor["b m m"], n: int) -> tuple[ShapedTensor["b m m"], int]: rec(torch.randn(2, 3, 3), 10) records = stop_trace_recording() print(trace_records_to_string(records)) + + @check_tensor_shapes() + class MyInputTuple(NamedTuple): + p: ShapedTensor["n m"] + q: ShapedTensor["m 1"] + + @check_tensor_shapes() + class MyOutputTuple(NamedTuple): + result: ShapedTensor["n"] + + @check_tensor_shapes() + def test(x: MyInputTuple) -> MyOutputTuple: + return MyOutputTuple(result=(x.p @ x.q)[:, 0]) + + start_trace_recording() + test(MyInputTuple( + p=torch.zeros((5, 4)), + q=torch.zeros((4, 1)) + )) + records = stop_trace_recording() + record_str = trace_records_to_string(records) + print(record_str) \ No newline at end of file