Skip to content

Commit c25c078

Browse files
author
niushengxiao
committed
refine
1 parent 100088a commit c25c078

12 files changed

Lines changed: 58 additions & 122 deletions

File tree

lightllm/common/basemodel/attention_vit/fa3/fp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def _vit_att_fwd(
1818
head_dim = q.shape[-1]
1919
softmax_scale = head_dim ** -0.5
2020
window_size = (-1, -1)
21-
attn_output = flash_attn_varlen_func(
21+
o = flash_attn_varlen_func(
2222
q,
2323
k,
2424
v,
@@ -29,8 +29,7 @@ def _vit_att_fwd(
2929
softmax_scale=softmax_scale,
3030
causal=False,
3131
window_size=window_size,
32+
attention_chunk=0,
3233
softcap=0.0,
3334
)
34-
o.copy_(attn_output)
35-
3635
return o

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 9 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
77
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
88
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
9-
from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed, load_tensor_afs
9+
from lightllm.server.embed_cache.utils import get_shm_name_embed, load_tensor_afs
1010
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
1111
from lightllm.distributed.communication_op import all_reduce
1212
from lightllm.utils.envs_utils import get_env_start_args
@@ -41,9 +41,7 @@ def _copy_loaded_embed_to_cache(
4141
self, embed_tensor: torch.Tensor, cpu_embed_cache_tensor: torch.Tensor, start_index: int
4242
):
4343
if embed_tensor.ndim == 2:
44-
token_num, hidden_size = embed_tensor.shape
45-
cpu_embed_cache_tensor[start_index : start_index + token_num, 0, :hidden_size].copy_(embed_tensor)
46-
return
44+
embed_tensor = embed_tensor.unsqueeze(1)
4745

4846
token_num, layer_num, hidden_size = embed_tensor.shape
4947
cpu_embed_cache_tensor[start_index : start_index + token_num, :layer_num, :hidden_size].copy_(embed_tensor)
@@ -53,18 +51,20 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
5351
img_start_token_ids = []
5452
img_token_lens = []
5553
img_start_locs_in_cache = []
54+
unique_uids = []
5655
device = layer_weight.wte_weight_.weight.device
5756
dtype = layer_weight.wte_weight_.weight.dtype
5857
hidden_size = layer_weight.wte_weight_.weight.shape[1]
5958

60-
for batch_id, p in enumerate(infer_state.multimodal_params):
59+
for _, p in enumerate(infer_state.multimodal_params):
6160
for img in p["images"] + p["audios"]:
6261
# skip the same image
6362
if img["token_id"] in img_start_token_ids:
6463
continue
6564
img_start_token_ids.append(img["token_id"])
6665
img_token_lens.append(img["token_num"])
6766
img_start_locs_in_cache.append(img["start_index_in_embed_cache"])
67+
unique_uids.append(img["uuid"])
6868
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)
6969

7070
from lightllm.server.router.model_infer.infer_batch import g_infer_context
@@ -77,33 +77,13 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
7777
)
7878

7979
if self.args.enable_remote_vit:
80-
unique_multimodal_items = []
81-
seen_uuids = set()
8280
release_ids = []
83-
for batch_id, p in enumerate(infer_state.multimodal_params):
81+
for _, p in enumerate(infer_state.multimodal_params):
8482
for img in p["images"] + p["audios"]:
85-
if img["token_num"] is None:
86-
continue
87-
uid = img["uuid"]
88-
release_ids.append(uid)
89-
if uid in seen_uuids:
90-
continue
91-
seen_uuids.add(uid)
92-
unique_multimodal_items.append((uid, img["start_index_in_embed_cache"]))
83+
release_ids.append(img["uuid"])
9384

94-
if self.args.image_embed_dir:
95-
image_embed_dir = self.args.image_embed_dir
96-
97-
def load_embed_tensor(uid):
98-
return load_tensor_afs(get_shm_name_embed(uid), image_embed_dir)
99-
100-
else:
101-
102-
def load_embed_tensor(uid):
103-
return bytes2tensor(read_shm(get_shm_name_embed(uid)))
104-
105-
for uid, start_index_in_embed_cache in unique_multimodal_items:
106-
embed_tensor = load_embed_tensor(uid)
85+
for uid, start_index_in_embed_cache in zip(unique_uids, img_start_locs_in_cache):
86+
embed_tensor = load_tensor_afs(get_shm_name_embed(uid), self.args.image_embed_dir)
10787
self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache)
10888

