Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions models/inception_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,18 @@
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv
from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from timm.models.registry import register_model
from torch.nn.init import _calculate_fan_in_and_fan_out
import math
import warnings
from timm.models.layers.helpers import to_2tuple
try:
from timm.models import register_model
except:
from timm.models.registry import register_model

try:
from timm.layers.helpers import to_2tuple
except:
from timm.models.layers.helpers import to_2tuple


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -361,7 +368,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule

self.patch_embed = FirstPatchEmbed(in_chans=in_chans, embed_dim=embed_dims[0])
self.num_patches1 = num_patches = img_size // 4
self.num_patches1 = num_patches = 224 // 4
self.pos_embed1 = nn.Parameter(torch.zeros(1, num_patches, num_patches, embed_dims[0]))
self.blocks1 = nn.Sequential(*[
Block(
Expand Down