[blip_2] Support attn_implementation=sdpa dispatch#46401
Conversation
| class Blip2QFormerModel(Blip2PreTrainedModel): | ||
| config: Blip2QFormerConfig | ||
|
|
||
| _supports_attention_backend = False # adds position on attn weights before last matmul |
There was a problem hiding this comment.
have to set it explicitly to True here and in other models
| config: Blip2QFormerConfig | ||
|
|
||
| _supports_attention_backend = False # adds position on attn weights before last matmul | ||
| _supports_flash_attn = False |
There was a problem hiding this comment.
i think we can do FA and flex now, no?
There was a problem hiding this comment.
FA still looks blocked by the QFormer fp32 path, flex does not seem blocked in the same way as FA, but flex fails with attention dropout in training, and QFormer attention recording expects attention weights while flex returns LSE instead. Perhaps we should keep this PR to SDPA content only?👀
| key_layer = self.transpose_for_scores(self.key(hidden_states)) | ||
| value_layer = self.transpose_for_scores(self.value(hidden_states)) | ||
|
|
||
| mixed_query_layer = self.query(hidden_states) | ||
|
|
||
| query_layer = self.transpose_for_scores(mixed_query_layer) | ||
|
|
||
| # Take the dot product between "query" and "key" to get the raw attention scores. | ||
| attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | ||
|
|
||
| attention_scores = attention_scores / math.sqrt(self.attention_head_size) | ||
|
|
||
| if attention_mask is not None: | ||
| # Apply the attention mask is (precomputed for all layers in BertModel forward() function) | ||
| attention_scores = attention_scores + attention_mask | ||
| current_states = hidden_states |
There was a problem hiding this comment.
ohh nice, i forgot we got rid of those position_embeddings which weren't used by official ckpt
|
[For maintainers] Suggested jobs to run (before merge) run-slow: blip_2, instructblip, instructblipvideo |
|
For the test case |
vasqu
left a comment
There was a problem hiding this comment.
Just a few smaller comments but overall looks good already
| def save_attn_gradients(self, attn_gradients): | ||
| self.attn_gradients = attn_gradients | ||
|
|
||
| def get_attn_gradients(self): | ||
| return self.attn_gradients | ||
|
|
||
| def save_attention_map(self, attention_map): | ||
| self.attention_map = attention_map | ||
|
|
||
| def get_attention_map(self): | ||
| return self.attention_map | ||
|
|
||
| def transpose_for_scores(self, x): | ||
| new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | ||
| x = x.view(*new_x_shape) | ||
| return x.permute(0, 2, 1, 3) |
There was a problem hiding this comment.
Oh god that was ugly 😬 glad to get rid of this
| _supports_attention_backend = True | ||
| _supports_sdpa = True | ||
| _supports_flash_attn = False # Q-Former is kept in fp32, which blocks reliable Flash Attention dispatch. | ||
| _supports_flex_attn = False |
There was a problem hiding this comment.
Flex attention should work no?
| _supports_sdpa = False | ||
| _supports_attention_backend = True | ||
| _supports_sdpa = True | ||
| _supports_flash_attn = False # Q-Former is kept in fp32, which blocks reliable Flash Attention dispatch. |
There was a problem hiding this comment.
Oh yea I remember that one...
| _supports_attention_backend = False # adds position on attn weights before last matmul | ||
| _supports_attention_backend = True | ||
| _supports_sdpa = True | ||
| _supports_flash_attn = False |
There was a problem hiding this comment.
Lets add comments when why not
| _supports_sdpa = True | ||
| _supports_flash_attn = False | ||
| _supports_sdpa = False | ||
| _supports_flex_attn = False |
| def test_model_base_model_prefix(self): | ||
| pass | ||
|
|
||
| def test_sdpa_can_dispatch_on_flash(self): |
There was a problem hiding this comment.
yes but lets use the @unittest.skip decorator please
| self.all_head_size = self.num_attention_heads * self.attention_head_size | ||
| self.scaling = self.attention_head_size**-0.5 | ||
| self.is_causal = False | ||
| self.attention_dropout = config.attention_probs_dropout_prob |
There was a problem hiding this comment.
| self.attention_dropout = config.attention_probs_dropout_prob | |
| self.dropout = config.attention_probs_dropout_prob |
nit
Can we force eager attention at load time instead for now? |
What does this PR do?
As per the title.