Fix FLAVA bugs #530 and #533 (attention masks and CLS tokens)#545
Open
Vishal-sys-code wants to merge 1 commit intofacebookresearch:mainfrom
Open
Fix FLAVA bugs #530 and #533 (attention masks and CLS tokens)#545Vishal-sys-code wants to merge 1 commit intofacebookresearch:mainfrom
Vishal-sys-code wants to merge 1 commit intofacebookresearch:mainfrom
Conversation
…on masks and CLS tokens) This commit fixes two open bugs in the FLAVA model implementation: - Issue facebookresearch#533: Fixed a bug in `FLAVATransformerWithoutEmbeddings` where prepending the CLS token to `hidden_states` caused sequence length mismatches because `attention_mask` was not correspondingly padded. It now dynamically concatenates an active mask token to `attention_mask` matching the CLS token in a dimension-agnostic way (`dim=-1`). Also updated the docstring to clarify that `attention_mask` needs to be pre-expanded to `(B, 1, 1, seq_len)`. - Issue facebookresearch#530: The multimodal encoder `FLAVAModel.encode_mm` was previously ignoring padded text tokens. The text padding mask is now dynamically generated from input text (or can be passed directly), combined with an all-ones image mask, and passed to the multimodal cross-attention encoder. This parameter is properly threaded through both `FLAVAForPreTraining` and `FLAVAForClassification`. - Added a robust regression test (`test_attention_mask_affects_output`) directly targeting `FLAVATransformerWithoutEmbeddings` to verify that padded tokens properly alter the final encoder output.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
This PR addresses two bugs in the FLAVA model's handling of attention masks:
FLAVATransformerWithoutEmbeddings: Prepending the CLS token tohidden_statespreviously caused a sequence length mismatch with theattention_mask. This PR dynamically concatenates an active mask token (1) to the mask in a dimension-agnostic way (dim=-1) to keep sequences aligned. A docstring was also added clarifying that the mask should be pre-expanded to(B, 1, 1, seq_len).FLAVAModel.encode_mm: The multimodal encoder was previously ignoring text padding tokens. The model now dynamically inferstext_pad_mask(viaself.text_encoder.embeddings.pad_token_id) or accepts it directly, builds a combined text/image mask, and propagates it through to the cross-modal attention layers.FLAVAForClassificationandFLAVAForPreTrainingwere updated to correctly thread this new parameter.Test plan:
test_attention_mask_affects_outputintests/models/flava/test_flava.py. It directly targetsflava_multimodal_encoderwith non-trivial random tensors to verify that padding masks successfully alter the model'slast_hidden_state.pytest tests/models/flava/test_flava.pyto ensure backwards compatibility and verify that all golden values and fallback behaviors remain completely intact.Fixes #533, Fixes #530