From ef8837c7cd5e917325b72472b5b432305ae421d6 Mon Sep 17 00:00:00 2001 From: Paris Zhang Date: Tue, 8 Nov 2022 04:36:37 +0000 Subject: [PATCH] add dalle visualization --- clip_advice.py | 12 ++++++++++++ dalle/dalle2.json | 38 ++++++++++++++++++++++++++++++++++++++ dalle/setup.py | 17 +++++++++++++++++ dalle/visualization.py | 38 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+) create mode 100644 dalle/dalle2.json create mode 100644 dalle/setup.py create mode 100644 dalle/visualization.py diff --git a/clip_advice.py b/clip_advice.py index 0845cb1..3d4dc57 100644 --- a/clip_advice.py +++ b/clip_advice.py @@ -36,6 +36,10 @@ import omegaconf import ast +# from dalle2_laion import ModelLoadConfig, DalleModelManager +# from dalle2_laion.scripts import InferenceScript +# from test import ExampleInference + parser = argparse.ArgumentParser(description='CLIP Advice') parser.add_argument('--config', default='configs/Noop.yaml', help="config file") parser.add_argument('overrides', nargs='*', help="Any key=value arguments to override config values " @@ -221,6 +225,14 @@ def flatten_config(dic, running_key=None, flattened_dict={}): except: print(f"sample idx {sample_idx} is not a valid index") wandb.log({"train features NN": wandb.Image(f), "domain consistency acc": domain_acc, "class consistency acc": class_acc, "unique nn": prop_unique}) + + sample_idxs = random.sample(list(range(len(train_features))), 10) + with open('dalle/embeddings.npy', 'wb') as f: + for i in sample_idxs: + np.save(f, train_features[i]) + np.save(f, train_features[i+1]) + print(f"original domain: {train_domains[i]}, augmented domain: {train_domains[i+1]}") + wandb.sklearn.plot_confusion_matrix(sample_domains, neighbor_domains, dataset_domains) if args.EXP.LOG_EMB_DRIFT: sample_idxs = random.sample(list(range(len(val_features))), min([len(val_features), 1000])) diff --git a/dalle/dalle2.json b/dalle/dalle2.json new file mode 100644 index 0000000..de947a6 --- /dev/null +++ b/dalle/dalle2.json @@ -0,0 +1,38 @@ +{ + "decoder": { + "unet_sources": [ + { + "unet_numbers": [1], + "default_cond_scale": [1.7], + "load_model_from": { + "load_type": "url", + "path": "https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/v1.0.2/latest.pth", + "cache_dir": "./dalle/models", + "filename_override": "new_decoder.pth" + } + }, + { + "unet_numbers": [2], + "load_model_from": { + "load_type": "url", + "path": "https://huggingface.co/Veldrovive/upsamplers/resolve/main/working/latest.pth", + "cache_dir": "./dalle/models", + "filename_override": "second_decoder.pth" + }, + "load_config_from": { + "load_type": "url", + "path": "https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json", + "checksum_file_path": "https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json", + "cache_dir": "./dalle/models", + "filename_override": "second_decoder_config.json" + } + } + ] + }, + "clip": { + "make": "openai", + "model": "ViT-L/14" + }, + "devices": "cuda:0", + "strict_loading": false +} \ No newline at end of file diff --git a/dalle/setup.py b/dalle/setup.py new file mode 100644 index 0000000..2660de4 --- /dev/null +++ b/dalle/setup.py @@ -0,0 +1,17 @@ +from setuptools import setup, find_packages + +setup( + name = "dalle2-laion", + version = "0.0.1", + packages = find_packages(exclude=[]), + include_package_data = True, + install_requires = [ + "packaging>=21.0", + "pydantic>=1.9.0", + "torch>=1.10", + "Pillow>=9.0.0", + "numpy>=1.20.0", + "click>=8.0.0" + "dalle2-pytorch" + ] +) diff --git a/dalle/visualization.py b/dalle/visualization.py new file mode 100644 index 0000000..873b1e2 --- /dev/null +++ b/dalle/visualization.py @@ -0,0 +1,38 @@ +from dalle2_laion import ModelLoadConfig, DalleModelManager +from dalle2_laion.scripts import InferenceScript +import torch +import numpy as np +from PIL import Image + +class ExampleInference(InferenceScript): + def run(self, text, image_embedding) -> Image: + image_map = self._sample_decoder(text=text, image_embed=image_embedding) + return image_map[0][0] + +model_config = ModelLoadConfig.from_json_path("./dalle/dalle2.json") +model_manager = DalleModelManager(model_config) +inference = ExampleInference(model_manager) + +original_embed = [] +augmented_embed = [] +with open('./dalle/embeddings.npy', 'rb') as f: + for _ in range(10): + a = np.load(f) + original_embed.append(torch.from_numpy(a)) + b = np.load(f) + augmented_embed.append(torch.from_numpy(b)) + +# original_domain = [1, 1, 0, 0, 0, 0, 1, 1, 0, 1] +# original_text = ['a photo of a landbird in the forest.', 'a photo of a waterbird on water.'] +# augmented_text = ['a photo of a waterbird in the forest.', 'a photo of a landbird on water.'] +for i in range(10): + image = inference.run(["a photo"], [original_embed[i]]) + image.save(f'./dalle/images/waterbird-original-photo-{i}.png') + image = inference.run(["a photo"], [augmented_embed[i]]) + image.save(f'./dalle/images/waterbird-augment-photo-{i}.png') + +for i in range(10): + image = inference.run(["a car"], [original_embed[i]]) + image.save(f'./dalle/images/waterbird-original-car-{i}.png') + image = inference.run(["a car"], [augmented_embed[i]]) + image.save(f'./dalle/images/waterbird-augment-car-{i}.png') \ No newline at end of file