From 59d8d6777d2068562a7845f6f1211ddf2d045e4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=92=AE=E5=9C=A3=E8=99=93?= Date: Wed, 11 Mar 2026 11:01:19 +0800 Subject: [PATCH 1/4] feat: vit seperation --- .../common/basemodel/attention_vit/fa3/fp.py | 42 +- lightllm/models/internvl/model.py | 13 +- .../qwen_vl/layer_infer/pre_layer_infer.py | 25 ++ .../vit/triton_kernel/flashattention_nopad.py | 42 +- lightllm/server/api_cli.py | 46 ++- lightllm/server/api_http.py | 20 +- lightllm/server/api_lightllm.py | 17 + lightllm/server/api_server.py | 4 +- lightllm/server/api_start.py | 146 +++++-- lightllm/server/config_server/api_http.py | 35 ++ .../server/core/objs/io_objs/group_req.py | 4 +- .../impl/memory_cache_with_redis.py | 74 ++++ .../embed_cache/impl/naive_memory_cache.py | 8 +- lightllm/server/embed_cache/manager.py | 18 +- lightllm/server/embed_cache/utils.py | 389 ++++++++++++++++++ lightllm/server/httpserver/manager.py | 112 +++-- lightllm/server/multimodal_params.py | 21 +- .../server/router/model_infer/infer_batch.py | 4 +- .../model_infer/mode_backend/base_backend.py | 2 +- lightllm/server/visualserver/manager.py | 106 ++++- .../visualserver/model_infer/model_rpc.py | 71 ++-- lightllm/server/visualserver/register_loop.py | 42 ++ lightllm/server/visualserver/vit_connect.py | 236 +++++++++++ lightllm/utils/dist_utils.py | 10 +- lightllm/utils/redis_utils.py | 74 ++++ lightllm/utils/start_utils.py | 8 + requirements.txt | 1 + 27 files changed, 1377 insertions(+), 193 deletions(-) create mode 100644 lightllm/server/embed_cache/impl/memory_cache_with_redis.py create mode 100644 lightllm/server/visualserver/register_loop.py create mode 100644 lightllm/server/visualserver/vit_connect.py create mode 100644 lightllm/utils/redis_utils.py diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index f804116f1f..c5b8cc1076 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -1,6 +1,7 @@ import dataclasses import torch from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.utils.sgl_utils import flash_attn_varlen_func class Fa3VitAttBackend(BaseVitAttBackend): @@ -17,42 +18,19 @@ def _vit_att_fwd( head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 window_size = (-1, -1) - torch.ops.sgl_kernel.fwd.default( + attn_output = flash_attn_varlen_func( q, k, v, - None, # k_new - None, # v_new - None, # qv - o, # out - cu_seqlens, - cu_seqlens, - None, # cu_seqlens_k_new - None, - None, - max_seqlen, - max_seqlen, - None, # page_table, - None, # kv_batch_idx - None, # leftpad_k - None, # rotary cos - None, # rotary sin - None, # seqlens_rotary - None, - None, - None, - softmax_scale, - False, - window_size[0], - window_size[1], - attention_chunk=0, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=False, + window_size=window_size, softcap=0.0, - is_rotary_interleaved=False, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - sinks=None, ) + o.copy_(attn_output) return o diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index ccb76d3512..3e5b9c5e2a 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -68,6 +68,7 @@ def init_imageitem_extral_params( img.extra_params["image_patch_max_num"] = 6 elif num_images > 6: img.extra_params["image_patch_max_num"] = 0 + img.patch_num = self.get_image_patch(img) return def init_audioitem_extral_params( @@ -75,14 +76,14 @@ def init_audioitem_extral_params( ): return - def get_image_token_length(self, img: ImageItem): - return ( - self.get_image_patch_func( - img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True - ) - * self.image_length + def get_image_patch(self, img: ImageItem): + return self.get_image_patch_func( + img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True ) + def get_image_token_length(self, img: ImageItem): + return self.get_image_patch(img) * self.image_length + def get_audio_token_length(self, audio: AudioItem): L = audio.audio_length audio_token_num = 0 diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 9b9fe2569c..faff792901 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -1,11 +1,15 @@ +import rpyc +import socket import torch import torch.distributed as dist from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed, read_afs from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.envs_utils import get_env_start_args """ @@ -26,6 +30,11 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config): super().__init__(network_config) + self.args = get_env_start_args() + self.cache_client = None + if self.args.enable_remote_vit: + self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -55,6 +64,22 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei else cpu_embed_cache_client.cpu_embed_cache_tensor ) + if self.args.enable_remote_vit: + for batch_id, p in enumerate(infer_state.multimodal_params): + for img in p["images"] + p["audios"]: + if img["token_num"] is None: + continue + if self.args.image_embed_dir: + embed_bytes = read_afs(get_shm_name_embed(img["uuid"]), self.args.image_embed_dir) + else: + embed_bytes = read_shm(get_shm_name_embed(img["uuid"])) + embed_tensor = bytes2tensor(embed_bytes).to(device="cuda", non_blocking=True) + g_infer_context.cpu_embed_cache_client.copy_vision_to_cache( + embed_tensor=embed_tensor, + start_index_in_cache=img["start_index_in_embed_cache"], + ) + self.cache_client.root.release([img["uuid"]]) + assert cpu_embed_cache_tensor.shape[2] == hidden_size, ( f"Dimension mismatch: text weight dimension is {hidden_size}, " f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}" diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index 8428e52996..a38e27924b 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -167,42 +167,20 @@ def flash_attention_v3_fwd( head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 window_size = (-1, -1) - torch.ops.sgl_kernel.fwd.default( + attn_output = flash_attn_varlen_func( q, k, v, - None, # k_new - None, # v_new - None, # qv - o, # out - cu_seqlens, - cu_seqlens, - None, # cu_seqlens_k_new - None, - None, - max_seqlen, - max_seqlen, - None, # page_table, - None, # kv_batch_idx - None, # leftpad_k - None, # rotary cos - None, # rotary sin - None, # seqlens_rotary - None, - None, - None, - softmax_scale, - False, - window_size[0], - window_size[1], - 0.0, - is_rotary_interleaved=False, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - sinks=None, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=False, + window_size=window_size, + softcap=0.0, ) + o.copy_(attn_output) return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 762d84575b..924b26fce3 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,17 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"], + choices=[ + "normal", + "prefill", + "decode", + "nixl_prefill", + "nixl_decode", + "pd_master", + "config_server", + "visual", + "visual_only", + ], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -593,6 +603,40 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + parser.add_argument( + "--image_embed_dir", + type=str, + default=None, + help="path for vit embed", + ) + parser.add_argument( + "--enable_remote_vit", + action="store_true", + help="Whether to enable remote vit for multimodal service.", + ) + parser.add_argument( + "--remote_vit_port", + type=int, + default=12346, + help="The port number for the remote vit service.", + ) + parser.add_argument( + "--redis_port", + type=int, + default=6379, + help="The port number for the redis service in config_server mode.", + ) + parser.add_argument( + "--redis_evict_fraction", + type=float, + default=0.3, + help="The evict fraction for the redis service in config_server mode.", + ) + parser.add_argument( + "--start_redis", + action="store_true", + help="Whether to start the redis service in config_server mode.", + ) parser.add_argument( "--enable_cpu_cache", action="store_true", diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 230da5b369..bf246f8f0d 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -43,7 +43,7 @@ from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster -from .api_lightllm import lightllm_get_score +from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.log_utils import init_logger from lightllm.utils.error_utils import ServerBusyError @@ -92,6 +92,8 @@ def set_args(self, args: StartArgs): self.httpserver_manager = HttpServerManagerForPDMaster( args=args, ) + elif args.run_mode == "visual": + self.metric_client = MetricClient(args.metric_port) else: init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) @@ -136,7 +138,7 @@ def get_model_name(): @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") async def healthcheck(request: Request): - if g_objs.args.run_mode == "pd_master": + if g_objs.args.run_mode in ["pd_master", "visual"]: return JSONResponse({"message": "Ok"}, status_code=200) if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": @@ -221,6 +223,18 @@ async def get_score(request: Request) -> Response: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) +@app.post("/get_image_embedding") +async def get_image_embed(request: Request) -> Response: + try: + return await lightllm_get_image_embedding(request, g_objs.httpserver_manager) + except ServerBusyError as e: + logger.error("%s", str(e), exc_info=True) + return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e)) + except Exception as e: + logger.error("An error occurred: %s", str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) + + @app.post("/") async def compat_generate(request: Request) -> Response: if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: @@ -359,6 +373,8 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) + if g_objs.httpserver_manager is None: + return loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f54..fbff8e681c 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -5,6 +5,7 @@ from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager +from fastapi.responses import JSONResponse import ujson as json @@ -150,3 +151,19 @@ async def stream_results() -> AsyncGenerator[bytes, None]: background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + + +async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response: + request_dict = await request.json() + # request_dict: {'parameters': {'max_new_tokens': 128}, + # 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}} + sample_params_dict = request_dict["parameters"] + sampling_params = SamplingParams() + sampling_params.init(tokenizer=None, **sample_params_dict) + sampling_params.verify() + multimodal_params_dict = request_dict.get("multimodal_params", {}) + multimodal_params = MultimodalParams(**multimodal_params_dict) + + await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request) + + return JSONResponse({"message": "OK"}, status_code=200) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808a..c6700c0416 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -5,11 +5,13 @@ torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess parser = make_argument_parser() args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) + elif args.run_mode == "visual": + visual_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 0db786d0bf..e69588ff80 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -5,7 +5,7 @@ import subprocess import signal from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker -from lightllm.utils.start_utils import process_manager, kill_recursive +from lightllm.utils.start_utils import process_manager, kill_recursive, is_multimodal_mode from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger @@ -15,6 +15,7 @@ from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip +from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size from lightllm.utils.config_utils import has_audio_module, has_vision_module @@ -57,11 +58,12 @@ def signal_handler(sig, frame): signal.signal(signal.SIGINT, signal_handler) logger.info(f"start process pid {os.getpid()}") - logger.info(f"http server pid {http_server_process.pid}") + if http_server_process: + logger.info(f"http server pid {http_server_process.pid}") return -def normal_or_p_d_start(args): +def check_and_set_args(args): from lightllm.server.core.objs.start_args_type import StartArgs args: StartArgs = args @@ -73,7 +75,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual", "visual_only"]: return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 @@ -161,6 +163,7 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + args.enable_multimodal = is_multimodal_mode(args) # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -191,18 +194,20 @@ def normal_or_p_d_start(args): if args.batch_max_tokens is None: args.batch_max_tokens = args.max_req_total_len else: - assert args.batch_max_tokens >= args.max_req_total_len, f"batch_max_tokens must >= max_req_total_len" - f"but got {args.batch_max_tokens}, {args.max_req_total_len}" + assert args.batch_max_tokens >= args.max_req_total_len, ( + f"batch_max_tokens must >= max_req_total_len" + f"but got {args.batch_max_tokens}, {args.max_req_total_len}" + ) else: # chunked 模式下 if args.batch_max_tokens is None: args.batch_max_tokens = 16384 // args.dp if args.chunked_prefill_size is None: args.chunked_prefill_size = args.batch_max_tokens // 2 - assert ( - args.batch_max_tokens >= args.chunked_prefill_size - ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, " - f"but got {args.batch_max_tokens}, {args.chunked_prefill_size}" + assert args.batch_max_tokens >= args.chunked_prefill_size, ( + "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, " + f"but got {args.batch_max_tokens}, {args.chunked_prefill_size}" + ) # help to manage data stored on Ceph if "s3://" in args.model_dir: @@ -222,11 +227,17 @@ def normal_or_p_d_start(args): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + +def normal_or_p_d_start(args): + check_and_set_args(args) + already_uesd_ports = [args.port] if args.nccl_port is not None: already_uesd_ports.append(args.nccl_port) if args.pd_decode_rpyc_port is not None: already_uesd_ports.append(args.pd_decode_rpyc_port) + if args.visual_nccl_ports is not None: + already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -234,8 +245,10 @@ def normal_or_p_d_start(args): ports_locker.lock_port() node_world_size = args.tp // args.nnodes + need_visual_nccl_ports = 0 if args.visual_nccl_ports is not None else args.visual_dp can_use_ports = alloc_can_use_network_port( - num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_nccl_ports=already_uesd_ports + num=10 + node_world_size + args.visual_dp * args.visual_tp + need_visual_nccl_ports, + used_nccl_ports=already_uesd_ports, ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -253,15 +266,17 @@ def normal_or_p_d_start(args): can_use_ports = can_use_ports[10:] visual_model_tp_ports = [] - visual_nccl_ports = [] for _ in range(args.visual_dp): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] visual_model_tp_ports.append(tp_ports_for_dp) can_use_ports = can_use_ports[args.visual_tp :] - visual_nccl_ports.append(can_use_ports[0]) - can_use_ports = can_use_ports[1:] - # 将申请好的端口放入args参数中 + if args.visual_nccl_ports is None: + visual_nccl_ports = can_use_ports[0 : args.visual_dp] + can_use_ports = can_use_ports[args.visual_dp :] + else: + visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + if args.nccl_port is None: args.nccl_port = nccl_port if args.pd_decode_rpyc_port is None: @@ -288,7 +303,6 @@ def normal_or_p_d_start(args): args.router_max_wait_tokens = 0 send_and_receive_node_ip(args) # 多机用于收发node ip - # dp 必须 > 1 if args.enable_dp_prompt_cache_fetch and args.dp <= 1: args.enable_dp_prompt_cache_fetch = False logger.warning( @@ -309,27 +323,27 @@ def normal_or_p_d_start(args): start_args=[(args,)], ) - if not args.disable_vision: - from .visualserver.manager import start_visual_process + if not args.disable_audio: + from .audioserver.manager import start_audio_process process_manager.start_submodule_processes( start_funcs=[ - start_visual_process, + start_audio_process, ], start_args=[ - (args, visual_model_tp_ports), + (args,), ], ) - if not args.disable_audio: - from .audioserver.manager import start_audio_process + if not args.disable_vision and not args.enable_remote_vit: + from .visualserver.manager import start_visual_process process_manager.start_submodule_processes( start_funcs=[ - start_audio_process, + start_visual_process, ], start_args=[ - (args,), + (args, visual_model_tp_ports), ], ) @@ -435,7 +449,6 @@ def pd_master_start(args): "-", "--error-logfile", "-", - "--preload", "lightllm.server.api_http:app", "--keep-alive", f"{get_lightllm_gunicorn_keep_alive()}", @@ -452,6 +465,81 @@ def pd_master_start(args): http_server_process.wait() +def visual_start(args): + check_and_set_args(args) + + already_uesd_ports = [args.remote_vit_port] + if args.nccl_port is not None: + already_uesd_ports.append(args.nccl_port) + if args.visual_nccl_ports is not None: + already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) + + need_visual_nccl_ports = 0 if args.visual_nccl_ports is not None else args.visual_dp + can_use_ports = alloc_can_use_network_port( + num=5 + args.visual_dp * args.visual_tp + need_visual_nccl_ports, + used_nccl_ports=already_uesd_ports, + ) + logger.info(f"alloced ports: {can_use_ports}") + ( + router_port, + visual_port, + audio_port, + cache_port, + metric_port, + ) = can_use_ports[0:5] + can_use_ports = can_use_ports[5:] + + visual_model_tp_ports = [] + for _ in range(args.visual_dp): + tp_ports_for_dp = can_use_ports[0 : args.visual_tp] + can_use_ports = can_use_ports[args.visual_tp :] + visual_model_tp_ports.append(tp_ports_for_dp) + + if args.visual_nccl_ports is None: + args.visual_nccl_ports = can_use_ports[0 : args.visual_dp] + can_use_ports = can_use_ports[args.visual_dp :] + else: + args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + + args.router_port = router_port + args.visual_port = visual_port + args.audio_port = audio_port + args.cache_port = cache_port + args.metric_port = metric_port + args.visual_model_rpc_ports = visual_model_tp_ports + args.visual_node_id = uuid.uuid4().int + + logger.info(f"all start args:{args}") + + set_env_start_args(args) + + from .visualserver.manager import start_visual_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_cache_manager, + ], + start_args=[(args,)], + ) + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, visual_model_tp_ports), + ], + ) + setup_signal_handlers(None, process_manager) + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Received keyboard interrupt, shutting down...") + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully.") + sys.exit(0) + + def config_server_start(args): set_unique_server_name(args) if args.run_mode != "config_server": @@ -459,6 +547,9 @@ def config_server_start(args): logger.info(f"all start args:{args}") + if args.start_redis: + start_redis_service(args) + set_env_start_args(args) command = [ @@ -470,10 +561,9 @@ def config_server_start(args): "--log-level", "info", "--access-logfile", - "-", + "/dev/stdout", "--error-logfile", - "-", - "--preload", + "/dev/stderr", "lightllm.server.config_server.api_http:app", "--keep-alive", f"{get_lightllm_gunicorn_keep_alive()}", diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index c5505acda4..c55b743480 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -9,6 +9,7 @@ from typing import Dict, List from fastapi.responses import JSONResponse from lightllm.utils.log_utils import init_logger +from lightllm.server.visualserver.vit_connect import VIT_Obj from ..pd_io_struct import PD_Master_Obj from .nccl_tcp_store import start_tcp_store_server from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name @@ -19,7 +20,9 @@ app = FastAPI() registered_pd_master_objs: Dict[str, PD_Master_Obj] = {} +registered_visual_server_objs: Dict[str, VIT_Obj] = {} registered_pd_master_obj_lock = Lock() +registered_visual_server_obj_lock = Lock() global_req_id = 0 global_req_id_lock = Lock() @@ -72,6 +75,30 @@ async def websocket_endpoint(websocket: WebSocket): return +@app.websocket("/visual_register") +async def visual_websocket_endpoint(websocket: WebSocket): + await websocket.accept() + client_ip, client_port = websocket.client + logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") + registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes()) + logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}") + with registered_visual_server_obj_lock: + registered_visual_server_objs[registered_visual_server_obj.node_id] = registered_visual_server_obj + + try: + while True: + data = await websocket.receive_text() + assert data == "heartbeat" + except (WebSocketDisconnect, Exception, RuntimeError) as e: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} has error {str(e)}") + logger.exception(str(e)) + finally: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} removed") + with registered_visual_server_obj_lock: + registered_visual_server_objs.pop(registered_visual_server_obj.node_id, None) + return + + @app.get("/registered_objects") async def get_registered_objects(): with registered_pd_master_obj_lock: @@ -80,6 +107,14 @@ async def get_registered_objects(): return {"data": base64_encoded} +@app.get("/registered_visual_objects") +async def get_vit_registered_objects(): + with registered_visual_server_obj_lock: + serialized_data = pickle.dumps(registered_visual_server_objs) + base64_encoded = base64.b64encode(serialized_data).decode("utf-8") + return {"data": base64_encoded} + + @app.get("/allocate_global_unique_id_range") async def allocate_global_id_range(): """ diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index dfcbdd2562..75f2c0e2f1 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -23,7 +23,9 @@ def to_group_req_index(self): return GroupReqIndexes( group_req_id=self.group_req_id, multimodal_params=self.multimodal_params, - shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], + shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs] + if self.shm_req_objs is not None + else None, time_mark=self.time_mark, ) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py new file mode 100644 index 0000000000..05bd0bc23e --- /dev/null +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -0,0 +1,74 @@ +import uuid +import threading +import dataclasses +import requests +from typing import Union, Optional +import torch +import time +from collections import deque +import multiprocessing.shared_memory as shm +from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, EmbedRefCountRedis +from .naive_memory_cache import Record, InMemoryCache +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MemoryCacheWithRedis(InMemoryCache): + def __init__(self, args) -> None: + super().__init__(args) + redis_url = f"redis://{args.config_server_host}:{args.redis_port}/0" + self.redis_cache = EmbedRefCountRedis( + redis_url=redis_url, + capacity=args.cache_capacity, + evict_fraction=args.redis_evict_fraction, + image_embed_dir=args.image_embed_dir, + ) + # 这里之所以把cache * 2是因为,在分离模式下,cache 服务只是为了更新redis状态,以及维护图片cache的 token_id + # 便于 dynamic prompt cache 的使用。所以要把cache_capacity * 2,保障其保留的图片cache > redis 服务维护的 + # 硬盘里的图片image embed 数量。 + self.cache_capacity = args.cache_capacity * 2 + + # llm 负责release + def release(self, ids: list[int]) -> None: + with self.lock: + for id in ids: + self._records[id].ref -= 1 + if self.redis_cache.query(str(id)): + self.redis_cache.decr(str(id)) + # print(self.redis_cache.stats(), flush=True) + + # vit 负责set + def set_items_embed(self, ids: list[int]) -> None: + with self.lock: + for id in ids: + self.redis_cache.insert(str(id)) + self._records[id].embed = True + self._records[id].ref -= 1 + self.redis_cache.decr(str(id)) # vit端alloc之后ref+1 vit完成后ref-1 + + def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: + ret = [] + for id in ids: + if embeding_only: + exist = self.redis_cache.query(str(id)) + else: + exist = self.redis_cache.query_and_incre(str(id)) + ret.append(exist) + if exist: + self._records[id].embed = True + return ret + + # def get_items_embed_and_incre(self, ids: list[int]) -> list[Optional[bool]]: + # ret = [] + # for id in ids: + # # if self.redis_cache.query(str(id)): + # # ret.append(True) + # # continue + # # 避免重复的引用计数增加 + # if self._records[id].embed: + # ret.append(True) + # continue + # self._records[id].embed = self.redis_cache.query_and_incre(str(id)) + # ret.append(self._records[id].embed) + # return ret diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 61ba46d7c6..8b7528cd0f 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -36,6 +36,7 @@ class InMemoryCache: def __init__(self, args) -> None: self.args = args self._id_to_records = dict() + self._records = self._id_to_records self._md5_to_record = dict() self._sorted_records = SortedSet(key=lambda x: (x.ref, x.visittime, x.id)) self.capacity = max(1, args.cache_capacity) @@ -160,14 +161,13 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l new_md5_dict[m] = token_need new_needed = len(new_md5_dict) - alloc_md5_dict = self._free_to_alloc( free_min_count=new_needed - (self.capacity - self.occupied), new_md5_dict=new_md5_dict ) if len(alloc_md5_dict) == len(new_md5_dict): for md5sum, mem_block in alloc_md5_dict.items(): token_num = new_md5_dict[md5sum] - uid_int = uuid.uuid1().int + uid_int = md5sum self._check_and_set_new_id_range(token_num) rec = Record( id=uid_int, @@ -207,6 +207,8 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l return results else: + for md5sum in add_ref_m_list: + self._del_ref(md5sum) return None def release(self, ids: list[int]) -> None: @@ -228,5 +230,5 @@ def set_items_embed(self, ids: list[int]) -> None: for id_ in ids: self._id_to_records[id_].embed = True - def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: + def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: return [self._id_to_records.get(id_).embed if id_ in self._id_to_records else False for id_ in ids] diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 5de4df4ab3..0dc8830cdb 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -6,6 +6,7 @@ from lightllm.server.core.objs import StartArgs from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache +from lightllm.server.embed_cache.impl.memory_cache_with_redis import MemoryCacheWithRedis from rpyc.utils.classic import obtain from lightllm.utils.envs_utils import get_unique_server_name @@ -25,6 +26,10 @@ def on_disconnect(self, conn): # (to finalize the service, if needed) pass + def exposed__check_and_set_new_id_range(self, token_num: int) -> int: + token_num = obtain(token_num) + return self._impl._check_and_set_new_id_range(token_num) + def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]: md5sum_list = obtain(md5sum_list) token_num_list = obtain(token_num_list) @@ -47,9 +52,16 @@ def exposed_set_items_embed(self, ids: list[int]) -> None: ids = obtain(ids) return self._impl.set_items_embed(ids) - def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: + def exposed_get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[bool]: ids = obtain(ids) - return self._impl.get_items_embed(ids) + return self._impl.get_items_embed(ids, embeding_only) + + +def get_cache_manager(args): + if args.enable_remote_vit or args.run_mode == "visual": + return MemoryCacheWithRedis(args) + else: + return InMemoryCache(args) def start_cache_manager(args: StartArgs, pipe_writer): @@ -57,7 +69,7 @@ def start_cache_manager(args: StartArgs, pipe_writer): graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::cache_manager") - manager = InMemoryCache(args) + manager = get_cache_manager(args) service = CacheServer(manager) from rpyc.utils.server import ThreadedServer import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 367bcc91a9..aa0dc903dd 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -1,4 +1,29 @@ +import os +import time +import torch +import redis +import numpy as np +from typing import List, Tuple +from io import BytesIO +from pathlib import Path import multiprocessing.shared_memory as shm +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def tensor2bytes(t: torch.Tensor): + buf = BytesIO() + t = t.detach().cpu() + dest = torch.empty_like(t) + dest.copy_(t) + torch.save(dest, buf, _use_new_zipfile_serialization=False, pickle_protocol=4) + buf.seek(0) + return buf.read() + + +def bytes2tensor(b): + return torch.load(BytesIO(b), weights_only=False) def create_shm(name, data): @@ -11,17 +36,381 @@ def create_shm(name, data): print("Warning create shm {} failed because of FileExistsError!".format(name)) +def create_afs(name, data, path): + try: + data_size = len(data) + path = os.path.join(path, name) + with open(path, "xb") as f: + mem_view = memoryview(data) + f.write(mem_view[:data_size]) + f.flush() + os.fsync(f.fileno()) + except FileExistsError: + print("Warning create afs {} failed because of FileExistsError!".format(name)) + + def read_shm(name): shared_memory = shm.SharedMemory(name=name) data = shared_memory.buf.tobytes() return data +def read_afs(name: str, base_dir) -> bytes: + + path = Path(base_dir) / name + return path.read_bytes() + + def free_shm(name): shared_memory = shm.SharedMemory(name=name) shared_memory.close() shared_memory.unlink() +def free_afs(name: str, base_dir) -> None: + path = Path(base_dir) / name + path.unlink() + + def get_shm_name_data(uid): return str(uid) + "-data" + + +def get_shm_name_embed(uid): + return str(uid) + "-embed" + + +""" +Importable Redis-backed MD5 refcount with LRU eviction. + +Public API: + from md5_refcount import EmbedRefCountRedis + + cache = EmbedRefCountRedis( + redis_url="redis://localhost:6379/0", + capacity=10000, + evict_fraction=0.2 + ) + + # Insert a new md5 with default ref_count=0 + success, evicted_list = cache.insert(md5) + + # Query if exists and increment ref_count if found + exists = cache.query_and_incre(md5) + + # Decrement ref_count + rc, deleted = cache.decr(md5) + + s = cache.stats() +""" + + +class EmbedRefCountRedis: + def __init__( + self, + redis_url: str = "redis://localhost:6379/0", + capacity: int = 50000, + evict_fraction: float = 0.1, + key_prefix: str = "md5:", + image_embed_dir: str = None, + path_ext: str = "-embed", + **redis_kwargs, + ) -> None: + """ + - capacity: max count of md5 entries allowed in Redis + - evict_fraction: fraction to evict when inserting a NEW md5 and at capacity + - image_embed_dir: base directory for image embed files (e.g., "/afs/embeds") + - path_ext: file extension for embed files (default: "-embed") + """ + if not (0.0 <= evict_fraction <= 1.0): + raise ValueError("evict_fraction must be 0..1") + if capacity < 1: + raise ValueError("capacity must be >=1") + + self.capacity = int(capacity) + self.evict_fraction = float(evict_fraction) + self.zset_key = f"{key_prefix}lru" + self.ref_prefix = f"{key_prefix}rc:" + self.lock_key = f"{key_prefix}evict:lock" + self.image_embed_dir = image_embed_dir + self.path_ext = path_ext + + self.r = redis.Redis.from_url(redis_url, decode_responses=True, **redis_kwargs) + + # Register Lua scripts + self._insert_script = self.r.register_script(self._INSERT_LUA) + self._query_incre_script = self.r.register_script(self._QUERY_INCRE_LUA) + self._decr_script = self.r.register_script(self._DECR_LUA) + self._evict_and_insert_script = self.r.register_script(self._EVICT_AND_INSERT_LUA) + + def insert(self, md5: str) -> Tuple[bool, List[str]]: + """Insert a new md5 with default ref_count=1. May trigger LRU eviction.""" + # 等待任何正在进行的逐出操作 + self._wait_if_eviction() + + res = self._insert_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5, self.capacity, self.evict_fraction], + ) + + if res[0] == 0: # No eviction needed + return True, [] + + # Need eviction - use atomic eviction script + try: + if self._try_acquire_lock(): + try: + # 原子执行逐出和插入 + evict_res = self._evict_and_insert_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5, self.capacity, self.evict_fraction], + ) + success = bool(evict_res[0]) + victims = evict_res[1:] if len(evict_res) > 1 else [] + + if success: + # 删除被逐出md5对应的AFS文件 + if victims and self.image_embed_dir: + self._delete_afs_files(victims) + return True, victims + else: + # 逐出失败,短暂退避后重试 + time.sleep(0.01) + return self.insert(md5) + finally: + self._release_lock() + else: + # 等待锁释放后重试 + time.sleep(0.01) + return self.insert(md5) + except Exception as e: + self._release_lock() + raise e + + def query(self, md5: str) -> bool: + """Quert if md5 exists.""" + self._wait_if_eviction() + return bool(self.r.exists(self.ref_prefix + md5)) + + def query_and_incre(self, md5: str) -> bool: + """Query if md5 exists and increment ref_count if found.""" + self._wait_if_eviction() + res = self._query_incre_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5], + ) + return bool(res[0]) + + def decr(self, md5: str) -> Tuple[int, bool]: + """Decrement ref_count for md5. Returns (ref_count, deleted).""" + self._wait_if_eviction() + + res = self._decr_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5], + ) + if res[0] == -1: + raise KeyError("md5 not found") + return int(res[0]), bool(res[1]) + + def stats(self) -> dict: + self._wait_if_eviction() + + size = self.r.zcard(self.zset_key) + return { + "items": size, + "capacity": self.capacity, + "evict_fraction": self.evict_fraction, + } + + def get_ref(self, md5: str) -> int | None: + self._wait_if_eviction() + val = self.r.get(self.ref_prefix + md5) + return int(val) if val is not None else None + + def _wait_if_eviction(self) -> None: + max_wait = 30 + start_time = time.time() + + while self.r.exists(self.lock_key): + if time.time() - start_time > max_wait: + raise TimeoutError("Eviction operation timeout, waited too long") + time.sleep(0.01) # 短暂等待 + + def _try_acquire_lock(self) -> bool: + return bool(self.r.set(self.lock_key, "1", nx=True, ex=30)) + + def _release_lock(self) -> None: + try: + self.r.delete(self.lock_key) + except Exception: + pass + + def _md5_to_afs_path(self, md5: str) -> str: + """Convert md5 to AFS file path.""" + if not self.image_embed_dir: + return None + filename = self.image_embed_dir + md5 + self.path_ext + return filename + + def _delete_afs_files(self, victims: List[str]) -> None: + """Delete AFS files for evicted md5s.""" + if not self.image_embed_dir: + return + + for md5 in victims: + try: + file_path = self._md5_to_afs_path(md5) + if file_path and os.path.exists(file_path): + os.remove(file_path) + logger.debug(f"Deleted AFS file: {file_path}") + except Exception as e: + logger.debug(f"Warning: Failed to delete AFS file for {md5}: {e}") + + # ---------------- Lua scripts ---------------- + _INSERT_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5, ARGV[2] = capacity, ARGV[3] = evict_fraction +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] +local capacity = tonumber(ARGV[2]) + +local unpack = unpack or table.unpack +local ref_key = ref_prefix .. md5 +if redis.call('GET', ref_key) then + return {0} -- Already exists +end + +local size = redis.call('ZCARD', zset) +if size < capacity then + -- Insert with ref_count=1 + redis.call('SET', ref_key, 1) + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) + return {0} -- Success, no eviction +end + +return {1} -- Need eviction +""" + + _QUERY_INCRE_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5 +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] + +local ref_key = ref_prefix .. md5 +local val = redis.call('GET', ref_key) + +if not val then + return {0} -- Not found +end + +-- Found, increment ref_count and update LRU +local rc = tonumber(val) + 1 +redis.call('SET', ref_key, rc) +local now = redis.call('TIME')[1] * 1000 +redis.call('ZADD', zset, now, md5) +return {1} -- Found and incremented +""" + + _DECR_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5 +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] + +local ref_key = ref_prefix .. md5 +local val = redis.call('GET', ref_key) + +if not val then + return {-1, 0} -- Not found +end + +--ref 递减到 0 时保留键,只更新计数与 LRU +local rc = tonumber(val) - 1 +if rc < 0 then rc = 0 end +redis.call('SET', ref_key, rc) + +if rc > 0 then + -- 只有仍被引用时才更新 LRU + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) +end + +return {rc, 0} +""" + + _EVICT_AND_INSERT_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = new_md5, ARGV[2] = capacity, ARGV[3] = evict_fraction +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local new_md5 = ARGV[1] +local capacity = tonumber(ARGV[2]) +local evict_fraction = tonumber(ARGV[3]) + +local unpack = unpack or table.unpack + +-- helper: now millis +local function now_ms() + local t = redis.call('TIME') + return t[1] * 1000 + math.floor(t[2] / 1000) +end + +local new_ref_key = ref_prefix .. new_md5 + +-- If already exists, treat as a hit: bump ref_count and refresh LRU +local cur = redis.call('GET', new_ref_key) +if cur then + local rc = tonumber(cur) + 1 + redis.call('SET', new_ref_key, rc) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1} -- success, no victims +end + +-- If not at capacity, just insert +local size = redis.call('ZCARD', zset) +if size < capacity then + redis.call('SET', new_ref_key, 1) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1} -- success, no victims +end + +-- At capacity: try to evict up to max_try items with rc==0, but success if at least 1 is freed +local max_try = math.max(1, math.floor(size * evict_fraction + 0.5)) +local victims = {} +local freed = 0 + +-- Scan from LRU (smallest score) to MRU +local all_keys = redis.call('ZRANGE', zset, 0, -1, 'WITHSCORES') +local i = 1 +while freed < 1 and i <= #all_keys and #victims < max_try do + local md5 = all_keys[i] + local ref_key = ref_prefix .. md5 + local v = redis.call('GET', ref_key) + if v and tonumber(v) <= 0 then + table.insert(victims, md5) + freed = freed + 1 + end + i = i + 2 -- skip score +end + +if freed >= 1 then + -- delete victims + for _, v in ipairs(victims) do + redis.call('DEL', ref_prefix .. v) + redis.call('ZREM', zset, v) + end + -- insert new + redis.call('SET', new_ref_key, 1) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1, unpack(victims)} +else + -- no zero-ref items found + return {0} +end +""" diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 6481098eb9..08966545e2 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -82,15 +82,10 @@ def __init__( if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + # 初始化VIT连接管理器 + from lightllm.server.visualserver.vit_connect import VITConnectionManager - if not self.args.disable_vision: - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") - - if not self.args.disable_audio: - self.send_to_audio = context.socket(zmq.PUSH) - self.send_to_audio.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") - + self.vit_manager = VITConnectionManager(args, context, args.visual_port, self.cache_client) if args.enable_cpu_cache and not self.args.enable_multimodal: self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH) self.send_to_multi_level_kv_cache.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") @@ -124,10 +119,10 @@ def __init__( self.latest_success_infer_time_mark.set_value(int(time.time())) return - async def _alloc_resource(self, items, md5sums, token_nums, datas): + async def _alloc_resource(self, items, uuids, token_nums, datas): while True: - records = obtain(self.cache_client.root.alloc(md5sums, token_nums)) + records = obtain(self.cache_client.root.alloc(uuids, token_nums)) if records is None: await asyncio.sleep(0.1) @@ -147,6 +142,10 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): uid_list.append(rec["id"]) + # # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server + if self.args.enable_remote_vit: + return + ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) update_data_ids = [] @@ -166,14 +165,15 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 async with self._resource_lock: - items, md5sums, tokens_nums, datas = [], [], [], [] + items, uuids, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) - md5sums.append(md5sum) + md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), img.patch_num) + uuid = int(md5sum, 16) + uuids.append(uuid) tokens_nums.append(token_num) datas.append(data) items.append(img) @@ -181,13 +181,17 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) data = audio.read() token_num = self.tokenizer.get_audio_token_length(audio) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) - md5sums.append(md5sum) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), + ) + uuid = int(md5sum, 16) + uuids.append(uuid) tokens_nums.append(token_num) datas.append(data) items.append(audio) - await self._alloc_resource(items, md5sums, tokens_nums, datas) + await self._alloc_resource(items, uuids, tokens_nums, datas) return async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): @@ -211,7 +215,7 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam audio.token_id = None audio.token_num = None audio.start_index_in_embed_cache = None - if ids_to_release: + if ids_to_release and not self.args.enable_remote_vit: self.cache_client.root.release(ids_to_release) return @@ -408,6 +412,49 @@ async def generate( raise e return + async def get_image_embeding( + self, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + is_health_req: bool = False, + ) -> Tuple[int, str, dict, FinishStatus]: + start_time = time.time() + request_headers = request.headers if request is not None else {} + group_request_id = self.alloc_req_id(sampling_params, is_health_req) + + try: + original_multimodal_params = None + if self.is_multinode_tp_master: + original_multimodal_params = copy.deepcopy(multimodal_params) + + if self.pd_mode.is_P_or_NORMAL(): + await multimodal_params.verify_and_preload(request) + + await multimodal_params.verify_and_preload(request) + image_count = len(multimodal_params.images) + # 记录请求到达的相关信息 + + await self._log_req_header(request_headers, group_request_id) + logger.info(f"image_count:{image_count}") + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + + visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time) + + await self.transfer_to_next_module_or_node( + None, sampling_params, original_multimodal_params, visual_req_status, embeding_only=True + ) + + except Exception as e: + logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + await self.abort(group_request_id) + raise e + return + def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]: image_tokens = 0 audio_tokens = 0 @@ -509,6 +556,7 @@ async def transfer_to_next_module_or_node( sampling_params: SamplingParams, original_multimodal_params: MultimodalParams, group_req_objs: Optional[GroupReqObjs] = None, + embeding_only: Optional[bool] = False, ): # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. if self.is_multinode_tp_master: @@ -518,22 +566,22 @@ async def transfer_to_next_module_or_node( protocol=pickle.HIGHEST_PROTOCOL, ) - await self.transfer_to_next_module(group_req_objs) + await self.transfer_to_next_module(group_req_objs, embeding_only) return async def transfer_to_next_module( self, group_req_objs: Optional[GroupReqObjs] = None, + embeding_only: Optional[bool] = False, ): if self.pd_mode.is_P_or_NORMAL(): - if not self.args.disable_vision: - self.send_to_visual.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL) - return - - if not self.args.disable_audio: - self.send_to_audio.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL) - return + if self.enable_multimodal: + await self.vit_manager.send_to_vit( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL, + embeding_only=embeding_only, + ) if self.args.enable_cpu_cache: self.send_to_multi_level_kv_cache.send_pyobj( @@ -542,14 +590,15 @@ async def transfer_to_next_module( ) return - self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), - protocol=pickle.HIGHEST_PROTOCOL, - ) + if not self.enable_multimodal or self.args.enable_remote_vit: + self.send_to_router.send_pyobj( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL, + ) return if self.pd_mode.is_D(): - # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了 + # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, @@ -742,6 +791,9 @@ async def handle_loop(self): asyncio.create_task(pd_handle_loop(self)) + if self.enable_multimodal: + asyncio.create_task(self.vit_manager.vit_handle_loop()) + while True: try: await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 09a07455b3..9616d5f84c 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -26,6 +26,7 @@ def __init__(self, **kwargs): self.token_num = None # the audio length self.audio_length = None + self.afs_embed = False self._preload_data = None self.extra_params = {} @@ -54,10 +55,7 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - ans = self._preload_data - self._preload_data = None - self._data = None - return ans + return self._preload_data def to_dict(self): ret = {} @@ -95,6 +93,7 @@ def __init__(self, **kwargs): self.grid_thwd = None self.image_w = 0 self.image_h = 0 + self.patch_num = 0 self._preload_data = None self.extra_params = {} @@ -128,10 +127,11 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - ans = self._preload_data + return self._preload_data + + def free(self): self._preload_data = None self._data = None - return ans def to_dict(self): ret = {} @@ -163,6 +163,15 @@ def __init__( self.audios = [AudioItem(**a) for a in audios] return + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + + def get_all_uuids(self): + return [image.uuid for image in self.images] + [audio.uuid for audio in self.audios] + async def verify_and_preload(self, request: Request): for image in self.images: await image.preload(request) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 1b4a1ca5cb..3aee370eb1 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -59,8 +59,8 @@ def register( self.vocab_size = vocab_size return - def init_cpu_embed_cache_client(self): - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + def init_cpu_embed_cache_client(self, init_shm_data: bool = False): + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=init_shm_data) return def get_overlap_stream(self) -> torch.cuda.Stream: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 8b085c45ed..8b0522d58e 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -145,7 +145,7 @@ def init_model(self, kvargs): wait_events.append(self.multi_level_cache_module) if self.args.enable_multimodal: - g_infer_context.init_cpu_embed_cache_client() + g_infer_context.init_cpu_embed_cache_client(init_shm_data=self.args.enable_remote_vit) model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 8fba9f08d7..508860e899 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -20,7 +20,7 @@ from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name from rpyc.utils.classic import obtain - +from lightllm.server.embed_cache.utils import create_shm, get_shm_name_data logger = init_logger(__name__) @@ -31,13 +31,16 @@ def __init__( args: StartArgs, visual_model_rpc_ports, ): + self.args = args + self.visual_only = args.run_mode in ["visual", "visual_only"] + self.remote_vit = args.enable_remote_vit or self.visual_only + context = zmq.Context(2) - enable_audio = not args.disable_audio - if enable_audio: - self.send_to_next_module = context.socket(zmq.PUSH) - self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") - else: - if args.enable_cpu_cache: + if not self.visual_only: + if not args.disable_audio: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") + elif args.enable_cpu_cache: self.send_to_next_module = context.socket(zmq.PUSH) self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") else: @@ -45,7 +48,11 @@ def __init__( self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") self.zmq_recv_socket = context.socket(zmq.PULL) - self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + if self.remote_vit: + self.zmq_recv_socket.bind(f"tcp://*:{args.remote_vit_port}") + else: + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.cache_port = args.cache_port @@ -56,13 +63,11 @@ def __init__( self.vit_tp = args.visual_tp self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code - self.args = args self.visual_model_rpc_ports = visual_model_rpc_ports self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() async def wait_to_model_ready(self): - self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): @@ -146,13 +151,12 @@ def flush_ready(force: bool = False): continue multimodal_params = group_req_indexes.multimodal_params - img_uuids = [img.uuid for img in multimodal_params.images] # disable prompt cache通常用来测试,需要也去掉image cache的影响 if disable_prompt_cache: ready_image = [False] * len(img_uuids) else: - ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) + ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids, True)) for img, ready in zip(multimodal_params.images, ready_image): if not ready: @@ -180,6 +184,43 @@ def flush_ready(force: bool = False): processing_group_reqs = [] flush_ready(force=True) + async def _recv_reqs(self): + recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if not self.remote_vit: + return recv_req + + uuids = [img.uuid for img in recv_req.multimodal_params.images] + already_embed = await asyncio.to_thread(self.cache_client.root.get_items_embed, uuids, True) + if all(already_embed): + return None + + missing_uuids = [] + token_nums = [] + datas = [] + for img, embed_ready in zip(recv_req.multimodal_params.images, already_embed): + if embed_ready: + continue + missing_uuids.append(img.uuid) + token_nums.append(img.token_num) + datas.append(img.read()) + img.free() + + while True: + if await asyncio.to_thread(self.cache_client.root.alloc, missing_uuids, token_nums) is not None: + break + await asyncio.sleep(0.01) + + ready_flags = obtain(self.cache_client.root.get_items_data(missing_uuids)) + update_data_ids = [] + for uid, ready, data in zip(missing_uuids, ready_flags, datas): + if not ready: + create_shm(get_shm_name_data(uid), data) + update_data_ids.append(uid) + + if update_data_ids: + await asyncio.to_thread(self.cache_client.root.set_items_data, update_data_ids) + return recv_req + async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): self.visual_recv_max_count = 64 @@ -187,7 +228,9 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + recv_req = await self._recv_reqs() + if recv_req is None: + continue if isinstance(recv_req, GroupReqIndexes): logger.info( f"visual recv req id {recv_req.group_req_id} " @@ -196,12 +239,31 @@ async def loop_for_netio_req(self): self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" - self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) + self.visual_recv_max_count = min(int(self.visual_recv_max_count * 1.3), 256) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 self.visual_recv_max_count = 64 await asyncio.sleep(0.01) + async def loop_for_fwd_visual_only(self): + while True: + if len(self.waiting_reqs) == 0: + await asyncio.sleep(0.01) + continue + + images_need_infer = [] + while len(self.waiting_reqs) > 0: + visual_req = self.waiting_reqs.pop(0) + for img in visual_req.multimodal_params.images: + images_need_infer.append(img) + if len(images_need_infer) == self.infer_batch_size: + await self.infer_imgs(images_need_infer) + images_need_infer = [] + + if len(images_need_infer) > 0: + await self.infer_imgs(images_need_infer) + images_need_infer = [] + def clean_up(self): for model_rpc in self.model_rpcs: model_rpc.rpc_server_process.kill() @@ -210,17 +272,29 @@ def clean_up(self): return +def create_forward_loop(args, visualserver: VisualManager, loop: asyncio.AbstractEventLoop): + if args.run_mode in ["visual", "visual_only"]: + from .register_loop import register_loop + + loop.create_task(visualserver.loop_for_fwd_visual_only()) + loop.create_task(register_loop(args)) + else: + loop.create_task(visualserver.loop_for_fwd()) + + def start_visual_process(args, model_rpc_ports, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") start_parent_check_thread() + visualserver = None try: visualserver = VisualManager(args=args, visual_model_rpc_ports=model_rpc_ports) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) - visualserver.clean_up() + if visualserver is not None: + visualserver.clean_up() raise e pipe_writer.send("init ok") @@ -231,6 +305,6 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - loop.create_task(visualserver.loop_for_fwd()) + create_forward_loop(args, visualserver, loop) loop.run_until_complete(visualserver.loop_for_netio_req()) return diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 3e97f4de3e..a9f8e1223a 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -6,25 +6,27 @@ import inspect from datetime import timedelta from typing import Dict, List, Tuple -from transformers.configuration_utils import PretrainedConfig from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer +from transformers.configuration_utils import PretrainedConfig + from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer -from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel +from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel from lightllm.models.vit.model import VisionTransformer from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel +from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel -from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.server.embed_cache.utils import create_afs, get_shm_name_embed, tensor2bytes from lightllm.server.visualserver import set_vit_att_backend @@ -34,6 +36,7 @@ def exposed_init_model(self, kvargs): import torch import torch.distributed as dist + self.args = get_env_start_args() self.vit_dp = kvargs["vit_dp"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] @@ -41,6 +44,9 @@ def exposed_init_model(self, kvargs): self.cache_port = kvargs["cache_port"] weight_dir = kvargs["weight_dir"] self.vit_rank_id = kvargs["vit_rank_id"] + self.image_embed_dir = self.args.image_embed_dir + self.remote_vit = self.args.enable_remote_vit or self.args.run_mode in ["visual", "visual_only"] + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] @@ -56,6 +62,7 @@ def exposed_init_model(self, kvargs): "quant_type": kvargs["quant_type"], "quant_cfg": kvargs["quant_cfg"], "max_batch_size": kvargs["max_batch_size"], + "remote_vit": self.remote_vit, } self.model_type = model_cfg["model_type"] if self.model_type == "qwen": @@ -92,10 +99,10 @@ def exposed_init_model(self, kvargs): ) else: raise Exception(f"can not support {self.model_type} now") - self.model.load_model(weight_dir) self.model = self.model.cuda() - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True) + if not self.remote_vit: + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True) except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -116,33 +123,48 @@ def forward(self, images: List[ImageItem]): def exposed_encode(self, images: List[ImageItem]): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cuda")) - - if self.tp_rank_id == 0: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) - ids_to_set = [] - for i, ready in enumerate(ready_flags): - if ready: - continue - uid = uuids[i] - start, end = valid_ids[i] - image = images[i] + + if self.tp_rank_id != 0: + return + + ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) + ids_to_set = [] + cpu_embeds = None + if self.remote_vit: + cpu_embeds = all_img_embeds.to(torch.device("cpu"), non_blocking=True) + + for i, ready in enumerate(ready_flags): + if ready: + continue + uid = uuids[i] + start, end = valid_ids[i] + image = images[i] + if self.remote_vit: + cur_embed_bytes = tensor2bytes(cpu_embeds[start:end]) + create_afs(get_shm_name_embed(uid), cur_embed_bytes, self.image_embed_dir) + else: self.cpu_embed_cache_client.copy_vision_to_cache( - embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache + embed_tensor=all_img_embeds[start:end], + start_index_in_cache=image.start_index_in_embed_cache, ) - ids_to_set.append(uid) - if ids_to_set: - self.cache_client.root.set_items_embed(ids_to_set) + ids_to_set.append(uid) + + if ids_to_set: + self.cache_client.root.set_items_embed(ids_to_set) + if not self.remote_vit: torch.cuda.current_stream().synchronize() return class VisualModelRpcClient: - def __init__(self, model_rpc, vit_tp, rpc_server_process=None): - self.model: VisualModelRpcServer = model_rpc + def __init__(self, conn, vit_tp, rpc_server_process=None): + self.conn = conn + self.model: VisualModelRpcServer = conn.root self.vit_tp = vit_tp self.rpc_server_process = rpc_server_process self.use_rpc = True + self._bg = rpyc.BgServingThread(self.conn) + if self.use_rpc: def async_wrap(f): @@ -161,15 +183,12 @@ async def _func(*args, **kwargs): else: self._init_model = self.model.exposed_init_model self._encode = self.model.exposed_encode - return async def init_model(self, kvargs): ans: rpyc.AsyncResult = self._init_model(kvargs) if self.use_rpc: await ans return - else: - return async def encode(self, images: List[ImageItem]): ans = self._encode(images) @@ -215,4 +234,4 @@ async def start_model_process(port, vit_tp, device_id): raise Exception("init rpc env error!") assert proc.is_alive() - return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc) + return VisualModelRpcClient(con, vit_tp, rpc_server_process=proc) diff --git a/lightllm/server/visualserver/register_loop.py b/lightllm/server/visualserver/register_loop.py new file mode 100644 index 0000000000..31d0f7b8ac --- /dev/null +++ b/lightllm/server/visualserver/register_loop.py @@ -0,0 +1,42 @@ +import asyncio +import pickle +import websockets +import socket +from lightllm.utils.net_utils import get_hostname_ip +from lightllm.utils.log_utils import init_logger +from .vit_connect import VIT_Obj + +logger = init_logger(__name__) + + +async def register_loop(args): + assert args.host not in ["127.0.0.1", "localhost"], "remote visual server must specify host ip" + + if args.host in ["0.0.0.0"]: + host_ip = get_hostname_ip() + else: + host_ip = args.host + + while True: + + try: + uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_register" + async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket: + + sock = websocket.transport.get_extra_info("socket") + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip_port=f"{host_ip}:{args.remote_vit_port}") + + await websocket.send(pickle.dumps(vit_obj)) + logger.info(f"Sent registration vit_obj: {vit_obj}") + + while True: + await websocket.send("heartbeat") + await asyncio.sleep(40) + + except Exception as e: + logger.error("connetion to config_server has error") + logger.exception(str(e)) + await asyncio.sleep(10) + logger.info("reconnection to config_server") diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py new file mode 100644 index 0000000000..7a1443f025 --- /dev/null +++ b/lightllm/server/visualserver/vit_connect.py @@ -0,0 +1,236 @@ +import asyncio +import zmq +import zmq.asyncio +import time +import pickle +from typing import Dict, List, Optional, Any +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs.io_objs import GroupReqObjs, GroupReqIndexes +from lightllm.server.multimodal_params import MultimodalParams +import httpx +import base64 +from dataclasses import dataclass +import rpyc + +logger = init_logger(__name__) + + +@dataclass +class VIT_Obj: + node_id: int + host_ip_port: str + + def to_log_str(self): + return f"VIT host_ip_port: {self.host_ip_port} node_id: {self.node_id}" + + +class VITConnectionManager: + """VIT连接管理器""" + + def __init__(self, args, context, local_visual_port: int, cache_client: rpyc.Connection): + self.args = args + self.context = context + self.local_visual_port = local_visual_port + + self.send_to_visual = None + self.remote_vit_instances = {} + self.current_vit_index = 0 + self.remote_vit = args.enable_remote_vit + self.remote_vit_port = args.remote_vit_port + self.cache_client = cache_client + + self._setup_vit_connections() + + def _setup_vit_connections(self): + """ + 设置VIT连接,支持本地和远程VIT实例 + 支持多种连接模式: + 1. 本地VIT实例 (默认) + 2. 远程多个VIT实例 (负载均衡) + """ + if self.remote_vit: + # 远程VIT实例模式 + self._setup_remote_vit_connections() + else: + print("not remote") + self._setup_local_vit_connection() + + def _setup_local_vit_connection(self): + self.send_to_visual = self.context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + + def _setup_remote_vit_connections(self): + """ + 初始化远程VIT连接,同步获取初始实例 + """ + logger.info("Setting up remote VIT connections...") + + self._sync_init_vit_instances() + + retry_count = 0 + max_retries = 30 # 最多等待30秒 + while len(self.remote_vit_instances) == 0 and retry_count < max_retries: + logger.info(f"Waiting for VIT instances... (attempt {retry_count + 1}/{max_retries})") + time.sleep(1) + retry_count += 1 + self._sync_init_vit_instances() + + if len(self.remote_vit_instances) == 0: + logger.warning("No VIT instances available after initialization") + else: + logger.info(f"Successfully connected to {len(self.remote_vit_instances)} VIT instances") + + def _sync_init_vit_instances(self): + """ + 同步初始化VIT实例连接 + """ + try: + # 使用同步方式获取VIT实例 + vit_objs = self._sync_get_vit_objs() + if vit_objs: + self._update_vit_connections(vit_objs) + except Exception as e: + logger.error(f"Failed to initialize VIT instances: {e}") + + def _sync_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + 同步获取VIT实例信息 + """ + import requests + + uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + response = requests.get(uri, timeout=10) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.error(f"Error getting VIT instances: {e}") + return None + + def _update_vit_connections(self, id_to_vit_obj: Dict[int, VIT_Obj]): + """ + 更新VIT连接,添加新的连接,关闭失效的连接 + """ + # 关闭不再存在的连接 + closed_ids = [] + for id, remote_instance in self.remote_vit_instances.items(): + if id not in id_to_vit_obj: + try: + remote_instance.close() + except: + pass + closed_ids.append(id) + logger.info(f"Closed VIT connection {id}") + + for id in closed_ids: + self.remote_vit_instances.pop(id) + + # 建立新的连接 + for id, vit_obj in id_to_vit_obj.items(): + if id not in self.remote_vit_instances: + try: + socket = self.context.socket(zmq.PUSH) + # print(vit_obj.host_ip_port, self.args.remote_vit_port, flush=True) + ip, port = vit_obj.host_ip_port.split(":") + socket.connect(f"tcp://{ip}:{port}") + self.remote_vit_instances[id] = socket + logger.info(f"Connected to VIT instance {id} at {vit_obj.host_ip_port}") + except Exception as e: + logger.error(f"Failed to connect to VIT instance {id}: {e}") + + def _get_vit_instance(self): + """ + 获取下一个可用的VIT实例 (轮询负载均衡) + """ + if not self.remote_vit: + return self.send_to_visual + + if len(self.remote_vit_instances) == 0: + raise Exception("No available VIT instances") + + # 简单的轮询负载均衡 + index = (self.current_vit_index + 1) % len(self.remote_vit_instances) + self.current_vit_index = index + return list(self.remote_vit_instances.values())[index] + + async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL, embeding_only=False): + """ + 发送数据到VIT实例,支持本地和远程模式 + """ + instance = self._get_vit_instance() + # 本地模式下,提前释放图片资源,降低传输开销 + if not self.remote_vit: + req.multimodal_params.free() + + try: + print(instance, flush=True) + instance.send_pyobj(req, protocol=protocol) + except Exception as e: + logger.error(f"Failed to send to VIT instance: {e}") + raise Exception(f"Failed to send to VIT instance: {e}") + + # 远程模式下,发送完以后,在释放图片资源 + await self._wait_visual_embed_ready(req, embeding_only) + if self.remote_vit: + req.multimodal_params.free() + + async def vit_handle_loop(self): + """ + 异步VIT连接管理循环,由外部启动 + """ + if not self.remote_vit: + return + logger.info("Starting VIT connection management loop") + while True: + try: + id_to_vit_obj = await self._async_get_vit_objs() + if id_to_vit_obj: + self._update_vit_connections(id_to_vit_obj) + await asyncio.sleep(30) + except Exception as e: + logger.exception(f"Error in VIT handle loop: {e}") + await asyncio.sleep(10) + + async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + 异步获取VIT实例信息 + """ + uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + async with httpx.AsyncClient() as client: + response = await client.get(uri) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.exception(f"Error getting VIT instances: {e}") + return None + + async def _wait_visual_embed_ready( + self, req: GroupReqIndexes, embeding_only: bool = False, timeout_seconds: int = 1000 + ): + # 本地模式不需要等待 + if not self.remote_vit: + return + uuids = req.multimodal_params.get_all_uuids() + + async def wait_for_embeds(): + while not all(self.cache_client.root.get_items_embed(uuids, embeding_only)): + await asyncio.sleep(0.01) + + try: + await asyncio.wait_for(wait_for_embeds(), timeout=timeout_seconds) + except asyncio.TimeoutError: + logger.error( + f"Req {req.group_req_id}: timeout waiting for visual embed ready after {timeout_seconds} seconds" + ) diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 65ac401d4c..f0b06ead1c 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -55,19 +55,23 @@ def get_environ(environ_name): def init_vision_distributed_env(kvargs): - tp_world_size = kvargs["vit_tp"] + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + tp_world_size = args.visual_tp dp_size = 1 tp_rank_id = kvargs["tp_rank_id"] set_dp_size(dp_size) set_dp_world_size(tp_world_size) set_current_rank_in_dp(tp_rank_id) - visual_gpu_ids = kvargs["visual_gpu_ids"] + visual_gpu_ids = args.visual_gpu_ids device_id = visual_gpu_ids[kvargs["vit_rank_id"]] set_current_device_id(device_id) torch.cuda.set_device(device_id) + visual_nccl_port = args.visual_nccl_ports[kvargs["dp_rank_id"]] dist.init_process_group( "nccl", - init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', + init_method=f"tcp://127.0.0.1:{visual_nccl_port}", rank=kvargs["tp_rank_id"], world_size=tp_world_size, device_id=torch.device(f"cuda:{device_id}"), diff --git a/lightllm/utils/redis_utils.py b/lightllm/utils/redis_utils.py new file mode 100644 index 0000000000..acc4deb589 --- /dev/null +++ b/lightllm/utils/redis_utils.py @@ -0,0 +1,74 @@ +import subprocess +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def start_redis_service(args): + """launch redis service""" + if not hasattr(args, "start_redis") or not args.start_redis: + return None + + config_server_host = args.config_server_host + redis_port = args.redis_port + try: + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "FLUSHALL", "ASYNC"], check=False, timeout=2 + ) + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "SHUTDOWN", "NOSAVE"], check=False, timeout=2 + ) + except Exception: + pass + + try: + redis_command = [ + "redis-server", + "--port", + str(redis_port), + "--bind", + f"{config_server_host}", + "--daemonize", + "no", + "--logfile", + "/dev/stdout", + "--loglevel", + "notice", + "--save", + '""', # 不触发 RDB 快照 + "--appendonly", + "no", # 关闭 AOF + ] + + logger.info(f"Starting Redis service on port {redis_port}") + redis_process = subprocess.Popen(redis_command) + + import redis + import time + + max_wait = 10 + start_time = time.time() + + while time.time() - start_time < max_wait: + try: + r = redis.Redis(host=args.config_server_host, port=redis_port, socket_connect_timeout=1) + r.ping() + logger.info(f"Redis service started successfully on port {redis_port}") + del r + break + except Exception: + time.sleep(0.5) + if redis_process.poll() is not None: + logger.error("Redis service failed to start") + return None + else: + logger.error("Redis service startup timeout") + if redis_process.poll() is None: + redis_process.terminate() + return None + + return redis_process + + except Exception as e: + logger.error(f"Failed to start Redis service: {e}") + return None diff --git a/lightllm/utils/start_utils.py b/lightllm/utils/start_utils.py index 372b7e1cfa..8245431084 100644 --- a/lightllm/utils/start_utils.py +++ b/lightllm/utils/start_utils.py @@ -111,4 +111,12 @@ def kill_recursive(proc): logger.warning(f"Process {proc.pid} does not exist.") +def is_multimodal_mode(args): + from transformers import PretrainedConfig + + model_cfg, _ = PretrainedConfig.get_config_dict(args.model_dir) + is_multimodal = "visual" in model_cfg or "vision_config" in model_cfg + return is_multimodal + + process_manager = SubmoduleManager() diff --git a/requirements.txt b/requirements.txt index 5b0b201ae3..298c5fe6ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -94,3 +94,4 @@ partial_json_parser==0.2.1.1.post6 websockets==15.0.1 cupy-cuda12x==13.6.0 nixl==0.8.0 +xformers==0.0.33.post2 From f80b864d952692a15c97a1dbaa3054447815ae52 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 12 Mar 2026 14:08:10 +0800 Subject: [PATCH 2/4] fix --- lightllm/server/api_start.py | 25 ++++++++++ .../impl/memory_cache_with_redis.py | 22 ++++++--- .../embed_cache/impl/naive_memory_cache.py | 21 +++++---- lightllm/server/embed_cache/utils.py | 47 ++++++++++++++----- lightllm/server/httpserver/manager.py | 14 +++--- .../visualserver/model_infer/model_rpc.py | 4 +- lightllm/server/visualserver/vit_connect.py | 10 ++-- 7 files changed, 107 insertions(+), 36 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index e69588ff80..f628ec46e3 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -22,6 +22,30 @@ logger = init_logger(__name__) +def _ensure_remote_vit_embed_dir(image_embed_dir: str) -> None: + if os.path.exists(image_embed_dir): + if not os.path.isdir(image_embed_dir): + raise ValueError(f"image_embed_dir is not a directory: {image_embed_dir}") + return + + os.makedirs(image_embed_dir, mode=0o777, exist_ok=True) + os.chmod(image_embed_dir, 0o777) + + +def _prepare_remote_vit_embed_dir(args): + remote_vit_mode = args.enable_remote_vit or args.run_mode in ["visual", "visual_only"] + if not remote_vit_mode: + return + + if not args.image_embed_dir: + raise ValueError("remote vit mode requires --image_embed_dir to be set") + + args.image_embed_dir = os.path.abspath(args.image_embed_dir) + _ensure_remote_vit_embed_dir(args.image_embed_dir) + + logger.info(f"using image_embed_dir: {args.image_embed_dir}") + + def setup_signal_handlers(http_server_process, process_manager): def signal_handler(sig, frame): if sig == signal.SIGINT: @@ -164,6 +188,7 @@ def check_and_set_args(args): assert args.mtp_step == 0 args.enable_multimodal = is_multimodal_mode(args) + _prepare_remote_vit_embed_dir(args) # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 05bd0bc23e..9cfc8364a4 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -27,16 +27,24 @@ def __init__(self, args) -> None: # 这里之所以把cache * 2是因为,在分离模式下,cache 服务只是为了更新redis状态,以及维护图片cache的 token_id # 便于 dynamic prompt cache 的使用。所以要把cache_capacity * 2,保障其保留的图片cache > redis 服务维护的 # 硬盘里的图片image embed 数量。 - self.cache_capacity = args.cache_capacity * 2 + self.capacity = max(1, args.cache_capacity * 2) # llm 负责release def release(self, ids: list[int]) -> None: with self.lock: for id in ids: - self._records[id].ref -= 1 - if self.redis_cache.query(str(id)): + rec = self._records.get(id) + if rec is None: + continue + + redis_exist = self.redis_cache.query(str(id)) + if redis_exist: self.redis_cache.decr(str(id)) - # print(self.redis_cache.stats(), flush=True) + + # remote_vit 模式下 release 可能走“预层提前释放 + 请求结束兜底释放”两条路径, + # 这里避免本地 ref 被重复减成负数,保证 release 可重复调用。 + if rec.ref > 0: + self._update_record_ref(rec, -1) # vit 负责set def set_items_embed(self, ids: list[int]) -> None: @@ -44,8 +52,10 @@ def set_items_embed(self, ids: list[int]) -> None: for id in ids: self.redis_cache.insert(str(id)) self._records[id].embed = True - self._records[id].ref -= 1 - self.redis_cache.decr(str(id)) # vit端alloc之后ref+1 vit完成后ref-1 + if self._records[id].ref > 0: + self._update_record_ref_by_id(id, -1) + # 保留一份 redis 引用,直到真正的消费者读取完成后再 release, + # 避免 VIT 刚写完文件但 LLM 还没来得及读取时被 LRU 误删。 def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: ret = [] diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 8b7528cd0f..0788803753 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -126,18 +126,26 @@ def _free_to_alloc(self, free_min_count: int, new_md5_dict: Dict[str, int]) -> D def _add_ref(self, md5_sum): rec: Record = self._md5_to_record[md5_sum] - self._sorted_records.remove(rec) - rec.ref += 1 - self._sorted_records.add(rec) + self._update_record_ref(rec, 1) return def _del_ref(self, md5_sum): rec: Record = self._md5_to_record[md5_sum] + self._update_record_ref(rec, -1) + return + + def _update_record_ref(self, rec: Record, delta: int): self._sorted_records.remove(rec) - rec.ref -= 1 + rec.ref += delta + rec.visittime = time.time() self._sorted_records.add(rec) return + def _update_record_ref_by_id(self, id_: int, delta: int): + rec: Record = self._id_to_records[id_] + self._update_record_ref(rec, delta) + return + def _judge_enough_token_cache(self, md5sum_list: list[str], token_num_list: list[int]) -> bool: tmp_dict = {} for md5, token_num in zip(md5sum_list, token_num_list): @@ -214,10 +222,7 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l def release(self, ids: list[int]) -> None: with self.lock: for id_ in ids: - rec: Record = self._id_to_records[id_] - self._sorted_records.remove(rec) - rec.ref -= 1 - self._sorted_records.add(rec) + self._update_record_ref_by_id(id_, -1) def set_items_data(self, ids: list[int]) -> None: for id_ in ids: diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index aa0dc903dd..5e09aafed3 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -12,6 +12,22 @@ logger = init_logger(__name__) +def _get_afs_path(base_dir: str, name: str) -> Path: + if not base_dir: + raise ValueError("image_embed_dir must be set before using disk-backed embed cache") + return Path(base_dir) / name + + +def _ensure_afs_dir(base_dir: Path) -> None: + if base_dir.exists(): + if not base_dir.is_dir(): + raise ValueError(f"image_embed_dir is not a directory: {base_dir}") + return + + base_dir.mkdir(parents=True, mode=0o777, exist_ok=True) + os.chmod(base_dir, 0o777) + + def tensor2bytes(t: torch.Tensor): buf = BytesIO() t = t.detach().cpu() @@ -37,16 +53,27 @@ def create_shm(name, data): def create_afs(name, data, path): + target_path = _get_afs_path(path, name) + _ensure_afs_dir(target_path.parent) + data_size = len(data) + tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}" + try: - data_size = len(data) - path = os.path.join(path, name) - with open(path, "xb") as f: + with open(tmp_path, "wb") as f: mem_view = memoryview(data) f.write(mem_view[:data_size]) f.flush() os.fsync(f.fileno()) - except FileExistsError: - print("Warning create afs {} failed because of FileExistsError!".format(name)) + os.chmod(tmp_path, 0o777) + os.replace(tmp_path, target_path) + os.chmod(target_path, 0o777) + except Exception: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + logger.exception(f"failed to create embed file: {target_path}") + raise def read_shm(name): @@ -56,8 +83,7 @@ def read_shm(name): def read_afs(name: str, base_dir) -> bytes: - - path = Path(base_dir) / name + path = _get_afs_path(base_dir, name) return path.read_bytes() @@ -68,8 +94,8 @@ def free_shm(name): def free_afs(name: str, base_dir) -> None: - path = Path(base_dir) / name - path.unlink() + path = _get_afs_path(base_dir, name) + path.unlink(missing_ok=True) def get_shm_name_data(uid): @@ -250,8 +276,7 @@ def _md5_to_afs_path(self, md5: str) -> str: """Convert md5 to AFS file path.""" if not self.image_embed_dir: return None - filename = self.image_embed_dir + md5 + self.path_ext - return filename + return str(_get_afs_path(self.image_embed_dir, f"{md5}{self.path_ext}")) def _delete_afs_files(self, victims: List[str]) -> None: """Delete AFS files for evicted md5s.""" diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 08966545e2..9c5ad1aeb6 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -144,6 +144,9 @@ async def _alloc_resource(self, items, uuids, token_nums, datas): # # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server if self.args.enable_remote_vit: + # 对已经命中的 embed 立即增加一份引用,确保从命中这一刻开始到 LLM 读取完成前 + # 都不会被 LRU 提前淘汰。对未 ready 的项,该调用不会增加引用。 + obtain(self.cache_client.root.get_items_embed(uid_list, False)) return ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) @@ -215,7 +218,7 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam audio.token_id = None audio.token_num = None audio.start_index_in_embed_cache = None - if ids_to_release and not self.args.enable_remote_vit: + if ids_to_release: self.cache_client.root.release(ids_to_release) return @@ -446,11 +449,13 @@ async def get_image_embeding( visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time) await self.transfer_to_next_module_or_node( - None, sampling_params, original_multimodal_params, visual_req_status, embeding_only=True + None, sampling_params, original_multimodal_params, visual_req_status ) + await self._release_multimodal_resources(multimodal_params) except Exception as e: logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + await self._release_multimodal_resources(multimodal_params) await self.abort(group_request_id) raise e return @@ -556,7 +561,6 @@ async def transfer_to_next_module_or_node( sampling_params: SamplingParams, original_multimodal_params: MultimodalParams, group_req_objs: Optional[GroupReqObjs] = None, - embeding_only: Optional[bool] = False, ): # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. if self.is_multinode_tp_master: @@ -566,13 +570,12 @@ async def transfer_to_next_module_or_node( protocol=pickle.HIGHEST_PROTOCOL, ) - await self.transfer_to_next_module(group_req_objs, embeding_only) + await self.transfer_to_next_module(group_req_objs) return async def transfer_to_next_module( self, group_req_objs: Optional[GroupReqObjs] = None, - embeding_only: Optional[bool] = False, ): if self.pd_mode.is_P_or_NORMAL(): @@ -580,7 +583,6 @@ async def transfer_to_next_module( await self.vit_manager.send_to_vit( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, - embeding_only=embeding_only, ) if self.args.enable_cpu_cache: diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index a9f8e1223a..9db68569ca 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -46,6 +46,8 @@ def exposed_init_model(self, kvargs): self.vit_rank_id = kvargs["vit_rank_id"] self.image_embed_dir = self.args.image_embed_dir self.remote_vit = self.args.enable_remote_vit or self.args.run_mode in ["visual", "visual_only"] + if self.remote_vit and not self.image_embed_dir: + raise ValueError("remote vit mode requires image_embed_dir") self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -127,7 +129,7 @@ def exposed_encode(self, images: List[ImageItem]): if self.tp_rank_id != 0: return - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) + ready_flags = obtain(self.cache_client.root.get_items_embed(uuids, True)) ids_to_set = [] cpu_embeds = None if self.remote_vit: diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index 7a1443f025..f0ca614d7d 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -159,7 +159,7 @@ def _get_vit_instance(self): self.current_vit_index = index return list(self.remote_vit_instances.values())[index] - async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL, embeding_only=False): + async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL): """ 发送数据到VIT实例,支持本地和远程模式 """ @@ -176,7 +176,7 @@ async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOC raise Exception(f"Failed to send to VIT instance: {e}") # 远程模式下,发送完以后,在释放图片资源 - await self._wait_visual_embed_ready(req, embeding_only) + await self._wait_visual_embed_ready(req) if self.remote_vit: req.multimodal_params.free() @@ -217,7 +217,9 @@ async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: return None async def _wait_visual_embed_ready( - self, req: GroupReqIndexes, embeding_only: bool = False, timeout_seconds: int = 1000 + self, + req: GroupReqIndexes, + timeout_seconds: int = 1000, ): # 本地模式不需要等待 if not self.remote_vit: @@ -225,7 +227,7 @@ async def _wait_visual_embed_ready( uuids = req.multimodal_params.get_all_uuids() async def wait_for_embeds(): - while not all(self.cache_client.root.get_items_embed(uuids, embeding_only)): + while not all(self.cache_client.root.get_items_embed(uuids, True)): await asyncio.sleep(0.01) try: From 100088aff01d08f34c6b565bd87be78257c665b0 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Thu, 12 Mar 2026 20:36:57 +0800 Subject: [PATCH 3/4] opt: opt perf --- .../qwen_vl/layer_infer/pre_layer_infer.py | 51 +++++++++++++++---- lightllm/server/embed_cache/utils.py | 32 ++++++++++++ .../visualserver/model_infer/model_rpc.py | 5 +- 3 files changed, 74 insertions(+), 14 deletions(-) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index faff792901..242707e22c 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -6,7 +6,7 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed, read_afs +from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed, load_tensor_afs from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce from lightllm.utils.envs_utils import get_env_start_args @@ -37,6 +37,18 @@ def __init__(self, network_config): self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) return + def _copy_loaded_embed_to_cache( + self, embed_tensor: torch.Tensor, cpu_embed_cache_tensor: torch.Tensor, start_index: int + ): + if embed_tensor.ndim == 2: + token_num, hidden_size = embed_tensor.shape + cpu_embed_cache_tensor[start_index : start_index + token_num, 0, :hidden_size].copy_(embed_tensor) + return + + token_num, layer_num, hidden_size = embed_tensor.shape + cpu_embed_cache_tensor[start_index : start_index + token_num, :layer_num, :hidden_size].copy_(embed_tensor) + return + def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): img_start_token_ids = [] img_token_lens = [] @@ -65,20 +77,37 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei ) if self.args.enable_remote_vit: + unique_multimodal_items = [] + seen_uuids = set() + release_ids = [] for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: if img["token_num"] is None: continue - if self.args.image_embed_dir: - embed_bytes = read_afs(get_shm_name_embed(img["uuid"]), self.args.image_embed_dir) - else: - embed_bytes = read_shm(get_shm_name_embed(img["uuid"])) - embed_tensor = bytes2tensor(embed_bytes).to(device="cuda", non_blocking=True) - g_infer_context.cpu_embed_cache_client.copy_vision_to_cache( - embed_tensor=embed_tensor, - start_index_in_cache=img["start_index_in_embed_cache"], - ) - self.cache_client.root.release([img["uuid"]]) + uid = img["uuid"] + release_ids.append(uid) + if uid in seen_uuids: + continue + seen_uuids.add(uid) + unique_multimodal_items.append((uid, img["start_index_in_embed_cache"])) + + if self.args.image_embed_dir: + image_embed_dir = self.args.image_embed_dir + + def load_embed_tensor(uid): + return load_tensor_afs(get_shm_name_embed(uid), image_embed_dir) + + else: + + def load_embed_tensor(uid): + return bytes2tensor(read_shm(get_shm_name_embed(uid))) + + for uid, start_index_in_embed_cache in unique_multimodal_items: + embed_tensor = load_embed_tensor(uid) + self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache) + + if release_ids: + self.cache_client.root.release(release_ids) assert cpu_embed_cache_tensor.shape[2] == hidden_size, ( f"Dimension mismatch: text weight dimension is {hidden_size}, " diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 5e09aafed3..7201703ec1 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -10,6 +10,7 @@ from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) +_ENSURED_AFS_DIRS = set() def _get_afs_path(base_dir: str, name: str) -> Path: @@ -19,13 +20,18 @@ def _get_afs_path(base_dir: str, name: str) -> Path: def _ensure_afs_dir(base_dir: Path) -> None: + base_dir_key = str(base_dir) + if base_dir_key in _ENSURED_AFS_DIRS: + return if base_dir.exists(): if not base_dir.is_dir(): raise ValueError(f"image_embed_dir is not a directory: {base_dir}") + _ENSURED_AFS_DIRS.add(base_dir_key) return base_dir.mkdir(parents=True, mode=0o777, exist_ok=True) os.chmod(base_dir, 0o777) + _ENSURED_AFS_DIRS.add(base_dir_key) def tensor2bytes(t: torch.Tensor): @@ -42,6 +48,32 @@ def bytes2tensor(b): return torch.load(BytesIO(b), weights_only=False) +def save_tensor_afs(name: str, tensor: torch.Tensor, base_dir: str) -> None: + target_path = _get_afs_path(base_dir, name) + _ensure_afs_dir(target_path.parent) + tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}" + + try: + with open(tmp_path, "wb") as f: + torch.save(tensor.detach().cpu(), f, _use_new_zipfile_serialization=False, pickle_protocol=4) + os.chmod(tmp_path, 0o777) + os.replace(tmp_path, target_path) + os.chmod(target_path, 0o777) + except Exception: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + logger.exception(f"failed to save embed tensor file: {target_path}") + raise + + +def load_tensor_afs(name: str, base_dir: str) -> torch.Tensor: + path = _get_afs_path(base_dir, name) + with open(path, "rb") as f: + return torch.load(f, weights_only=False) + + def create_shm(name, data): try: data_size = len(data) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 9db68569ca..237e52ce6f 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -26,7 +26,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient -from lightllm.server.embed_cache.utils import create_afs, get_shm_name_embed, tensor2bytes +from lightllm.server.embed_cache.utils import create_afs, get_shm_name_embed, tensor2bytes, save_tensor_afs from lightllm.server.visualserver import set_vit_att_backend @@ -142,8 +142,7 @@ def exposed_encode(self, images: List[ImageItem]): start, end = valid_ids[i] image = images[i] if self.remote_vit: - cur_embed_bytes = tensor2bytes(cpu_embeds[start:end]) - create_afs(get_shm_name_embed(uid), cur_embed_bytes, self.image_embed_dir) + save_tensor_afs(get_shm_name_embed(uid), cpu_embeds[start:end], self.image_embed_dir) else: self.cpu_embed_cache_client.copy_vision_to_cache( embed_tensor=all_img_embeds[start:end], From 8d63cbf54354538314fda6af174ac19e40d8d6b0 Mon Sep 17 00:00:00 2001 From: niushengxiao Date: Fri, 13 Mar 2026 14:07:45 +0800 Subject: [PATCH 4/4] refine --- .../common/basemodel/attention_vit/fa3/fp.py | 5 +- .../qwen_vl/layer_infer/pre_layer_infer.py | 38 ++++---------- .../vit/triton_kernel/flashattention_nopad.py | 6 +-- lightllm/server/api_lightllm.py | 3 +- lightllm/server/api_start.py | 52 +++++++++---------- .../impl/memory_cache_with_redis.py | 26 +++------- lightllm/server/embed_cache/manager.py | 4 -- lightllm/server/embed_cache/utils.py | 20 ------- lightllm/server/httpserver/manager.py | 10 ++-- lightllm/server/multimodal_params.py | 4 ++ lightllm/server/visualserver/vit_connect.py | 2 - lightllm/utils/dist_utils.py | 10 ++-- requirements.txt | 1 + 13 files changed, 59 insertions(+), 122 deletions(-) diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index c5b8cc1076..f1bef078a7 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -18,7 +18,7 @@ def _vit_att_fwd( head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 window_size = (-1, -1) - attn_output = flash_attn_varlen_func( + o = flash_attn_varlen_func( q, k, v, @@ -29,8 +29,7 @@ def _vit_att_fwd( softmax_scale=softmax_scale, causal=False, window_size=window_size, + attention_chunk=0, softcap=0.0, ) - o.copy_(attn_output) - return o diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 242707e22c..0127fbea8b 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -6,7 +6,7 @@ from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer -from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed, load_tensor_afs +from lightllm.server.embed_cache.utils import get_shm_name_embed, load_tensor_afs from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce from lightllm.utils.envs_utils import get_env_start_args @@ -41,9 +41,7 @@ def _copy_loaded_embed_to_cache( self, embed_tensor: torch.Tensor, cpu_embed_cache_tensor: torch.Tensor, start_index: int ): if embed_tensor.ndim == 2: - token_num, hidden_size = embed_tensor.shape - cpu_embed_cache_tensor[start_index : start_index + token_num, 0, :hidden_size].copy_(embed_tensor) - return + embed_tensor = embed_tensor.unsqueeze(1) token_num, layer_num, hidden_size = embed_tensor.shape cpu_embed_cache_tensor[start_index : start_index + token_num, :layer_num, :hidden_size].copy_(embed_tensor) @@ -53,11 +51,12 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_token_ids = [] img_token_lens = [] img_start_locs_in_cache = [] + unique_uids = [] device = layer_weight.wte_weight_.weight.device dtype = layer_weight.wte_weight_.weight.dtype hidden_size = layer_weight.wte_weight_.weight.shape[1] - for batch_id, p in enumerate(infer_state.multimodal_params): + for _, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: # skip the same image if img["token_id"] in img_start_token_ids: @@ -65,6 +64,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs_in_cache.append(img["start_index_in_embed_cache"]) + unique_uids.append(img["uuid"]) out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) from lightllm.server.router.model_infer.infer_batch import g_infer_context @@ -77,33 +77,13 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei ) if self.args.enable_remote_vit: - unique_multimodal_items = [] - seen_uuids = set() release_ids = [] - for batch_id, p in enumerate(infer_state.multimodal_params): + for _, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: - if img["token_num"] is None: - continue - uid = img["uuid"] - release_ids.append(uid) - if uid in seen_uuids: - continue - seen_uuids.add(uid) - unique_multimodal_items.append((uid, img["start_index_in_embed_cache"])) + release_ids.append(img["uuid"]) - if self.args.image_embed_dir: - image_embed_dir = self.args.image_embed_dir - - def load_embed_tensor(uid): - return load_tensor_afs(get_shm_name_embed(uid), image_embed_dir) - - else: - - def load_embed_tensor(uid): - return bytes2tensor(read_shm(get_shm_name_embed(uid))) - - for uid, start_index_in_embed_cache in unique_multimodal_items: - embed_tensor = load_embed_tensor(uid) + for uid, start_index_in_embed_cache in zip(unique_uids, img_start_locs_in_cache): + embed_tensor = load_tensor_afs(get_shm_name_embed(uid), self.args.image_embed_dir) self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache) if release_ids: diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index a38e27924b..3a0b2d2069 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -167,7 +167,7 @@ def flash_attention_v3_fwd( head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 window_size = (-1, -1) - attn_output = flash_attn_varlen_func( + o = flash_attn_varlen_func( q, k, v, @@ -180,9 +180,7 @@ def flash_attention_v3_fwd( window_size=window_size, softcap=0.0, ) - o.copy_(attn_output) - - return + return o except ImportError: print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.") diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index fbff8e681c..bfb8bff6db 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -1,11 +1,10 @@ import collections from typing import AsyncGenerator from fastapi import BackgroundTasks, Request -from fastapi.responses import Response, StreamingResponse +from fastapi.responses import Response, StreamingResponse, JSONResponse from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager -from fastapi.responses import JSONResponse import ujson as json diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index f628ec46e3..918ed3c2d0 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -87,7 +87,7 @@ def signal_handler(sig, frame): return -def check_and_set_args(args): +def normal_or_p_d_start(args, only_prepare=False): from lightllm.server.core.objs.start_args_type import StartArgs args: StartArgs = args @@ -219,20 +219,18 @@ def check_and_set_args(args): if args.batch_max_tokens is None: args.batch_max_tokens = args.max_req_total_len else: - assert args.batch_max_tokens >= args.max_req_total_len, ( - f"batch_max_tokens must >= max_req_total_len" - f"but got {args.batch_max_tokens}, {args.max_req_total_len}" - ) + assert args.batch_max_tokens >= args.max_req_total_len, f"batch_max_tokens must >= max_req_total_len" + f"but got {args.batch_max_tokens}, {args.max_req_total_len}" else: # chunked 模式下 if args.batch_max_tokens is None: args.batch_max_tokens = 16384 // args.dp if args.chunked_prefill_size is None: args.chunked_prefill_size = args.batch_max_tokens // 2 - assert args.batch_max_tokens >= args.chunked_prefill_size, ( - "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, " - f"but got {args.batch_max_tokens}, {args.chunked_prefill_size}" - ) + assert ( + args.batch_max_tokens >= args.chunked_prefill_size + ), "chunked prefill mode, batch_max_tokens must >= chunked_prefill_size, " + f"but got {args.batch_max_tokens}, {args.chunked_prefill_size}" # help to manage data stored on Ceph if "s3://" in args.model_dir: @@ -252,9 +250,8 @@ def check_and_set_args(args): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] - -def normal_or_p_d_start(args): - check_and_set_args(args) + if only_prepare: + return already_uesd_ports = [args.port] if args.nccl_port is not None: @@ -291,17 +288,19 @@ def normal_or_p_d_start(args): can_use_ports = can_use_ports[10:] visual_model_tp_ports = [] + visual_nccl_ports = [] for _ in range(args.visual_dp): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] visual_model_tp_ports.append(tp_ports_for_dp) can_use_ports = can_use_ports[args.visual_tp :] + if args.visual_nccl_ports is None: + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] - if args.visual_nccl_ports is None: - visual_nccl_ports = can_use_ports[0 : args.visual_dp] - can_use_ports = can_use_ports[args.visual_dp :] - else: - visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + if args.visual_nccl_ports is not None: + args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + # 将申请好的端口放入args参数中 if args.nccl_port is None: args.nccl_port = nccl_port if args.pd_decode_rpyc_port is None: @@ -328,6 +327,7 @@ def normal_or_p_d_start(args): args.router_max_wait_tokens = 0 send_and_receive_node_ip(args) # 多机用于收发node ip + # dp 必须 > 1 if args.enable_dp_prompt_cache_fetch and args.dp <= 1: args.enable_dp_prompt_cache_fetch = False logger.warning( @@ -491,7 +491,7 @@ def pd_master_start(args): def visual_start(args): - check_and_set_args(args) + normal_or_p_d_start(args, only_prepare=True) already_uesd_ports = [args.remote_vit_port] if args.nccl_port is not None: @@ -515,15 +515,16 @@ def visual_start(args): can_use_ports = can_use_ports[5:] visual_model_tp_ports = [] + visual_nccl_ports = [] for _ in range(args.visual_dp): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] - can_use_ports = can_use_ports[args.visual_tp :] visual_model_tp_ports.append(tp_ports_for_dp) + can_use_ports = can_use_ports[args.visual_tp :] + if args.visual_nccl_ports is None: + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] - if args.visual_nccl_ports is None: - args.visual_nccl_ports = can_use_ports[0 : args.visual_dp] - can_use_ports = can_use_ports[args.visual_dp :] - else: + if args.visual_nccl_ports is not None: args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] args.router_port = router_port @@ -531,7 +532,6 @@ def visual_start(args): args.audio_port = audio_port args.cache_port = cache_port args.metric_port = metric_port - args.visual_model_rpc_ports = visual_model_tp_ports args.visual_node_id = uuid.uuid4().int logger.info(f"all start args:{args}") @@ -586,9 +586,9 @@ def config_server_start(args): "--log-level", "info", "--access-logfile", - "/dev/stdout", + "-", "--error-logfile", - "/dev/stderr", + "-", "lightllm.server.config_server.api_http:app", "--keep-alive", f"{get_lightllm_gunicorn_keep_alive()}", diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 9cfc8364a4..6ef8c66f68 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -51,9 +51,11 @@ def set_items_embed(self, ids: list[int]) -> None: with self.lock: for id in ids: self.redis_cache.insert(str(id)) - self._records[id].embed = True - if self._records[id].ref > 0: - self._update_record_ref_by_id(id, -1) + rec = self._records.get(id) + if rec is not None: + rec.embed = True + if rec.ref > 0: + self._update_record_ref_by_id(id, -1) # 保留一份 redis 引用,直到真正的消费者读取完成后再 release, # 避免 VIT 刚写完文件但 LLM 还没来得及读取时被 LRU 误删。 @@ -66,19 +68,7 @@ def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[O exist = self.redis_cache.query_and_incre(str(id)) ret.append(exist) if exist: - self._records[id].embed = True + rec = self._records.get(id) + if rec is not None: + rec.embed = True return ret - - # def get_items_embed_and_incre(self, ids: list[int]) -> list[Optional[bool]]: - # ret = [] - # for id in ids: - # # if self.redis_cache.query(str(id)): - # # ret.append(True) - # # continue - # # 避免重复的引用计数增加 - # if self._records[id].embed: - # ret.append(True) - # continue - # self._records[id].embed = self.redis_cache.query_and_incre(str(id)) - # ret.append(self._records[id].embed) - # return ret diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 0dc8830cdb..faf48c4085 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -26,10 +26,6 @@ def on_disconnect(self, conn): # (to finalize the service, if needed) pass - def exposed__check_and_set_new_id_range(self, token_num: int) -> int: - token_num = obtain(token_num) - return self._impl._check_and_set_new_id_range(token_num) - def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]: md5sum_list = obtain(md5sum_list) token_num_list = obtain(token_num_list) diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 7201703ec1..caeca0b2b6 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -10,7 +10,6 @@ from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) -_ENSURED_AFS_DIRS = set() def _get_afs_path(base_dir: str, name: str) -> Path: @@ -19,21 +18,6 @@ def _get_afs_path(base_dir: str, name: str) -> Path: return Path(base_dir) / name -def _ensure_afs_dir(base_dir: Path) -> None: - base_dir_key = str(base_dir) - if base_dir_key in _ENSURED_AFS_DIRS: - return - if base_dir.exists(): - if not base_dir.is_dir(): - raise ValueError(f"image_embed_dir is not a directory: {base_dir}") - _ENSURED_AFS_DIRS.add(base_dir_key) - return - - base_dir.mkdir(parents=True, mode=0o777, exist_ok=True) - os.chmod(base_dir, 0o777) - _ENSURED_AFS_DIRS.add(base_dir_key) - - def tensor2bytes(t: torch.Tensor): buf = BytesIO() t = t.detach().cpu() @@ -50,13 +34,11 @@ def bytes2tensor(b): def save_tensor_afs(name: str, tensor: torch.Tensor, base_dir: str) -> None: target_path = _get_afs_path(base_dir, name) - _ensure_afs_dir(target_path.parent) tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}" try: with open(tmp_path, "wb") as f: torch.save(tensor.detach().cpu(), f, _use_new_zipfile_serialization=False, pickle_protocol=4) - os.chmod(tmp_path, 0o777) os.replace(tmp_path, target_path) os.chmod(target_path, 0o777) except Exception: @@ -86,7 +68,6 @@ def create_shm(name, data): def create_afs(name, data, path): target_path = _get_afs_path(path, name) - _ensure_afs_dir(target_path.parent) data_size = len(data) tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}" @@ -96,7 +77,6 @@ def create_afs(name, data, path): f.write(mem_view[:data_size]) f.flush() os.fsync(f.fileno()) - os.chmod(tmp_path, 0o777) os.replace(tmp_path, target_path) os.chmod(target_path, 0o777) except Exception: diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 9c5ad1aeb6..0ee3ce03dd 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -144,9 +144,8 @@ async def _alloc_resource(self, items, uuids, token_nums, datas): # # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server if self.args.enable_remote_vit: - # 对已经命中的 embed 立即增加一份引用,确保从命中这一刻开始到 LLM 读取完成前 - # 都不会被 LRU 提前淘汰。对未 ready 的项,该调用不会增加引用。 - obtain(self.cache_client.root.get_items_embed(uid_list, False)) + # 避免远端lru被逐出 + self.cache_client.root.get_items_embed(uid_list, False) return ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) @@ -431,9 +430,6 @@ async def get_image_embeding( if self.is_multinode_tp_master: original_multimodal_params = copy.deepcopy(multimodal_params) - if self.pd_mode.is_P_or_NORMAL(): - await multimodal_params.verify_and_preload(request) - await multimodal_params.verify_and_preload(request) image_count = len(multimodal_params.images) # 记录请求到达的相关信息 @@ -600,7 +596,7 @@ async def transfer_to_next_module( return if self.pd_mode.is_D(): - # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了, 传输一个空的即可 + # 在 D 模式下,不需要传输真的多模态参数,因为其已经被 P 处理好了 self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 9616d5f84c..c86a9a0786 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -57,6 +57,10 @@ def read(self): assert self._preload_data is not None return self._preload_data + def free(self): + self._preload_data = None + self._data = None + def to_dict(self): ret = {} ret["uuid"] = self.uuid diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index f0ca614d7d..af93effaf0 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -52,7 +52,6 @@ def _setup_vit_connections(self): # 远程VIT实例模式 self._setup_remote_vit_connections() else: - print("not remote") self._setup_local_vit_connection() def _setup_local_vit_connection(self): @@ -169,7 +168,6 @@ async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOC req.multimodal_params.free() try: - print(instance, flush=True) instance.send_pyobj(req, protocol=protocol) except Exception as e: logger.error(f"Failed to send to VIT instance: {e}") diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index f0b06ead1c..65ac401d4c 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -55,23 +55,19 @@ def get_environ(environ_name): def init_vision_distributed_env(kvargs): - from lightllm.utils.envs_utils import get_env_start_args - - args = get_env_start_args() - tp_world_size = args.visual_tp + tp_world_size = kvargs["vit_tp"] dp_size = 1 tp_rank_id = kvargs["tp_rank_id"] set_dp_size(dp_size) set_dp_world_size(tp_world_size) set_current_rank_in_dp(tp_rank_id) - visual_gpu_ids = args.visual_gpu_ids + visual_gpu_ids = kvargs["visual_gpu_ids"] device_id = visual_gpu_ids[kvargs["vit_rank_id"]] set_current_device_id(device_id) torch.cuda.set_device(device_id) - visual_nccl_port = args.visual_nccl_ports[kvargs["dp_rank_id"]] dist.init_process_group( "nccl", - init_method=f"tcp://127.0.0.1:{visual_nccl_port}", + init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', rank=kvargs["tp_rank_id"], world_size=tp_world_size, device_id=torch.device(f"cuda:{device_id}"), diff --git a/requirements.txt b/requirements.txt index 298c5fe6ca..5331227586 100644 --- a/requirements.txt +++ b/requirements.txt @@ -95,3 +95,4 @@ websockets==15.0.1 cupy-cuda12x==13.6.0 nixl==0.8.0 xformers==0.0.33.post2 +redis==7.3.0