Skip to content

Fix FLAVA bugs #530 and #533 (attention masks and CLS tokens)#545

Open
Vishal-sys-code wants to merge 1 commit intofacebookresearch:mainfrom
Vishal-sys-code:main
Open

Fix FLAVA bugs #530 and #533 (attention masks and CLS tokens)#545
Vishal-sys-code wants to merge 1 commit intofacebookresearch:mainfrom
Vishal-sys-code:main

Conversation

@Vishal-sys-code
Copy link
Copy Markdown

Summary:
This PR addresses two bugs in the FLAVA model's handling of attention masks:

  1. Bug in FLAVATransformerWithoutEmbeddings: Prepending the CLS token to hidden_states previously caused a sequence length mismatch with the attention_mask. This PR dynamically concatenates an active mask token (1) to the mask in a dimension-agnostic way (dim=-1) to keep sequences aligned. A docstring was also added clarifying that the mask should be pre-expanded to (B, 1, 1, seq_len).
  2. Bug in FLAVAModel.encode_mm: The multimodal encoder was previously ignoring text padding tokens. The model now dynamically infers text_pad_mask (via self.text_encoder.embeddings.pad_token_id) or accepts it directly, builds a combined text/image mask, and propagates it through to the cross-modal attention layers. FLAVAForClassification and FLAVAForPreTraining were updated to correctly thread this new parameter.

Test plan:

  • Added a robust regression test test_attention_mask_affects_output in tests/models/flava/test_flava.py. It directly targets flava_multimodal_encoder with non-trivial random tensors to verify that padding masks successfully alter the model's last_hidden_state.
  • Ran the test suite via pytest tests/models/flava/test_flava.py to ensure backwards compatibility and verify that all golden values and fallback behaviors remain completely intact.

Fixes #533, Fixes #530

…on masks and CLS tokens)

This commit fixes two open bugs in the FLAVA model implementation:

- Issue facebookresearch#533: Fixed a bug in `FLAVATransformerWithoutEmbeddings` where prepending the CLS token to `hidden_states` caused sequence length mismatches because `attention_mask` was not correspondingly padded. It now dynamically concatenates an active mask token to `attention_mask` matching the CLS token in a dimension-agnostic way (`dim=-1`). Also updated the docstring to clarify that `attention_mask` needs to be pre-expanded to `(B, 1, 1, seq_len)`.
- Issue facebookresearch#530: The multimodal encoder `FLAVAModel.encode_mm` was previously ignoring padded text tokens. The text padding mask is now dynamically generated from input text (or can be passed directly), combined with an all-ones image mask, and passed to the multimodal cross-attention encoder. This parameter is properly threaded through both `FLAVAForPreTraining` and `FLAVAForClassification`.
- Added a robust regression test (`test_attention_mask_affects_output`) directly targeting `FLAVATransformerWithoutEmbeddings` to verify that padded tokens properly alter the final encoder output.
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug in FLAVATransformerWithoutEmbeddings Probable FLAVA multimodal encoder bug

1 participant