Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 257 additions & 0 deletions mainv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
import os
import clip
import open_clip
import torch
import numpy as np
import torchvision
import wandb
import argparse
from PIL import Image
import matplotlib.pyplot as plt
import random
import omegaconf
from omegaconf import OmegaConf

import helpers.data_helpers as dh
import methods.clip_transformations as CLIPTransformations
from utils import read_unknowns, nest_dict
from clip_utils import get_features, evaluate, zeroshot_classifier, get_ensamble_preds, get_pred_overlap, get_nn_metrics
import methods.augmentations

parser = argparse.ArgumentParser(description='CLIP Advice')
parser.add_argument('--config', default='configs/base.yaml', help="config file")
parser.add_argument('overrides', nargs='*', help="Any key=value arguments to override config values "
"(use dots for.nested=overrides)")
# flags = parser.parse_args()
flags, unknown = parser.parse_known_args()

overrides = OmegaConf.from_cli(flags.overrides)
cfg = OmegaConf.load(flags.config)
base = OmegaConf.load('configs/base.yaml')
args = OmegaConf.merge(base, cfg, overrides)
if len(unknown) > 0:
print(unknown)
config = nest_dict(read_unknowns(unknown))
to_merge = OmegaConf.create(config)
args = OmegaConf.merge(args, to_merge)
args.yaml = flags.config

assert args.EXP.ADVICE_METHOD != 'CNN', "main.py not for CNN baseline, use train.py"
assert args.EXP.ADVICE_METHOD != 'CLIPZS', "main.py not for CLIP zero-shot, use clip_zs.py"

if args.EXP.WANDB_SILENT:
os.environ['WANDB_SILENT']="true"

def flatten_config(dic, running_key=None, flattened_dict={}):
for key, value in dic.items():
if running_key is None:
running_key_temp = key
else:
running_key_temp = '{}.{}'.format(running_key, key)
if isinstance(value, omegaconf.dictconfig.DictConfig):
flatten_config(value, running_key_temp)
else:
#print(running_key_temp, value)
flattened_dict[running_key_temp] = value
return flattened_dict

run = wandb.init(project=args.EXP.PROJ, group=args.EXP.ADVICE_METHOD, config=flatten_config(args), allow_val_change=False)
# wandb.save(flags.config)
# wandb.run.log_code(".")

torch.manual_seed(args.EXP.SEED)
np.random.seed(args.EXP.SEED)
random.seed(args.EXP.SEED)

DATASET_NAME = args.DATA.DATASET

