diff --git a/.gitignore b/.gitignore
index 8fa68fe..f68a4af 100644
--- a/.gitignore
+++ b/.gitignore
@@ -47,3 +47,5 @@ all_sequences/
ckpts/
logs/
# configs/
+
+*.mp4
diff --git a/README.md b/README.md
index c6534e6..4e0290c 100755
--- a/README.md
+++ b/README.md
@@ -1,7 +1,13 @@
-# CoDeF: Content Deformation Fields for Temporally Consistent Video Processing
-
+# HeyGen ❤️ Superwoman
+
+The slides for the project: https://docs.google.com/presentation/d/1y4qP2dALviAZm_D4BU_udy2IHE-1A6nwD4zS3RAqfvc/edit?usp=sharing
+
+The config for a HeyGen avatar stylization: `configs/heygen/base.yaml`
+
+---------
+## CoDeF: Content Deformation Fields for Temporally Consistent Video Processing
[Hao Ouyang](https://ken-ouyang.github.io/)\*, [Qiuyu Wang](https://github.com/qiuyu96/)\*, [Yuxi Xiao](https://henry123-boy.github.io/)\*, [Qingyan Bai](https://scholar.google.com/citations?user=xUMjxi4AAAAJ&hl=en), [Juntao Zhang](https://github.com/JordanZh), [Kecheng Zheng](https://scholar.google.com/citations?user=hMDQifQAAAAJ), [Xiaowei Zhou](https://xzhou.me/),
[Qifeng Chen](https://cqf.io/)†, [Yujun Shen](https://shenyujun.github.io/)† (*equal contribution, †corresponding author)
diff --git a/configs/heygen/base.yaml b/configs/heygen/base.yaml
new file mode 100755
index 0000000..531bb04
--- /dev/null
+++ b/configs/heygen/base.yaml
@@ -0,0 +1,26 @@
+mask_dir: null
+flow_dir: null
+
+img_wh: [1080, 1080]
+canonical_wh: [1080, 1080]
+
+lr: 0.001
+bg_loss: 0.0
+
+ref_idx: null # 0
+
+N_xyz_w: [8, 8]
+flow_loss: 1
+flow_step: -1
+self_bg: False
+
+deform_hash: True
+vid_hash: True
+
+num_steps: 200000
+decay_step: [2500, 5000, 7500]
+annealed_begin_step: 4000
+annealed_step: 4000
+save_model_iters: 2000
+
+fps: 25
diff --git a/data_preprocessing/RAFT/core/corr.py b/data_preprocessing/RAFT/core/corr.py
index cffcbc8..42a7817 100755
--- a/data_preprocessing/RAFT/core/corr.py
+++ b/data_preprocessing/RAFT/core/corr.py
@@ -19,10 +19,10 @@ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
- corr = corr.reshape(batch*h1*w1, dim, h2, w2)
-
+ corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
+
self.corr_pyramid.append(corr)
- for i in range(self.num_levels-1):
+ for i in range(self.num_levels - 1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
@@ -34,12 +34,12 @@ def __call__(self, coords):
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
- dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
- dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
+ dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
+ dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
- centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
- delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
+ centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
@@ -52,12 +52,12 @@ def __call__(self, coords):
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
- fmap1 = fmap1.view(batch, dim, ht*wd)
- fmap2 = fmap2.view(batch, dim, ht*wd)
-
- corr = torch.matmul(fmap1.transpose(1,2), fmap2)
+ fmap1 = fmap1.view(batch, dim, ht * wd)
+ fmap2 = fmap2.view(batch, dim, ht * wd)
+
+ corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
- return corr / torch.sqrt(torch.tensor(dim).float())
+ return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
@@ -83,7 +83,7 @@ def __call__(self, coords):
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
- corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
+ (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
diff --git a/data_preprocessing/RAFT/core/datasets.py b/data_preprocessing/RAFT/core/datasets.py
index 3411fda..2513967 100755
--- a/data_preprocessing/RAFT/core/datasets.py
+++ b/data_preprocessing/RAFT/core/datasets.py
@@ -66,8 +66,8 @@ def __getitem__(self, index):
# grayscale images
if len(img1.shape) == 2:
- img1 = np.tile(img1[...,None], (1, 1, 3))
- img2 = np.tile(img2[...,None], (1, 1, 3))
+ img1 = np.tile(img1[..., None], (1, 1, 3))
+ img2 = np.tile(img2[..., None], (1, 1, 3))
else:
img1 = img1[..., :3]
img2 = img2[..., :3]
@@ -89,147 +89,203 @@ def __getitem__(self, index):
return img1, img2, flow, valid.float()
-
def __rmul__(self, v):
self.flow_list = v * self.flow_list
self.image_list = v * self.image_list
return self
-
+
def __len__(self):
return len(self.image_list)
-
+
class MpiSintel(FlowDataset):
- def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
+ def __init__(
+ self, aug_params=None, split="training", root="datasets/Sintel", dstype="clean"
+ ):
super(MpiSintel, self).__init__(aug_params)
- flow_root = osp.join(root, split, 'flow')
+ flow_root = osp.join(root, split, "flow")
image_root = osp.join(root, split, dstype)
- if split == 'test':
+ if split == "test":
self.is_test = True
for scene in os.listdir(image_root):
- image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
- for i in range(len(image_list)-1):
- self.image_list += [ [image_list[i], image_list[i+1]] ]
- self.extra_info += [ (scene, i) ] # scene and frame_id
+ image_list = sorted(glob(osp.join(image_root, scene, "*.png")))
+ for i in range(len(image_list) - 1):
+ self.image_list += [[image_list[i], image_list[i + 1]]]
+ self.extra_info += [(scene, i)] # scene and frame_id
- if split != 'test':
- self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
+ if split != "test":
+ self.flow_list += sorted(glob(osp.join(flow_root, scene, "*.flo")))
class FlyingChairs(FlowDataset):
- def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
+ def __init__(
+ self, aug_params=None, split="train", root="datasets/FlyingChairs_release/data"
+ ):
super(FlyingChairs, self).__init__(aug_params)
- images = sorted(glob(osp.join(root, '*.ppm')))
- flows = sorted(glob(osp.join(root, '*.flo')))
- assert (len(images)//2 == len(flows))
+ images = sorted(glob(osp.join(root, "*.ppm")))
+ flows = sorted(glob(osp.join(root, "*.flo")))
+ assert len(images) // 2 == len(flows)
- split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
+ split_list = np.loadtxt("chairs_split.txt", dtype=np.int32)
for i in range(len(flows)):
xid = split_list[i]
- if (split=='training' and xid==1) or (split=='validation' and xid==2):
- self.flow_list += [ flows[i] ]
- self.image_list += [ [images[2*i], images[2*i+1]] ]
+ if (split == "training" and xid == 1) or (
+ split == "validation" and xid == 2
+ ):
+ self.flow_list += [flows[i]]
+ self.image_list += [[images[2 * i], images[2 * i + 1]]]
class FlyingThings3D(FlowDataset):
- def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
+ def __init__(
+ self, aug_params=None, root="datasets/FlyingThings3D", dstype="frames_cleanpass"
+ ):
super(FlyingThings3D, self).__init__(aug_params)
- for cam in ['left']:
- for direction in ['into_future', 'into_past']:
- image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
+ for cam in ["left"]:
+ for direction in ["into_future", "into_past"]:
+ image_dirs = sorted(glob(osp.join(root, dstype, "TRAIN/*/*")))
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
- flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
+ flow_dirs = sorted(glob(osp.join(root, "optical_flow/TRAIN/*/*")))
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
for idir, fdir in zip(image_dirs, flow_dirs):
- images = sorted(glob(osp.join(idir, '*.png')) )
- flows = sorted(glob(osp.join(fdir, '*.pfm')) )
- for i in range(len(flows)-1):
- if direction == 'into_future':
- self.image_list += [ [images[i], images[i+1]] ]
- self.flow_list += [ flows[i] ]
- elif direction == 'into_past':
- self.image_list += [ [images[i+1], images[i]] ]
- self.flow_list += [ flows[i+1] ]
-
+ images = sorted(glob(osp.join(idir, "*.png")))
+ flows = sorted(glob(osp.join(fdir, "*.pfm")))
+ for i in range(len(flows) - 1):
+ if direction == "into_future":
+ self.image_list += [[images[i], images[i + 1]]]
+ self.flow_list += [flows[i]]
+ elif direction == "into_past":
+ self.image_list += [[images[i + 1], images[i]]]
+ self.flow_list += [flows[i + 1]]
+
class KITTI(FlowDataset):
- def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
+ def __init__(self, aug_params=None, split="training", root="datasets/KITTI"):
super(KITTI, self).__init__(aug_params, sparse=True)
- if split == 'testing':
+ if split == "testing":
self.is_test = True
root = osp.join(root, split)
- images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
- images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
+ images1 = sorted(glob(osp.join(root, "image_2/*_10.png")))
+ images2 = sorted(glob(osp.join(root, "image_2/*_11.png")))
for img1, img2 in zip(images1, images2):
- frame_id = img1.split('/')[-1]
- self.extra_info += [ [frame_id] ]
- self.image_list += [ [img1, img2] ]
+ frame_id = img1.split("/")[-1]
+ self.extra_info += [[frame_id]]
+ self.image_list += [[img1, img2]]
- if split == 'training':
- self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
+ if split == "training":
+ self.flow_list = sorted(glob(osp.join(root, "flow_occ/*_10.png")))
class HD1K(FlowDataset):
- def __init__(self, aug_params=None, root='datasets/HD1k'):
+ def __init__(self, aug_params=None, root="datasets/HD1k"):
super(HD1K, self).__init__(aug_params, sparse=True)
seq_ix = 0
while 1:
- flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
- images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
+ flows = sorted(
+ glob(os.path.join(root, "hd1k_flow_gt", "flow_occ/%06d_*.png" % seq_ix))
+ )
+ images = sorted(
+ glob(os.path.join(root, "hd1k_input", "image_2/%06d_*.png" % seq_ix))
+ )
if len(flows) == 0:
break
- for i in range(len(flows)-1):
+ for i in range(len(flows) - 1):
self.flow_list += [flows[i]]
- self.image_list += [ [images[i], images[i+1]] ]
+ self.image_list += [[images[i], images[i + 1]]]
seq_ix += 1
-def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
- """ Create the data loader for the corresponding trainign set """
-
- if args.stage == 'chairs':
- aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
- train_dataset = FlyingChairs(aug_params, split='training')
-
- elif args.stage == 'things':
- aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
- clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
- final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
+def fetch_dataloader(args, TRAIN_DS="C+T+K+S+H"):
+ """Create the data loader for the corresponding trainign set"""
+
+ if args.stage == "chairs":
+ aug_params = {
+ "crop_size": args.image_size,
+ "min_scale": -0.1,
+ "max_scale": 1.0,
+ "do_flip": True,
+ }
+ train_dataset = FlyingChairs(aug_params, split="training")
+
+ elif args.stage == "things":
+ aug_params = {
+ "crop_size": args.image_size,
+ "min_scale": -0.4,
+ "max_scale": 0.8,
+ "do_flip": True,
+ }
+ clean_dataset = FlyingThings3D(aug_params, dstype="frames_cleanpass")
+ final_dataset = FlyingThings3D(aug_params, dstype="frames_finalpass")
train_dataset = clean_dataset + final_dataset
- elif args.stage == 'sintel':
- aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
- things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
- sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
- sintel_final = MpiSintel(aug_params, split='training', dstype='final')
-
- if TRAIN_DS == 'C+T+K+S+H':
- kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
- hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
- train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
-
- elif TRAIN_DS == 'C+T+K/S':
- train_dataset = 100*sintel_clean + 100*sintel_final + things
-
- elif args.stage == 'kitti':
- aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
- train_dataset = KITTI(aug_params, split='training')
-
- train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
- pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
-
- print('Training with %d image pairs' % len(train_dataset))
+ elif args.stage == "sintel":
+ aug_params = {
+ "crop_size": args.image_size,
+ "min_scale": -0.2,
+ "max_scale": 0.6,
+ "do_flip": True,
+ }
+ things = FlyingThings3D(aug_params, dstype="frames_cleanpass")
+ sintel_clean = MpiSintel(aug_params, split="training", dstype="clean")
+ sintel_final = MpiSintel(aug_params, split="training", dstype="final")
+
+ if TRAIN_DS == "C+T+K+S+H":
+ kitti = KITTI(
+ {
+ "crop_size": args.image_size,
+ "min_scale": -0.3,
+ "max_scale": 0.5,
+ "do_flip": True,
+ }
+ )
+ hd1k = HD1K(
+ {
+ "crop_size": args.image_size,
+ "min_scale": -0.5,
+ "max_scale": 0.2,
+ "do_flip": True,
+ }
+ )
+ train_dataset = (
+ 100 * sintel_clean
+ + 100 * sintel_final
+ + 200 * kitti
+ + 5 * hd1k
+ + things
+ )
+
+ elif TRAIN_DS == "C+T+K/S":
+ train_dataset = 100 * sintel_clean + 100 * sintel_final + things
+
+ elif args.stage == "kitti":
+ aug_params = {
+ "crop_size": args.image_size,
+ "min_scale": -0.2,
+ "max_scale": 0.4,
+ "do_flip": False,
+ }
+ train_dataset = KITTI(aug_params, split="training")
+
+ train_loader = data.DataLoader(
+ train_dataset,
+ batch_size=args.batch_size,
+ pin_memory=False,
+ shuffle=True,
+ num_workers=4,
+ drop_last=True,
+ )
+
+ print("Training with %d image pairs" % len(train_dataset))
return train_loader
-
diff --git a/data_preprocessing/RAFT/core/extractor.py b/data_preprocessing/RAFT/core/extractor.py
index 9a9c759..4215b79 100755
--- a/data_preprocessing/RAFT/core/extractor.py
+++ b/data_preprocessing/RAFT/core/extractor.py
@@ -4,34 +4,36 @@
class ResidualBlock(nn.Module):
- def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
-
- self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
+
+ self.conv1 = nn.Conv2d(
+ in_planes, planes, kernel_size=3, padding=1, stride=stride
+ )
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
- if norm_fn == 'group':
+ if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
-
- elif norm_fn == 'batch':
+
+ elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
-
- elif norm_fn == 'instance':
+
+ elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
- elif norm_fn == 'none':
+ elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
@@ -39,11 +41,11 @@ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
if stride == 1:
self.downsample = None
-
- else:
- self.downsample = nn.Sequential(
- nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
+ )
def forward(self, x):
y = x
@@ -53,43 +55,44 @@ def forward(self, x):
if self.downsample is not None:
x = self.downsample(x)
- return self.relu(x+y)
-
+ return self.relu(x + y)
class BottleneckBlock(nn.Module):
- def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(BottleneckBlock, self).__init__()
-
- self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
- self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
- self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
+
+ self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0)
+ self.conv2 = nn.Conv2d(
+ planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride
+ )
+ self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
- if norm_fn == 'group':
- self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
- self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
-
- elif norm_fn == 'batch':
- self.norm1 = nn.BatchNorm2d(planes//4)
- self.norm2 = nn.BatchNorm2d(planes//4)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes // 4)
+ self.norm2 = nn.BatchNorm2d(planes // 4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
-
- elif norm_fn == 'instance':
- self.norm1 = nn.InstanceNorm2d(planes//4)
- self.norm2 = nn.InstanceNorm2d(planes//4)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes // 4)
+ self.norm2 = nn.InstanceNorm2d(planes // 4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
- elif norm_fn == 'none':
+ elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
@@ -98,11 +101,11 @@ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
if stride == 1:
self.downsample = None
-
- else:
- self.downsample = nn.Sequential(
- nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4
+ )
def forward(self, x):
y = x
@@ -113,30 +116,31 @@ def forward(self, x):
if self.downsample is not None:
x = self.downsample(x)
- return self.relu(x+y)
+ return self.relu(x + y)
+
class BasicEncoder(nn.Module):
- def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
+ def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
- if self.norm_fn == 'group':
+ if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
-
- elif self.norm_fn == 'batch':
+
+ elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(64)
- elif self.norm_fn == 'instance':
+ elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(64)
- elif self.norm_fn == 'none':
+ elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
- self.layer1 = self._make_layer(64, stride=1)
+ self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
@@ -149,7 +153,7 @@ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
for m in self.modules():
if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
@@ -160,11 +164,10 @@ def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
-
+
self.in_planes = dim
return nn.Sequential(*layers)
-
def forward(self, x):
# if input is list, combine batch dimension
@@ -193,39 +196,39 @@ def forward(self, x):
class SmallEncoder(nn.Module):
- def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
+ def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
- if self.norm_fn == 'group':
+ if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
-
- elif self.norm_fn == 'batch':
+
+ elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(32)
- elif self.norm_fn == 'instance':
+ elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(32)
- elif self.norm_fn == 'none':
+ elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
- self.layer1 = self._make_layer(32, stride=1)
+ self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
-
+
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
@@ -236,11 +239,10 @@ def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
-
+
self.in_planes = dim
return nn.Sequential(*layers)
-
def forward(self, x):
# if input is list, combine batch dimension
diff --git a/data_preprocessing/RAFT/core/raft.py b/data_preprocessing/RAFT/core/raft.py
index 652b81a..72ed994 100755
--- a/data_preprocessing/RAFT/core/raft.py
+++ b/data_preprocessing/RAFT/core/raft.py
@@ -15,8 +15,10 @@
class autocast:
def __init__(self, enabled):
pass
+
def __enter__(self):
pass
+
def __exit__(self, *args):
pass
@@ -31,28 +33,36 @@ def __init__(self, args):
self.context_dim = cdim = 64
args.corr_levels = 4
args.corr_radius = 3
-
+
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
args.corr_levels = 4
args.corr_radius = 4
- if 'dropout' not in self.args:
+ if "dropout" not in self.args:
self.args.dropout = 0
- if 'alternate_corr' not in self.args:
+ if "alternate_corr" not in self.args:
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
- self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
- self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
+ self.fnet = SmallEncoder(
+ output_dim=128, norm_fn="instance", dropout=args.dropout
+ )
+ self.cnet = SmallEncoder(
+ output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout
+ )
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
- self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
- self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
+ self.fnet = BasicEncoder(
+ output_dim=256, norm_fn="instance", dropout=args.dropout
+ )
+ self.cnet = BasicEncoder(
+ output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout
+ )
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
@@ -61,30 +71,31 @@ def freeze_bn(self):
m.eval()
def initialize_flow(self, img):
- """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
+ """Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
- coords0 = coords_grid(N, H//8, W//8, device=img.device)
- coords1 = coords_grid(N, H//8, W//8, device=img.device)
+ coords0 = coords_grid(N, H // 8, W // 8, device=img.device)
+ coords1 = coords_grid(N, H // 8, W // 8, device=img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def upsample_flow(self, flow, mask):
- """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
+ """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination"""
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
- up_flow = F.unfold(8 * flow, [3,3], padding=1)
+ up_flow = F.unfold(8 * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
- return up_flow.reshape(N, 2, 8*H, 8*W)
-
+ return up_flow.reshape(N, 2, 8 * H, 8 * W)
- def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
- """ Estimate optical flow between pair of frames """
+ def forward(
+ self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False
+ ):
+ """Estimate optical flow between pair of frames"""
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
@@ -97,8 +108,8 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_
# run the feature network
with autocast(enabled=self.args.mixed_precision):
- fmap1, fmap2 = self.fnet([image1, image2])
-
+ fmap1, fmap2 = self.fnet([image1, image2])
+
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
@@ -121,7 +132,7 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
- corr = corr_fn(coords1) # index correlation volume
+ corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
with autocast(enabled=self.args.mixed_precision):
@@ -135,10 +146,10 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
-
+
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
-
+
return flow_predictions
diff --git a/data_preprocessing/RAFT/core/update.py b/data_preprocessing/RAFT/core/update.py
index f940497..ced6df0 100755
--- a/data_preprocessing/RAFT/core/update.py
+++ b/data_preprocessing/RAFT/core/update.py
@@ -13,56 +13,70 @@ def __init__(self, input_dim=128, hidden_dim=256):
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
+
class ConvGRU(nn.Module):
- def __init__(self, hidden_dim=128, input_dim=192+128):
+ def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
- self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
- self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
- self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
+ self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
+ self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
+ self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
- q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
+ q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
- h = (1-z) * h + z * q
+ h = (1 - z) * h + z * q
return h
+
class SepConvGRU(nn.Module):
- def __init__(self, hidden_dim=128, input_dim=192+128):
+ def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
- self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
- self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
- self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
-
- self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
- self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
- self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
-
+ self.convz1 = nn.Conv2d(
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
+ )
+ self.convr1 = nn.Conv2d(
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
+ )
+ self.convq1 = nn.Conv2d(
+ hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)
+ )
+
+ self.convz2 = nn.Conv2d(
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
+ )
+ self.convr2 = nn.Conv2d(
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
+ )
+ self.convq2 = nn.Conv2d(
+ hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)
+ )
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
- q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
- h = (1-z) * h + z * q
+ q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
+ h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
- q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
- h = (1-z) * h + z * q
+ q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
+ h = (1 - z) * h + z * q
return h
+
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
- cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
+ cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
@@ -76,15 +90,16 @@ def forward(self, flow, corr):
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
+
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
- cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
+ cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
- self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
+ self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
@@ -96,11 +111,12 @@ def forward(self, flow, corr):
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
+
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
- self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
+ self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
@@ -111,18 +127,20 @@ def forward(self, net, inp, corr, flow):
return net, None, delta_flow
+
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
- self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
- nn.Conv2d(256, 64*9, 1, padding=0))
+ nn.Conv2d(256, 64 * 9, 1, padding=0),
+ )
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
@@ -132,8 +150,5 @@ def forward(self, net, inp, corr, flow, upsample=True):
delta_flow = self.flow_head(net)
# scale mask to balence gradients
- mask = .25 * self.mask(net)
+ mask = 0.25 * self.mask(net)
return net, mask, delta_flow
-
-
-
diff --git a/data_preprocessing/RAFT/core/utils/augmentor.py b/data_preprocessing/RAFT/core/utils/augmentor.py
index e81c4f2..ae675df 100755
--- a/data_preprocessing/RAFT/core/utils/augmentor.py
+++ b/data_preprocessing/RAFT/core/utils/augmentor.py
@@ -4,6 +4,7 @@
from PIL import Image
import cv2
+
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
@@ -14,7 +15,7 @@
class FlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
-
+
# spatial augmentation params
self.crop_size = crop_size
self.min_scale = min_scale
@@ -29,12 +30,14 @@ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
self.v_flip_prob = 0.1
# photometric augmentation params
- self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
+ self.photo_aug = ColorJitter(
+ brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14
+ )
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
def color_transform(self, img1, img2):
- """ Photometric augmentation """
+ """Photometric augmentation"""
# asymmetric
if np.random.rand() < self.asymmetric_color_aug_prob:
@@ -44,13 +47,15 @@ def color_transform(self, img1, img2):
# symmetric
else:
image_stack = np.concatenate([img1, img2], axis=0)
- image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+ image_stack = np.array(
+ self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
+ )
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
def eraser_transform(self, img1, img2, bounds=[50, 100]):
- """ Occlusion augmentation """
+ """Occlusion augmentation"""
ht, wd = img1.shape[:2]
if np.random.rand() < self.eraser_aug_prob:
@@ -60,7 +65,7 @@ def eraser_transform(self, img1, img2, bounds=[50, 100]):
y0 = np.random.randint(0, ht)
dx = np.random.randint(bounds[0], bounds[1])
dy = np.random.randint(bounds[0], bounds[1])
- img2[y0:y0+dy, x0:x0+dx, :] = mean_color
+ img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
return img1, img2
@@ -68,8 +73,8 @@ def spatial_transform(self, img1, img2, flow):
# randomly sample scale
ht, wd = img1.shape[:2]
min_scale = np.maximum(
- (self.crop_size[0] + 8) / float(ht),
- (self.crop_size[1] + 8) / float(wd))
+ (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd)
+ )
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = scale
@@ -77,34 +82,40 @@ def spatial_transform(self, img1, img2, flow):
if np.random.rand() < self.stretch_prob:
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
-
+
scale_x = np.clip(scale_x, min_scale, None)
scale_y = np.clip(scale_y, min_scale, None)
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
- img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
- img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
- flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img1 = cv2.resize(
+ img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
+ )
+ img2 = cv2.resize(
+ img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
+ )
+ flow = cv2.resize(
+ flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
+ )
flow = flow * [scale_x, scale_y]
if self.do_flip:
- if np.random.rand() < self.h_flip_prob: # h-flip
+ if np.random.rand() < self.h_flip_prob: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
- if np.random.rand() < self.v_flip_prob: # v-flip
+ if np.random.rand() < self.v_flip_prob: # v-flip
img1 = img1[::-1, :]
img2 = img2[::-1, :]
flow = flow[::-1, :] * [1.0, -1.0]
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
-
- img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
- img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
- flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+
+ img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
+ img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
+ flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
return img1, img2, flow
@@ -119,6 +130,7 @@ def __call__(self, img1, img2, flow):
return img1, img2, flow
+
class SparseFlowAugmentor:
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
# spatial augmentation params
@@ -135,13 +147,17 @@ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
self.v_flip_prob = 0.1
# photometric augmentation params
- self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
+ self.photo_aug = ColorJitter(
+ brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14
+ )
self.asymmetric_color_aug_prob = 0.2
self.eraser_aug_prob = 0.5
-
+
def color_transform(self, img1, img2):
image_stack = np.concatenate([img1, img2], axis=0)
- image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+ image_stack = np.array(
+ self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8
+ )
img1, img2 = np.split(image_stack, 2, axis=0)
return img1, img2
@@ -154,7 +170,7 @@ def eraser_transform(self, img1, img2):
y0 = np.random.randint(0, ht)
dx = np.random.randint(50, 100)
dy = np.random.randint(50, 100)
- img2[y0:y0+dy, x0:x0+dx, :] = mean_color
+ img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color
return img1, img2
@@ -167,8 +183,8 @@ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
flow = flow.reshape(-1, 2).astype(np.float32)
valid = valid.reshape(-1).astype(np.float32)
- coords0 = coords[valid>=1]
- flow0 = flow[valid>=1]
+ coords0 = coords[valid >= 1]
+ flow0 = flow[valid >= 1]
ht1 = int(round(ht * fy))
wd1 = int(round(wd * fx))
@@ -176,8 +192,8 @@ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
coords1 = coords0 * [fx, fy]
flow1 = flow0 * [fx, fy]
- xx = np.round(coords1[:,0]).astype(np.int32)
- yy = np.round(coords1[:,1]).astype(np.int32)
+ xx = np.round(coords1[:, 0]).astype(np.int32)
+ yy = np.round(coords1[:, 1]).astype(np.int32)
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
xx = xx[v]
@@ -197,8 +213,8 @@ def spatial_transform(self, img1, img2, flow, valid):
ht, wd = img1.shape[:2]
min_scale = np.maximum(
- (self.crop_size[0] + 1) / float(ht),
- (self.crop_size[1] + 1) / float(wd))
+ (self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd)
+ )
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
scale_x = np.clip(scale, min_scale, None)
@@ -206,12 +222,18 @@ def spatial_transform(self, img1, img2, flow, valid):
if np.random.rand() < self.spatial_aug_prob:
# rescale the images
- img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
- img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
- flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
+ img1 = cv2.resize(
+ img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
+ )
+ img2 = cv2.resize(
+ img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR
+ )
+ flow, valid = self.resize_sparse_flow_map(
+ flow, valid, fx=scale_x, fy=scale_y
+ )
if self.do_flip:
- if np.random.rand() < 0.5: # h-flip
+ if np.random.rand() < 0.5: # h-flip
img1 = img1[:, ::-1]
img2 = img2[:, ::-1]
flow = flow[:, ::-1] * [-1.0, 1.0]
@@ -226,13 +248,12 @@ def spatial_transform(self, img1, img2, flow, valid):
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
- img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
- img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
- flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
- valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
+ img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
+ flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
+ valid = valid[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
return img1, img2, flow, valid
-
def __call__(self, img1, img2, flow, valid):
img1, img2 = self.color_transform(img1, img2)
img1, img2 = self.eraser_transform(img1, img2)
diff --git a/data_preprocessing/RAFT/core/utils/flow_viz.py b/data_preprocessing/RAFT/core/utils/flow_viz.py
index dcee65e..fec0836 100755
--- a/data_preprocessing/RAFT/core/utils/flow_viz.py
+++ b/data_preprocessing/RAFT/core/utils/flow_viz.py
@@ -17,6 +17,7 @@
import numpy as np
+
def make_colorwheel():
"""
Generates a color wheel for optical flow visualization as presented in:
@@ -43,27 +44,27 @@ def make_colorwheel():
# RY
colorwheel[0:RY, 0] = 255
- colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
- col = col+RY
+ colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
+ col = col + RY
# YG
- colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
- colorwheel[col:col+YG, 1] = 255
- col = col+YG
+ colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
+ colorwheel[col : col + YG, 1] = 255
+ col = col + YG
# GC
- colorwheel[col:col+GC, 1] = 255
- colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
- col = col+GC
+ colorwheel[col : col + GC, 1] = 255
+ colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
+ col = col + GC
# CB
- colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
- colorwheel[col:col+CB, 2] = 255
- col = col+CB
+ colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
+ colorwheel[col : col + CB, 2] = 255
+ col = col + CB
# BM
- colorwheel[col:col+BM, 2] = 255
- colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
- col = col+BM
+ colorwheel[col : col + BM, 2] = 255
+ colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
+ col = col + BM
# MR
- colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
- colorwheel[col:col+MR, 0] = 255
+ colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
+ colorwheel[col : col + MR, 0] = 255
return colorwheel
@@ -86,23 +87,23 @@ def flow_uv_to_colors(u, v, convert_to_bgr=False):
colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v))
- a = np.arctan2(-v, -u)/np.pi
- fk = (a+1) / 2*(ncols-1)
+ a = np.arctan2(-v, -u) / np.pi
+ fk = (a + 1) / 2 * (ncols - 1)
k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1
k1[k1 == ncols] = 0
f = fk - k0
for i in range(colorwheel.shape[1]):
- tmp = colorwheel[:,i]
+ tmp = colorwheel[:, i]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
- col = (1-f)*col0 + f*col1
- idx = (rad <= 1)
- col[idx] = 1 - rad[idx] * (1-col[idx])
- col[~idx] = col[~idx] * 0.75 # out of range
+ col = (1 - f) * col0 + f * col1
+ idx = rad <= 1
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range
# Note the 2-i => BGR instead of RGB
- ch_idx = 2-i if convert_to_bgr else i
- flow_image[:,:,ch_idx] = np.floor(255 * col)
+ ch_idx = 2 - i if convert_to_bgr else i
+ flow_image[:, :, ch_idx] = np.floor(255 * col)
return flow_image
@@ -118,15 +119,15 @@ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
- assert flow_uv.ndim == 3, 'input flow must have three dimensions'
- assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+ assert flow_uv.ndim == 3, "input flow must have three dimensions"
+ assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]"
if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow)
- u = flow_uv[:,:,0]
- v = flow_uv[:,:,1]
+ u = flow_uv[:, :, 0]
+ v = flow_uv[:, :, 1]
rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad)
epsilon = 1e-5
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
- return flow_uv_to_colors(u, v, convert_to_bgr)
\ No newline at end of file
+ return flow_uv_to_colors(u, v, convert_to_bgr)
diff --git a/data_preprocessing/RAFT/core/utils/frame_utils.py b/data_preprocessing/RAFT/core/utils/frame_utils.py
index 6c49113..bee554e 100755
--- a/data_preprocessing/RAFT/core/utils/frame_utils.py
+++ b/data_preprocessing/RAFT/core/utils/frame_utils.py
@@ -4,34 +4,37 @@
import re
import cv2
+
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
TAG_CHAR = np.array([202021.25], np.float32)
+
def readFlow(fn):
- """ Read .flo file in Middlebury format"""
+ """Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
- with open(fn, 'rb') as f:
+ with open(fn, "rb") as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
- print('Magic number incorrect. Invalid .flo file')
+ print("Magic number incorrect. Invalid .flo file")
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
- data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
+ data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))
+
def readPFM(file):
- file = open(file, 'rb')
+ file = open(file, "rb")
color = None
width = None
@@ -40,36 +43,37 @@ def readPFM(file):
endian = None
header = file.readline().rstrip()
- if header == b'PF':
+ if header == b"PF":
color = True
- elif header == b'Pf':
+ elif header == b"Pf":
color = False
else:
- raise Exception('Not a PFM file.')
+ raise Exception("Not a PFM file.")
- dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
+ dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
- raise Exception('Malformed PFM header.')
+ raise Exception("Malformed PFM header.")
scale = float(file.readline().rstrip())
- if scale < 0: # little-endian
- endian = '<'
+ if scale < 0: # little-endian
+ endian = "<"
scale = -scale
else:
- endian = '>' # big-endian
+ endian = ">" # big-endian
- data = np.fromfile(file, endian + 'f')
+ data = np.fromfile(file, endian + "f")
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
return data
-def writeFlow(filename,uv,v=None):
- """ Write optical flow to file.
-
+
+def writeFlow(filename, uv, v=None):
+ """Write optical flow to file.
+
If v is None, uv is assumed to contain both u and v channels,
stacked in depth.
Original code by Deqing Sun, adapted from Daniel Scharstein.
@@ -77,35 +81,36 @@ def writeFlow(filename,uv,v=None):
nBands = 2
if v is None:
- assert(uv.ndim == 3)
- assert(uv.shape[2] == 2)
- u = uv[:,:,0]
- v = uv[:,:,1]
+ assert uv.ndim == 3
+ assert uv.shape[2] == 2
+ u = uv[:, :, 0]
+ v = uv[:, :, 1]
else:
u = uv
- assert(u.shape == v.shape)
- height,width = u.shape
- f = open(filename,'wb')
+ assert u.shape == v.shape
+ height, width = u.shape
+ f = open(filename, "wb")
# write the header
f.write(TAG_CHAR)
np.array(width).astype(np.int32).tofile(f)
np.array(height).astype(np.int32).tofile(f)
# arrange into matrix form
- tmp = np.zeros((height, width*nBands))
- tmp[:,np.arange(width)*2] = u
- tmp[:,np.arange(width)*2 + 1] = v
+ tmp = np.zeros((height, width * nBands))
+ tmp[:, np.arange(width) * 2] = u
+ tmp[:, np.arange(width) * 2 + 1] = v
tmp.astype(np.float32).tofile(f)
f.close()
def readFlowKITTI(filename):
- flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
- flow = flow[:,:,::-1].astype(np.float32)
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
+ flow = flow[:, :, ::-1].astype(np.float32)
flow, valid = flow[:, :, :2], flow[:, :, 2]
flow = (flow - 2**15) / 64.0
return flow, valid
+
def readDispKITTI(filename):
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
valid = disp > 0.0
@@ -118,20 +123,20 @@ def writeFlowKITTI(filename, uv):
valid = np.ones([uv.shape[0], uv.shape[1], 1])
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
cv2.imwrite(filename, uv[..., ::-1])
-
+
def read_gen(file_name, pil=False):
ext = splitext(file_name)[-1]
- if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
+ if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg":
return Image.open(file_name)
- elif ext == '.bin' or ext == '.raw':
+ elif ext == ".bin" or ext == ".raw":
return np.load(file_name)
- elif ext == '.flo':
+ elif ext == ".flo":
return readFlow(file_name).astype(np.float32)
- elif ext == '.pfm':
+ elif ext == ".pfm":
flow = readPFM(file_name).astype(np.float32)
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
- return []
\ No newline at end of file
+ return []
diff --git a/data_preprocessing/RAFT/core/utils/utils.py b/data_preprocessing/RAFT/core/utils/utils.py
index 741ccfe..781dc55 100755
--- a/data_preprocessing/RAFT/core/utils/utils.py
+++ b/data_preprocessing/RAFT/core/utils/utils.py
@@ -5,23 +5,30 @@
class InputPadder:
- """ Pads images such that dimensions are divisible by 8 """
- def __init__(self, dims, mode='sintel'):
+ """Pads images such that dimensions are divisible by 8"""
+
+ def __init__(self, dims, mode="sintel"):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
- if mode == 'sintel':
- self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
+ if mode == "sintel":
+ self._pad = [
+ pad_wd // 2,
+ pad_wd - pad_wd // 2,
+ pad_ht // 2,
+ pad_ht - pad_ht // 2,
+ ]
else:
- self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
+ self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
- return [F.pad(x, self._pad, mode='replicate') for x in inputs]
+ return [F.pad(x, self._pad, mode="replicate") for x in inputs]
- def unpad(self,x):
+ def unpad(self, x):
ht, wd = x.shape[-2:]
- c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
- return x[..., c[0]:c[1], c[2]:c[3]]
+ c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
+ return x[..., c[0] : c[1], c[2] : c[3]]
+
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
@@ -32,7 +39,7 @@ def forward_interpolate(flow):
x1 = x0 + dx
y1 = y0 + dy
-
+
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
@@ -45,21 +52,23 @@ def forward_interpolate(flow):
dy = dy[valid]
flow_x = interpolate.griddata(
- (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
+ (x1, y1), dx, (x0, y0), method="nearest", fill_value=0
+ )
flow_y = interpolate.griddata(
- (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
+ (x1, y1), dy, (x0, y0), method="nearest", fill_value=0
+ )
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
-def bilinear_sampler(img, coords, mode='bilinear', mask=False):
- """ Wrapper for grid_sample, uses pixel coordinates """
+def bilinear_sampler(img, coords, mode="bilinear", mask=False):
+ """Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
- xgrid, ygrid = coords.split([1,1], dim=-1)
- xgrid = 2*xgrid/(W-1) - 1
- ygrid = 2*ygrid/(H-1) - 1
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
+ xgrid = 2 * xgrid / (W - 1) - 1
+ ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
@@ -72,11 +81,13 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
def coords_grid(batch, ht, wd, device):
- coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device))
+ coords = torch.meshgrid(
+ torch.arange(ht, device=device), torch.arange(wd, device=device)
+ )
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
-def upflow8(flow, mode='bilinear'):
+def upflow8(flow, mode="bilinear"):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
- return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
diff --git a/data_preprocessing/RAFT/demo.py b/data_preprocessing/RAFT/demo.py
index b6df6bc..9a447b9 100755
--- a/data_preprocessing/RAFT/demo.py
+++ b/data_preprocessing/RAFT/demo.py
@@ -1,5 +1,6 @@
import sys
-sys.path.append('core')
+
+sys.path.append("core")
import argparse
import os
@@ -14,7 +15,7 @@
from utils import flow_viz
from utils.utils import InputPadder
-DEVICE = 'cuda'
+DEVICE = "cuda"
def load_image(imfile):
@@ -23,15 +24,15 @@ def load_image(imfile):
return img[None].to(DEVICE)
-def viz(img, flo,img_name=None):
- img = img[0].permute(1,2,0).cpu().numpy()
- flo = flo[0].permute(1,2,0).cpu().numpy()
+def viz(img, flo, img_name=None):
+ img = img[0].permute(1, 2, 0).cpu().numpy()
+ flo = flo[0].permute(1, 2, 0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
- cv2.imwrite(f'{img_name}', img_flo[:, :, [2,1,0]])
+ cv2.imwrite(f"{img_name}", img_flo[:, :, [2, 1, 0]])
def demo(args):
@@ -44,66 +45,76 @@ def demo(args):
os.makedirs(args.outdir, exist_ok=True)
os.makedirs(args.outdir_conf, exist_ok=True)
with torch.no_grad():
- images = glob.glob(os.path.join(args.path, '*.png')) + \
- glob.glob(os.path.join(args.path, '*.jpg'))
+ images = glob.glob(os.path.join(args.path, "*.png")) + glob.glob(
+ os.path.join(args.path, "*.jpg")
+ )
images = sorted(images)
- i=0
+ i = 0
for imfile1, imfile2 in zip(images[:-1], images[1:]):
image1 = load_image(imfile1)
image2 = load_image(imfile2)
if args.if_mask:
- mk_file1=imfile1.split("/")
- mk_file1[-2]=f"{args.name}_masks"
- mk_file1='/'.join(mk_file1)
- mk_file2=imfile2.split("/")
- mk_file2[-2]=f"{args.name}_masks"
- mk_file2='/'.join(mk_file2)
- mask1=cv2.imread(mk_file1.replace('jpg','png')
- ,0)
- mask2=cv2.imread(mk_file2.replace('jpg','png'),
- 0)
- mask1=torch.from_numpy(mask1).to(DEVICE).float()
- mask2=torch.from_numpy(mask2).to(DEVICE).float()
- mask1[mask1>0]=1
- mask2[mask2>0]=1
- image1*=mask1
- image2*=mask2
+ mk_file1 = imfile1.split("/")
+ mk_file1[-2] = f"{args.name}_masks"
+ mk_file1 = "/".join(mk_file1)
+ mk_file2 = imfile2.split("/")
+ mk_file2[-2] = f"{args.name}_masks"
+ mk_file2 = "/".join(mk_file2)
+ mask1 = cv2.imread(mk_file1.replace("jpg", "png"), 0)
+ mask2 = cv2.imread(mk_file2.replace("jpg", "png"), 0)
+ mask1 = torch.from_numpy(mask1).to(DEVICE).float()
+ mask2 = torch.from_numpy(mask2).to(DEVICE).float()
+ mask1[mask1 > 0] = 1
+ mask2[mask2 > 0] = 1
+ image1 *= mask1
+ image2 *= mask2
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
if args.if_mask:
- mask1,mask2=padder.pad(mask1.unsqueeze(0).unsqueeze(0),
- mask2.unsqueeze(0).unsqueeze(0))
- mask1=mask1.squeeze()
- mask2=mask2.squeeze()
+ mask1, mask2 = padder.pad(
+ mask1.unsqueeze(0).unsqueeze(0), mask2.unsqueeze(0).unsqueeze(0)
+ )
+ mask1 = mask1.squeeze()
+ mask2 = mask2.squeeze()
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
flow_low_, flow_up_ = model(image2, image1, iters=20, test_mode=True)
flow_1to2 = flow_up.clone()
flow_2to1 = flow_up_.clone()
- _,_,H,W=image1.shape
+ _, _, H, W = image1.shape
x = torch.linspace(0, 1, W)
y = torch.linspace(0, 1, H)
- grid_x,grid_y=torch.meshgrid(x,y)
- grid=torch.stack([grid_x,grid_y],dim=0).to(DEVICE)
- grid=grid.permute(0,2,1)
- grid[0]*=W
- grid[1]*=H
+ grid_x, grid_y = torch.meshgrid(x, y)
+ grid = torch.stack([grid_x, grid_y], dim=0).to(DEVICE)
+ grid = grid.permute(0, 2, 1)
+ grid[0] *= W
+ grid[1] *= H
if args.if_mask:
- flow_up[:,:,mask1.long()==0]=10000
- grid_=grid+flow_up.squeeze()
-
- grid_norm=grid_.clone()
- grid_norm[0,...]=2*grid_norm[0,...]/(W-1)-1
- grid_norm[1,...]=2*grid_norm[1,...]/(H-1)-1
-
- flow_bilinear_=F.grid_sample(flow_up_,grid_norm.unsqueeze(0).permute(0,2,3,1),mode='bilinear',padding_mode='zeros')
-
- rgb_bilinear_=F.grid_sample(image2,grid_norm.unsqueeze(0).permute(0,2,3,1),mode='bilinear',padding_mode='zeros')
- rgb_np=rgb_bilinear_.squeeze().permute(1,2,0).cpu().numpy()[:, :, ::-1]
- cv2.imwrite(f'{args.outdir}/warped.png',rgb_np)
+ flow_up[:, :, mask1.long() == 0] = 10000
+ grid_ = grid + flow_up.squeeze()
+
+ grid_norm = grid_.clone()
+ grid_norm[0, ...] = 2 * grid_norm[0, ...] / (W - 1) - 1
+ grid_norm[1, ...] = 2 * grid_norm[1, ...] / (H - 1) - 1
+
+ flow_bilinear_ = F.grid_sample(
+ flow_up_,
+ grid_norm.unsqueeze(0).permute(0, 2, 3, 1),
+ mode="bilinear",
+ padding_mode="zeros",
+ )
+
+ rgb_bilinear_ = F.grid_sample(
+ image2,
+ grid_norm.unsqueeze(0).permute(0, 2, 3, 1),
+ mode="bilinear",
+ padding_mode="zeros",
+ )
+ rgb_np = rgb_bilinear_.squeeze().permute(1, 2, 0).cpu().numpy()[:, :, ::-1]
+ cv2.imwrite(f"{args.outdir}/warped.png", rgb_np)
if args.confidence:
### Calculate confidence map using cycle consistency.
@@ -118,62 +129,92 @@ def demo(args):
norm_grid_2to1 = grid_2to1.clone()
norm_grid_2to1[0, ...] = 2 * norm_grid_2to1[0, ...] / (W - 1) - 1
norm_grid_2to1[1, ...] = 2 * norm_grid_2to1[1, ...] / (H - 1) - 1
- warped_image2 = F.grid_sample(image1, norm_grid_2to1.unsqueeze(0).permute(0,2,3,1), mode='bilinear', padding_mode='zeros')
+ warped_image2 = F.grid_sample(
+ image1,
+ norm_grid_2to1.unsqueeze(0).permute(0, 2, 3, 1),
+ mode="bilinear",
+ padding_mode="zeros",
+ )
grid_1to2 = grid + flow_1to2.squeeze()
norm_grid_1to2 = grid_1to2.clone()
norm_grid_1to2[0, ...] = 2 * norm_grid_1to2[0, ...] / (W - 1) - 1
norm_grid_1to2[1, ...] = 2 * norm_grid_1to2[1, ...] / (H - 1) - 1
- warped_image1 = F.grid_sample(warped_image2, norm_grid_1to2.unsqueeze(0).permute(0,2,3,1), mode='bilinear', padding_mode='zeros')
+ warped_image1 = F.grid_sample(
+ warped_image2,
+ norm_grid_1to2.unsqueeze(0).permute(0, 2, 3, 1),
+ mode="bilinear",
+ padding_mode="zeros",
+ )
error = torch.abs(image1 - warped_image1)
confidence_map = torch.mean(error, dim=1, keepdim=True)
confidence_map[confidence_map < args.thres] = 1
confidence_map[confidence_map >= args.thres] = 0
- grid_bck=grid+flow_up.squeeze()+flow_bilinear_.squeeze()
- res=grid-grid_bck
- res=torch.norm(res,dim=0)
- mk=(res<10)&(flow_up.norm(dim=1).squeeze()>5)
+ grid_bck = grid + flow_up.squeeze() + flow_bilinear_.squeeze()
+ res = grid - grid_bck
+ res = torch.norm(res, dim=0)
+ mk = (res < 10) & (flow_up.norm(dim=1).squeeze() > 5)
- pts_src=grid[:,mk]
+ pts_src = grid[:, mk]
- pts_dst=(grid[:,mk]+flow_up.squeeze()[:,mk])
+ pts_dst = grid[:, mk] + flow_up.squeeze()[:, mk]
- pts_src=pts_src.permute(1,0).cpu().numpy()
- pts_dst=pts_dst.permute(1,0).cpu().numpy()
- indx=torch.randperm(pts_src.shape[0])[:30]
+ pts_src = pts_src.permute(1, 0).cpu().numpy()
+ pts_dst = pts_dst.permute(1, 0).cpu().numpy()
+ indx = torch.randperm(pts_src.shape[0])[:30]
# use cv2 to draw the matches in image1 and image2
- img_new=np.zeros((H,W*2,3),dtype=np.uint8)
- img_new[:,:W,:]=image1[0].permute(1,2,0).cpu().numpy()
- img_new[:,W:,:]=image2[0].permute(1,2,0).cpu().numpy()
+ img_new = np.zeros((H, W * 2, 3), dtype=np.uint8)
+ img_new[:, :W, :] = image1[0].permute(1, 2, 0).cpu().numpy()
+ img_new[:, W:, :] = image2[0].permute(1, 2, 0).cpu().numpy()
for j in indx:
- cv2.line(img_new,(int(pts_src[j,0]),int(pts_src[j,1])),(int(pts_dst[j,0])+W,int(pts_dst[j,1])),(0,255,0),1)
+ cv2.line(
+ img_new,
+ (int(pts_src[j, 0]), int(pts_src[j, 1])),
+ (int(pts_dst[j, 0]) + W, int(pts_dst[j, 1])),
+ (0, 255, 0),
+ 1,
+ )
- cv2.imwrite(f'{args.outdir}/matches.png',img_new)
+ cv2.imwrite(f"{args.outdir}/matches.png", img_new)
- np.save(f'{args.outdir}/{i:06d}.npy', flow_up.cpu().numpy())
+ np.save(f"{args.outdir}/{i:06d}.npy", flow_up.cpu().numpy())
if args.confidence:
- np.save(f'{args.outdir_conf}/{i:06d}_c.npy', confidence_map.cpu().numpy())
+ np.save(
+ f"{args.outdir_conf}/{i:06d}_c.npy", confidence_map.cpu().numpy()
+ )
i += 1
- viz(image1, flow_up,f'{args.outdir}/flow_up{i:03d}.png')
+ viz(image1, flow_up, f"{args.outdir}/flow_up{i:03d}.png")
-if __name__ == '__main__':
+if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument('--model', help="restore checkpoint")
- parser.add_argument('--path', help="dataset for evaluation")
- parser.add_argument('--outdir',help="directory for the ouput the result")
- parser.add_argument('--small', action='store_true', help='use small model')
- parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
- parser.add_argument('--if_mask', action='store_true', help='if using the image mask to mask the color img')
- parser.add_argument('--confidence', action='store_true', help='if saving the confidence map')
- parser.add_argument('--discrete', action='store_true', help='if saving the confidence map in discrete')
- parser.add_argument('--thres', default=4, help='Threshold value for confidence map')
- parser.add_argument('--outdir_conf', help="directory to save flow confidence")
- parser.add_argument('--name', help="the name of a sequence")
+ parser.add_argument("--model", help="restore checkpoint")
+ parser.add_argument("--path", help="dataset for evaluation")
+ parser.add_argument("--outdir", help="directory for the ouput the result")
+ parser.add_argument("--small", action="store_true", help="use small model")
+ parser.add_argument(
+ "--mixed_precision", action="store_true", help="use mixed precision"
+ )
+ parser.add_argument(
+ "--if_mask",
+ action="store_true",
+ help="if using the image mask to mask the color img",
+ )
+ parser.add_argument(
+ "--confidence", action="store_true", help="if saving the confidence map"
+ )
+ parser.add_argument(
+ "--discrete",
+ action="store_true",
+ help="if saving the confidence map in discrete",
+ )
+ parser.add_argument("--thres", default=4, help="Threshold value for confidence map")
+ parser.add_argument("--outdir_conf", help="directory to save flow confidence")
+ parser.add_argument("--name", help="the name of a sequence")
args = parser.parse_args()
demo(args)
diff --git a/data_preprocessing/RAFT/run_raft.sh b/data_preprocessing/RAFT/run_raft.sh
index 164bffd..be789a1 100755
--- a/data_preprocessing/RAFT/run_raft.sh
+++ b/data_preprocessing/RAFT/run_raft.sh
@@ -1,6 +1,6 @@
-NAME=beauty_1
-ROOT_DIR=/home/xxx/code/CoDeF/all_sequences
-CODE_DIR=/home/xxx/code/CoDeF/data_preprocessing/RAFT
+NAME=heygen
+ROOT_DIR=/root/CoDeF/all_sequences
+CODE_DIR=/root/CoDeF/data_preprocessing/RAFT
IMG_DIR=$ROOT_DIR/${NAME}/${NAME}
FLOW_DIR=$ROOT_DIR/${NAME}/${NAME}_flow
diff --git a/data_preprocessing/preproc_mask.py b/data_preprocessing/preproc_mask.py
index d78a38a..3cacb15 100755
--- a/data_preprocessing/preproc_mask.py
+++ b/data_preprocessing/preproc_mask.py
@@ -4,17 +4,17 @@
from glob import glob
from tqdm import tqdm
-root_dir = '/home/xxx/code/CoDeF/all_sequences'
-name = 'beauty_1'
+root_dir = "/root/CoDeF/all_sequences"
+name = "heygen"
-msk_folder = f'{root_dir}/{name}/{name}_masks'
-img_folder = f'{root_dir}/{name}/{name}'
-frg_mask_folder = f'{root_dir}/{name}/{name}_masks_0'
-bkg_mask_folder = f'{root_dir}/{name}/{name}_masks_1'
+msk_folder = f"{root_dir}/{name}/{name}_masks"
+img_folder = f"{root_dir}/{name}/{name}"
+frg_mask_folder = f"{root_dir}/{name}/{name}_masks_0"
+bkg_mask_folder = f"{root_dir}/{name}/{name}_masks_1"
os.makedirs(frg_mask_folder, exist_ok=True)
os.makedirs(bkg_mask_folder, exist_ok=True)
-files = glob(msk_folder + '/*.png')
+files = glob(msk_folder + "/*.png")
num = len(files)
for i in tqdm(range(num)):
@@ -27,4 +27,4 @@
bg_mask[bg_mask == 0] = 127
bg_mask[bg_mask == 255] = 0
bg_mask[bg_mask == 127] = 255
- cv2.imwrite(os.path.join(bkg_mask_folder, file_n), bg_mask)
\ No newline at end of file
+ cv2.imwrite(os.path.join(bkg_mask_folder, file_n), bg_mask)
diff --git a/datasets/__init__.py b/datasets/__init__.py
index 5d34d1a..cdbc0b1 100755
--- a/datasets/__init__.py
+++ b/datasets/__init__.py
@@ -1,6 +1,6 @@
from .distributed_weighted_sampler import DistributedWeightedSampler
from .video_dataset import VideoDataset
-dataset_dict = {'video': VideoDataset}
+dataset_dict = {"video": VideoDataset}
-custom_sampler_dict = {'weighted': DistributedWeightedSampler}
\ No newline at end of file
+custom_sampler_dict = {"weighted": DistributedWeightedSampler}
diff --git a/datasets/distributed_weighted_sampler.py b/datasets/distributed_weighted_sampler.py
index 26049b7..f92d0d7 100755
--- a/datasets/distributed_weighted_sampler.py
+++ b/datasets/distributed_weighted_sampler.py
@@ -8,7 +8,7 @@
import torch.distributed as dist
-T_co = TypeVar('T_co', covariant=True)
+T_co = TypeVar("T_co", covariant=True)
class DistributedWeightedSampler(Sampler[T_co]):
@@ -58,9 +58,16 @@ class DistributedWeightedSampler(Sampler[T_co]):
... train(loader)
"""
- def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
- rank: Optional[int] = None, shuffle: bool = True,
- seed: int = 0, drop_last: bool = False, replacement: bool = True) -> None:
+ def __init__(
+ self,
+ dataset: Dataset,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ seed: int = 0,
+ drop_last: bool = False,
+ replacement: bool = True,
+ ) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
@@ -72,7 +79,8 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
- " [0, {}]".format(rank, num_replicas - 1))
+ " [0, {}]".format(rank, num_replicas - 1)
+ )
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@@ -87,7 +95,8 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
self.num_samples = math.ceil(
# `type:ignore` is required because Dataset cannot provide a default __len__
# see NOTE in pytorch/torch/utils/data/sampler.py
- (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
+ (len(self.dataset) - self.num_replicas)
+ / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
@@ -119,14 +128,16 @@ def __iter__(self) -> Iterator[T_co]:
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
- indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
+ indices += (indices * math.ceil(padding_size / len(indices)))[
+ :padding_size
+ ]
else:
# remove tail of data to make it evenly divisible.
- indices = indices[:self.total_size]
+ indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
- indices = indices[self.rank:self.total_size:self.num_replicas]
+ indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
# subsample weights
@@ -134,21 +145,22 @@ def __iter__(self) -> Iterator[T_co]:
weights = self.weights[indices][:, 0]
assert len(weights) == self.num_samples
- ###########################################################################
+ ###########################################################################
# the upper bound category number of multinomial is 2^24, to handle this we can use chunk or using random choices
# subsample_balanced_indicies = torch.multinomial(weights, self.num_samples, self.replacement)
- ###########################################################################
+ ###########################################################################
# using random choices
- rand_tensor = np.random.choice(range(0, len(weights)),
- size=self.num_samples,
- p=weights.numpy() / torch.sum(weights).numpy(),
- replace=self.replacement)
+ rand_tensor = np.random.choice(
+ range(0, len(weights)),
+ size=self.num_samples,
+ p=weights.numpy() / torch.sum(weights).numpy(),
+ replace=self.replacement,
+ )
subsample_balanced_indicies = torch.from_numpy(rand_tensor)
dataset_indices = torch.tensor(indices)[subsample_balanced_indicies]
return iter(dataset_indices.tolist())
-
def __len__(self) -> int:
return self.num_samples
@@ -161,4 +173,4 @@ def set_epoch(self, epoch: int) -> None:
Args:
epoch (int): Epoch number.
"""
- self.epoch = epoch
\ No newline at end of file
+ self.epoch = epoch
diff --git a/datasets/video_dataset.py b/datasets/video_dataset.py
index b969ee3..942d9e0 100755
--- a/datasets/video_dataset.py
+++ b/datasets/video_dataset.py
@@ -8,19 +8,22 @@
import glob
import cv2
+
# The basic dataset of reading rays
class VideoDataset(Dataset):
- def __init__(self,
- root_dir,
- split='train',
- img_wh=(504, 378),
- mask_dir=None,
- flow_dir=None,
- canonical_wh=None,
- ref_idx=None,
- canonical_dir=None,
- test=False):
+ def __init__(
+ self,
+ root_dir,
+ split="train",
+ img_wh=(504, 378),
+ mask_dir=None,
+ flow_dir=None,
+ canonical_wh=None,
+ ref_idx=None,
+ canonical_dir=None,
+ test=False,
+ ):
self.test = test
self.root_dir = root_dir
self.split = split
@@ -41,11 +44,11 @@ def read_meta(self):
# construct grid
grid = np.indices((h, w)).astype(np.float32)
# normalize
- grid[0,:,:] = grid[0,:,:] / h
- grid[1,:,:] = grid[1,:,:] / w
- self.grid = torch.from_numpy(rearrange(grid, 'c h w -> (h w) c'))
+ grid[0, :, :] = grid[0, :, :] / h
+ grid[1, :, :] = grid[1, :, :] / w
+ self.grid = torch.from_numpy(rearrange(grid, "c h w -> (h w) c"))
warp_code = 1
- for input_image_path in sorted(glob.glob(f'{self.root_dir}/*')):
+ for input_image_path in sorted(glob.glob(f"{self.root_dir}/*")):
print(input_image_path)
all_images_path.append(input_image_path)
self.ts_w.append(torch.Tensor([warp_code]).long())
@@ -55,9 +58,9 @@ def read_meta(self):
h_c = self.canonical_wh[1]
w_c = self.canonical_wh[0]
grid_c = np.indices((h_c, w_c)).astype(np.float32)
- grid_c[0,:,:] = (grid_c[0,:,:] - (h_c - h) / 2) / h
- grid_c[1,:,:] = (grid_c[1,:,:] - (w_c - w) / 2) / w
- self.grid_c = torch.from_numpy(rearrange(grid_c, 'c h w -> (h w) c'))
+ grid_c[0, :, :] = (grid_c[0, :, :] - (h_c - h) / 2) / h
+ grid_c[1, :, :] = (grid_c[1, :, :] - (w_c - w) / 2) / w
+ self.grid_c = torch.from_numpy(rearrange(grid_c, "c h w -> (h w) c"))
else:
self.grid_c = self.grid
self.canonical_wh = self.img_wh
@@ -69,14 +72,18 @@ def read_meta(self):
else:
self.all_flows = None
- if self.split == 'train' or self.split == 'val':
+ if self.split == "train" or self.split == "val":
if self.canonical_dir is not None:
- all_images_path_ = sorted(glob.glob(f'{self.canonical_dir}/*.png'))
+ all_images_path_ = sorted(glob.glob(f"{self.canonical_dir}/*.png"))
self.canonical_img = []
for input_image_path in all_images_path_:
input_image = cv2.imread(input_image_path)
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
- input_image = cv2.resize(input_image, (self.canonical_wh[0], self.canonical_wh[1]), interpolation = cv2.INTER_AREA)
+ input_image = cv2.resize(
+ input_image,
+ (self.canonical_wh[0], self.canonical_wh[1]),
+ interpolation=cv2.INTER_AREA,
+ )
input_image_tensor = torch.from_numpy(input_image).float() / 256
self.canonical_img.append(input_image_tensor)
self.canonical_img = torch.stack(self.canonical_img, dim=0)
@@ -84,53 +91,79 @@ def read_meta(self):
for input_image_path in all_images_path:
input_image = cv2.imread(input_image_path)
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
- input_image = cv2.resize(input_image, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA)
+ input_image = cv2.resize(
+ input_image,
+ (self.img_wh[0], self.img_wh[1]),
+ interpolation=cv2.INTER_AREA,
+ )
input_image_tensor = torch.from_numpy(input_image).float() / 256
self.all_images.append(input_image_tensor)
if self.mask_dir:
input_image_name = input_image_path.split("/")[-1][:-4]
for i in range(len(self.mask_dir)):
- input_mask = cv2.imread(f'{self.mask_dir[i]}/{input_image_name}.png')
- input_mask = cv2.resize(input_mask, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA)
+ input_mask = cv2.imread(
+ f"{self.mask_dir[i]}/{input_image_name}.png"
+ )
+ input_mask = cv2.resize(
+ input_mask,
+ (self.img_wh[0], self.img_wh[1]),
+ interpolation=cv2.INTER_AREA,
+ )
input_mask_tensor = torch.from_numpy(input_mask).float() / 256
self.all_masks.append(input_mask_tensor)
- if self.split == 'val':
+ if self.split == "val":
input_image = cv2.imread(all_images_path[0])
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
- input_image = cv2.resize(input_image, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA)
+ input_image = cv2.resize(
+ input_image,
+ (self.img_wh[0], self.img_wh[1]),
+ interpolation=cv2.INTER_AREA,
+ )
input_image_tensor = torch.from_numpy(input_image).float() / 256
self.all_images.append(input_image_tensor)
if self.mask_dir:
input_image_name = all_images_path[0].split("/")[-1][:-4]
for i in range(len(self.mask_dir)):
- input_mask = cv2.imread(f'{self.mask_dir[i]}/{input_image_name}.png')
- input_mask = cv2.resize(input_mask, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA)
+ input_mask = cv2.imread(
+ f"{self.mask_dir[i]}/{input_image_name}.png"
+ )
+ input_mask = cv2.resize(
+ input_mask,
+ (self.img_wh[0], self.img_wh[1]),
+ interpolation=cv2.INTER_AREA,
+ )
input_mask_tensor = torch.from_numpy(input_mask).float() / 256
self.all_masks.append(input_mask_tensor)
if self.flow_dir:
- for input_image_path in sorted(glob.glob(f'{self.flow_dir}/*npy')):
- flow_load=np.load(input_image_path) # (1, 2, h, w)
- flow_tensor=torch.from_numpy(flow_load).float()[:, [1, 0]]
- flow_tensor=torch.nn.functional.interpolate(flow_tensor,size=(self.img_wh[1],self.img_wh[0]))
- H_,W_=flow_load.shape[-2],flow_load.shape[-1]
- flow_tensor=flow_tensor.reshape(2,-1).transpose(1,0)
+ for input_image_path in sorted(glob.glob(f"{self.flow_dir}/*npy")):
+ flow_load = np.load(input_image_path) # (1, 2, h, w)
+ flow_tensor = torch.from_numpy(flow_load).float()[:, [1, 0]]
+ flow_tensor = torch.nn.functional.interpolate(
+ flow_tensor, size=(self.img_wh[1], self.img_wh[0])
+ )
+ H_, W_ = flow_load.shape[-2], flow_load.shape[-1]
+ flow_tensor = flow_tensor.reshape(2, -1).transpose(1, 0)
flow_tensor[..., 0] /= W_
flow_tensor[..., 1] /= H_
self.all_flows.append(flow_tensor)
i = 0
- for input_image_path in sorted(glob.glob(f'{self.flow_dir}_confidence/*npy')):
- flow_load=np.load(input_image_path)
- flow_tensor=torch.from_numpy(flow_load).float()
- flow_tensor=torch.nn.functional.interpolate(flow_tensor,size=(self.img_wh[1],self.img_wh[0]))
- flow_tensor=flow_tensor.reshape(1,-1).transpose(1,0)
+ for input_image_path in sorted(
+ glob.glob(f"{self.flow_dir}_confidence/*npy")
+ ):
+ flow_load = np.load(input_image_path)
+ flow_tensor = torch.from_numpy(flow_load).float()
+ flow_tensor = torch.nn.functional.interpolate(
+ flow_tensor, size=(self.img_wh[1], self.img_wh[0])
+ )
+ flow_tensor = flow_tensor.reshape(1, -1).transpose(1, 0)
flow_tensor = flow_tensor.sum(dim=-1) < 0.05
self.all_flows[i][flow_tensor] = 5
- i += 1
+ i += 1
- if self.split == 'val':
+ if self.split == "val":
self.ref_idx = 0
def __len__(self):
@@ -139,18 +172,45 @@ def __len__(self):
return 200 * len(self.all_images)
def __getitem__(self, idx):
- if self.split == 'train' or self.split == 'val':
+ if self.split == "train" or self.split == "val":
idx = idx % len(self.all_images)
- sample = {'rgbs': self.all_images[idx],
- 'canonical_img': self.all_images[idx] if self.canonical_dir is None else self.canonical_img,
- 'ts_w': self.ts_w[idx],
- 'grid': self.grid,
- 'canonical_wh': self.canonical_wh,
- 'img_wh': self.img_wh,
- 'masks': self.all_masks[len(self.mask_dir)*idx:len(self.mask_dir)*idx+len(self.mask_dir)] if self.mask_dir else [torch.ones((self.img_wh[1], self.img_wh[0], 1))],
- 'flows': self.all_flows[idx] if (idx b c h w', b=1)
+ gt_lpips = rearrange(gt_lpips, "(b h) w c -> b c h w", b=1)
gt_lpips = torch.from_numpy(gt_lpips)
predict_image_lpips = image_pred.clone().detach().cpu() * 2.0 - 1.0
- predict_image_lpips = rearrange(predict_image_lpips, '(b h) w c -> b c h w', b=1)
- lpips_result = lpips_model.forward(predict_image_lpips, gt_lpips).cpu().detach().numpy()
+ predict_image_lpips = rearrange(predict_image_lpips, "(b h) w c -> b c h w", b=1)
+ lpips_result = (
+ lpips_model.forward(predict_image_lpips, gt_lpips).cpu().detach().numpy()
+ )
return np.squeeze(lpips_result)
diff --git a/models/implicit_model.py b/models/implicit_model.py
index 133afa6..e9567e5 100755
--- a/models/implicit_model.py
+++ b/models/implicit_model.py
@@ -11,9 +11,7 @@ def init_weights(m):
class TranslationField(nn.Module):
- def __init__(self, D=6, W=128,
- in_channels_w=8, in_channels_xyz=34,
- skips=[4]):
+ def __init__(self, D=6, W=128, in_channels_w=8, in_channels_xyz=34, skips=[4]):
"""
D: number of layers for density (sigma) encoder
W: number of hidden units in each layer
@@ -32,9 +30,9 @@ def __init__(self, D=6, W=128,
# encoding layers
for i in range(D):
if i == 0:
- layer = nn.Linear(in_channels_xyz+self.in_channels_w, W)
+ layer = nn.Linear(in_channels_xyz + self.in_channels_w, W)
elif i in skips:
- layer = nn.Linear(W+in_channels_xyz+self.in_channels_w, W)
+ layer = nn.Linear(W + in_channels_xyz + self.in_channels_w, W)
else:
layer = nn.Linear(W, W)
init_weights(layer)
@@ -81,12 +79,12 @@ def __init__(self, in_channels, N_freqs, logscale=True, identity=True):
self.identity = identity
self.in_channels = in_channels
self.funcs = [torch.sin, torch.cos]
- self.out_channels = in_channels*(len(self.funcs)*N_freqs+1)
+ self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1)
if logscale:
- self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs)
+ self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs)
else:
- self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs)
+ self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs)
def forward(self, x):
"""
@@ -106,13 +104,21 @@ def forward(self, x):
out = []
for freq in self.freq_bands:
for func in self.funcs:
- out += [func(freq*x)]
+ out += [func(freq * x)]
return torch.cat(out, -1)
class AnnealedEmbedding(nn.Module):
- def __init__(self, in_channels, N_freqs, annealed_step, annealed_begin_step=0, logscale=True, identity=True):
+ def __init__(
+ self,
+ in_channels,
+ N_freqs,
+ annealed_step,
+ annealed_begin_step=0,
+ logscale=True,
+ identity=True,
+ ):
"""
Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
in_channels: number of input channels (3 for both xyz and direction)
@@ -124,14 +130,14 @@ def __init__(self, in_channels, N_freqs, annealed_step, annealed_begin_step=0, l
self.annealed_step = annealed_step
self.annealed_begin_step = annealed_begin_step
self.funcs = [torch.sin, torch.cos]
- self.out_channels = in_channels*(len(self.funcs)*N_freqs+1)
- self.index = torch.linspace(0, N_freqs-1, N_freqs)
+ self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1)
+ self.index = torch.linspace(0, N_freqs - 1, N_freqs)
self.identity = identity
if logscale:
- self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs)
+ self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs)
else:
- self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs)
+ self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs)
def forward(self, x, step):
"""
@@ -157,20 +163,24 @@ def forward(self, x, step):
if step <= self.annealed_begin_step:
alpha = 0
else:
- alpha = self.N_freqs * (step - self.annealed_begin_step) / float(
- self.annealed_step)
+ alpha = (
+ self.N_freqs
+ * (step - self.annealed_begin_step)
+ / float(self.annealed_step)
+ )
for j, freq in enumerate(self.freq_bands):
- w = (1 - torch.cos(
- math.pi * torch.clamp(alpha - self.index[j], 0, 1))) / 2
+ w = (1 - torch.cos(math.pi * torch.clamp(alpha - self.index[j], 0, 1))) / 2
for func in self.funcs:
- out += [w * func(freq*x)]
+ out += [w * func(freq * x)]
return torch.cat(out, -1)
class AnnealedHash(nn.Module):
- def __init__(self, in_channels, annealed_step, annealed_begin_step=0, identity=True):
+ def __init__(
+ self, in_channels, annealed_step, annealed_begin_step=0, identity=True
+ ):
"""
Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
in_channels: number of input channels (3 for both xyz and direction)
@@ -207,10 +217,21 @@ def forward(self, x_embed, step):
if step <= self.annealed_begin_step:
alpha = 0
else:
- alpha = self.N_freqs * (step - self.annealed_begin_step) / float(
- self.annealed_step)
-
- w = (1 - torch.cos(math.pi * torch.clamp(alpha * torch.ones_like(self.index_2) - self.index_2, 0, 1))) / 2
+ alpha = (
+ self.N_freqs
+ * (step - self.annealed_begin_step)
+ / float(self.annealed_step)
+ )
+
+ w = (
+ 1
+ - torch.cos(
+ math.pi
+ * torch.clamp(
+ alpha * torch.ones_like(self.index_2) - self.index_2, 0, 1
+ )
+ )
+ ) / 2
out = x_embed * w.to(x_embed.device)
@@ -218,12 +239,15 @@ def forward(self, x_embed, step):
class ImplicitVideo(nn.Module):
- def __init__(self,
- D=8, W=256,
- in_channels_xyz=34,
- skips=[4],
- out_channels=3,
- sigmoid_offset=0):
+ def __init__(
+ self,
+ D=8,
+ W=256,
+ in_channels_xyz=34,
+ skips=[4],
+ out_channels=3,
+ sigmoid_offset=0,
+ ):
"""
D: number of layers for density (sigma) encoder
W: number of hidden units in each layer
@@ -247,7 +271,7 @@ def __init__(self,
if i == 0:
layer = nn.Linear(self.in_channels_xyz, W)
elif i in skips:
- layer = nn.Linear(W+self.in_channels_xyz, W)
+ layer = nn.Linear(W + self.in_channels_xyz, W)
else:
layer = nn.Linear(W, W)
init_weights(layer)
@@ -297,12 +321,12 @@ def forward(self, x):
class ImplicitVideo_Hash(nn.Module):
def __init__(self, config):
super().__init__()
- self.encoder = tcnn.Encoding(n_input_dims=2,
- encoding_config=config["encoding"])
- self.decoder = tcnn.Network(n_input_dims=self.encoder.n_output_dims +
- 2,
- n_output_dims=3,
- network_config=config["network"])
+ self.encoder = tcnn.Encoding(n_input_dims=2, encoding_config=config["encoding"])
+ self.decoder = tcnn.Network(
+ n_input_dims=self.encoder.n_output_dims + 2,
+ n_output_dims=3,
+ network_config=config["network"],
+ )
def forward(self, x):
input = x
@@ -316,17 +340,20 @@ def forward(self, x):
class Deform_Hash3d(nn.Module):
def __init__(self, config):
super().__init__()
- self.encoder = tcnn.Encoding(n_input_dims=3,
- encoding_config=config["encoding_deform3d"])
- self.decoder = tcnn.Network(n_input_dims=self.encoder.n_output_dims + 3,
- n_output_dims=2,
- network_config=config["network_deform"])
+ self.encoder = tcnn.Encoding(
+ n_input_dims=3, encoding_config=config["encoding_deform3d"]
+ )
+ self.decoder = tcnn.Network(
+ n_input_dims=self.encoder.n_output_dims + 3,
+ n_output_dims=2,
+ network_config=config["network_deform"],
+ )
def forward(self, x, step=0, aneal_func=None):
input = x
input = self.encoder(input)
if aneal_func is not None:
- input = torch.cat([x, aneal_func(input,step)], dim=-1)
+ input = torch.cat([x, aneal_func(input, step)], dim=-1)
else:
input = torch.cat([x, input], dim=-1)
@@ -341,7 +368,7 @@ def __init__(self, config):
super().__init__()
self.Deform_Hash3d = Deform_Hash3d(config)
- def forward(self, xyt_norm, step=0,aneal_func=None):
- x = self.Deform_Hash3d(xyt_norm,step=step, aneal_func=aneal_func)
+ def forward(self, xyt_norm, step=0, aneal_func=None):
+ x = self.Deform_Hash3d(xyt_norm, step=step, aneal_func=aneal_func)
return x
diff --git a/opt.py b/opt.py
index b62ccb9..ab4c132 100755
--- a/opt.py
+++ b/opt.py
@@ -6,155 +6,280 @@ def get_opts():
parser = argparse.ArgumentParser()
# General Setttings
- parser.add_argument('--root_dir', type=str, default='Batman_masked_frames',
- help='root directory of dataset')
- parser.add_argument('--canonical_dir', type=str, default=None,
- help='directory of canonical dataset')
+ parser.add_argument(
+ "--root_dir",
+ type=str,
+ default="Batman_masked_frames",
+ help="root directory of dataset",
+ )
+ parser.add_argument(
+ "--canonical_dir", type=str, default=None, help="directory of canonical dataset"
+ )
# support multiple mask as input (each mask has different deformation fields)
- parser.add_argument('--mask_dir', nargs="+", type=str, default=None,
- help='mask of the dataset')
- parser.add_argument('--flow_dir', type=str,
- default=None,
- help='masks of dataset')
- parser.add_argument('--dataset_name', type=str, default='video',
- choices=['video'],
- help='which dataset to train/val')
- parser.add_argument('--img_wh', nargs="+", type=int, default=[842, 512],
- help='resolution (img_w, img_h) of the full image')
- parser.add_argument('--canonical_wh', nargs="+", type=int, default=None,
- help='default same as the img_wh, can be set to a larger range to include more content')
- parser.add_argument('--ref_idx', type=int, default=None,
- help='manually select a frame as reference (for rigid movement)')
+ parser.add_argument(
+ "--mask_dir", nargs="+", type=str, default=None, help="mask of the dataset"
+ )
+ parser.add_argument("--flow_dir", type=str, default=None, help="masks of dataset")
+ parser.add_argument(
+ "--dataset_name",
+ type=str,
+ default="video",
+ choices=["video"],
+ help="which dataset to train/val",
+ )
+ parser.add_argument(
+ "--img_wh",
+ nargs="+",
+ type=int,
+ default=[842, 512],
+ help="resolution (img_w, img_h) of the full image",
+ )
+ parser.add_argument(
+ "--canonical_wh",
+ nargs="+",
+ type=int,
+ default=None,
+ help="default same as the img_wh, can be set to a larger range to include more content",
+ )
+ parser.add_argument(
+ "--ref_idx",
+ type=int,
+ default=None,
+ help="manually select a frame as reference (for rigid movement)",
+ )
# Deformation Setting
- parser.add_argument('--encode_w', default=False, action="store_true",
- help='whether to apply warping')
+ parser.add_argument(
+ "--encode_w",
+ default=False,
+ action="store_true",
+ help="whether to apply warping",
+ )
# Training Setttings
- parser.add_argument('--batch_size', type=int, default=1,
- help='batch size')
- parser.add_argument('--num_steps', type=int, default=10000,
- help='number of training epochs')
- parser.add_argument('--valid_iters', type=int, default=30,
- help='valid iters for each epoch')
- parser.add_argument('--valid_batches', type=int, default=0,
- help='valid batches for each valid process')
- parser.add_argument('--save_model_iters', type=int, default=5000,
- help='iterations to save the models')
- parser.add_argument('--gpus', nargs="+", type=int, default=[0],
- help='gpu devices')
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size")
+ parser.add_argument(
+ "--num_steps", type=int, default=10000, help="number of training epochs"
+ )
+ parser.add_argument(
+ "--valid_iters", type=int, default=30, help="valid iters for each epoch"
+ )
+ parser.add_argument(
+ "--valid_batches",
+ type=int,
+ default=0,
+ help="valid batches for each valid process",
+ )
+ parser.add_argument(
+ "--save_model_iters",
+ type=int,
+ default=5000,
+ help="iterations to save the models",
+ )
+ parser.add_argument("--gpus", nargs="+", type=int, default=[0], help="gpu devices")
# Test Setttings
- parser.add_argument('--test', default=False, action="store_true",
- help='whether to disable identity')
+ parser.add_argument(
+ "--test", default=False, action="store_true", help="whether to disable identity"
+ )
# Model Save and Load
- parser.add_argument('--ckpt_path', type=str, default=None,
- help='pretrained checkpoint to load (including optimizers, etc)')
- parser.add_argument('--prefixes_to_ignore', nargs='+', type=str, default=['loss'],
- help='the prefixes to ignore in the checkpoint state dict')
- parser.add_argument('--weight_path', type=str, default=None,
- help='pretrained model weight to load (do not load optimizers, etc)')
- parser.add_argument('--model_save_path', type=str, default='ckpts',
- help='save checkpoint to')
- parser.add_argument('--log_save_path', type=str, default='logs',
- help='save log to')
- parser.add_argument('--exp_name', type=str, default='exp',
- help='experiment name')
+ parser.add_argument(
+ "--ckpt_path",
+ type=str,
+ default=None,
+ help="pretrained checkpoint to load (including optimizers, etc)",
+ )
+ parser.add_argument(
+ "--prefixes_to_ignore",
+ nargs="+",
+ type=str,
+ default=["loss"],
+ help="the prefixes to ignore in the checkpoint state dict",
+ )
+ parser.add_argument(
+ "--weight_path",
+ type=str,
+ default=None,
+ help="pretrained model weight to load (do not load optimizers, etc)",
+ )
+ parser.add_argument(
+ "--model_save_path", type=str, default="ckpts", help="save checkpoint to"
+ )
+ parser.add_argument("--log_save_path", type=str, default="logs", help="save log to")
+ parser.add_argument("--exp_name", type=str, default="exp", help="experiment name")
# Optimize Settings
- parser.add_argument('--optimizer', type=str, default='adam',
- help='optimizer type',
- choices=['sgd', 'adam', 'radam', 'ranger'])
- parser.add_argument('--lr', type=float, default=5e-4,
- help='learning rate')
- parser.add_argument('--momentum', type=float, default=0.9,
- help='learning rate momentum')
- parser.add_argument('--weight_decay', type=float, default=0,
- help='weight decay')
- parser.add_argument('--lr_scheduler', type=str, default='steplr',
- help='scheduler type',
- choices=['steplr', 'cosine', 'poly', 'exponential'])
+ parser.add_argument(
+ "--optimizer",
+ type=str,
+ default="adam",
+ help="optimizer type",
+ choices=["sgd", "adam", "radam", "ranger"],
+ )
+ parser.add_argument("--lr", type=float, default=5e-4, help="learning rate")
+ parser.add_argument(
+ "--momentum", type=float, default=0.9, help="learning rate momentum"
+ )
+ parser.add_argument("--weight_decay", type=float, default=0, help="weight decay")
+ parser.add_argument(
+ "--lr_scheduler",
+ type=str,
+ default="steplr",
+ help="scheduler type",
+ choices=["steplr", "cosine", "poly", "exponential"],
+ )
#### params for steplr ####
- parser.add_argument('--decay_step', nargs='+', type=int,
- default=[2500, 5000, 7500],
- help='scheduler decay step')
- parser.add_argument('--decay_gamma', type=float, default=0.5,
- help='learning rate decay amount')
+ parser.add_argument(
+ "--decay_step",
+ nargs="+",
+ type=int,
+ default=[2500, 5000, 7500],
+ help="scheduler decay step",
+ )
+ parser.add_argument(
+ "--decay_gamma", type=float, default=0.5, help="learning rate decay amount"
+ )
#### params for warmup, only applied when optimizer == 'sgd' or 'adam'
- parser.add_argument('--warmup_multiplier', type=float, default=1.0,
- help='lr is multiplied by this factor after --warmup_epochs')
- parser.add_argument('--warmup_epochs', type=int, default=0,
- help='Gradually warm-up(increasing) learning rate in optimizer')
+ parser.add_argument(
+ "--warmup_multiplier",
+ type=float,
+ default=1.0,
+ help="lr is multiplied by this factor after --warmup_epochs",
+ )
+ parser.add_argument(
+ "--warmup_epochs",
+ type=int,
+ default=0,
+ help="Gradually warm-up(increasing) learning rate in optimizer",
+ )
##### annealed positional encoding ######
- parser.add_argument('--annealed', default=False, action="store_true",
- help='whether to apply annealed positional encoding (Only in the warping field)')
- parser.add_argument('--annealed_begin_step', type=int, default=0,
- help='annealed step to begin for positional encoding')
- parser.add_argument('--annealed_step', type=int, default=5000,
- help='maximum annealed step for positional encoding')
+ parser.add_argument(
+ "--annealed",
+ default=False,
+ action="store_true",
+ help="whether to apply annealed positional encoding (Only in the warping field)",
+ )
+ parser.add_argument(
+ "--annealed_begin_step",
+ type=int,
+ default=0,
+ help="annealed step to begin for positional encoding",
+ )
+ parser.add_argument(
+ "--annealed_step",
+ type=int,
+ default=5000,
+ help="maximum annealed step for positional encoding",
+ )
##### Additional losses ######
- parser.add_argument('--flow_loss', type=float, default=None,
- help='optical flow loss weight')
- parser.add_argument('--bg_loss', type=float, default=None,
- help='regularize the rest part of each object ')
- parser.add_argument('--grad_loss', type=float, default=0.1,
- help='image gradient loss weight')
- parser.add_argument('--flow_step', type=int, default=-1,
- help='Step to begin to perform flow loss.')
- parser.add_argument('--ref_step', type=int, default=-1,
- help='Step to stop reference frame loss.')
- parser.add_argument('--self_bg', type=bool_parser, default=False,
- help='Whether to use self background as bg loss.')
+ parser.add_argument(
+ "--flow_loss", type=float, default=None, help="optical flow loss weight"
+ )
+ parser.add_argument(
+ "--bg_loss",
+ type=float,
+ default=None,
+ help="regularize the rest part of each object ",
+ )
+ parser.add_argument(
+ "--grad_loss", type=float, default=0.1, help="image gradient loss weight"
+ )
+ parser.add_argument(
+ "--flow_step", type=int, default=-1, help="Step to begin to perform flow loss."
+ )
+ parser.add_argument(
+ "--ref_step", type=int, default=-1, help="Step to stop reference frame loss."
+ )
+ parser.add_argument(
+ "--self_bg",
+ type=bool_parser,
+ default=False,
+ help="Whether to use self background as bg loss.",
+ )
##### Special cases: for black-dominated images
- parser.add_argument('--sigmoid_offset', type=float, default=0,
- help='whether to process balck-dominated images.')
+ parser.add_argument(
+ "--sigmoid_offset",
+ type=float,
+ default=0,
+ help="whether to process balck-dominated images.",
+ )
# Other miscellaneous settings.
- parser.add_argument('--save_deform', type=bool_parser, default=False,
- help='Whether to save deformation field or not.')
- parser.add_argument('--save_video', type=bool_parser, default=True,
- help='Whether to save video or not.')
- parser.add_argument('--fps', type=int, default=30,
- help='FPS of the saved video.')
+ parser.add_argument(
+ "--save_deform",
+ type=bool_parser,
+ default=False,
+ help="Whether to save deformation field or not.",
+ )
+ parser.add_argument(
+ "--save_video",
+ type=bool_parser,
+ default=True,
+ help="Whether to save video or not.",
+ )
+ parser.add_argument("--fps", type=int, default=30, help="FPS of the saved video.")
# Network settings for PE.
- parser.add_argument('--deform_D', type=int, default=6,
- help='The depth of deformation field MLP.')
- parser.add_argument('--deform_W', type=int, default=128,
- help='The width of deformation field MLP.')
- parser.add_argument('--vid_D', type=int, default=8,
- help='The depth of implicit video MLP.')
- parser.add_argument('--vid_W', type=int, default=256,
- help='The width of implicit video MLP.')
- parser.add_argument('--N_vocab_w', type=int, default=200,
- help='number of vocabulary for warp code in the dataset for nn.Embedding')
- parser.add_argument('--N_w', type=int, default=8,
- help='embeddings size for warping')
- parser.add_argument('--N_xyz_w', nargs="+", type=int, default=[8, 8],
- help='positional encoding frequency of deformation field')
+ parser.add_argument(
+ "--deform_D", type=int, default=6, help="The depth of deformation field MLP."
+ )
+ parser.add_argument(
+ "--deform_W", type=int, default=128, help="The width of deformation field MLP."
+ )
+ parser.add_argument(
+ "--vid_D", type=int, default=8, help="The depth of implicit video MLP."
+ )
+ parser.add_argument(
+ "--vid_W", type=int, default=256, help="The width of implicit video MLP."
+ )
+ parser.add_argument(
+ "--N_vocab_w",
+ type=int,
+ default=200,
+ help="number of vocabulary for warp code in the dataset for nn.Embedding",
+ )
+ parser.add_argument(
+ "--N_w", type=int, default=8, help="embeddings size for warping"
+ )
+ parser.add_argument(
+ "--N_xyz_w",
+ nargs="+",
+ type=int,
+ default=[8, 8],
+ help="positional encoding frequency of deformation field",
+ )
# Network settings for Hash, please see details in configs/hash.json
- parser.add_argument('--vid_hash', type=bool_parser, default=False,
- help='Whether to use hash encoding in implicit video system.')
- parser.add_argument('--deform_hash', type=bool_parser, default=False,
- help='Whether to use hash encoding in deformation field.')
+ parser.add_argument(
+ "--vid_hash",
+ type=bool_parser,
+ default=False,
+ help="Whether to use hash encoding in implicit video system.",
+ )
+ parser.add_argument(
+ "--deform_hash",
+ type=bool_parser,
+ default=False,
+ help="Whether to use hash encoding in deformation field.",
+ )
# Config files
- parser.add_argument('--config', type=str, default=None,
- help='path to the YAML config file.')
+ parser.add_argument(
+ "--config", type=str, default=None, help="path to the YAML config file."
+ )
args = parser.parse_args()
if args.config is not None:
- with open(args.config, 'r') as f:
+ with open(args.config, "r") as f:
config = yaml.safe_load(f)
args_dict = vars(args)
args_dict.update(config)
@@ -170,8 +295,8 @@ def bool_parser(arg):
return arg
if arg is None:
return False
- if arg.lower() in ['1', 'true', 't', 'yes', 'y']:
+ if arg.lower() in ["1", "true", "t", "yes", "y"]:
return True
- if arg.lower() in ['0', 'false', 'f', 'no', 'n']:
+ if arg.lower() in ["0", "false", "f", "no", "n"]:
return False
- raise ValueError(f'`{arg}` cannot be converted to boolean!')
+ raise ValueError(f"`{arg}` cannot be converted to boolean!")
diff --git a/prepare_video.sh b/prepare_video.sh
new file mode 100644
index 0000000..616cb25
--- /dev/null
+++ b/prepare_video.sh
@@ -0,0 +1,2 @@
+ffmpeg -i $1 -vf "crop=1080:1080:420:0" cropped.mp4
+ffmpeg -i cropped.mp4 -start_number 0 -vf "fps=25" /root/CoDeF/all_sequences/heygen/heygen/%05d.png
\ No newline at end of file
diff --git a/resize_image.py b/resize_image.py
new file mode 100644
index 0000000..b14e779
--- /dev/null
+++ b/resize_image.py
@@ -0,0 +1,13 @@
+from PIL import Image
+
+
+path = "/root/CoDeF/all_sequences/heygen/base_control/canonical_0.png"
+
+# Open the original image
+original_image = Image.open(path)
+
+# Resize the image
+resized_image = original_image.resize((1080, 1080))
+
+# Save the resized image
+resized_image.save(path)
diff --git a/scripts/test_canonical.sh b/scripts/test_canonical.sh
index a60140a..33b27b7 100755
--- a/scripts/test_canonical.sh
+++ b/scripts/test_canonical.sh
@@ -1,22 +1,19 @@
GPUS=0
-NAME=scene_0
+NAME=heygen
EXP_NAME=base
ROOT_DIRECTORY="all_sequences/$NAME/$NAME"
LOG_SAVE_PATH="logs/test_all_sequences/$NAME"
-MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1"
-
CANONICAL_DIR="all_sequences/${NAME}/${EXP_NAME}_control"
-WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt
-# WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=10000.ckpt
+# WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt
+WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=200000.ckpt
python train.py --test --encode_w \
--root_dir $ROOT_DIRECTORY \
--log_save_path $LOG_SAVE_PATH \
- --mask_dir $MASK_DIRECTORY \
--weight_path $WEIGHT_PATH \
--gpus $GPUS \
--canonical_dir $CANONICAL_DIR \
diff --git a/scripts/test_multi.sh b/scripts/test_multi.sh
index 1beeda7..c491b63 100755
--- a/scripts/test_multi.sh
+++ b/scripts/test_multi.sh
@@ -1,20 +1,17 @@
GPUS=0
-NAME=scene_0
+NAME=heygen
EXP_NAME=base
ROOT_DIRECTORY="all_sequences/$NAME/$NAME"
LOG_SAVE_PATH="logs/test_all_sequences/$NAME"
-MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1"
-
-WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt
-# WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=10000.ckpt
+# WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt
+WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=200000.ckpt
python train.py --test --encode_w \
--root_dir $ROOT_DIRECTORY \
--log_save_path $LOG_SAVE_PATH \
- --mask_dir $MASK_DIRECTORY \
--weight_path $WEIGHT_PATH \
--gpus $GPUS \
--config configs/${NAME}/${EXP_NAME}.yaml \
diff --git a/scripts/train_multi.sh b/scripts/train_multi.sh
index 297457f..66459e7 100755
--- a/scripts/train_multi.sh
+++ b/scripts/train_multi.sh
@@ -1,19 +1,17 @@
GPUS=0
-NAME=scene_0
+NAME=heygen
EXP_NAME=base
ROOT_DIRECTORY="all_sequences/$NAME/$NAME"
MODEL_SAVE_PATH="ckpts/all_sequences/$NAME"
LOG_SAVE_PATH="logs/all_sequences/$NAME"
-MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1"
FLOW_DIRECTORY="all_sequences/$NAME/${NAME}_flow"
python train.py --root_dir $ROOT_DIRECTORY \
--model_save_path $MODEL_SAVE_PATH \
--log_save_path $LOG_SAVE_PATH \
- --mask_dir $MASK_DIRECTORY \
--flow_dir $FLOW_DIRECTORY \
--gpus $GPUS \
--encode_w --annealed \
diff --git a/train.py b/train.py
index 8f30b3b..50cd154 100755
--- a/train.py
+++ b/train.py
@@ -42,15 +42,15 @@ class ImplicitVideoSystem(LightningModule):
def __init__(self, hparams):
super(ImplicitVideoSystem, self).__init__()
self.save_hyperparameters(hparams)
- self.color_loss = loss_dict['mse'](coef=1)
+ self.color_loss = loss_dict["mse"](coef=1)
if hparams.save_video:
self.video_visualizer = VideoVisualizer(fps=hparams.fps)
self.raw_video_visualizer = VideoVisualizer(fps=hparams.fps)
self.dual_video_visualizer = VideoVisualizer(fps=hparams.fps)
- self.models_to_train=[]
+ self.models_to_train = []
self.embedding_xyz = Embedding(2, 8)
- self.embeddings = {'xyz': self.embedding_xyz}
+ self.embeddings = {"xyz": self.embedding_xyz}
self.models = {}
# Construct normalized meshgrid.
@@ -69,8 +69,8 @@ def __init__(self, hparams):
# Multiple deformation MLP.
# Progressive Training for the Deformation (Annealed PE).
# No trainable parameters.
- self.embeddings['xyz_w'] = []
- assert (isinstance(self.hparams.N_xyz_w, list))
+ self.embeddings["xyz_w"] = []
+ assert isinstance(self.hparams.N_xyz_w, list)
in_channels_xyz = []
for i in range(self.num_models):
N_xyz_w = self.hparams.N_xyz_w[i]
@@ -78,59 +78,62 @@ def __init__(self, hparams):
if hparams.annealed:
if hparams.deform_hash:
self.embedding_hash = AnnealedHash(
- in_channels=2,
- annealed_step=hparams.annealed_step,
- annealed_begin_step=hparams.annealed_begin_step)
- self.embeddings['aneal_hash'] = self.embedding_hash
+ in_channels=2,
+ annealed_step=hparams.annealed_step,
+ annealed_begin_step=hparams.annealed_begin_step,
+ )
+ self.embeddings["aneal_hash"] = self.embedding_hash
else:
self.embedding_xyz_w = AnnealedEmbedding(
in_channels=2,
N_freqs=N_xyz_w,
annealed_step=hparams.annealed_step,
- annealed_begin_step=hparams.annealed_begin_step)
- self.embeddings['xyz_w'] += [self.embedding_xyz_w]
+ annealed_begin_step=hparams.annealed_begin_step,
+ )
+ self.embeddings["xyz_w"] += [self.embedding_xyz_w]
else:
self.embedding_xyz_w = Embedding(2, N_xyz_w)
- self.embeddings['xyz_w'] += [self.embedding_xyz_w]
+ self.embeddings["xyz_w"] += [self.embedding_xyz_w]
for i in range(self.num_models):
embedding_w = torch.nn.Embedding(hparams.N_vocab_w, hparams.N_w)
torch.nn.init.uniform_(embedding_w.weight, -0.05, 0.05)
- load_ckpt(embedding_w, hparams.weight_path, model_name=f'w_{i}')
- self.embeddings[f'w_{i}'] = embedding_w
- self.models_to_train += [self.embeddings[f'w_{i}']]
+ load_ckpt(embedding_w, hparams.weight_path, model_name=f"w_{i}")
+ self.embeddings[f"w_{i}"] = embedding_w
+ self.models_to_train += [self.embeddings[f"w_{i}"]]
# Add warping field mlp.
if hparams.deform_hash:
- with open('configs/hash.json') as f:
+ with open("configs/hash.json") as f:
config = json.load(f)
warping_field = Deform_Hash3d_Warp(config=config)
else:
warping_field = TranslationField(
D=self.hparams.deform_D,
W=self.hparams.deform_W,
- in_channels_xyz=in_channels_xyz[i])
+ in_channels_xyz=in_channels_xyz[i],
+ )
- load_ckpt(warping_field,
- hparams.weight_path,
- model_name=f'warping_field_{i}')
- self.models[f'warping_field_{i}'] = warping_field
+ load_ckpt(
+ warping_field, hparams.weight_path, model_name=f"warping_field_{i}"
+ )
+ self.models[f"warping_field_{i}"] = warping_field
# Set up the canonical model.
if hparams.canonical_dir is None:
for i in range(self.num_models):
if hparams.vid_hash:
- with open('configs/hash.json') as f:
+ with open("configs/hash.json") as f:
config = json.load(f)
implicit_video = ImplicitVideo_Hash(config=config)
else:
implicit_video = ImplicitVideo(
D=hparams.vid_D,
W=hparams.vid_W,
- sigmoid_offset=hparams.sigmoid_offset)
- load_ckpt(implicit_video, hparams.weight_path,
- f'implicit_video_{i}')
- self.models[f'implicit_video_{i}'] = implicit_video
+ sigmoid_offset=hparams.sigmoid_offset,
+ )
+ load_ckpt(implicit_video, hparams.weight_path, f"implicit_video_{i}")
+ self.models[f"implicit_video_{i}"] = implicit_video
for key in self.embeddings:
setattr(self, key, self.embeddings[key])
@@ -144,50 +147,46 @@ def deform_pts(self, ts_w, grid, encode_w, step=0, i=0):
ts_w_norm = ts_w / self.seq_len
ts_w_norm = ts_w_norm.repeat(grid.shape[0], 1)
input_xyt = torch.cat([grid, ts_w_norm], dim=-1)
- if 'aneal_hash' in self.embeddings.keys():
- deform = self.models[f'warping_field_{i}'](
- input_xyt,
- step=step,
- aneal_func=self.embeddings['aneal_hash'])
+ if "aneal_hash" in self.embeddings.keys():
+ deform = self.models[f"warping_field_{i}"](
+ input_xyt, step=step, aneal_func=self.embeddings["aneal_hash"]
+ )
else:
- deform = self.models[f'warping_field_{i}'](input_xyt)
+ deform = self.models[f"warping_field_{i}"](input_xyt)
if encode_w:
deformed_grid = deform + grid
else:
deformed_grid = grid
else:
if encode_w:
- e_w = self.embeddings[f'w_{i}'](repeat(ts_w, 'b n -> (b l) n ',
- l=grid.shape[0])[:, 0])
+ e_w = self.embeddings[f"w_{i}"](
+ repeat(ts_w, "b n -> (b l) n ", l=grid.shape[0])[:, 0]
+ )
# Whether to use annealed positional encoding.
if self.hparams.annealed:
- pe_w = self.embeddings['xyz_w'][i](grid, step)
+ pe_w = self.embeddings["xyz_w"][i](grid, step)
else:
- pe_w = self.embeddings['xyz_w'][i](grid)
+ pe_w = self.embeddings["xyz_w"][i](grid)
# Warping field type.
- deform = self.models[f'warping_field_{i}'](torch.cat(
- [e_w, pe_w], 1))
+ deform = self.models[f"warping_field_{i}"](torch.cat([e_w, pe_w], 1))
deformed_grid = deform + grid
else:
deformed_grid = grid
return deformed_grid
- def forward(self,
- ts_w,
- grid,
- encode_w,
- step=0,
- flows=None):
+ def forward(self, ts_w, grid, encode_w, step=0, flows=None):
# grid -> positional encoding
# ts_w -> embedding
- grid = rearrange(grid, 'b n c -> (b n) c')
+ grid = rearrange(grid, "b n c -> (b n) c")
results_list = []
flow_loss_list = []
deform_list = []
for i in range(self.num_models):
- deformed_grid = self.deform_pts(ts_w, grid, encode_w, step, i) # [batch * num_pixels, 2]
+ deformed_grid = self.deform_pts(
+ ts_w, grid, encode_w, step, i
+ ) # [batch * num_pixels, 2]
deform_list.append(deformed_grid)
# Compute optical flow loss.
flow_loss = 0
@@ -195,13 +194,14 @@ def forward(self,
if flows.max() > -1e2 and step > self.hparams.flow_step:
grid_new = grid + flows.squeeze(0)
deformed_grid_new = self.deform_pts(
- ts_w + 1, grid_new, encode_w, step, i)
+ ts_w + 1, grid_new, encode_w, step, i
+ )
flow_loss = (deformed_grid_new, deformed_grid)
flow_loss_list.append(flow_loss)
if self.hparams.vid_hash:
pe_deformed_grid = (deformed_grid + 0.3) / 1.6
else:
- pe_deformed_grid = self.embeddings['xyz'](deformed_grid)
+ pe_deformed_grid = self.embeddings["xyz"](deformed_grid)
if not self.training and self.hparams.canonical_dir is not None:
w, h = self.img_wh
canonical_img = self.canonical_img.squeeze(0)
@@ -212,19 +212,18 @@ def forward(self,
if len(canonical_img.shape) == 3:
canonical_img = canonical_img.unsqueeze(0)
results = torch.nn.functional.grid_sample(
- canonical_img[i:i + 1].permute(0, 3, 1, 2),
+ canonical_img[i : i + 1].permute(0, 3, 1, 2),
grid_new.unsqueeze(1).unsqueeze(0),
- mode='bilinear',
- padding_mode='border')
- results = results.squeeze().permute(1,0)
+ mode="bilinear",
+ padding_mode="border",
+ )
+ results = results.squeeze().permute(1, 0)
else:
- results = self.models[f'implicit_video_{i}'](pe_deformed_grid)
+ results = self.models[f"implicit_video_{i}"](pe_deformed_grid)
results_list.append(results)
- ret = edict(rgbs=results_list,
- flow_loss=flow_loss_list,
- deform=deform_list)
+ ret = edict(rgbs=results_list, flow_loss=flow_loss_list, deform=deform_list)
return ret
@@ -232,16 +231,16 @@ def setup(self, stage):
if not self.hparams.test:
dataset = dataset_dict[self.hparams.dataset_name]
kwargs = {
- 'root_dir': self.hparams.root_dir,
- 'img_wh': tuple(self.hparams.img_wh),
- 'mask_dir': self.hparams.mask_dir,
- 'flow_dir': self.hparams.flow_dir,
- 'canonical_wh': self.hparams.canonical_wh,
- 'ref_idx': self.hparams.ref_idx,
- 'canonical_dir': self.hparams.canonical_dir
+ "root_dir": self.hparams.root_dir,
+ "img_wh": tuple(self.hparams.img_wh),
+ "mask_dir": self.hparams.mask_dir,
+ "flow_dir": self.hparams.flow_dir,
+ "canonical_wh": self.hparams.canonical_wh,
+ "ref_idx": self.hparams.ref_idx,
+ "canonical_dir": self.hparams.canonical_dir,
}
- self.train_dataset = dataset(split='train', **kwargs)
- self.val_dataset = dataset(split='val', **kwargs)
+ self.train_dataset = dataset(split="train", **kwargs)
+ self.val_dataset = dataset(split="val", **kwargs)
def configure_optimizers(self):
self.optimizer = get_optimizer(self.hparams, self.models_to_train)
@@ -251,11 +250,13 @@ def configure_optimizers(self):
def train_dataloader(self):
sampler = DistributedSampler(self.train_dataset, shuffle=True)
- return DataLoader(self.train_dataset,
- num_workers=4,
- batch_size=self.hparams.batch_size,
- sampler=sampler,
- pin_memory=True)
+ return DataLoader(
+ self.train_dataset,
+ num_workers=4,
+ batch_size=self.hparams.batch_size,
+ sampler=sampler,
+ pin_memory=True,
+ )
def val_dataloader(self):
return DataLoader(
@@ -263,64 +264,68 @@ def val_dataloader(self):
shuffle=False,
num_workers=4,
batch_size=1, # validate one image (H*W rays) at a time.
- pin_memory=True)
+ pin_memory=True,
+ )
def test_dataloader(self):
dataset = dataset_dict[self.hparams.dataset_name]
kwargs = {
- 'root_dir': self.hparams.root_dir,
- 'img_wh': tuple(self.hparams.img_wh),
- 'mask_dir': self.hparams.mask_dir,
- 'canonical_wh': self.hparams.canonical_wh,
- 'canonical_dir': self.hparams.canonical_dir,
- 'test': self.hparams.test
+ "root_dir": self.hparams.root_dir,
+ "img_wh": tuple(self.hparams.img_wh),
+ "mask_dir": self.hparams.mask_dir,
+ "canonical_wh": self.hparams.canonical_wh,
+ "canonical_dir": self.hparams.canonical_dir,
+ "test": self.hparams.test,
}
- self.train_dataset = dataset(split='train', **kwargs)
+ self.train_dataset = dataset(split="train", **kwargs)
return DataLoader(
self.train_dataset,
shuffle=False,
num_workers=4,
batch_size=1, # validate one image (H*W rays) at a time.
- pin_memory=True)
+ pin_memory=True,
+ )
def training_step(self, batch, batch_idx):
# Fetch training data.
- rgbs = batch['rgbs']
- ts_w = batch['ts_w']
- grid = batch['grid']
- mk = batch['masks']
- flows = batch['flows']
- grid_c = batch['grid_c']
- ref_batch = batch['reference']
- self.seq_len = batch['seq_len']
+ rgbs = batch["rgbs"]
+ ts_w = batch["ts_w"]
+ grid = batch["grid"]
+ mk = batch["masks"]
+ flows = batch["flows"]
+ grid_c = batch["grid_c"]
+ ref_batch = batch["reference"]
+ self.seq_len = batch["seq_len"]
loss = 0
- rgbs_flattend = rearrange(rgbs, 'b h w c -> (b h w) c')
+ rgbs_flattend = rearrange(rgbs, "b h w c -> (b h w) c")
# Forward the model.
- ret = self.forward(ts_w,
- grid,
- self.hparams.encode_w,
- self.global_step,
- flows=flows)
+ ret = self.forward(
+ ts_w, grid, self.hparams.encode_w, self.global_step, flows=flows
+ )
# Mannually set a reference frame.
- if self.hparams.ref_step < 0: self.hparams.step = 1e10
- if (self.hparams.ref_idx is not None
- and self.global_step < self.hparams.ref_step):
- rgbs_c_flattend = rearrange(ref_batch[0],
- 'b h w c -> (b h w) c')
+ if self.hparams.ref_step < 0:
+ self.hparams.step = 1e10
+ if (
+ self.hparams.ref_idx is not None
+ and self.global_step < self.hparams.ref_step
+ ):
+ rgbs_c_flattend = rearrange(ref_batch[0], "b h w c -> (b h w) c")
ret_c = self(ts_w, grid, False, self.global_step, flows=flows)
# Loss computation.
for i in range(self.num_models):
results = ret.rgbs[i]
- mk_t = rearrange(mk[i], 'b h w c -> (b h w) c')
+ mk_t = rearrange(mk[i], "b h w c -> (b h w) c")
mk_t = mk_t.sum(dim=-1) > 0.05
- if (self.hparams.ref_idx is not None
- and self.global_step < self.hparams.ref_step):
- mk_c_t = rearrange(ref_batch[1][i], 'b h w c -> (b h w) c')
+ if (
+ self.hparams.ref_idx is not None
+ and self.global_step < self.hparams.ref_step
+ ):
+ mk_c_t = rearrange(ref_batch[1][i], "b h w c -> (b h w) c")
mk_c_t = mk_c_t.sum(dim=-1) > 0.05
# Background regularization.
@@ -329,84 +334,84 @@ def training_step(self, batch, batch_idx):
if self.hparams.self_bg:
grid_flattened = rgbs_flattend
else:
- grid_flattened = rearrange(grid, 'b n c -> (b n) c')
+ grid_flattened = rearrange(grid, "b n c -> (b n) c")
grid_flattened = torch.cat(
- [grid_flattened, grid_flattened[:, :1]], -1)
+ [grid_flattened, grid_flattened[:, :1]], -1
+ )
if self.hparams.bg_loss and self.hparams.mask_dir:
loss = loss + self.hparams.bg_loss * self.color_loss(
- results[mk1], grid_flattened[mk1])
+ results[mk1], grid_flattened[mk1]
+ )
# MSE color loss.
- loss = loss + self.color_loss(results[mk_t],
- rgbs_flattend[mk_t])
+ loss = loss + self.color_loss(results[mk_t], rgbs_flattend[mk_t])
# Image gradient loss.
- img_pred = rearrange(results,
- '(b h w) c -> b h w c',
- b=1,
- h=self.h,
- w=self.w)
- rgbs_gt = rearrange(rgbs_flattend,
- '(b h w) c -> b h w c',
- b=1,
- h=self.h,
- w=self.w)
- mk_t_re = rearrange(mk_t,
- '(b h w c) -> b h w c',
- b=1,
- h=self.h,
- w=self.w)
- grad_loss = compute_gradient_loss(rgbs_gt.permute(0, 3, 1, 2),
- img_pred.permute(0, 3, 1, 2),
- mask=mk_t_re.permute(0, 3, 1, 2))
+ img_pred = rearrange(
+ results, "(b h w) c -> b h w c", b=1, h=self.h, w=self.w
+ )
+ rgbs_gt = rearrange(
+ rgbs_flattend, "(b h w) c -> b h w c", b=1, h=self.h, w=self.w
+ )
+ mk_t_re = rearrange(mk_t, "(b h w c) -> b h w c", b=1, h=self.h, w=self.w)
+ grad_loss = compute_gradient_loss(
+ rgbs_gt.permute(0, 3, 1, 2),
+ img_pred.permute(0, 3, 1, 2),
+ mask=mk_t_re.permute(0, 3, 1, 2),
+ )
loss = loss + grad_loss * self.hparams.grad_loss
# Optical flow loss.
if ret.flow_loss[0] != 0:
- mk_flow_t = torch.logical_and(mk_t, flows[0].sum(dim=-1)< 3)
- loss = loss + torch.nn.functional.l1_loss(
- ret.flow_loss[i][0][mk_flow_t], ret.flow_loss[i][1]
- [mk_flow_t]) * self.hparams.flow_loss
+ mk_flow_t = torch.logical_and(mk_t, flows[0].sum(dim=-1) < 3)
+ loss = (
+ loss
+ + torch.nn.functional.l1_loss(
+ ret.flow_loss[i][0][mk_flow_t], ret.flow_loss[i][1][mk_flow_t]
+ )
+ * self.hparams.flow_loss
+ )
# Reference loss.
- if (self.hparams.ref_idx is not None
- and self.global_step < self.hparams.ref_step):
+ if (
+ self.hparams.ref_idx is not None
+ and self.global_step < self.hparams.ref_step
+ ):
results_c = ret_c.rgbs[i]
- loss += self.color_loss(results_c[mk_c_t],
- rgbs_c_flattend[mk_c_t])
+ loss += self.color_loss(results_c[mk_c_t], rgbs_c_flattend[mk_c_t])
# PSNR metric.
with torch.no_grad():
if i == 0:
psnr_ = psnr(results[mk_t], rgbs_flattend[mk_t])
- self.log('lr', get_learning_rate(self.optimizer), prog_bar=True)
- self.log('train/loss', loss, prog_bar=True)
- self.log('train/psnr', psnr_, prog_bar=True)
+ self.log("lr", get_learning_rate(self.optimizer), prog_bar=True)
+ self.log("train/loss", loss, prog_bar=True)
+ self.log("train/psnr", psnr_, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
- rgbs = batch['rgbs']
- ts_w = batch['ts_w']
- grid = batch['grid']
- mk = batch['masks']
+ rgbs = batch["rgbs"]
+ ts_w = batch["ts_w"]
+ grid = batch["grid"]
+ mk = batch["masks"]
grid_c = grid # batch['grid_c']
- self.seq_len = batch['seq_len']
+ self.seq_len = batch["seq_len"]
ret = self(ts_w, grid, self.hparams.encode_w, self.global_step)
ret_c = self(ts_w, grid_c, False, self.global_step)
log = {}
W, H = self.hparams.img_wh
- rgbs_flattend = rearrange(rgbs, 'b h w c -> (b h w) c')
+ rgbs_flattend = rearrange(rgbs, "b h w c -> (b h w) c")
img_gt = rgbs_flattend.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W)
stack_list = [img_gt]
for i in range(self.num_models):
results = ret.rgbs[i]
results_c = ret_c.rgbs[i]
- mk_t = rearrange(mk[i], 'b h w c -> (b h w) c')
+ mk_t = rearrange(mk[i], "b h w c -> (b h w) c")
if batch_idx == 0:
results[mk_t.sum(dim=-1) <= 0.05] = 0
img = results.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W)
@@ -414,43 +419,47 @@ def validation_step(self, batch, batch_idx):
stack_list.append(img)
stack_list.append(img_c)
- stack = torch.stack(stack_list) # (3, 3, H, W)
- self.logger.experiment.add_images('val/GT_Reconstructed', stack,
- self.global_step)
+ stack = torch.stack(stack_list) # (3, 3, H, W)
+ self.logger.experiment.add_images(
+ "val/GT_Reconstructed", stack, self.global_step
+ )
return log
def test_step(self, batch, batch_idx):
- ts_w = batch['ts_w']
- grid = batch['grid']
- mk = batch['masks']
- grid_c = batch['grid_c']
+ ts_w = batch["ts_w"]
+ grid = batch["grid"]
+ mk = batch["masks"]
+ grid_c = batch["grid_c"]
W, H = self.hparams.img_wh
- self.seq_len = batch['seq_len']
+ self.seq_len = batch["seq_len"]
if self.hparams.canonical_dir is not None:
- self.canonical_img = batch['canonical_img']
- self.img_wh = batch['img_wh']
-
- save_dir = os.path.join('results',
- self.hparams.root_dir.split('/')[0],
- self.hparams.root_dir.split('/')[1],
- self.hparams.exp_name)
- sample_name = self.hparams.root_dir.split('/')[1]
+ self.canonical_img = batch["canonical_img"]
+ self.img_wh = batch["img_wh"]
+
+ save_dir = os.path.join(
+ "results",
+ self.hparams.root_dir.split("/")[0],
+ self.hparams.root_dir.split("/")[1],
+ self.hparams.exp_name,
+ )
+ sample_name = self.hparams.root_dir.split("/")[1]
if self.hparams.canonical_dir is not None:
- test_dir = f'{save_dir}_transformed'
- video_name = f'{sample_name}_{self.hparams.exp_name}_transformed'
+ test_dir = f"{save_dir}_transformed"
+ video_name = f"{sample_name}_{self.hparams.exp_name}_transformed"
else:
- test_dir = f'{save_dir}'
- video_name = f'{sample_name}_{self.hparams.exp_name}'
+ test_dir = f"{save_dir}"
+ video_name = f"{sample_name}_{self.hparams.exp_name}"
Path(test_dir).mkdir(parents=True, exist_ok=True)
if batch_idx > 0 and self.hparams.save_video:
- self.video_visualizer.set_path(os.path.join(
- test_dir, f'{video_name}.mp4'))
- self.raw_video_visualizer.set_path(os.path.join(
- test_dir, f'{video_name}_raw.mp4'))
- self.dual_video_visualizer.set_path(os.path.join(
- test_dir, f'{video_name}_dual.mp4'))
+ self.video_visualizer.set_path(os.path.join(test_dir, f"{video_name}.mp4"))
+ self.raw_video_visualizer.set_path(
+ os.path.join(test_dir, f"{video_name}_raw.mp4")
+ )
+ self.dual_video_visualizer.set_path(
+ os.path.join(test_dir, f"{video_name}_dual.mp4")
+ )
if batch_idx == 0 and self.hparams.canonical_dir is None:
# Save the canonical image.
@@ -463,46 +472,56 @@ def test_step(self, batch, batch_idx):
if batch_idx == 0 and self.hparams.canonical_dir is None:
results_c = ret.rgbs[i]
if self.hparams.canonical_wh:
- img_c = results_c.view(self.hparams.canonical_wh[1],
- self.hparams.canonical_wh[0],
- 3).float().cpu().numpy()
+ img_c = (
+ results_c.view(
+ self.hparams.canonical_wh[1],
+ self.hparams.canonical_wh[0],
+ 3,
+ )
+ .float()
+ .cpu()
+ .numpy()
+ )
else:
img_c = results_c.view(H, W, 3).float().cpu().numpy()
img_c = cv2.cvtColor(img_c, cv2.COLOR_BGR2RGB)
- cv2.imwrite(f'{test_dir}/canonical_{i}.png', img_c * 255)
+ cv2.imwrite(f"{test_dir}/canonical_{i}.png", img_c * 255)
- mk_n = rearrange(mk[i], 'b h w c -> (b h w) c')
+ mk_n = rearrange(mk[i], "b h w c -> (b h w) c")
mk_n = mk_n.sum(dim=-1) > 0.05
mk_n = mk_n.cpu().numpy()
results = ret_n.rgbs[i]
results = results.cpu().numpy() # (3, H, W)
img[mk_n] = results[mk_n]
- img = rearrange(img, '(h w) c -> h w c', h=H, w=W)
+ img = rearrange(img, "(h w) c -> h w c", h=H, w=W)
img = img * 255
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
- cv2.imwrite(f'{test_dir}/{batch_idx:05d}.png', img)
+ cv2.imwrite(f"{test_dir}/{batch_idx:05d}.png", img)
if batch_idx > 0 and self.hparams.save_video:
img = img[..., ::-1]
self.video_visualizer.add(img)
- rgbs = batch['rgbs'].view(H, W, 3).cpu().numpy() * 255
+ rgbs = batch["rgbs"].view(H, W, 3).cpu().numpy() * 255
rgbs = rgbs.astype(np.uint8)
self.raw_video_visualizer.add(rgbs)
dual_img = np.concatenate((rgbs, img), axis=1)
self.dual_video_visualizer.add(dual_img)
if self.hparams.save_deform:
- save_deform_dir = f'{test_dir}_deform'
+ save_deform_dir = f"{test_dir}_deform"
Path(save_deform_dir).mkdir(parents=True, exist_ok=True)
deformation_field = ret_n.deform[0]
- deformation_field = rearrange(deformation_field,
- '(h w) c -> h w c', h=H, w=W)
- grid_ = rearrange(grid[0], '(h w) c -> h w c', h=H, w=W)
+ deformation_field = rearrange(
+ deformation_field, "(h w) c -> h w c", h=H, w=W
+ )
+ grid_ = rearrange(grid[0], "(h w) c -> h w c", h=H, w=W)
deformation_delta = deformation_field - grid_
- np.save(f'{save_deform_dir}/{batch_idx:05d}.npy',
- deformation_delta.cpu().numpy())
+ np.save(
+ f"{save_deform_dir}/{batch_idx:05d}.npy",
+ deformation_delta.cpu().numpy(),
+ )
def on_test_epoch_end(self):
if self.hparams.save_video:
@@ -515,36 +534,38 @@ def get_progress_bar_dict(self):
items.pop("v_num", None)
return items
+
def main(hparams):
system = ImplicitVideoSystem(hparams)
if not hparams.test:
- os.makedirs(f'{hparams.model_save_path}/{hparams.exp_name}',
- exist_ok=True)
+ os.makedirs(f"{hparams.model_save_path}/{hparams.exp_name}", exist_ok=True)
checkpoint_callback = ModelCheckpoint(
- dirpath=f'{hparams.model_save_path}/{hparams.exp_name}',
- filename='{step:d}',
- mode='max',
+ dirpath=f"{hparams.model_save_path}/{hparams.exp_name}",
+ filename="{step:d}",
+ mode="max",
save_top_k=-1,
every_n_train_steps=hparams.save_model_iters,
- save_last=True)
-
- logger = TensorBoardLogger(save_dir=hparams.log_save_path,
- name=hparams.exp_name)
-
- trainer = Trainer(max_steps=hparams.num_steps,
- precision=16 if hparams.vid_hash == True else 32,
- callbacks=[checkpoint_callback],
- logger=logger,
- accelerator='gpu',
- devices=hparams.gpus,
- num_sanity_val_steps=1,
- benchmark=True,
- profiler="simple" if len(hparams.gpus) == 1 else None,
- val_check_interval=hparams.valid_iters,
- limit_val_batches=hparams.valid_batches,
- strategy="ddp_find_unused_parameters_true")
+ save_last=True,
+ )
+
+ logger = TensorBoardLogger(save_dir=hparams.log_save_path, name=hparams.exp_name)
+
+ trainer = Trainer(
+ max_steps=hparams.num_steps,
+ precision=16 if hparams.vid_hash == True else 32,
+ callbacks=[checkpoint_callback],
+ logger=logger,
+ accelerator="gpu",
+ devices=hparams.gpus,
+ num_sanity_val_steps=1,
+ benchmark=True,
+ profiler="simple" if len(hparams.gpus) == 1 else None,
+ val_check_interval=hparams.valid_iters,
+ limit_val_batches=hparams.valid_batches,
+ strategy="ddp_find_unused_parameters_true",
+ )
if hparams.test:
trainer.test(system, dataloaders=system.test_dataloader())
@@ -552,6 +573,6 @@ def main(hparams):
trainer.fit(system, ckpt_path=hparams.ckpt_path)
-if __name__ == '__main__':
+if __name__ == "__main__":
hparams = get_opts()
main(hparams)
diff --git a/utils/__init__.py b/utils/__init__.py
index 444ac71..ee6ac86 100755
--- a/utils/__init__.py
+++ b/utils/__init__.py
@@ -1,7 +1,9 @@
import torch
+
# optimizer
from torch.optim import SGD, Adam
import torch_optimizer as optim
+
# scheduler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.lr_scheduler import MultiStepLR
@@ -23,79 +25,99 @@ def get_parameters(models):
# print("is dict")
for model in models.values():
parameters += get_parameters(model)
- else: # models is actually a single pytorch model
+ else: # models is actually a single pytorch model
parameters += list(models.parameters())
return parameters
+
def get_optimizer(hparams, models):
eps = 1e-8
parameters = get_parameters(models)
- if hparams.optimizer == 'sgd':
- optimizer = SGD(parameters, lr=hparams.lr,
- momentum=hparams.momentum, weight_decay=hparams.weight_decay)
- elif hparams.optimizer == 'adam':
- optimizer = Adam(parameters, lr=hparams.lr, eps=eps,
- weight_decay=hparams.weight_decay)
- elif hparams.optimizer == 'radam':
- optimizer = optim.RAdam(parameters, lr=hparams.lr, eps=eps,
- weight_decay=hparams.weight_decay)
- elif hparams.optimizer == 'ranger':
- optimizer = optim.Ranger(parameters, lr=hparams.lr, eps=eps,
- weight_decay=hparams.weight_decay)
+ if hparams.optimizer == "sgd":
+ optimizer = SGD(
+ parameters,
+ lr=hparams.lr,
+ momentum=hparams.momentum,
+ weight_decay=hparams.weight_decay,
+ )
+ elif hparams.optimizer == "adam":
+ optimizer = Adam(
+ parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay
+ )
+ elif hparams.optimizer == "radam":
+ optimizer = optim.RAdam(
+ parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay
+ )
+ elif hparams.optimizer == "ranger":
+ optimizer = optim.Ranger(
+ parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay
+ )
else:
- raise ValueError('optimizer not recognized!')
+ raise ValueError("optimizer not recognized!")
return optimizer
+
def get_scheduler(hparams, optimizer):
eps = 1e-8
- if hparams.lr_scheduler == 'steplr':
- scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step,
- gamma=hparams.decay_gamma)
- elif hparams.lr_scheduler == 'cosine':
+ if hparams.lr_scheduler == "steplr":
+ scheduler = MultiStepLR(
+ optimizer, milestones=hparams.decay_step, gamma=hparams.decay_gamma
+ )
+ elif hparams.lr_scheduler == "cosine":
scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps)
- elif hparams.lr_scheduler == 'poly':
- scheduler = LambdaLR(optimizer,
- lambda epoch: (1-epoch/hparams.num_epochs)**hparams.poly_exp)
- elif hparams.lr_scheduler == 'exponential':
+ elif hparams.lr_scheduler == "poly":
+ scheduler = LambdaLR(
+ optimizer,
+ lambda epoch: (1 - epoch / hparams.num_epochs) ** hparams.poly_exp,
+ )
+ elif hparams.lr_scheduler == "exponential":
# Adaptively adjust the schedule
- scheduler = LambdaLR(optimizer,
- lambda step: hparams.exponent_base**(step/(2 * hparams.num_steps)))
+ scheduler = LambdaLR(
+ optimizer,
+ lambda step: hparams.exponent_base ** (step / (2 * hparams.num_steps)),
+ )
else:
- raise ValueError('scheduler not recognized!')
+ raise ValueError("scheduler not recognized!")
- if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']:
- scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier,
- total_epoch=hparams.warmup_epochs, after_scheduler=scheduler)
+ if hparams.warmup_epochs > 0 and hparams.optimizer not in ["radam", "ranger"]:
+ scheduler = GradualWarmupScheduler(
+ optimizer,
+ multiplier=hparams.warmup_multiplier,
+ total_epoch=hparams.warmup_epochs,
+ after_scheduler=scheduler,
+ )
return scheduler
+
def get_learning_rate(optimizer):
for param_group in optimizer.param_groups:
- return param_group['lr']
+ return param_group["lr"]
+
-def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]):
- checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
+def extract_model_state_dict(ckpt_path, model_name="model", prefixes_to_ignore=[]):
+ checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu"))
checkpoint_ = {}
- if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint
- checkpoint = checkpoint['state_dict']
+ if "state_dict" in checkpoint: # if it's a pytorch-lightning checkpoint
+ checkpoint = checkpoint["state_dict"]
for k, v in checkpoint.items():
if not k.startswith(model_name):
continue
- k = k[len(model_name)+1:]
+ k = k[len(model_name) + 1 :]
for prefix in prefixes_to_ignore:
if k.startswith(prefix):
- print('ignore', k)
+ print("ignore", k)
break
else:
checkpoint_[k] = v
return checkpoint_
-def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]):
+
+def load_ckpt(model, ckpt_path, model_name="model", prefixes_to_ignore=[]):
if not ckpt_path:
return
model_dict = model.state_dict()
checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore)
model_dict.update(checkpoint_)
model.load_state_dict(model_dict)
-
diff --git a/utils/image_utils.py b/utils/image_utils.py
index aad5698..54deb14 100755
--- a/utils/image_utils.py
+++ b/utils/image_utils.py
@@ -12,10 +12,20 @@
# File extensions regarding images (not including GIFs).
IMAGE_EXTENSIONS = (
- '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp',
- '.tiff', '.tif'
+ ".bmp",
+ ".ppm",
+ ".pgm",
+ ".jpeg",
+ ".jpg",
+ ".jpe",
+ ".jp2",
+ ".png",
+ ".webp",
+ ".tiff",
+ ".tif",
)
+
def check_file_ext(filename, *ext_list):
"""Checks whether the given filename is with target extension(s).
@@ -31,7 +41,7 @@ def check_file_ext(filename, *ext_list):
"""
if len(ext_list) == 0:
return False
- ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list]
+ ext_list = [ext if ext.startswith(".") else "." + ext for ext in ext_list]
ext_list = [ext.lower() for ext in ext_list]
basename = os.path.basename(filename)
ext = os.path.splitext(basename)[1].lower()
@@ -143,14 +153,16 @@ def resize_image(image, *args, **kwargs):
return cv2.resize(image, *args, **kwargs)
-def add_text_to_image(image,
- text='',
- position=None,
- font=cv2.FONT_HERSHEY_TRIPLEX,
- font_size=1.0,
- line_type=cv2.LINE_8,
- line_width=1,
- color=(255, 255, 255)):
+def add_text_to_image(
+ image,
+ text="",
+ position=None,
+ font=cv2.FONT_HERSHEY_TRIPLEX,
+ font_size=1.0,
+ line_type=cv2.LINE_8,
+ line_width=1,
+ color=(255, 255, 255),
+):
"""Overlays text on given image.
NOTE: The input image is assumed to be with `RGB` channel order.
@@ -174,15 +186,17 @@ def add_text_to_image(image,
return image
_check_2d_image(image)
- cv2.putText(img=image,
- text=text,
- org=position,
- fontFace=font,
- fontScale=font_size,
- color=color,
- thickness=line_width,
- lineType=line_type,
- bottomLeftOrigin=False)
+ cv2.putText(
+ img=image,
+ text=text,
+ org=position,
+ fontFace=font,
+ fontScale=font_size,
+ color=color,
+ thickness=line_width,
+ lineType=line_type,
+ bottomLeftOrigin=False,
+ )
return image
@@ -254,7 +268,7 @@ def parse_image_size(obj):
Raises:
If the input is invalid, i.e., neither a list or tuple, nor a string.
"""
- if obj is None or obj == '':
+ if obj is None or obj == "":
height = 0
width = 0
elif isinstance(obj, int):
@@ -262,7 +276,7 @@ def parse_image_size(obj):
width = obj
elif isinstance(obj, (list, tuple, str, np.ndarray)):
if isinstance(obj, str):
- splits = obj.replace(' ', '').split(',')
+ splits = obj.replace(" ", "").split(",")
numbers = tuple(map(int, splits))
else:
numbers = tuple(obj)
@@ -276,9 +290,9 @@ def parse_image_size(obj):
height = int(numbers[0])
width = int(numbers[1])
else:
- raise ValueError('At most two elements for image size.')
+ raise ValueError("At most two elements for image size.")
else:
- raise ValueError(f'Invalid type of input: `{type(obj)}`!')
+ raise ValueError(f"Invalid type of input: `{type(obj)}`!")
return (max(0, height), max(0, width))
diff --git a/utils/video_visualizer.py b/utils/video_visualizer.py
index 8262a3d..bc1c8bc 100755
--- a/utils/video_visualizer.py
+++ b/utils/video_visualizer.py
@@ -9,13 +9,15 @@
class VideoVisualizer(object):
"""Defines the video visualizer that presents images as a video."""
- def __init__(self,
- path=None,
- frame_size=None,
- fps=25.0,
- codec='libx264',
- pix_fmt='yuv420p',
- crf=1):
+ def __init__(
+ self,
+ path=None,
+ frame_size=None,
+ fps=25.0,
+ codec="libx264",
+ pix_fmt="yuv420p",
+ crf=1,
+ ):
"""Initializes the video visualizer.
Args:
@@ -53,11 +55,11 @@ def set_fps(self, fps=25.0):
"""Sets the FPS (frame per second) of the video."""
self.fps = fps
- def set_codec(self, codec='libx264'):
+ def set_codec(self, codec="libx264"):
"""Sets the video codec."""
self.codec = codec
- def set_pix_fmt(self, pix_fmt='yuv420p'):
+ def set_pix_fmt(self, pix_fmt="yuv420p"):
"""Sets the video pixel format."""
self.pix_fmt = pix_fmt
@@ -71,11 +73,11 @@ def init_video(self):
assert self.frame_width > 0
video_setting = {
- '-r': f'{self.fps:.2f}',
- '-s': f'{self.frame_width}x{self.frame_height}',
- '-vcodec': f'{self.codec}',
- '-crf': f'{self.crf}',
- '-pix_fmt': f'{self.pix_fmt}',
+ "-r": f"{self.fps:.2f}",
+ "-s": f"{self.frame_width}x{self.frame_height}",
+ "-vcodec": f"{self.codec}",
+ "-crf": f"{self.crf}",
+ "-pix_fmt": f"{self.pix_fmt}",
}
self.video = FFmpegWriter(self.path, outputdict=video_setting)
@@ -126,16 +128,15 @@ def save(self):
self.set_path(None)
-if __name__ == '__main__':
+if __name__ == "__main__":
from glob import glob
import cv2
- video_visualizer = VideoVisualizer(path='output.mp4',
- frame_size=None,
- fps=25.0)
- img_folder = 'src_images/'
- imgs = sorted(glob(img_folder + '/*.png'))
+
+ video_visualizer = VideoVisualizer(path="output.mp4", frame_size=None, fps=25.0)
+ img_folder = "src_images/"
+ imgs = sorted(glob(img_folder + "/*.png"))
for img in imgs:
image = cv2.imread(img)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
video_visualizer.add(image)
- video_visualizer.save()
\ No newline at end of file
+ video_visualizer.save()
diff --git a/utils/warmup_scheduler.py b/utils/warmup_scheduler.py
index 03208c4..6a80cee 100755
--- a/utils/warmup_scheduler.py
+++ b/utils/warmup_scheduler.py
@@ -1,8 +1,9 @@
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
+
class GradualWarmupScheduler(_LRScheduler):
- """ Gradually warm-up(increasing) learning rate in optimizer.
+ """Gradually warm-up(increasing) learning rate in optimizer.
Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
Args:
optimizer (Optimizer): Wrapped optimizer.
@@ -13,8 +14,8 @@ class GradualWarmupScheduler(_LRScheduler):
def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
self.multiplier = multiplier
- if self.multiplier < 1.:
- raise ValueError('multiplier should be greater thant or equal to 1.')
+ if self.multiplier < 1.0:
+ raise ValueError("multiplier should be greater thant or equal to 1.")
self.total_epoch = total_epoch
self.after_scheduler = after_scheduler
self.finished = False
@@ -24,21 +25,33 @@ def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
- self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
+ self.after_scheduler.base_lrs = [
+ base_lr * self.multiplier for base_lr in self.base_lrs
+ ]
self.finished = True
return self.after_scheduler.get_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
- return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
+ return [
+ base_lr
+ * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
+ for base_lr in self.base_lrs
+ ]
def step_ReduceLROnPlateau(self, metrics, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
- self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
+ self.last_epoch = (
+ epoch if epoch != 0 else 1
+ ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
if self.last_epoch <= self.total_epoch:
- warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
+ warmup_lr = [
+ base_lr
+ * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
+ for base_lr in self.base_lrs
+ ]
for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
- param_group['lr'] = lr
+ param_group["lr"] = lr
else:
if epoch is None:
self.after_scheduler.step(metrics, None)
@@ -55,4 +68,4 @@ def step(self, epoch=None, metrics=None):
else:
return super(GradualWarmupScheduler, self).step(epoch)
else:
- self.step_ReduceLROnPlateau(metrics, epoch)
\ No newline at end of file
+ self.step_ReduceLROnPlateau(metrics, epoch)