From d965293cc4150e20717bb564892ee975f08e8cba Mon Sep 17 00:00:00 2001 From: Martin Ahrnbom Date: Sat, 12 Jun 2021 18:27:47 +0200 Subject: [PATCH] Update pytorch_builder.py Fixed shapes for PyTorch plots. See https://github.com/waleedka/hiddenlayer/issues/83 for more info. --- hiddenlayer/pytorch_builder.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/hiddenlayer/pytorch_builder.py b/hiddenlayer/pytorch_builder.py index 702c167..39eeccc 100644 --- a/hiddenlayer/pytorch_builder.py +++ b/hiddenlayer/pytorch_builder.py @@ -53,13 +53,7 @@ def get_shape(torch_node): # https://discuss.pytorch.org/t/node-output-shape-from-trace-graph/24351/2 # TODO: find a better way to extract output shape # TODO: Assuming the node has one output. Update if we encounter a multi-output node. - m = re.match(r".*Float\(([\d\s\,]+)\).*", str(next(torch_node.outputs()))) - if m: - shape = m.group(1) - shape = shape.split(",") - shape = tuple(map(int, shape)) - else: - shape = None + shape = torch_node.output().type().sizes() return shape