Skip to content

the question about the function 5 in the paper #4

@ToneLi

Description

@ToneLi

def masked_log_softmax(vector: torch.Tensor, mask: torch.BoolTensor, dim: int = -1) -> torch.Tensor:
if mask is not None:
while mask.dim() < vector.dim():
mask = mask.unsqueeze(1)
# vector + mask.log() is an easy way to zero out masked elements in logspace, but it
# results in nans when the whole vector is masked. We need a very small value instead of a
# zero in the mask for these cases.
vector = vector + (mask + tiny_value_of_dtype(vector.dtype)).log()
return torch.nn.functional.log_softmax(vector, dim=dim)

So sorry, I can not understand: why just mask the (vector=type_linear_output @ span_linear_output) before inputting the vector to the log_softmax function, how to make sure the numerator (exp(sim(s_i,j, ek))) in function (5) is the positive?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions