I'm a little confused about the generation of mask_words_label. Appreciate any helpful answer.
In my understanding, the mask_words_label is the pixel value with highest frequency in the word, as defined in the function _generate_mask_words_label() in dataset.py:
for nb in range(mask_words.shape[0]):
mask_words_label_tmp = []
for nw in range(mask_words.shape[1]):
mask_words_label_tmp.append(np.argmax(np.bincount(mask_words[nb, nw])))
mask_words_label.append(mask_words_label_tmp)
thus the label is a value within the range of [0, 256].
However the classifier is designed for 4-classes classification task:
self.word_criterion = FocalLoss(num_class=4)
and the one-hot encoder could not work for the label:
# in source file: utils/criterions/focal_loss.py
# class FocalLoss forward()
target = target.view(-1, 1)
idx = target.cpu().long()
one_hot_key = torch.FloatTensor(target.size(0), self.num_class).zero_()
one_hot_key = one_hot_key.scatter_(1, idx, 1)
where I encountered the following error:
File "MuMo/utils/criterions/focal_loss.py", line 54, in forward
one_hot_key = one_hot_key.scatter_(1, idx, 1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: index 255 is out of bounds for dimension 1 with size 4
I wonder if a transformation of the mask_words_label from its raw range [0, 256] to num_classes(4) is missed? Or have I not fully understood the code, or something went wrong during execution? I would greatly appreciate it if you could provide some hints.
I'm a little confused about the generation of mask_words_label. Appreciate any helpful answer.
In my understanding, the mask_words_label is the pixel value with highest frequency in the word, as defined in the function _generate_mask_words_label() in dataset.py:
thus the label is a value within the range of [0, 256].
However the classifier is designed for 4-classes classification task:
self.word_criterion = FocalLoss(num_class=4)and the one-hot encoder could not work for the label:
where I encountered the following error:
I wonder if a transformation of the mask_words_label from its raw range [0, 256] to num_classes(4) is missed? Or have I not fully understood the code, or something went wrong during execution? I would greatly appreciate it if you could provide some hints.