forked from qianqianwang68/caps
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextract_features_v2.py
More file actions
executable file
·86 lines (72 loc) · 3.07 KB
/
extract_features_v2.py
File metadata and controls
executable file
·86 lines (72 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
from torch.utils.data import Dataset
import os
import numpy as np
import cv2
import skimage.io as io
import torchvision.transforms as transforms
import config
import math
from tqdm import tqdm
from CAPS.caps_model_v2 import CAPSModel
class HPatchDataset(Dataset):
def __init__(self, imdir):
self.transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225)),
])
self.imfs = []
for f in os.listdir(imdir):
scene_dir = os.path.join(imdir, f)
self.imfs.extend([os.path.join(scene_dir, '{}.ppm').format(ind) for ind in range(1, 7)])
def __getitem__(self, item):
imf = self.imfs[item]
im = io.imread(imf)
im_tensor = self.transform(im)
c, h, w = im_tensor.shape
div_num = 16
if (h % div_num) != 0 or (w % div_num) != 0:
nh = math.ceil(h / div_num) * div_num
nw = math.ceil(w / div_num) * div_num
bot_pad = nh - h
rig_pad = nw - w
m = torch.nn.ReflectionPad2d((0, rig_pad, 0, bot_pad))
im_tensor = m(im_tensor)
# print("padding h :{}, padding w:{}".format(bot_pad, rig_pad))
# print("padding shape: {}".format(im_tensor.shape))
# using sift keypoints
sift = cv2.SIFT_create()
gray = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
kpts = sift.detect(gray)
kpts = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts])
coord = torch.from_numpy(kpts).float()
out = {'im': im_tensor, 'coord': coord, 'imf': imf}
return out
def __len__(self):
return len(self.imfs)
if __name__ == '__main__':
# example code for extracting features for HPatches dataset, SIFT keypoint is used
args = config.get_args()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset = HPatchDataset(args.extract_img_dir)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.workers)
model = CAPSModel(args)
outdir = args.extract_out_dir
os.makedirs(outdir, exist_ok=True)
with torch.no_grad():
for data in tqdm(data_loader):
im = data['im'].to(device)
img_path = data['imf'][0]
coord = data['coord'].to(device)
feat_f = model.extract_features(im, coord)
desc = feat_f.squeeze(0).detach().cpu().numpy()
kpt = coord.cpu().numpy().squeeze(0)
save_folder = os.path.join(outdir, os.path.basename(os.path.dirname(img_path)))
os.makedirs(save_folder, exist_ok=True)
save_file = os.path.join(save_folder, "{}.caps_effiv3".format(os.path.basename(img_path)))
with open(save_file, 'wb') as output_file:
np.savez(
output_file,
keypoints=kpt,
scores=[],
descriptors=desc)