Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir

RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly

RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . && \
cd flash-attention/hopper/ && python setup.py install
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This RUN command can be improved for readability and to reduce the final image size by cleaning up the cloned repository. Consider breaking down the command and removing the LightKernel directory after installation. For example:

RUN git clone https://github.com/ModelTC/LightKernel.git && \
    cd LightKernel && \
    pip install --no-deps -v . && \
    cd flash-attention/hopper/ && \
    python setup.py install && \
    cd ../../.. && \
    rm -rf LightKernel


RUN apt-get update && apt-get install -y libnuma-dev # for sgl_kernel

Expand Down
3 changes: 2 additions & 1 deletion docker/Dockerfile.deepep
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir

RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly

RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . && \
cd flash-attention/hopper/ && python setup.py install
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This RUN command can be improved for readability and to reduce the final image size by cleaning up the cloned repository. Consider breaking down the command and removing the LightKernel directory after installation. For example:

RUN git clone https://github.com/ModelTC/LightKernel.git && \
    cd LightKernel && \
    pip install --no-deps -v . && \
    cd flash-attention/hopper/ && \
    python setup.py install && \
    cd ../../.. && \
    rm -rf LightKernel


RUN apt-get update && apt-get install -y libnuma-dev wget devscripts debhelper dh-make build-essential dkms
RUN apt-get install -y ibverbs-providers infiniband-diags perftest rdma-core libibverbs-dev librdmacm-dev
Expand Down
13 changes: 11 additions & 2 deletions lightllm/common/basemodel/triton_kernel/gen_decode_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@
import triton
import triton.language as tl
from .gen_prefill_params import gen_cumsum_pad0_tensor
from lightllm.utils.envs_utils import get_env_start_args


@torch.no_grad()
def gen_decode_params(b_seq_len: torch.Tensor):
b_kv_seq_len = b_seq_len
position_ids = b_seq_len - 1
b_q_seq_len = torch.ones_like(b_seq_len)
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)
mtp_step = get_env_start_args().mtp_step
mtp_size = mtp_step + 1
enable_fa3_mtp = get_env_start_args().enable_fa3_mtp
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve readability and avoid redundant calls, it's better to call get_env_start_args() only once and store the result in a local variable.

Suggested change
mtp_step = get_env_start_args().mtp_step
mtp_size = mtp_step + 1
enable_fa3_mtp = get_env_start_args().enable_fa3_mtp
start_args = get_env_start_args()
mtp_step, enable_fa3_mtp = start_args.mtp_step, start_args.enable_fa3_mtp
mtp_size = mtp_step + 1


if enable_fa3_mtp:
b_q_seq_len = torch.ones_like(b_seq_len[: len(b_seq_len) // mtp_size])
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size])
else:
b_q_seq_len = torch.ones_like(b_seq_len)
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)

return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids
212 changes: 212 additions & 0 deletions lightllm/common/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# This file is adapted from sgl-project/sglang:
# https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/flash_attn.py
# The original code and this file are licensed under the Apache License, Version 2.0.
#
# Copyright (c) sgl-project and other contributors.
# Modifications Copyright (c) LightLLM contributors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import List, Optional, Tuple, Union
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x


try:
import flash_attn_3._C # Registers operators with PyTorch

flash_attn_3_mtp = torch.ops.flash_attn_3

def flash_attn_with_kvcache_mtp(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
mtp_step=0,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
k and v. This is useful for incremental decoding: you can pass in the cached keys/values from
the previous step, and update them with the new keys/values from the current step, and do
attention with the updated cache, all in 1 kernel.

If you pass in k / v, you must make sure that the cache is large enough to hold the new values.
For example, the KV cache could be pre-allocated with the max sequence length, and you can use
cache_seqlens to keep track of the current sequence lengths of each sequence in the batch.

Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be
rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos
and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc.
If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at
indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens).

See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function.

Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.

If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.

If window_size != (-1, -1), implements sliding window local attention. Query at position i
will only attend to keys between
[i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive.

Note: Does not support backward pass.

Arguments:
q: (batch_size, seqlen, nheads, headdim)
k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache)
page_block_size must be a multiple of 256.
v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table,
or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache)
k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate
k with k_cache, starting at the indices specified by cache_seqlens.
v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k.
qv [optional]: (batch_size, seqlen, nheads, headdim_v)
rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding
to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16.
rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos.
cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the
KV cache.
cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache.
If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1].
If the indices are not distinct, and k and v are provided, the values updated in the cache
might come from any of the duplicate indices.
cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0.
page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention.
softcap: float. Anything > 0 activates softcapping attention.
rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in.
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence.
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic
to automatically determine the number of splits.
Don't change this unless you know what you are doing.
return_softmax_lse: bool. Whether to return the logsumexp of the attention scores.

Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
if softmax_scale is None:
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
if cache_seqlens is not None and isinstance(cache_seqlens, int):
cache_seqlens = torch.full((k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device)
cache_seqlens = maybe_contiguous(cache_seqlens)

q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)]
v_cache = v_cache.contiguous() if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1 else v_cache
cu_seqlens_q, cu_seqlens_k_new = [maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)]
page_table, cache_batch_idx, cache_leftpad = [
maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad)
]
rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)]
rotary_seqlens = maybe_contiguous(rotary_seqlens)

# out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default(
out, softmax_lse, *rest = flash_attn_3_mtp.fwd(
q,
k_cache,
v_cache,
k,
v,
qv,
None, # out
cu_seqlens_q,
None, # cu_seqlens_k
cu_seqlens_k_new,
None, # seqused_q
cache_seqlens,
max_seqlen_q,
None, # max_seqlen_k
page_table,
cache_batch_idx,
cache_leftpad,
rotary_cos,
rotary_sin,
rotary_seqlens,
q_descale,
k_descale,
v_descale,
softmax_scale,
causal,
window_size[0],
window_size[1],
0,
softcap,
rotary_interleaved,
scheduler_metadata,
num_splits,
pack_gqa,
sm_margin,
mtp_step,
)
return (out, softmax_lse, *rest) if return_softmax_lse else out

except:
flash_attn_3_mtp = None
flash_attn_with_kvcache_mtp = None
logger.warning("flash_attn_3._C is not available, please install flash-attention-3 package.")
39 changes: 38 additions & 1 deletion lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from lightllm.utils.dist_utils import get_global_world_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2
from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp

logger = init_logger(__name__)

Expand Down Expand Up @@ -72,6 +73,8 @@ def __init__(self, layer_num, network_config, mode=[]):
super().__init__(layer_num, network_config, mode)
self.num_heads = network_config["num_attention_heads"]
self.num_kv_heads = network_config["num_key_value_heads"]
self.mtp_step = get_env_start_args().mtp_step
self.mtp_size = self.mtp_step + 1
return

def _bind_func(self):
Expand All @@ -97,7 +100,11 @@ def _bind_attention(self):
)
else:
self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
if get_env_start_args().enable_fa3:
if get_env_start_args().enable_fa3_mtp:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_env_start_args() is called multiple times within this if/elif chain (lines 103, 107, 111). To improve performance and readability, consider calling it once before this conditional block and storing the result in a local variable.

self._token_attention_kernel = partial(
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_mtp, self
)
elif get_env_start_args().enable_fa3:
self._token_attention_kernel = partial(
Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self
)
Expand Down Expand Up @@ -547,6 +554,36 @@ def _context_attention_kernel_origin_fp8(
)
return o_tensor

def _token_gqa_decode_attention_mtp(
self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
):
q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :]
q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1)
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim)
kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank)
k_descale, v_descale = None, None
o_tensor = flash_attn_with_kvcache_mtp(
q=q_rope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_rope_head_dim),
k_cache=k_rope,
v_cache=kv_nope,
qv=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.kv_lora_rank),
page_table=infer_state.page_table[self.mtp_size - 1 :: self.mtp_size],
cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size],
cu_seqlens_q=infer_state.cu_seqlens_q,
cu_seqlens_k_new=infer_state.cu_seqlens_k,
max_seqlen_q=1,
softmax_scale=self.softmax_scale,
causal=True,
window_size=(-1, -1),
softcap=0.0,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=False,
mtp_step=self.mtp_step,
)
return o_tensor

def _token_gqa_decode_attention_flashattention(
self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None
):
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/deepseek2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, kvargs):
return

def _init_inferstate_cls(self):
if get_env_start_args().enable_fa3:
if get_env_start_args().enable_fa3 or get_env_start_args().enable_fa3_mtp:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_env_start_args() is called twice on this line. To improve performance and readability, consider calling it once before the if statement and storing the result in a local variable.

self.infer_state_class = Deepseek2FlashAttentionStateInfo
elif self.enable_flashinfer:
self.infer_state_class = Deepseek2FlashInferStateInfo
Expand Down
5 changes: 5 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
but ensure that the model is compatible with the specified step count.
currently, deepseekv3 model only support 1 step""",
)
parser.add_argument(
"--enable_fa3_mtp",
action="store_true",
help="""inference backend will use the fa3_mtp kernel for decode with MTP mode""",
)
parser.add_argument(
"--kv_quant_calibration_config_path",
type=str,
Expand Down
3 changes: 3 additions & 0 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ def normal_or_p_d_start(args):
assert args.mtp_draft_model_dir is None
assert args.mtp_step == 0

if args.enable_fa3_mtp:
assert args.mtp_mode is not None, "enable_fa3_mtp must set mtp_mode"

# 检查GPU数量是否足够
if args.visual_gpu_ids is None:
args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp))
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,5 @@ class StartArgs:
mtp_mode: Optional[str] = field(default=None)
mtp_draft_model_dir: Optional[str] = field(default=None)
mtp_step: int = field(default=0)
enable_fa3_mtp: bool = field(default=False)
kv_quant_calibration_config_path: Optional[str] = field(default=None)
Loading
Loading