# load data
if args.DATA.LOAD_CACHED:
print(args.DATA.LOAD_CACHED)
if args.EXP.IMAGE_FEATURES == 'clip' or args.EXP.IMAGE_FEATURES == 'openclip':
model_name = args.EXP.CLIP_MODEL
else:
model_name = args.EXP.IMAGE_FEATURES
cache_file, dataset_classes, dataset_domains = dh.get_cache_file(DATASET_NAME, model_name, args.EXP.BIASED_VAL, args.EXP.IMAGE_FEATURES)
assert os.path.exists(cache_file), f"{cache_file} does not exist. To compute embeddings, set DATA.LOAD_CACHED=False"
data = torch.load(cache_file)
train_features, train_labels, train_groups, train_domains, train_filenames = data['train_features'], data['train_labels'], data['train_groups'], data['train_domains'], data['train_filenames']
val_features, val_labels, val_groups, val_domains, val_filenames = data['val_features'], data['val_labels'], data['val_groups'], data['val_domains'], data['val_filenames']
test_features, test_labels, test_groups, test_domains, test_filenames = data['test_features'], data['test_labels'], data['test_groups'], data['test_domains'], data['test_filenames']
# move some val data to test
if args.DATA.DATASET != 'ColoredMNISTBinary':
val_features, val_labels, val_groups, val_domains, val_filenames = data['val_features'][::2], data['val_labels'][::2], data['val_groups'][::2], data['val_domains'][::2], data['val_filenames'][::2]
test_features, test_labels, test_groups, test_domains, test_filenames = np.concatenate((data['test_features'], data['val_features'][1::2])), np.concatenate((data['test_labels'], data['val_labels'][1::2])), np.concatenate((data['test_groups'], data['val_groups'][1::2])), np.concatenate((data['test_domains'], data['val_domains'][1::2])), np.concatenate((data['test_filenames'], data['val_filenames'][1::2]))
if args.METHOD.NORMALIZE:
train_features /= np.linalg.norm(train_features, axis=-1, keepdims=True)
val_features /= np.linalg.norm(val_features, axis=-1, keepdims=True)
test_features /= np.linalg.norm(test_features, axis=-1, keepdims=True)
# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(args.EXP.IMAGE_FEATURES)
# clip_model, preprocess = clip.load(args.EXP.CLIP_MODEL, device)
if args.EXP.IMAGE_FEATURES == 'clip':
clip_model, preprocess = clip.load(args.EXP.CLIP_MODEL, device)
model, preprocess = clip.load(args.EXP.CLIP_MODEL, device)
elif args.EXP.IMAGE_FEATURES == 'openclip':
model, _, preprocess = open_clip.create_model_and_transforms(args.EXP.CLIP_MODEL, pretrained=args.EXP.CLIP_PRETRAINED_DATASET)
model = model.to(torch.device('cuda:0'))
clip_model = model
else:
model = getattr(torchvision.models, args.EXP.IMAGE_FEATURES)(pretrained=True)
model = model.to(device)

# Calculate the image features
prompts = list(args.EXP.TEXT_PROMPTS)
if len(prompts) >0 and type(prompts[0]) == omegaconf.listconfig.ListConfig:
prompts = [list(p) for p in prompts]

neutral_prompts = list(args.EXP.NEUTRAL_TEXT_PROMPTS)
if len(neutral_prompts) >0 and type(neutral_prompts[0]) == omegaconf.listconfig.ListConfig:
neutral_prompts = [list(p) for p in neutral_prompts]
print("Advice Method", args.EXP.ADVICE_METHOD)
bias_correction = getattr(CLIPTransformations, args.EXP.ADVICE_METHOD)(prompts, clip_model, args, neutral_prompts)
if args.DATA.LOAD_CACHED == False:
trainset, valset, testset = dh.get_dataset(DATASET_NAME, preprocess, biased_val=args.EXP.BIASED_VAL)
dataset_classes = dh.get_class(DATASET_NAME)
dataset_domains = dh.get_domain(DATASET_NAME)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(testset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=False)
train_features, train_labels, train_groups, train_domains, train_filenames = get_features(train_loader, model, device, model_type=args.EXP.IMAGE_FEATURES)
val_features, val_labels, val_groups, val_domains, val_filenames = get_features(val_loader, model, device, model_type=args.EXP.IMAGE_FEATURES)
test_features, test_labels, test_groups, test_domains, test_filenames = get_features(test_loader, model, device, model_type=args.EXP.IMAGE_FEATURES)
data = {
"train_features": train_features, "train_labels": train_labels, "train_groups": train_groups, "train_domains": train_domains, "train_filenames": train_filenames,
"val_features": val_features, "val_labels": val_labels, "val_groups": val_groups, "val_domains": val_domains, "val_filenames": val_filenames,
"test_features": test_features, "test_labels": test_labels, "test_groups": test_groups, "test_domains": test_domains, "test_filenames": test_filenames
}
data_dir = '/'.join(args.DATA.SAVE_PATH.split('/')[:-1])
if not os.path.exists(data_dir):
os.makedirs(data_dir)
torch.save(data, args.DATA.SAVE_PATH)
if args.METHOD.NORMALIZE:
train_features /= np.linalg.norm(train_features, axis=-1, keepdims=True)
val_features /= np.linalg.norm(val_features, axis=-1, keepdims=True)
test_features /= np.linalg.norm(test_features, axis=-1, keepdims=True)

