diff --git a/main.py b/main.py index 4681484..61ccde8 100644 --- a/main.py +++ b/main.py @@ -80,38 +80,42 @@ def flatten_config(dic, running_key=None, flattened_dict={}): model = model.to(device) # # load data -# if args.DATA.LOAD_CACHED: -# cache_file = f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}/{args.EXP.IMAGE_FEATURES}_{args.EXP.CLIP_PRETRAINED_DATASET}_{args.EXP.CLIP_MODEL.replace('/','_')}.pt" -# dataset_classes, dataset_domains = dh.DATASET_CLASSES[args.DATA.DATASET], dh.DATASET_DOMAINS[args.DATA.DATASET] -# assert os.path.exists(cache_file), f"{cache_file} does not exist. To compute embeddings, set DATA.LOAD_CACHED=False" -# print(f"Loading cached embeddings from {cache_file}") -# train_features, train_labels, train_groups, train_domains, train_filenames, val_features, val_labels, val_groups, val_domains, val_filenames, test_features, test_labels, test_groups, test_domains, test_filenames = load_embeddings(cache_file, args.DATA.DATASET) -cache_file = f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}/{args.EXP.IMAGE_FEATURES}_{args.EXP.CLIP_PRETRAINED_DATASET}_{args.EXP.CLIP_MODEL.replace('/','_')}.pt" -dataset_classes, dataset_domains = dh.DATASET_CLASSES[args.DATA.DATASET], dh.DATASET_DOMAINS[args.DATA.DATASET] -if os.path.exists(cache_file): +if args.DATA.LOAD_CACHED: + cache_file = "data/CUB/vit14_new.pth" + dataset_classes, dataset_domains = dh.DATASET_CLASSES[args.DATA.DATASET], dh.DATASET_DOMAINS[args.DATA.DATASET] + assert os.path.exists(cache_file), f"{cache_file} does not exist. To compute embeddings, set DATA.LOAD_CACHED=False" print(f"Loading cached embeddings from {cache_file}") train_features, train_labels, train_groups, train_domains, train_filenames, val_features, val_labels, val_groups, val_domains, val_filenames, test_features, test_labels, test_groups, test_domains, test_filenames = load_embeddings(cache_file, args.DATA.DATASET) else: - print(f"Computing embeddings and saving to {cache_file}") - trainset, valset, testset = dh.get_dataset(DATASET_NAME, preprocess, biased_val=args.EXP.BIASED_VAL) - dataset_classes, dataset_domains = dh.get_class(DATASET_NAME), dh.get_domain(DATASET_NAME) - train_loader = torch.utils.data.DataLoader(trainset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=True) - val_loader = torch.utils.data.DataLoader(valset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=False) - test_loader = torch.utils.data.DataLoader(testset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=False) - train_features, train_labels, train_groups, train_domains, train_filenames = get_features(train_loader, model, device, model_type=args.EXP.IMAGE_FEATURES) - val_features, val_labels, val_groups, val_domains, val_filenames = get_features(val_loader, model, device, model_type=args.EXP.IMAGE_FEATURES) - test_features, test_labels, test_groups, test_domains, test_filenames = get_features(test_loader, model, device, model_type=args.EXP.IMAGE_FEATURES) - save_dict = { - "train_features": train_features, "train_labels": train_labels, "train_groups": train_groups, "train_domains": train_domains, "train_filenames": train_filenames, - "val_features": val_features, "val_labels": val_labels, "val_groups": val_groups, "val_domains": val_domains, "val_filenames": val_filenames, - "test_features": test_features, "test_labels": test_labels, "test_groups": test_groups, "test_domains": test_domains, "test_filenames": test_filenames, - "seed": args.EXP.SEED - } - if not os.path.exists(f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}"): - os.makedirs(f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}") - cache_file = f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}/{args.EXP.IMAGE_FEATURES}_{args.EXP.CLIP_PRETRAINED_DATASET}_{args.EXP.CLIP_MODEL.replace('/','_')}.pt" - torch.save(save_dict, cache_file) - print(f"Saved CLIP embeddings to {cache_file}") + cache_file = "data/CUB/vit14_new.pth" + dataset_classes, dataset_domains = dh.DATASET_CLASSES[args.DATA.DATASET], dh.DATASET_DOMAINS[args.DATA.DATASET] + if os.path.exists(cache_file): + print(f"Loading cached embeddings from {cache_file}") + train_features, train_labels, train_groups, train_domains, train_filenames, val_features, val_labels, val_groups, val_domains, val_filenames, test_features, test_labels, test_groups, test_domains, test_filenames = load_embeddings(cache_file, args.DATA.DATASET) + else: + print(f"Computing embeddings and saving to {cache_file}") + trainset, valset, testset = dh.get_dataset(DATASET_NAME, preprocess, biased_val=args.EXP.BIASED_VAL) + dataset_classes, dataset_domains = dh.get_class(DATASET_NAME), dh.get_domain(DATASET_NAME) + train_loader = torch.utils.data.DataLoader(trainset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=True) + val_loader = torch.utils.data.DataLoader(valset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=False) + test_loader = torch.utils.data.DataLoader(testset, batch_size=cfg.DATA.BATCH_SIZE, shuffle=False) + train_features, train_labels, train_groups, train_domains, train_filenames = get_features(train_loader, model, device, model_type=args.EXP.IMAGE_FEATURES) + val_features, val_labels, val_groups, val_domains, val_filenames = get_features(val_loader, model, device, model_type=args.EXP.IMAGE_FEATURES) + test_features, test_labels, test_groups, test_domains, test_filenames = get_features(test_loader, model, device, model_type=args.EXP.IMAGE_FEATURES) + save_dict = { + "train_features": train_features, "train_labels": train_labels, "train_groups": train_groups, "train_domains": train_domains, "train_filenames": train_filenames, + "val_features": val_features, "val_labels": val_labels, "val_groups": val_groups, "val_domains": val_domains, "val_filenames": val_filenames, + "test_features": test_features, "test_labels": test_labels, "test_groups": test_groups, "test_domains": test_domains, "test_filenames": test_filenames, + "seed": args.EXP.SEED + } + if not os.path.exists(f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}"): + os.makedirs(f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}") + cache_file = "data/CUB/vit14_new.pth" + torch.save(save_dict, cache_file) + print(f"Saved CLIP embeddings to {cache_file}") + old_val_features, old_val_labels, old_val_groups, old_val_domains, old_val_filenames = val_features, val_labels, val_groups, val_domains, val_filenames + val_features, val_labels, val_groups, val_domains, val_filenames = val_features[::2], val_labels[::2], val_groups[::2], val_domains[::2], val_filenames[::2] + test_features, test_labels, test_groups, test_domains, test_filenames = np.concatenate((test_features, old_val_features[1::2])), np.concatenate((test_labels, old_val_labels[1::2])), np.concatenate((test_groups, old_val_groups[1::2])), np.concatenate((test_domains, old_val_domains[1::2])), np.concatenate((test_filenames, old_val_filenames[1::2])) if args.METHOD.NORMALIZE: train_features /= np.linalg.norm(train_features, axis=-1, keepdims=True) diff --git a/methods/predictors.py b/methods/predictors.py index 7a68335..6f322b0 100644 --- a/methods/predictors.py +++ b/methods/predictors.py @@ -518,7 +518,7 @@ def predict(self, image_feature): # def predict(self, img_embeddings, label=None): # return self.prompt_learner.predict(img_embeddings.cuda()) -from CLIP.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer +from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer from collections import OrderedDict _tokenizer = _Tokenizer()