Skip to content

Is it possible to pass a model a torch.Tensor and return a torch.Tensor? #1848

@billbrod

Description

@billbrod

I am trying to figure out how to use brainscore's models with plenoptic. Plenoptic requires a model that accepts a 4d torch.Tensor of images (shape batch, channel, height, width) and returns a 3d or 4d torch.Tensor of images (in a way that torch knows how to autodiff). I cannot figure out how to interact with brainscore's models in that way. Is there a way to do so?

Here is what I've tried:

import brainscore_vision
model = brainscore_vision.load_model("alexnet_training_seed_01")
img = torch.rand(1,3,256,256)
rep = model.activations_model.get_activations(img, ["features.0"])["features.0"]

rep is then a numpy array (the shape is fine), not a torch tensor. Looking at model_helpers/activations/pytorch.py, it looks like this is because _tensor_to_numpy is registered as a hook -- is it possible to disable this hook?

As an example of what I'm talking about, here's how one would prepare a model from timm to use with plenoptic, with an optional image transform and specific layers (using torchvision's feature extractor):

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torchvision.models.feature_extraction import create_feature_extractor

class IntermediateOutputResnet(nn.Module):
    def __init__(self, model: nn.Module, return_node: str, transform: Optional[Callable] = None):
        super().__init__()
        self.return_node = return_node
        self.extractor = create_feature_extractor(model, return_nodes=[return_node])
        self.model = model
        self.transform = transform

    def forward(self, x):
        if self.transform is not None:
            x = self.transform(x)
        return self.extractor(x)[self.return_node]


model = timm.create_model("hf-hub:nateraw/resnet50-oxford-iiit-pet", pretrained=True)
model.eval()
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
test_model = IntermediateOutputResnet(model, "layer2", transform)

test_model(torch.rand(1,3,256,256))
# returns a 4d tensor

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions