diff --git a/README.md b/README.md index 6b85bff..3c813c9 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ The [currently available algorithms](./subpopbench/learning/algorithms.py) are: * Label-Distribution-Aware Margin Loss (**LDAM**, [Cao et al., 2019](https://arxiv.org/abs/1906.07413)) * Balanced Softmax (**BSoftmax**, [Ren et al., 2020](https://arxiv.org/abs/2007.10740)) * Classifier Re-Training (**CRT**, [Kang et al., 2020](https://arxiv.org/abs/1910.09217)) +* Uniform Risk Minimization (**URM**, [Krishnamachari et al., 2024](https://openreview.net/forum?id=PgLbS5yp8n)) Send us a PR to add your algorithm! Our implementations use the hyper-parameter grids [described here](./subpopbench/hparams_registry.py). diff --git a/subpopbench/hparams_registry.py b/subpopbench/hparams_registry.py index ff925e8..b208c3b 100644 --- a/subpopbench/hparams_registry.py +++ b/subpopbench/hparams_registry.py @@ -89,6 +89,23 @@ def _hparam(name, default_val, random_val_fn): _hparam('stage1_model', 'model.pkl', lambda r: 'model.pkl') _hparam('dfr_reg', .1, lambda r: 10**r.uniform(-2, 0.5)) + elif algorithm == 'URM': + _hparam('urm_lambda', 0.1, lambda r: float(r.uniform(0,0.2))) + + _hparam('urm_discriminator_hidden_layers', 1, lambda r: int(r.choice([1,2,3]))) + _hparam('urm_generator_output', 'tanh', lambda r: str(r.choice(['tanh', 'relu']))) + _hparam('urm_discriminator_update_freq', 1, lambda r: int(r.choice([1]))) + + if dataset in IMAGE_DATASETS + TABULAR_DATASET: + _hparam('urm_discriminator_lr', 1e-3, lambda r: 10**r.uniform(-5, -3)) + else: + _hparam('urm_discriminator_lr', 1e-5, lambda r: 10**r.uniform(-6, -5)) + + if dataset in TEXT_DATASETS: + _hparam('urm_discriminator_optimizer', 'adamw', lambda r: str(r.choice(['adamw']))) + else: + _hparam('urm_discriminator_optimizer', 'sgd', lambda r: str(r.choice(['sgd']))) + # Dataset-and-algorithm-specific hparam definitions # Each block of code below corresponds to exactly one hparam. Avoid nested conditionals diff --git a/subpopbench/learning/algorithms.py b/subpopbench/learning/algorithms.py index 24613a1..9eae462 100644 --- a/subpopbench/learning/algorithms.py +++ b/subpopbench/learning/algorithms.py @@ -36,7 +36,8 @@ 'BSoftmax', 'CRT', 'ReWeightCRT', - 'VanillaCRT' + 'VanillaCRT', + 'URM' ] @@ -174,6 +175,187 @@ def return_feats(self, x): def predict(self, x): return self.network(x) +class URM(ERM): + def __init__(self, data_type, input_shape, num_classes, num_attributes, num_examples, hparams, grp_sizes=None): + ERM.__init__(self, data_type, input_shape, num_classes, num_attributes, num_examples, hparams, grp_sizes=grp_sizes) + + self._setup_adversarial_net() + + def _modify_generator_output(self): + """ + Modifies the output activation of the encoder/featurizer + """ + print('--> Modifying encoder output:', self.hparams['urm_generator_output']) + + if self.hparams['urm_generator_output'] == 'tanh': + if self.data_type == 'images' and self.hparams['image_arch'] == 'resnet_sup_in1k': + self.featurizer.network.layer4[2].relu = nn.Tanh() + + elif self.data_type == 'text' and self.hparams['text_arch'] == 'bert-base-uncased': + # self.featurizer.activation = nn.Tanh() + # # it's already Tanh, no change needed + assert type(self.featurizer.model.pooler.activation) is torch.nn.modules.activation.Tanh + elif self.data_type == 'tabular': + self.featurizer.activation = nn.Tanh() + else: + raise Exception('unimplemented data_type: %s' % self.data_type) + + elif self.hparams['urm_generator_output'] == 'relu': + if self.data_type == 'images' and self.hparams['image_arch'] == 'resnet_sup_in1k': + pass # unchanged + elif self.data_type == 'text' and self.hparams['text_arch'] == 'bert-base-uncased': + # self.featurizer.activation = nn.ReLU() + self.featurizer.model.pooler.activation = nn.ReLU() + elif self.data_type == 'tabular': + self.featurizer.activation = nn.ReLU() + else: + raise Exception('unimplemented data_type: %s' % self.data_type) + + else: + raise Exception('unrecognized output activation: %s' % self.hparams['urm_generator_output']) + + # define min and max of output values + if self.hparams['urm_generator_output'] == 'tanh': + self.a, self.b = -1,1 + elif self.hparams['urm_generator_output'] == 'identity': + self.a, self.b = 0,1 + # a,b = self.hparams['urm_noise_range'][0], self.hparams['urm_noise_range'][1] + elif self.hparams['urm_generator_output'] == 'relu': + self.a, self.b = 0,1 + # self.a,self.b = self.hparams['urm_noise_range'][0], self.hparams['urm_noise_range'][1] + elif self.hparams['urm_generator_output'] in ['sigmoid']: + self.a, self.b = 0,1 + elif self.hparams['urm_generator_output'] in ['brelu']: + self.a, self.b = 0,3 + else: + raise Exception('unrecognized output activation: %s' % self.hparams['urm_generator_output']) + + def _setup_adversarial_net(self): + print('--> Initializing discriminator <--') + self.discriminator = self._init_discriminator() + + self.discriminator_loss = torch.nn.BCEWithLogitsLoss(reduction="mean") # apply on logit + + # featurizer optimized by self.optimizer only + if self.hparams["urm_discriminator_optimizer"] == 'sgd': + self.discriminator_optimizer = torch.optim.SGD(self.discriminator.parameters(), lr=self.hparams['urm_discriminator_lr'], \ + weight_decay=self.hparams['weight_decay'], momentum=0.9) + elif self.hparams["urm_discriminator_optimizer"] == 'adam': + self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams['urm_discriminator_lr']) + elif self.hparams["urm_discriminator_optimizer"] == 'adamw': + self.discriminator_optimizer = torch.optim.AdamW(self.discriminator.parameters(), lr=self.hparams['urm_discriminator_lr'], weight_decay=self.hparams['weight_decay']) + else: + raise Exception('%s unimplemented' % self.hparams["urm_discriminator_optimizer"]) + + self._modify_generator_output() + self.sigmoid = nn.Sigmoid() # to compute discriminator acc. + + def _init_discriminator(self): + """ + 3 hidden layer MLP + """ + model = nn.Sequential() + + model.add_module("dense1", nn.Linear(self.featurizer.n_outputs, 100)) + model.add_module("act1", nn.LeakyReLU()) + + for _ in range(self.hparams['urm_discriminator_hidden_layers']): + model.add_module("dense%d" % (2+_), nn.Linear(100, 100)) + model.add_module("act2%d" % (2+_), nn.LeakyReLU()) + + model.add_module("output", nn.Linear(100, 1)) # model outputs logit, used with BCEWithLogitsLoss (numerically more stable) + + return model + + def _generate_noise(self, feats): + """ + If U is a random variable uniformly distributed on [0, 1), then (b-a)*U + a is uniformly distributed on [a, b). + """ + uniform_noise = torch.rand(feats.size(), dtype=feats.dtype, layout=feats.layout, device=feats.device) # U~[0,1] + n = ((self.b-self.a) * uniform_noise) + self.a # n ~ [a,b) + + return n + + def _generate_soft_labels(self, size, device, a ,b): + # returns size random numbers in [a,b] + uniform_noise = torch.rand(size, device=device) # U~[0,1] + return ((b-a) * uniform_noise) + a + + def get_accuracy(self, y_true, y_prob): + # y_prob is binary probability + assert y_true.ndim == 1 and y_true.size() == y_prob.size() + y_prob = y_prob > 0.5 + return (y_true == y_prob).sum().item() / y_true.size(0) + + def _update_discriminator(self, i, x, y, a, step, feats): + feats = feats.detach() # don't backbrop through encoder in this step + noise = self._generate_noise(feats) + + noise_logits = self.discriminator(noise) # (N,1) + feats_logits = self.discriminator(feats) # (N,1) + + # hard targets + hard_true_y = torch.tensor([1] * noise.shape[0], device=noise.device, dtype=noise.dtype) # [1,1...1] noise is true + hard_fake_y = torch.tensor([0] * feats.shape[0], device=feats.device, dtype=feats.dtype) # [0,0...0] feats are fake (generated) + + true_y = hard_true_y + fake_y = hard_fake_y + + noise_loss = self.discriminator_loss(noise_logits.squeeze(1), true_y) # pass logits to BCEWithLogitsLoss + feats_loss = self.discriminator_loss(feats_logits.squeeze(1), fake_y) # pass logits to BCEWithLogitsLoss + + d_loss = 1*noise_loss + self.hparams['urm_lambda']*feats_loss + + # update discriminator + self.discriminator_optimizer.zero_grad() + d_loss.backward() + self.discriminator_optimizer.step() + + def _compute_loss(self, i, x, y, a, step): + self.activations = {} # reset activations + + feats = self.return_feats(x) + + classifier_output = self.classifier(feats) + + # train generator/encoder to make discriminator classify feats as noise (label 1) + true_y = torch.tensor(feats.shape[0]*[1], device=feats.device, dtype=feats.dtype) + + g_logits = self.discriminator(feats) + g_loss = self.discriminator_loss(g_logits.squeeze(1), true_y) # apply BCEWithLogitsLoss to discriminator's logit output + + loss = ce_loss + self.hparams['urm_lambda']*g_loss + + return loss, feats + + def predict(self, x): + # for inference, used in eval_helper.py + return self.network(x) + + def update(self, minibatch, step): + all_i, all_x, all_y, all_a = minibatch + + loss, feats = self._compute_loss(all_i, all_x, all_y, all_a, step) + + self.optimizer.zero_grad() + + loss.backward() + if self.clip_grad: + torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0) + self.optimizer.step() + + if self.lr_scheduler is not None: + self.lr_scheduler.step() + + if self.data_type == "text": + self.network.zero_grad() + + # update discriminator after updating encoder-classifier (alternating updates) + if (step % self.hparams['urm_discriminator_update_freq'] == 0): + self._update_discriminator(all_i, all_x, all_y, all_a, step, feats) + + return {'loss': loss.item()} + class GroupDRO(ERM): """ diff --git a/subpopbench/models/networks.py b/subpopbench/models/networks.py index 4928271..88ae575 100644 --- a/subpopbench/models/networks.py +++ b/subpopbench/models/networks.py @@ -28,6 +28,8 @@ def __init__(self, n_inputs, n_outputs, hparams): self.output = nn.Linear(hparams['mlp_width'], n_outputs) self.n_outputs = n_outputs + self.activation = nn.Identity() # added for URM, does not affect other algorithms + def forward(self, x): x = self.input(x) x = self.dropout(x) @@ -37,6 +39,9 @@ def forward(self, x): x = self.dropout(x) x = F.relu(x) x = self.output(x) + + x = self.activation(x) # added for URM, does not affect other algorithms + return x @@ -191,6 +196,8 @@ def __init__(self, model, hparams): ) self.dropout = nn.Dropout(classifier_dropout) + self.activation = nn.Identity() # added for URM, does not affect other algorithms + def forward(self, x): kwargs = { 'input_ids': x[:, :, 0], @@ -199,11 +206,15 @@ def forward(self, x): if x.shape[-1] == 3: kwargs['token_type_ids'] = x[:, :, 2] output = self.model(**kwargs) + if hasattr(output, 'pooler_output'): - return self.dropout(output.pooler_output) + output = self.dropout(output.pooler_output) else: - return self.dropout(output.last_hidden_state[:, 0, :]) + output = self.dropout(output.last_hidden_state[:, 0, :]) + + output = self.activation(output) # added for URM, does not affect other algorithms + return output def replace_module_prefix(state_dict, prefix, replace_with=""): state_dict = {