10989
if release_ids:

lightllm/models/vit/triton_kernel/flashattention_nopad.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def flash_attention_v3_fwd(
167167
head_dim = q.shape[-1]
168168
softmax_scale = head_dim ** -0.5
169169
window_size = (-1, -1)
170-
attn_output = flash_attn_varlen_func(
170+
o = flash_attn_varlen_func(
171171
q,
172172
k,
173173
v,
@@ -180,9 +180,7 @@ def flash_attention_v3_fwd(
180180
window_size=window_size,
181181
softcap=0.0,
182182
)
183-
o.copy_(attn_output)
184-
185-
return
183+
return o
186184

187185
except ImportError:
188186
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")

lightllm/server/api_lightllm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import collections
22
from typing import AsyncGenerator
33
from fastapi import BackgroundTasks, Request
4-
from fastapi.responses import Response, StreamingResponse
4+
from fastapi.responses import Response, StreamingResponse, JSONResponse
55
from lightllm.server.core.objs.sampling_params import SamplingParams
66
from .multimodal_params import MultimodalParams
77
from .httpserver.manager import HttpServerManager
8-
from fastapi.responses import JSONResponse
98
import ujson as json
109

1110

lightllm/server/api_start.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def signal_handler(sig, frame):
8787
return
8888

8989

90-
def check_and_set_args(args):
90+
def normal_or_p_d_start(args, only_prepare=False):
9191
from lightllm.server.core.objs.start_args_type import StartArgs
9292

9393
args: StartArgs = args
@@ -219,20 +219,18 @@ def check_and_set_args(args):
219219
if args.batch_max_tokens is None:
220220
args.batch_max_tokens = args.max_req_total_len
221221
else:
222-
assert args.batch_max_tokens >= args.max_req_total_len, (
223-
f"batch_max_tokens must >= max_req_total_len"
224-
f"but got {args.batch_max_tokens}, {args.max_req_total_len}"
225-
)
222+
assert args.batch_max_tokens >= args.max_req_total_len, f"batch_max_tokens must >= max_req_total_len"
223+
f"but got {args.batch_max_tokens}, {args.max_req_total_len}"
226224
else:
227225
# chunked 模式下
228226
if args.batch_max_tokens is None:
229227
args.batch_max_tokens = 16384 // args.dp
230228
if args.chunked_prefill_size is None:
231229
args.chunked_prefill_size = args.batch_max_tokens // 2
232-
assert args.batch_max_tokens >= args.chunked_prefill_size, (
233-
"chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, "
234-
f"but got {args.batch_max_tokens}, {args.chunked_prefill_size}"
235-
)
230+
assert (
231+
args.batch_max_tokens >= args.chunked_prefill_size
232+
), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, "
233+
f"but got {args.batch_max_tokens}, {args.chunked_prefill_size}"
236234

237235
# help to manage data stored on Ceph
238236
if "s3://" in args.model_dir:
@@ -252,9 +250,8 @@ def check_and_set_args(args):
252250
args.data_type = get_dtype(args.model_dir)
253251
assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]
254252

255-
256-
def normal_or_p_d_start(args):
257-
check_and_set_args(args)
253+
if only_prepare:
254+
return
258255

259256
already_uesd_ports = [args.port]
260257
if args.nccl_port is not None:
@@ -291,17 +288,19 @@ def normal_or_p_d_start(args):
291288
can_use_ports = can_use_ports[10:]
292289

293290
visual_model_tp_ports = []
291+
visual_nccl_ports = []
294292
for _ in range(args.visual_dp):
295293
tp_ports_for_dp = can_use_ports[0 : args.visual_tp]
296294
visual_model_tp_ports.append(tp_ports_for_dp)
297295
can_use_ports = can_use_ports[args.visual_tp :]
296+
if args.visual_nccl_ports is None:
297+
visual_nccl_ports.append(can_use_ports[0])
298+
can_use_ports = can_use_ports[1:]
298299

299-
if args.visual_nccl_ports is None:
300-
visual_nccl_ports = can_use_ports[0 : args.visual_dp]
301-
can_use_ports = can_use_ports[args.visual_dp :]
302-
else:
303-
visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
300+
if args.visual_nccl_ports is not None:
301+
args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
304302

303+
# 将申请好的端口放入args参数中
305304
if args.nccl_port is None:
306305
args.nccl_port = nccl_port
307306
if args.pd_decode_rpyc_port is None:
@@ -328,6 +327,7 @@ def normal_or_p_d_start(args):
328327
args.router_max_wait_tokens = 0
329328

330329
send_and_receive_node_ip(args) # 多机用于收发node ip
330+
# dp 必须 > 1
331331
if args.enable_dp_prompt_cache_fetch and args.dp <= 1:
332332
args.enable_dp_prompt_cache_fetch = False
333333
logger.warning(
@@ -491,7 +491,7 @@ def pd_master_start(args):
491491

492492

493493
def visual_start(args):
494-
check_and_set_args(args)
494+
normal_or_p_d_start(args, only_prepare=True)
495495

496496
already_uesd_ports = [args.remote_vit_port]
497497
if args.nccl_port is not None:
@@ -515,23 +515,23 @@ def visual_start(args):
515515
can_use_ports = can_use_ports[5:]
516516

517517
visual_model_tp_ports = []
518+
visual_nccl_ports = []
518519
for _ in range(args.visual_dp):
519520
tp_ports_for_dp = can_use_ports[0 : args.visual_tp]
520-
can_use_ports = can_use_ports[args.visual_tp :]
521521
visual_model_tp_ports.append(tp_ports_for_dp)
522+
can_use_ports = can_use_ports[args.visual_tp :]
523+
if args.visual_nccl_ports is None:
524+
visual_nccl_ports.append(can_use_ports[0])
525+
can_use_ports = can_use_ports[1:]
522526

523-
if args.visual_nccl_ports is None:
524-
args.visual_nccl_ports = can_use_ports[0 : args.visual_dp]
525-
can_use_ports = can_use_ports[args.visual_dp :]
526-
else:
527+
if args.visual_nccl_ports is not None:
527528
args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
528529

529530
args.router_port = router_port
530531
args.visual_port = visual_port
531532
args.audio_port = audio_port
532533
args.cache_port = cache_port
533534
args.metric_port = metric_port
534-
args.visual_model_rpc_ports = visual_model_tp_ports
535535
args.visual_node_id = uuid.uuid4().int
536536

537537
logger.info(f"all start args:{args}")
@@ -586,9 +586,9 @@ def config_server_start(args):
586586
"--log-level",
587587
"info",
588588
"--access-logfile",
589-
"/dev/stdout",
589+
"-",
590590
"--error-logfile",
591-
"/dev/stderr",
591+
"-",
592592
"lightllm.server.config_server.api_http:app",
593593
"--keep-alive",
594594
f"{get_lightllm_gunicorn_keep_alive()}",

lightllm/server/embed_cache/impl/memory_cache_with_redis.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,11 @@ def set_items_embed(self, ids: list[int]) -> None:
5151
with self.lock:
5252
for id in ids:
5353
self.redis_cache.insert(str(id))
54-
self._records[id].embed = True
55-
if self._records[id].ref > 0:
56-
self._update_record_ref_by_id(id, -1)
54+
rec = self._records.get(id)
55+
if rec is not None:
56+
rec.embed = True
57+
if rec.ref > 0:
58+
self._update_record_ref_by_id(id, -1)
5759
# 保留一份 redis 引用,直到真正的消费者读取完成后再 release,
5860
# 避免 VIT 刚写完文件但 LLM 还没来得及读取时被 LRU 误删。
5961

@@ -66,19 +68,7 @@ def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[O
6668
exist = self.redis_cache.query_and_incre(str(id))
6769
ret.append(exist)
6870
if exist:
69-
self._records[id].embed = True
71+
rec = self._records.get(id)
72+
if rec is not None:
73+
rec.embed = True
7074
return ret
71-
72-
# def get_items_embed_and_incre(self, ids: list[int]) -> list[Optional[bool]]:
73-
# ret = []
74-
# for id in ids:
75-
# # if self.redis_cache.query(str(id)):
76-
# # ret.append(True)
77-
# # continue
78-
# # 避免重复的引用计数增加
79-
# if self._records[id].embed:
80-
# ret.append(True)
81-
# continue
82-
# self._records[id].embed = self.redis_cache.query_and_incre(str(id))
83-
# ret.append(self._records[id].embed)
84-
# return ret

lightllm/server/embed_cache/manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ def on_disconnect(self, conn):
2626
# (to finalize the service, if needed)
2727
pass
2828

29-
def exposed__check_and_set_new_id_range(self, token_num: int) -> int:
30-
token_num = obtain(token_num)
31-
return self._impl._check_and_set_new_id_range(token_num)
32-
3329
def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]:
3430
md5sum_list = obtain(md5sum_list)
3531
token_num_list = obtain(token_num_list)

lightllm/server/embed_cache/utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from lightllm.utils.log_utils import init_logger
1111

1212
logger = init_logger(__name__)
13-
_ENSURED_AFS_DIRS = set()
1413

1514

1615
def _get_afs_path(base_dir: str, name: str) -> Path:
@@ -19,21 +18,6 @@ def _get_afs_path(base_dir: str, name: str) -> Path:
1918
return Path(base_dir) / name
2019

2120

22-
def _ensure_afs_dir(base_dir: Path) -> None:
23-
base_dir_key = str(base_dir)
24-
if base_dir_key in _ENSURED_AFS_DIRS:
25-
return
26-
if base_dir.exists():
27-
if not base_dir.is_dir():
28-
raise ValueError(f"image_embed_dir is not a directory: {base_dir}")
29-
_ENSURED_AFS_DIRS.add(base_dir_key)
30-
return
31-
32-
base_dir.mkdir(parents=True, mode=0o777, exist_ok=True)
33-
os.chmod(base_dir, 0o777)
34-
_ENSURED_AFS_DIRS.add(base_dir_key)
35-
36-
3721
def tensor2bytes(t: torch.Tensor):
3822
buf = BytesIO()
3923
t = t.detach().cpu()
@@ -50,13 +34,11 @@ def bytes2tensor(b):
5034

5135
def save_tensor_afs(name: str, tensor: torch.Tensor, base_dir: str) -> None:
5236
target_path = _get_afs_path(base_dir, name)
53-
_ensure_afs_dir(target_path.parent)
5437
tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}"
5538

5639
try:
5740
with open(tmp_path, "wb") as f:
5841
torch.save(tensor.detach().cpu(), f, _use_new_zipfile_serialization=False, pickle_protocol=4)
59-
os.chmod(tmp_path, 0o777)
6042
os.replace(tmp_path, target_path)
6143
os.chmod(target_path, 0o777)
6244
except Exception:
@@ -86,7 +68,6 @@ def create_shm(name, data):
8668

8769
def create_afs(name, data, path):
8870
target_path = _get_afs_path(path, name)
89-
_ensure_afs_dir(target_path.parent)
9071
data_size = len(data)
9172
tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}"
9273

@@ -96,7 +77,6 @@ def create_afs(name, data, path):
9677
f.write(mem_view[:data_size])
9778
f.flush()
9879
os.fsync(f.fileno())
99-
os.chmod(tmp_path, 0o777)
10080
os.replace(tmp_path, target_path)
10181
os.chmod(target_path, 0o777)
10282
except Exception:

0 commit comments

Comments
 (0)