Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions clip_advice.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
import omegaconf
import ast

# from dalle2_laion import ModelLoadConfig, DalleModelManager
# from dalle2_laion.scripts import InferenceScript
# from test import ExampleInference

parser = argparse.ArgumentParser(description='CLIP Advice')
parser.add_argument('--config', default='configs/Noop.yaml', help="config file")
parser.add_argument('overrides', nargs='*', help="Any key=value arguments to override config values "
Expand Down Expand Up @@ -221,6 +225,14 @@ def flatten_config(dic, running_key=None, flattened_dict={}):
except:
print(f"sample idx {sample_idx} is not a valid index")
wandb.log({"train features NN": wandb.Image(f), "domain consistency acc": domain_acc, "class consistency acc": class_acc, "unique nn": prop_unique})

sample_idxs = random.sample(list(range(len(train_features))), 10)
with open('dalle/embeddings.npy', 'wb') as f:
for i in sample_idxs:
np.save(f, train_features[i])
np.save(f, train_features[i+1])
print(f"original domain: {train_domains[i]}, augmented domain: {train_domains[i+1]}")

wandb.sklearn.plot_confusion_matrix(sample_domains, neighbor_domains, dataset_domains)
if args.EXP.LOG_EMB_DRIFT:
sample_idxs = random.sample(list(range(len(val_features))), min([len(val_features), 1000]))
Expand Down
38 changes: 38 additions & 0 deletions dalle/dalle2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"decoder": {
"unet_sources": [
{
"unet_numbers": [1],
"default_cond_scale": [1.7],
"load_model_from": {
"load_type": "url",
"path": "https://huggingface.co/laion/DALLE2-PyTorch/resolve/main/decoder/v1.0.2/latest.pth",
"cache_dir": "./dalle/models",
"filename_override": "new_decoder.pth"
}
},
{
"unet_numbers": [2],
"load_model_from": {
"load_type": "url",
"path": "https://huggingface.co/Veldrovive/upsamplers/resolve/main/working/latest.pth",
"cache_dir": "./dalle/models",
"filename_override": "second_decoder.pth"
},
"load_config_from": {
"load_type": "url",
"path": "https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json",
"checksum_file_path": "https://huggingface.co/Veldrovive/upsamplers/raw/main/working/decoder_config.json",
"cache_dir": "./dalle/models",
"filename_override": "second_decoder_config.json"
}
}
]
},
"clip": {
"make": "openai",
"model": "ViT-L/14"
},
"devices": "cuda:0",
"strict_loading": false
}
17 changes: 17 additions & 0 deletions dalle/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from setuptools import setup, find_packages

setup(
name = "dalle2-laion",
version = "0.0.1",
packages = find_packages(exclude=[]),
include_package_data = True,
install_requires = [
"packaging>=21.0",
"pydantic>=1.9.0",
"torch>=1.10",
"Pillow>=9.0.0",
"numpy>=1.20.0",
"click>=8.0.0"
"dalle2-pytorch"
]
)
38 changes: 38 additions & 0 deletions dalle/visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dalle2_laion import ModelLoadConfig, DalleModelManager
from dalle2_laion.scripts import InferenceScript
import torch
import numpy as np
from PIL import Image

class ExampleInference(InferenceScript):
def run(self, text, image_embedding) -> Image:
image_map = self._sample_decoder(text=text, image_embed=image_embedding)
return image_map[0][0]

model_config = ModelLoadConfig.from_json_path("./dalle/dalle2.json")
model_manager = DalleModelManager(model_config)
inference = ExampleInference(model_manager)

original_embed = []
augmented_embed = []
with open('./dalle/embeddings.npy', 'rb') as f:
for _ in range(10):
a = np.load(f)
original_embed.append(torch.from_numpy(a))
b = np.load(f)
augmented_embed.append(torch.from_numpy(b))

# original_domain = [1, 1, 0, 0, 0, 0, 1, 1, 0, 1]
# original_text = ['a photo of a landbird in the forest.', 'a photo of a waterbird on water.']
# augmented_text = ['a photo of a waterbird in the forest.', 'a photo of a landbird on water.']
for i in range(10):
image = inference.run(["a photo"], [original_embed[i]])
image.save(f'./dalle/images/waterbird-original-photo-{i}.png')
image = inference.run(["a photo"], [augmented_embed[i]])
image.save(f'./dalle/images/waterbird-augment-photo-{i}.png')

for i in range(10):
image = inference.run(["a car"], [original_embed[i]])
image.save(f'./dalle/images/waterbird-original-car-{i}.png')
image = inference.run(["a car"], [augmented_embed[i]])
image.save(f'./dalle/images/waterbird-augment-car-{i}.png')