old_train_features, old_train_labels, old_train_groups, old_train_domains, old_train_filenames = train_features, train_labels, train_groups, train_domains, train_filenames
old_val_features, old_val_labels, old_val_groups, old_val_domains, old_val_filenames = val_features, val_labels, val_groups, val_domains, val_filenames
old_test_features, old_test_labels, old_test_groups, old_test_domains, old_test_filenames = test_features, test_labels, test_groups, test_domains, test_filenames


print("SIZE of embeddings ", old_train_features.shape)
# set zeroshot weights if doing a ensamble
if args.EXP.ENSAMBLE:
all_prompts = neutral_prompts + prompts
print("Setting zeroshot weights...")
zeroshot_weights = zeroshot_classifier([[p.format(c) for p in all_prompts] for c in dataset_classes], model, model_type=args.EXP.IMAGE_FEATURES)
dataset_doms = [d.replace('real', 'photo') for d in dataset_domains]
dom_zeroshot_weights = zeroshot_classifier([[f"a {d} of an object."] for d in dataset_doms], model, model_type=args.EXP.IMAGE_FEATURES)
print("Zeroshot weights set!")

# if we want to do any augmentations, do them here
num_augmentations = 1
print("SIZE of embeddings ", train_features.shape, train_domains.shape)
if args.EXP.AUGMENTATION != None and args.EXP.AUGMENTATION != 'None':
print("Augmenting training set...")
if "LADS" in args.EXP.AUGMENTATION or 'Directional' in args.EXP.AUGMENTATION:
augment = getattr(methods.augmentations, args.EXP.AUGMENTATION)(args, train_features, train_labels, train_groups, train_domains, train_filenames, bias_correction.text_embeddings, val_features, val_labels, val_groups, val_domains, val_filenames)
else:
augment = getattr(methods.augmentations, args.EXP.AUGMENTATION)(args, train_features, train_labels, train_groups, train_domains, train_filenames, bias_correction.text_embeddings)
train_features, train_labels, train_domains, train_groups, train_filenames = augment.augment_dataset()
print("Training set augmented!")
print("SIZE of embeddings ", train_features.shape, train_domains.shape)

if args.EXP.LOG_NN:
features, labels, groups, domains, filenames = np.concatenate([old_val_features, old_test_features]), np.concatenate([old_val_labels, old_test_labels]), np.concatenate([old_val_groups, old_test_groups]), np.concatenate([old_val_domains, old_test_domains]), np.concatenate([old_val_filenames, old_test_filenames])
# features, labels, groups, domains, filenames = old_test_features, old_test_labels, old_test_groups, old_test_domains, old_test_filenames
if len(np.unique(train_domains)) > 1:
filtered_idxs = np.where(train_domains != train_domains[0])
sample_features, sample_domains, sample_labels, sample_filenames = np.array(train_features[filtered_idxs]), train_domains[filtered_idxs], train_labels[filtered_idxs], train_filenames[filtered_idxs]
sample_idxs = random.sample(list(range(len(sample_filenames))), min((len(train_filenames), 1000)))
sample_features, sample_domains, sample_labels, sample_filenames = sample_features[sample_idxs], sample_domains[sample_idxs], sample_labels[sample_idxs], sample_filenames[sample_idxs]
else:
sample_idxs = random.sample(list(range(len(train_filenames))), min((len(train_filenames), 1000)))
sample_features, sample_domains, sample_labels, sample_filenames = train_features[sample_idxs], train_domains[sample_idxs], train_labels[sample_idxs], train_filenames[sample_idxs]
neighbor_domains, neighbor_labels, domain_acc, class_acc, neighbor_samples, prop_unique, mean_cs = get_nn_metrics(sample_features, sample_domains, sample_labels, features, domains, labels)
plt.rcParams["figure.figsize"] = (20,5)
f, (axs_orig, axs_new) = plt.subplots(2, 10, sharey=True)
for i, (original_idx, sample_idx) in enumerate(neighbor_samples):
try:
axs_orig[i].imshow(Image.open(sample_filenames[original_idx]).resize((224, 224)))
axs_orig[i].set_title(f"{dataset_domains[int(sample_domains[int(original_idx)])]} - {sample_labels[int(original_idx)]}")
axs_orig[i].axis('off')
axs_new[i].imshow(Image.open(filenames[sample_idx]).resize((224, 224)))
axs_new[i].set_title(f"{dataset_domains[int(domains[int(sample_idx)])]} - {labels[int(sample_idx)]}")
axs_new[i].axis('off')
except:
print(f"sample idx {sample_idx} is not a valid index")
wandb.log({"train features NN": wandb.Image(f), "domain consistency acc": domain_acc, "class consistency acc": class_acc, "unique nn": prop_unique})
# wandb.sklearn.plot_confusion_matrix(sample_domains, neighbor_domains, dataset_domains)
print("Plotted Nearest Neighbors")

