Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions info_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def forward(self, global_enc, local_enc):
global_enc = global_enc.unsqueeze(1) # (batch, 1, global_size)
global_enc = global_enc.expand(-1, local_enc.size(1), -1) # (batch, seq_len, global_size)
# (batch, seq_len, global_size) * (batch, seq_len, local_size) -> (batch, seq_len, 1)
scores = self.bilinear(global_enc.contiguous(), local_enc.contiguous())
scores = self.bilinear(global_enc.contiguous(), local_enc.contiguous()) # critic function

return scores

Expand All @@ -36,12 +36,16 @@ class SquadDIMLoss(nn.Module):
Deep infomax loss for SQuAD question answering dataset.
As the difference between GC and LC only lies in whether we do summarization over x,
this class can be used as both GC and LC.
Instead of using BCELoss, we use SMILE lower bound to estimate mutual information
In our case,
`MI(X,Y) >= E_p[g(X,Y)] - (1/2)*log(E_N[exp(g((X^(\bar),Y)))]) - (1/2)*log(E_N[exp(g((X,Y^(\bar))))])`
'''
def __init__(self, feature_size):
super(SquadDIMLoss, self).__init__()
self.discriminator = SquadDiscriminator(feature_size)
self.summarize = Summarize()
self.bce_loss = nn.BCEWithLogitsLoss()
# self.bce_loss = nn.BCEWithLogitsLoss()
self.smile_lower_bound = smile_lower_bound()

self.dropout = nn.Dropout(0.5)

Expand All @@ -51,22 +55,23 @@ def forward(self, x_enc, x_fake, y_enc, y_fake, do_summarize=False):
global_enc, local_enc: (batch, seq, dim)
global_fake, local_fake: (batch, seq, dim)
'''

# Compute g(x, y)
if do_summarize:
x_enc = self.summarize(x_enc)
x_enc = self.dropout(x_enc)
y_enc = self.dropout(y_enc)
logits = self.discriminator(x_enc, y_enc)
logits = self.discriminator(x_enc, y_enc)
batch_size1, n_seq1 = y_enc.size(0), y_enc.size(1)
labels = torch.ones(batch_size1, n_seq1)

# Compute 1 - g(x, y^(\bar))
y_fake = self.dropout(y_fake)
# Compute g(x, y^(\bar))
y_fake = self.dropout(y_fake)
_logits = self.discriminator(x_enc, y_fake)
batch_size2, n_seq2 = y_fake.size(0), y_fake.size(1)
_labels = torch.zeros(batch_size2, n_seq2)

logits, labels = torch.cat((logits, _logits), dim=1), torch.cat((labels, _labels), dim=1)
logits = torch.cat((logits, _logits), dim=1)

# Compute 1 - g(x^(\bar), y)
if do_summarize:
Expand All @@ -75,7 +80,7 @@ def forward(self, x_enc, x_fake, y_enc, y_fake, do_summarize=False):
_logits = self.discriminator(x_fake, y_enc)
_labels = torch.zeros(batch_size1, n_seq1)

logits, labels = torch.cat((logits, _logits), dim=1), torch.cat((labels, _labels), dim=1)
logits = torch.cat((logits, _logits), dim=1)

loss = self.bce_loss(logits.squeeze(2), labels.cuda())

Expand Down
Binary file added log/events.out.tfevents.1620625258.cocoa.81980.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620628951.cocoa.37161.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620629064.cocoa.39884.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620629321.cocoa.45909.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620629492.cocoa.49984.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620791617.cocoa.4135.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620793983.cocoa.59139.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620795608.cocoa.97024.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620796058.cocoa.107605.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620796134.cocoa.109471.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620796503.cocoa.118138.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620796732.cocoa.123577.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620797076.cocoa.1073.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620797577.cocoa.12915.0
Binary file not shown.
Binary file added log/events.out.tfevents.1620797645.cocoa.14671.0
Binary file not shown.
16 changes: 10 additions & 6 deletions pytorch_pretrained_bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_al
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=torch.float32) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

