-
Notifications
You must be signed in to change notification settings - Fork 7k
add SP support for _flash_3_varlen_hub backend
#13809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -340,6 +340,8 @@ class _HubKernelConfig: | |
| AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig( | ||
| repo_id="kernels-community/flash-attn3", | ||
| function_attr="flash_attn_varlen_func", | ||
| wrapped_forward_attr="flash_attn_interface._flash_attn_forward", | ||
| wrapped_backward_attr="flash_attn_interface._flash_attn_backward", | ||
| version=1, | ||
| ), | ||
| AttentionBackendName.FLASH_HUB: _HubKernelConfig( | ||
|
|
@@ -1612,6 +1614,194 @@ def _flash_attention_3_hub_backward_op( | |
| return grad_query, grad_key, grad_value | ||
|
|
||
|
|
||
| def _flash_attention_3_varlen_hub_forward_op( | ||
| ctx: torch.autograd.function.FunctionCtx, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: torch.Tensor | None = None, | ||
| dropout_p: float = 0.0, | ||
| is_causal: bool = False, | ||
| scale: float | None = None, | ||
| enable_gqa: bool = False, | ||
| return_lse: bool = False, | ||
| _save_ctx: bool = True, | ||
| _parallel_config: "ParallelConfig" | None = None, | ||
| *, | ||
| window_size: tuple[int, int] = (-1, -1), | ||
| softcap: float = 0.0, | ||
| num_splits: int = 1, | ||
| pack_gqa: bool | None = None, | ||
| deterministic: bool = False, | ||
| sm_margin: int = 0, | ||
| ): | ||
| if dropout_p != 0.0: | ||
| raise ValueError("`dropout_p` is not yet supported for flash-attn 3 varlen hub kernels.") | ||
| if enable_gqa: | ||
| raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 varlen hub kernels.") | ||
|
|
||
| config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB] | ||
| wrapped_forward_fn = config.wrapped_forward_fn | ||
| wrapped_backward_fn = config.wrapped_backward_fn | ||
| if wrapped_forward_fn is None or wrapped_backward_fn is None: | ||
| raise RuntimeError( | ||
| "Flash attention 3 varlen hub kernels must expose `flash_attn_interface._flash_attn_forward` and " | ||
| "`flash_attn_interface._flash_attn_backward` for context parallel execution." | ||
| ) | ||
|
|
||
| if scale is None: | ||
| scale = query.shape[-1] ** (-0.5) | ||
|
|
||
| batch_size, seq_len_q, num_heads, _ = query.shape | ||
| _, seq_len_kv, _, _ = key.shape | ||
|
|
||
| if attn_mask is not None: | ||
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | ||
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (_, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device) | ||
| ) | ||
| indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten() | ||
| query_packed = query.flatten(0, 1) | ||
| key_packed = key.reshape(-1, *key.shape[2:])[indices_k] | ||
| value_packed = value.reshape(-1, *value.shape[2:])[indices_k] | ||
| max_seqlen_q = seq_len_q | ||
| else: | ||
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) | ||
| ) | ||
| query_packed = query.flatten(0, 1) | ||
| key_packed = key.flatten(0, 1) | ||
| value_packed = value.flatten(0, 1) | ||
| seqlens_k = None | ||
|
|
||
| out_packed, softmax_lse, *_ = wrapped_forward_fn( | ||
| query_packed, | ||
| key_packed, | ||
| value_packed, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| None, | ||
| None, | ||
| None, | ||
| max_seqlen_q, | ||
| max_seqlen_k, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| None, | ||
| scale, | ||
| is_causal, | ||
| window_size[0], | ||
| window_size[1], | ||
| 0, | ||
| softcap, | ||
| True, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does |
||
| None, | ||
| num_splits, | ||
| pack_gqa, | ||
| sm_margin, | ||
| ) | ||
|
|
||
| out = out_packed.view(batch_size, seq_len_q, *out_packed.shape[1:]) | ||
|
|
||
| if _save_ctx: | ||
| ctx.save_for_backward( | ||
| query_packed, key_packed, value_packed, out_packed, softmax_lse, cu_seqlens_q, cu_seqlens_k | ||
| ) | ||
| ctx.seqlens_k = seqlens_k # None if unmasked | ||
| ctx.indices_k = indices_k if attn_mask is not None else None | ||
| ctx.max_seqlen_q = max_seqlen_q | ||
| ctx.max_seqlen_k = max_seqlen_k | ||
| ctx.batch_size = batch_size | ||
| ctx.seq_len_q = seq_len_q | ||
| ctx.seq_len_kv = seq_len_kv | ||
| ctx.num_heads = num_heads | ||
| ctx.scale = scale | ||
| ctx.is_causal = is_causal | ||
| ctx.window_size = window_size | ||
| ctx.softcap = softcap | ||
| ctx.deterministic = deterministic | ||
| ctx.sm_margin = sm_margin | ||
|
|
||
| # softmax_lse in varlen mode: (num_heads, total_q) -> (batch_size, seq_len_q, num_heads) | ||
| lse_sp = softmax_lse.view(num_heads, batch_size, seq_len_q).permute(1, 2, 0).contiguous() | ||
|
|
||
| return (out, lse_sp) if return_lse else out | ||
|
|
||
|
|
||
| def _flash_attention_3_varlen_hub_backward_op( | ||
| ctx: torch.autograd.function.FunctionCtx, | ||
| grad_out: torch.Tensor, | ||
| *args, | ||
| **kwargs, | ||
| ): | ||
| config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB] | ||
| wrapped_backward_fn = config.wrapped_backward_fn | ||
| if wrapped_backward_fn is None: | ||
| raise RuntimeError( | ||
| "Flash attention 3 varlen hub kernels must expose `flash_attn_interface._flash_attn_backward` " | ||
| "for context parallel execution." | ||
| ) | ||
|
|
||
| query_packed, key_packed, value_packed, out_packed, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors | ||
|
|
||
| grad_out_packed = grad_out.flatten(0, 1) | ||
| grad_query, grad_key, grad_value = ( | ||
| torch.empty_like(query_packed), | ||
| torch.empty_like(key_packed), | ||
| torch.empty_like(value_packed), | ||
| ) | ||
|
|
||
| wrapped_backward_fn( | ||
| grad_out_packed, | ||
| query_packed, | ||
| key_packed, | ||
| value_packed, | ||
| out_packed, | ||
| softmax_lse, | ||
| cu_seqlens_q, | ||
| cu_seqlens_k, | ||
| None, | ||
| None, | ||
| ctx.max_seqlen_q, | ||
| ctx.max_seqlen_k, | ||
| grad_query, | ||
| grad_key, | ||
| grad_value, | ||
| ctx.scale, | ||
| ctx.is_causal, | ||
| ctx.window_size[0], | ||
| ctx.window_size[1], | ||
| ctx.softcap, | ||
| ctx.deterministic, | ||
| ctx.sm_margin, | ||
| ) | ||
|
|
||
| grad_query = grad_query.view(ctx.batch_size, ctx.seq_len_q, *grad_query.shape[1:]) | ||
|
|
||
| if ctx.seqlens_k is not None: | ||
| grad_key = _unpad_to_padded(grad_key, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv) | ||
| grad_value = _unpad_to_padded(grad_value, ctx.indices_k, ctx.batch_size, ctx.seq_len_kv) | ||
| else: | ||
| grad_key = grad_key.view(ctx.batch_size, ctx.seq_len_kv, *grad_key.shape[1:]) | ||
| grad_value = grad_value.view(ctx.batch_size, ctx.seq_len_kv, *grad_value.shape[1:]) | ||
|
|
||
| grad_query = grad_query[..., : grad_out.shape[-1]] | ||
| grad_key = grad_key[..., : grad_out.shape[-1]] | ||
| grad_value = grad_value[..., : grad_out.shape[-1]] | ||
|
|
||
| return grad_query, grad_key, grad_value | ||
|
|
||
|
|
||
| def _sage_attention_forward_op( | ||
| ctx: torch.autograd.function.FunctionCtx, | ||
| query: torch.Tensor, | ||
|
|
@@ -2986,7 +3176,7 @@ def _flash_attention_3_hub( | |
| @_AttentionBackendRegistry.register( | ||
| AttentionBackendName._FLASH_3_VARLEN_HUB, | ||
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], | ||
| supports_context_parallel=False, | ||
| supports_context_parallel=True, | ||
| ) | ||
| def _flash_attention_3_varlen_hub( | ||
| query: torch.Tensor, | ||
|
|
@@ -2998,41 +3188,73 @@ def _flash_attention_3_varlen_hub( | |
| return_lse: bool = False, | ||
| _parallel_config: "ParallelConfig" | None = None, | ||
| ) -> torch.Tensor: | ||
| if _parallel_config is not None and _parallel_config.context_parallel_config.ring_degree > 1: | ||
| raise NotImplementedError("`ring_degree > 1` is not yet supported for the _FLASH_3_VARLEN_HUB backend.") | ||
|
|
||
| lse = None | ||
| batch_size, seq_len_q, _, _ = query.shape | ||
| _, seq_len_kv, _, _ = key.shape | ||
|
|
||
| if attn_mask is not None: | ||
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | ||
|
|
||
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen( | ||
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device | ||
| ) | ||
| ) | ||
|
|
||
| key_valid, value_valid = [], [] | ||
| for b in range(batch_size): | ||
| valid_len = seqlens_k[b] | ||
| key_valid.append(key[b, :valid_len]) | ||
| value_valid.append(value[b, :valid_len]) | ||
|
Comment on lines
-3004
to
-3017
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like should come under |
||
| if _parallel_config is None: | ||
| if attn_mask is not None: | ||
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | ||
| (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask, query.device) | ||
| ) | ||
| indices_k = attn_mask.flatten().nonzero(as_tuple=False).flatten() | ||
| key_packed = key.reshape(-1, *key.shape[2:])[indices_k] | ||
| value_packed = value.reshape(-1, *value.shape[2:])[indices_k] | ||
| else: | ||
| (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) | ||
| ) | ||
| key_packed = key.flatten(0, 1) | ||
| value_packed = value.flatten(0, 1) | ||
|
|
||
| query_packed = query.flatten(0, 1) | ||
| key_packed = torch.cat(key_valid, dim=0) | ||
| value_packed = torch.cat(value_valid, dim=0) | ||
| query_packed = query.flatten(0, 1) | ||
|
|
||
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn | ||
| out, lse, *_ = func( | ||
| q=query_packed, | ||
| k=key_packed, | ||
| v=value_packed, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_k=cu_seqlens_k, | ||
| max_seqlen_q=max_seqlen_q, | ||
| max_seqlen_k=max_seqlen_k, | ||
| softmax_scale=scale, | ||
| causal=is_causal, | ||
| ) | ||
| out = out.unflatten(0, (batch_size, -1)) | ||
| func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn | ||
| out = func( | ||
| q=query_packed, | ||
| k=key_packed, | ||
| v=value_packed, | ||
| cu_seqlens_q=cu_seqlens_q, | ||
| cu_seqlens_k=cu_seqlens_k, | ||
| max_seqlen_q=max_seqlen_q, | ||
| max_seqlen_k=max_seqlen_k, | ||
| softmax_scale=scale, | ||
| causal=is_causal, | ||
| return_attn_probs=return_lse, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like an extra argument?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: The original code always unpacked |
||
| ) | ||
| if return_lse: | ||
| out, lse, *_ = out | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to initialize |
||
| out = out.unflatten(0, (batch_size, -1)) | ||
| else: | ||
| forward_op = functools.partial( | ||
| _flash_attention_3_varlen_hub_forward_op, | ||
| window_size=(-1, -1), | ||
| softcap=0.0, | ||
| num_splits=1, | ||
| pack_gqa=None, | ||
| deterministic=False, | ||
| sm_margin=0, | ||
| ) | ||
| out = _templated_context_parallel_attention( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| 0.0, | ||
| is_causal, | ||
| scale, | ||
| False, | ||
| return_lse, | ||
| forward_op=forward_op, | ||
| backward_op=_flash_attention_3_varlen_hub_backward_op, | ||
| _parallel_config=_parallel_config, | ||
| ) | ||
| if return_lse: | ||
| out, lse = out | ||
|
|
||
| return (out, lse) if return_lse else out | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: The non-varlen
_flash_attention_3_hub_forward_opuses keyword arguments for the trailing parameters (causal=is_causal,window_size_left=window_size[0], etc.), but here everything is passed positionally with no inline comments explaining what eachNonecorresponds to. This makes the code harder to audit and fragile if the upstream signature changes.Consider either:
Nonevalues (like the non-varlen version does with# k_new, v_new,# cu_seqlens_q/k/k_new, etc.)