diff --git a/src/pie_modules/models/components/pointer_head.py b/src/pie_modules/models/components/pointer_head.py index e8a02acfa..2bd3634d3 100644 --- a/src/pie_modules/models/components/pointer_head.py +++ b/src/pie_modules/models/components/pointer_head.py @@ -261,6 +261,7 @@ def forward( decoder_attention_mask: Optional[torch.LongTensor] = None, constraints: Optional[torch.LongTensor] = None, ): + min_float_val = torch.finfo(last_hidden_state.dtype).min # assemble the logits logits = last_hidden_state.new_full( ( @@ -268,7 +269,8 @@ def forward( last_hidden_state.size(1), self.pointer_offset + encoder_input_ids.size(-1), ), - fill_value=-1e24, + fill_value=min_float_val, + dtype=last_hidden_state.dtype, ) # eos and label scores depend only on the decoder output @@ -295,7 +297,8 @@ def forward( # never point to the padding or the eos token in the encoder input # TODO: why not excluding the bos token? seems to give worse results, but not tested extensively mask_invalid = encoder_attention_mask.eq(0) | encoder_input_ids.eq(self.eos_token_id) - avg_word_scores = avg_word_scores.masked_fill(mask_invalid.unsqueeze(1), -1e32) + min_float_val = torch.finfo(avg_word_scores.dtype).min + avg_word_scores = avg_word_scores.masked_fill(mask_invalid.unsqueeze(1), min_float_val) # Note: the remaining row in logits contains the score for the bos token which should be never generated! logits[:, :, [self.eos_id]] = eos_scores @@ -331,13 +334,15 @@ def forward( constraints_word_scores = torch.einsum( "blh,bnh->bln", last_hidden_state, constraints_src_outputs ) + min_float_val = torch.finfo(last_hidden_state.dtype).min constraints_logits = last_hidden_state.new_full( ( last_hidden_state.size(0), last_hidden_state.size(1), self.pointer_offset + encoder_input_ids.size(-1), ), - fill_value=-1e24, + fill_value=min_float_val, + dtype=last_hidden_state.dtype, ) constraints_logits[:, :, self.label_ids] = constraints_label_scores constraints_logits[:, :, self.pointer_offset :] = constraints_word_scores @@ -346,7 +351,8 @@ def forward( constraints_logits_valid = constraints_logits[mask] constraints_valid = constraints[mask] loss_c = F.binary_cross_entropy( - torch.sigmoid(constraints_logits_valid), constraints_valid.float() + torch.sigmoid(constraints_logits_valid), + constraints_valid.to(dtype=constraints_logits_valid.dtype), ) if loss is None: diff --git a/tests/models/base_models/test_bart_as_pointer_network.py b/tests/models/base_models/test_bart_as_pointer_network.py index 96892e8a3..036fda2b3 100644 --- a/tests/models/base_models/test_bart_as_pointer_network.py +++ b/tests/models/base_models/test_bart_as_pointer_network.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch from transformers import ( @@ -18,7 +19,7 @@ # this is a small model that can be used for testing MODEL_NAME_OR_PATH = "sshleifer/bart-tiny-random" DECODER_POSITION_ID_PATTERN = [0, 0, 1, 0, 0, 1, 1] -CONFIGS = [{}, {"decoder_position_id_mode": "pattern"}] +CONFIGS = [{}, {"decoder_position_id_mode": "pattern"}, {"torch_dtype": "torch.float16"}] CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} @@ -80,8 +81,15 @@ def taskmodule(document): @pytest.fixture(scope="module") def model(config) -> BartAsPointerNetwork: + config = config.copy() model_name_or_path = MODEL_NAME_OR_PATH - + torch_dtype_str = config.pop("torch_dtype", "torch.float32") + if torch_dtype_str == "torch.float16": + torch_dtype = torch.float16 + elif torch_dtype_str == "torch.float32": + torch_dtype = torch.float32 + else: + raise ValueError(f"Invalid torch_dtype: {torch_dtype_str}") torch.random.manual_seed(42) model = BartAsPointerNetwork.from_pretrained( model_name_or_path, @@ -101,9 +109,10 @@ def model(config) -> BartAsPointerNetwork: 50267: [354, 1215, 9006], }, decoder_position_id_pattern=DECODER_POSITION_ID_PATTERN, + # torch_dtype=torch_dtype, **config, ) - + model = model.to(dtype=torch_dtype) return model @@ -111,111 +120,209 @@ def test_model(model, config): assert model is not None named_parameters = dict(model.named_parameters()) parameter_means = {k: trunc_number(v.mean().item(), 7) for k, v in named_parameters.items()} - parameter_means_expected = { - "model.shared.weight": -1.41e-05, - "model.encoder.embed_positions.weight": -0.0001324, - "model.encoder.layers.0.self_attn.k_proj.weight": -0.0004574, - "model.encoder.layers.0.self_attn.k_proj.bias": 0.0, - "model.encoder.layers.0.self_attn.v_proj.weight": -0.0005457, - "model.encoder.layers.0.self_attn.v_proj.bias": 0.0, - "model.encoder.layers.0.self_attn.q_proj.weight": -0.0009775, - "model.encoder.layers.0.self_attn.q_proj.bias": 0.0, - "model.encoder.layers.0.self_attn.out_proj.weight": -0.0001075, - "model.encoder.layers.0.self_attn.out_proj.bias": 0.0, - "model.encoder.layers.0.self_attn_layer_norm.weight": 1.0, - "model.encoder.layers.0.self_attn_layer_norm.bias": 0.0, - "model.encoder.layers.0.fc1.weight": -0.0008655, - "model.encoder.layers.0.fc1.bias": 0.0, - "model.encoder.layers.0.fc2.weight": 0.0015535, - "model.encoder.layers.0.fc2.bias": 0.0, - "model.encoder.layers.0.final_layer_norm.weight": 1.0, - "model.encoder.layers.0.final_layer_norm.bias": 0.0, - "model.encoder.layers.1.self_attn.k_proj.weight": -0.0007831, - "model.encoder.layers.1.self_attn.k_proj.bias": 0.0, - "model.encoder.layers.1.self_attn.v_proj.weight": 0.0001186, - "model.encoder.layers.1.self_attn.v_proj.bias": 0.0, - "model.encoder.layers.1.self_attn.q_proj.weight": 0.0006847, - "model.encoder.layers.1.self_attn.q_proj.bias": 0.0, - "model.encoder.layers.1.self_attn.out_proj.weight": 0.0011724, - "model.encoder.layers.1.self_attn.out_proj.bias": 0.0, - "model.encoder.layers.1.self_attn_layer_norm.weight": 1.0, - "model.encoder.layers.1.self_attn_layer_norm.bias": 0.0, - "model.encoder.layers.1.fc1.weight": 0.0007757, - "model.encoder.layers.1.fc1.bias": 0.0, - "model.encoder.layers.1.fc2.weight": -0.0002014, - "model.encoder.layers.1.fc2.bias": 0.0, - "model.encoder.layers.1.final_layer_norm.weight": 1.0, - "model.encoder.layers.1.final_layer_norm.bias": 0.0, - "model.encoder.layernorm_embedding.weight": 1.0, - "model.encoder.layernorm_embedding.bias": 0.0, - "model.decoder.embed_positions.weight": -0.0001275, - "model.decoder.layers.0.self_attn.k_proj.weight": -0.0010682, - "model.decoder.layers.0.self_attn.k_proj.bias": 0.0, - "model.decoder.layers.0.self_attn.v_proj.weight": 0.0005057, - "model.decoder.layers.0.self_attn.v_proj.bias": 0.0, - "model.decoder.layers.0.self_attn.q_proj.weight": 0.0003248, - "model.decoder.layers.0.self_attn.q_proj.bias": 0.0, - "model.decoder.layers.0.self_attn.out_proj.weight": -0.0002014, - "model.decoder.layers.0.self_attn.out_proj.bias": 0.0, - "model.decoder.layers.0.self_attn_layer_norm.weight": 1.0, - "model.decoder.layers.0.self_attn_layer_norm.bias": 0.0, - "model.decoder.layers.0.encoder_attn.k_proj.weight": -0.0004254, - "model.decoder.layers.0.encoder_attn.k_proj.bias": 0.0, - "model.decoder.layers.0.encoder_attn.v_proj.weight": -0.0004049, - "model.decoder.layers.0.encoder_attn.v_proj.bias": 0.0, - "model.decoder.layers.0.encoder_attn.q_proj.weight": -0.0003516, - "model.decoder.layers.0.encoder_attn.q_proj.bias": 0.0, - "model.decoder.layers.0.encoder_attn.out_proj.weight": 0.0009908, - "model.decoder.layers.0.encoder_attn.out_proj.bias": 0.0, - "model.decoder.layers.0.encoder_attn_layer_norm.weight": 1.0, - "model.decoder.layers.0.encoder_attn_layer_norm.bias": 0.0, - "model.decoder.layers.0.fc1.weight": 0.0008378, - "model.decoder.layers.0.fc1.bias": 0.0, - "model.decoder.layers.0.fc2.weight": -2e-05, - "model.decoder.layers.0.fc2.bias": 0.0, - "model.decoder.layers.0.final_layer_norm.weight": 1.0, - "model.decoder.layers.0.final_layer_norm.bias": 0.0, - "model.decoder.layers.1.self_attn.k_proj.weight": -0.0007669, - "model.decoder.layers.1.self_attn.k_proj.bias": 0.0, - "model.decoder.layers.1.self_attn.v_proj.weight": -0.0007123, - "model.decoder.layers.1.self_attn.v_proj.bias": 0.0, - "model.decoder.layers.1.self_attn.q_proj.weight": 0.0012958, - "model.decoder.layers.1.self_attn.q_proj.bias": 0.0, - "model.decoder.layers.1.self_attn.out_proj.weight": -0.0006818, - "model.decoder.layers.1.self_attn.out_proj.bias": 0.0, - "model.decoder.layers.1.self_attn_layer_norm.weight": 1.0, - "model.decoder.layers.1.self_attn_layer_norm.bias": 0.0, - "model.decoder.layers.1.encoder_attn.k_proj.weight": -0.0006906, - "model.decoder.layers.1.encoder_attn.k_proj.bias": 0.0, - "model.decoder.layers.1.encoder_attn.v_proj.weight": -0.0009213, - "model.decoder.layers.1.encoder_attn.v_proj.bias": 0.0, - "model.decoder.layers.1.encoder_attn.q_proj.weight": -0.000842, - "model.decoder.layers.1.encoder_attn.q_proj.bias": 0.0, - "model.decoder.layers.1.encoder_attn.out_proj.weight": 0.0008073, - "model.decoder.layers.1.encoder_attn.out_proj.bias": 0.0, - "model.decoder.layers.1.encoder_attn_layer_norm.weight": 1.0, - "model.decoder.layers.1.encoder_attn_layer_norm.bias": 0.0, - "model.decoder.layers.1.fc1.weight": 0.0015493, - "model.decoder.layers.1.fc1.bias": 0.0, - "model.decoder.layers.1.fc2.weight": -0.0009827, - "model.decoder.layers.1.fc2.bias": 0.0, - "model.decoder.layers.1.final_layer_norm.weight": 1.0, - "model.decoder.layers.1.final_layer_norm.bias": 0.0, - "model.decoder.layernorm_embedding.weight": 1.0, - "model.decoder.layernorm_embedding.bias": 0.0, - "pointer_head.encoder_mlp.0.weight": 0.0004805, - "pointer_head.encoder_mlp.0.bias": 0.0, - "pointer_head.encoder_mlp.3.weight": 0.0001837, - "pointer_head.encoder_mlp.3.bias": 0.0, - } + if config == {"torch_dtype": "torch.float16"}: + parameter_means_expected = { + "model.shared.weight": -1.41e-05, + "model.encoder.embed_positions.weight": -0.0001324, + "model.encoder.layers.0.self_attn.k_proj.weight": -0.0004575, + "model.encoder.layers.0.self_attn.k_proj.bias": 0.0, + "model.encoder.layers.0.self_attn.v_proj.weight": -0.0005455, + "model.encoder.layers.0.self_attn.v_proj.bias": 0.0, + "model.encoder.layers.0.self_attn.q_proj.weight": -0.0009775, + "model.encoder.layers.0.self_attn.q_proj.bias": 0.0, + "model.encoder.layers.0.self_attn.out_proj.weight": -0.0001074, + "model.encoder.layers.0.self_attn.out_proj.bias": 0.0, + "model.encoder.layers.0.self_attn_layer_norm.weight": 1.0, + "model.encoder.layers.0.self_attn_layer_norm.bias": 0.0, + "model.encoder.layers.0.fc1.weight": -0.0008654, + "model.encoder.layers.0.fc1.bias": 0.0, + "model.encoder.layers.0.fc2.weight": 0.0015535, + "model.encoder.layers.0.fc2.bias": 0.0, + "model.encoder.layers.0.final_layer_norm.weight": 1.0, + "model.encoder.layers.0.final_layer_norm.bias": 0.0, + "model.encoder.layers.1.self_attn.k_proj.weight": -0.0007829, + "model.encoder.layers.1.self_attn.k_proj.bias": 0.0, + "model.encoder.layers.1.self_attn.v_proj.weight": 0.0001187, + "model.encoder.layers.1.self_attn.v_proj.bias": 0.0, + "model.encoder.layers.1.self_attn.q_proj.weight": 0.0006847, + "model.encoder.layers.1.self_attn.q_proj.bias": 0.0, + "model.encoder.layers.1.self_attn.out_proj.weight": 0.001172, + "model.encoder.layers.1.self_attn.out_proj.bias": 0.0, + "model.encoder.layers.1.self_attn_layer_norm.weight": 1.0, + "model.encoder.layers.1.self_attn_layer_norm.bias": 0.0, + "model.encoder.layers.1.fc1.weight": 0.0007753, + "model.encoder.layers.1.fc1.bias": 0.0, + "model.encoder.layers.1.fc2.weight": -0.0002021, + "model.encoder.layers.1.fc2.bias": 0.0, + "model.encoder.layers.1.final_layer_norm.weight": 1.0, + "model.encoder.layers.1.final_layer_norm.bias": 0.0, + "model.encoder.layernorm_embedding.weight": 1.0, + "model.encoder.layernorm_embedding.bias": 0.0, + "model.decoder.embed_positions.weight": -0.0001275, + "model.decoder.layers.0.self_attn.k_proj.weight": -0.0010681, + "model.decoder.layers.0.self_attn.k_proj.bias": 0.0, + "model.decoder.layers.0.self_attn.v_proj.weight": 0.0005054, + "model.decoder.layers.0.self_attn.v_proj.bias": 0.0, + "model.decoder.layers.0.self_attn.q_proj.weight": 0.0003247, + "model.decoder.layers.0.self_attn.q_proj.bias": 0.0, + "model.decoder.layers.0.self_attn.out_proj.weight": -0.0002011, + "model.decoder.layers.0.self_attn.out_proj.bias": 0.0, + "model.decoder.layers.0.self_attn_layer_norm.weight": 1.0, + "model.decoder.layers.0.self_attn_layer_norm.bias": 0.0, + "model.decoder.layers.0.encoder_attn.k_proj.weight": -0.0004251, + "model.decoder.layers.0.encoder_attn.k_proj.bias": 0.0, + "model.decoder.layers.0.encoder_attn.v_proj.weight": -0.0004045, + "model.decoder.layers.0.encoder_attn.v_proj.bias": 0.0, + "model.decoder.layers.0.encoder_attn.q_proj.weight": -0.0003519, + "model.decoder.layers.0.encoder_attn.q_proj.bias": 0.0, + "model.decoder.layers.0.encoder_attn.out_proj.weight": 0.0009908, + "model.decoder.layers.0.encoder_attn.out_proj.bias": 0.0, + "model.decoder.layers.0.encoder_attn_layer_norm.weight": 1.0, + "model.decoder.layers.0.encoder_attn_layer_norm.bias": 0.0, + "model.decoder.layers.0.fc1.weight": 0.0008378, + "model.decoder.layers.0.fc1.bias": 0.0, + "model.decoder.layers.0.fc2.weight": -2e-05, + "model.decoder.layers.0.fc2.bias": 0.0, + "model.decoder.layers.0.final_layer_norm.weight": 1.0, + "model.decoder.layers.0.final_layer_norm.bias": 0.0, + "model.decoder.layers.1.self_attn.k_proj.weight": -0.0007667, + "model.decoder.layers.1.self_attn.k_proj.bias": 0.0, + "model.decoder.layers.1.self_attn.v_proj.weight": -0.0007123, + "model.decoder.layers.1.self_attn.v_proj.bias": 0.0, + "model.decoder.layers.1.self_attn.q_proj.weight": 0.001295, + "model.decoder.layers.1.self_attn.q_proj.bias": 0.0, + "model.decoder.layers.1.self_attn.out_proj.weight": -0.0006818, + "model.decoder.layers.1.self_attn.out_proj.bias": 0.0, + "model.decoder.layers.1.self_attn_layer_norm.weight": 1.0, + "model.decoder.layers.1.self_attn_layer_norm.bias": 0.0, + "model.decoder.layers.1.encoder_attn.k_proj.weight": -0.0006904, + "model.decoder.layers.1.encoder_attn.k_proj.bias": 0.0, + "model.decoder.layers.1.encoder_attn.v_proj.weight": -0.0009212, + "model.decoder.layers.1.encoder_attn.v_proj.bias": 0.0, + "model.decoder.layers.1.encoder_attn.q_proj.weight": -0.000842, + "model.decoder.layers.1.encoder_attn.q_proj.bias": 0.0, + "model.decoder.layers.1.encoder_attn.out_proj.weight": 0.0008072, + "model.decoder.layers.1.encoder_attn.out_proj.bias": 0.0, + "model.decoder.layers.1.encoder_attn_layer_norm.weight": 1.0, + "model.decoder.layers.1.encoder_attn_layer_norm.bias": 0.0, + "model.decoder.layers.1.fc1.weight": 0.0015497, + "model.decoder.layers.1.fc1.bias": 0.0, + "model.decoder.layers.1.fc2.weight": -0.0009822, + "model.decoder.layers.1.fc2.bias": 0.0, + "model.decoder.layers.1.final_layer_norm.weight": 1.0, + "model.decoder.layers.1.final_layer_norm.bias": 0.0, + "model.decoder.layernorm_embedding.weight": 1.0, + "model.decoder.layernorm_embedding.bias": 0.0, + "pointer_head.encoder_mlp.0.weight": 0.0004801, + "pointer_head.encoder_mlp.0.bias": 0.0, + "pointer_head.encoder_mlp.3.weight": 0.0001838, + "pointer_head.encoder_mlp.3.bias": 0.0, + } + else: + parameter_means_expected = { + "model.shared.weight": -1.41e-05, + "model.encoder.embed_positions.weight": -0.0001324, + "model.encoder.layers.0.self_attn.k_proj.weight": -0.0004574, + "model.encoder.layers.0.self_attn.k_proj.bias": 0.0, + "model.encoder.layers.0.self_attn.v_proj.weight": -0.0005457, + "model.encoder.layers.0.self_attn.v_proj.bias": 0.0, + "model.encoder.layers.0.self_attn.q_proj.weight": -0.0009775, + "model.encoder.layers.0.self_attn.q_proj.bias": 0.0, + "model.encoder.layers.0.self_attn.out_proj.weight": -0.0001075, + "model.encoder.layers.0.self_attn.out_proj.bias": 0.0, + "model.encoder.layers.0.self_attn_layer_norm.weight": 1.0, + "model.encoder.layers.0.self_attn_layer_norm.bias": 0.0, + "model.encoder.layers.0.fc1.weight": -0.0008655, + "model.encoder.layers.0.fc1.bias": 0.0, + "model.encoder.layers.0.fc2.weight": 0.0015535, + "model.encoder.layers.0.fc2.bias": 0.0, + "model.encoder.layers.0.final_layer_norm.weight": 1.0, + "model.encoder.layers.0.final_layer_norm.bias": 0.0, + "model.encoder.layers.1.self_attn.k_proj.weight": -0.0007831, + "model.encoder.layers.1.self_attn.k_proj.bias": 0.0, + "model.encoder.layers.1.self_attn.v_proj.weight": 0.0001186, + "model.encoder.layers.1.self_attn.v_proj.bias": 0.0, + "model.encoder.layers.1.self_attn.q_proj.weight": 0.0006847, + "model.encoder.layers.1.self_attn.q_proj.bias": 0.0, + "model.encoder.layers.1.self_attn.out_proj.weight": 0.0011724, + "model.encoder.layers.1.self_attn.out_proj.bias": 0.0, + "model.encoder.layers.1.self_attn_layer_norm.weight": 1.0, + "model.encoder.layers.1.self_attn_layer_norm.bias": 0.0, + "model.encoder.layers.1.fc1.weight": 0.0007757, + "model.encoder.layers.1.fc1.bias": 0.0, + "model.encoder.layers.1.fc2.weight": -0.0002014, + "model.encoder.layers.1.fc2.bias": 0.0, + "model.encoder.layers.1.final_layer_norm.weight": 1.0, + "model.encoder.layers.1.final_layer_norm.bias": 0.0, + "model.encoder.layernorm_embedding.weight": 1.0, + "model.encoder.layernorm_embedding.bias": 0.0, + "model.decoder.embed_positions.weight": -0.0001275, + "model.decoder.layers.0.self_attn.k_proj.weight": -0.0010682, + "model.decoder.layers.0.self_attn.k_proj.bias": 0.0, + "model.decoder.layers.0.self_attn.v_proj.weight": 0.0005057, + "model.decoder.layers.0.self_attn.v_proj.bias": 0.0, + "model.decoder.layers.0.self_attn.q_proj.weight": 0.0003248, + "model.decoder.layers.0.self_attn.q_proj.bias": 0.0, + "model.decoder.layers.0.self_attn.out_proj.weight": -0.0002014, + "model.decoder.layers.0.self_attn.out_proj.bias": 0.0, + "model.decoder.layers.0.self_attn_layer_norm.weight": 1.0, + "model.decoder.layers.0.self_attn_layer_norm.bias": 0.0, + "model.decoder.layers.0.encoder_attn.k_proj.weight": -0.0004254, + "model.decoder.layers.0.encoder_attn.k_proj.bias": 0.0, + "model.decoder.layers.0.encoder_attn.v_proj.weight": -0.0004049, + "model.decoder.layers.0.encoder_attn.v_proj.bias": 0.0, + "model.decoder.layers.0.encoder_attn.q_proj.weight": -0.0003516, + "model.decoder.layers.0.encoder_attn.q_proj.bias": 0.0, + "model.decoder.layers.0.encoder_attn.out_proj.weight": 0.0009908, + "model.decoder.layers.0.encoder_attn.out_proj.bias": 0.0, + "model.decoder.layers.0.encoder_attn_layer_norm.weight": 1.0, + "model.decoder.layers.0.encoder_attn_layer_norm.bias": 0.0, + "model.decoder.layers.0.fc1.weight": 0.0008378, + "model.decoder.layers.0.fc1.bias": 0.0, + "model.decoder.layers.0.fc2.weight": -2e-05, + "model.decoder.layers.0.fc2.bias": 0.0, + "model.decoder.layers.0.final_layer_norm.weight": 1.0, + "model.decoder.layers.0.final_layer_norm.bias": 0.0, + "model.decoder.layers.1.self_attn.k_proj.weight": -0.0007669, + "model.decoder.layers.1.self_attn.k_proj.bias": 0.0, + "model.decoder.layers.1.self_attn.v_proj.weight": -0.0007123, + "model.decoder.layers.1.self_attn.v_proj.bias": 0.0, + "model.decoder.layers.1.self_attn.q_proj.weight": 0.0012958, + "model.decoder.layers.1.self_attn.q_proj.bias": 0.0, + "model.decoder.layers.1.self_attn.out_proj.weight": -0.0006818, + "model.decoder.layers.1.self_attn.out_proj.bias": 0.0, + "model.decoder.layers.1.self_attn_layer_norm.weight": 1.0, + "model.decoder.layers.1.self_attn_layer_norm.bias": 0.0, + "model.decoder.layers.1.encoder_attn.k_proj.weight": -0.0006906, + "model.decoder.layers.1.encoder_attn.k_proj.bias": 0.0, + "model.decoder.layers.1.encoder_attn.v_proj.weight": -0.0009213, + "model.decoder.layers.1.encoder_attn.v_proj.bias": 0.0, + "model.decoder.layers.1.encoder_attn.q_proj.weight": -0.000842, + "model.decoder.layers.1.encoder_attn.q_proj.bias": 0.0, + "model.decoder.layers.1.encoder_attn.out_proj.weight": 0.0008073, + "model.decoder.layers.1.encoder_attn.out_proj.bias": 0.0, + "model.decoder.layers.1.encoder_attn_layer_norm.weight": 1.0, + "model.decoder.layers.1.encoder_attn_layer_norm.bias": 0.0, + "model.decoder.layers.1.fc1.weight": 0.0015493, + "model.decoder.layers.1.fc1.bias": 0.0, + "model.decoder.layers.1.fc2.weight": -0.0009827, + "model.decoder.layers.1.fc2.bias": 0.0, + "model.decoder.layers.1.final_layer_norm.weight": 1.0, + "model.decoder.layers.1.final_layer_norm.bias": 0.0, + "model.decoder.layernorm_embedding.weight": 1.0, + "model.decoder.layernorm_embedding.bias": 0.0, + "pointer_head.encoder_mlp.0.weight": 0.0004805, + "pointer_head.encoder_mlp.0.bias": 0.0, + "pointer_head.encoder_mlp.3.weight": 0.0001837, + "pointer_head.encoder_mlp.3.bias": 0.0, + } + assert parameter_means == parameter_means_expected assert isinstance(model, BartAsPointerNetwork) - if config == {}: - assert isinstance(model.model, BartModel) - elif config == {"decoder_position_id_mode": "pattern"}: + if config == {"decoder_position_id_mode": "pattern"}: assert isinstance(model.model, BartModelWithDecoderPositionIds) else: - raise ValueError(f"Unknown config: {config}") + assert isinstance(model.model, BartModel) @pytest.fixture(scope="module") @@ -308,58 +415,109 @@ def test_forward(model, batch, decoder_input_ids, config): # shape: (batch_size, output_seq_len, target_size=num_target_ids+num_offsets) assert outputs.logits.shape == (2, 8, 17) # check exact values only for the first sequence output - torch.testing.assert_close( - outputs.logits[:, 0, :], - torch.tensor( - [ + if config == {"torch_dtype": "torch.float16"}: + torch.testing.assert_close( + outputs.logits[:, 0, :], + torch.tensor( [ - -1.0000000138484279e24, - -0.23238050937652588, - 0.2958170175552368, - 0.05529244244098663, - 0.04253090173006058, - 0.10081345587968826, - -0.07145103067159653, - 0.12317530065774918, - -0.06861806660890579, - 0.07819556444883347, - 0.006490768864750862, - -0.040455855429172516, - 0.03176971897482872, - 0.05362509936094284, - 0.04528001323342323, - -0.0684177577495575, - -1.0000000331813535e32, + [ + -65504.0, + -0.2322998046875, + 0.295654296875, + 0.05535888671875, + 0.042449951171875, + 0.10089111328125, + -0.07147216796875, + 0.12322998046875, + -0.068603515625, + 0.0782470703125, + 0.00640869140625, + -0.040435791015625, + 0.03179931640625, + 0.0535888671875, + 0.045318603515625, + -0.06842041015625, + -65504.0, + ], + [ + -65504.0, + -0.232666015625, + 0.296142578125, + 0.055511474609375, + 0.042755126953125, + 0.10076904296875, + -0.0714111328125, + 0.1231689453125, + 0.06494140625, + 0.0794677734375, + -0.0794677734375, + -65504.0, + -65504.0, + -65504.0, + -65504.0, + -65504.0, + -65504.0, + ], ], + dtype=outputs.logits.dtype, + ), + ) + else: + torch.testing.assert_close( + outputs.logits[:, 0, :], + torch.tensor( [ - -1.0000000138484279e24, - -0.23274855315685272, - 0.2960396707057953, - 0.05556505173444748, - 0.04273710399866104, - 0.10071954131126404, - -0.071356862783432, - 0.12314081937074661, - 0.06498698145151138, - 0.07938676327466965, - -0.07943986356258392, - -1.0000000331813535e32, - -1.0000000331813535e32, - -1.0000000331813535e32, - -1.0000000331813535e32, - -1.0000000331813535e32, - -1.0000000331813535e32, + [ + -3.4028234663852886e38, + -0.23238050937652588, + 0.2958170175552368, + 0.05529244244098663, + 0.04253090173006058, + 0.10081345587968826, + -0.07145103067159653, + 0.12317530065774918, + -0.06861806660890579, + 0.07819556444883347, + 0.006490768864750862, + -0.040455855429172516, + 0.03176971897482872, + 0.05362509936094284, + 0.04528001323342323, + -0.0684177577495575, + -3.4028234663852886e38, + ], + [ + -3.4028234663852886e38, + -0.23274855315685272, + 0.2960396707057953, + 0.05556505173444748, + 0.04273710399866104, + 0.10071954131126404, + -0.071356862783432, + 0.12314081937074661, + 0.06498698145151138, + 0.07938676327466965, + -0.07943986356258392, + -3.4028234663852886e38, + -3.4028234663852886e38, + -3.4028234663852886e38, + -3.4028234663852886e38, + -3.4028234663852886e38, + -3.4028234663852886e38, + ], ], - ] - ), - ) + dtype=outputs.logits.dtype, + ), + ) # check the sum of all logits if config == {}: + # ensure that no individual value is -inf + assert outputs.logits.min() > -np.inf torch.testing.assert_close( outputs.logits.sum(0).sum(0), torch.tensor( [ - -1.6000000221574846e25, + -np.inf, -0.9064984321594238, 1.189674735069275, 0.9796359539031982, @@ -370,21 +528,24 @@ def test_forward(model, batch, decoder_input_ids, config): -0.12306825071573257, 0.6218758225440979, -0.4374474287033081, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -1.6000000530901656e33, - ] + -np.inf, + -np.inf, + -np.inf, + -np.inf, + -np.inf, + -np.inf, + ], + dtype=outputs.logits.dtype, ), ) elif config == {"decoder_position_id_mode": "pattern"}: + # ensure that no individual value is -inf + assert outputs.logits.min() > -np.inf torch.testing.assert_close( outputs.logits.sum(0).sum(0), torch.tensor( [ - -1.6000000221574846e25, + -np.inf, -0.5539568662643433, 0.7004716396331787, 1.5720455646514893, @@ -395,13 +556,42 @@ def test_forward(model, batch, decoder_input_ids, config): -0.04344810172915459, 0.3674442768096924, -0.6838937997817993, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -8.000000265450828e32, - -1.6000000530901656e33, - ] + -np.inf, + -np.inf, + -np.inf, + -np.inf, + -np.inf, + -np.inf, + ], + dtype=outputs.logits.dtype, + ), + ) + elif config == {"torch_dtype": "torch.float16"}: + # ensure that no individual value is -inf + assert outputs.logits.min() > -np.inf + torch.testing.assert_close( + outputs.logits.sum(0).sum(0), + torch.tensor( + [ + -np.inf, + -0.90625, + 1.189453125, + 0.9794921875, + 0.183837890625, + 1.3076171875, + -0.121337890625, + 0.53173828125, + -0.12322998046875, + 0.6220703125, + -0.437255859375, + -np.inf, + -np.inf, + -np.inf, + -np.inf, + -np.inf, + -np.inf, + ], + dtype=outputs.logits.dtype, ), ) else: @@ -416,7 +606,13 @@ def test_forward_with_labels(model, batch, config): assert set(inputs) == {"input_ids", "attention_mask"} assert set(targets_without_constraints) == {"labels", "decoder_attention_mask"} torch.manual_seed(42) - outputs = model(**inputs, **targets_without_constraints) + if config == {"torch_dtype": "torch.float16"}: + with pytest.raises(RuntimeError) as excinfo: + outputs = model(**inputs, **targets_without_constraints) + assert str(excinfo.value) == "\"nll_loss_out_frame\" not implemented for 'Half'" + return # skip the rest of the test + else: + outputs = model(**inputs, **targets_without_constraints) loss = outputs.loss if config == {}: torch.testing.assert_close(loss, torch.tensor(2.4516539573669434)) @@ -431,7 +627,13 @@ def test_forward_with_labels_and_constraints(model, batch_with_constraints, conf assert set(inputs) == {"input_ids", "attention_mask"} assert set(targets) == {"labels", "decoder_attention_mask", "constraints"} torch.manual_seed(42) - outputs = model(**inputs, **targets) + if config == {"torch_dtype": "torch.float16"}: + with pytest.raises(RuntimeError) as excinfo: + outputs = model(**inputs, **targets) + assert str(excinfo.value) == "\"nll_loss_out_frame\" not implemented for 'Half'" + return # skip the rest of the test + else: + outputs = model(**inputs, **targets) loss = outputs.loss if config == {}: torch.testing.assert_close(loss, torch.tensor(4.776531219482422)) @@ -559,7 +761,7 @@ def test_prepare_inputs_for_generation( assert result["head_mask"] is None assert result["decoder_head_mask"] is None assert result["cross_attn_head_mask"] is None - if config == {}: + if config == {} or config == {"torch_dtype": "torch.float16"}: assert "decoder_position_ids" not in result elif config == {"decoder_position_id_mode": "pattern"}: torch.testing.assert_close(result["decoder_position_ids"], torch.tensor([[0], [0]])) @@ -604,7 +806,7 @@ def test_prepare_inputs_for_generation_with_past_key_values( ) result = model.prepare_inputs_for_generation(past_key_values=dummy_past_key_values, **kwargs) - if config == {}: + if config == {} or config == {"torch_dtype": "torch.float16"}: assert len(result) == 10 elif config == {"decoder_position_id_mode": "pattern"}: assert len(result) == 11 @@ -646,7 +848,7 @@ def test_generate(model, batch, empty_decoder_input_ids, config): batch_size, seq_len = inputs["input_ids"].shape torch.manual_seed(42) outputs = model.generate(**inputs) - if config == {}: + if config == {} or config == {"torch_dtype": "torch.float16"}: assert outputs.shape == (batch_size, 20) # note that 20 is the model.config.max_length torch.testing.assert_close( outputs, diff --git a/tests/models/components/test_pointer_head.py b/tests/models/components/test_pointer_head.py index eb2650f68..882905aad 100644 --- a/tests/models/components/test_pointer_head.py +++ b/tests/models/components/test_pointer_head.py @@ -317,8 +317,14 @@ def test_prepare_decoder_inputs(): ] -def test_forward(): +@pytest.mark.parametrize( + "half_precision", + [True, False], +) +def test_forward(half_precision): pointer_head = get_pointer_head() + if half_precision: + pointer_head.half() # shape: (batch_size=2, input_sequence_length=5) encoder_input_ids = torch.tensor( [ @@ -357,121 +363,240 @@ def test_forward(): assert logits is not None # shape: (batch_size=2, target_sequence_length=4, num_targets+num_offsets=6+5==11) assert logits.shape == (2, 4, 11) - torch.testing.assert_close( - logits, - torch.tensor( - [ + if not half_precision: + torch.testing.assert_close( + logits, + torch.tensor( [ [ - -1.0000000138484279e24, - -0.9407045245170593, - -1.0000000138484279e24, - 0.5535521507263184, - 0.04295700043439865, - 1.0467679500579834, - -1.110795497894287, - 1.1652655601501465, - 0.09444020688533783, - 0.43052661418914795, - -1.0437036752700806, - ], - [ - -1.0000000138484279e24, - 1.1563994884490967, - -1.0000000138484279e24, - -0.8941665887832642, - -0.6862093806266785, - -1.154745101928711, - 1.6984729766845703, - -1.3889904022216797, - -0.4076152741909027, - -1.0112841129302979, - 0.9846026301383972, + [ + -3.4028234663852886e38, + -0.9407045245170593, + -3.4028234663852886e38, + 0.5535521507263184, + 0.04295700043439865, + 1.0467679500579834, + -1.110795497894287, + 1.1652655601501465, + 0.09444020688533783, + 0.43052661418914795, + -1.0437036752700806, + ], + [ + -3.4028234663852886e38, + 1.1563994884490967, + -3.4028234663852886e38, + -0.8941665887832642, + -0.6862093806266785, + -1.154745101928711, + 1.6984729766845703, + -1.3889904022216797, + -0.4076152741909027, + -1.0112841129302979, + 0.9846026301383972, + ], + [ + -3.4028234663852886e38, + -1.9377808570861816, + -3.4028234663852886e38, + 2.437451124191284, + 0.041493892669677734, + 0.5383729338645935, + -1.5238577127456665, + 1.6700562238693237, + -0.07231226563453674, + 1.0911093950271606, + -0.9189060926437378, + ], + [ + -3.4028234663852886e38, + -1.880744218826294, + -3.4028234663852886e38, + 3.8719429969787598, + 0.07287894189357758, + -1.3378281593322754, + -0.653921365737915, + 0.783344566822052, + -0.3344290256500244, + 1.3571363687515259, + 0.5505899786949158, + ], ], [ - -1.0000000138484279e24, - -1.9377808570861816, - -1.0000000138484279e24, - 2.437451124191284, - 0.041493892669677734, - 0.5383729338645935, - -1.5238577127456665, - 1.6700562238693237, - -0.07231226563453674, - 1.0911093950271606, - -0.9189060926437378, - ], - [ - -1.0000000138484279e24, - -1.880744218826294, - -1.0000000138484279e24, - 3.8719429969787598, - 0.07287894189357758, - -1.3378281593322754, - -0.653921365737915, - 0.783344566822052, - -0.3344290256500244, - 1.3571363687515259, - 0.5505899786949158, + [ + -3.4028234663852886e38, + -0.9407045245170593, + -3.4028234663852886e38, + 0.5535521507263184, + 0.04295700043439865, + 1.0467679500579834, + -1.0019789934158325, + 0.6891120672225952, + -0.002076566219329834, + 0.7561025619506836, + -3.4028234663852886e38, + ], + [ + -3.4028234663852886e38, + -1.880744218826294, + -3.4028234663852886e38, + 3.8719429969787598, + 0.07287894189357758, + -1.3378281593322754, + -1.3875324726104736, + -2.124865770339966, + -2.559859275817871, + 0.5425653457641602, + -3.4028234663852886e38, + ], + [ + -3.4028234663852886e38, + -1.479057788848877, + -3.4028234663852886e38, + 1.7857770919799805, + 0.6723557114601135, + 0.6378745436668396, + -2.262815475463867, + -0.1536862850189209, + -0.5338708758354187, + 1.3628911972045898, + -3.4028234663852886e38, + ], + [ + -3.4028234663852886e38, + 1.1815755367279053, + -3.4028234663852886e38, + -1.880744218826294, + -0.10646091401576996, + 0.1437276005744934, + 1.0795626640319824, + 0.6434042453765869, + 1.0681594610214233, + -0.5814396142959595, + -3.4028234663852886e38, + ], ], ], + dtype=logits.dtype, + ), + ) + else: + torch.testing.assert_close( + logits, + torch.tensor( [ [ - -1.0000000138484279e24, - -0.9407045245170593, - -1.0000000138484279e24, - 0.5535521507263184, - 0.04295700043439865, - 1.0467679500579834, - -1.0019789934158325, - 0.6891120672225952, - -0.002076566219329834, - 0.7561025619506836, - -1.0000000331813535e32, - ], - [ - -1.0000000138484279e24, - -1.880744218826294, - -1.0000000138484279e24, - 3.8719429969787598, - 0.07287894189357758, - -1.3378281593322754, - -1.3875324726104736, - -2.124865770339966, - -2.559859275817871, - 0.5425653457641602, - -1.0000000331813535e32, + [ + -65504.0, + -0.94091796875, + -65504.0, + 0.5537109375, + 0.04302978515625, + 1.0478515625, + -1.111328125, + 1.166015625, + 0.09442138671875, + 0.4306640625, + -1.044921875, + ], + [ + -65504.0, + 1.15625, + -65504.0, + -0.89404296875, + -0.68603515625, + -1.1552734375, + 1.69921875, + -1.3896484375, + -0.40771484375, + -1.01171875, + 0.984375, + ], + [ + -65504.0, + -1.9375, + -65504.0, + 2.4375, + 0.04156494140625, + 0.53955078125, + -1.5244140625, + 1.6708984375, + -0.0723876953125, + 1.0908203125, + -0.91943359375, + ], + [ + -65504.0, + -1.880859375, + -65504.0, + 3.87109375, + 0.0726318359375, + -1.337890625, + -0.65380859375, + 0.78369140625, + -0.33447265625, + 1.357421875, + 0.55029296875, + ], ], [ - -1.0000000138484279e24, - -1.479057788848877, - -1.0000000138484279e24, - 1.7857770919799805, - 0.6723557114601135, - 0.6378745436668396, - -2.262815475463867, - -0.1536862850189209, - -0.5338708758354187, - 1.3628911972045898, - -1.0000000331813535e32, - ], - [ - -1.0000000138484279e24, - 1.1815755367279053, - -1.0000000138484279e24, - -1.880744218826294, - -0.10646091401576996, - 0.1437276005744934, - 1.0795626640319824, - 0.6434042453765869, - 1.0681594610214233, - -0.5814396142959595, - -1.0000000331813535e32, + [ + -65504.0, + -0.94091796875, + -65504.0, + 0.5537109375, + 0.04302978515625, + 1.0478515625, + -1.001953125, + 0.689453125, + -0.001953125, + 0.7568359375, + -65504.0, + ], + [ + -65504.0, + -1.880859375, + -65504.0, + 3.87109375, + 0.0726318359375, + -1.337890625, + -1.38671875, + -2.125, + -2.55859375, + 0.54296875, + -65504.0, + ], + [ + -65504.0, + -1.4794921875, + -65504.0, + 1.7861328125, + 0.67236328125, + 0.63818359375, + -2.263671875, + -0.1536865234375, + -0.53369140625, + 1.36328125, + -65504.0, + ], + [ + -65504.0, + 1.181640625, + -65504.0, + -1.880859375, + -0.10638427734375, + 0.1436767578125, + 1.0791015625, + 0.6435546875, + 1.068359375, + -0.58154296875, + -65504.0, + ], ], ], - ] - ), - ) + dtype=logits.dtype, + ), + ) @pytest.mark.parametrize(