From b5102b71fe56be336718a5cf1e60a03d3356bbaf Mon Sep 17 00:00:00 2001 From: Lisa Dunlap Date: Fri, 14 Apr 2023 15:06:47 -0700 Subject: [PATCH] cleaning up --- README.md | 11 +- configs/CUB/dpl.yaml | 85 ----- configs/CUB/lads.yaml | 3 +- configs/DomainNet/test.yaml | 22 -- configs/DomainNet/test_aug.yaml | 38 --- configs/Waterbirds/lads.yaml | 2 +- configs/Waterbirds/mlp.yaml | 2 +- configs/Waterbirds/mlpzs.yaml | 2 +- configs/Waterbirds/new_lads.yaml | 2 +- configs/base.yaml | 1 + helpers/data_helpers.py | 2 +- helpers/data_paths.py | 13 - main.py | 17 +- methods/augmentations.py | 512 +------------------------------ methods/clip_transformations.py | 42 ++- methods/predictors.py | 27 +- 16 files changed, 72 insertions(+), 709 deletions(-) delete mode 100644 configs/CUB/dpl.yaml delete mode 100644 configs/DomainNet/test.yaml delete mode 100644 configs/DomainNet/test_aug.yaml delete mode 100644 helpers/data_paths.py diff --git a/README.md b/README.md index fc63c1d..b66526d 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,15 @@ Official Implementation of [LADS (Latent Augmentation using Domain descriptionS) *WARNING: this is still WIP, please raise an issue if you run into any bugs.* +``` +@article{dunlap2023lads, + title={Using Language to Entend to Unseen Domains}, + author = {Dunlap, Lisa and Mohri, Clara and Guillory, Devin and Zhang, Han and Darrell, Trevor and Gonzalez, Joseph E. and Raghunathan, Aditi and Rohrbach, Anja}, + journal={International Conference on Learning Representations (ICLR)}, + year={2023} +} +``` + ## Getting started 1. Install the dependencies for our code using Conda. You may need to adjust the environment YAML file depending on your setup. @@ -26,7 +35,7 @@ python main.py --config configs/Waterbirds/mlp.yaml Datasets supported are in the [helpers folder](./helpers/data_helpers.py). Currently they are: * Waterbirds (100% and 95%) [our specific split](https://drive.google.com/file/d/1zJpQYGEt1SuwitlNfE06TFyLaWX-st1k/view) [code to generate data](https://github.com/kohpangwei/group_DRO) -* ColoredMNIST (LNTL version and simplified version) NOTEBOOK COMING SOON +* ColoredMNIST (LNTL version and simplified version) [Paper Dataset](https://drive.google.com/file/d/1GomKfufFrXIRAFJNedUBCDwHW2X9NTP-/view?usp=share_link) * DomainNet (the version used in the paper is `DATA.DATASET=DomainNetMini`) [full dataset](http://ai.bu.edu/DomainNet/) * CUB Paintings [photos dataset](https://www.vision.caltech.edu/datasets/cub_200_2011/) [paintings dataset](https://github.com/thuml/PAN) * OfficeHome COMING SOON diff --git a/configs/CUB/dpl.yaml b/configs/CUB/dpl.yaml deleted file mode 100644 index 3507f78..0000000 --- a/configs/CUB/dpl.yaml +++ /dev/null @@ -1,85 +0,0 @@ -EXP: - ADVICE_METHOD: "DPL" - WANDB_SILENT: False - PROJ: "LADS_Replication" - SEED: 0 - TEXT_PROMPTS: ['a painting of a {} bird.'] - NEUTRAL_TEXT_PROMPTS: ['a photo of a {} bird.'] - EPOCHS: 400 - CHECKPOINT_VAL: True - -DATA: - DATASET: "CUB" - BATCH_SIZE: 256 - -METHOD: - MODEL: - NUM_LAYERS: 1 - LR: 0.001 - WEIGHT_DECAY: 0.05 - CHECKPOINT_NAME: 'dpl' - BATCH_AVERAGING: True - TEST_BATCH_AVG: False - -OPTIM: - NAME: "sgd" - LR: 0.002 - MAX_EPOCH: 10 - LR_SCHEDULER: "cosine" - WARMUP_EPOCH: 1 - WEIGHT_DECAY: 5e-4 - MOMENTUM: 0.9 - SGD_DAMPNING: 0 - SGD_NESTEROV: False - RMSPROP_ALPHA: 0.99 - # The following also apply to other - # adaptive optimizers like adamw - ADAM_BETA1: 0.9 - ADAM_BETA2: 0.999 - # STAGED_LR allows different layers to have - # different lr, e.g. pre-trained base layers - # can be assigned a smaller lr than the new - # classification layer - STAGED_LR: False - NEW_LAYERS: () - BASE_LR_MULT: 0.1 - # -1 or 0 means the stepsize is equal to max_epoch - STEPSIZE: (-1, ) - GAMMA: 0.1 - # Either linear or constant - WARMUP_TYPE: "linear" - # Constant learning rate when type=constant - WARMUP_CONS_LR: 1e-5 - # Minimum learning rate when type=linear - WARMUP_MIN_LR: 1e-5 - # Recount epoch for the next scheduler (last_epoch=-1) - # Otherwise last_epoch=warmup_epoch - WARMUP_RECOUNT: True - -TRAINER: - COCOOP: - N_CTX: 4 - AVG: False - LOAD_CTX: False - CTX_CHECKPOINT: '' - NUM_DOM_TOKEN: 4 - CTX_INIT: "a photo of a" - PREC: "fp16" - -INPUT: - SIZE: - - 224 - - 224 - INTERPOLATION: "bicubic" - PIXEL_MEAN: - - 0.48145466 - - 0.4578275 - - 0.40821073 - PIXEL_STD: - - 0.26862954 - - 0.26130258 - - 0.27577711 - TRANSFORMS: - - "random_resized_crop" - - "random_flip" - - "normalize" \ No newline at end of file diff --git a/configs/CUB/lads.yaml b/configs/CUB/lads.yaml index 7baf68f..0b865c1 100644 --- a/configs/CUB/lads.yaml +++ b/configs/CUB/lads.yaml @@ -14,6 +14,7 @@ DATA: DATASET: "CUB" LOAD_CACHED: True BATCH_SIZE: 256 + SAVE_PATH: 'embeddings/CUB/recomputed.pth' METHOD: MODEL: @@ -27,7 +28,7 @@ METHOD: AUGMENTATION: MODEL: - LR: 0.1 + LR: 0.001 WEIGHT_DECAY: 0.05 NUM_LAYERS: 1 GENERIC: False diff --git a/configs/DomainNet/test.yaml b/configs/DomainNet/test.yaml deleted file mode 100644 index 5491163..0000000 --- a/configs/DomainNet/test.yaml +++ /dev/null @@ -1,22 +0,0 @@ -EXP: - ADVICE_METHOD: "Base" - WANDB_SILENT: False - PROJ: "DomainNetMini_LADS_Replication" - SEED: 0 - TEXT_PROMPTS: ['clipart of a {}.', 'a painting of a {}.', 'a realistic photo of a {}.'] - NEUTRAL_TEXT_PROMPTS: ['a sketch of a {}'] - EPOCHS: 400 - -DATA: - DATASET: "DomainNetMini" - BATCH_SIZE: 256 - -METHOD: - MODEL: - NUM_LAYERS: 1 - DOM_WEIGHT: 1.0 - LR: 0.0001 - CHECKPOINT: 'checkpoint/mlp.pth' - CHECKPOINT_NAME: 'DomainNetMini/mlp' - RESUME: False - USE_DOM_GT: False \ No newline at end of file diff --git a/configs/DomainNet/test_aug.yaml b/configs/DomainNet/test_aug.yaml deleted file mode 100644 index 4a1a373..0000000 --- a/configs/DomainNet/test_aug.yaml +++ /dev/null @@ -1,38 +0,0 @@ -EXP: - ADVICE_METHOD: "ClipMLPNew" - WANDB_SILENT: False - PROJ: "DomainNetMini_LADS_Replication" - SEED: 0 - TEXT_PROMPTS: [['a realistic photo of a {}.'], ['a painting of a {}.'], ['clipart of a {}.'], ['art of a {}.'], ['an image of a {}.']] - NEUTRAL_TEXT_PROMPTS: ['a sketch of a {}', 'a pencil drawing of a {}.', 'a drawing of a {}.'] - AUGMENTATION: 'DirectionalMulti' - EPOCHS: 400 - LOG_NN: False - ENSAMBLE: False - - -DATA: - DATASET: "DomainNetMini" - LOAD_CACHED: True - SAVE_PATH: "vit14_clip.pth" - BATCH_SIZE: 256 - -METHOD: - MODEL: - NUM_LAYERS: 1 - DOM_WEIGHT: 1.0 - LR: 0.0001 - CHECKPOINT: 'checkpoint/mlp_simple.pth' - CHECKPOINT_NAME: 'DomainNetMini-mlp-directional' - RESUME: False - USE_DOM_GT: True - -AUGMENTATION: - MODEL: - LR: 0.0001 - WEIGHT_DECAY: 0.005 - NUM_LAYERS: 1 - EPOCHS: 50 - GENERIC: False - ALPHA: 0.5 - DOM_LABELS: ['real', 'painting', 'clipart', 'painting', 'real'] \ No newline at end of file diff --git a/configs/Waterbirds/lads.yaml b/configs/Waterbirds/lads.yaml index e334a3c..413ac43 100644 --- a/configs/Waterbirds/lads.yaml +++ b/configs/Waterbirds/lads.yaml @@ -6,7 +6,7 @@ EXP: TEXT_PROMPTS: [['a photo of a {} on forest.'], ['a photo of a {} on water.']] NEUTRAL_TEXT_PROMPTS: [] AUGMENTATION: 'BiasDirectional' - EPOCHS: 200 + EPOCHS: 400 CHECKPOINT_VAL: True ENSAMBLE: False diff --git a/configs/Waterbirds/mlp.yaml b/configs/Waterbirds/mlp.yaml index 9aa8991..23cf474 100644 --- a/configs/Waterbirds/mlp.yaml +++ b/configs/Waterbirds/mlp.yaml @@ -5,7 +5,7 @@ EXP: SEED: 0 TEXT_PROMPTS: ['a photo of a {} on forest.', 'a photo of a {} on water.'] NEUTRAL_TEXT_PROMPTS: ['a photo of a {}.'] - EPOCHS: 200 + EPOCHS: 400 BIASED_VAL: DATA: diff --git a/configs/Waterbirds/mlpzs.yaml b/configs/Waterbirds/mlpzs.yaml index 42f127e..fc2a1f7 100644 --- a/configs/Waterbirds/mlpzs.yaml +++ b/configs/Waterbirds/mlpzs.yaml @@ -5,7 +5,7 @@ EXP: SEED: 0 TEXT_PROMPTS: ['a photo of a {} on forest.', 'a photo of a {} on water.'] NEUTRAL_TEXT_PROMPTS: ['a photo of a {}.'] - EPOCHS: 200 + EPOCHS: 400 TEMPLATES: 'waterbirds_templates2' DATA: diff --git a/configs/Waterbirds/new_lads.yaml b/configs/Waterbirds/new_lads.yaml index 687fa46..abaa3bd 100644 --- a/configs/Waterbirds/new_lads.yaml +++ b/configs/Waterbirds/new_lads.yaml @@ -6,7 +6,7 @@ EXP: TEXT_PROMPTS: [['a photo of a {} on forest.'], ['a photo of a {} on water.']] NEUTRAL_TEXT_PROMPTS: ['a photo of a {} on forest.', 'a photo of a {} on water.'] AUGMENTATION: 'LADSBias' - EPOCHS: 200 + EPOCHS: 400 CHECKPOINT_VAL: True ENSAMBLE: False diff --git a/configs/base.yaml b/configs/base.yaml index 56f5d66..6d30fa5 100644 --- a/configs/base.yaml +++ b/configs/base.yaml @@ -23,6 +23,7 @@ DATA: UPWEIGHT_DOMAINS: False UPWEIGHT_CLASSES: True MODEL_DIM: 1024 + ROOT: /shared/lisabdunlap/data METHOD: MODEL: diff --git a/helpers/data_helpers.py b/helpers/data_helpers.py index 5368ccd..0adcb7a 100644 --- a/helpers/data_helpers.py +++ b/helpers/data_helpers.py @@ -210,6 +210,6 @@ def get_domain(dataset_name): def get_class(dataset_name): return DATASET_CLASSES[dataset_name] -def get_cache_file(dataset_name, model_name='ViT-B/32', biased_val=True, model_type='clip'): +def get_classes(dataset_name, model_name='ViT-B/32', model_type='clip'): assert dataset_name in DATASET_PATHS[model_type][model_name].keys(), f"{dataset_name} is not cached or not added to the DATASET_PATHS dict in helpers/dataset_helpers.py" return DATASET_PATHS[model_type][model_name][dataset_name], DATASET_CLASSES[dataset_name], DATASET_DOMAINS[dataset_name] \ No newline at end of file diff --git a/helpers/data_paths.py b/helpers/data_paths.py deleted file mode 100644 index 8eaa32c..0000000 --- a/helpers/data_paths.py +++ /dev/null @@ -1,13 +0,0 @@ -DATASET_PATHS = { - "clip":{ - "ViT-L/14": { - "CUB": "embeddings/CUB/clip_openai_ViT-L_14.pt", - "DomainNetMini": "embeddings/DomainNetMini/clip_openai_ViT-L_14.pt", - "Waterbirds": "embeddings/Waterbirds/clip_openai_ViT-L_14.pt", - }, - }, - "openclip":{ - } - - -} \ No newline at end of file diff --git a/main.py b/main.py index d42d987..3a568de 100644 --- a/main.py +++ b/main.py @@ -11,6 +11,7 @@ import random import omegaconf from omegaconf import OmegaConf +import collections import helpers.data_helpers as dh import methods.clip_transformations as CLIPTransformations @@ -37,7 +38,7 @@ args.yaml = flags.config assert args.EXP.ADVICE_METHOD != 'CNN', "main.py not for CNN baseline, use train.py" -assert args.EXP.ADVICE_METHOD != 'CLIPZS', "main.py not for CLIP zero-shot, use clip_zs.py" +# assert args.EXP.ADVICE_METHOD != 'CLIPZS', "main.py not for CLIP zero-shot, use clip_zs.py" if args.EXP.WANDB_SILENT: os.environ['WANDB_SILENT']="true" @@ -66,13 +67,14 @@ def flatten_config(dic, running_key=None, flattened_dict={}): DATASET_NAME = args.DATA.DATASET # load data -if args.DATA.LOAD_CACHED: - print(args.DATA.LOAD_CACHED) +cache_file = f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}/{args.EXP.IMAGE_FEATURES}_{args.EXP.CLIP_PRETRAINED_DATASET}_{args.EXP.CLIP_MODEL.replace('/','_')}.pt" +if os.path.exists(cache_file): if args.EXP.IMAGE_FEATURES == 'clip' or args.EXP.IMAGE_FEATURES == 'openclip': model_name = args.EXP.CLIP_MODEL else: model_name = args.EXP.IMAGE_FEATURES cache_file, dataset_classes, dataset_domains = dh.get_cache_file(DATASET_NAME, model_name, args.EXP.BIASED_VAL, args.EXP.IMAGE_FEATURES) + print(f"loading embeddings from {cache_file}") assert os.path.exists(cache_file), f"{cache_file} does not exist. To compute embeddings, set DATA.LOAD_CACHED=False" data = torch.load(cache_file) train_features, train_labels, train_groups, train_domains, train_filenames = data['train_features'], data['train_labels'], data['train_groups'], data['train_domains'], data['train_filenames'] @@ -82,6 +84,11 @@ def flatten_config(dic, running_key=None, flattened_dict={}): if args.DATA.DATASET != 'ColoredMNISTBinary': val_features, val_labels, val_groups, val_domains, val_filenames = data['val_features'][::2], data['val_labels'][::2], data['val_groups'][::2], data['val_domains'][::2], data['val_filenames'][::2] test_features, test_labels, test_groups, test_domains, test_filenames = np.concatenate((data['test_features'], data['val_features'][1::2])), np.concatenate((data['test_labels'], data['val_labels'][1::2])), np.concatenate((data['test_groups'], data['val_groups'][1::2])), np.concatenate((data['test_domains'], data['val_domains'][1::2])), np.concatenate((data['test_filenames'], data['val_filenames'][1::2])) + + # print out group counts + print("Train groups:", collections.Counter(train_groups)) + print("Val groups:", collections.Counter(val_groups)) + print("Test groups:", collections.Counter(test_groups)) if args.METHOD.NORMALIZE: train_features /= np.linalg.norm(train_features, axis=-1, keepdims=True) val_features /= np.linalg.norm(val_features, axis=-1, keepdims=True) @@ -163,7 +170,9 @@ def flatten_config(dic, running_key=None, flattened_dict={}): print("Training set augmented!") print("SIZE of embeddings ", train_features.shape, train_domains.shape) -if args.EXP.LOG_NN: + +# Logs the Nearest Neighbors in the Extended Domain +if args.EXP.LOG_NN and args.EXP.ADVICE_METHOD != 'CLIPZS': features, labels, groups, domains, filenames = np.concatenate([old_val_features, old_test_features]), np.concatenate([old_val_labels, old_test_labels]), np.concatenate([old_val_groups, old_test_groups]), np.concatenate([old_val_domains, old_test_domains]), np.concatenate([old_val_filenames, old_test_filenames]) # features, labels, groups, domains, filenames = old_test_features, old_test_labels, old_test_groups, old_test_domains, old_test_filenames if len(np.unique(train_domains)) > 1: diff --git a/methods/augmentations.py b/methods/augmentations.py index 47280b3..9ce0f34 100644 --- a/methods/augmentations.py +++ b/methods/augmentations.py @@ -27,7 +27,6 @@ import helpers.data_helpers as dh from methods.clip_transformations import EmbeddingDataset from clip_utils import * -from methods import predictors import omegaconf from omegaconf import OmegaConf @@ -212,7 +211,10 @@ def augment_single(self, img_embedding, label): return list(aug_embedding) class SLERP(Addition): - + """ + Spherical Linear Interpolation (might be buggy). This was the original idea but didn't work as well as a learned method. + It's possible this could work better with a much better model + """ def __init__(self, cfg, image_features, labels, group_labels, domain_labels, filenames, text_features): super().__init__(cfg, image_features, labels, group_labels, domain_labels, filenames, text_features) # assumes text features is a list @@ -222,12 +224,6 @@ def get_interp(self, img_embedding, text_features, orig_text_embeddings): Augments a single by taking the shperical interpolation of the image emb and text emn """ return [self.slerp(img_embedding, feat, self.alpha) for feat in text_features] - - # def get_text_embeddings(self, model_name, prompts, norm=True): - # text_embs = zeroshot_classifier(prompts, self.model, model_type=self.cfg.EXP.IMAGE_FEATURES).cpu().numpy().T - # if norm: - # text_embs /= np.linalg.norm(text_embs, axis=-1, keepdims=True) - # return text_embs @staticmethod def slerp(p0, p1, t): @@ -235,7 +231,6 @@ def slerp(p0, p1, t): so = np.sin(omega) return np.sin((1.0-t)*omega) / so * p0 + np.sin(t*omega)/so * p1 -# TO-DO: Implement learned augmentation using CLIP_{direction} loss function class MLP(nn.Module): def __init__(self, input_dim=768, hidden_dim=384): @@ -249,302 +244,6 @@ def forward(self, x): x = self.fc2(x) return x -class Directional(Augment): - def __init__(self, cfg, image_features, labels, group_labels, domain_labels, filenames, text_features, val_image_features, val_labels, val_group_labels,val_domain_labels, val_filenames): - super().__init__(cfg, image_features, labels, group_labels, domain_labels, filenames, text_features) - dataset = EmbeddingDataset(self.cfg, self.image_features, self.labels, self.group_labels, self.domain_labels) - self.dataset = dataset - self.train_loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=True) - - val_dataset = EmbeddingDataset(self.cfg, val_image_features, val_labels, val_group_labels, val_domain_labels) - self.val_dataset = val_dataset - self.val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=True) - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.nets = [] - self.net_checkpoints = [] - self.uid = uuid.uuid4() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if self.cfg.EXP.IMAGE_FEATURES == 'clip': - model, preprocess = clip.load(self.cfg.EXP.CLIP_MODEL, device) - elif self.cfg.EXP.IMAGE_FEATURES == 'openclip': - model, _, preprocess = open_clip.create_model_and_transforms(self.cfg.EXP.CLIP_MODEL, pretrained=self.cfg.EXP.CLIP_PRETRAINED_DATASET) - model = model.to(torch.device('cuda:1')) - # model, preprocess = clip.load(self.cfg.EXP.CLIP_MODEL, device) - model.eval() - self.model = model - if self.cfg.DATA.DATASET == 'ColoredMNISTBinary': - text_embs = zeroshot_classifier([[f'a photo of the number "{c}"'] for c in self.class_names], model, model_type=self.cfg.EXP.IMAGE_FEATURES) - else: - text_embs = zeroshot_classifier([[f"a photo of a {c}"] for c in self.class_names], model, model_type=self.cfg.EXP.IMAGE_FEATURES) - - self.class_text_embs = text_embs.float().cuda() - print("text emb shape ", self.class_text_embs.shape) - - try: - self.orig_prompts = torch.Tensor(self.get_orig_text_embeddings(self.prompts).transpose((1, 0, 2))).float().cuda() - self.neutral_embs = torch.Tensor(self.get_orig_text_embeddings(self.neutral_prompts).transpose((1, 0, 2))).float().cuda() - # if len(self.orig_prompts) > 1: - # raise ValueError("this only works for one domain shift atm") - print('=========================') - print('=========================') - print("task specific text emb ", self.orig_prompts.shape, self.orig_prompts[0].shape, torch.transpose(self.orig_prompts[0].float().cuda(), 1, 0).shape, self.class_text_embs.shape) - print('=========================') - print('=========================') - # stack = [torch.mean(self.neutral_embs, dim=1)] + torch.mean(self.orig_prompts, dim=1) - self.val_dom_check = torch.squeeze(torch.cat([torch.mean(self.neutral_embs, dim=1), torch.mean(self.orig_prompts, dim=1)]), dim=1).float().cuda() - self.val_dom_check = torch.transpose(self.val_dom_check, 0, 1) - print("val domain check shape ", self.val_dom_check.shape, self.class_text_embs.shape) - except: - print("can't load prompts") - - if not self.cfg.AUGMENTATION.GENERIC: - - if self.cfg.AUGMENTATION.DOM_SPECIFIC_XE: - print("DOMAIN SPECIFIC PROMPTS") - print("task specific text emb ", self.orig_prompts.shape) - for i in range(len(self.orig_prompts)): - self.train_network("sketch", self.orig_prompts[:,i], i) - else: - for i in range(len(self.text_features[0])): - self.train_network("sketch", self.text_features[:,i], i) - else: - for i in range(len(self.text_features)): - self.train_network("sketch", self.text_features[i], i) - - for net in self.nets: - net.eval() - - def directional_loss_builder(self, num_net): - """ - CLIP directional loss from gan NADA paper. Ensures that the difference in - image embeddings is similar to the difference in text embeddings of the - source and target domain. - """ - if not self.cfg.AUGMENTATION.GENERIC: - delta_t = torch.Tensor(self.text_features[:,num_net]) - else: - delta_t = torch.Tensor(self.text_features[num_net]) - delta_t = delta_t.type(torch.float).cuda() - - def custom_loss(predictions, labels, targets): - total_sum = None - delta_i = predictions - labels - ctr = 0 - for i, l in zip(delta_i, targets): - if not self.cfg.AUGMENTATION.GENERIC: - delta_tt = delta_t[l] - else: - delta_tt = delta_t - ctr += 1 - if total_sum == None: - numerator = torch.dot(i, delta_tt) - denominator = torch.norm(i) * torch.norm(delta_tt) - total_sum = 1 - (numerator/denominator) - else: - total_sum += 1 - (torch.dot(i, delta_tt)/ (torch.norm(i) * torch.norm(delta_tt))) - return total_sum / ctr - return custom_loss - - @staticmethod - def get_class_logits(outputs, class_embs): - outputs_norm = outputs / outputs.norm(dim=-1, keepdim=True) - return torch.matmul(outputs_norm, class_embs) - - def train_network(self, source, target, num_net): - net = MLP(hidden_dim=self.cfg.AUGMENTATION.MODEL.HIDDEN_DIM, input_dim=self.dataset.embedding_dim) - self.nets.append(net.cuda()) - self.net_checkpoints.append("") - - cudnn.benchmark = True - self.optimizer = AdamW(self.nets[num_net].parameters(), lr=self.cfg.AUGMENTATION.MODEL.LR, weight_decay=self.cfg.AUGMENTATION.MODEL.WEIGHT_DECAY) - self.directional_loss = self.directional_loss_builder(num_net) - self.class_consistency_loss = nn.CrossEntropyLoss(weight=self.dataset.class_weights.cuda()) - - if self.cfg.AUGMENTATION.CLIP_NN_LOSS: - self.clip_nn_loss = nn.CrossEntropyLoss() - - self.nets[num_net].train() - - best_train_loss, best_epoch = 10000, 0 - for epoch in range(self.cfg.AUGMENTATION.EPOCHS): - train_metrics = self.training_loop(num_net, epoch) - val_metrics = self.eval_loop(num_net, epoch) - if val_metrics['val loss'] < best_train_loss: - best_train_loss = val_metrics['val loss'] - best_epoch = epoch - self.net_checkpoints[num_net] = self.save_checkpoint(best_train_loss, epoch, num_net) - - wandb.summary[f"{self.prompts[num_net]} best epoch"] = best_epoch - wandb.summary[f"{self.prompts[num_net]} best train_loss"] = best_train_loss - print(f"==> loading checkpoint {self.net_checkpoints[num_net]} at epoch {best_epoch} with loss {best_train_loss}") - self.nets[num_net] = self.load_checkpoint(self.nets[num_net], self.net_checkpoints[num_net]) - - def get_nn(self, inputs_unnorm, samples_unnorm, labels, cs=False): - """ Gets nearest neighbor of that same class for each img_emb""" - inputs = inputs_unnorm / inputs_unnorm.norm(dim=-1, keepdim=True) - samples = samples_unnorm / samples_unnorm.norm(dim=-1, keepdim=True) - nns, nn_dot_prod = [], [] - for i, input in enumerate(inputs): - if self.cfg.AUGMENTATION.NN_INCLUDE_SAMPLE: - assert self.cfg.AUGMENTATION.COMPARE_BEFORE_AUG, "Must compare before augmentation" - nn_features = samples - else: - nn_features = torch.cat([samples[0:i], samples[i+1:]]) - dot_prod = input @ nn_features.T - dot_prod = dot_prod * np.exp(0.007) - nns.append(torch.argmax(dot_prod)) - nn_dot_prod.append(dot_prod) - return torch.stack(nns).long(), torch.stack(nn_dot_prod) - - def training_loop(self, num_net, epoch): - self.nets[num_net].train() - train_directional_loss, train_class_loss, train_nn_loss, train_loss, cls_correct, total = 0, 0, 0, 0, 0, 0 - for i, (inp, cls_target, cls_group, dom_target) in enumerate(self.train_loader): - inp, cls_target= inp.cuda().float(), cls_target.cuda().long() - self.optimizer.zero_grad() - cls_outputs = self.nets[num_net](inp) - # compute directional loss - directional_loss = self.cfg.AUGMENTATION.DOM_WEIGHT * self.directional_loss(cls_outputs, inp, cls_target) - - if not self.cfg.AUGMENTATION.GENERIC: - if self.cfg.AUGMENTATION.DOM_SPECIFIC_XE: - cls_logits = self.get_class_logits(cls_outputs, torch.transpose(self.orig_prompts[num_net].float().cuda(), 1, 0)) - else: - cls_logits = self.get_class_logits(cls_outputs, self.class_text_embs) - cls_consist = self.class_consistency_loss(cls_logits, cls_target) - loss = self.alpha * directional_loss + (1 - self.alpha) * cls_consist - train_class_loss += (1 - self.alpha) * cls_consist - train_directional_loss += self.alpha * directional_loss.item() - else: - train_directional_loss += directional_loss.item() - # wandb.log({"directional loss": directional_loss.item()}) - loss = directional_loss - - if self.cfg.AUGMENTATION.CLIP_NN_LOSS: - nn_labels, _ = self.get_nn(inp, inp, cls_target) - if self.cfg.AUGMENTATION.COMPARE_BEFORE_AUG: - _, nn_logits = self.get_nn(cls_outputs, inp, cls_target) - else: - _, nn_logits = self.get_nn(cls_outputs, cls_outputs, cls_target) - nn_loss = self.cfg.AUGMENTATION.NN_WEIGHT * self.clip_nn_loss(nn_logits, nn_labels) - loss += nn_loss - train_nn_loss += nn_loss.item() - - loss.backward(retain_graph=True) - self.optimizer.step() - - train_loss += loss.item() - - total += cls_target.size(0) - progress_bar(i, len(self.train_loader), 'Loss: %.3f' - % (train_loss/(i+1))) - - metrics = {"train class loss": train_class_loss/(i+1), "train directional loss": train_directional_loss/(i+1), "train nn loss": train_nn_loss/(i+1), "train loss": train_loss/(i+1), "epoch": epoch} - wandb.log(metrics) - return metrics - - def eval_loop(self, num_net, epoch): - """ - Checkpoint aug netowrk on the eval set. Try both minizing the loss and maximizing the zeroshot accuracy with the correct class and domain. - """ - m = nn.Softmax(dim=1) - self.nets[num_net].eval() - nn_correct, dom_correct, total = 0, 0, 0 - train_directional_loss, train_class_loss, train_nn_loss, train_loss, cls_correct, total = 0, 0, 0, 0, 0, 0 - for i, (inp, cls_target, cls_group, dom_target) in enumerate(self.val_loader): - with torch.no_grad(): - inp, cls_target= inp.cuda().float(), cls_target.cuda().long() - cls_outputs = self.nets[num_net](inp) - # compute directional loss - # directional_loss = self.directional_loss(cls_outputs, inp, cls_target) - try: - cls_logits = self.get_class_logits(cls_outputs, self.val_dom_check) - _, cls_predicted = m(cls_logits).max(1) - targets = torch.Tensor([1 for _ in range(cls_target.size(0))]).cuda().float() - dom_correct += (cls_predicted == targets).sum().item() - except: - dom_correct = 0 - - nn_labels, _ = self.get_nn(inp, inp, cls_target) - nns, nn_logits = self.get_nn(cls_outputs, inp, cls_target) - nn_correct += nns.eq(nn_labels).sum().item() - - cls_outputs = self.nets[num_net](inp) - # compute directional loss - directional_loss = self.directional_loss(cls_outputs, inp, cls_target) - cls_logits = self.get_class_logits(cls_outputs, self.class_text_embs) - cls_consist = self.class_consistency_loss(cls_logits, cls_target) - loss = self.cfg.AUGMENTATION.DOM_WEIGHT * self.alpha * directional_loss + (1 - self.alpha) * cls_consist - train_directional_loss += self.cfg.AUGMENTATION.DOM_WEIGHT * self.alpha * directional_loss.item() - - if self.cfg.AUGMENTATION.CLIP_NN_LOSS: - nn_labels, _ = self.get_nn(inp, inp, cls_target) - if self.cfg.AUGMENTATION.COMPARE_BEFORE_AUG: - _, nn_logits = self.get_nn(cls_outputs, inp, cls_target) - else: - _, nn_logits = self.get_nn(cls_outputs, cls_outputs, cls_target) - nn_loss = self.cfg.AUGMENTATION.NN_WEIGHT * self.clip_nn_loss(nn_logits, nn_labels) - train_nn_loss += nn_loss.item() - - train_loss += loss.item() - - total += cls_target.size(0) - - - metrics = {"val loss": train_loss/(i+1), "val dom acc": dom_correct/total, "val nn acc": nn_correct/total, "epoch": epoch} - wandb.log(metrics) - return metrics - - def eval_lang_loop(self, num_net, epoch): - self.nets[num_net].eval() - train_directional_loss, train_class_loss, train_nn_loss, train_loss, cls_correct, total = 0, 0, 0, 0, 0, 0 - for i, (inp, cls_target, cls_group, dom_target) in enumerate(self.val_loader): - inp, cls_target= inp.cuda().float(), cls_target.cuda().long() - return - - def augment_single(self, img_embedding, label): - keep = img_embedding - if self.cfg.AUGMENTATION.INCLUDE_ORIG_TRAINING: - output = [keep] - else: - output = [] - img_embedding = torch.tensor(img_embedding) - img_embedding = img_embedding.type(torch.float32) - img_embedding = img_embedding.cuda() - img_embedding /= img_embedding.norm(dim=-1, keepdim=True) - for net in self.nets: - o = net(img_embedding) - o /= o.norm(dim=-1, keepdim=True) - - o = o.detach().cpu().numpy() - output.append(o) - wandb.log({"cos sim:": distance.cosine(o, img_embedding.cpu())}) - # output = self.net(img_embedding) - # val = torch.tensor(output) - return list(np.array(output)) - - def save_checkpoint(self, acc, epoch, num_net): - checkpoint_dir = os.path.join("./checkpoint", self.cfg.DATA.DATASET) - if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir) - path = f'./checkpoint/{self.cfg.DATA.DATASET}/{self.prompts[num_net]}-{self.cfg.EXP.SEED}-{self.uid}.pth' - print(f'Saving checkpoint with acc {acc} to {path}...') - state = { - "acc": acc, - "epoch": epoch, - "net": self.nets[num_net].state_dict() - } - torch.save(state, path) - # wandb.save(path) - return path - - def load_checkpoint(self, net, path): - checkpoint = torch.load(path) - net.load_state_dict(checkpoint['net']) - print(f"...loaded checkpoint with acc {checkpoint['acc']}") - return net - from clip_utils import get_domain_text_embs class DirectionLoss(torch.nn.Module): @@ -872,205 +571,4 @@ def augment_dataset(self): augmented_group_labels += [self.group_labels[i], self.get_inv(self.group_labels[i])] augmented_filenames += [self.filenames[i], self.filenames[i]] return np.array(augmented_features), np.array(augmented_labels), np.array(augmented_domain_labels), np.array(augmented_group_labels), np.array(augmented_filenames) - - -class BiasDirectional(Directional): - """ - This implements the similar directional loss as the directional class, but routes examples - based on clips classification of their domain. - """ - - def __init__(self, cfg, image_features, labels, group_labels, domain_labels, filenames, text_features, val_image_features, val_labels, val_group_labels,val_domain_labels, val_filenames): - super(Directional, self).__init__(cfg, image_features, labels, group_labels, domain_labels, filenames, text_features) - if self.cfg.AUGMENTATION.DOM_SPECIFIC_XE: - self.orig_prompts = torch.Tensor(self.get_orig_text_embeddings(self.prompts).transpose((1, 0, 2))).float().cuda() - print("orig prompts shape ", self.orig_prompts.shape) - self.text_features = torch.mean(self.orig_prompts, dim=1) - text_features = self.text_features - else: - self.text_features = torch.tensor(text_features).float().cuda() - print("text features shape ", self.text_features.shape) - - print("text features", text_features.shape) - dataset = EmbeddingDataset(self.cfg, self.image_features, self.labels, self.group_labels, self.domain_labels, text_emb=self.text_features) - self.dataset = dataset - self.train_loader = torch.utils.data.DataLoader(dataset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=True) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.nets = [] - self.net_checkpoints = [] - self.domain_indexes = [0] - self.uid = uuid.uuid4() - - if self.cfg.AUGMENTATION.DOM_SPECIFIC_XE: - self.class_text_embs = self.orig_prompts - else: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if self.cfg.EXP.IMAGE_FEATURES == 'clip': - model, preprocess = clip.load(self.cfg.EXP.CLIP_MODEL, device) - elif self.cfg.EXP.IMAGE_FEATURES == 'openclip': - model, _, preprocess = open_clip.create_model_and_transforms(self.cfg.EXP.CLIP_MODEL, pretrained=self.cfg.EXP.CLIP_PRETRAINED_DATASET) - model = model.to(torch.device('cuda:1')) - # model, preprocess = clip.load(self.cfg.EXP.CLIP_MODEL, device) - model.eval() - self.model = model - text_embs = zeroshot_classifier([[f"a photo of a {c}"] for c in self.class_names], model, model_type=self.cfg.EXP.IMAGE_FEATURES, cuda_device='1') - self.class_text_embs = text_embs.float().cuda() - - print("text emb shape ", self.class_text_embs.shape) - - print("text features ", self.text_features.shape) - self.train_network("sketch", self.text_features, 0) - - for net in self.nets: - net.eval() - - def directional_loss_builder(self, num_net): - """ - CLIP directional loss from gan NADA paper. Ensures that the difference in - image embeddings is similar to the difference in text embeddings of the - source and target domain. - This modification changes the delta depending on - """ - def custom_loss(predictions, labels, targets, domain_labels): - total_sum = None - delta_i = predictions - labels - ctr = 0 - for i, d, l in zip(delta_i, domain_labels, targets): - if d == 0: - # delta_tt = self.text_features[1] - self.text_features[0] - delta_tt = self.text_features[1][l] - self.text_features[0][l] - else: - # delta_tt = self.text_features[0] - self.text_features[1] - delta_tt = self.text_features[0][l] - self.text_features[1][l] - try: - delta_tt /= np.linalg.norm(delta_tt, axis=-1, keepdims=True) - delta_tt = torch.Tensor(delta_tt).type(torch.float).cuda() - except: - delta_tt /= delta_tt.norm(dim=-1, keepdim=True) - ctr += 1 - if total_sum == None: - numerator = torch.dot(i, delta_tt) - denominator = torch.norm(i) * torch.norm(delta_tt) - total_sum = 1 - (numerator/denominator) - else: - total_sum += 1 - (torch.dot(i, delta_tt)/ (torch.norm(i) * torch.norm(delta_tt))) - return total_sum / ctr - return custom_loss - - def loss_builder(self): - cos = torch.nn.CosineSimilarity(dim=0) - def custom_loss(predictions, inputs, labels): - total_sum = None - ctr = 0 - for p, _, l in zip(predictions, inputs, labels): - # get class similarity - dom_one_sim = cos(p, torch.Tensor(self.text_features[0]).cuda()) - - # get sketch similarity - dom_two_sim = cos(p, torch.Tensor(self.text_features[1]).cuda()) - - loss = torch.abs(dom_one_sim - dom_two_sim) - if total_sum == None: - total_sum = loss - else: - total_sum += loss - ctr += 1 - return total_sum / ctr - return custom_loss - - @staticmethod - def get_class_logits(outputs, class_embs, dom_labels, dom_specific=False): - outputs_norm = outputs / outputs.norm(dim=-1, keepdim=True) - if dom_specific: - ret = [] - for o, d in zip(outputs, dom_labels): - idx = 1 if d == 0 else 0 - ret.append(torch.matmul(o, class_embs[idx].transpose(0, 1))) - return torch.stack(ret).cuda() - else: - return torch.matmul(outputs_norm, class_embs) - - def train_network(self, source, target, num_net): - net = MLP(hidden_dim=self.cfg.AUGMENTATION.MODEL.HIDDEN_DIM, input_dim=self.dataset.embedding_dim) - self.nets.append(net.cuda()) - self.net_checkpoints.append("") - - cudnn.benchmark = True - self.optimizer = AdamW(self.nets[num_net].parameters(), lr=self.cfg.AUGMENTATION.MODEL.LR, weight_decay=self.cfg.AUGMENTATION.MODEL.WEIGHT_DECAY) - self.directional_loss = self.directional_loss_builder(num_net) - self.class_consistency_loss = nn.CrossEntropyLoss(weight=self.dataset.class_weights.cuda()) - - if self.cfg.AUGMENTATION.CLIP_NN_LOSS: - self.clip_nn_loss = nn.CrossEntropyLoss() - - self.nets[num_net].train() - - best_train_loss, best_epoch = 10000, 0 - for epoch in range(self.cfg.AUGMENTATION.EPOCHS): - train_metrics = self.training_loop(num_net, epoch) - if train_metrics['train loss'] < best_train_loss: - best_train_loss = train_metrics['train loss'] - best_epoch = epoch - self.net_checkpoints[num_net] = self.save_checkpoint(best_train_loss, epoch, num_net) - - wandb.summary[f"{self.prompts[num_net]} best epoch"] = best_epoch - wandb.summary[f"{self.prompts[num_net]} best train_loss"] = best_train_loss - - def training_loop(self, num_net, epoch): - train_directional_loss, train_class_loss, train_nn_loss, train_loss, cls_correct, total = 0, 0, 0, 0, 0, 0 - for i, (inp, cls_target, cls_group, dom_target) in enumerate(self.train_loader): - inp, cls_target= inp.cuda().float(), cls_target.cuda().long() - self.optimizer.zero_grad() - cls_outputs = self.nets[num_net](inp) - # compute directional loss - directional_loss = self.directional_loss(cls_outputs, inp, cls_target, dom_target) - - cls_logits = self.get_class_logits(cls_outputs, self.class_text_embs, dom_target, self.cfg.AUGMENTATION.DOM_SPECIFIC_XE) - cls_consist = self.class_consistency_loss(cls_logits, cls_target) - loss = self.cfg.AUGMENTATION.DOM_WEIGHT * self.alpha * directional_loss + (1 - self.alpha) * cls_consist - train_class_loss += (1 - self.alpha) * cls_consist - train_directional_loss += self.cfg.AUGMENTATION.DOM_WEIGHT * self.alpha * directional_loss.item() - - if self.cfg.AUGMENTATION.CLIP_NN_LOSS: - nn_labels, _ = self.get_nn(inp, inp, cls_target) - if self.cfg.AUGMENTATION.COMPARE_BEFORE_AUG: - _, nn_logits = self.get_nn(cls_outputs, inp, cls_target) - else: - _, nn_logits = self.get_nn(cls_outputs, cls_outputs, cls_target) - nn_loss = self.cfg.AUGMENTATION.NN_WEIGHT * self.clip_nn_loss(nn_logits, nn_labels) - loss += nn_loss - train_nn_loss += nn_loss.item() - - loss.backward(retain_graph=True) - self.optimizer.step() - - train_loss += loss.item() - - total += cls_target.size(0) - progress_bar(i, len(self.train_loader), 'Loss: %.3f' - % (train_loss/(i+1))) - - metrics = {"train class loss": train_class_loss/(i+1), "train directional loss": train_directional_loss/(i+1), "train nn loss": train_nn_loss/(i+1), "train loss": train_loss/(i+1), "epoch": epoch} - wandb.log(metrics) - return metrics - - @staticmethod - def get_inv(label): - return 1 if label == 0 else 0 - - def augment_dataset(self): - """ - Augments the dataset - """ - augmented_features = [] - augmented_labels = [] - augmented_domain_labels = [] - augmented_group_labels = [] - augmented_filenames = [] - for i, feature in enumerate(self.image_features): - augmented_features += self.augment_single(feature, self.labels[i]) - augmented_labels += [self.labels[i], self.labels[i]] - augmented_domain_labels += [self.domain_labels[i], self.get_inv(self.domain_labels[i])] - augmented_group_labels += [self.group_labels[i], self.get_inv(self.group_labels[i])] - augmented_filenames += [self.filenames[i], self.filenames[i]] - return np.array(augmented_features), np.array(augmented_labels), np.array(augmented_domain_labels), np.array(augmented_group_labels), np.array(augmented_filenames) + \ No newline at end of file diff --git a/methods/clip_transformations.py b/methods/clip_transformations.py index 2c96e9a..fdb92ed 100644 --- a/methods/clip_transformations.py +++ b/methods/clip_transformations.py @@ -63,18 +63,6 @@ def __init__(self, text_prompts, model, cfg, neutral_prompts=[]): self.text_embeddings = target_embeddings - source_embeddings.mean(axis=0) else: self.text_embeddings = target_embeddings - source_embeddings - - # @staticmethod - # def compose_text_with_templates(text: str, templates=imagenet_templates) -> list: - # return [template.format(text) for template in templates] - - # @staticmethod - # def get_embedding(text_prompts, model): - # text_inputs = torch.cat([clip.tokenize(t) for t in text_prompts]).cuda() - # # Calculate features - # with torch.no_grad(): - # text_features = model.encode_text(text_inputs) - # return text_features.cpu().numpy() @staticmethod def normalize(inputs): @@ -89,6 +77,36 @@ def apply(self, inputs, labels=None): return self.normalize(inputs) return inputs + +class CLIPZS(Base): + """ + CLIPZS method. Computes CLIP embeddings and then applies a zero-shot classifier. + """ + def __init__(self, text_prompts, model, cfg, neutral_prompts=[]): + super().__init__(text_prompts, model, cfg, neutral_prompts) + templates = getattr(helpers.text_templates, cfg.EXP.TEMPLATES) + text_embs = zeroshot_classifier([[p.format(c) for p in templates] for c in self.class_names], model, model_type=self.cfg.EXP.IMAGE_FEATURES, cuda_device='1') + self.class_text_embs = text_embs.float().cuda() + print("class text embs", self.class_text_embs.shape) + + def train_debias(self, inputs, labels, groups, dom_gt, test_inputs, test_labels, test_groups, test_dom_gt): + pass + + def eval(self, inputs): + with torch.no_grad(): + preds, probs = np.array([]), [] + generator = chunks(torch.tensor(inputs).cuda().float(), self.cfg.DATA.BATCH_SIZE) + for i, images in enumerate(generator): + images = images.cuda() + images /= images.norm(dim=-1, keepdim=True) + # predict + logits = (100. * images @ self.class_text_embs).float().softmax(dim=-1) + clip_pred = torch.argmax(logits, dim=-1) + preds = np.append(preds, clip_pred.cpu().numpy()) + probs.append(logits.detach().cpu().numpy()) + return preds, np.concatenate(probs, axis=0) + + class EmbeddingDataset: """ Takes in CLIP embeddings (INPUTS), labels, and CLIP text embedding (TEXT_EMB of shape (num_domains, clip emb shape)). diff --git a/methods/predictors.py b/methods/predictors.py index 7e06aee..dd74069 100644 --- a/methods/predictors.py +++ b/methods/predictors.py @@ -9,26 +9,6 @@ from omegaconf import OmegaConf import torch.nn.functional as F -class Predictor(nn.Module): - def __init__(self, input_ch=32, num_classes=8): - super(Predictor, self).__init__() - self.pred_conv1 = nn.Conv2d(input_ch, input_ch, kernel_size=3, - stride=1, padding=1) - self.pred_bn1 = nn.BatchNorm2d(input_ch) - self.relu = nn.ReLU(inplace=True) - self.pred_conv2 = nn.Conv2d(input_ch, num_classes, kernel_size=3, - stride=1, padding=1) - self.softmax = nn.Softmax(dim=1) - - def forward(self, x): - x = self.pred_conv1(x) - x = self.pred_bn1(x) - x = self.relu(x) - x = self.pred_conv2(x) - px = self.softmax(x) - - return x,px - class MLP(nn.Module): def __init__(self, cfg): super(MLP, self).__init__() @@ -57,6 +37,9 @@ def forward(self, x): return h class MPLZS(MLP): + """ + MLP initialized with CLIP text embeddings + """ def __init__(self, cfg, text_embeddings): super(MPLZS, self).__init__(cfg) assert self.num_layers == 1, 'Only one layer supported' @@ -117,7 +100,9 @@ def _convert_weights_to_fp16(l): model.apply(_convert_weights_to_fp16) class CLIPFinetune(nn.Module): - + """ + Finetune the CLIP backbone. This theoretically should be usable.... + """ def __init__(self, clip_model, num_classes=8): super(CLIPFinetune, self).__init__() convert_weights(clip_model)