Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ build
**/.nox
**/.pytest_cache
.vscode
**/.mypy_cache
**/.mypy_cache
benchmark_results.csv
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
repos: []
107 changes: 61 additions & 46 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,45 @@ A minimal utility library that
* supports any array class that exposes a ``shape`` property,
* and a lot more!

## Installation

Currently, the package can only be installed directly from the repository with
```bash
pip install git+https://github.com/leifvan/tensor-shape-assert
```

## Usage

Decorate functions with ``@check_tensor_shapes()`` and any parameter with a type hint of type ``ShapedTensor[<desc>]`` will be dynamically checked for the correct shape. A shape descriptor is a *string* of space-separated length descriptions for each dimension. The return value can also be annotated in the same way.

### Simple example

```python
import torch
from .tensor_shape_assert import check_tensor_shapes, ShapedTensor

@check_tensor_shapes()
def my_simple_func(
x: ShapedTensor["a b 3"],
y: ShapedTensor["b 2"]
) -> ShapedTensor["a"]:

z = x[:, :, :2] + y[None]
return (z[:, :, 0] * z[:, :, 1]).sum(dim=1)
```

Calling it like this
```python
my_simple_func(torch.zeros(5, 4, 3), y=torch.zeros(4, 2)) # works
```
passes the test, because ``a=5 and b=4`` matches for both input and output annotations.

For
```python
my_simple_func(torch.zeros(5, 4, 3), y=torch.zeros(4, 3)) # fails
```
the test fails, because `y` is expected to have length 2 in the second dimension.

### Integers

* Sizes can be defined explicitly as an integer, e.g. ``"5 3"`` (only arrays of shape ``(5, 3)`` are valid)
Expand Down Expand Up @@ -44,51 +80,11 @@ There are convenience functions that access the current states of the shape vari

You can even go one step further and do a check tensors inside the wrapped function directly with ``assert_shape_here(x, <desc>)``, which will run a check on the object or shape ``x`` given the descriptor and add previously unseen variables in the descriptor to the state inside the wrapped function. This way you can check the output of the function against tensor shapes that only appear in the body of the function.

## Installation

Currently, the package can only be installed directly from the repository with
```bash
pip install git+https://github.com/leifvan/tensor-shape-assert
```

## Compatibility