# train MLP with domain adaptation loss
bias_correction.train_debias(train_features, train_labels, train_groups, train_domains, val_features, val_labels, np.squeeze(val_groups), val_domains)
if args.EXP.ENSAMBLE:
print("Ensambling predictions")
predictions, probs = bias_correction.eval(val_features, ret_probs=True)
lads_preds, zs_preds, ensamble_predictions, combined_preds = get_ensamble_preds(val_features, probs, zeroshot_weights, dataset_domains=dom_zeroshot_weights)
non_overlap, non_overlap_prop, non_overlap_prop_correct = get_pred_overlap(lads_preds, zs_preds, val_labels)
accuracy, balanced_acc, class_accuracy, group_accuracy = evaluate(ensamble_predictions, val_labels, np.squeeze(val_groups), num_augmentations=num_augmentations)
wandb.summary["ensamble val acc"] = accuracy
wandb.summary["ensamble val blanced acc"] = balanced_acc

predictions, probs = bias_correction.eval(test_features, ret_probs=True)
lads_preds, zs_preds, ensamble_predictions, combined_preds = get_ensamble_preds(test_features, probs, zeroshot_weights, dataset_domains=dom_zeroshot_weights)
non_overlap, non_overlap_prop, non_overlap_prop_correct = get_pred_overlap(lads_preds, zs_preds, test_labels)
accuracy, balanced_acc, class_accuracy, group_accuracy = evaluate(ensamble_predictions, test_labels, np.squeeze(test_groups), num_augmentations=num_augmentations)
_, _, _, domain_accuracy = evaluate(ensamble_predictions, test_labels, np.squeeze(test_domains), list(range(len(dataset_classes))), num_augmentations=num_augmentations)
wandb.summary["ensamble test acc"] = accuracy
wandb.summary["ensamble test blanced acc"] = balanced_acc
wandb.summary["ensamble test class acc"] = class_accuracy
wandb.summary["ensamble test domain acc"] = domain_accuracy
wandb.summary["ensamble test worst domain acc"] = np.min(domain_accuracy)
wandb.summary['ensamble test group acc'] = group_accuracy
else:
predictions, probs = bias_correction.eval(test_features)
val_predictions, val_probs = bias_correction.eval(val_features)
val_accuracy, val_balanced_acc, val_class_accuracy, val_group_accuracy = evaluate(val_predictions, val_labels, np.squeeze(val_groups), num_augmentations=num_augmentations)
accuracy, balanced_acc, class_accuracy, group_accuracy = evaluate(predictions, test_labels, np.squeeze(test_groups), num_augmentations=num_augmentations)
_, _, _, domain_accuracy = evaluate(predictions, test_labels, np.squeeze(test_domains), list(range(len(dataset_classes))), num_augmentations=num_augmentations)
wandb.summary["test acc"] = accuracy
wandb.summary["test blanced acc"] = balanced_acc
wandb.summary["test class acc"] = class_accuracy
wandb.summary["test domain acc"] = domain_accuracy
wandb.summary["test worst domain acc"] = np.min(domain_accuracy)
wandb.summary['test group acc'] = group_accuracy
for i in range(len(domain_accuracy)):
wandb.summary[f"{dataset_domains[i]} test acc"] = domain_accuracy[i]
print(f"Test accuracy: {group_accuracy} \n Test domain accuracy: {domain_accuracy}")

