diff --git a/hiddenlayer/pytorch_builder.py b/hiddenlayer/pytorch_builder.py index 702c167..39eeccc 100644 --- a/hiddenlayer/pytorch_builder.py +++ b/hiddenlayer/pytorch_builder.py @@ -53,13 +53,7 @@ def get_shape(torch_node): # https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2 # TODO: find a better way to extract output shape # TODO: Assuming the node has one output. Update if we encounter a multi-output node. - m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs()))) - if m: - shape = m.group(1) - shape = shape.split(",") - shape = tuple(map(int, shape)) - else: - shape = None + shape = torch_node.output().type().sizes() return shape