Skip to content
14 changes: 10 additions & 4 deletions src/pie_modules/models/components/pointer_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,16 @@ def forward(
decoder_attention_mask: Optional[torch.LongTensor] = None,
constraints: Optional[torch.LongTensor] = None,
):
min_float_val = torch.finfo(last_hidden_state.dtype).min
# assemble the logits
logits = last_hidden_state.new_full(
(
last_hidden_state.size(0),
last_hidden_state.size(1),
self.pointer_offset + encoder_input_ids.size(-1),
),
fill_value=-1e24,
fill_value=min_float_val,
dtype=last_hidden_state.dtype,
)

# eos and label scores depend only on the decoder output
Expand All @@ -295,7 +297,8 @@ def forward(
# never point to the padding or the eos token in the encoder input
# TODO: why not excluding the bos token? seems to give worse results, but not tested extensively
mask_invalid = encoder_attention_mask.eq(0) | encoder_input_ids.eq(self.eos_token_id)
avg_word_scores = avg_word_scores.masked_fill(mask_invalid.unsqueeze(1), -1e32)
min_float_val = torch.finfo(avg_word_scores.dtype).min
avg_word_scores = avg_word_scores.masked_fill(mask_invalid.unsqueeze(1), min_float_val)

# Note: the remaining row in logits contains the score for the bos token which should be never generated!
logits[:, :, [self.eos_id]] = eos_scores
Expand Down Expand Up @@ -331,13 +334,15 @@ def forward(
constraints_word_scores = torch.einsum(
"blh,bnh->bln", last_hidden_state, constraints_src_outputs
)
min_float_val = torch.finfo(last_hidden_state.dtype).min
constraints_logits = last_hidden_state.new_full(
(
last_hidden_state.size(0),
last_hidden_state.size(1),
self.pointer_offset + encoder_input_ids.size(-1),
),
fill_value=-1e24,
fill_value=min_float_val,
dtype=last_hidden_state.dtype,
)
constraints_logits[:, :, self.label_ids] = constraints_label_scores
constraints_logits[:, :, self.pointer_offset :] = constraints_word_scores
Expand All @@ -346,7 +351,8 @@ def forward(
constraints_logits_valid = constraints_logits[mask]
constraints_valid = constraints[mask]
loss_c = F.binary_cross_entropy(
torch.sigmoid(constraints_logits_valid), constraints_valid.float()
torch.sigmoid(constraints_logits_valid),
constraints_valid.to(dtype=constraints_logits_valid.dtype),
)

if loss is None:
Expand Down
Loading