-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathStitcher.py
More file actions
308 lines (253 loc) · 12.5 KB
/
Stitcher.py
File metadata and controls
308 lines (253 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
__author__ = "Amos Decker"
__date__ = "Summer 2019"
"""
A working python stitcher for ir images. Uses visible light images to find the keypoints and do all of the adjustments,
but then swaps out the visible light images for ir (or mx) images.
The stitch_fast() is a modified version of these:
https://raw.githubusercontent.com/opencv/opencv/master/samples/python/stitching_detailed.py
https://raw.githubusercontent.com/opencv/opencv/master/samples/cpp/stitching_detailed.cpp
stitch() relies on a modified opencv that can be found here: https://drive.google.com/drive/folders/1m4QldXdOSNnNFJIXDueIMOLjRZhhL6Ji?usp=sharing
"""
import numpy as np
import cv2
import time
import os
from datetime import datetime
def stitch_fast(data, use_kaze=False):
"""waaay faster than the other stitch() method. The results are slightly worse, but for one set of 45 images it
took 33 seconds using this stitch_fast() compared to 607 seconds using stitch()"""
use_gpu = False
work_megapix = -1
seam_megapix = 0.1
compose_megapix = -1
wave_correct = "horiz" # "vert"
warp_type = "cylindrical" # "spherical" #"mercator" #"cylindrical"
match_conf = 0.3
blend_type = "multiband" # feather # multiband #for no blending at all put any other string, like "no"
blend_strength = 5
if use_kaze:
finder = cv2.KAZE.create()
else:
finder = cv2.ORB.create()
seam_work_aspect = 1
features = []
images = []
print("getting image features and scaling images...")
work_scale = -1
seam_scale = -1
for i in range(len(data[0][1])):
full_img = data[0][1][i]
if work_megapix < 0:
img = full_img
work_scale = 1
else:
if work_scale == -1: # if it hasn't been set yet
work_scale = min(1.0, np.sqrt(work_megapix * 1e6 / (full_img.shape[0] * full_img.shape[1])))
img = cv2.resize(src=full_img, dsize=None, fx=work_scale, fy=work_scale,
interpolation=cv2.INTER_LINEAR_EXACT)
if seam_scale == -1: # if it hasn't been set yet
seam_scale = min(1.0, np.sqrt(seam_megapix * 1e6 / (full_img.shape[0] * full_img.shape[1])))
seam_work_aspect = seam_scale / work_scale
features.append(cv2.detail.computeImageFeatures2(finder, img)) # gets image features
images.append(cv2.resize(src=full_img, dsize=None, fx=seam_scale, fy=seam_scale, interpolation=cv2.INTER_LINEAR_EXACT))
print("getting matches info...")
matcher = cv2.detail.BestOf2NearestMatcher_create(use_gpu, match_conf)
# setting the matching mask makes it a lot faster because it tells it the order of images:
# https://software.intel.com/sites/default/files/Fast%20Panorama%20Stitching.pdf
match_mask = np.zeros((len(features), len(features)), np.uint8)
for i in range(len(data[0][1]) - 1):
match_mask[i, i + 1] = 1
matches_info = matcher.apply2(features, match_mask)
matcher.collectGarbage()
num_images = len(data[0][1])
# get camera params
print("finding camera params...")
estimator = cv2.detail_HomographyBasedEstimator()
b, cameras = estimator.apply(features, matches_info, None)
if not b:
print("Homography estimation failed.")
exit()
for cam in cameras:
cam.R = cam.R.astype(np.float32)
# adjust camera params
print("adjusting camera params...")
adjuster = cv2.detail_BundleAdjusterRay()
adjuster.setConfThresh(1)
b, cameras = adjuster.apply(features, matches_info, cameras)
if not b:
print("Camera parameters adjusting failed.")
exit()
# get warped image scale
print("getting warped image scale...")
focals = []
for cam in cameras:
focals.append(cam.focal)
sorted(focals)
if len(focals) % 2 == 1:
warped_image_scale = focals[len(focals) // 2]
else:
warped_image_scale = (focals[len(focals) // 2] + focals[len(focals) // 2 - 1]) / 2
# wave correct. see section 5 of this paper: http://matthewalunbrown.com/papers/ijcv2007.pdf
print("wave correction...")
rmats = []
for cam in cameras:
rmats.append(np.copy(cam.R))
if wave_correct == "horiz":
rmats = cv2.detail.waveCorrect(rmats, cv2.detail.WAVE_CORRECT_HORIZ)
elif wave_correct == "vert":
rmats = cv2.detail.waveCorrect(rmats, cv2.detail.WAVE_CORRECT_VERT)
for i in range(len(cameras)):
cameras[i].R = rmats[i]
masks_warped = []
images_warped = []
masks = []
# create masks
for i in range(num_images):
um = cv2.UMat(255 * np.ones((images[i].shape[0], images[i].shape[1]), np.uint8))
masks.append(um)
# warp images and masks
print("warping...")
warper = cv2.PyRotationWarper(warp_type, warped_image_scale * seam_work_aspect)
print()
corners = []
for i in range(num_images):
K = cameras[i].K().astype(np.float32)
K[0, 0] *= seam_work_aspect
K[0, 2] *= seam_work_aspect
K[1, 1] *= seam_work_aspect
K[1, 2] *= seam_work_aspect
corner, image_wp = warper.warp(images[i], K, cameras[i].R, cv2.INTER_LINEAR, cv2.BORDER_REFLECT)
images_warped.append(image_wp)
corners.append(corner)
p, mask_wp = warper.warp(masks[i], K, cameras[i].R, cv2.INTER_NEAREST, cv2.BORDER_CONSTANT)
masks_warped.append(mask_wp.get())
# convert type
images_warped_f = []
for img in images_warped:
imgf = img.astype(np.float32)
images_warped_f.append(imgf)
# blends each type of image and saves them
for res_name, imgs in data:
# compensate for exposure -- NOTE it doesn't do this
# but see https://docs.opencv.org/4.1.0/d2/d37/classcv_1_1detail_1_1ExposureCompensator.html for options
compensator = cv2.detail.ExposureCompensator_createDefault(cv2.detail.ExposureCompensator_NO)
compensator.feed(corners=corners, images=images_warped, masks=masks_warped)
# find seams in the images -- NOTE just as with exposure this doesn't actually do anything
# but there are other possibilities here: https://docs.opencv.org/4.1.0/d7/d09/classcv_1_1detail_1_1SeamFinder.html#aaefc003adf1ebec13867ad9203096f6fa55b2503305e94168c0b36c4531f288d7
seam_finder = cv2.detail.SeamFinder_createDefault(cv2.detail.SeamFinder_NO)
seam_finder.find(images_warped_f, corners, masks_warped)
sizes = []
blender = None
compose_scale = -1
for i in range(num_images):
full_img = imgs[i]
if compose_scale == -1: # if it hasn't been set yet
corners = []
if compose_megapix > 0:
compose_scale = min(1.0, np.sqrt(compose_megapix * 1e6 / (full_img.shape[0] * full_img.shape[1])))
else:
compose_scale = 1
compose_work_aspect = compose_scale / work_scale
warped_image_scale *= compose_work_aspect
warper = cv2.PyRotationWarper(warp_type, warped_image_scale)
for c in range(len(data[0][1])):
cameras[c].focal *= compose_work_aspect
cameras[c].ppx *= compose_work_aspect
cameras[c].ppy *= compose_work_aspect
sz = (data[0][1][c].shape[1] * compose_scale, data[0][1][c].shape[0] * compose_scale)
K = cameras[c].K().astype(np.float32)
roi = warper.warpRoi(sz, K, cameras[c].R)
corners.append(roi[0:2])
sizes.append(roi[2:4])
if abs(compose_scale - 1) > 1e-1:
img = cv2.resize(src=full_img, dsize=None, fx=compose_scale, fy=compose_scale,
interpolation=cv2.INTER_LINEAR_EXACT)
else:
img = full_img
K = cameras[i].K().astype(np.float32)
corner, image_warped = warper.warp(img, K, cameras[i].R, cv2.INTER_LINEAR, cv2.BORDER_REFLECT)
mask = 255 * np.ones((img.shape[0], img.shape[1]), np.uint8)
p, mask_warped = warper.warp(mask, K, cameras[i].R, cv2.INTER_NEAREST, cv2.BORDER_CONSTANT)
compensator.apply(i, corners[i], image_warped, mask_warped)
image_warped_s = image_warped.astype(np.int16)
dilated_mask = cv2.dilate(masks_warped[i], None)
seam_mask = cv2.resize(dilated_mask, (mask_warped.shape[1], mask_warped.shape[0]), 0, 0,
cv2.INTER_LINEAR_EXACT)
mask_warped = cv2.bitwise_and(seam_mask, mask_warped)
# setup blender -- this sets up the part that combines the images by laying them on top of each other
if blender is None:
blender = cv2.detail.Blender_createDefault(cv2.detail.Blender_NO)
dst_sz = cv2.detail.resultRoi(corners=corners, sizes=sizes)
blend_width = np.sqrt(dst_sz[2] * dst_sz[3]) * blend_strength / 100
if blend_width < 1:
print("no blend")
blender = cv2.detail.Blender_createDefault(cv2.detail.Blender_NO)
elif blend_type == "multiband": # I think this is generally better
print(blend_type)
blender = cv2.detail_MultiBandBlender()
elif blend_type == "feather": # mixes images at borders
print(blend_type)
blender = cv2.detail_FeatherBlender()
blender.setSharpness(1.0 / blend_width)
blender.prepare(dst_sz)
blender.feed(image_warped_s, mask_warped, corners[i])
result = None
result_mask = None
print("blending..." + res_name)
result, result_mask = blender.blend(result, result_mask)
print("SIZE:", result.shape)
cv2.imwrite(res_name, result)
def stitch(data, use_kaze=False):
""""this works ONLY WITH the modified opencv c++ code stitcher.stitch() takes vl images as first param and then
creates a panorama using the images from the second param. So if you want a vl pano, do
stitcher.stitch(vl_images, vl_images) and if you want an ir pano do this: stitcher.stitch(vl_images, ir_images)"""
print("stitching...")
stitcher = cv2.Stitcher_create()
if use_kaze:
stitcher.setFeaturesFinder(cv2.KAZE.create()) # sometimes does a better job, but can take longer. Alternative is ORB
stitcher.setPanoConfidenceThresh(1.0)
match_mask = np.zeros((len(data[0][1]), len(data[0][1])), np.uint8)
for i in range(len(data[0][1]) - 1):
match_mask[i, i + 1] = 1
stitcher.setMatchingMask(match_mask)
print("vl...")
status, stitched_vl = stitcher.stitch(data[0][1], data[0][1])
if status == 0:
print("SIZE:", stitched_vl.shape)
cv2.imwrite(data[0][0], stitched_vl)
print("ir...")
status, stitched_ir = stitcher.composePanorama(data[1][1])
if status == 0:
cv2.imwrite(data[1][0], stitched_ir)
print("mx...")
status, stitched_mx = stitcher.composePanorama(data[2][1])
if status == 0:
cv2.imwrite(data[2][0], stitched_mx)
if __name__ == "__main__":
num_imgs = 45
# grab all the pano folders
pano_dirs = []
for dir in next(os.walk('.'))[1]:
if dir[:4] == "pano":
pano_dirs.append(dir)
print(pano_dirs)
start, end = 0, 1
print(start, end)
for d in range(start, end):
kaze = True
directory = pano_dirs[d]
start = time.time()
print("\n\n---------------", d)
print("KAZE:", kaze)
print(datetime.utcfromtimestamp(start - 4 * 3600).strftime('%Y-%m-%d %H:%M:%S'))
print(directory)
vl_im = [cv2.imread(directory + "/vl{0}.png".format(i)) if i > 9 else cv2.imread(directory + "/vl0{0}.png".format(i)) for i in range(num_imgs)]
ir_im = [cv2.imread(directory + "/ir{0}.png".format(i)) if i > 9 else cv2.imread(directory + "/ir0{0}.png".format(i)) for i in range(num_imgs)]
mx_im = [cv2.imread(directory + "/mx{0}.png".format(i)) if i > 9 else cv2.imread(directory + "/mx0{0}.png".format(i)) for i in range(num_imgs)]
types = ["vl", "ir", "mx"]
data = [("output/{0}-{1}.png".format(directory, types[0]), vl_im),
("output/{0}-{1}.png".format(directory, types[1]), ir_im),
("output/{0}-{1}.png".format(directory, types[2]), mx_im)]
# stitch(data, use_kaze=kaze) # slow, requires modified opencv library
stitch_fast(data, use_kaze=kaze) # fast, uses opencv-python from pip
print("total time (secs):", (time.time() - start))