Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ all_sequences/
ckpts/
logs/
# configs/

*.mp4
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# CoDeF: Content Deformation Fields for Temporally Consistent Video Processing

<img src='docs/teaser.gif'></img>
# HeyGen ❤️ Superwoman
<img src='docs/heygen_red.gif'></img>

The slides for the project: https://docs.google.com/presentation/d/1y4qP2dALviAZm_D4BU_udy2IHE-1A6nwD4zS3RAqfvc/edit?usp=sharing

The config for a HeyGen avatar stylization: `configs/heygen/base.yaml`

---------
## CoDeF: Content Deformation Fields for Temporally Consistent Video Processing
[Hao Ouyang](https://ken-ouyang.github.io/)\*, [Qiuyu Wang](https://github.com/qiuyu96/)\*, [Yuxi Xiao](https://henry123-boy.github.io/)\*, [Qingyan Bai](https://scholar.google.com/citations?user=xUMjxi4AAAAJ&hl=en), [Juntao Zhang](https://github.com/JordanZh), [Kecheng Zheng](https://scholar.google.com/citations?user=hMDQifQAAAAJ), [Xiaowei Zhou](https://xzhou.me/),
[Qifeng Chen](https://cqf.io/)&#8224;, [Yujun Shen](https://shenyujun.github.io/)&#8224; (*equal contribution, &#8224;corresponding author)

Expand Down
26 changes: 26 additions & 0 deletions configs/heygen/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
mask_dir: null
flow_dir: null

img_wh: [1080, 1080]
canonical_wh: [1080, 1080]

lr: 0.001
bg_loss: 0.0

ref_idx: null # 0

N_xyz_w: [8, 8]
flow_loss: 1
flow_step: -1
self_bg: False

deform_hash: True
vid_hash: True

num_steps: 200000
decay_step: [2500, 5000, 7500]
annealed_begin_step: 4000
annealed_step: 4000
save_model_iters: 2000

fps: 25
26 changes: 13 additions & 13 deletions data_preprocessing/RAFT/core/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
corr = CorrBlock.corr(fmap1, fmap2)

batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)

self.corr_pyramid.append(corr)
for i in range(self.num_levels-1):
for i in range(self.num_levels - 1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)

Expand All @@ -34,12 +34,12 @@ def __call__(self, coords):
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2*r+1, device=coords.device)
dy = torch.linspace(-r, r, 2*r+1, device=coords.device)
dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)

centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl

corr = bilinear_sampler(corr, coords_lvl)
Expand All @@ -52,12 +52,12 @@ def __call__(self, coords):
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht*wd)
fmap2 = fmap2.view(batch, dim, ht*wd)
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
fmap1 = fmap1.view(batch, dim, ht * wd)
fmap2 = fmap2.view(batch, dim, ht * wd)

corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
return corr / torch.sqrt(torch.tensor(dim).float())


class AlternateCorrBlock:
Expand All @@ -83,7 +83,7 @@ def __call__(self, coords):
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()

coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
(corr,) = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))

corr = torch.stack(corr_list, dim=1)
Expand Down
Loading