Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added examples/elephant_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/pile_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/pile_top_mask.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
29 changes: 16 additions & 13 deletions region_based.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import argparse
import numpy as np
from PIL import Image

from tqdm import tqdm

def seed_everything(seed):
torch.manual_seed(seed)
Expand Down Expand Up @@ -115,7 +115,7 @@ def generate(self, masks, prompts, negative_prompts='', height=512, width=2048,
text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2 * len(prompts), 77, 768]

# Define panorama grid and get views
latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device)
latent = torch.randn((1, self.unet.config.in_channels, height // 8, width // 8), device=self.device)
noise = latent.clone().repeat(len(prompts) - 1, 1, 1, 1)
views = get_views(height, width)
count = torch.zeros_like(latent)
Expand All @@ -124,7 +124,7 @@ def generate(self, masks, prompts, negative_prompts='', height=512, width=2048,
self.scheduler.set_timesteps(num_inference_steps)

with torch.autocast('cuda'):
for i, t in enumerate(self.scheduler.timesteps):
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
count.zero_()
value.zero_()

Expand Down Expand Up @@ -175,20 +175,20 @@ def preprocess_mask(mask_path, h, w, device):

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--mask_paths', type=list)
parser.add_argument('--bg_prompt', type=str, default='blurred image')
parser.add_argument('--bg_negative', type=str, default='') # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
parser.add_argument('--fg_prompts', nargs='*', type=str, required=True, help='one ore more prompts for masks')
parser.add_argument('--fg_negative', nargs='*', type=str, help='optional negative prompts for masks', default=None)
parser.add_argument('--fg_masks', nargs='*', type=str, required=True, help='one or more paths to mask images')
# important: it is necessary that SD output high-quality images for the bg/fg prompts.
parser.add_argument('--bg_prompt', type=str)
parser.add_argument('--bg_negative', type=str) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
parser.add_argument('--fg_prompts', type=list)
parser.add_argument('--fg_negative', type=list) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image'
parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0'],
help="stable diffusion version")
parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0'], help="stable diffusion version")
parser.add_argument('--H', type=int, default=768)
parser.add_argument('--W', type=int, default=512)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--steps', type=int, default=50)
# bootstrapping encourages high fidelity to tight masks, the value can be lowered is most cases
# bootstrapping encourages high fidelity to tight masks, the value can be lowered in most cases
parser.add_argument('--bootstrapping', type=int, default=20)
parser.add_argument('--outfile', type=str, default='out.png')
opt = parser.parse_args()

seed_everything(opt.seed)
Expand All @@ -197,15 +197,18 @@ def preprocess_mask(mask_path, h, w, device):

sd = MultiDiffusion(device, opt.sd_version)

fg_masks = torch.cat([preprocess_mask(mask_path, opt.H // 8, opt.W // 8, device) for mask_path in opt.mask_paths])
fg_masks = torch.cat([preprocess_mask(mask_path, opt.H // 8, opt.W // 8, device) for mask_path in opt.fg_masks])
bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True)
bg_mask[bg_mask < 0] = 0
masks = torch.cat([bg_mask, fg_masks])

# fill in optional negative prompts if none given
if opt.fg_negative is None:
opt.fg_negative = ['' for _ in opt.fg_prompts]
prompts = [opt.bg_prompt] + opt.fg_prompts
neg_prompts = [opt.bg_negative] + opt.fg_negative

img = sd.generate(masks, prompts, neg_prompts, opt.H, opt.W, opt.steps, bootstrapping=opt.bootstrapping)

# save image
img.save('out.png')
img.save(opt.outfile)