Skip to content

Commit e756ee6

Browse files
SangChengCsangchengmengwangzaijunshihaobaihiworldwzj
authored
add-choose-vit-backend (#1191)
Co-authored-by: sangchengmeng <sangchengmeng@sensetime.com> Co-authored-by: wangzaijun <wangzaijun@sensetime.com> Co-authored-by: shihaobai <1798930569@qq.com> Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent 871ace6 commit e756ee6

25 files changed

Lines changed: 435 additions & 39 deletions

File tree

docs/CN/source/tutorial/api_server_args.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,17 @@ PD 分离模式参数
284284

285285
为 ViT 构建分布式环境的 NCCL 端口列表,例如 29500 29501 29502,默认为 [29500]
286286

287+
.. option:: --vit_att_backend
288+
289+
设置 ViT 使用的注意力后端。可选值为:
290+
291+
* ``auto``: 自动选择最佳后端(默认值),优先级为 fa3 > xformers > sdpa > triton
292+
* ``fa3``: 使用 Flash-Attention 3 后端
293+
* ``xformers``: 使用 xformers 后端
294+
* ``sdpa``: 使用 sdpa 后端
295+
* ``triton``: 使用 Triton 后端
296+
297+
287298
性能优化参数
288299
------------
289300

docs/EN/source/tutorial/api_server_args.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,17 @@ Multimodal Parameters
282282

283283
List of NCCL ports for ViT, e.g., 29500 29501 29502, default is [29500]
284284

285+
.. option:: --vit_att_backend
286+
287+
Set the attention backend for ViT. Available options:
288+
289+
* ``auto``: Automatically select the best backend (default), with priority fa3 > xformers > sdpa > triton
290+
* ``fa3``: Use Flash-Attention 3 backend
291+
* ``xformers``: Use xformers backend
292+
* ``sdpa``: Use sdpa backend
293+
* ``triton``: Use Triton backend
294+
295+
285296
Performance Optimization Parameters
286297
-----------------------------------
287298

lightllm/common/basemodel/attention_vit/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from abc import ABC, abstractmethod
3+
4+
5+
class BaseVitAttBackend(ABC):
6+
"""
7+
用于创建支持各种不同的AttBackend, 如 fa3, sdpa, triton 实现等,
8+
这个是单列模式, 每种backend只有一个实例
9+
"""
10+
11+
_instances = {}
12+
13+
def __new__(cls, *args, **kwargs):
14+
"""
15+
重写__new__方法实现单例模式
16+
"""
17+
# 检查是否已经有该类的实例
18+
if cls not in cls._instances:
19+
# 创建新实例并存储
20+
instance = super().__new__(cls)
21+
cls._instances[cls] = instance
22+
# 返回已有的实例
23+
return cls._instances[cls]
24+
25+
def __init__(self):
26+
pass
27+
28+
@abstractmethod
29+
def _vit_att_fwd(
30+
self,
31+
q: torch.Tensor,
32+
k: torch.Tensor,
33+
v: torch.Tensor,
34+
o: torch.Tensor,
35+
cu_seqlens: torch.Tensor,
36+
max_seqlen: int,
37+
) -> torch.Tensor:
38+
raise NotImplementedError("not impl")
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
from lightllm.utils.log_utils import init_logger
3+
from lightllm.utils.envs_utils import get_env_start_args
4+
from lightllm.utils.backend_validator import _validate
5+
from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend
6+
from lightllm.common.basemodel.attention_vit.fa3.fp import Fa3VitAttBackend
7+
from lightllm.common.basemodel.attention_vit.triton.fp import TritonVitAttBackend
8+
from lightllm.common.basemodel.attention_vit.sdpa.fp import SdpaVitAttBackend
9+
from lightllm.common.basemodel.attention_vit.xformers.fp import XformersVitAttBackend
10+
11+
logger = init_logger(__name__)
12+
13+
14+
vit_att_backend = {
15+
"triton": TritonVitAttBackend,
16+
"sdpa": SdpaVitAttBackend,
17+
"fa3": Fa3VitAttBackend,
18+
"xformers": XformersVitAttBackend,
19+
}
20+
21+
22+
def get_vit_att_backend_class(backend_name: str) -> BaseVitAttBackend:
23+
vit_att_backend_class = vit_att_backend[backend_name]
24+
return vit_att_backend_class
25+
26+
27+
def init_vit_att_backend(index=0, priority_list: list = ["fa3", "xformers", "sdpa", "triton"]) -> str:
28+
args = get_env_start_args()
29+
backend_name = args.vit_att_backend[index]
30+
if backend_name != "auto":
31+
logger.info(f"Selected {backend_name} backend for VIT")
32+
return backend_name
33+
else:
34+
return _select_vit_backend(priority_list=priority_list)
35+
36+
37+
def _select_vit_backend(priority_list: list = ["fa3", "xformers", "sdpa", "triton"]) -> str:
38+
"""Auto-select the best available backend with validation for VIT.
39+
40+
Priority: FA3 > Xformers > Sdpa > Triton
41+
Each backend is validated in a subprocess with ground truth checks.
42+
"""
43+
44+
for backend_name in priority_list:
45+
if _validate(backend_name):
46+
logger.info(f"Auto-selected {backend_name} backend (validated) for VIT")
47+
return backend_name
48+
49+
# Fallback to triton without validation (should not happen)
50+
logger.warning("No backend validation succeeded, falling back to triton")
51+
return "triton"

lightllm/common/basemodel/attention_vit/fa3/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import dataclasses
2+
import torch
3+
from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend
4+
5+
6+
class Fa3VitAttBackend(BaseVitAttBackend):
7+
@staticmethod
8+
def _vit_att_fwd(
9+
q: torch.Tensor,
10+
k: torch.Tensor,
11+
v: torch.Tensor,
12+
o: torch.Tensor,
13+
cu_seqlens: torch.Tensor,
14+
max_seqlen: int,
15+
) -> None:
16+
17+
head_dim = q.shape[-1]
18+
softmax_scale = head_dim ** -0.5
19+
window_size = (-1, -1)
20+
torch.ops.sgl_kernel.fwd.default(
21+
q,
22+
k,
23+
v,
24+
None, # k_new
25+
None, # v_new
26+
None, # qv
27+
o, # out
28+
cu_seqlens,
29+
cu_seqlens,
30+
None, # cu_seqlens_k_new
31+
None,
32+
None,
33+
max_seqlen,
34+
max_seqlen,
35+
None, # page_table,
36+
None, # kv_batch_idx
37+
None, # leftpad_k
38+
None, # rotary cos
39+
None, # rotary sin
40+
None, # seqlens_rotary
41+
None,
42+
None,
43+
None,
44+
softmax_scale,
45+
False,
46+
window_size[0],
47+
window_size[1],
48+
0.0,
49+
is_rotary_interleaved=False,
50+
scheduler_metadata=None,
51+
num_splits=1,
52+
pack_gqa=None,
53+
sm_margin=0,
54+
sinks=None,
55+
)
56+
57+
return o

lightllm/common/basemodel/attention_vit/sdpa/__init__.py

Whitespace-only changes.
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import torch
2+
import torch.nn.functional as F
3+
from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend
4+
5+
6+
class SdpaVitAttBackend(BaseVitAttBackend):
7+
@staticmethod
8+
def _vit_att_fwd(
9+
q: torch.Tensor,
10+
k: torch.Tensor,
11+
v: torch.Tensor,
12+
o: torch.Tensor,
13+
cu_seqlens: torch.Tensor,
14+
max_seqlen: int,
15+
) -> torch.Tensor:
16+
assert q.ndim == k.ndim == v.ndim == o.ndim == 3
17+
assert cu_seqlens is not None and cu_seqlens.ndim == 1
18+
cu_seqlens = cu_seqlens.detach().to("cpu")
19+
B = cu_seqlens.numel() - 1
20+
21+
with torch.no_grad():
22+
for b in range(B):
23+
s = int(cu_seqlens[b])
24+
e = int(cu_seqlens[b + 1])
25+
L = e - s
26+
if L <= 0:
27+
continue
28+
if max_seqlen:
29+
assert L <= max_seqlen
30+
31+
# [L, H, D] -> [1, H, L, D]
32+
q_ = q[s:e].permute(1, 0, 2).unsqueeze(0)
33+
k_ = k[s:e].permute(1, 0, 2).unsqueeze(0)
34+
v_ = v[s:e].permute(1, 0, 2).unsqueeze(0)
35+
36+
out = F.scaled_dot_product_attention(
37+
q_,
38+
k_,
39+
v_,
40+
attn_mask=None,
41+
dropout_p=0.0,
42+
is_causal=False,
43+
)
44+
# [1, H, L, D] -> [L, H, D]
45+
o[s:e].copy_(out.squeeze(0).permute(1, 0, 2))
46+
47+
return o

lightllm/common/basemodel/attention_vit/triton/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)