Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions lightllm/common/basemodel/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,4 @@
from .base_att import BaseAttBackend, BasePrefillAttState, BaseDecodeAttState, AttControl
from .triton.fp import TritonAttBackend
from .triton.int4kv import Int4kvTritonAttBackend
from .triton.int8kv import Int8kvTritonAttBackend
from .triton.mla import MlaTritonAttBackend
from .fa3.fp import Fa3AttBackend
from .fa3.fp8 import Fp8Fa3AttBackend
from .fa3.mla import MlaFa3AttBackend
from .flashinfer.fp8 import Fp8FlashInferAttBackend
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend

# NSA backend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend

from .create_utils import (
get_prefill_att_backend_class,
Expand Down
298 changes: 187 additions & 111 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,226 @@
"""Attention backend selection utilities."""
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.utils.log_utils import init_logger
from lightllm.platform.base.attention import (
AttCategory,
att_backend_registry,
ensure_att_backends_loaded,
)
from lightllm.common.basemodel.attention.base_att import BaseAttBackend
from lightllm.utils.backend_validator import validate
from typing import Dict
from .base_att import BaseAttBackend
from .triton.fp import TritonAttBackend
from .triton.int4kv import Int4kvTritonAttBackend
from .triton.int8kv import Int8kvTritonAttBackend
from .triton.mla import MlaTritonAttBackend
from .fa3.fp import Fa3AttBackend
from .fa3.fp8 import Fp8Fa3AttBackend
from .fa3.mla import MlaFa3AttBackend
from .flashinfer.fp8 import Fp8FlashInferAttBackend
from .flashinfer.fp import FlashInferAttBackend
from .flashinfer.mla import MlaFlashInferAttBackend
from .nsa.flashmla_sparse import NsaFlashMlaSparseAttBackend
from .nsa.fp8_flashmla_sparse import NsaFlashMlaFp8SparseAttBackend
from lightllm.utils.envs_utils import get_env_start_args, get_page_size
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)

# Backend class mappings by data type
data_type_to_backend = {
"None": {
"triton": TritonAttBackend,
"fa3": Fa3AttBackend,
"flashinfer": FlashInferAttBackend,
},
"int4kv": {
"triton": Int4kvTritonAttBackend,
# "fa3": Fp8Fa3AttBackend,
# "flashinfer": Fp8FlashInferAttBackend,
},
"int8kv": {
"triton": Int8kvTritonAttBackend,
# "fa3": Fp8Fa3AttBackend,
# "flashinfer": Fp8FlashInferAttBackend,
},
"fp8kv_sph": {
"fa3": Fp8Fa3AttBackend,
},
"fp8kv_spt": {
"flashinfer": Fp8FlashInferAttBackend,
},
}

mla_data_type_to_backend = {
"None": {
"triton": MlaTritonAttBackend,
"fa3": MlaFa3AttBackend,
"flashinfer": MlaFlashInferAttBackend,
},
}

nsa_data_type_to_backend = {
"None": {
"flashmla_sparse": NsaFlashMlaSparseAttBackend,
# Future backends: "fa3", "tilelang", "aiter"
},
"fp8kv_dsa": {
"flashmla_sparse": NsaFlashMlaFp8SparseAttBackend,
},
}

def _resolve_registry_name(
name: str,
*,
category: AttCategory,
kv_type: str,
platform: str | None,
) -> str:
if (
name == "fa3"
and category == "standard"
and kv_type == "None"
and get_page_size() > 1
and att_backend_registry.is_registered(
category=category,
name="paged_fa3",
kv_type=kv_type,
platform=platform,
)
):
return "paged_fa3"
return name


def _get_att_backend_class(
*,
category: AttCategory,
backend_name: str,
kv_type: str,
platform: str | None,
) -> type:
ensure_att_backends_loaded()
resolved_name = _resolve_registry_name(
backend_name,
category=category,
kv_type=kv_type,
platform=platform,
)
return att_backend_registry.resolve_backend_cls(
category=category,
name=resolved_name,
kv_type=kv_type,
platform=platform,
)


def _fallback_att_backend_class(
*,
category: AttCategory,
kv_type: str,
platform: str | None,
) -> type:
"""Pick a registered backend without validation when auto-selection fails."""
registered = att_backend_registry.list_names(
category=category,
kv_type=kv_type,
platform=platform,
)
if not registered:
raise ValueError(
f"No attention backends are registered for "
f"category={category!r}, kv_type={kv_type!r}, platform={platform!r}"
)

fallback_name = "triton" if "triton" in registered else registered[0]
resolved_name = _resolve_registry_name(
fallback_name,
category=category,
kv_type=kv_type,
platform=platform,
)
logger.warning(
f"No backend validation succeeded, falling back to {fallback_name!r} "
f"(resolved as {resolved_name!r})"
)
return att_backend_registry.resolve_backend_cls(
category=category,
name=resolved_name,
kv_type=kv_type,
platform=platform,
)


def _auto_select_backend(
llm_dtype: str,
kv_type_to_backend: Dict[str, Dict[str, BaseAttBackend]],
priority_list: list = ["fa3", "flashinfer", "triton"],
*,
category: AttCategory,
kv_type: str,
platform: str | None,
priority_list: list[str],
) -> type:
"""Auto-select the best available backend with validation.
ensure_att_backends_loaded()

Priority follows the provided priority_list.
Each backend is validated in a subprocess with ground truth checks.
"""
backend_map = kv_type_to_backend

args = get_env_start_args()
if args.enable_ep_moe:
if get_env_start_args().enable_ep_moe:
logger.info("Expert parallelism with MoE enabled, excluding flashinfer attention backend")
priority_list = [name for name in priority_list if name != "flashinfer"]