embedding_output = self.embeddings(input_ids, token_type_ids)
Expand Down Expand Up @@ -776,7 +776,8 @@ def __init__(self, config):
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2)


#### insert smile-mi-estimator? (#rsk)
self.global_infomax = SquadDIMLoss(config.hidden_size) # GC
self.local_infomax = SquadDIMLoss(config.hidden_size) # LC

Expand Down Expand Up @@ -818,7 +819,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_pos
extend_end = min(sequence_output.size(1), end_positions[b_idx] + 5)
ans_seq = sequence_output[b_idx, extend_start : extend_end, :] # (seq, hidden_size)
ans_enc.append(ans_seq.unsqueeze(0))
assert len(ans_enc) == len(context_enc)
assert len(ans_enc) == len(context_enc)

# generate fake examples by shifting one index.
context_fake = [context_enc[-1]] + context_enc[:-1]
Expand All @@ -831,24 +832,27 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_pos
c_enc, c_fake = context_enc[b_idx], context_fake[b_idx]
a_enc, a_fake = ans_enc[b_idx], ans_fake[b_idx]

## Compute GC ##
## Compute GC ## Change to smile (#rsk)
global_loss = global_loss + self.global_infomax(a_enc, a_fake, c_enc, c_fake, do_summarize=True)

## Compute LC ##
## Compute LC ## Change to smile (#rsk)
# sample one
rand_idx = np.random.randint(a_enc.size(1))
a_enc_word = a_enc[0, rand_idx, :]
rand_idx = np.random.randint(a_fake.size(1))
a_enc_fake = a_fake[0, rand_idx, :]
local_loss = local_loss + self.local_infomax(a_enc_word.unsqueeze(0), a_enc_fake.unsqueeze(0), a_enc, a_fake, do_summarize=False)


# Info loss
info_loss = (0.5 * global_loss + local_loss) / len(ans_enc)
info_loss = 0.25 * info_loss

# total loss
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss + info_loss) / 3
#print('total_loss',total_loss,'info_loss',info_loss)
return total_loss, info_loss
else:
return start_logits, end_logits, sequence_output
Expand Down
8 changes: 5 additions & 3 deletions run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,19 +990,21 @@ def main():
if n_gpu == 1:
batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
input_ids, input_mask, segment_ids, start_positions, end_positions = batch

loss, info_loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
#print('total_loss',loss,'info_loss',info_loss)
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
info_loss = info_loss.mean()
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps

if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
if step % 50 == 0:
tqdm.write('Total Loss:{} Info Loss: {}'.format(loss.item(), info_loss.item()))
#print('total_loss',loss,'info_loss',info_loss)
tqdm.write('Total Loss:{0} Info Loss: {1}'.format(loss.item(), info_loss.item()))
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
# modify learning rate with special warm up BERT uses
Expand Down
164 changes: 164 additions & 0 deletions smile/estimators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import numpy as np
import torch
import torch.nn.functional as F


def logmeanexp_diag(x, device='cuda'):
"""Compute logmeanexp over the diagonal elements of x."""
batch_size = x.size(0)

logsumexp = torch.logsumexp(x.diag(), dim=(0,))
num_elem = batch_size

return logsumexp - torch.log(torch.tensor(num_elem).float()).to(device)


def logmeanexp_nodiag(x, dim=None, device='cuda'):
batch_size = x.size(0)
if dim is None:
dim = (0, 1)

logsumexp = torch.logsumexp(
x - torch.diag(np.inf * torch.ones(batch_size).to(device)), dim=dim)

try:
if len(dim) == 1:
num_elem = batch_size - 1.
else:
num_elem = batch_size * (batch_size - 1.)
except ValueError:
num_elem = batch_size - 1
return logsumexp - torch.log(torch.tensor(num_elem)).to(device)


def tuba_lower_bound(scores, log_baseline=None):
if log_baseline is not None:
scores -= log_baseline[:, None]

# First term is an expectation over samples from the joint,
# which are the diagonal elmements of the scores matrix.
joint_term = scores.diag().mean()

