diff --git a/models/trans_utils/transformer.py b/models/trans_utils/transformer.py index 0d2a438..f4933a4 100644 --- a/models/trans_utils/transformer.py +++ b/models/trans_utils/transformer.py @@ -47,19 +47,18 @@ def _reset_parameters(self): def forward(self, src, mask, query_embed, pos_embed): # flatten NxCxHxW to HWxNxC bs, c, h, w = src.shape - src = src.flatten(2).permute(0, 2, 1)#(2, 0, 1) - pos_embed = pos_embed.flatten(2).permute(0, 2, 1) + # src = src.flatten(2).permute(0, 2, 1) #(2, 0, 1) + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) query_embed = query_embed.unsqueeze(1).repeat(bs, 1, 1) - mask = mask.flatten(1).permute(1,0) - #print('mask', mask.shape) + if mask is not None: + mask = mask.flatten(1) tgt = torch.zeros_like(query_embed) memory = self.encoder(src, src_key_padding_mask=mask, pos=None) hs = memory # for fast inference - #hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, - # pos=pos_embed, query_pos=query_embed) - return hs.transpose(1, 2), memory.permute(0, 2, 1).view(bs, c, h, w) + return hs.transpose(1, 2), memory.permute(1, 2, 0).reshape(bs, c, h, w) class TransformerEncoder(nn.Module):