for backend_name in priority_list:
if backend_name in backend_map[llm_dtype] and validate(backend_name):
resolved_name = _resolve_registry_name(
backend_name,
category=category,
kv_type=kv_type,
platform=platform,
)
if not att_backend_registry.is_registered(
category=category,
name=resolved_name,
kv_type=kv_type,
platform=platform,
):
continue

# Get 'AttBackendSpec'
spec = att_backend_registry.get_spec(
category=category,
name=resolved_name,
kv_type=kv_type,
)
validate_name = spec.effective_validate_name() if spec is not None else backend_name
if validate(validate_name):
logger.info(f"Auto-selected {backend_name} backend (validated)")
return backend_map[llm_dtype][backend_name]

# Fallback to triton without validation (should not happen)
logger.warning("No backend validation succeeded, falling back to triton")
return backend_map[llm_dtype]["triton"]
return att_backend_registry.resolve_backend_cls(
category=category,
name=resolved_name,
kv_type=kv_type,
platform=platform,
)


return _fallback_att_backend_class(
category=category,
kv_type=kv_type,
platform=platform,
)


def _select_att_backend_class(
*,
category: AttCategory,
backend_str: str,
priority_list: list[str],
) -> type:
args = get_env_start_args()
kv_type = args.llm_kv_type
platform = args.hardware_platform
# If backend_str is not "auto", use the specified backend
if backend_str != "auto":
return _get_att_backend_class(
category=category,
backend_name=backend_str,
kv_type=kv_type,
platform=platform,
)
# Auto select backend from priority_list
return _auto_select_backend(
category=category,
kv_type=kv_type,
platform=platform,
priority_list=priority_list,
)


def get_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_prefill_att_backend[index]
if backend_str != "auto":
return data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list)
return _select_att_backend_class(
category="standard",
backend_str=args.llm_prefill_att_backend[index],
priority_list=priority_list,
)


def get_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_decode_att_backend[index]
if backend_str != "auto":
return data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, kv_type_to_backend=data_type_to_backend, priority_list=priority_list)
return _select_att_backend_class(
category="standard",
backend_str=args.llm_decode_att_backend[index],
priority_list=priority_list,
)


def get_mla_prefill_att_backend_class(index=0, priority_list: list = ["fa3", "flashinfer", "triton"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_prefill_att_backend[index]
if backend_str != "auto":
return mla_data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list)
return _select_att_backend_class(
category="mla",
backend_str=args.llm_prefill_att_backend[index],
priority_list=priority_list,
)


def get_mla_decode_att_backend_class(index=0, priority_list: list = ["flashinfer", "fa3", "triton"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_decode_att_backend[index]
if backend_str != "auto":
return mla_data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, kv_type_to_backend=mla_data_type_to_backend, priority_list=priority_list)
return _select_att_backend_class(
category="mla",
backend_str=args.llm_decode_att_backend[index],
priority_list=priority_list,
)


def get_nsa_prefill_att_backend_class(index=0, priority_list: list = ["flashmla_sparse"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_prefill_att_backend[index]
if backend_str != "auto":
return nsa_data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, kv_type_to_backend=nsa_data_type_to_backend, priority_list=priority_list)
return _select_att_backend_class(
category="nsa",
backend_str=args.llm_prefill_att_backend[index],
priority_list=priority_list,
)


def get_nsa_decode_att_backend_class(index=0, priority_list: list = ["flashmla_sparse"]) -> BaseAttBackend:
args = get_env_start_args()
llm_dtype = args.llm_kv_type
backend_str = args.llm_decode_att_backend[index]
if backend_str != "auto":
return nsa_data_type_to_backend[llm_dtype][backend_str]
else:
return _auto_select_backend(llm_dtype, kv_type_to_backend=nsa_data_type_to_backend, priority_list=priority_list)
return _select_att_backend_class(
category="nsa",
backend_str=args.llm_decode_att_backend[index],
priority_list=priority_list,
)
2 changes: 2 additions & 0 deletions lightllm/common/basemodel/attention/fa3/fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
from lightllm.platform.base.attention import register_att_backend


@register_att_backend(name="fa3", category="standard", platforms=("cuda",))
class Fa3AttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/basemodel/attention/fa3/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.common.basemodel.triton_kernel.quantization.q_per_head_fp8_quant import q_per_head_fp8_quant
from lightllm.utils.vllm_utils import HAS_VLLM, vllm_ops
from lightllm.platform.base.attention import register_att_backend
from typing import Union
from .fp import Fa3AttBackend, Fa3PrefillAttState, Fa3DecodeAttState

Expand All @@ -15,6 +16,7 @@
scaled_fp8_quant = None


@register_att_backend(name="fa3", category="standard", kv_types=("fp8kv_sph",), platforms=("cuda",))
class Fp8Fa3AttBackend(Fa3AttBackend):
def __init__(self, model):
super().__init__(model=model)
Expand Down
2 changes: 2 additions & 0 deletions lightllm/common/basemodel/attention/fa3/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from lightllm.common.basemodel.triton_kernel.fa3_utils import page_table_copy
from lightllm.common.basemodel.triton_kernel.gen_prefill_params import gen_cumsum_pad0_tensor
from lightllm.utils.sgl_utils import flash_attn_varlen_func
from lightllm.platform.base.attention import register_att_backend


@register_att_backend(name="fa3", category="mla", platforms=("cuda",))
class MlaFa3AttBackend(BaseAttBackend):
def __init__(self, model):
super().__init__(model=model)
Expand Down
Loading