if 'E2E' in args.EXP.ADVICE_METHOD:
# features, labels, groups, domains, filenames = np.concatenate([old_val_features, old_test_features]), np.concatenate([old_val_labels, old_test_labels]), np.concatenate([old_val_groups, old_test_groups]), np.concatenate([old_val_domains, old_test_domains]), np.concatenate([old_val_filenames, old_test_filenames])
aug_features, aug_labels, aug_domains, aug_filenames = bias_correction.augment_dataset(train_features, train_labels, train_domains, train_filenames)
sample_idxs = random.sample(list(range(len(aug_filenames))), 1000)
# print("SAMPLE SHAPE: ", sample_filenames.shape, sample_domains.shape)
sample_features, sample_domains, sample_labels, sample_filenames = aug_features[sample_idxs], aug_domains[sample_idxs], aug_labels[sample_idxs], aug_filenames[sample_idxs]
print("UNIQUE DOMAINS ", np.unique(aug_domains))
neighbor_domains, neighbor_labels, domain_acc, class_acc, neighbor_samples, prop_unique, mean_cs = get_nn_metrics(sample_features, sample_domains, sample_labels, old_test_features, old_test_domains, old_test_labels)
wandb.log({"mean CS for NN": mean_cs})
print(neighbor_samples)
plt.rcParams["figure.figsize"] = (20,5)
f, (axs_orig, axs_new) = plt.subplots(2, 10, sharey=True)
print("DATASET DOMAIN ", dataset_domains)
for i, (original_idx, sample_idx) in enumerate(neighbor_samples):
# try:
print(sample_filenames[original_idx])
axs_orig[i].imshow(Image.open(sample_filenames[original_idx]).resize((224, 224)))
axs_orig[i].set_title(f"{dataset_domains[int(sample_domains[int(original_idx)])]} - {sample_labels[int(original_idx)]}")
axs_orig[i].axis('off')
axs_new[i].imshow(Image.open(old_test_filenames[sample_idx]).resize((224, 224)))
axs_new[i].set_title(f"{dataset_domains[int(old_test_domains[int(sample_idx)])]} - {old_test_labels[int(sample_idx)]}")
axs_new[i].axis('off')
# except:
# print(f"sample idx {sample_idx} is not a valid index")
wandb.log({"train features NN": wandb.Image(f), "domain consistency acc": domain_acc, "class consistency acc": class_acc, "unique nn": prop_unique})
wandb.sklearn.plot_confusion_matrix(sample_domains, neighbor_domains, dataset_domains)
9 changes: 5 additions & 4 deletions methods/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def __init__(self, cfg, image_features, labels, group_labels, domain_labels, fil
for i in range(len(self.orig_prompts)):
self.train_network("sketch", self.orig_prompts[:,i], i)
else:
for i in range(len(self.text_features[0])):
self.train_network("sketch", self.text_features[:,i], i)
for i in range(len(self.text_features)):
self.train_network("sketch", self.text_features[i], i)
else:
for i in range(len(self.text_features)):
self.train_network("sketch", self.text_features[i], i)
Expand All @@ -321,12 +321,13 @@ def directional_loss_builder(self, num_net):
image embeddings is similar to the difference in text embeddings of the
source and target domain.
"""
print("text feat ", self.text_features.shape)
if not self.cfg.AUGMENTATION.GENERIC:
delta_t = torch.Tensor(self.text_features[:,num_net])
delta_t = torch.Tensor(self.text_features[num_net])
else:
delta_t = torch.Tensor(self.text_features[num_net])
delta_t = delta_t.type(torch.float).cuda()

print("text features shape ", self.text_features.shape, delta_t.shape)
def custom_loss(predictions, labels, targets):
total_sum = None
delta_i = predictions - labels
Expand Down