Thanks for sharing the codes. Could you provide an example of the o1 loss?
I've combined it with the CTC loss as shown in the following code, but it seems the performance has not improved.
log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = nn.functional.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction, # default: sum, use_focal_loss=none
zero_infinity=self.config.ctc_zero_infinity, # default: false
)
o1_loss = self.o1_loss(log_probs.transpose(0,1),
input_lengths,
labels,
target_lengths,
)
if self.use_o1_loss:
o1_loss /= batch_size
loss = 0.01 * loss + 1. * o1_loss
Thanks for sharing the codes. Could you provide an example of the o1 loss?
I've combined it with the CTC loss as shown in the following code, but it seems the performance has not improved.