diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py index baaa21ed1f..ffcfd8b801 100644 --- a/monai/networks/blocks/crossattention.py +++ b/monai/networks/blocks/crossattention.py @@ -61,8 +61,13 @@ def __init__( input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional parameter size. attention_dtype: cast attention operations to this dtype. - use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + use_flash_attention: if True, dispatch attention through + ``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend; + the true flash kernel is used when no custom additive attention bias is passed. + Pure ``causal`` masking (with no ``rel_pos_embedding``) keeps the fast path via + ``is_causal=True``. When an additive bias is required (for example, + ``rel_pos_embedding``, or ``causal`` merged with another bias), PyTorch falls + back to the memory-efficient or cuDNN SDPA backend. """ super().__init__() @@ -88,9 +93,6 @@ def __init__( "to True. save_attn can only be used if use_flash_attention is False" ) - if use_flash_attention and rel_pos_embedding is not None: - raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") - self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.context_input_size = context_input_size if context_input_size else hidden_size @@ -155,8 +157,31 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None): k = k.to(self.attention_dtype) if self.use_flash_attention: + # Additive bias path mirrors SABlock: null bias preserves the true + # flash kernel fast path; any of rel_pos_embedding / causal forces + # fallback to the efficient or cuDNN SDPA backend. + bias: torch.Tensor | None = None + lq, lk = q.shape[-2], k.shape[-2] + + if self.rel_positional_embedding is not None: + zero_logits = torch.zeros(q.shape[0], self.num_heads, lq, lk, dtype=q.dtype, device=q.device) + bias = self.rel_positional_embedding(x, zero_logits, q) + + is_causal_arg = self.causal + if self.causal and bias is not None: + causal_bias = torch.zeros(lq, lk, dtype=q.dtype, device=q.device) + causal_bias.masked_fill_(self.causal_mask[0, 0, :lq, :lk] == 0, float("-inf")) + bias = bias + causal_bias + is_causal_arg = False + x = torch.nn.functional.scaled_dot_product_attention( - query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal + query=q, + key=k, + value=v, + attn_mask=bias, + scale=self.scale, + dropout_p=self.dropout_rate, + is_causal=is_causal_arg, ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 2791d2fb00..b03b237ba0 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -63,8 +63,13 @@ def __init__( attention_dtype: cast attention operations to this dtype. include_fc: whether to include the final linear layer. Default to True. use_combined_linear: whether to use a single linear layer for qkv projection, default to True. - use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism - (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html). + use_flash_attention: if True, dispatch attention through + ``torch.nn.functional.scaled_dot_product_attention``. PyTorch selects the backend; + the true flash kernel is used when no custom additive attention bias is passed. + Pure ``causal`` masking (with no ``rel_pos_embedding`` or ``attn_mask``) keeps the + fast path via ``is_causal=True``. When an additive bias is required (for example, + ``rel_pos_embedding``, or ``causal``/``attn_mask`` merged with another bias), + PyTorch falls back to the memory-efficient or cuDNN SDPA backend. """ @@ -94,9 +99,6 @@ def __init__( "to True. save_attn can only be used if use_flash_attention is False." ) - if use_flash_attention and rel_pos_embedding is not None: - raise ValueError("rel_pos_embedding must be None if you are using flash_attention.") - self.num_heads = num_heads self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size self.out_proj: nn.Linear | nn.Identity @@ -174,14 +176,40 @@ def forward(self, x, attn_mask: torch.Tensor | None = None): k = k.to(self.attention_dtype) if self.use_flash_attention: + # Build an additive attention bias when we have to combine + # rel_pos_embedding, a causal mask, or a user attn_mask. A null bias + # preserves the no-mask fast path so PyTorch can still pick the true + # flash kernel when available. + bias: torch.Tensor | None = None + lq, lk = q.shape[-2], k.shape[-2] + + if self.rel_positional_embedding is not None: + zero_logits = torch.zeros(q.shape[0], self.num_heads, lq, lk, dtype=q.dtype, device=q.device) + bias = self.rel_positional_embedding(x, zero_logits, q) + + is_causal_arg = self.causal + if self.causal and (bias is not None or attn_mask is not None): + causal_bias = torch.zeros(lq, lk, dtype=q.dtype, device=q.device) + causal_bias.masked_fill_(self.causal_mask[0, 0, :lq, :lk] == 0, float("-inf")) + bias = causal_bias if bias is None else bias + causal_bias + is_causal_arg = False + + if attn_mask is not None: + if self.causal: + raise ValueError("Causal attention does not support attention masks.") + mask_bias = torch.zeros_like(attn_mask, dtype=q.dtype) + mask_bias.masked_fill_(attn_mask == 0, float("-inf")) + mask_bias = mask_bias.unsqueeze(1).unsqueeze(2) + bias = mask_bias if bias is None else bias + mask_bias + x = F.scaled_dot_product_attention( query=q, key=k, value=v, - attn_mask=attn_mask, + attn_mask=bias, scale=self.scale, dropout_p=self.dropout_rate, - is_causal=self.causal, + is_causal=is_causal_arg, ) else: att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale diff --git a/tests/networks/blocks/test_crossattention.py b/tests/networks/blocks/test_crossattention.py index f691f4e534..ebd39b92b7 100644 --- a/tests/networks/blocks/test_crossattention.py +++ b/tests/networks/blocks/test_crossattention.py @@ -30,7 +30,7 @@ [ { **{k: v for k, v in params.items() if k not in ["rel_pos_embedding_val"]}, - "rel_pos_embedding": params["rel_pos_embedding_val"] if not params["use_flash_attention"] else None, + "rel_pos_embedding": params["rel_pos_embedding_val"], }, (2, 512, params["hidden_size"]), (2, 512, params["hidden_size"]), @@ -69,16 +69,53 @@ def test_save_attn_with_flash_attention(self): hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True ) + @skipUnless(has_einops, "Requires einops") def test_rel_pos_embedding_with_flash_attention(self): - with self.assertRaises(ValueError): - CrossAttentionBlock( - hidden_size=128, - num_heads=3, - dropout_rate=0.1, - use_flash_attention=True, - save_attn=False, - rel_pos_embedding=RelPosEmbedding.DECOMPOSED, - ) + # rel_pos_embedding combined with use_flash_attention now dispatches + # via SDPA with an additive bias. Must match the explicit path. + for input_size in [(16, 32), (8, 8, 8)]: + input_param = { + "hidden_size": 128, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": RelPosEmbedding.DECOMPOSED, + "input_size": input_size, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_flash = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) + block_ref = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) + block_ref.load_state_dict(block_flash.state_dict()) + seq_len = int(np.prod(input_size)) + test_data = torch.randn(2, seq_len, 128).to(device) + with eval_mode(block_flash), eval_mode(block_ref): + out_flash = block_flash(test_data) + out_ref = block_ref(test_data) + assert_allclose(out_flash, out_ref, atol=1e-4) + + @skipUnless(has_einops, "Requires einops") + def test_causal_rel_pos_with_flash_attention(self): + # Exercise the merged causal-bias branch: causal=True together with + # rel_pos_embedding builds an additive bias and disables is_causal. + for input_size in [(16, 32), (8, 8, 8)]: + seq_len = int(np.prod(input_size)) + input_param = { + "hidden_size": 128, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": RelPosEmbedding.DECOMPOSED, + "input_size": input_size, + "causal": True, + "sequence_length": seq_len, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_flash = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device) + block_ref = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device) + block_ref.load_state_dict(block_flash.state_dict()) + test_data = torch.randn(2, seq_len, 128).to(device) + with eval_mode(block_flash), eval_mode(block_ref): + out_flash = block_flash(test_data) + out_ref = block_ref(test_data) + assert_allclose(out_flash, out_ref, atol=1e-4) @skipUnless(has_einops, "Requires einops") def test_attention_dim_not_multiple_of_heads(self): diff --git a/tests/networks/blocks/test_selfattention.py b/tests/networks/blocks/test_selfattention.py index af52918612..1ff6398894 100644 --- a/tests/networks/blocks/test_selfattention.py +++ b/tests/networks/blocks/test_selfattention.py @@ -43,7 +43,7 @@ "input_size": input_size, "include_fc": include_fc, "use_combined_linear": use_combined_linear, - "use_flash_attention": True if rel_pos_embedding is None else False, + "use_flash_attention": True, }, (2, 512, hidden_size), (2, 512, hidden_size), @@ -67,16 +67,79 @@ def test_ill_arg(self): with self.assertRaises(ValueError): SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4) + @skipUnless(has_einops, "Requires einops") def test_rel_pos_embedding_with_flash_attention(self): - with self.assertRaises(ValueError): - SABlock( - hidden_size=128, - num_heads=3, - dropout_rate=0.1, - use_flash_attention=True, - save_attn=False, - rel_pos_embedding=RelPosEmbedding.DECOMPOSED, - ) + # rel_pos_embedding is now allowed with use_flash_attention; SDPA picks + # a fused backend that supports an additive attention bias. The two + # code paths must be numerically equivalent for the same weights. + for input_size in [(16, 32), (8, 8, 8)]: + input_param = { + "hidden_size": 128, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": RelPosEmbedding.DECOMPOSED, + "input_size": input_size, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_flash = SABlock(**input_param, use_flash_attention=True).to(device) + block_ref = SABlock(**input_param, use_flash_attention=False).to(device) + block_ref.load_state_dict(block_flash.state_dict()) + test_data = torch.randn(2, int(np.prod(input_size)), 128).to(device) + with eval_mode(block_flash), eval_mode(block_ref): + out_flash = block_flash(test_data) + out_ref = block_ref(test_data) + assert_allclose(out_flash, out_ref, atol=1e-4) + + @skipUnless(has_einops, "Requires einops") + def test_causal_rel_pos_with_flash_attention(self): + # Exercise the merged causal-bias branch: causal=True together with + # rel_pos_embedding builds an additive bias and disables is_causal, + # so flash and reference paths must still match numerically. + for input_size in [(16, 32), (8, 8, 8)]: + seq_len = int(np.prod(input_size)) + input_param = { + "hidden_size": 128, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": RelPosEmbedding.DECOMPOSED, + "input_size": input_size, + "causal": True, + "sequence_length": seq_len, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_flash = SABlock(**input_param, use_flash_attention=True).to(device) + block_ref = SABlock(**input_param, use_flash_attention=False).to(device) + block_ref.load_state_dict(block_flash.state_dict()) + test_data = torch.randn(2, seq_len, 128).to(device) + with eval_mode(block_flash), eval_mode(block_ref): + out_flash = block_flash(test_data) + out_ref = block_ref(test_data) + assert_allclose(out_flash, out_ref, atol=1e-4) + + @skipUnless(has_einops, "Requires einops") + def test_attn_mask_rel_pos_with_flash_attention(self): + # Exercise the user-attn-mask + rel_pos branch: the user mask is + # merged into the additive bias passed via SDPA's attn_mask argument. + for input_size in [(16, 32), (8, 8, 8)]: + seq_len = int(np.prod(input_size)) + input_param = { + "hidden_size": 128, + "num_heads": 4, + "dropout_rate": 0.0, + "rel_pos_embedding": RelPosEmbedding.DECOMPOSED, + "input_size": input_size, + } + device = "cuda:0" if torch.cuda.is_available() else "cpu" + block_flash = SABlock(**input_param, use_flash_attention=True).to(device) + block_ref = SABlock(**input_param, use_flash_attention=False).to(device) + block_ref.load_state_dict(block_flash.state_dict()) + test_data = torch.randn(2, seq_len, 128).to(device) + attn_mask = torch.ones(2, seq_len, dtype=torch.bool, device=device) + attn_mask[:, seq_len // 2 :] = False # mask out the second half + with eval_mode(block_flash), eval_mode(block_ref): + out_flash = block_flash(test_data, attn_mask=attn_mask) + out_ref = block_ref(test_data, attn_mask=attn_mask) + assert_allclose(out_flash, out_ref, atol=1e-4) def test_save_attn_with_flash_attention(self): with self.assertRaises(ValueError):