Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="tensor-shape-assert",
version="0.3.1",
version="0.3.2",
description="A simple runtime assert library for tensor-based frameworks.",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
25 changes: 17 additions & 8 deletions src/tensor_shape_assert/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,17 @@ def add_function_trace(fn):
if not _trace_enabled:
return

fn_code = inspect.getsourcefile(fn), inspect.getsourcelines(fn)[1]
try:
source_file = inspect.getsourcefile(fn)
source_line = inspect.getsourcelines(fn)[1]
except OSError:
source_file = f"{fn.__module__}.{fn.__qualname__}"
source_line = -1

trace_record = TracedFunctionCall(
function_name=fn.__name__,
file=fn_code[0],
line=fn_code[1],
file=source_file,
line=source_line,
stack_index=len(_trace_stack),
call_index=len(_trace_records)
)
Expand Down Expand Up @@ -99,24 +105,27 @@ def finalize_function_trace():
def start_trace_recording():
global _trace_enabled
_trace_enabled = True
_trace_records.clear()
_trace_stack.clear()

def stop_trace_recording() -> list[TraceRecord]:
global _trace_enabled
_trace_enabled = False
records = _trace_records.copy()
_trace_records.clear()
_trace_stack.clear()
return records

def trace_records_to_string(records: list[TraceRecord]) -> str:
lines = []
cur_stack_size = -1
mentioned_calls = set()

for record in records:
indentation = "| " * record.function.stack_index

if record.function.stack_index > cur_stack_size:
if record.function not in mentioned_calls:
lines.append(f"{indentation}\n{indentation}{record.function}")
cur_stack_size = record.function.stack_index
mentioned_calls.add(record.function)

lines.append(f"{indentation}| {record.assignment}")

return "\n".join(lines)
return "\n".join(lines)
36 changes: 34 additions & 2 deletions src/tensor_shape_assert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,7 @@ def test(x: ShapedTensor["n"], info: tuple[int, int]) -> ShapedTensor["n"]:


class TestTraceLogging(unittest.TestCase):
def test_trace_logging_prints_to_stdout(self):
def test_trace_logging_example(self):

@check_tensor_shapes()
def f(x: ShapedTensor["a b n"]) -> ShapedTensor["a n"]:
Expand Down Expand Up @@ -1445,4 +1445,36 @@ def h(x: ShapedTensor["a b n"], n: int = 2) -> tuple[ScalarTensor, ScalarTensor]
self.assertIn("| | | <return> : (a n) -> shape (3, 2) => {'a': 3, 'b': 4, 'n': 2}", record_str)
self.assertIn("| | <return> : (a) -> shape (3,) => {'a': 3, 'b': 4}", record_str)
self.assertIn("| <return> : () -> shape () => {'n': 2, 'a': 3, 'b': 4}", record_str)
self.assertIn("| <return> : () -> shape () => {'n': 2, 'a': 3, 'b': 4}", record_str)
self.assertIn("| <return> : () -> shape () => {'n': 2, 'a': 3, 'b': 4}", record_str)

def test_tracing_namedtuples(self):
@check_tensor_shapes()
class MyInputTuple(NamedTuple):
p: ShapedTensor["n m"]
q: ShapedTensor["m 1"]

@check_tensor_shapes()
class MyOutputTuple(NamedTuple):
result: ShapedTensor["n"]

@check_tensor_shapes()
def test(x: MyInputTuple) -> MyOutputTuple:
return MyOutputTuple(result=(x.p @ x.q)[:, 0])

start_trace_recording()
test(MyInputTuple(
p=xp.zeros((5, 4)),
q=xp.zeros((4, 1))
))
records = stop_trace_recording()
record_str = trace_records_to_string(records)

self.assertIn(
"__new__ (defined at namedtuple_MyInputTuple.MyInputTuple.__new__:-1), stack index: 0, call index: 0\n"
"| p : (n m) -> shape (5, 4) => {'n': 5, 'm': 4}\n"
"| q : (m 1) -> shape (4, 1) => {'n': 5, 'm': 4}\n"
"| \n"
"| __new__ (defined at namedtuple_MyOutputTuple.MyOutputTuple.__new__:-1), stack index: 1, call index: 2\n"
"| | result : (n) -> shape (5,) => {'n': 5}",
record_str
)
33 changes: 31 additions & 2 deletions tracetest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
check_tensor_shapes, ShapedTensor, ScalarTensor, start_trace_recording,
stop_trace_recording, trace_records_to_string
)
from typing import NamedTuple

if __name__ == "__main__":

@check_tensor_shapes()
class Result(NamedTuple):
mean: ScalarTensor
var: ScalarTensor

@check_tensor_shapes()
def f(x: ShapedTensor["a b n"]) -> ShapedTensor["a n"]:
return x.sum(dim=1)
Expand All @@ -15,9 +22,9 @@ def g(x: ShapedTensor["a b 2"]) -> ShapedTensor["a"]:
return y[:, 0] * y[:, 1]

@check_tensor_shapes()
def h(x: ShapedTensor["a b n"], n: int = 2) -> tuple[ScalarTensor, ScalarTensor]:
def h(x: ShapedTensor["a b n"], n: int = 2) -> Result:
y = g(x)
return y.mean(), y.var()
return Result(mean=y.mean(), var=y.var())

start_trace_recording()
h(torch.randn(3, 4, 2))
Expand All @@ -36,3 +43,25 @@ def rec(x: ShapedTensor["b m m"], n: int) -> tuple[ShapedTensor["b m m"], int]:
rec(torch.randn(2, 3, 3), 10)
records = stop_trace_recording()
print(trace_records_to_string(records))

@check_tensor_shapes()
class MyInputTuple(NamedTuple):
p: ShapedTensor["n m"]
q: ShapedTensor["m 1"]

@check_tensor_shapes()
class MyOutputTuple(NamedTuple):
result: ShapedTensor["n"]

@check_tensor_shapes()
def test(x: MyInputTuple) -> MyOutputTuple:
return MyOutputTuple(result=(x.p @ x.q)[:, 0])

start_trace_recording()
test(MyInputTuple(
p=torch.zeros((5, 4)),
q=torch.zeros((4, 1))
))
records = stop_trace_recording()
record_str = trace_records_to_string(records)
print(record_str)