diff --git a/examples/elephant_mask.png b/examples/elephant_mask.png new file mode 100755 index 0000000..50f25b2 Binary files /dev/null and b/examples/elephant_mask.png differ diff --git a/examples/pile_mask.png b/examples/pile_mask.png new file mode 100755 index 0000000..fce149a Binary files /dev/null and b/examples/pile_mask.png differ diff --git a/examples/pile_top_mask.png b/examples/pile_top_mask.png new file mode 100755 index 0000000..e1b604d Binary files /dev/null and b/examples/pile_top_mask.png differ diff --git a/region_based.py b/region_based.py old mode 100644 new mode 100755 index f512839..c0a347f --- a/region_based.py +++ b/region_based.py @@ -10,7 +10,7 @@ import argparse import numpy as np from PIL import Image - +from tqdm import tqdm def seed_everything(seed): torch.manual_seed(seed) @@ -115,7 +115,7 @@ def generate(self, masks, prompts, negative_prompts='', height=512, width=2048, text_embeds = self.get_text_embeds(prompts, negative_prompts) # [2 * len(prompts), 77, 768] # Define panorama grid and get views - latent = torch.randn((1, self.unet.in_channels, height // 8, width // 8), device=self.device) + latent = torch.randn((1, self.unet.config.in_channels, height // 8, width // 8), device=self.device) noise = latent.clone().repeat(len(prompts) - 1, 1, 1, 1) views = get_views(height, width) count = torch.zeros_like(latent) @@ -124,7 +124,7 @@ def generate(self, masks, prompts, negative_prompts='', height=512, width=2048, self.scheduler.set_timesteps(num_inference_steps) with torch.autocast('cuda'): - for i, t in enumerate(self.scheduler.timesteps): + for i, t in enumerate(tqdm(self.scheduler.timesteps)): count.zero_() value.zero_() @@ -175,20 +175,20 @@ def preprocess_mask(mask_path, h, w, device): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('--mask_paths', type=list) + parser.add_argument('--bg_prompt', type=str, default='blurred image') + parser.add_argument('--bg_negative', type=str, default='') # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image' + parser.add_argument('--fg_prompts', nargs='*', type=str, required=True, help='one ore more prompts for masks') + parser.add_argument('--fg_negative', nargs='*', type=str, help='optional negative prompts for masks', default=None) + parser.add_argument('--fg_masks', nargs='*', type=str, required=True, help='one or more paths to mask images') # important: it is necessary that SD output high-quality images for the bg/fg prompts. - parser.add_argument('--bg_prompt', type=str) - parser.add_argument('--bg_negative', type=str) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image' - parser.add_argument('--fg_prompts', type=list) - parser.add_argument('--fg_negative', type=list) # 'artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image' - parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0'], - help="stable diffusion version") + parser.add_argument('--sd_version', type=str, default='2.0', choices=['1.5', '2.0'], help="stable diffusion version") parser.add_argument('--H', type=int, default=768) parser.add_argument('--W', type=int, default=512) parser.add_argument('--seed', type=int, default=0) parser.add_argument('--steps', type=int, default=50) - # bootstrapping encourages high fidelity to tight masks, the value can be lowered is most cases + # bootstrapping encourages high fidelity to tight masks, the value can be lowered in most cases parser.add_argument('--bootstrapping', type=int, default=20) + parser.add_argument('--outfile', type=str, default='out.png') opt = parser.parse_args() seed_everything(opt.seed) @@ -197,15 +197,18 @@ def preprocess_mask(mask_path, h, w, device): sd = MultiDiffusion(device, opt.sd_version) - fg_masks = torch.cat([preprocess_mask(mask_path, opt.H // 8, opt.W // 8, device) for mask_path in opt.mask_paths]) + fg_masks = torch.cat([preprocess_mask(mask_path, opt.H // 8, opt.W // 8, device) for mask_path in opt.fg_masks]) bg_mask = 1 - torch.sum(fg_masks, dim=0, keepdim=True) bg_mask[bg_mask < 0] = 0 masks = torch.cat([bg_mask, fg_masks]) + # fill in optional negative prompts if none given + if opt.fg_negative is None: + opt.fg_negative = ['' for _ in opt.fg_prompts] prompts = [opt.bg_prompt] + opt.fg_prompts neg_prompts = [opt.bg_negative] + opt.fg_negative img = sd.generate(masks, prompts, neg_prompts, opt.H, opt.W, opt.steps, bootstrapping=opt.bootstrapping) # save image - img.save('out.png') \ No newline at end of file + img.save(opt.outfile)