Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@ Official Implementation of [LADS (Latent Augmentation using Domain descriptionS)

*WARNING: this is still WIP, please raise an issue if you run into any bugs.*

```
@article{dunlap2023lads,
title={Using Language to Entend to Unseen Domains},
author = {Dunlap, Lisa and Mohri, Clara and Guillory, Devin and Zhang, Han and Darrell, Trevor and Gonzalez, Joseph E. and Raghunathan, Aditi and Rohrbach, Anja},
journal={International Conference on Learning Representations (ICLR)},
year={2023}
}
```

## Getting started

1. Install the dependencies for our code using Conda. You may need to adjust the environment YAML file depending on your setup.
Expand All @@ -26,7 +35,7 @@ python main.py --config configs/Waterbirds/mlp.yaml

Datasets supported are in the [helpers folder](./helpers/data_helpers.py). Currently they are:
* Waterbirds (100% and 95%) [our specific split](https://drive.google.com/file/d/1zJpQYGEt1SuwitlNfE06TFyLaWX-st1k/view) [code to generate data](https://github.com/kohpangwei/group_DRO)
* ColoredMNIST (LNTL version and simplified version) NOTEBOOK COMING SOON
* ColoredMNIST (LNTL version and simplified version) [Paper Dataset](https://drive.google.com/file/d/1GomKfufFrXIRAFJNedUBCDwHW2X9NTP-/view?usp=share_link)
* DomainNet (the version used in the paper is `DATA.DATASET=DomainNetMini`) [full dataset](http://ai.bu.edu/DomainNet/)
* CUB Paintings [photos dataset](https://www.vision.caltech.edu/datasets/cub_200_2011/) [paintings dataset](https://github.com/thuml/PAN)
* OfficeHome COMING SOON
Expand Down
85 changes: 0 additions & 85 deletions configs/CUB/dpl.yaml

This file was deleted.

3 changes: 2 additions & 1 deletion configs/CUB/lads.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ DATA:
DATASET: "CUB"
LOAD_CACHED: True
BATCH_SIZE: 256
SAVE_PATH: 'embeddings/CUB/recomputed.pth'

METHOD:
MODEL:
Expand All @@ -27,7 +28,7 @@ METHOD:

AUGMENTATION:
MODEL:
LR: 0.1
LR: 0.001
WEIGHT_DECAY: 0.05
NUM_LAYERS: 1
GENERIC: False
Expand Down
22 changes: 0 additions & 22 deletions configs/DomainNet/test.yaml

This file was deleted.

38 changes: 0 additions & 38 deletions configs/DomainNet/test_aug.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion configs/Waterbirds/lads.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ EXP:
TEXT_PROMPTS: [['a photo of a {} on forest.'], ['a photo of a {} on water.']]
NEUTRAL_TEXT_PROMPTS: []
AUGMENTATION: 'BiasDirectional'
EPOCHS: 200
EPOCHS: 400
CHECKPOINT_VAL: True
ENSAMBLE: False

Expand Down
2 changes: 1 addition & 1 deletion configs/Waterbirds/mlp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ EXP:
SEED: 0
TEXT_PROMPTS: ['a photo of a {} on forest.', 'a photo of a {} on water.']
NEUTRAL_TEXT_PROMPTS: ['a photo of a {}.']
EPOCHS: 200
EPOCHS: 400
BIASED_VAL:

DATA:
Expand Down
2 changes: 1 addition & 1 deletion configs/Waterbirds/mlpzs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ EXP:
SEED: 0
TEXT_PROMPTS: ['a photo of a {} on forest.', 'a photo of a {} on water.']
NEUTRAL_TEXT_PROMPTS: ['a photo of a {}.']
EPOCHS: 200
EPOCHS: 400
TEMPLATES: 'waterbirds_templates2'

DATA:
Expand Down
2 changes: 1 addition & 1 deletion configs/Waterbirds/new_lads.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ EXP:
TEXT_PROMPTS: [['a photo of a {} on forest.'], ['a photo of a {} on water.']]
NEUTRAL_TEXT_PROMPTS: ['a photo of a {} on forest.', 'a photo of a {} on water.']
AUGMENTATION: 'LADSBias'
EPOCHS: 200
EPOCHS: 400
CHECKPOINT_VAL: True
ENSAMBLE: False

Expand Down
1 change: 1 addition & 0 deletions configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ DATA:
UPWEIGHT_DOMAINS: False
UPWEIGHT_CLASSES: True
MODEL_DIM: 1024
ROOT: /shared/lisabdunlap/data

METHOD:
MODEL:
Expand Down
2 changes: 1 addition & 1 deletion helpers/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,6 @@ def get_domain(dataset_name):
def get_class(dataset_name):
return DATASET_CLASSES[dataset_name]

def get_cache_file(dataset_name, model_name='ViT-B/32', biased_val=True, model_type='clip'):
def get_classes(dataset_name, model_name='ViT-B/32', model_type='clip'):
assert dataset_name in DATASET_PATHS[model_type][model_name].keys(), f"{dataset_name} is not cached or not added to the DATASET_PATHS dict in helpers/dataset_helpers.py"
return DATASET_PATHS[model_type][model_name][dataset_name], DATASET_CLASSES[dataset_name], DATASET_DOMAINS[dataset_name]
13 changes: 0 additions & 13 deletions helpers/data_paths.py

This file was deleted.

17 changes: 13 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import random
import omegaconf
from omegaconf import OmegaConf
import collections

import helpers.data_helpers as dh
import methods.clip_transformations as CLIPTransformations
Expand All @@ -37,7 +38,7 @@
args.yaml = flags.config

assert args.EXP.ADVICE_METHOD != 'CNN', "main.py not for CNN baseline, use train.py"
assert args.EXP.ADVICE_METHOD != 'CLIPZS', "main.py not for CLIP zero-shot, use clip_zs.py"
# assert args.EXP.ADVICE_METHOD != 'CLIPZS', "main.py not for CLIP zero-shot, use clip_zs.py"

if args.EXP.WANDB_SILENT:
os.environ['WANDB_SILENT']="true"
Expand Down Expand Up @@ -66,13 +67,14 @@ def flatten_config(dic, running_key=None, flattened_dict={}):
DATASET_NAME = args.DATA.DATASET

# load data
if args.DATA.LOAD_CACHED:
print(args.DATA.LOAD_CACHED)
cache_file = f"{args.DATA.SAVE_PATH}/{args.DATA.DATASET}/{args.EXP.IMAGE_FEATURES}_{args.EXP.CLIP_PRETRAINED_DATASET}_{args.EXP.CLIP_MODEL.replace('/','_')}.pt"
if os.path.exists(cache_file):
if args.EXP.IMAGE_FEATURES == 'clip' or args.EXP.IMAGE_FEATURES == 'openclip':
model_name = args.EXP.CLIP_MODEL
else:
model_name = args.EXP.IMAGE_FEATURES
cache_file, dataset_classes, dataset_domains = dh.get_cache_file(DATASET_NAME, model_name, args.EXP.BIASED_VAL, args.EXP.IMAGE_FEATURES)
print(f"loading embeddings from {cache_file}")
assert os.path.exists(cache_file), f"{cache_file} does not exist. To compute embeddings, set DATA.LOAD_CACHED=False"
data = torch.load(cache_file)
train_features, train_labels, train_groups, train_domains, train_filenames = data['train_features'], data['train_labels'], data['train_groups'], data['train_domains'], data['train_filenames']
Expand All @@ -82,6 +84,11 @@ def flatten_config(dic, running_key=None, flattened_dict={}):
if args.DATA.DATASET != 'ColoredMNISTBinary':
val_features, val_labels, val_groups, val_domains, val_filenames = data['val_features'][::2], data['val_labels'][::2], data['val_groups'][::2], data['val_domains'][::2], data['val_filenames'][::2]
test_features, test_labels, test_groups, test_domains, test_filenames = np.concatenate((data['test_features'], data['val_features'][1::2])), np.concatenate((data['test_labels'], data['val_labels'][1::2])), np.concatenate((data['test_groups'], data['val_groups'][1::2])), np.concatenate((data['test_domains'], data['val_domains'][1::2])), np.concatenate((data['test_filenames'], data['val_filenames'][1::2]))

# print out group counts
print("Train groups:", collections.Counter(train_groups))
print("Val groups:", collections.Counter(val_groups))
print("Test groups:", collections.Counter(test_groups))
if args.METHOD.NORMALIZE:
train_features /= np.linalg.norm(train_features, axis=-1, keepdims=True)
val_features /= np.linalg.norm(val_features, axis=-1, keepdims=True)
Expand Down Expand Up @@ -163,7 +170,9 @@ def flatten_config(dic, running_key=None, flattened_dict={}):
print("Training set augmented!")
print("SIZE of embeddings ", train_features.shape, train_domains.shape)

if args.EXP.LOG_NN:

# Logs the Nearest Neighbors in the Extended Domain
if args.EXP.LOG_NN and args.EXP.ADVICE_METHOD != 'CLIPZS':
features, labels, groups, domains, filenames = np.concatenate([old_val_features, old_test_features]), np.concatenate([old_val_labels, old_test_labels]), np.concatenate([old_val_groups, old_test_groups]), np.concatenate([old_val_domains, old_test_domains]), np.concatenate([old_val_filenames, old_test_filenames])
# features, labels, groups, domains, filenames = old_test_features, old_test_labels, old_test_groups, old_test_domains, old_test_filenames
if len(np.unique(train_domains)) > 1:
Expand Down
Loading