While the examples below are using PyTorch, *tensor-shape-assert* requires very minimal functionality and is compatible with any array class that has a ``shape`` method, which includes popular frameworks such as NumPy, TensorFlow, Jax and more generally frameworks that conform to the [Python array API standard](https://data-apis.org/array-api/latest/).

## Examples

Here are two examples that demonstrate how the annotation works.

### Simple example

```python
import torch
from .tensor_shape_assert import check_tensor_shapes, ShapedTensor

@check_tensor_shapes()
def my_simple_func(
x: ShapedTensor["a b 3"],
y: ShapedTensor["b 2"]
) -> ShapedTensor["a"]:

z = x[:, :, :2] + y[None]
return (z[:, :, 0] * z[:, :, 1]).sum(dim=1)
```

Calling it like this
```python
my_simple_func(torch.zeros(5, 4, 3), y=torch.zeros(4, 2)) # works
```
passes the test, because ``a=5 and b=4`` matches for both input and output annotations.

For
```python
my_simple_func(torch.zeros(5, 4, 3), y=torch.zeros(4, 3)) # fails
```
the test fails, because `y` is expected to have length 2 in the second dimension.

---
### Complex example
## More Examples

The complex example additionally contains tuple and optional annotations.
```python
Expand Down Expand Up @@ -186,6 +182,25 @@ def my_func(x: ShapedTensor["n k"], k: int):
my_func(torch.zeros(10, 2), k=2) # works
my_func(torch.zeros(10, 2), k=3) # works
```
---
### Type safety (new in version 0.3)

If you are using static type checkers like MyPy, you can use the more verbose but type safe literal syntax. For this, you are also required to specify the array type. You can either do this manually with ``ShapedLiteral``, or use the predefined aliases ``ShapedTorchLiteral``, ``ShapedNumpyLiteral``. The example will show both options for PyTorch.

```python
from typing import Literal as L

@check_tensor_shapes()
def my_simple_func(
x: ShapedTorchLiteral[L["a b 3"]],
y: ShapedTorchLiteral[L["b 2"]]
) -> ShapedLiteral[torch.Tensor, L["a"]]:

z = x[:, :, :2] + y[None]
return (z[:, :, 0] * z[:, :, 1]).sum(dim=1)
```

Another benefit from using the typed version is that tooltips in VS Code are more helpful, as they can pass trough the ``Literal`` string. This way you can check the annotated shape without having to open the file with the annotated code.

## Known bugs
* [ ] ``get_shape_variables`` does not work if checks are disabled. This should be possible but give a performance warning, recommending not to use this feature in performance-critical applications.
Expand All @@ -208,19 +223,19 @@ reraise
* [x] donnx
* [x] sparse
* ~~[ ] cupy~~ (we leave this out for now because it requires CUDA)
* [ ] check compatibility with static type checkers
* [x] check compatibility with static type checkers
* [ ] rewrite README to give a cleaner overview over the features
* [ ] support union of shape descriptors (but this might break the current simplicity)
* [ ] benchmark speed to understand impact in tight loops
* [x] compatibility for torch.compile (or at least auto-disable check)
* ~~[ ] device annotation~~ (device definition not standardized in Python array API 2024.12, see [this section of the specifications](https://data-apis.org/array-api/2024.12/design_topics/device_support.html#device-support))
* [ ] add variable names for dtype
* [ ] add more helpful message when parameter / output is not the expected type
* [ ] (maybe instead of the one before) catch and reraise all exceptions inside
* [x] (maybe instead of the one before) catch and reraise all exceptions inside
wrapper and reraise with additional info about exception location
* [ ] improve hints for static type checking (currently it assumes either torch
* [x] improve hints for static type checking (currently it assumes either torch
or the object just having a .shape parameter)
* [ ] come up with a way to allow union of shape descriptors
* [ ] make get_shape_variables work in check modes "never" and "once" without
performance overhead
* [ ] work on performance overhead in general
* [ ] work on performance overhead in general

Binary file added assets/benchmark_overhead_additional_args_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/benchmark_overhead_additional_args_10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/benchmark_overhead_additional_args_100.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 0 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,3 @@ def typecheck(session):
session.install("numpy", "torch")
session.install("mypy")
session.run("mypy", "src")

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="tensor-shape-assert",
version="0.2.6",
version="0.3.0",
description="A simple runtime assert library for tensor-based frameworks.",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
137 changes: 86 additions & 51 deletions speedtest.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,96 @@
from src.tensor_shape_assert import ShapedTensor, check_tensor_shapes, set_global_check_mode
import torch
from src.tensor_shape_assert import ShapedTensor, check_tensor_shapes
from test_utils import get_library_by_name, NAME_LIBRARY_MAP
from time import time
from typing import NamedTuple
from tqdm import tqdm
from tabulate import tabulate


def benchmark(f, num_runs, num_additonal_args):
start_time = time()

add = [torch.zeros(10, 5) for _ in range(num_additonal_args)]

for _ in range(num_runs):
f(
torch.zeros((10, 20, 30)),
torch.zeros((20, 2)),
*add
)
return (time() - start_time) / num_runs

def func(x: ShapedTensor["a b c"], y: ShapedTensor["b d"], *args: tuple[ShapedTensor["10 5"], ...]) -> ShapedTensor["a d"]:
k = sum(args)
z = x[..., None] * y[None, :, None] # a b c d
return z.sum(dim=(1, 2)) # a d

if __name__ == "__main__":

num_runs = 10000
for num_add_args in (0, 10, 100):
def benchmark(f, num_additonal_args, xp):

func_always = check_tensor_shapes(check_mode="always")(func)
func_once = check_tensor_shapes(check_mode="once")(func)
func_never = check_tensor_shapes(check_mode="never")(func)
# prepare args
x = xp.zeros((10, 20, 30))
y = xp.zeros((20, 2))
add = [xp.zeros((10, 5)) for _ in range(num_additonal_args)]

# full shape checking
# some warmup
for _ in range(2 ** 7):
f(x, y, *add)

duration_with = benchmark(func_always, num_runs=num_runs, num_additonal_args=num_add_args)

# check once

set_global_check_mode("once")

duration_global_check_once = benchmark(func_once, num_runs=num_runs, num_additonal_args=num_add_args)

# disabled global checking

set_global_check_mode("never")

duration_global_check_never = benchmark(func_never, num_runs=num_runs, num_additonal_args=num_add_args)
# actual benchmark
start_time = time()
for i in tqdm(range(2 ** 16)):
f(x, y, *add)

# no annotations
if i % 500 == 0:
if time() - start_time > 5:
break
return (time() - start_time) / (i + 1) # type: ignore

duration_not_annotated = benchmark(func, num_runs=num_runs, num_additonal_args=num_add_args)
def func(
x: ShapedTensor["a b c"],
y: ShapedTensor["b d"],
*args: tuple[ShapedTensor["10 5"], ...]
) -> ShapedTensor["a d"]:
z = x[:, :, :, None] * y[None, :, None]
return z.sum(axis=(1, 2)) # type: ignore

# print results and compute percentage

print(f"\nBenchmarking with {num_add_args} additional arguments:")
print(f"Duration without annotations: {duration_not_annotated*1000:.4f} ms, 100.00%")
print(f"Duration with global check never: {duration_global_check_never*1000:.4f} ms, {duration_global_check_never/duration_not_annotated*100:5.2f}%")
print(f"Duration with global check once: {duration_global_check_once*1000:.4f} ms, {duration_global_check_once/duration_not_annotated*100:5.2f}%")
print(f"Duration with shape checking: {duration_with*1000:.4f} ms, {duration_with/duration_not_annotated*100:5.2f}%")
print("-" * 50)
if __name__ == "__main__":
func_always = check_tensor_shapes(check_mode="always")(func)
func_once = check_tensor_shapes(check_mode="once")(func)
func_never = check_tensor_shapes(check_mode="never")(func)

results = []

def add_results(lib, num_add_args, mode, duration):

if mode == "not_annotated":
duration_not_annotated = duration
else:
duration_not_annotated = None
for r in results:
if r["library"] == lib and r["additional args"] == num_add_args and r["check mode"] == "not_annotated":
duration_not_annotated = r["duration (ms)"] / 1000
break

results.append({
"library": lib,
"additional args": num_add_args,
"check mode": mode,
"duration (ms)": duration * 1000,
"overhead (ms)": (duration - duration_not_annotated) * 1000,
"relative (%)": duration / duration_not_annotated
})

for lib in NAME_LIBRARY_MAP.keys():
try:
xp = get_library_by_name(lib)
except ModuleNotFoundError:
print("skipping", lib, ", not installed")
continue

try:
for num_add_args in (0, 10, 100):

add_results(lib, num_add_args, "not_annotated", benchmark(func, num_additonal_args=num_add_args, xp=xp))
add_results(lib, num_add_args, "never", benchmark(func_never, num_additonal_args=num_add_args, xp=xp))
add_results(lib, num_add_args, "once", benchmark(func_once, num_additonal_args=num_add_args, xp=xp))
add_results(lib, num_add_args, "always", benchmark(func_always, num_additonal_args=num_add_args, xp=xp))

# if lib == "torch":
# import torch
# add_results("torch-compile", num_add_args, "not_annotated", benchmark(torch.compile(func), num_additonal_args=num_add_args, xp=xp))
# add_results("torch-compile", num_add_args, "never", benchmark(torch.compile(func_never), num_additonal_args=num_add_args, xp=xp))
# add_results("torch-compile", num_add_args, "once", benchmark(torch.compile(func_once), num_additonal_args=num_add_args, xp=xp))
# add_results("torch-compile", num_add_args, "always", benchmark(torch.compile(func_always), num_additonal_args=num_add_args, xp=xp))

print(tabulate(results, tablefmt="github", floatfmt=(None, None, None, ".5f", ".4e", ".2%"))) # type: ignore
except Exception as e:
print("Error benchmarking", lib, ":", e)

# write results to csv file

with open("benchmark_results.csv", "w") as f:
f.write("library,additional args,check mode,duration (ms),relative (%),overhead (ms)\n")
for r in results:
f.write(f"{r['library']},{r['additional args']},{r['check mode']},{r['duration (ms)']:.5f},{r['relative (%)']:.2%},{r['overhead (ms)']:.10f}\n")
149 changes: 149 additions & 0 deletions speedtest_plot.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/tensor_shape_assert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
check_tensor_shapes, get_shape_variables, assert_shape_here,
set_global_check_mode
)
from .types import ShapedTensor, ShapedTorchLiteral, ShapedNumpyLiteral
from .types import ShapedTensor, ShapedTorchLiteral, ShapedNumpyLiteral, ShapedLiteral
from .types import ScalarTensor # type: ignore
13 changes: 11 additions & 2 deletions src/tensor_shape_assert/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def __init__(self, *args, **kwargs):
)

def __class_getitem__(cls, key):
# check if it is a tuple annotation
if isinstance(key, tuple):
if len(key) != 2:
raise TypeError(
"ShapedTensor can only be parameterized with a single "
"shape descriptor string or a tuple of (type, shape)."
)
key = key[1]

# check if it is a literal
if get_origin(key) is Literal:
key = " ".join(get_args(key))
Expand All @@ -138,7 +147,7 @@ def __class_getitem__(cls, key):
# torch

try:
from array_api_compat import torch
import torch
ShapedTorchLiteral = TypeAliasType(
'ShapedTorchLiteral',
ShapedLiteral[torch.Tensor, S],
Expand All @@ -150,7 +159,7 @@ def __class_getitem__(cls, key):
# numpy

try:
from array_api_compat import numpy
import numpy
ShapedNumpyLiteral = TypeAliasType(
'ShapedNumpyLiteral',
ShapedLiteral[numpy.ndarray, S],
Expand Down
Loading