diff --git a/instanttensor/_impl.py b/instanttensor/_impl.py index 01bab8f..245b33e 100644 --- a/instanttensor/_impl.py +++ b/instanttensor/_impl.py @@ -554,7 +554,7 @@ def tensors(self) -> Generator[tuple[str, torch.Tensor], None, None]: tensor_size = get_tensor_size(shape, torch_dtype) dl_tensor = instanttensor._C.get_dl_tensor(self.loader_handle, tensor_index, tensor_size) # always returns int8 tensor tensor_int8 = torch.from_dlpack(dl_tensor) - tensor = tensor_int8.view(torch_dtype).view(*shape) + tensor = tensor_int8.view(torch_dtype).view(torch.Size(shape)) if tensor.data_ptr() % tensor.element_size() != 0: raise ValueError(f"Tensor {name} address {tensor.data_ptr():#x} is not aligned to dtype {torch_dtype} size {tensor.element_size()}B")