# Second term is an expectation over samples from the marginal,
# which are the off-diagonal elements of the scores matrix.
marg_term = logmeanexp_nodiag(scores).exp()
return 1. + joint_term - marg_term


def nwj_lower_bound(scores):
return tuba_lower_bound(scores - 1.)


def infonce_lower_bound(scores):
nll = scores.diag().mean() - scores.logsumexp(dim=1)
# Alternative implementation:
# nll = -tf.nn.sparse_softmax_cross_entropy_with_logits(logits=scores, labels=tf.range(batch_size))
mi = torch.tensor(scores.size(0)).float().log() + nll
mi = mi.mean()
return mi


def js_fgan_lower_bound(f):
"""Lower bound on Jensen-Shannon divergence from Nowozin et al. (2016)."""
f_diag = f.diag()
first_term = -F.softplus(-f_diag).mean()
n = f.size(0)
second_term = (torch.sum(F.softplus(f)) -
torch.sum(F.softplus(f_diag))) / (n * (n - 1.))
return first_term - second_term


def js_lower_bound(f):
"""Obtain density ratio from JS lower bound then output MI estimate from NWJ bound."""
nwj = nwj_lower_bound(f)
js = js_fgan_lower_bound(f)

with torch.no_grad():
nwj_js = nwj - js

return js + nwj_js


def dv_upper_lower_bound(f):
"""
Donsker-Varadhan lower bound, but upper bounded by using log outside.
Similar to MINE, but did not involve the term for moving averages.
"""
first_term = f.diag().mean()
second_term = logmeanexp_nodiag(f)

return first_term - second_term


def mine_lower_bound(f, buffer=None, momentum=0.9):
"""
MINE lower bound based on DV inequality.
"""
if buffer is None:
buffer = torch.tensor(1.0).cuda()
first_term = f.diag().mean()

buffer_update = logmeanexp_nodiag(f).exp()
with torch.no_grad():
second_term = logmeanexp_nodiag(f)
buffer_new = buffer * momentum + buffer_update * (1 - momentum)
buffer_new = torch.clamp(buffer_new, min=1e-4)
third_term_no_grad = buffer_update / buffer_new

third_term_grad = buffer_update / buffer_new

return first_term - second_term - third_term_grad + third_term_no_grad, buffer_update


def smile_lower_bound(f, clip=None):
if clip is not None:
f_ = torch.clamp(f, -clip, clip)
else:
f_ = f
z = logmeanexp_nodiag(f_, dim=(0, 1))
dv = f.diag().mean() - z

js = js_fgan_lower_bound(f)

with torch.no_grad():
dv_js = dv - js

return js + dv_js


def estimate_mutual_information(estimator, x, y, critic_fn,
baseline_fn=None, alpha_logit=None, **kwargs):
"""Estimate variational lower bounds on mutual information.

Args:
estimator: string specifying estimator, one of:
'nwj', 'infonce', 'tuba', 'js', 'interpolated'
x: [batch_size, dim_x] Tensor
y: [batch_size, dim_y] Tensor
critic_fn: callable that takes x and y as input and outputs critic scores
output shape is a [batch_size, batch_size] matrix
baseline_fn (optional): callable that takes y as input
outputs a [batch_size] or [batch_size, 1] vector
alpha_logit (optional): logit(alpha) for interpolated bound

Returns:
scalar estimate of mutual information
"""
x, y = x.cuda(), y.cuda()
scores = critic_fn(x, y)
if baseline_fn is not None:
# Some baselines' output is (batch_size, 1) which we remove here.
log_baseline = torch.squeeze(baseline_fn(y))
if estimator == 'infonce':
mi = infonce_lower_bound(scores)
elif estimator == 'nwj':
mi = nwj_lower_bound(scores)
elif estimator == 'tuba':
mi = tuba_lower_bound(scores, log_baseline)
elif estimator == 'js':
mi = js_lower_bound(scores)
elif estimator == 'smile':
mi = smile_lower_bound(scores, **kwargs)
elif estimator == 'dv':
mi = dv_upper_lower_bound(scores)
return mi