Skip to content

[blip_2] Support attn_implementation=sdpa dispatch#46401

Open
YangKai0616 wants to merge 2 commits into
huggingface:mainfrom
YangKai0616:sdpa-Blip2QFormerModel
Open

[blip_2] Support attn_implementation=sdpa dispatch#46401
YangKai0616 wants to merge 2 commits into
huggingface:mainfrom
YangKai0616:sdpa-Blip2QFormerModel

Conversation

@YangKai0616
Copy link
Copy Markdown
Contributor

What does this PR do?

As per the title.

@YangKai0616
Copy link
Copy Markdown
Contributor Author

@vasqu

class Blip2QFormerModel(Blip2PreTrainedModel):
config: Blip2QFormerConfig

_supports_attention_backend = False # adds position on attn weights before last matmul
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have to set it explicitly to True here and in other models

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

config: Blip2QFormerConfig

_supports_attention_backend = False # adds position on attn weights before last matmul
_supports_flash_attn = False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can do FA and flex now, no?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FA still looks blocked by the QFormer fp32 path, flex does not seem blocked in the same way as FA, but flex fails with attention dropout in training, and QFormer attention recording expects attention weights while flex returns LSE instead. Perhaps we should keep this PR to SDPA content only?👀

Comment on lines -599 to +583
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

mixed_query_layer = self.query(hidden_states)

query_layer = self.transpose_for_scores(mixed_query_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

attention_scores = attention_scores / math.sqrt(self.attention_head_size)

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
current_states = hidden_states
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh nice, i forgot we got rid of those position_embeddings which weren't used by official ckpt

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 5, 2026

[For maintainers] Suggested jobs to run (before merge)

run-slow: blip_2, instructblip, instructblipvideo

@YangKai0616
Copy link
Copy Markdown
Contributor Author

For the test case test_modeling_instructblip.py::InstructBlipModelIntegrationTest::test_inference_flant5_xl, the output from the current default SDPA branch differs from the default eager mode in upstream/main.
However, it matches when using eager mode on the current branch or float32 dtype. Expectations should be updated, but since I don't know the device used in your CI, this PR will not be updated for now.

Copy link
Copy Markdown
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few smaller comments but overall looks good already

Comment on lines -564 to -579
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients

def get_attn_gradients(self):
return self.attn_gradients

def save_attention_map(self, attention_map):
self.attention_map = attention_map

def get_attention_map(self):
return self.attention_map

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh god that was ugly 😬 glad to get rid of this

_supports_attention_backend = True
_supports_sdpa = True
_supports_flash_attn = False # Q-Former is kept in fp32, which blocks reliable Flash Attention dispatch.
_supports_flex_attn = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flex attention should work no?

_supports_sdpa = False
_supports_attention_backend = True
_supports_sdpa = True
_supports_flash_attn = False # Q-Former is kept in fp32, which blocks reliable Flash Attention dispatch.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yea I remember that one...

_supports_attention_backend = False # adds position on attn weights before last matmul
_supports_attention_backend = True
_supports_sdpa = True
_supports_flash_attn = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets add comments when why not

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And same re flex

_supports_sdpa = True
_supports_flash_attn = False
_supports_sdpa = False
_supports_flex_attn = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

def test_model_base_model_prefix(self):
pass

def test_sdpa_can_dispatch_on_flash(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but lets use the @unittest.skip decorator please

self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.is_causal = False
self.attention_dropout = config.attention_probs_dropout_prob
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.attention_dropout = config.attention_probs_dropout_prob
self.dropout = config.attention_probs_dropout_prob

nit

@vasqu
Copy link
Copy Markdown
Contributor

vasqu commented Jun 5, 2026

For the test case test_modeling_instructblip.py::InstructBlipModelIntegrationTest::test_inference_flant5_xl, the output from the current default SDPA branch differs from the default eager mode in upstream/main.
However, it matches when using eager mode on the current branch or float32 dtype. Expectations should be updated, but since I don't know the device used in your CI, this PR will not be updated for now.

Can we force eager attention at load time instead for now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants