diff --git a/info_models.py b/info_models.py index 40424cd..c62ba43 100644 --- a/info_models.py +++ b/info_models.py @@ -19,7 +19,7 @@ def forward(self, global_enc, local_enc): global_enc = global_enc.unsqueeze(1) # (batch, 1, global_size) global_enc = global_enc.expand(-1, local_enc.size(1), -1) # (batch, seq_len, global_size) # (batch, seq_len, global_size) * (batch, seq_len, local_size) -> (batch, seq_len, 1) - scores = self.bilinear(global_enc.contiguous(), local_enc.contiguous()) + scores = self.bilinear(global_enc.contiguous(), local_enc.contiguous()) # critic function return scores @@ -36,12 +36,16 @@ class SquadDIMLoss(nn.Module): Deep infomax loss for SQuAD question answering dataset. As the difference between GC and LC only lies in whether we do summarization over x, this class can be used as both GC and LC. + Instead of using BCELoss, we use SMILE lower bound to estimate mutual information + In our case, + `MI(X,Y) >= E_p[g(X,Y)] - (1/2)*log(E_N[exp(g((X^(\bar),Y)))]) - (1/2)*log(E_N[exp(g((X,Y^(\bar))))])` ''' def __init__(self, feature_size): super(SquadDIMLoss, self).__init__() self.discriminator = SquadDiscriminator(feature_size) self.summarize = Summarize() - self.bce_loss = nn.BCEWithLogitsLoss() + # self.bce_loss = nn.BCEWithLogitsLoss() + self.smile_lower_bound = smile_lower_bound() self.dropout = nn.Dropout(0.5) @@ -51,22 +55,23 @@ def forward(self, x_enc, x_fake, y_enc, y_fake, do_summarize=False): global_enc, local_enc: (batch, seq, dim) global_fake, local_fake: (batch, seq, dim) ''' + # Compute g(x, y) if do_summarize: x_enc = self.summarize(x_enc) x_enc = self.dropout(x_enc) y_enc = self.dropout(y_enc) - logits = self.discriminator(x_enc, y_enc) + logits = self.discriminator(x_enc, y_enc) batch_size1, n_seq1 = y_enc.size(0), y_enc.size(1) labels = torch.ones(batch_size1, n_seq1) - # Compute 1 - g(x, y^(\bar)) - y_fake = self.dropout(y_fake) + # Compute g(x, y^(\bar)) + y_fake = self.dropout(y_fake) _logits = self.discriminator(x_enc, y_fake) batch_size2, n_seq2 = y_fake.size(0), y_fake.size(1) _labels = torch.zeros(batch_size2, n_seq2) - logits, labels = torch.cat((logits, _logits), dim=1), torch.cat((labels, _labels), dim=1) + logits = torch.cat((logits, _logits), dim=1) # Compute 1 - g(x^(\bar), y) if do_summarize: @@ -75,7 +80,7 @@ def forward(self, x_enc, x_fake, y_enc, y_fake, do_summarize=False): _logits = self.discriminator(x_fake, y_enc) _labels = torch.zeros(batch_size1, n_seq1) - logits, labels = torch.cat((logits, _logits), dim=1), torch.cat((labels, _labels), dim=1) + logits = torch.cat((logits, _logits), dim=1) loss = self.bce_loss(logits.squeeze(2), labels.cuda()) diff --git a/log/events.out.tfevents.1620625258.cocoa.81980.0 b/log/events.out.tfevents.1620625258.cocoa.81980.0 new file mode 100644 index 0000000..cb60713 Binary files /dev/null and b/log/events.out.tfevents.1620625258.cocoa.81980.0 differ diff --git a/log/events.out.tfevents.1620628951.cocoa.37161.0 b/log/events.out.tfevents.1620628951.cocoa.37161.0 new file mode 100644 index 0000000..1f20b72 Binary files /dev/null and b/log/events.out.tfevents.1620628951.cocoa.37161.0 differ diff --git a/log/events.out.tfevents.1620629064.cocoa.39884.0 b/log/events.out.tfevents.1620629064.cocoa.39884.0 new file mode 100644 index 0000000..99fc706 Binary files /dev/null and b/log/events.out.tfevents.1620629064.cocoa.39884.0 differ diff --git a/log/events.out.tfevents.1620629321.cocoa.45909.0 b/log/events.out.tfevents.1620629321.cocoa.45909.0 new file mode 100644 index 0000000..a6259cd Binary files /dev/null and b/log/events.out.tfevents.1620629321.cocoa.45909.0 differ diff --git a/log/events.out.tfevents.1620629492.cocoa.49984.0 b/log/events.out.tfevents.1620629492.cocoa.49984.0 new file mode 100644 index 0000000..bf02a0d Binary files /dev/null and b/log/events.out.tfevents.1620629492.cocoa.49984.0 differ diff --git a/log/events.out.tfevents.1620791617.cocoa.4135.0 b/log/events.out.tfevents.1620791617.cocoa.4135.0 new file mode 100644 index 0000000..53c4f92 Binary files /dev/null and b/log/events.out.tfevents.1620791617.cocoa.4135.0 differ diff --git a/log/events.out.tfevents.1620793983.cocoa.59139.0 b/log/events.out.tfevents.1620793983.cocoa.59139.0 new file mode 100644 index 0000000..2a0d407 Binary files /dev/null and b/log/events.out.tfevents.1620793983.cocoa.59139.0 differ diff --git a/log/events.out.tfevents.1620795608.cocoa.97024.0 b/log/events.out.tfevents.1620795608.cocoa.97024.0 new file mode 100644 index 0000000..0e3c69b Binary files /dev/null and b/log/events.out.tfevents.1620795608.cocoa.97024.0 differ diff --git a/log/events.out.tfevents.1620796058.cocoa.107605.0 b/log/events.out.tfevents.1620796058.cocoa.107605.0 new file mode 100644 index 0000000..06da87f Binary files /dev/null and b/log/events.out.tfevents.1620796058.cocoa.107605.0 differ diff --git a/log/events.out.tfevents.1620796134.cocoa.109471.0 b/log/events.out.tfevents.1620796134.cocoa.109471.0 new file mode 100644 index 0000000..75e437c Binary files /dev/null and b/log/events.out.tfevents.1620796134.cocoa.109471.0 differ diff --git a/log/events.out.tfevents.1620796503.cocoa.118138.0 b/log/events.out.tfevents.1620796503.cocoa.118138.0 new file mode 100644 index 0000000..9a44336 Binary files /dev/null and b/log/events.out.tfevents.1620796503.cocoa.118138.0 differ diff --git a/log/events.out.tfevents.1620796732.cocoa.123577.0 b/log/events.out.tfevents.1620796732.cocoa.123577.0 new file mode 100644 index 0000000..f952f64 Binary files /dev/null and b/log/events.out.tfevents.1620796732.cocoa.123577.0 differ diff --git a/log/events.out.tfevents.1620797076.cocoa.1073.0 b/log/events.out.tfevents.1620797076.cocoa.1073.0 new file mode 100644 index 0000000..89daf51 Binary files /dev/null and b/log/events.out.tfevents.1620797076.cocoa.1073.0 differ diff --git a/log/events.out.tfevents.1620797577.cocoa.12915.0 b/log/events.out.tfevents.1620797577.cocoa.12915.0 new file mode 100644 index 0000000..9ad7ab3 Binary files /dev/null and b/log/events.out.tfevents.1620797577.cocoa.12915.0 differ diff --git a/log/events.out.tfevents.1620797645.cocoa.14671.0 b/log/events.out.tfevents.1620797645.cocoa.14671.0 new file mode 100644 index 0000000..962ef03 Binary files /dev/null and b/log/events.out.tfevents.1620797645.cocoa.14671.0 differ diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index c4ba710..8b53bcf 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -709,7 +709,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_al # positions we want to attend and -10000.0 for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 embedding_output = self.embeddings(input_ids, token_type_ids) @@ -776,7 +776,8 @@ def __init__(self, config): # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version # self.dropout = nn.Dropout(config.hidden_dropout_prob) self.qa_outputs = nn.Linear(config.hidden_size, 2) - + + #### insert smile-mi-estimator? (#rsk) self.global_infomax = SquadDIMLoss(config.hidden_size) # GC self.local_infomax = SquadDIMLoss(config.hidden_size) # LC @@ -818,7 +819,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_pos extend_end = min(sequence_output.size(1), end_positions[b_idx] + 5) ans_seq = sequence_output[b_idx, extend_start : extend_end, :] # (seq, hidden_size) ans_enc.append(ans_seq.unsqueeze(0)) - assert len(ans_enc) == len(context_enc) + assert len(ans_enc) == len(context_enc) # generate fake examples by shifting one index. context_fake = [context_enc[-1]] + context_enc[:-1] @@ -831,24 +832,27 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_pos c_enc, c_fake = context_enc[b_idx], context_fake[b_idx] a_enc, a_fake = ans_enc[b_idx], ans_fake[b_idx] - ## Compute GC ## + ## Compute GC ## Change to smile (#rsk) global_loss = global_loss + self.global_infomax(a_enc, a_fake, c_enc, c_fake, do_summarize=True) - ## Compute LC ## + ## Compute LC ## Change to smile (#rsk) # sample one rand_idx = np.random.randint(a_enc.size(1)) a_enc_word = a_enc[0, rand_idx, :] rand_idx = np.random.randint(a_fake.size(1)) a_enc_fake = a_fake[0, rand_idx, :] local_loss = local_loss + self.local_infomax(a_enc_word.unsqueeze(0), a_enc_fake.unsqueeze(0), a_enc, a_fake, do_summarize=False) - + + # Info loss info_loss = (0.5 * global_loss + local_loss) / len(ans_enc) info_loss = 0.25 * info_loss + # total loss loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss + info_loss) / 3 + #print('total_loss',total_loss,'info_loss',info_loss) return total_loss, info_loss else: return start_logits, end_logits, sequence_output diff --git a/run_squad.py b/run_squad.py index 7539fa1..c6223f4 100644 --- a/run_squad.py +++ b/run_squad.py @@ -990,19 +990,21 @@ def main(): if n_gpu == 1: batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self input_ids, input_mask, segment_ids, start_positions, end_positions = batch - + loss, info_loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions) + #print('total_loss',loss,'info_loss',info_loss) if n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. + info_loss = info_loss.mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps - if args.fp16: optimizer.backward(loss) else: loss.backward() if step % 50 == 0: - tqdm.write('Total Loss:{} Info Loss: {}'.format(loss.item(), info_loss.item())) + #print('total_loss',loss,'info_loss',info_loss) + tqdm.write('Total Loss:{0} Info Loss: {1}'.format(loss.item(), info_loss.item())) if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: # modify learning rate with special warm up BERT uses diff --git a/smile/estimators.py b/smile/estimators.py new file mode 100644 index 0000000..2b6d4c9 --- /dev/null +++ b/smile/estimators.py @@ -0,0 +1,164 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +def logmeanexp_diag(x, device='cuda'): + """Compute logmeanexp over the diagonal elements of x.""" + batch_size = x.size(0) + + logsumexp = torch.logsumexp(x.diag(), dim=(0,)) + num_elem = batch_size + + return logsumexp - torch.log(torch.tensor(num_elem).float()).to(device) + + +def logmeanexp_nodiag(x, dim=None, device='cuda'): + batch_size = x.size(0) + if dim is None: + dim = (0, 1) + + logsumexp = torch.logsumexp( + x - torch.diag(np.inf * torch.ones(batch_size).to(device)), dim=dim) + + try: + if len(dim) == 1: + num_elem = batch_size - 1. + else: + num_elem = batch_size * (batch_size - 1.) + except ValueError: + num_elem = batch_size - 1 + return logsumexp - torch.log(torch.tensor(num_elem)).to(device) + + +def tuba_lower_bound(scores, log_baseline=None): + if log_baseline is not None: + scores -= log_baseline[:, None] + + # First term is an expectation over samples from the joint, + # which are the diagonal elmements of the scores matrix. + joint_term = scores.diag().mean() + + # Second term is an expectation over samples from the marginal, + # which are the off-diagonal elements of the scores matrix. + marg_term = logmeanexp_nodiag(scores).exp() + return 1. + joint_term - marg_term + + +def nwj_lower_bound(scores): + return tuba_lower_bound(scores - 1.) + + +def infonce_lower_bound(scores): + nll = scores.diag().mean() - scores.logsumexp(dim=1) + # Alternative implementation: + # nll = -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=scores, labels=tf.range(batch_size)) + mi = torch.tensor(scores.size(0)).float().log() + nll + mi = mi.mean() + return mi + + +def js_fgan_lower_bound(f): + """Lower bound on Jensen-Shannon divergence from Nowozin et al. (2016).""" + f_diag = f.diag() + first_term = -F.softplus(-f_diag).mean() + n = f.size(0) + second_term = (torch.sum(F.softplus(f)) - + torch.sum(F.softplus(f_diag))) / (n * (n - 1.)) + return first_term - second_term + + +def js_lower_bound(f): + """Obtain density ratio from JS lower bound then output MI estimate from NWJ bound.""" + nwj = nwj_lower_bound(f) + js = js_fgan_lower_bound(f) + + with torch.no_grad(): + nwj_js = nwj - js + + return js + nwj_js + + +def dv_upper_lower_bound(f): + """ + Donsker-Varadhan lower bound, but upper bounded by using log outside. + Similar to MINE, but did not involve the term for moving averages. + """ + first_term = f.diag().mean() + second_term = logmeanexp_nodiag(f) + + return first_term - second_term + + +def mine_lower_bound(f, buffer=None, momentum=0.9): + """ + MINE lower bound based on DV inequality. + """ + if buffer is None: + buffer = torch.tensor(1.0).cuda() + first_term = f.diag().mean() + + buffer_update = logmeanexp_nodiag(f).exp() + with torch.no_grad(): + second_term = logmeanexp_nodiag(f) + buffer_new = buffer * momentum + buffer_update * (1 - momentum) + buffer_new = torch.clamp(buffer_new, min=1e-4) + third_term_no_grad = buffer_update / buffer_new + + third_term_grad = buffer_update / buffer_new + + return first_term - second_term - third_term_grad + third_term_no_grad, buffer_update + + +def smile_lower_bound(f, clip=None): + if clip is not None: + f_ = torch.clamp(f, -clip, clip) + else: + f_ = f + z = logmeanexp_nodiag(f_, dim=(0, 1)) + dv = f.diag().mean() - z + + js = js_fgan_lower_bound(f) + + with torch.no_grad(): + dv_js = dv - js + + return js + dv_js + + +def estimate_mutual_information(estimator, x, y, critic_fn, + baseline_fn=None, alpha_logit=None, **kwargs): + """Estimate variational lower bounds on mutual information. + + Args: + estimator: string specifying estimator, one of: + 'nwj', 'infonce', 'tuba', 'js', 'interpolated' + x: [batch_size, dim_x] Tensor + y: [batch_size, dim_y] Tensor + critic_fn: callable that takes x and y as input and outputs critic scores + output shape is a [batch_size, batch_size] matrix + baseline_fn (optional): callable that takes y as input + outputs a [batch_size] or [batch_size, 1] vector + alpha_logit (optional): logit(alpha) for interpolated bound + + Returns: + scalar estimate of mutual information + """ + x, y = x.cuda(), y.cuda() + scores = critic_fn(x, y) + if baseline_fn is not None: + # Some baselines' output is (batch_size, 1) which we remove here. + log_baseline = torch.squeeze(baseline_fn(y)) + if estimator == 'infonce': + mi = infonce_lower_bound(scores) + elif estimator == 'nwj': + mi = nwj_lower_bound(scores) + elif estimator == 'tuba': + mi = tuba_lower_bound(scores, log_baseline) + elif estimator == 'js': + mi = js_lower_bound(scores) + elif estimator == 'smile': + mi = smile_lower_bound(scores, **kwargs) + elif estimator == 'dv': + mi = dv_upper_lower_bound(scores) + return mi