diff --git a/.gitignore b/.gitignore index 8fa68fe..f68a4af 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,5 @@ all_sequences/ ckpts/ logs/ # configs/ + +*.mp4 diff --git a/README.md b/README.md index c6534e6..4e0290c 100755 --- a/README.md +++ b/README.md @@ -1,7 +1,13 @@ -# CoDeF: Content Deformation Fields for Temporally Consistent Video Processing - +# HeyGen ❤️ Superwoman + +The slides for the project: https://docs.google.com/presentation/d/1y4qP2dALviAZm_D4BU_udy2IHE-1A6nwD4zS3RAqfvc/edit?usp=sharing + +The config for a HeyGen avatar stylization: `configs/heygen/base.yaml` + +--------- +## CoDeF: Content Deformation Fields for Temporally Consistent Video Processing [Hao Ouyang](https://ken-ouyang.github.io/)\*, [Qiuyu Wang](https://github.com/qiuyu96/)\*, [Yuxi Xiao](https://henry123-boy.github.io/)\*, [Qingyan Bai](https://scholar.google.com/citations?user=xUMjxi4AAAAJ&hl=en), [Juntao Zhang](https://github.com/JordanZh), [Kecheng Zheng](https://scholar.google.com/citations?user=hMDQifQAAAAJ), [Xiaowei Zhou](https://xzhou.me/), [Qifeng Chen](https://cqf.io/)†, [Yujun Shen](https://shenyujun.github.io/)† (*equal contribution, †corresponding author) diff --git a/configs/heygen/base.yaml b/configs/heygen/base.yaml new file mode 100755 index 0000000..531bb04 --- /dev/null +++ b/configs/heygen/base.yaml @@ -0,0 +1,26 @@ +mask_dir: null +flow_dir: null + +img_wh: [1080, 1080] +canonical_wh: [1080, 1080] + +lr: 0.001 +bg_loss: 0.0 + +ref_idx: null # 0 + +N_xyz_w: [8, 8] +flow_loss: 1 +flow_step: -1 +self_bg: False + +deform_hash: True +vid_hash: True + +num_steps: 200000 +decay_step: [2500, 5000, 7500] +annealed_begin_step: 4000 +annealed_step: 4000 +save_model_iters: 2000 + +fps: 25 diff --git a/data_preprocessing/RAFT/core/corr.py b/data_preprocessing/RAFT/core/corr.py index cffcbc8..42a7817 100755 --- a/data_preprocessing/RAFT/core/corr.py +++ b/data_preprocessing/RAFT/core/corr.py @@ -19,10 +19,10 @@ def __init__(self, fmap1, fmap2, num_levels=4, radius=4): corr = CorrBlock.corr(fmap1, fmap2) batch, h1, w1, dim, h2, w2 = corr.shape - corr = corr.reshape(batch*h1*w1, dim, h2, w2) - + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + self.corr_pyramid.append(corr) - for i in range(self.num_levels-1): + for i in range(self.num_levels - 1): corr = F.avg_pool2d(corr, 2, stride=2) self.corr_pyramid.append(corr) @@ -34,12 +34,12 @@ def __call__(self, coords): out_pyramid = [] for i in range(self.num_levels): corr = self.corr_pyramid[i] - dx = torch.linspace(-r, r, 2*r+1, device=coords.device) - dy = torch.linspace(-r, r, 2*r+1, device=coords.device) + dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) - centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i - delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) coords_lvl = centroid_lvl + delta_lvl corr = bilinear_sampler(corr, coords_lvl) @@ -52,12 +52,12 @@ def __call__(self, coords): @staticmethod def corr(fmap1, fmap2): batch, dim, ht, wd = fmap1.shape - fmap1 = fmap1.view(batch, dim, ht*wd) - fmap2 = fmap2.view(batch, dim, ht*wd) - - corr = torch.matmul(fmap1.transpose(1,2), fmap2) + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) corr = corr.view(batch, ht, wd, 1, ht, wd) - return corr / torch.sqrt(torch.tensor(dim).float()) + return corr / torch.sqrt(torch.tensor(dim).float()) class AlternateCorrBlock: @@ -83,7 +83,7 @@ def __call__(self, coords): fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() - corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + (corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) corr_list.append(corr.squeeze(1)) corr = torch.stack(corr_list, dim=1) diff --git a/data_preprocessing/RAFT/core/datasets.py b/data_preprocessing/RAFT/core/datasets.py index 3411fda..2513967 100755 --- a/data_preprocessing/RAFT/core/datasets.py +++ b/data_preprocessing/RAFT/core/datasets.py @@ -66,8 +66,8 @@ def __getitem__(self, index): # grayscale images if len(img1.shape) == 2: - img1 = np.tile(img1[...,None], (1, 1, 3)) - img2 = np.tile(img2[...,None], (1, 1, 3)) + img1 = np.tile(img1[..., None], (1, 1, 3)) + img2 = np.tile(img2[..., None], (1, 1, 3)) else: img1 = img1[..., :3] img2 = img2[..., :3] @@ -89,147 +89,203 @@ def __getitem__(self, index): return img1, img2, flow, valid.float() - def __rmul__(self, v): self.flow_list = v * self.flow_list self.image_list = v * self.image_list return self - + def __len__(self): return len(self.image_list) - + class MpiSintel(FlowDataset): - def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): + def __init__( + self, aug_params=None, split="training", root="datasets/Sintel", dstype="clean" + ): super(MpiSintel, self).__init__(aug_params) - flow_root = osp.join(root, split, 'flow') + flow_root = osp.join(root, split, "flow") image_root = osp.join(root, split, dstype) - if split == 'test': + if split == "test": self.is_test = True for scene in os.listdir(image_root): - image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) - for i in range(len(image_list)-1): - self.image_list += [ [image_list[i], image_list[i+1]] ] - self.extra_info += [ (scene, i) ] # scene and frame_id + image_list = sorted(glob(osp.join(image_root, scene, "*.png"))) + for i in range(len(image_list) - 1): + self.image_list += [[image_list[i], image_list[i + 1]]] + self.extra_info += [(scene, i)] # scene and frame_id - if split != 'test': - self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + if split != "test": + self.flow_list += sorted(glob(osp.join(flow_root, scene, "*.flo"))) class FlyingChairs(FlowDataset): - def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): + def __init__( + self, aug_params=None, split="train", root="datasets/FlyingChairs_release/data" + ): super(FlyingChairs, self).__init__(aug_params) - images = sorted(glob(osp.join(root, '*.ppm'))) - flows = sorted(glob(osp.join(root, '*.flo'))) - assert (len(images)//2 == len(flows)) + images = sorted(glob(osp.join(root, "*.ppm"))) + flows = sorted(glob(osp.join(root, "*.flo"))) + assert len(images) // 2 == len(flows) - split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) + split_list = np.loadtxt("chairs_split.txt", dtype=np.int32) for i in range(len(flows)): xid = split_list[i] - if (split=='training' and xid==1) or (split=='validation' and xid==2): - self.flow_list += [ flows[i] ] - self.image_list += [ [images[2*i], images[2*i+1]] ] + if (split == "training" and xid == 1) or ( + split == "validation" and xid == 2 + ): + self.flow_list += [flows[i]] + self.image_list += [[images[2 * i], images[2 * i + 1]]] class FlyingThings3D(FlowDataset): - def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): + def __init__( + self, aug_params=None, root="datasets/FlyingThings3D", dstype="frames_cleanpass" + ): super(FlyingThings3D, self).__init__(aug_params) - for cam in ['left']: - for direction in ['into_future', 'into_past']: - image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) + for cam in ["left"]: + for direction in ["into_future", "into_past"]: + image_dirs = sorted(glob(osp.join(root, dstype, "TRAIN/*/*"))) image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) - flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) + flow_dirs = sorted(glob(osp.join(root, "optical_flow/TRAIN/*/*"))) flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) for idir, fdir in zip(image_dirs, flow_dirs): - images = sorted(glob(osp.join(idir, '*.png')) ) - flows = sorted(glob(osp.join(fdir, '*.pfm')) ) - for i in range(len(flows)-1): - if direction == 'into_future': - self.image_list += [ [images[i], images[i+1]] ] - self.flow_list += [ flows[i] ] - elif direction == 'into_past': - self.image_list += [ [images[i+1], images[i]] ] - self.flow_list += [ flows[i+1] ] - + images = sorted(glob(osp.join(idir, "*.png"))) + flows = sorted(glob(osp.join(fdir, "*.pfm"))) + for i in range(len(flows) - 1): + if direction == "into_future": + self.image_list += [[images[i], images[i + 1]]] + self.flow_list += [flows[i]] + elif direction == "into_past": + self.image_list += [[images[i + 1], images[i]]] + self.flow_list += [flows[i + 1]] + class KITTI(FlowDataset): - def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): + def __init__(self, aug_params=None, split="training", root="datasets/KITTI"): super(KITTI, self).__init__(aug_params, sparse=True) - if split == 'testing': + if split == "testing": self.is_test = True root = osp.join(root, split) - images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) - images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + images1 = sorted(glob(osp.join(root, "image_2/*_10.png"))) + images2 = sorted(glob(osp.join(root, "image_2/*_11.png"))) for img1, img2 in zip(images1, images2): - frame_id = img1.split('/')[-1] - self.extra_info += [ [frame_id] ] - self.image_list += [ [img1, img2] ] + frame_id = img1.split("/")[-1] + self.extra_info += [[frame_id]] + self.image_list += [[img1, img2]] - if split == 'training': - self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + if split == "training": + self.flow_list = sorted(glob(osp.join(root, "flow_occ/*_10.png"))) class HD1K(FlowDataset): - def __init__(self, aug_params=None, root='datasets/HD1k'): + def __init__(self, aug_params=None, root="datasets/HD1k"): super(HD1K, self).__init__(aug_params, sparse=True) seq_ix = 0 while 1: - flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) - images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + flows = sorted( + glob(os.path.join(root, "hd1k_flow_gt", "flow_occ/%06d_*.png" % seq_ix)) + ) + images = sorted( + glob(os.path.join(root, "hd1k_input", "image_2/%06d_*.png" % seq_ix)) + ) if len(flows) == 0: break - for i in range(len(flows)-1): + for i in range(len(flows) - 1): self.flow_list += [flows[i]] - self.image_list += [ [images[i], images[i+1]] ] + self.image_list += [[images[i], images[i + 1]]] seq_ix += 1 -def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): - """ Create the data loader for the corresponding trainign set """ - - if args.stage == 'chairs': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} - train_dataset = FlyingChairs(aug_params, split='training') - - elif args.stage == 'things': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} - clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') - final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') +def fetch_dataloader(args, TRAIN_DS="C+T+K+S+H"): + """Create the data loader for the corresponding trainign set""" + + if args.stage == "chairs": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.1, + "max_scale": 1.0, + "do_flip": True, + } + train_dataset = FlyingChairs(aug_params, split="training") + + elif args.stage == "things": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.4, + "max_scale": 0.8, + "do_flip": True, + } + clean_dataset = FlyingThings3D(aug_params, dstype="frames_cleanpass") + final_dataset = FlyingThings3D(aug_params, dstype="frames_finalpass") train_dataset = clean_dataset + final_dataset - elif args.stage == 'sintel': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} - things = FlyingThings3D(aug_params, dstype='frames_cleanpass') - sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') - sintel_final = MpiSintel(aug_params, split='training', dstype='final') - - if TRAIN_DS == 'C+T+K+S+H': - kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) - hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) - train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things - - elif TRAIN_DS == 'C+T+K/S': - train_dataset = 100*sintel_clean + 100*sintel_final + things - - elif args.stage == 'kitti': - aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} - train_dataset = KITTI(aug_params, split='training') - - train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, - pin_memory=False, shuffle=True, num_workers=4, drop_last=True) - - print('Training with %d image pairs' % len(train_dataset)) + elif args.stage == "sintel": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.2, + "max_scale": 0.6, + "do_flip": True, + } + things = FlyingThings3D(aug_params, dstype="frames_cleanpass") + sintel_clean = MpiSintel(aug_params, split="training", dstype="clean") + sintel_final = MpiSintel(aug_params, split="training", dstype="final") + + if TRAIN_DS == "C+T+K+S+H": + kitti = KITTI( + { + "crop_size": args.image_size, + "min_scale": -0.3, + "max_scale": 0.5, + "do_flip": True, + } + ) + hd1k = HD1K( + { + "crop_size": args.image_size, + "min_scale": -0.5, + "max_scale": 0.2, + "do_flip": True, + } + ) + train_dataset = ( + 100 * sintel_clean + + 100 * sintel_final + + 200 * kitti + + 5 * hd1k + + things + ) + + elif TRAIN_DS == "C+T+K/S": + train_dataset = 100 * sintel_clean + 100 * sintel_final + things + + elif args.stage == "kitti": + aug_params = { + "crop_size": args.image_size, + "min_scale": -0.2, + "max_scale": 0.4, + "do_flip": False, + } + train_dataset = KITTI(aug_params, split="training") + + train_loader = data.DataLoader( + train_dataset, + batch_size=args.batch_size, + pin_memory=False, + shuffle=True, + num_workers=4, + drop_last=True, + ) + + print("Training with %d image pairs" % len(train_dataset)) return train_loader - diff --git a/data_preprocessing/RAFT/core/extractor.py b/data_preprocessing/RAFT/core/extractor.py index 9a9c759..4215b79 100755 --- a/data_preprocessing/RAFT/core/extractor.py +++ b/data_preprocessing/RAFT/core/extractor.py @@ -4,34 +4,36 @@ class ResidualBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn='group', stride=1): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): super(ResidualBlock, self).__init__() - - self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, padding=1, stride=stride + ) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 - if norm_fn == 'group': + if norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == 'batch': + + elif norm_fn == "batch": self.norm1 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes) if not stride == 1: self.norm3 = nn.BatchNorm2d(planes) - - elif norm_fn == 'instance': + + elif norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm3 = nn.InstanceNorm2d(planes) - elif norm_fn == 'none': + elif norm_fn == "none": self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() if not stride == 1: @@ -39,11 +41,11 @@ def __init__(self, in_planes, planes, norm_fn='group', stride=1): if stride == 1: self.downsample = None - - else: - self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 + ) def forward(self, x): y = x @@ -53,43 +55,44 @@ def forward(self, x): if self.downsample is not None: x = self.downsample(x) - return self.relu(x+y) - + return self.relu(x + y) class BottleneckBlock(nn.Module): - def __init__(self, in_planes, planes, norm_fn='group', stride=1): + def __init__(self, in_planes, planes, norm_fn="group", stride=1): super(BottleneckBlock, self).__init__() - - self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) - self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) - self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + + self.conv1 = nn.Conv2d(in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride + ) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) num_groups = planes // 8 - if norm_fn == 'group': - self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) - self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + if norm_fn == "group": + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes // 4) self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) if not stride == 1: self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) - - elif norm_fn == 'batch': - self.norm1 = nn.BatchNorm2d(planes//4) - self.norm2 = nn.BatchNorm2d(planes//4) + + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) self.norm3 = nn.BatchNorm2d(planes) if not stride == 1: self.norm4 = nn.BatchNorm2d(planes) - - elif norm_fn == 'instance': - self.norm1 = nn.InstanceNorm2d(planes//4) - self.norm2 = nn.InstanceNorm2d(planes//4) + + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) self.norm3 = nn.InstanceNorm2d(planes) if not stride == 1: self.norm4 = nn.InstanceNorm2d(planes) - elif norm_fn == 'none': + elif norm_fn == "none": self.norm1 = nn.Sequential() self.norm2 = nn.Sequential() self.norm3 = nn.Sequential() @@ -98,11 +101,11 @@ def __init__(self, in_planes, planes, norm_fn='group', stride=1): if stride == 1: self.downsample = None - - else: - self.downsample = nn.Sequential( - nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4 + ) def forward(self, x): y = x @@ -113,30 +116,31 @@ def forward(self, x): if self.downsample is not None: x = self.downsample(x) - return self.relu(x+y) + return self.relu(x + y) + class BasicEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): super(BasicEncoder, self).__init__() self.norm_fn = norm_fn - if self.norm_fn == 'group': + if self.norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) - - elif self.norm_fn == 'batch': + + elif self.norm_fn == "batch": self.norm1 = nn.BatchNorm2d(64) - elif self.norm_fn == 'instance': + elif self.norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(64) - elif self.norm_fn == 'none': + elif self.norm_fn == "none": self.norm1 = nn.Sequential() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.relu1 = nn.ReLU(inplace=True) self.in_planes = 64 - self.layer1 = self._make_layer(64, stride=1) + self.layer1 = self._make_layer(64, stride=1) self.layer2 = self._make_layer(96, stride=2) self.layer3 = self._make_layer(128, stride=2) @@ -149,7 +153,7 @@ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) @@ -160,11 +164,10 @@ def _make_layer(self, dim, stride=1): layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) - + self.in_planes = dim return nn.Sequential(*layers) - def forward(self, x): # if input is list, combine batch dimension @@ -193,39 +196,39 @@ def forward(self, x): class SmallEncoder(nn.Module): - def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + def __init__(self, output_dim=128, norm_fn="batch", dropout=0.0): super(SmallEncoder, self).__init__() self.norm_fn = norm_fn - if self.norm_fn == 'group': + if self.norm_fn == "group": self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) - - elif self.norm_fn == 'batch': + + elif self.norm_fn == "batch": self.norm1 = nn.BatchNorm2d(32) - elif self.norm_fn == 'instance': + elif self.norm_fn == "instance": self.norm1 = nn.InstanceNorm2d(32) - elif self.norm_fn == 'none': + elif self.norm_fn == "none": self.norm1 = nn.Sequential() self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) self.relu1 = nn.ReLU(inplace=True) self.in_planes = 32 - self.layer1 = self._make_layer(32, stride=1) + self.layer1 = self._make_layer(32, stride=1) self.layer2 = self._make_layer(64, stride=2) self.layer3 = self._make_layer(96, stride=2) self.dropout = None if dropout > 0: self.dropout = nn.Dropout2d(p=dropout) - + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) for m in self.modules(): if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) @@ -236,11 +239,10 @@ def _make_layer(self, dim, stride=1): layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) layers = (layer1, layer2) - + self.in_planes = dim return nn.Sequential(*layers) - def forward(self, x): # if input is list, combine batch dimension diff --git a/data_preprocessing/RAFT/core/raft.py b/data_preprocessing/RAFT/core/raft.py index 652b81a..72ed994 100755 --- a/data_preprocessing/RAFT/core/raft.py +++ b/data_preprocessing/RAFT/core/raft.py @@ -15,8 +15,10 @@ class autocast: def __init__(self, enabled): pass + def __enter__(self): pass + def __exit__(self, *args): pass @@ -31,28 +33,36 @@ def __init__(self, args): self.context_dim = cdim = 64 args.corr_levels = 4 args.corr_radius = 3 - + else: self.hidden_dim = hdim = 128 self.context_dim = cdim = 128 args.corr_levels = 4 args.corr_radius = 4 - if 'dropout' not in self.args: + if "dropout" not in self.args: self.args.dropout = 0 - if 'alternate_corr' not in self.args: + if "alternate_corr" not in self.args: self.args.alternate_corr = False # feature network, context network, and update block if args.small: - self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) - self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) + self.fnet = SmallEncoder( + output_dim=128, norm_fn="instance", dropout=args.dropout + ) + self.cnet = SmallEncoder( + output_dim=hdim + cdim, norm_fn="none", dropout=args.dropout + ) self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) else: - self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) - self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) + self.fnet = BasicEncoder( + output_dim=256, norm_fn="instance", dropout=args.dropout + ) + self.cnet = BasicEncoder( + output_dim=hdim + cdim, norm_fn="batch", dropout=args.dropout + ) self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) def freeze_bn(self): @@ -61,30 +71,31 @@ def freeze_bn(self): m.eval() def initialize_flow(self, img): - """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + """Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" N, C, H, W = img.shape - coords0 = coords_grid(N, H//8, W//8, device=img.device) - coords1 = coords_grid(N, H//8, W//8, device=img.device) + coords0 = coords_grid(N, H // 8, W // 8, device=img.device) + coords1 = coords_grid(N, H // 8, W // 8, device=img.device) # optical flow computed as difference: flow = coords1 - coords0 return coords0, coords1 def upsample_flow(self, flow, mask): - """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + """Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination""" N, _, H, W = flow.shape mask = mask.view(N, 1, 9, 8, 8, H, W) mask = torch.softmax(mask, dim=2) - up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = F.unfold(8 * flow, [3, 3], padding=1) up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) up_flow = torch.sum(mask * up_flow, dim=2) up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) - return up_flow.reshape(N, 2, 8*H, 8*W) - + return up_flow.reshape(N, 2, 8 * H, 8 * W) - def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): - """ Estimate optical flow between pair of frames """ + def forward( + self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False + ): + """Estimate optical flow between pair of frames""" image1 = 2 * (image1 / 255.0) - 1.0 image2 = 2 * (image2 / 255.0) - 1.0 @@ -97,8 +108,8 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_ # run the feature network with autocast(enabled=self.args.mixed_precision): - fmap1, fmap2 = self.fnet([image1, image2]) - + fmap1, fmap2 = self.fnet([image1, image2]) + fmap1 = fmap1.float() fmap2 = fmap2.float() if self.args.alternate_corr: @@ -121,7 +132,7 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_ flow_predictions = [] for itr in range(iters): coords1 = coords1.detach() - corr = corr_fn(coords1) # index correlation volume + corr = corr_fn(coords1) # index correlation volume flow = coords1 - coords0 with autocast(enabled=self.args.mixed_precision): @@ -135,10 +146,10 @@ def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_ flow_up = upflow8(coords1 - coords0) else: flow_up = self.upsample_flow(coords1 - coords0, up_mask) - + flow_predictions.append(flow_up) if test_mode: return coords1 - coords0, flow_up - + return flow_predictions diff --git a/data_preprocessing/RAFT/core/update.py b/data_preprocessing/RAFT/core/update.py index f940497..ced6df0 100755 --- a/data_preprocessing/RAFT/core/update.py +++ b/data_preprocessing/RAFT/core/update.py @@ -13,56 +13,70 @@ def __init__(self, input_dim=128, hidden_dim=256): def forward(self, x): return self.conv2(self.relu(self.conv1(x))) + class ConvGRU(nn.Module): - def __init__(self, hidden_dim=128, input_dim=192+128): + def __init__(self, hidden_dim=128, input_dim=192 + 128): super(ConvGRU, self).__init__() - self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) - self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) - self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, 3, padding=1) def forward(self, h, x): hx = torch.cat([h, x], dim=1) z = torch.sigmoid(self.convz(hx)) r = torch.sigmoid(self.convr(hx)) - q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) - h = (1-z) * h + z * q + h = (1 - z) * h + z * q return h + class SepConvGRU(nn.Module): - def __init__(self, hidden_dim=128, input_dim=192+128): + def __init__(self, hidden_dim=128, input_dim=192 + 128): super(SepConvGRU, self).__init__() - self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) - self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) - self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) - - self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) - self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) - self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) - + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2) + ) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0) + ) def forward(self, h, x): # horizontal hx = torch.cat([h, x], dim=1) z = torch.sigmoid(self.convz1(hx)) r = torch.sigmoid(self.convr1(hx)) - q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) - h = (1-z) * h + z * q + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q # vertical hx = torch.cat([h, x], dim=1) z = torch.sigmoid(self.convz2(hx)) r = torch.sigmoid(self.convr2(hx)) - q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) - h = (1-z) * h + z * q + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q return h + class SmallMotionEncoder(nn.Module): def __init__(self, args): super(SmallMotionEncoder, self).__init__() - cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) self.convf1 = nn.Conv2d(2, 64, 7, padding=3) self.convf2 = nn.Conv2d(64, 32, 3, padding=1) @@ -76,15 +90,16 @@ def forward(self, flow, corr): out = F.relu(self.conv(cor_flo)) return torch.cat([out, flow], dim=1) + class BasicMotionEncoder(nn.Module): def __init__(self, args): super(BasicMotionEncoder, self).__init__() - cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + cor_planes = args.corr_levels * (2 * args.corr_radius + 1) ** 2 self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) self.convc2 = nn.Conv2d(256, 192, 3, padding=1) self.convf1 = nn.Conv2d(2, 128, 7, padding=3) self.convf2 = nn.Conv2d(128, 64, 3, padding=1) - self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) def forward(self, flow, corr): cor = F.relu(self.convc1(corr)) @@ -96,11 +111,12 @@ def forward(self, flow, corr): out = F.relu(self.conv(cor_flo)) return torch.cat([out, flow], dim=1) + class SmallUpdateBlock(nn.Module): def __init__(self, args, hidden_dim=96): super(SmallUpdateBlock, self).__init__() self.encoder = SmallMotionEncoder(args) - self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) self.flow_head = FlowHead(hidden_dim, hidden_dim=128) def forward(self, net, inp, corr, flow): @@ -111,18 +127,20 @@ def forward(self, net, inp, corr, flow): return net, None, delta_flow + class BasicUpdateBlock(nn.Module): def __init__(self, args, hidden_dim=128, input_dim=128): super(BasicUpdateBlock, self).__init__() self.args = args self.encoder = BasicMotionEncoder(args) - self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128 + hidden_dim) self.flow_head = FlowHead(hidden_dim, hidden_dim=256) self.mask = nn.Sequential( nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(256, 64*9, 1, padding=0)) + nn.Conv2d(256, 64 * 9, 1, padding=0), + ) def forward(self, net, inp, corr, flow, upsample=True): motion_features = self.encoder(flow, corr) @@ -132,8 +150,5 @@ def forward(self, net, inp, corr, flow, upsample=True): delta_flow = self.flow_head(net) # scale mask to balence gradients - mask = .25 * self.mask(net) + mask = 0.25 * self.mask(net) return net, mask, delta_flow - - - diff --git a/data_preprocessing/RAFT/core/utils/augmentor.py b/data_preprocessing/RAFT/core/utils/augmentor.py index e81c4f2..ae675df 100755 --- a/data_preprocessing/RAFT/core/utils/augmentor.py +++ b/data_preprocessing/RAFT/core/utils/augmentor.py @@ -4,6 +4,7 @@ from PIL import Image import cv2 + cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False) @@ -14,7 +15,7 @@ class FlowAugmentor: def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): - + # spatial augmentation params self.crop_size = crop_size self.min_scale = min_scale @@ -29,12 +30,14 @@ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): self.v_flip_prob = 0.1 # photometric augmentation params - self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.photo_aug = ColorJitter( + brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14 + ) self.asymmetric_color_aug_prob = 0.2 self.eraser_aug_prob = 0.5 def color_transform(self, img1, img2): - """ Photometric augmentation """ + """Photometric augmentation""" # asymmetric if np.random.rand() < self.asymmetric_color_aug_prob: @@ -44,13 +47,15 @@ def color_transform(self, img1, img2): # symmetric else: image_stack = np.concatenate([img1, img2], axis=0) - image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + image_stack = np.array( + self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 + ) img1, img2 = np.split(image_stack, 2, axis=0) return img1, img2 def eraser_transform(self, img1, img2, bounds=[50, 100]): - """ Occlusion augmentation """ + """Occlusion augmentation""" ht, wd = img1.shape[:2] if np.random.rand() < self.eraser_aug_prob: @@ -60,7 +65,7 @@ def eraser_transform(self, img1, img2, bounds=[50, 100]): y0 = np.random.randint(0, ht) dx = np.random.randint(bounds[0], bounds[1]) dy = np.random.randint(bounds[0], bounds[1]) - img2[y0:y0+dy, x0:x0+dx, :] = mean_color + img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color return img1, img2 @@ -68,8 +73,8 @@ def spatial_transform(self, img1, img2, flow): # randomly sample scale ht, wd = img1.shape[:2] min_scale = np.maximum( - (self.crop_size[0] + 8) / float(ht), - (self.crop_size[1] + 8) / float(wd)) + (self.crop_size[0] + 8) / float(ht), (self.crop_size[1] + 8) / float(wd) + ) scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) scale_x = scale @@ -77,34 +82,40 @@ def spatial_transform(self, img1, img2, flow): if np.random.rand() < self.stretch_prob: scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) - + scale_x = np.clip(scale_x, min_scale, None) scale_y = np.clip(scale_y, min_scale, None) if np.random.rand() < self.spatial_aug_prob: # rescale the images - img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img1 = cv2.resize( + img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + img2 = cv2.resize( + img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + flow = cv2.resize( + flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) flow = flow * [scale_x, scale_y] if self.do_flip: - if np.random.rand() < self.h_flip_prob: # h-flip + if np.random.rand() < self.h_flip_prob: # h-flip img1 = img1[:, ::-1] img2 = img2[:, ::-1] flow = flow[:, ::-1] * [-1.0, 1.0] - if np.random.rand() < self.v_flip_prob: # v-flip + if np.random.rand() < self.v_flip_prob: # v-flip img1 = img1[::-1, :] img2 = img2[::-1, :] flow = flow[::-1, :] * [1.0, -1.0] y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) - - img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] return img1, img2, flow @@ -119,6 +130,7 @@ def __call__(self, img1, img2, flow): return img1, img2, flow + class SparseFlowAugmentor: def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): # spatial augmentation params @@ -135,13 +147,17 @@ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): self.v_flip_prob = 0.1 # photometric augmentation params - self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.photo_aug = ColorJitter( + brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14 + ) self.asymmetric_color_aug_prob = 0.2 self.eraser_aug_prob = 0.5 - + def color_transform(self, img1, img2): image_stack = np.concatenate([img1, img2], axis=0) - image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + image_stack = np.array( + self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8 + ) img1, img2 = np.split(image_stack, 2, axis=0) return img1, img2 @@ -154,7 +170,7 @@ def eraser_transform(self, img1, img2): y0 = np.random.randint(0, ht) dx = np.random.randint(50, 100) dy = np.random.randint(50, 100) - img2[y0:y0+dy, x0:x0+dx, :] = mean_color + img2[y0 : y0 + dy, x0 : x0 + dx, :] = mean_color return img1, img2 @@ -167,8 +183,8 @@ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): flow = flow.reshape(-1, 2).astype(np.float32) valid = valid.reshape(-1).astype(np.float32) - coords0 = coords[valid>=1] - flow0 = flow[valid>=1] + coords0 = coords[valid >= 1] + flow0 = flow[valid >= 1] ht1 = int(round(ht * fy)) wd1 = int(round(wd * fx)) @@ -176,8 +192,8 @@ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): coords1 = coords0 * [fx, fy] flow1 = flow0 * [fx, fy] - xx = np.round(coords1[:,0]).astype(np.int32) - yy = np.round(coords1[:,1]).astype(np.int32) + xx = np.round(coords1[:, 0]).astype(np.int32) + yy = np.round(coords1[:, 1]).astype(np.int32) v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) xx = xx[v] @@ -197,8 +213,8 @@ def spatial_transform(self, img1, img2, flow, valid): ht, wd = img1.shape[:2] min_scale = np.maximum( - (self.crop_size[0] + 1) / float(ht), - (self.crop_size[1] + 1) / float(wd)) + (self.crop_size[0] + 1) / float(ht), (self.crop_size[1] + 1) / float(wd) + ) scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) scale_x = np.clip(scale, min_scale, None) @@ -206,12 +222,18 @@ def spatial_transform(self, img1, img2, flow, valid): if np.random.rand() < self.spatial_aug_prob: # rescale the images - img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) - flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + img1 = cv2.resize( + img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + img2 = cv2.resize( + img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR + ) + flow, valid = self.resize_sparse_flow_map( + flow, valid, fx=scale_x, fy=scale_y + ) if self.do_flip: - if np.random.rand() < 0.5: # h-flip + if np.random.rand() < 0.5: # h-flip img1 = img1[:, ::-1] img2 = img2[:, ::-1] flow = flow[:, ::-1] * [-1.0, 1.0] @@ -226,13 +248,12 @@ def spatial_transform(self, img1, img2, flow, valid): y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) - img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] - valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img1 = img1[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + img2 = img2[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + flow = flow[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] + valid = valid[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] return img1, img2, flow, valid - def __call__(self, img1, img2, flow, valid): img1, img2 = self.color_transform(img1, img2) img1, img2 = self.eraser_transform(img1, img2) diff --git a/data_preprocessing/RAFT/core/utils/flow_viz.py b/data_preprocessing/RAFT/core/utils/flow_viz.py index dcee65e..fec0836 100755 --- a/data_preprocessing/RAFT/core/utils/flow_viz.py +++ b/data_preprocessing/RAFT/core/utils/flow_viz.py @@ -17,6 +17,7 @@ import numpy as np + def make_colorwheel(): """ Generates a color wheel for optical flow visualization as presented in: @@ -43,27 +44,27 @@ def make_colorwheel(): # RY colorwheel[0:RY, 0] = 255 - colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) - col = col+RY + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col = col + RY # YG - colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) - colorwheel[col:col+YG, 1] = 255 - col = col+YG + colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG # GC - colorwheel[col:col+GC, 1] = 255 - colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) - col = col+GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col = col + GC # CB - colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) - colorwheel[col:col+CB, 2] = 255 - col = col+CB + colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB # BM - colorwheel[col:col+BM, 2] = 255 - colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) - col = col+BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col = col + BM # MR - colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) - colorwheel[col:col+MR, 0] = 255 + colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 return colorwheel @@ -86,23 +87,23 @@ def flow_uv_to_colors(u, v, convert_to_bgr=False): colorwheel = make_colorwheel() # shape [55x3] ncols = colorwheel.shape[0] rad = np.sqrt(np.square(u) + np.square(v)) - a = np.arctan2(-v, -u)/np.pi - fk = (a+1) / 2*(ncols-1) + a = np.arctan2(-v, -u) / np.pi + fk = (a + 1) / 2 * (ncols - 1) k0 = np.floor(fk).astype(np.int32) k1 = k0 + 1 k1[k1 == ncols] = 0 f = fk - k0 for i in range(colorwheel.shape[1]): - tmp = colorwheel[:,i] + tmp = colorwheel[:, i] col0 = tmp[k0] / 255.0 col1 = tmp[k1] / 255.0 - col = (1-f)*col0 + f*col1 - idx = (rad <= 1) - col[idx] = 1 - rad[idx] * (1-col[idx]) - col[~idx] = col[~idx] * 0.75 # out of range + col = (1 - f) * col0 + f * col1 + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range # Note the 2-i => BGR instead of RGB - ch_idx = 2-i if convert_to_bgr else i - flow_image[:,:,ch_idx] = np.floor(255 * col) + ch_idx = 2 - i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) return flow_image @@ -118,15 +119,15 @@ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): Returns: np.ndarray: Flow visualization image of shape [H,W,3] """ - assert flow_uv.ndim == 3, 'input flow must have three dimensions' - assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + assert flow_uv.ndim == 3, "input flow must have three dimensions" + assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" if clip_flow is not None: flow_uv = np.clip(flow_uv, 0, clip_flow) - u = flow_uv[:,:,0] - v = flow_uv[:,:,1] + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] rad = np.sqrt(np.square(u) + np.square(v)) rad_max = np.max(rad) epsilon = 1e-5 u = u / (rad_max + epsilon) v = v / (rad_max + epsilon) - return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file + return flow_uv_to_colors(u, v, convert_to_bgr) diff --git a/data_preprocessing/RAFT/core/utils/frame_utils.py b/data_preprocessing/RAFT/core/utils/frame_utils.py index 6c49113..bee554e 100755 --- a/data_preprocessing/RAFT/core/utils/frame_utils.py +++ b/data_preprocessing/RAFT/core/utils/frame_utils.py @@ -4,34 +4,37 @@ import re import cv2 + cv2.setNumThreads(0) cv2.ocl.setUseOpenCL(False) TAG_CHAR = np.array([202021.25], np.float32) + def readFlow(fn): - """ Read .flo file in Middlebury format""" + """Read .flo file in Middlebury format""" # Code adapted from: # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy # WARNING: this will work on little-endian architectures (eg Intel x86) only! # print 'fn = %s'%(fn) - with open(fn, 'rb') as f: + with open(fn, "rb") as f: magic = np.fromfile(f, np.float32, count=1) if 202021.25 != magic: - print('Magic number incorrect. Invalid .flo file') + print("Magic number incorrect. Invalid .flo file") return None else: w = np.fromfile(f, np.int32, count=1) h = np.fromfile(f, np.int32, count=1) # print 'Reading %d x %d flo file\n' % (w, h) - data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) # Reshape data into 3D array (columns, rows, bands) # The reshape here is for visualization, the original code is (w,h,2) return np.resize(data, (int(h), int(w), 2)) + def readPFM(file): - file = open(file, 'rb') + file = open(file, "rb") color = None width = None @@ -40,36 +43,37 @@ def readPFM(file): endian = None header = file.readline().rstrip() - if header == b'PF': + if header == b"PF": color = True - elif header == b'Pf': + elif header == b"Pf": color = False else: - raise Exception('Not a PFM file.') + raise Exception("Not a PFM file.") - dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline()) if dim_match: width, height = map(int, dim_match.groups()) else: - raise Exception('Malformed PFM header.') + raise Exception("Malformed PFM header.") scale = float(file.readline().rstrip()) - if scale < 0: # little-endian - endian = '<' + if scale < 0: # little-endian + endian = "<" scale = -scale else: - endian = '>' # big-endian + endian = ">" # big-endian - data = np.fromfile(file, endian + 'f') + data = np.fromfile(file, endian + "f") shape = (height, width, 3) if color else (height, width) data = np.reshape(data, shape) data = np.flipud(data) return data -def writeFlow(filename,uv,v=None): - """ Write optical flow to file. - + +def writeFlow(filename, uv, v=None): + """Write optical flow to file. + If v is None, uv is assumed to contain both u and v channels, stacked in depth. Original code by Deqing Sun, adapted from Daniel Scharstein. @@ -77,35 +81,36 @@ def writeFlow(filename,uv,v=None): nBands = 2 if v is None: - assert(uv.ndim == 3) - assert(uv.shape[2] == 2) - u = uv[:,:,0] - v = uv[:,:,1] + assert uv.ndim == 3 + assert uv.shape[2] == 2 + u = uv[:, :, 0] + v = uv[:, :, 1] else: u = uv - assert(u.shape == v.shape) - height,width = u.shape - f = open(filename,'wb') + assert u.shape == v.shape + height, width = u.shape + f = open(filename, "wb") # write the header f.write(TAG_CHAR) np.array(width).astype(np.int32).tofile(f) np.array(height).astype(np.int32).tofile(f) # arrange into matrix form - tmp = np.zeros((height, width*nBands)) - tmp[:,np.arange(width)*2] = u - tmp[:,np.arange(width)*2 + 1] = v + tmp = np.zeros((height, width * nBands)) + tmp[:, np.arange(width) * 2] = u + tmp[:, np.arange(width) * 2 + 1] = v tmp.astype(np.float32).tofile(f) f.close() def readFlowKITTI(filename): - flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) - flow = flow[:,:,::-1].astype(np.float32) + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) flow, valid = flow[:, :, :2], flow[:, :, 2] flow = (flow - 2**15) / 64.0 return flow, valid + def readDispKITTI(filename): disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 valid = disp > 0.0 @@ -118,20 +123,20 @@ def writeFlowKITTI(filename, uv): valid = np.ones([uv.shape[0], uv.shape[1], 1]) uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) cv2.imwrite(filename, uv[..., ::-1]) - + def read_gen(file_name, pil=False): ext = splitext(file_name)[-1] - if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg": return Image.open(file_name) - elif ext == '.bin' or ext == '.raw': + elif ext == ".bin" or ext == ".raw": return np.load(file_name) - elif ext == '.flo': + elif ext == ".flo": return readFlow(file_name).astype(np.float32) - elif ext == '.pfm': + elif ext == ".pfm": flow = readPFM(file_name).astype(np.float32) if len(flow.shape) == 2: return flow else: return flow[:, :, :-1] - return [] \ No newline at end of file + return [] diff --git a/data_preprocessing/RAFT/core/utils/utils.py b/data_preprocessing/RAFT/core/utils/utils.py index 741ccfe..781dc55 100755 --- a/data_preprocessing/RAFT/core/utils/utils.py +++ b/data_preprocessing/RAFT/core/utils/utils.py @@ -5,23 +5,30 @@ class InputPadder: - """ Pads images such that dimensions are divisible by 8 """ - def __init__(self, dims, mode='sintel'): + """Pads images such that dimensions are divisible by 8""" + + def __init__(self, dims, mode="sintel"): self.ht, self.wd = dims[-2:] pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 - if mode == 'sintel': - self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + if mode == "sintel": + self._pad = [ + pad_wd // 2, + pad_wd - pad_wd // 2, + pad_ht // 2, + pad_ht - pad_ht // 2, + ] else: - self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] def pad(self, *inputs): - return [F.pad(x, self._pad, mode='replicate') for x in inputs] + return [F.pad(x, self._pad, mode="replicate") for x in inputs] - def unpad(self,x): + def unpad(self, x): ht, wd = x.shape[-2:] - c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] - return x[..., c[0]:c[1], c[2]:c[3]] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0] : c[1], c[2] : c[3]] + def forward_interpolate(flow): flow = flow.detach().cpu().numpy() @@ -32,7 +39,7 @@ def forward_interpolate(flow): x1 = x0 + dx y1 = y0 + dy - + x1 = x1.reshape(-1) y1 = y1.reshape(-1) dx = dx.reshape(-1) @@ -45,21 +52,23 @@ def forward_interpolate(flow): dy = dy[valid] flow_x = interpolate.griddata( - (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + (x1, y1), dx, (x0, y0), method="nearest", fill_value=0 + ) flow_y = interpolate.griddata( - (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + (x1, y1), dy, (x0, y0), method="nearest", fill_value=0 + ) flow = np.stack([flow_x, flow_y], axis=0) return torch.from_numpy(flow).float() -def bilinear_sampler(img, coords, mode='bilinear', mask=False): - """ Wrapper for grid_sample, uses pixel coordinates """ +def bilinear_sampler(img, coords, mode="bilinear", mask=False): + """Wrapper for grid_sample, uses pixel coordinates""" H, W = img.shape[-2:] - xgrid, ygrid = coords.split([1,1], dim=-1) - xgrid = 2*xgrid/(W-1) - 1 - ygrid = 2*ygrid/(H-1) - 1 + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 grid = torch.cat([xgrid, ygrid], dim=-1) img = F.grid_sample(img, grid, align_corners=True) @@ -72,11 +81,13 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False): def coords_grid(batch, ht, wd, device): - coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.meshgrid( + torch.arange(ht, device=device), torch.arange(wd, device=device) + ) coords = torch.stack(coords[::-1], dim=0).float() return coords[None].repeat(batch, 1, 1, 1) -def upflow8(flow, mode='bilinear'): +def upflow8(flow, mode="bilinear"): new_size = (8 * flow.shape[2], 8 * flow.shape[3]) - return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/data_preprocessing/RAFT/demo.py b/data_preprocessing/RAFT/demo.py index b6df6bc..9a447b9 100755 --- a/data_preprocessing/RAFT/demo.py +++ b/data_preprocessing/RAFT/demo.py @@ -1,5 +1,6 @@ import sys -sys.path.append('core') + +sys.path.append("core") import argparse import os @@ -14,7 +15,7 @@ from utils import flow_viz from utils.utils import InputPadder -DEVICE = 'cuda' +DEVICE = "cuda" def load_image(imfile): @@ -23,15 +24,15 @@ def load_image(imfile): return img[None].to(DEVICE) -def viz(img, flo,img_name=None): - img = img[0].permute(1,2,0).cpu().numpy() - flo = flo[0].permute(1,2,0).cpu().numpy() +def viz(img, flo, img_name=None): + img = img[0].permute(1, 2, 0).cpu().numpy() + flo = flo[0].permute(1, 2, 0).cpu().numpy() # map flow to rgb image flo = flow_viz.flow_to_image(flo) img_flo = np.concatenate([img, flo], axis=0) - cv2.imwrite(f'{img_name}', img_flo[:, :, [2,1,0]]) + cv2.imwrite(f"{img_name}", img_flo[:, :, [2, 1, 0]]) def demo(args): @@ -44,66 +45,76 @@ def demo(args): os.makedirs(args.outdir, exist_ok=True) os.makedirs(args.outdir_conf, exist_ok=True) with torch.no_grad(): - images = glob.glob(os.path.join(args.path, '*.png')) + \ - glob.glob(os.path.join(args.path, '*.jpg')) + images = glob.glob(os.path.join(args.path, "*.png")) + glob.glob( + os.path.join(args.path, "*.jpg") + ) images = sorted(images) - i=0 + i = 0 for imfile1, imfile2 in zip(images[:-1], images[1:]): image1 = load_image(imfile1) image2 = load_image(imfile2) if args.if_mask: - mk_file1=imfile1.split("/") - mk_file1[-2]=f"{args.name}_masks" - mk_file1='/'.join(mk_file1) - mk_file2=imfile2.split("/") - mk_file2[-2]=f"{args.name}_masks" - mk_file2='/'.join(mk_file2) - mask1=cv2.imread(mk_file1.replace('jpg','png') - ,0) - mask2=cv2.imread(mk_file2.replace('jpg','png'), - 0) - mask1=torch.from_numpy(mask1).to(DEVICE).float() - mask2=torch.from_numpy(mask2).to(DEVICE).float() - mask1[mask1>0]=1 - mask2[mask2>0]=1 - image1*=mask1 - image2*=mask2 + mk_file1 = imfile1.split("/") + mk_file1[-2] = f"{args.name}_masks" + mk_file1 = "/".join(mk_file1) + mk_file2 = imfile2.split("/") + mk_file2[-2] = f"{args.name}_masks" + mk_file2 = "/".join(mk_file2) + mask1 = cv2.imread(mk_file1.replace("jpg", "png"), 0) + mask2 = cv2.imread(mk_file2.replace("jpg", "png"), 0) + mask1 = torch.from_numpy(mask1).to(DEVICE).float() + mask2 = torch.from_numpy(mask2).to(DEVICE).float() + mask1[mask1 > 0] = 1 + mask2[mask2 > 0] = 1 + image1 *= mask1 + image2 *= mask2 padder = InputPadder(image1.shape) image1, image2 = padder.pad(image1, image2) if args.if_mask: - mask1,mask2=padder.pad(mask1.unsqueeze(0).unsqueeze(0), - mask2.unsqueeze(0).unsqueeze(0)) - mask1=mask1.squeeze() - mask2=mask2.squeeze() + mask1, mask2 = padder.pad( + mask1.unsqueeze(0).unsqueeze(0), mask2.unsqueeze(0).unsqueeze(0) + ) + mask1 = mask1.squeeze() + mask2 = mask2.squeeze() flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) flow_low_, flow_up_ = model(image2, image1, iters=20, test_mode=True) flow_1to2 = flow_up.clone() flow_2to1 = flow_up_.clone() - _,_,H,W=image1.shape + _, _, H, W = image1.shape x = torch.linspace(0, 1, W) y = torch.linspace(0, 1, H) - grid_x,grid_y=torch.meshgrid(x,y) - grid=torch.stack([grid_x,grid_y],dim=0).to(DEVICE) - grid=grid.permute(0,2,1) - grid[0]*=W - grid[1]*=H + grid_x, grid_y = torch.meshgrid(x, y) + grid = torch.stack([grid_x, grid_y], dim=0).to(DEVICE) + grid = grid.permute(0, 2, 1) + grid[0] *= W + grid[1] *= H if args.if_mask: - flow_up[:,:,mask1.long()==0]=10000 - grid_=grid+flow_up.squeeze() - - grid_norm=grid_.clone() - grid_norm[0,...]=2*grid_norm[0,...]/(W-1)-1 - grid_norm[1,...]=2*grid_norm[1,...]/(H-1)-1 - - flow_bilinear_=F.grid_sample(flow_up_,grid_norm.unsqueeze(0).permute(0,2,3,1),mode='bilinear',padding_mode='zeros') - - rgb_bilinear_=F.grid_sample(image2,grid_norm.unsqueeze(0).permute(0,2,3,1),mode='bilinear',padding_mode='zeros') - rgb_np=rgb_bilinear_.squeeze().permute(1,2,0).cpu().numpy()[:, :, ::-1] - cv2.imwrite(f'{args.outdir}/warped.png',rgb_np) + flow_up[:, :, mask1.long() == 0] = 10000 + grid_ = grid + flow_up.squeeze() + + grid_norm = grid_.clone() + grid_norm[0, ...] = 2 * grid_norm[0, ...] / (W - 1) - 1 + grid_norm[1, ...] = 2 * grid_norm[1, ...] / (H - 1) - 1 + + flow_bilinear_ = F.grid_sample( + flow_up_, + grid_norm.unsqueeze(0).permute(0, 2, 3, 1), + mode="bilinear", + padding_mode="zeros", + ) + + rgb_bilinear_ = F.grid_sample( + image2, + grid_norm.unsqueeze(0).permute(0, 2, 3, 1), + mode="bilinear", + padding_mode="zeros", + ) + rgb_np = rgb_bilinear_.squeeze().permute(1, 2, 0).cpu().numpy()[:, :, ::-1] + cv2.imwrite(f"{args.outdir}/warped.png", rgb_np) if args.confidence: ### Calculate confidence map using cycle consistency. @@ -118,62 +129,92 @@ def demo(args): norm_grid_2to1 = grid_2to1.clone() norm_grid_2to1[0, ...] = 2 * norm_grid_2to1[0, ...] / (W - 1) - 1 norm_grid_2to1[1, ...] = 2 * norm_grid_2to1[1, ...] / (H - 1) - 1 - warped_image2 = F.grid_sample(image1, norm_grid_2to1.unsqueeze(0).permute(0,2,3,1), mode='bilinear', padding_mode='zeros') + warped_image2 = F.grid_sample( + image1, + norm_grid_2to1.unsqueeze(0).permute(0, 2, 3, 1), + mode="bilinear", + padding_mode="zeros", + ) grid_1to2 = grid + flow_1to2.squeeze() norm_grid_1to2 = grid_1to2.clone() norm_grid_1to2[0, ...] = 2 * norm_grid_1to2[0, ...] / (W - 1) - 1 norm_grid_1to2[1, ...] = 2 * norm_grid_1to2[1, ...] / (H - 1) - 1 - warped_image1 = F.grid_sample(warped_image2, norm_grid_1to2.unsqueeze(0).permute(0,2,3,1), mode='bilinear', padding_mode='zeros') + warped_image1 = F.grid_sample( + warped_image2, + norm_grid_1to2.unsqueeze(0).permute(0, 2, 3, 1), + mode="bilinear", + padding_mode="zeros", + ) error = torch.abs(image1 - warped_image1) confidence_map = torch.mean(error, dim=1, keepdim=True) confidence_map[confidence_map < args.thres] = 1 confidence_map[confidence_map >= args.thres] = 0 - grid_bck=grid+flow_up.squeeze()+flow_bilinear_.squeeze() - res=grid-grid_bck - res=torch.norm(res,dim=0) - mk=(res<10)&(flow_up.norm(dim=1).squeeze()>5) + grid_bck = grid + flow_up.squeeze() + flow_bilinear_.squeeze() + res = grid - grid_bck + res = torch.norm(res, dim=0) + mk = (res < 10) & (flow_up.norm(dim=1).squeeze() > 5) - pts_src=grid[:,mk] + pts_src = grid[:, mk] - pts_dst=(grid[:,mk]+flow_up.squeeze()[:,mk]) + pts_dst = grid[:, mk] + flow_up.squeeze()[:, mk] - pts_src=pts_src.permute(1,0).cpu().numpy() - pts_dst=pts_dst.permute(1,0).cpu().numpy() - indx=torch.randperm(pts_src.shape[0])[:30] + pts_src = pts_src.permute(1, 0).cpu().numpy() + pts_dst = pts_dst.permute(1, 0).cpu().numpy() + indx = torch.randperm(pts_src.shape[0])[:30] # use cv2 to draw the matches in image1 and image2 - img_new=np.zeros((H,W*2,3),dtype=np.uint8) - img_new[:,:W,:]=image1[0].permute(1,2,0).cpu().numpy() - img_new[:,W:,:]=image2[0].permute(1,2,0).cpu().numpy() + img_new = np.zeros((H, W * 2, 3), dtype=np.uint8) + img_new[:, :W, :] = image1[0].permute(1, 2, 0).cpu().numpy() + img_new[:, W:, :] = image2[0].permute(1, 2, 0).cpu().numpy() for j in indx: - cv2.line(img_new,(int(pts_src[j,0]),int(pts_src[j,1])),(int(pts_dst[j,0])+W,int(pts_dst[j,1])),(0,255,0),1) + cv2.line( + img_new, + (int(pts_src[j, 0]), int(pts_src[j, 1])), + (int(pts_dst[j, 0]) + W, int(pts_dst[j, 1])), + (0, 255, 0), + 1, + ) - cv2.imwrite(f'{args.outdir}/matches.png',img_new) + cv2.imwrite(f"{args.outdir}/matches.png", img_new) - np.save(f'{args.outdir}/{i:06d}.npy', flow_up.cpu().numpy()) + np.save(f"{args.outdir}/{i:06d}.npy", flow_up.cpu().numpy()) if args.confidence: - np.save(f'{args.outdir_conf}/{i:06d}_c.npy', confidence_map.cpu().numpy()) + np.save( + f"{args.outdir_conf}/{i:06d}_c.npy", confidence_map.cpu().numpy() + ) i += 1 - viz(image1, flow_up,f'{args.outdir}/flow_up{i:03d}.png') + viz(image1, flow_up, f"{args.outdir}/flow_up{i:03d}.png") -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', help="restore checkpoint") - parser.add_argument('--path', help="dataset for evaluation") - parser.add_argument('--outdir',help="directory for the ouput the result") - parser.add_argument('--small', action='store_true', help='use small model') - parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') - parser.add_argument('--if_mask', action='store_true', help='if using the image mask to mask the color img') - parser.add_argument('--confidence', action='store_true', help='if saving the confidence map') - parser.add_argument('--discrete', action='store_true', help='if saving the confidence map in discrete') - parser.add_argument('--thres', default=4, help='Threshold value for confidence map') - parser.add_argument('--outdir_conf', help="directory to save flow confidence") - parser.add_argument('--name', help="the name of a sequence") + parser.add_argument("--model", help="restore checkpoint") + parser.add_argument("--path", help="dataset for evaluation") + parser.add_argument("--outdir", help="directory for the ouput the result") + parser.add_argument("--small", action="store_true", help="use small model") + parser.add_argument( + "--mixed_precision", action="store_true", help="use mixed precision" + ) + parser.add_argument( + "--if_mask", + action="store_true", + help="if using the image mask to mask the color img", + ) + parser.add_argument( + "--confidence", action="store_true", help="if saving the confidence map" + ) + parser.add_argument( + "--discrete", + action="store_true", + help="if saving the confidence map in discrete", + ) + parser.add_argument("--thres", default=4, help="Threshold value for confidence map") + parser.add_argument("--outdir_conf", help="directory to save flow confidence") + parser.add_argument("--name", help="the name of a sequence") args = parser.parse_args() demo(args) diff --git a/data_preprocessing/RAFT/run_raft.sh b/data_preprocessing/RAFT/run_raft.sh index 164bffd..be789a1 100755 --- a/data_preprocessing/RAFT/run_raft.sh +++ b/data_preprocessing/RAFT/run_raft.sh @@ -1,6 +1,6 @@ -NAME=beauty_1 -ROOT_DIR=/home/xxx/code/CoDeF/all_sequences -CODE_DIR=/home/xxx/code/CoDeF/data_preprocessing/RAFT +NAME=heygen +ROOT_DIR=/root/CoDeF/all_sequences +CODE_DIR=/root/CoDeF/data_preprocessing/RAFT IMG_DIR=$ROOT_DIR/${NAME}/${NAME} FLOW_DIR=$ROOT_DIR/${NAME}/${NAME}_flow diff --git a/data_preprocessing/preproc_mask.py b/data_preprocessing/preproc_mask.py index d78a38a..3cacb15 100755 --- a/data_preprocessing/preproc_mask.py +++ b/data_preprocessing/preproc_mask.py @@ -4,17 +4,17 @@ from glob import glob from tqdm import tqdm -root_dir = '/home/xxx/code/CoDeF/all_sequences' -name = 'beauty_1' +root_dir = "/root/CoDeF/all_sequences" +name = "heygen" -msk_folder = f'{root_dir}/{name}/{name}_masks' -img_folder = f'{root_dir}/{name}/{name}' -frg_mask_folder = f'{root_dir}/{name}/{name}_masks_0' -bkg_mask_folder = f'{root_dir}/{name}/{name}_masks_1' +msk_folder = f"{root_dir}/{name}/{name}_masks" +img_folder = f"{root_dir}/{name}/{name}" +frg_mask_folder = f"{root_dir}/{name}/{name}_masks_0" +bkg_mask_folder = f"{root_dir}/{name}/{name}_masks_1" os.makedirs(frg_mask_folder, exist_ok=True) os.makedirs(bkg_mask_folder, exist_ok=True) -files = glob(msk_folder + '/*.png') +files = glob(msk_folder + "/*.png") num = len(files) for i in tqdm(range(num)): @@ -27,4 +27,4 @@ bg_mask[bg_mask == 0] = 127 bg_mask[bg_mask == 255] = 0 bg_mask[bg_mask == 127] = 255 - cv2.imwrite(os.path.join(bkg_mask_folder, file_n), bg_mask) \ No newline at end of file + cv2.imwrite(os.path.join(bkg_mask_folder, file_n), bg_mask) diff --git a/datasets/__init__.py b/datasets/__init__.py index 5d34d1a..cdbc0b1 100755 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -1,6 +1,6 @@ from .distributed_weighted_sampler import DistributedWeightedSampler from .video_dataset import VideoDataset -dataset_dict = {'video': VideoDataset} +dataset_dict = {"video": VideoDataset} -custom_sampler_dict = {'weighted': DistributedWeightedSampler} \ No newline at end of file +custom_sampler_dict = {"weighted": DistributedWeightedSampler} diff --git a/datasets/distributed_weighted_sampler.py b/datasets/distributed_weighted_sampler.py index 26049b7..f92d0d7 100755 --- a/datasets/distributed_weighted_sampler.py +++ b/datasets/distributed_weighted_sampler.py @@ -8,7 +8,7 @@ import torch.distributed as dist -T_co = TypeVar('T_co', covariant=True) +T_co = TypeVar("T_co", covariant=True) class DistributedWeightedSampler(Sampler[T_co]): @@ -58,9 +58,16 @@ class DistributedWeightedSampler(Sampler[T_co]): ... train(loader) """ - def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, - rank: Optional[int] = None, shuffle: bool = True, - seed: int = 0, drop_last: bool = False, replacement: bool = True) -> None: + def __init__( + self, + dataset: Dataset, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + replacement: bool = True, + ) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -72,7 +79,8 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, if rank >= num_replicas or rank < 0: raise ValueError( "Invalid rank {}, rank should be in the interval" - " [0, {}]".format(rank, num_replicas - 1)) + " [0, {}]".format(rank, num_replicas - 1) + ) self.dataset = dataset self.num_replicas = num_replicas self.rank = rank @@ -87,7 +95,8 @@ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, self.num_samples = math.ceil( # `type:ignore` is required because Dataset cannot provide a default __len__ # see NOTE in pytorch/torch/utils/data/sampler.py - (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + (len(self.dataset) - self.num_replicas) + / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] @@ -119,14 +128,16 @@ def __iter__(self) -> Iterator[T_co]: if padding_size <= len(indices): indices += indices[:padding_size] else: - indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + indices += (indices * math.ceil(padding_size / len(indices)))[ + :padding_size + ] else: # remove tail of data to make it evenly divisible. - indices = indices[:self.total_size] + indices = indices[: self.total_size] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples # subsample weights @@ -134,21 +145,22 @@ def __iter__(self) -> Iterator[T_co]: weights = self.weights[indices][:, 0] assert len(weights) == self.num_samples - ########################################################################### + ########################################################################### # the upper bound category number of multinomial is 2^24, to handle this we can use chunk or using random choices # subsample_balanced_indicies = torch.multinomial(weights, self.num_samples, self.replacement) - ########################################################################### + ########################################################################### # using random choices - rand_tensor = np.random.choice(range(0, len(weights)), - size=self.num_samples, - p=weights.numpy() / torch.sum(weights).numpy(), - replace=self.replacement) + rand_tensor = np.random.choice( + range(0, len(weights)), + size=self.num_samples, + p=weights.numpy() / torch.sum(weights).numpy(), + replace=self.replacement, + ) subsample_balanced_indicies = torch.from_numpy(rand_tensor) dataset_indices = torch.tensor(indices)[subsample_balanced_indicies] return iter(dataset_indices.tolist()) - def __len__(self) -> int: return self.num_samples @@ -161,4 +173,4 @@ def set_epoch(self, epoch: int) -> None: Args: epoch (int): Epoch number. """ - self.epoch = epoch \ No newline at end of file + self.epoch = epoch diff --git a/datasets/video_dataset.py b/datasets/video_dataset.py index b969ee3..942d9e0 100755 --- a/datasets/video_dataset.py +++ b/datasets/video_dataset.py @@ -8,19 +8,22 @@ import glob import cv2 + # The basic dataset of reading rays class VideoDataset(Dataset): - def __init__(self, - root_dir, - split='train', - img_wh=(504, 378), - mask_dir=None, - flow_dir=None, - canonical_wh=None, - ref_idx=None, - canonical_dir=None, - test=False): + def __init__( + self, + root_dir, + split="train", + img_wh=(504, 378), + mask_dir=None, + flow_dir=None, + canonical_wh=None, + ref_idx=None, + canonical_dir=None, + test=False, + ): self.test = test self.root_dir = root_dir self.split = split @@ -41,11 +44,11 @@ def read_meta(self): # construct grid grid = np.indices((h, w)).astype(np.float32) # normalize - grid[0,:,:] = grid[0,:,:] / h - grid[1,:,:] = grid[1,:,:] / w - self.grid = torch.from_numpy(rearrange(grid, 'c h w -> (h w) c')) + grid[0, :, :] = grid[0, :, :] / h + grid[1, :, :] = grid[1, :, :] / w + self.grid = torch.from_numpy(rearrange(grid, "c h w -> (h w) c")) warp_code = 1 - for input_image_path in sorted(glob.glob(f'{self.root_dir}/*')): + for input_image_path in sorted(glob.glob(f"{self.root_dir}/*")): print(input_image_path) all_images_path.append(input_image_path) self.ts_w.append(torch.Tensor([warp_code]).long()) @@ -55,9 +58,9 @@ def read_meta(self): h_c = self.canonical_wh[1] w_c = self.canonical_wh[0] grid_c = np.indices((h_c, w_c)).astype(np.float32) - grid_c[0,:,:] = (grid_c[0,:,:] - (h_c - h) / 2) / h - grid_c[1,:,:] = (grid_c[1,:,:] - (w_c - w) / 2) / w - self.grid_c = torch.from_numpy(rearrange(grid_c, 'c h w -> (h w) c')) + grid_c[0, :, :] = (grid_c[0, :, :] - (h_c - h) / 2) / h + grid_c[1, :, :] = (grid_c[1, :, :] - (w_c - w) / 2) / w + self.grid_c = torch.from_numpy(rearrange(grid_c, "c h w -> (h w) c")) else: self.grid_c = self.grid self.canonical_wh = self.img_wh @@ -69,14 +72,18 @@ def read_meta(self): else: self.all_flows = None - if self.split == 'train' or self.split == 'val': + if self.split == "train" or self.split == "val": if self.canonical_dir is not None: - all_images_path_ = sorted(glob.glob(f'{self.canonical_dir}/*.png')) + all_images_path_ = sorted(glob.glob(f"{self.canonical_dir}/*.png")) self.canonical_img = [] for input_image_path in all_images_path_: input_image = cv2.imread(input_image_path) input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) - input_image = cv2.resize(input_image, (self.canonical_wh[0], self.canonical_wh[1]), interpolation = cv2.INTER_AREA) + input_image = cv2.resize( + input_image, + (self.canonical_wh[0], self.canonical_wh[1]), + interpolation=cv2.INTER_AREA, + ) input_image_tensor = torch.from_numpy(input_image).float() / 256 self.canonical_img.append(input_image_tensor) self.canonical_img = torch.stack(self.canonical_img, dim=0) @@ -84,53 +91,79 @@ def read_meta(self): for input_image_path in all_images_path: input_image = cv2.imread(input_image_path) input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) - input_image = cv2.resize(input_image, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA) + input_image = cv2.resize( + input_image, + (self.img_wh[0], self.img_wh[1]), + interpolation=cv2.INTER_AREA, + ) input_image_tensor = torch.from_numpy(input_image).float() / 256 self.all_images.append(input_image_tensor) if self.mask_dir: input_image_name = input_image_path.split("/")[-1][:-4] for i in range(len(self.mask_dir)): - input_mask = cv2.imread(f'{self.mask_dir[i]}/{input_image_name}.png') - input_mask = cv2.resize(input_mask, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA) + input_mask = cv2.imread( + f"{self.mask_dir[i]}/{input_image_name}.png" + ) + input_mask = cv2.resize( + input_mask, + (self.img_wh[0], self.img_wh[1]), + interpolation=cv2.INTER_AREA, + ) input_mask_tensor = torch.from_numpy(input_mask).float() / 256 self.all_masks.append(input_mask_tensor) - if self.split == 'val': + if self.split == "val": input_image = cv2.imread(all_images_path[0]) input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) - input_image = cv2.resize(input_image, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA) + input_image = cv2.resize( + input_image, + (self.img_wh[0], self.img_wh[1]), + interpolation=cv2.INTER_AREA, + ) input_image_tensor = torch.from_numpy(input_image).float() / 256 self.all_images.append(input_image_tensor) if self.mask_dir: input_image_name = all_images_path[0].split("/")[-1][:-4] for i in range(len(self.mask_dir)): - input_mask = cv2.imread(f'{self.mask_dir[i]}/{input_image_name}.png') - input_mask = cv2.resize(input_mask, (self.img_wh[0], self.img_wh[1]), interpolation = cv2.INTER_AREA) + input_mask = cv2.imread( + f"{self.mask_dir[i]}/{input_image_name}.png" + ) + input_mask = cv2.resize( + input_mask, + (self.img_wh[0], self.img_wh[1]), + interpolation=cv2.INTER_AREA, + ) input_mask_tensor = torch.from_numpy(input_mask).float() / 256 self.all_masks.append(input_mask_tensor) if self.flow_dir: - for input_image_path in sorted(glob.glob(f'{self.flow_dir}/*npy')): - flow_load=np.load(input_image_path) # (1, 2, h, w) - flow_tensor=torch.from_numpy(flow_load).float()[:, [1, 0]] - flow_tensor=torch.nn.functional.interpolate(flow_tensor,size=(self.img_wh[1],self.img_wh[0])) - H_,W_=flow_load.shape[-2],flow_load.shape[-1] - flow_tensor=flow_tensor.reshape(2,-1).transpose(1,0) + for input_image_path in sorted(glob.glob(f"{self.flow_dir}/*npy")): + flow_load = np.load(input_image_path) # (1, 2, h, w) + flow_tensor = torch.from_numpy(flow_load).float()[:, [1, 0]] + flow_tensor = torch.nn.functional.interpolate( + flow_tensor, size=(self.img_wh[1], self.img_wh[0]) + ) + H_, W_ = flow_load.shape[-2], flow_load.shape[-1] + flow_tensor = flow_tensor.reshape(2, -1).transpose(1, 0) flow_tensor[..., 0] /= W_ flow_tensor[..., 1] /= H_ self.all_flows.append(flow_tensor) i = 0 - for input_image_path in sorted(glob.glob(f'{self.flow_dir}_confidence/*npy')): - flow_load=np.load(input_image_path) - flow_tensor=torch.from_numpy(flow_load).float() - flow_tensor=torch.nn.functional.interpolate(flow_tensor,size=(self.img_wh[1],self.img_wh[0])) - flow_tensor=flow_tensor.reshape(1,-1).transpose(1,0) + for input_image_path in sorted( + glob.glob(f"{self.flow_dir}_confidence/*npy") + ): + flow_load = np.load(input_image_path) + flow_tensor = torch.from_numpy(flow_load).float() + flow_tensor = torch.nn.functional.interpolate( + flow_tensor, size=(self.img_wh[1], self.img_wh[0]) + ) + flow_tensor = flow_tensor.reshape(1, -1).transpose(1, 0) flow_tensor = flow_tensor.sum(dim=-1) < 0.05 self.all_flows[i][flow_tensor] = 5 - i += 1 + i += 1 - if self.split == 'val': + if self.split == "val": self.ref_idx = 0 def __len__(self): @@ -139,18 +172,45 @@ def __len__(self): return 200 * len(self.all_images) def __getitem__(self, idx): - if self.split == 'train' or self.split == 'val': + if self.split == "train" or self.split == "val": idx = idx % len(self.all_images) - sample = {'rgbs': self.all_images[idx], - 'canonical_img': self.all_images[idx] if self.canonical_dir is None else self.canonical_img, - 'ts_w': self.ts_w[idx], - 'grid': self.grid, - 'canonical_wh': self.canonical_wh, - 'img_wh': self.img_wh, - 'masks': self.all_masks[len(self.mask_dir)*idx:len(self.mask_dir)*idx+len(self.mask_dir)] if self.mask_dir else [torch.ones((self.img_wh[1], self.img_wh[0], 1))], - 'flows': self.all_flows[idx] if (idx b c h w', b=1) + gt_lpips = rearrange(gt_lpips, "(b h) w c -> b c h w", b=1) gt_lpips = torch.from_numpy(gt_lpips) predict_image_lpips = image_pred.clone().detach().cpu() * 2.0 - 1.0 - predict_image_lpips = rearrange(predict_image_lpips, '(b h) w c -> b c h w', b=1) - lpips_result = lpips_model.forward(predict_image_lpips, gt_lpips).cpu().detach().numpy() + predict_image_lpips = rearrange(predict_image_lpips, "(b h) w c -> b c h w", b=1) + lpips_result = ( + lpips_model.forward(predict_image_lpips, gt_lpips).cpu().detach().numpy() + ) return np.squeeze(lpips_result) diff --git a/models/implicit_model.py b/models/implicit_model.py index 133afa6..e9567e5 100755 --- a/models/implicit_model.py +++ b/models/implicit_model.py @@ -11,9 +11,7 @@ def init_weights(m): class TranslationField(nn.Module): - def __init__(self, D=6, W=128, - in_channels_w=8, in_channels_xyz=34, - skips=[4]): + def __init__(self, D=6, W=128, in_channels_w=8, in_channels_xyz=34, skips=[4]): """ D: number of layers for density (sigma) encoder W: number of hidden units in each layer @@ -32,9 +30,9 @@ def __init__(self, D=6, W=128, # encoding layers for i in range(D): if i == 0: - layer = nn.Linear(in_channels_xyz+self.in_channels_w, W) + layer = nn.Linear(in_channels_xyz + self.in_channels_w, W) elif i in skips: - layer = nn.Linear(W+in_channels_xyz+self.in_channels_w, W) + layer = nn.Linear(W + in_channels_xyz + self.in_channels_w, W) else: layer = nn.Linear(W, W) init_weights(layer) @@ -81,12 +79,12 @@ def __init__(self, in_channels, N_freqs, logscale=True, identity=True): self.identity = identity self.in_channels = in_channels self.funcs = [torch.sin, torch.cos] - self.out_channels = in_channels*(len(self.funcs)*N_freqs+1) + self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1) if logscale: - self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs) + self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs) else: - self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs) + self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs) def forward(self, x): """ @@ -106,13 +104,21 @@ def forward(self, x): out = [] for freq in self.freq_bands: for func in self.funcs: - out += [func(freq*x)] + out += [func(freq * x)] return torch.cat(out, -1) class AnnealedEmbedding(nn.Module): - def __init__(self, in_channels, N_freqs, annealed_step, annealed_begin_step=0, logscale=True, identity=True): + def __init__( + self, + in_channels, + N_freqs, + annealed_step, + annealed_begin_step=0, + logscale=True, + identity=True, + ): """ Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) in_channels: number of input channels (3 for both xyz and direction) @@ -124,14 +130,14 @@ def __init__(self, in_channels, N_freqs, annealed_step, annealed_begin_step=0, l self.annealed_step = annealed_step self.annealed_begin_step = annealed_begin_step self.funcs = [torch.sin, torch.cos] - self.out_channels = in_channels*(len(self.funcs)*N_freqs+1) - self.index = torch.linspace(0, N_freqs-1, N_freqs) + self.out_channels = in_channels * (len(self.funcs) * N_freqs + 1) + self.index = torch.linspace(0, N_freqs - 1, N_freqs) self.identity = identity if logscale: - self.freq_bands = 2**torch.linspace(0, N_freqs-1, N_freqs) + self.freq_bands = 2 ** torch.linspace(0, N_freqs - 1, N_freqs) else: - self.freq_bands = torch.linspace(1, 2**(N_freqs-1), N_freqs) + self.freq_bands = torch.linspace(1, 2 ** (N_freqs - 1), N_freqs) def forward(self, x, step): """ @@ -157,20 +163,24 @@ def forward(self, x, step): if step <= self.annealed_begin_step: alpha = 0 else: - alpha = self.N_freqs * (step - self.annealed_begin_step) / float( - self.annealed_step) + alpha = ( + self.N_freqs + * (step - self.annealed_begin_step) + / float(self.annealed_step) + ) for j, freq in enumerate(self.freq_bands): - w = (1 - torch.cos( - math.pi * torch.clamp(alpha - self.index[j], 0, 1))) / 2 + w = (1 - torch.cos(math.pi * torch.clamp(alpha - self.index[j], 0, 1))) / 2 for func in self.funcs: - out += [w * func(freq*x)] + out += [w * func(freq * x)] return torch.cat(out, -1) class AnnealedHash(nn.Module): - def __init__(self, in_channels, annealed_step, annealed_begin_step=0, identity=True): + def __init__( + self, in_channels, annealed_step, annealed_begin_step=0, identity=True + ): """ Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) in_channels: number of input channels (3 for both xyz and direction) @@ -207,10 +217,21 @@ def forward(self, x_embed, step): if step <= self.annealed_begin_step: alpha = 0 else: - alpha = self.N_freqs * (step - self.annealed_begin_step) / float( - self.annealed_step) - - w = (1 - torch.cos(math.pi * torch.clamp(alpha * torch.ones_like(self.index_2) - self.index_2, 0, 1))) / 2 + alpha = ( + self.N_freqs + * (step - self.annealed_begin_step) + / float(self.annealed_step) + ) + + w = ( + 1 + - torch.cos( + math.pi + * torch.clamp( + alpha * torch.ones_like(self.index_2) - self.index_2, 0, 1 + ) + ) + ) / 2 out = x_embed * w.to(x_embed.device) @@ -218,12 +239,15 @@ def forward(self, x_embed, step): class ImplicitVideo(nn.Module): - def __init__(self, - D=8, W=256, - in_channels_xyz=34, - skips=[4], - out_channels=3, - sigmoid_offset=0): + def __init__( + self, + D=8, + W=256, + in_channels_xyz=34, + skips=[4], + out_channels=3, + sigmoid_offset=0, + ): """ D: number of layers for density (sigma) encoder W: number of hidden units in each layer @@ -247,7 +271,7 @@ def __init__(self, if i == 0: layer = nn.Linear(self.in_channels_xyz, W) elif i in skips: - layer = nn.Linear(W+self.in_channels_xyz, W) + layer = nn.Linear(W + self.in_channels_xyz, W) else: layer = nn.Linear(W, W) init_weights(layer) @@ -297,12 +321,12 @@ def forward(self, x): class ImplicitVideo_Hash(nn.Module): def __init__(self, config): super().__init__() - self.encoder = tcnn.Encoding(n_input_dims=2, - encoding_config=config["encoding"]) - self.decoder = tcnn.Network(n_input_dims=self.encoder.n_output_dims + - 2, - n_output_dims=3, - network_config=config["network"]) + self.encoder = tcnn.Encoding(n_input_dims=2, encoding_config=config["encoding"]) + self.decoder = tcnn.Network( + n_input_dims=self.encoder.n_output_dims + 2, + n_output_dims=3, + network_config=config["network"], + ) def forward(self, x): input = x @@ -316,17 +340,20 @@ def forward(self, x): class Deform_Hash3d(nn.Module): def __init__(self, config): super().__init__() - self.encoder = tcnn.Encoding(n_input_dims=3, - encoding_config=config["encoding_deform3d"]) - self.decoder = tcnn.Network(n_input_dims=self.encoder.n_output_dims + 3, - n_output_dims=2, - network_config=config["network_deform"]) + self.encoder = tcnn.Encoding( + n_input_dims=3, encoding_config=config["encoding_deform3d"] + ) + self.decoder = tcnn.Network( + n_input_dims=self.encoder.n_output_dims + 3, + n_output_dims=2, + network_config=config["network_deform"], + ) def forward(self, x, step=0, aneal_func=None): input = x input = self.encoder(input) if aneal_func is not None: - input = torch.cat([x, aneal_func(input,step)], dim=-1) + input = torch.cat([x, aneal_func(input, step)], dim=-1) else: input = torch.cat([x, input], dim=-1) @@ -341,7 +368,7 @@ def __init__(self, config): super().__init__() self.Deform_Hash3d = Deform_Hash3d(config) - def forward(self, xyt_norm, step=0,aneal_func=None): - x = self.Deform_Hash3d(xyt_norm,step=step, aneal_func=aneal_func) + def forward(self, xyt_norm, step=0, aneal_func=None): + x = self.Deform_Hash3d(xyt_norm, step=step, aneal_func=aneal_func) return x diff --git a/opt.py b/opt.py index b62ccb9..ab4c132 100755 --- a/opt.py +++ b/opt.py @@ -6,155 +6,280 @@ def get_opts(): parser = argparse.ArgumentParser() # General Setttings - parser.add_argument('--root_dir', type=str, default='Batman_masked_frames', - help='root directory of dataset') - parser.add_argument('--canonical_dir', type=str, default=None, - help='directory of canonical dataset') + parser.add_argument( + "--root_dir", + type=str, + default="Batman_masked_frames", + help="root directory of dataset", + ) + parser.add_argument( + "--canonical_dir", type=str, default=None, help="directory of canonical dataset" + ) # support multiple mask as input (each mask has different deformation fields) - parser.add_argument('--mask_dir', nargs="+", type=str, default=None, - help='mask of the dataset') - parser.add_argument('--flow_dir', type=str, - default=None, - help='masks of dataset') - parser.add_argument('--dataset_name', type=str, default='video', - choices=['video'], - help='which dataset to train/val') - parser.add_argument('--img_wh', nargs="+", type=int, default=[842, 512], - help='resolution (img_w, img_h) of the full image') - parser.add_argument('--canonical_wh', nargs="+", type=int, default=None, - help='default same as the img_wh, can be set to a larger range to include more content') - parser.add_argument('--ref_idx', type=int, default=None, - help='manually select a frame as reference (for rigid movement)') + parser.add_argument( + "--mask_dir", nargs="+", type=str, default=None, help="mask of the dataset" + ) + parser.add_argument("--flow_dir", type=str, default=None, help="masks of dataset") + parser.add_argument( + "--dataset_name", + type=str, + default="video", + choices=["video"], + help="which dataset to train/val", + ) + parser.add_argument( + "--img_wh", + nargs="+", + type=int, + default=[842, 512], + help="resolution (img_w, img_h) of the full image", + ) + parser.add_argument( + "--canonical_wh", + nargs="+", + type=int, + default=None, + help="default same as the img_wh, can be set to a larger range to include more content", + ) + parser.add_argument( + "--ref_idx", + type=int, + default=None, + help="manually select a frame as reference (for rigid movement)", + ) # Deformation Setting - parser.add_argument('--encode_w', default=False, action="store_true", - help='whether to apply warping') + parser.add_argument( + "--encode_w", + default=False, + action="store_true", + help="whether to apply warping", + ) # Training Setttings - parser.add_argument('--batch_size', type=int, default=1, - help='batch size') - parser.add_argument('--num_steps', type=int, default=10000, - help='number of training epochs') - parser.add_argument('--valid_iters', type=int, default=30, - help='valid iters for each epoch') - parser.add_argument('--valid_batches', type=int, default=0, - help='valid batches for each valid process') - parser.add_argument('--save_model_iters', type=int, default=5000, - help='iterations to save the models') - parser.add_argument('--gpus', nargs="+", type=int, default=[0], - help='gpu devices') + parser.add_argument("--batch_size", type=int, default=1, help="batch size") + parser.add_argument( + "--num_steps", type=int, default=10000, help="number of training epochs" + ) + parser.add_argument( + "--valid_iters", type=int, default=30, help="valid iters for each epoch" + ) + parser.add_argument( + "--valid_batches", + type=int, + default=0, + help="valid batches for each valid process", + ) + parser.add_argument( + "--save_model_iters", + type=int, + default=5000, + help="iterations to save the models", + ) + parser.add_argument("--gpus", nargs="+", type=int, default=[0], help="gpu devices") # Test Setttings - parser.add_argument('--test', default=False, action="store_true", - help='whether to disable identity') + parser.add_argument( + "--test", default=False, action="store_true", help="whether to disable identity" + ) # Model Save and Load - parser.add_argument('--ckpt_path', type=str, default=None, - help='pretrained checkpoint to load (including optimizers, etc)') - parser.add_argument('--prefixes_to_ignore', nargs='+', type=str, default=['loss'], - help='the prefixes to ignore in the checkpoint state dict') - parser.add_argument('--weight_path', type=str, default=None, - help='pretrained model weight to load (do not load optimizers, etc)') - parser.add_argument('--model_save_path', type=str, default='ckpts', - help='save checkpoint to') - parser.add_argument('--log_save_path', type=str, default='logs', - help='save log to') - parser.add_argument('--exp_name', type=str, default='exp', - help='experiment name') + parser.add_argument( + "--ckpt_path", + type=str, + default=None, + help="pretrained checkpoint to load (including optimizers, etc)", + ) + parser.add_argument( + "--prefixes_to_ignore", + nargs="+", + type=str, + default=["loss"], + help="the prefixes to ignore in the checkpoint state dict", + ) + parser.add_argument( + "--weight_path", + type=str, + default=None, + help="pretrained model weight to load (do not load optimizers, etc)", + ) + parser.add_argument( + "--model_save_path", type=str, default="ckpts", help="save checkpoint to" + ) + parser.add_argument("--log_save_path", type=str, default="logs", help="save log to") + parser.add_argument("--exp_name", type=str, default="exp", help="experiment name") # Optimize Settings - parser.add_argument('--optimizer', type=str, default='adam', - help='optimizer type', - choices=['sgd', 'adam', 'radam', 'ranger']) - parser.add_argument('--lr', type=float, default=5e-4, - help='learning rate') - parser.add_argument('--momentum', type=float, default=0.9, - help='learning rate momentum') - parser.add_argument('--weight_decay', type=float, default=0, - help='weight decay') - parser.add_argument('--lr_scheduler', type=str, default='steplr', - help='scheduler type', - choices=['steplr', 'cosine', 'poly', 'exponential']) + parser.add_argument( + "--optimizer", + type=str, + default="adam", + help="optimizer type", + choices=["sgd", "adam", "radam", "ranger"], + ) + parser.add_argument("--lr", type=float, default=5e-4, help="learning rate") + parser.add_argument( + "--momentum", type=float, default=0.9, help="learning rate momentum" + ) + parser.add_argument("--weight_decay", type=float, default=0, help="weight decay") + parser.add_argument( + "--lr_scheduler", + type=str, + default="steplr", + help="scheduler type", + choices=["steplr", "cosine", "poly", "exponential"], + ) #### params for steplr #### - parser.add_argument('--decay_step', nargs='+', type=int, - default=[2500, 5000, 7500], - help='scheduler decay step') - parser.add_argument('--decay_gamma', type=float, default=0.5, - help='learning rate decay amount') + parser.add_argument( + "--decay_step", + nargs="+", + type=int, + default=[2500, 5000, 7500], + help="scheduler decay step", + ) + parser.add_argument( + "--decay_gamma", type=float, default=0.5, help="learning rate decay amount" + ) #### params for warmup, only applied when optimizer == 'sgd' or 'adam' - parser.add_argument('--warmup_multiplier', type=float, default=1.0, - help='lr is multiplied by this factor after --warmup_epochs') - parser.add_argument('--warmup_epochs', type=int, default=0, - help='Gradually warm-up(increasing) learning rate in optimizer') + parser.add_argument( + "--warmup_multiplier", + type=float, + default=1.0, + help="lr is multiplied by this factor after --warmup_epochs", + ) + parser.add_argument( + "--warmup_epochs", + type=int, + default=0, + help="Gradually warm-up(increasing) learning rate in optimizer", + ) ##### annealed positional encoding ###### - parser.add_argument('--annealed', default=False, action="store_true", - help='whether to apply annealed positional encoding (Only in the warping field)') - parser.add_argument('--annealed_begin_step', type=int, default=0, - help='annealed step to begin for positional encoding') - parser.add_argument('--annealed_step', type=int, default=5000, - help='maximum annealed step for positional encoding') + parser.add_argument( + "--annealed", + default=False, + action="store_true", + help="whether to apply annealed positional encoding (Only in the warping field)", + ) + parser.add_argument( + "--annealed_begin_step", + type=int, + default=0, + help="annealed step to begin for positional encoding", + ) + parser.add_argument( + "--annealed_step", + type=int, + default=5000, + help="maximum annealed step for positional encoding", + ) ##### Additional losses ###### - parser.add_argument('--flow_loss', type=float, default=None, - help='optical flow loss weight') - parser.add_argument('--bg_loss', type=float, default=None, - help='regularize the rest part of each object ') - parser.add_argument('--grad_loss', type=float, default=0.1, - help='image gradient loss weight') - parser.add_argument('--flow_step', type=int, default=-1, - help='Step to begin to perform flow loss.') - parser.add_argument('--ref_step', type=int, default=-1, - help='Step to stop reference frame loss.') - parser.add_argument('--self_bg', type=bool_parser, default=False, - help='Whether to use self background as bg loss.') + parser.add_argument( + "--flow_loss", type=float, default=None, help="optical flow loss weight" + ) + parser.add_argument( + "--bg_loss", + type=float, + default=None, + help="regularize the rest part of each object ", + ) + parser.add_argument( + "--grad_loss", type=float, default=0.1, help="image gradient loss weight" + ) + parser.add_argument( + "--flow_step", type=int, default=-1, help="Step to begin to perform flow loss." + ) + parser.add_argument( + "--ref_step", type=int, default=-1, help="Step to stop reference frame loss." + ) + parser.add_argument( + "--self_bg", + type=bool_parser, + default=False, + help="Whether to use self background as bg loss.", + ) ##### Special cases: for black-dominated images - parser.add_argument('--sigmoid_offset', type=float, default=0, - help='whether to process balck-dominated images.') + parser.add_argument( + "--sigmoid_offset", + type=float, + default=0, + help="whether to process balck-dominated images.", + ) # Other miscellaneous settings. - parser.add_argument('--save_deform', type=bool_parser, default=False, - help='Whether to save deformation field or not.') - parser.add_argument('--save_video', type=bool_parser, default=True, - help='Whether to save video or not.') - parser.add_argument('--fps', type=int, default=30, - help='FPS of the saved video.') + parser.add_argument( + "--save_deform", + type=bool_parser, + default=False, + help="Whether to save deformation field or not.", + ) + parser.add_argument( + "--save_video", + type=bool_parser, + default=True, + help="Whether to save video or not.", + ) + parser.add_argument("--fps", type=int, default=30, help="FPS of the saved video.") # Network settings for PE. - parser.add_argument('--deform_D', type=int, default=6, - help='The depth of deformation field MLP.') - parser.add_argument('--deform_W', type=int, default=128, - help='The width of deformation field MLP.') - parser.add_argument('--vid_D', type=int, default=8, - help='The depth of implicit video MLP.') - parser.add_argument('--vid_W', type=int, default=256, - help='The width of implicit video MLP.') - parser.add_argument('--N_vocab_w', type=int, default=200, - help='number of vocabulary for warp code in the dataset for nn.Embedding') - parser.add_argument('--N_w', type=int, default=8, - help='embeddings size for warping') - parser.add_argument('--N_xyz_w', nargs="+", type=int, default=[8, 8], - help='positional encoding frequency of deformation field') + parser.add_argument( + "--deform_D", type=int, default=6, help="The depth of deformation field MLP." + ) + parser.add_argument( + "--deform_W", type=int, default=128, help="The width of deformation field MLP." + ) + parser.add_argument( + "--vid_D", type=int, default=8, help="The depth of implicit video MLP." + ) + parser.add_argument( + "--vid_W", type=int, default=256, help="The width of implicit video MLP." + ) + parser.add_argument( + "--N_vocab_w", + type=int, + default=200, + help="number of vocabulary for warp code in the dataset for nn.Embedding", + ) + parser.add_argument( + "--N_w", type=int, default=8, help="embeddings size for warping" + ) + parser.add_argument( + "--N_xyz_w", + nargs="+", + type=int, + default=[8, 8], + help="positional encoding frequency of deformation field", + ) # Network settings for Hash, please see details in configs/hash.json - parser.add_argument('--vid_hash', type=bool_parser, default=False, - help='Whether to use hash encoding in implicit video system.') - parser.add_argument('--deform_hash', type=bool_parser, default=False, - help='Whether to use hash encoding in deformation field.') + parser.add_argument( + "--vid_hash", + type=bool_parser, + default=False, + help="Whether to use hash encoding in implicit video system.", + ) + parser.add_argument( + "--deform_hash", + type=bool_parser, + default=False, + help="Whether to use hash encoding in deformation field.", + ) # Config files - parser.add_argument('--config', type=str, default=None, - help='path to the YAML config file.') + parser.add_argument( + "--config", type=str, default=None, help="path to the YAML config file." + ) args = parser.parse_args() if args.config is not None: - with open(args.config, 'r') as f: + with open(args.config, "r") as f: config = yaml.safe_load(f) args_dict = vars(args) args_dict.update(config) @@ -170,8 +295,8 @@ def bool_parser(arg): return arg if arg is None: return False - if arg.lower() in ['1', 'true', 't', 'yes', 'y']: + if arg.lower() in ["1", "true", "t", "yes", "y"]: return True - if arg.lower() in ['0', 'false', 'f', 'no', 'n']: + if arg.lower() in ["0", "false", "f", "no", "n"]: return False - raise ValueError(f'`{arg}` cannot be converted to boolean!') + raise ValueError(f"`{arg}` cannot be converted to boolean!") diff --git a/prepare_video.sh b/prepare_video.sh new file mode 100644 index 0000000..616cb25 --- /dev/null +++ b/prepare_video.sh @@ -0,0 +1,2 @@ +ffmpeg -i $1 -vf "crop=1080:1080:420:0" cropped.mp4 +ffmpeg -i cropped.mp4 -start_number 0 -vf "fps=25" /root/CoDeF/all_sequences/heygen/heygen/%05d.png \ No newline at end of file diff --git a/resize_image.py b/resize_image.py new file mode 100644 index 0000000..b14e779 --- /dev/null +++ b/resize_image.py @@ -0,0 +1,13 @@ +from PIL import Image + + +path = "/root/CoDeF/all_sequences/heygen/base_control/canonical_0.png" + +# Open the original image +original_image = Image.open(path) + +# Resize the image +resized_image = original_image.resize((1080, 1080)) + +# Save the resized image +resized_image.save(path) diff --git a/scripts/test_canonical.sh b/scripts/test_canonical.sh index a60140a..33b27b7 100755 --- a/scripts/test_canonical.sh +++ b/scripts/test_canonical.sh @@ -1,22 +1,19 @@ GPUS=0 -NAME=scene_0 +NAME=heygen EXP_NAME=base ROOT_DIRECTORY="all_sequences/$NAME/$NAME" LOG_SAVE_PATH="logs/test_all_sequences/$NAME" -MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1" - CANONICAL_DIR="all_sequences/${NAME}/${EXP_NAME}_control" -WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt -# WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=10000.ckpt +# WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt +WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=200000.ckpt python train.py --test --encode_w \ --root_dir $ROOT_DIRECTORY \ --log_save_path $LOG_SAVE_PATH \ - --mask_dir $MASK_DIRECTORY \ --weight_path $WEIGHT_PATH \ --gpus $GPUS \ --canonical_dir $CANONICAL_DIR \ diff --git a/scripts/test_multi.sh b/scripts/test_multi.sh index 1beeda7..c491b63 100755 --- a/scripts/test_multi.sh +++ b/scripts/test_multi.sh @@ -1,20 +1,17 @@ GPUS=0 -NAME=scene_0 +NAME=heygen EXP_NAME=base ROOT_DIRECTORY="all_sequences/$NAME/$NAME" LOG_SAVE_PATH="logs/test_all_sequences/$NAME" -MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1" - -WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt -# WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=10000.ckpt +# WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/${NAME}.ckpt +WEIGHT_PATH=ckpts/all_sequences/$NAME/${EXP_NAME}/step=200000.ckpt python train.py --test --encode_w \ --root_dir $ROOT_DIRECTORY \ --log_save_path $LOG_SAVE_PATH \ - --mask_dir $MASK_DIRECTORY \ --weight_path $WEIGHT_PATH \ --gpus $GPUS \ --config configs/${NAME}/${EXP_NAME}.yaml \ diff --git a/scripts/train_multi.sh b/scripts/train_multi.sh index 297457f..66459e7 100755 --- a/scripts/train_multi.sh +++ b/scripts/train_multi.sh @@ -1,19 +1,17 @@ GPUS=0 -NAME=scene_0 +NAME=heygen EXP_NAME=base ROOT_DIRECTORY="all_sequences/$NAME/$NAME" MODEL_SAVE_PATH="ckpts/all_sequences/$NAME" LOG_SAVE_PATH="logs/all_sequences/$NAME" -MASK_DIRECTORY="all_sequences/$NAME/${NAME}_masks_0 all_sequences/$NAME/${NAME}_masks_1" FLOW_DIRECTORY="all_sequences/$NAME/${NAME}_flow" python train.py --root_dir $ROOT_DIRECTORY \ --model_save_path $MODEL_SAVE_PATH \ --log_save_path $LOG_SAVE_PATH \ - --mask_dir $MASK_DIRECTORY \ --flow_dir $FLOW_DIRECTORY \ --gpus $GPUS \ --encode_w --annealed \ diff --git a/train.py b/train.py index 8f30b3b..50cd154 100755 --- a/train.py +++ b/train.py @@ -42,15 +42,15 @@ class ImplicitVideoSystem(LightningModule): def __init__(self, hparams): super(ImplicitVideoSystem, self).__init__() self.save_hyperparameters(hparams) - self.color_loss = loss_dict['mse'](coef=1) + self.color_loss = loss_dict["mse"](coef=1) if hparams.save_video: self.video_visualizer = VideoVisualizer(fps=hparams.fps) self.raw_video_visualizer = VideoVisualizer(fps=hparams.fps) self.dual_video_visualizer = VideoVisualizer(fps=hparams.fps) - self.models_to_train=[] + self.models_to_train = [] self.embedding_xyz = Embedding(2, 8) - self.embeddings = {'xyz': self.embedding_xyz} + self.embeddings = {"xyz": self.embedding_xyz} self.models = {} # Construct normalized meshgrid. @@ -69,8 +69,8 @@ def __init__(self, hparams): # Multiple deformation MLP. # Progressive Training for the Deformation (Annealed PE). # No trainable parameters. - self.embeddings['xyz_w'] = [] - assert (isinstance(self.hparams.N_xyz_w, list)) + self.embeddings["xyz_w"] = [] + assert isinstance(self.hparams.N_xyz_w, list) in_channels_xyz = [] for i in range(self.num_models): N_xyz_w = self.hparams.N_xyz_w[i] @@ -78,59 +78,62 @@ def __init__(self, hparams): if hparams.annealed: if hparams.deform_hash: self.embedding_hash = AnnealedHash( - in_channels=2, - annealed_step=hparams.annealed_step, - annealed_begin_step=hparams.annealed_begin_step) - self.embeddings['aneal_hash'] = self.embedding_hash + in_channels=2, + annealed_step=hparams.annealed_step, + annealed_begin_step=hparams.annealed_begin_step, + ) + self.embeddings["aneal_hash"] = self.embedding_hash else: self.embedding_xyz_w = AnnealedEmbedding( in_channels=2, N_freqs=N_xyz_w, annealed_step=hparams.annealed_step, - annealed_begin_step=hparams.annealed_begin_step) - self.embeddings['xyz_w'] += [self.embedding_xyz_w] + annealed_begin_step=hparams.annealed_begin_step, + ) + self.embeddings["xyz_w"] += [self.embedding_xyz_w] else: self.embedding_xyz_w = Embedding(2, N_xyz_w) - self.embeddings['xyz_w'] += [self.embedding_xyz_w] + self.embeddings["xyz_w"] += [self.embedding_xyz_w] for i in range(self.num_models): embedding_w = torch.nn.Embedding(hparams.N_vocab_w, hparams.N_w) torch.nn.init.uniform_(embedding_w.weight, -0.05, 0.05) - load_ckpt(embedding_w, hparams.weight_path, model_name=f'w_{i}') - self.embeddings[f'w_{i}'] = embedding_w - self.models_to_train += [self.embeddings[f'w_{i}']] + load_ckpt(embedding_w, hparams.weight_path, model_name=f"w_{i}") + self.embeddings[f"w_{i}"] = embedding_w + self.models_to_train += [self.embeddings[f"w_{i}"]] # Add warping field mlp. if hparams.deform_hash: - with open('configs/hash.json') as f: + with open("configs/hash.json") as f: config = json.load(f) warping_field = Deform_Hash3d_Warp(config=config) else: warping_field = TranslationField( D=self.hparams.deform_D, W=self.hparams.deform_W, - in_channels_xyz=in_channels_xyz[i]) + in_channels_xyz=in_channels_xyz[i], + ) - load_ckpt(warping_field, - hparams.weight_path, - model_name=f'warping_field_{i}') - self.models[f'warping_field_{i}'] = warping_field + load_ckpt( + warping_field, hparams.weight_path, model_name=f"warping_field_{i}" + ) + self.models[f"warping_field_{i}"] = warping_field # Set up the canonical model. if hparams.canonical_dir is None: for i in range(self.num_models): if hparams.vid_hash: - with open('configs/hash.json') as f: + with open("configs/hash.json") as f: config = json.load(f) implicit_video = ImplicitVideo_Hash(config=config) else: implicit_video = ImplicitVideo( D=hparams.vid_D, W=hparams.vid_W, - sigmoid_offset=hparams.sigmoid_offset) - load_ckpt(implicit_video, hparams.weight_path, - f'implicit_video_{i}') - self.models[f'implicit_video_{i}'] = implicit_video + sigmoid_offset=hparams.sigmoid_offset, + ) + load_ckpt(implicit_video, hparams.weight_path, f"implicit_video_{i}") + self.models[f"implicit_video_{i}"] = implicit_video for key in self.embeddings: setattr(self, key, self.embeddings[key]) @@ -144,50 +147,46 @@ def deform_pts(self, ts_w, grid, encode_w, step=0, i=0): ts_w_norm = ts_w / self.seq_len ts_w_norm = ts_w_norm.repeat(grid.shape[0], 1) input_xyt = torch.cat([grid, ts_w_norm], dim=-1) - if 'aneal_hash' in self.embeddings.keys(): - deform = self.models[f'warping_field_{i}']( - input_xyt, - step=step, - aneal_func=self.embeddings['aneal_hash']) + if "aneal_hash" in self.embeddings.keys(): + deform = self.models[f"warping_field_{i}"]( + input_xyt, step=step, aneal_func=self.embeddings["aneal_hash"] + ) else: - deform = self.models[f'warping_field_{i}'](input_xyt) + deform = self.models[f"warping_field_{i}"](input_xyt) if encode_w: deformed_grid = deform + grid else: deformed_grid = grid else: if encode_w: - e_w = self.embeddings[f'w_{i}'](repeat(ts_w, 'b n -> (b l) n ', - l=grid.shape[0])[:, 0]) + e_w = self.embeddings[f"w_{i}"]( + repeat(ts_w, "b n -> (b l) n ", l=grid.shape[0])[:, 0] + ) # Whether to use annealed positional encoding. if self.hparams.annealed: - pe_w = self.embeddings['xyz_w'][i](grid, step) + pe_w = self.embeddings["xyz_w"][i](grid, step) else: - pe_w = self.embeddings['xyz_w'][i](grid) + pe_w = self.embeddings["xyz_w"][i](grid) # Warping field type. - deform = self.models[f'warping_field_{i}'](torch.cat( - [e_w, pe_w], 1)) + deform = self.models[f"warping_field_{i}"](torch.cat([e_w, pe_w], 1)) deformed_grid = deform + grid else: deformed_grid = grid return deformed_grid - def forward(self, - ts_w, - grid, - encode_w, - step=0, - flows=None): + def forward(self, ts_w, grid, encode_w, step=0, flows=None): # grid -> positional encoding # ts_w -> embedding - grid = rearrange(grid, 'b n c -> (b n) c') + grid = rearrange(grid, "b n c -> (b n) c") results_list = [] flow_loss_list = [] deform_list = [] for i in range(self.num_models): - deformed_grid = self.deform_pts(ts_w, grid, encode_w, step, i) # [batch * num_pixels, 2] + deformed_grid = self.deform_pts( + ts_w, grid, encode_w, step, i + ) # [batch * num_pixels, 2] deform_list.append(deformed_grid) # Compute optical flow loss. flow_loss = 0 @@ -195,13 +194,14 @@ def forward(self, if flows.max() > -1e2 and step > self.hparams.flow_step: grid_new = grid + flows.squeeze(0) deformed_grid_new = self.deform_pts( - ts_w + 1, grid_new, encode_w, step, i) + ts_w + 1, grid_new, encode_w, step, i + ) flow_loss = (deformed_grid_new, deformed_grid) flow_loss_list.append(flow_loss) if self.hparams.vid_hash: pe_deformed_grid = (deformed_grid + 0.3) / 1.6 else: - pe_deformed_grid = self.embeddings['xyz'](deformed_grid) + pe_deformed_grid = self.embeddings["xyz"](deformed_grid) if not self.training and self.hparams.canonical_dir is not None: w, h = self.img_wh canonical_img = self.canonical_img.squeeze(0) @@ -212,19 +212,18 @@ def forward(self, if len(canonical_img.shape) == 3: canonical_img = canonical_img.unsqueeze(0) results = torch.nn.functional.grid_sample( - canonical_img[i:i + 1].permute(0, 3, 1, 2), + canonical_img[i : i + 1].permute(0, 3, 1, 2), grid_new.unsqueeze(1).unsqueeze(0), - mode='bilinear', - padding_mode='border') - results = results.squeeze().permute(1,0) + mode="bilinear", + padding_mode="border", + ) + results = results.squeeze().permute(1, 0) else: - results = self.models[f'implicit_video_{i}'](pe_deformed_grid) + results = self.models[f"implicit_video_{i}"](pe_deformed_grid) results_list.append(results) - ret = edict(rgbs=results_list, - flow_loss=flow_loss_list, - deform=deform_list) + ret = edict(rgbs=results_list, flow_loss=flow_loss_list, deform=deform_list) return ret @@ -232,16 +231,16 @@ def setup(self, stage): if not self.hparams.test: dataset = dataset_dict[self.hparams.dataset_name] kwargs = { - 'root_dir': self.hparams.root_dir, - 'img_wh': tuple(self.hparams.img_wh), - 'mask_dir': self.hparams.mask_dir, - 'flow_dir': self.hparams.flow_dir, - 'canonical_wh': self.hparams.canonical_wh, - 'ref_idx': self.hparams.ref_idx, - 'canonical_dir': self.hparams.canonical_dir + "root_dir": self.hparams.root_dir, + "img_wh": tuple(self.hparams.img_wh), + "mask_dir": self.hparams.mask_dir, + "flow_dir": self.hparams.flow_dir, + "canonical_wh": self.hparams.canonical_wh, + "ref_idx": self.hparams.ref_idx, + "canonical_dir": self.hparams.canonical_dir, } - self.train_dataset = dataset(split='train', **kwargs) - self.val_dataset = dataset(split='val', **kwargs) + self.train_dataset = dataset(split="train", **kwargs) + self.val_dataset = dataset(split="val", **kwargs) def configure_optimizers(self): self.optimizer = get_optimizer(self.hparams, self.models_to_train) @@ -251,11 +250,13 @@ def configure_optimizers(self): def train_dataloader(self): sampler = DistributedSampler(self.train_dataset, shuffle=True) - return DataLoader(self.train_dataset, - num_workers=4, - batch_size=self.hparams.batch_size, - sampler=sampler, - pin_memory=True) + return DataLoader( + self.train_dataset, + num_workers=4, + batch_size=self.hparams.batch_size, + sampler=sampler, + pin_memory=True, + ) def val_dataloader(self): return DataLoader( @@ -263,64 +264,68 @@ def val_dataloader(self): shuffle=False, num_workers=4, batch_size=1, # validate one image (H*W rays) at a time. - pin_memory=True) + pin_memory=True, + ) def test_dataloader(self): dataset = dataset_dict[self.hparams.dataset_name] kwargs = { - 'root_dir': self.hparams.root_dir, - 'img_wh': tuple(self.hparams.img_wh), - 'mask_dir': self.hparams.mask_dir, - 'canonical_wh': self.hparams.canonical_wh, - 'canonical_dir': self.hparams.canonical_dir, - 'test': self.hparams.test + "root_dir": self.hparams.root_dir, + "img_wh": tuple(self.hparams.img_wh), + "mask_dir": self.hparams.mask_dir, + "canonical_wh": self.hparams.canonical_wh, + "canonical_dir": self.hparams.canonical_dir, + "test": self.hparams.test, } - self.train_dataset = dataset(split='train', **kwargs) + self.train_dataset = dataset(split="train", **kwargs) return DataLoader( self.train_dataset, shuffle=False, num_workers=4, batch_size=1, # validate one image (H*W rays) at a time. - pin_memory=True) + pin_memory=True, + ) def training_step(self, batch, batch_idx): # Fetch training data. - rgbs = batch['rgbs'] - ts_w = batch['ts_w'] - grid = batch['grid'] - mk = batch['masks'] - flows = batch['flows'] - grid_c = batch['grid_c'] - ref_batch = batch['reference'] - self.seq_len = batch['seq_len'] + rgbs = batch["rgbs"] + ts_w = batch["ts_w"] + grid = batch["grid"] + mk = batch["masks"] + flows = batch["flows"] + grid_c = batch["grid_c"] + ref_batch = batch["reference"] + self.seq_len = batch["seq_len"] loss = 0 - rgbs_flattend = rearrange(rgbs, 'b h w c -> (b h w) c') + rgbs_flattend = rearrange(rgbs, "b h w c -> (b h w) c") # Forward the model. - ret = self.forward(ts_w, - grid, - self.hparams.encode_w, - self.global_step, - flows=flows) + ret = self.forward( + ts_w, grid, self.hparams.encode_w, self.global_step, flows=flows + ) # Mannually set a reference frame. - if self.hparams.ref_step < 0: self.hparams.step = 1e10 - if (self.hparams.ref_idx is not None - and self.global_step < self.hparams.ref_step): - rgbs_c_flattend = rearrange(ref_batch[0], - 'b h w c -> (b h w) c') + if self.hparams.ref_step < 0: + self.hparams.step = 1e10 + if ( + self.hparams.ref_idx is not None + and self.global_step < self.hparams.ref_step + ): + rgbs_c_flattend = rearrange(ref_batch[0], "b h w c -> (b h w) c") ret_c = self(ts_w, grid, False, self.global_step, flows=flows) # Loss computation. for i in range(self.num_models): results = ret.rgbs[i] - mk_t = rearrange(mk[i], 'b h w c -> (b h w) c') + mk_t = rearrange(mk[i], "b h w c -> (b h w) c") mk_t = mk_t.sum(dim=-1) > 0.05 - if (self.hparams.ref_idx is not None - and self.global_step < self.hparams.ref_step): - mk_c_t = rearrange(ref_batch[1][i], 'b h w c -> (b h w) c') + if ( + self.hparams.ref_idx is not None + and self.global_step < self.hparams.ref_step + ): + mk_c_t = rearrange(ref_batch[1][i], "b h w c -> (b h w) c") mk_c_t = mk_c_t.sum(dim=-1) > 0.05 # Background regularization. @@ -329,84 +334,84 @@ def training_step(self, batch, batch_idx): if self.hparams.self_bg: grid_flattened = rgbs_flattend else: - grid_flattened = rearrange(grid, 'b n c -> (b n) c') + grid_flattened = rearrange(grid, "b n c -> (b n) c") grid_flattened = torch.cat( - [grid_flattened, grid_flattened[:, :1]], -1) + [grid_flattened, grid_flattened[:, :1]], -1 + ) if self.hparams.bg_loss and self.hparams.mask_dir: loss = loss + self.hparams.bg_loss * self.color_loss( - results[mk1], grid_flattened[mk1]) + results[mk1], grid_flattened[mk1] + ) # MSE color loss. - loss = loss + self.color_loss(results[mk_t], - rgbs_flattend[mk_t]) + loss = loss + self.color_loss(results[mk_t], rgbs_flattend[mk_t]) # Image gradient loss. - img_pred = rearrange(results, - '(b h w) c -> b h w c', - b=1, - h=self.h, - w=self.w) - rgbs_gt = rearrange(rgbs_flattend, - '(b h w) c -> b h w c', - b=1, - h=self.h, - w=self.w) - mk_t_re = rearrange(mk_t, - '(b h w c) -> b h w c', - b=1, - h=self.h, - w=self.w) - grad_loss = compute_gradient_loss(rgbs_gt.permute(0, 3, 1, 2), - img_pred.permute(0, 3, 1, 2), - mask=mk_t_re.permute(0, 3, 1, 2)) + img_pred = rearrange( + results, "(b h w) c -> b h w c", b=1, h=self.h, w=self.w + ) + rgbs_gt = rearrange( + rgbs_flattend, "(b h w) c -> b h w c", b=1, h=self.h, w=self.w + ) + mk_t_re = rearrange(mk_t, "(b h w c) -> b h w c", b=1, h=self.h, w=self.w) + grad_loss = compute_gradient_loss( + rgbs_gt.permute(0, 3, 1, 2), + img_pred.permute(0, 3, 1, 2), + mask=mk_t_re.permute(0, 3, 1, 2), + ) loss = loss + grad_loss * self.hparams.grad_loss # Optical flow loss. if ret.flow_loss[0] != 0: - mk_flow_t = torch.logical_and(mk_t, flows[0].sum(dim=-1)< 3) - loss = loss + torch.nn.functional.l1_loss( - ret.flow_loss[i][0][mk_flow_t], ret.flow_loss[i][1] - [mk_flow_t]) * self.hparams.flow_loss + mk_flow_t = torch.logical_and(mk_t, flows[0].sum(dim=-1) < 3) + loss = ( + loss + + torch.nn.functional.l1_loss( + ret.flow_loss[i][0][mk_flow_t], ret.flow_loss[i][1][mk_flow_t] + ) + * self.hparams.flow_loss + ) # Reference loss. - if (self.hparams.ref_idx is not None - and self.global_step < self.hparams.ref_step): + if ( + self.hparams.ref_idx is not None + and self.global_step < self.hparams.ref_step + ): results_c = ret_c.rgbs[i] - loss += self.color_loss(results_c[mk_c_t], - rgbs_c_flattend[mk_c_t]) + loss += self.color_loss(results_c[mk_c_t], rgbs_c_flattend[mk_c_t]) # PSNR metric. with torch.no_grad(): if i == 0: psnr_ = psnr(results[mk_t], rgbs_flattend[mk_t]) - self.log('lr', get_learning_rate(self.optimizer), prog_bar=True) - self.log('train/loss', loss, prog_bar=True) - self.log('train/psnr', psnr_, prog_bar=True) + self.log("lr", get_learning_rate(self.optimizer), prog_bar=True) + self.log("train/loss", loss, prog_bar=True) + self.log("train/psnr", psnr_, prog_bar=True) return loss def validation_step(self, batch, batch_idx): - rgbs = batch['rgbs'] - ts_w = batch['ts_w'] - grid = batch['grid'] - mk = batch['masks'] + rgbs = batch["rgbs"] + ts_w = batch["ts_w"] + grid = batch["grid"] + mk = batch["masks"] grid_c = grid # batch['grid_c'] - self.seq_len = batch['seq_len'] + self.seq_len = batch["seq_len"] ret = self(ts_w, grid, self.hparams.encode_w, self.global_step) ret_c = self(ts_w, grid_c, False, self.global_step) log = {} W, H = self.hparams.img_wh - rgbs_flattend = rearrange(rgbs, 'b h w c -> (b h w) c') + rgbs_flattend = rearrange(rgbs, "b h w c -> (b h w) c") img_gt = rgbs_flattend.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) stack_list = [img_gt] for i in range(self.num_models): results = ret.rgbs[i] results_c = ret_c.rgbs[i] - mk_t = rearrange(mk[i], 'b h w c -> (b h w) c') + mk_t = rearrange(mk[i], "b h w c -> (b h w) c") if batch_idx == 0: results[mk_t.sum(dim=-1) <= 0.05] = 0 img = results.view(H, W, 3).permute(2, 0, 1).cpu() # (3, H, W) @@ -414,43 +419,47 @@ def validation_step(self, batch, batch_idx): stack_list.append(img) stack_list.append(img_c) - stack = torch.stack(stack_list) # (3, 3, H, W) - self.logger.experiment.add_images('val/GT_Reconstructed', stack, - self.global_step) + stack = torch.stack(stack_list) # (3, 3, H, W) + self.logger.experiment.add_images( + "val/GT_Reconstructed", stack, self.global_step + ) return log def test_step(self, batch, batch_idx): - ts_w = batch['ts_w'] - grid = batch['grid'] - mk = batch['masks'] - grid_c = batch['grid_c'] + ts_w = batch["ts_w"] + grid = batch["grid"] + mk = batch["masks"] + grid_c = batch["grid_c"] W, H = self.hparams.img_wh - self.seq_len = batch['seq_len'] + self.seq_len = batch["seq_len"] if self.hparams.canonical_dir is not None: - self.canonical_img = batch['canonical_img'] - self.img_wh = batch['img_wh'] - - save_dir = os.path.join('results', - self.hparams.root_dir.split('/')[0], - self.hparams.root_dir.split('/')[1], - self.hparams.exp_name) - sample_name = self.hparams.root_dir.split('/')[1] + self.canonical_img = batch["canonical_img"] + self.img_wh = batch["img_wh"] + + save_dir = os.path.join( + "results", + self.hparams.root_dir.split("/")[0], + self.hparams.root_dir.split("/")[1], + self.hparams.exp_name, + ) + sample_name = self.hparams.root_dir.split("/")[1] if self.hparams.canonical_dir is not None: - test_dir = f'{save_dir}_transformed' - video_name = f'{sample_name}_{self.hparams.exp_name}_transformed' + test_dir = f"{save_dir}_transformed" + video_name = f"{sample_name}_{self.hparams.exp_name}_transformed" else: - test_dir = f'{save_dir}' - video_name = f'{sample_name}_{self.hparams.exp_name}' + test_dir = f"{save_dir}" + video_name = f"{sample_name}_{self.hparams.exp_name}" Path(test_dir).mkdir(parents=True, exist_ok=True) if batch_idx > 0 and self.hparams.save_video: - self.video_visualizer.set_path(os.path.join( - test_dir, f'{video_name}.mp4')) - self.raw_video_visualizer.set_path(os.path.join( - test_dir, f'{video_name}_raw.mp4')) - self.dual_video_visualizer.set_path(os.path.join( - test_dir, f'{video_name}_dual.mp4')) + self.video_visualizer.set_path(os.path.join(test_dir, f"{video_name}.mp4")) + self.raw_video_visualizer.set_path( + os.path.join(test_dir, f"{video_name}_raw.mp4") + ) + self.dual_video_visualizer.set_path( + os.path.join(test_dir, f"{video_name}_dual.mp4") + ) if batch_idx == 0 and self.hparams.canonical_dir is None: # Save the canonical image. @@ -463,46 +472,56 @@ def test_step(self, batch, batch_idx): if batch_idx == 0 and self.hparams.canonical_dir is None: results_c = ret.rgbs[i] if self.hparams.canonical_wh: - img_c = results_c.view(self.hparams.canonical_wh[1], - self.hparams.canonical_wh[0], - 3).float().cpu().numpy() + img_c = ( + results_c.view( + self.hparams.canonical_wh[1], + self.hparams.canonical_wh[0], + 3, + ) + .float() + .cpu() + .numpy() + ) else: img_c = results_c.view(H, W, 3).float().cpu().numpy() img_c = cv2.cvtColor(img_c, cv2.COLOR_BGR2RGB) - cv2.imwrite(f'{test_dir}/canonical_{i}.png', img_c * 255) + cv2.imwrite(f"{test_dir}/canonical_{i}.png", img_c * 255) - mk_n = rearrange(mk[i], 'b h w c -> (b h w) c') + mk_n = rearrange(mk[i], "b h w c -> (b h w) c") mk_n = mk_n.sum(dim=-1) > 0.05 mk_n = mk_n.cpu().numpy() results = ret_n.rgbs[i] results = results.cpu().numpy() # (3, H, W) img[mk_n] = results[mk_n] - img = rearrange(img, '(h w) c -> h w c', h=H, w=W) + img = rearrange(img, "(h w) c -> h w c", h=H, w=W) img = img * 255 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - cv2.imwrite(f'{test_dir}/{batch_idx:05d}.png', img) + cv2.imwrite(f"{test_dir}/{batch_idx:05d}.png", img) if batch_idx > 0 and self.hparams.save_video: img = img[..., ::-1] self.video_visualizer.add(img) - rgbs = batch['rgbs'].view(H, W, 3).cpu().numpy() * 255 + rgbs = batch["rgbs"].view(H, W, 3).cpu().numpy() * 255 rgbs = rgbs.astype(np.uint8) self.raw_video_visualizer.add(rgbs) dual_img = np.concatenate((rgbs, img), axis=1) self.dual_video_visualizer.add(dual_img) if self.hparams.save_deform: - save_deform_dir = f'{test_dir}_deform' + save_deform_dir = f"{test_dir}_deform" Path(save_deform_dir).mkdir(parents=True, exist_ok=True) deformation_field = ret_n.deform[0] - deformation_field = rearrange(deformation_field, - '(h w) c -> h w c', h=H, w=W) - grid_ = rearrange(grid[0], '(h w) c -> h w c', h=H, w=W) + deformation_field = rearrange( + deformation_field, "(h w) c -> h w c", h=H, w=W + ) + grid_ = rearrange(grid[0], "(h w) c -> h w c", h=H, w=W) deformation_delta = deformation_field - grid_ - np.save(f'{save_deform_dir}/{batch_idx:05d}.npy', - deformation_delta.cpu().numpy()) + np.save( + f"{save_deform_dir}/{batch_idx:05d}.npy", + deformation_delta.cpu().numpy(), + ) def on_test_epoch_end(self): if self.hparams.save_video: @@ -515,36 +534,38 @@ def get_progress_bar_dict(self): items.pop("v_num", None) return items + def main(hparams): system = ImplicitVideoSystem(hparams) if not hparams.test: - os.makedirs(f'{hparams.model_save_path}/{hparams.exp_name}', - exist_ok=True) + os.makedirs(f"{hparams.model_save_path}/{hparams.exp_name}", exist_ok=True) checkpoint_callback = ModelCheckpoint( - dirpath=f'{hparams.model_save_path}/{hparams.exp_name}', - filename='{step:d}', - mode='max', + dirpath=f"{hparams.model_save_path}/{hparams.exp_name}", + filename="{step:d}", + mode="max", save_top_k=-1, every_n_train_steps=hparams.save_model_iters, - save_last=True) - - logger = TensorBoardLogger(save_dir=hparams.log_save_path, - name=hparams.exp_name) - - trainer = Trainer(max_steps=hparams.num_steps, - precision=16 if hparams.vid_hash == True else 32, - callbacks=[checkpoint_callback], - logger=logger, - accelerator='gpu', - devices=hparams.gpus, - num_sanity_val_steps=1, - benchmark=True, - profiler="simple" if len(hparams.gpus) == 1 else None, - val_check_interval=hparams.valid_iters, - limit_val_batches=hparams.valid_batches, - strategy="ddp_find_unused_parameters_true") + save_last=True, + ) + + logger = TensorBoardLogger(save_dir=hparams.log_save_path, name=hparams.exp_name) + + trainer = Trainer( + max_steps=hparams.num_steps, + precision=16 if hparams.vid_hash == True else 32, + callbacks=[checkpoint_callback], + logger=logger, + accelerator="gpu", + devices=hparams.gpus, + num_sanity_val_steps=1, + benchmark=True, + profiler="simple" if len(hparams.gpus) == 1 else None, + val_check_interval=hparams.valid_iters, + limit_val_batches=hparams.valid_batches, + strategy="ddp_find_unused_parameters_true", + ) if hparams.test: trainer.test(system, dataloaders=system.test_dataloader()) @@ -552,6 +573,6 @@ def main(hparams): trainer.fit(system, ckpt_path=hparams.ckpt_path) -if __name__ == '__main__': +if __name__ == "__main__": hparams = get_opts() main(hparams) diff --git a/utils/__init__.py b/utils/__init__.py index 444ac71..ee6ac86 100755 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1,7 +1,9 @@ import torch + # optimizer from torch.optim import SGD, Adam import torch_optimizer as optim + # scheduler from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import MultiStepLR @@ -23,79 +25,99 @@ def get_parameters(models): # print("is dict") for model in models.values(): parameters += get_parameters(model) - else: # models is actually a single pytorch model + else: # models is actually a single pytorch model parameters += list(models.parameters()) return parameters + def get_optimizer(hparams, models): eps = 1e-8 parameters = get_parameters(models) - if hparams.optimizer == 'sgd': - optimizer = SGD(parameters, lr=hparams.lr, - momentum=hparams.momentum, weight_decay=hparams.weight_decay) - elif hparams.optimizer == 'adam': - optimizer = Adam(parameters, lr=hparams.lr, eps=eps, - weight_decay=hparams.weight_decay) - elif hparams.optimizer == 'radam': - optimizer = optim.RAdam(parameters, lr=hparams.lr, eps=eps, - weight_decay=hparams.weight_decay) - elif hparams.optimizer == 'ranger': - optimizer = optim.Ranger(parameters, lr=hparams.lr, eps=eps, - weight_decay=hparams.weight_decay) + if hparams.optimizer == "sgd": + optimizer = SGD( + parameters, + lr=hparams.lr, + momentum=hparams.momentum, + weight_decay=hparams.weight_decay, + ) + elif hparams.optimizer == "adam": + optimizer = Adam( + parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay + ) + elif hparams.optimizer == "radam": + optimizer = optim.RAdam( + parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay + ) + elif hparams.optimizer == "ranger": + optimizer = optim.Ranger( + parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay + ) else: - raise ValueError('optimizer not recognized!') + raise ValueError("optimizer not recognized!") return optimizer + def get_scheduler(hparams, optimizer): eps = 1e-8 - if hparams.lr_scheduler == 'steplr': - scheduler = MultiStepLR(optimizer, milestones=hparams.decay_step, - gamma=hparams.decay_gamma) - elif hparams.lr_scheduler == 'cosine': + if hparams.lr_scheduler == "steplr": + scheduler = MultiStepLR( + optimizer, milestones=hparams.decay_step, gamma=hparams.decay_gamma + ) + elif hparams.lr_scheduler == "cosine": scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps) - elif hparams.lr_scheduler == 'poly': - scheduler = LambdaLR(optimizer, - lambda epoch: (1-epoch/hparams.num_epochs)**hparams.poly_exp) - elif hparams.lr_scheduler == 'exponential': + elif hparams.lr_scheduler == "poly": + scheduler = LambdaLR( + optimizer, + lambda epoch: (1 - epoch / hparams.num_epochs) ** hparams.poly_exp, + ) + elif hparams.lr_scheduler == "exponential": # Adaptively adjust the schedule - scheduler = LambdaLR(optimizer, - lambda step: hparams.exponent_base**(step/(2 * hparams.num_steps))) + scheduler = LambdaLR( + optimizer, + lambda step: hparams.exponent_base ** (step / (2 * hparams.num_steps)), + ) else: - raise ValueError('scheduler not recognized!') + raise ValueError("scheduler not recognized!") - if hparams.warmup_epochs > 0 and hparams.optimizer not in ['radam', 'ranger']: - scheduler = GradualWarmupScheduler(optimizer, multiplier=hparams.warmup_multiplier, - total_epoch=hparams.warmup_epochs, after_scheduler=scheduler) + if hparams.warmup_epochs > 0 and hparams.optimizer not in ["radam", "ranger"]: + scheduler = GradualWarmupScheduler( + optimizer, + multiplier=hparams.warmup_multiplier, + total_epoch=hparams.warmup_epochs, + after_scheduler=scheduler, + ) return scheduler + def get_learning_rate(optimizer): for param_group in optimizer.param_groups: - return param_group['lr'] + return param_group["lr"] + -def extract_model_state_dict(ckpt_path, model_name='model', prefixes_to_ignore=[]): - checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) +def extract_model_state_dict(ckpt_path, model_name="model", prefixes_to_ignore=[]): + checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu")) checkpoint_ = {} - if 'state_dict' in checkpoint: # if it's a pytorch-lightning checkpoint - checkpoint = checkpoint['state_dict'] + if "state_dict" in checkpoint: # if it's a pytorch-lightning checkpoint + checkpoint = checkpoint["state_dict"] for k, v in checkpoint.items(): if not k.startswith(model_name): continue - k = k[len(model_name)+1:] + k = k[len(model_name) + 1 :] for prefix in prefixes_to_ignore: if k.startswith(prefix): - print('ignore', k) + print("ignore", k) break else: checkpoint_[k] = v return checkpoint_ -def load_ckpt(model, ckpt_path, model_name='model', prefixes_to_ignore=[]): + +def load_ckpt(model, ckpt_path, model_name="model", prefixes_to_ignore=[]): if not ckpt_path: return model_dict = model.state_dict() checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore) model_dict.update(checkpoint_) model.load_state_dict(model_dict) - diff --git a/utils/image_utils.py b/utils/image_utils.py index aad5698..54deb14 100755 --- a/utils/image_utils.py +++ b/utils/image_utils.py @@ -12,10 +12,20 @@ # File extensions regarding images (not including GIFs). IMAGE_EXTENSIONS = ( - '.bmp', '.ppm', '.pgm', '.jpeg', '.jpg', '.jpe', '.jp2', '.png', '.webp', - '.tiff', '.tif' + ".bmp", + ".ppm", + ".pgm", + ".jpeg", + ".jpg", + ".jpe", + ".jp2", + ".png", + ".webp", + ".tiff", + ".tif", ) + def check_file_ext(filename, *ext_list): """Checks whether the given filename is with target extension(s). @@ -31,7 +41,7 @@ def check_file_ext(filename, *ext_list): """ if len(ext_list) == 0: return False - ext_list = [ext if ext.startswith('.') else '.' + ext for ext in ext_list] + ext_list = [ext if ext.startswith(".") else "." + ext for ext in ext_list] ext_list = [ext.lower() for ext in ext_list] basename = os.path.basename(filename) ext = os.path.splitext(basename)[1].lower() @@ -143,14 +153,16 @@ def resize_image(image, *args, **kwargs): return cv2.resize(image, *args, **kwargs) -def add_text_to_image(image, - text='', - position=None, - font=cv2.FONT_HERSHEY_TRIPLEX, - font_size=1.0, - line_type=cv2.LINE_8, - line_width=1, - color=(255, 255, 255)): +def add_text_to_image( + image, + text="", + position=None, + font=cv2.FONT_HERSHEY_TRIPLEX, + font_size=1.0, + line_type=cv2.LINE_8, + line_width=1, + color=(255, 255, 255), +): """Overlays text on given image. NOTE: The input image is assumed to be with `RGB` channel order. @@ -174,15 +186,17 @@ def add_text_to_image(image, return image _check_2d_image(image) - cv2.putText(img=image, - text=text, - org=position, - fontFace=font, - fontScale=font_size, - color=color, - thickness=line_width, - lineType=line_type, - bottomLeftOrigin=False) + cv2.putText( + img=image, + text=text, + org=position, + fontFace=font, + fontScale=font_size, + color=color, + thickness=line_width, + lineType=line_type, + bottomLeftOrigin=False, + ) return image @@ -254,7 +268,7 @@ def parse_image_size(obj): Raises: If the input is invalid, i.e., neither a list or tuple, nor a string. """ - if obj is None or obj == '': + if obj is None or obj == "": height = 0 width = 0 elif isinstance(obj, int): @@ -262,7 +276,7 @@ def parse_image_size(obj): width = obj elif isinstance(obj, (list, tuple, str, np.ndarray)): if isinstance(obj, str): - splits = obj.replace(' ', '').split(',') + splits = obj.replace(" ", "").split(",") numbers = tuple(map(int, splits)) else: numbers = tuple(obj) @@ -276,9 +290,9 @@ def parse_image_size(obj): height = int(numbers[0]) width = int(numbers[1]) else: - raise ValueError('At most two elements for image size.') + raise ValueError("At most two elements for image size.") else: - raise ValueError(f'Invalid type of input: `{type(obj)}`!') + raise ValueError(f"Invalid type of input: `{type(obj)}`!") return (max(0, height), max(0, width)) diff --git a/utils/video_visualizer.py b/utils/video_visualizer.py index 8262a3d..bc1c8bc 100755 --- a/utils/video_visualizer.py +++ b/utils/video_visualizer.py @@ -9,13 +9,15 @@ class VideoVisualizer(object): """Defines the video visualizer that presents images as a video.""" - def __init__(self, - path=None, - frame_size=None, - fps=25.0, - codec='libx264', - pix_fmt='yuv420p', - crf=1): + def __init__( + self, + path=None, + frame_size=None, + fps=25.0, + codec="libx264", + pix_fmt="yuv420p", + crf=1, + ): """Initializes the video visualizer. Args: @@ -53,11 +55,11 @@ def set_fps(self, fps=25.0): """Sets the FPS (frame per second) of the video.""" self.fps = fps - def set_codec(self, codec='libx264'): + def set_codec(self, codec="libx264"): """Sets the video codec.""" self.codec = codec - def set_pix_fmt(self, pix_fmt='yuv420p'): + def set_pix_fmt(self, pix_fmt="yuv420p"): """Sets the video pixel format.""" self.pix_fmt = pix_fmt @@ -71,11 +73,11 @@ def init_video(self): assert self.frame_width > 0 video_setting = { - '-r': f'{self.fps:.2f}', - '-s': f'{self.frame_width}x{self.frame_height}', - '-vcodec': f'{self.codec}', - '-crf': f'{self.crf}', - '-pix_fmt': f'{self.pix_fmt}', + "-r": f"{self.fps:.2f}", + "-s": f"{self.frame_width}x{self.frame_height}", + "-vcodec": f"{self.codec}", + "-crf": f"{self.crf}", + "-pix_fmt": f"{self.pix_fmt}", } self.video = FFmpegWriter(self.path, outputdict=video_setting) @@ -126,16 +128,15 @@ def save(self): self.set_path(None) -if __name__ == '__main__': +if __name__ == "__main__": from glob import glob import cv2 - video_visualizer = VideoVisualizer(path='output.mp4', - frame_size=None, - fps=25.0) - img_folder = 'src_images/' - imgs = sorted(glob(img_folder + '/*.png')) + + video_visualizer = VideoVisualizer(path="output.mp4", frame_size=None, fps=25.0) + img_folder = "src_images/" + imgs = sorted(glob(img_folder + "/*.png")) for img in imgs: image = cv2.imread(img) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) video_visualizer.add(image) - video_visualizer.save() \ No newline at end of file + video_visualizer.save() diff --git a/utils/warmup_scheduler.py b/utils/warmup_scheduler.py index 03208c4..6a80cee 100755 --- a/utils/warmup_scheduler.py +++ b/utils/warmup_scheduler.py @@ -1,8 +1,9 @@ from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import ReduceLROnPlateau + class GradualWarmupScheduler(_LRScheduler): - """ Gradually warm-up(increasing) learning rate in optimizer. + """Gradually warm-up(increasing) learning rate in optimizer. Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. Args: optimizer (Optimizer): Wrapped optimizer. @@ -13,8 +14,8 @@ class GradualWarmupScheduler(_LRScheduler): def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): self.multiplier = multiplier - if self.multiplier < 1.: - raise ValueError('multiplier should be greater thant or equal to 1.') + if self.multiplier < 1.0: + raise ValueError("multiplier should be greater thant or equal to 1.") self.total_epoch = total_epoch self.after_scheduler = after_scheduler self.finished = False @@ -24,21 +25,33 @@ def get_lr(self): if self.last_epoch > self.total_epoch: if self.after_scheduler: if not self.finished: - self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] + self.after_scheduler.base_lrs = [ + base_lr * self.multiplier for base_lr in self.base_lrs + ] self.finished = True return self.after_scheduler.get_lr() return [base_lr * self.multiplier for base_lr in self.base_lrs] - return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] + return [ + base_lr + * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) + for base_lr in self.base_lrs + ] def step_ReduceLROnPlateau(self, metrics, epoch=None): if epoch is None: epoch = self.last_epoch + 1 - self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning + self.last_epoch = ( + epoch if epoch != 0 else 1 + ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning if self.last_epoch <= self.total_epoch: - warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] + warmup_lr = [ + base_lr + * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) + for base_lr in self.base_lrs + ] for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): - param_group['lr'] = lr + param_group["lr"] = lr else: if epoch is None: self.after_scheduler.step(metrics, None) @@ -55,4 +68,4 @@ def step(self, epoch=None, metrics=None): else: return super(GradualWarmupScheduler, self).step(epoch) else: - self.step_ReduceLROnPlateau(metrics, epoch) \ No newline at end of file + self.step_ReduceLROnPlateau(metrics, epoch)