diff --git a/lightllm/common/basemodel/attention_vit/fa3/fp.py b/lightllm/common/basemodel/attention_vit/fa3/fp.py index f804116f1f..f1bef078a7 100644 --- a/lightllm/common/basemodel/attention_vit/fa3/fp.py +++ b/lightllm/common/basemodel/attention_vit/fa3/fp.py @@ -1,6 +1,7 @@ import dataclasses import torch from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend +from lightllm.utils.sgl_utils import flash_attn_varlen_func class Fa3VitAttBackend(BaseVitAttBackend): @@ -17,42 +18,18 @@ def _vit_att_fwd( head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 window_size = (-1, -1) - torch.ops.sgl_kernel.fwd.default( + o = flash_attn_varlen_func( q, k, v, - None, # k_new - None, # v_new - None, # qv - o, # out - cu_seqlens, - cu_seqlens, - None, # cu_seqlens_k_new - None, - None, - max_seqlen, - max_seqlen, - None, # page_table, - None, # kv_batch_idx - None, # leftpad_k - None, # rotary cos - None, # rotary sin - None, # seqlens_rotary - None, - None, - None, - softmax_scale, - False, - window_size[0], - window_size[1], + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=False, + window_size=window_size, attention_chunk=0, softcap=0.0, - is_rotary_interleaved=False, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - sinks=None, ) - return o diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index ccb76d3512..3e5b9c5e2a 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -68,6 +68,7 @@ def init_imageitem_extral_params( img.extra_params["image_patch_max_num"] = 6 elif num_images > 6: img.extra_params["image_patch_max_num"] = 0 + img.patch_num = self.get_image_patch(img) return def init_audioitem_extral_params( @@ -75,14 +76,14 @@ def init_audioitem_extral_params( ): return - def get_image_token_length(self, img: ImageItem): - return ( - self.get_image_patch_func( - img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True - ) - * self.image_length + def get_image_patch(self, img: ImageItem): + return self.get_image_patch_func( + img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True ) + def get_image_token_length(self, img: ImageItem): + return self.get_image_patch(img) * self.image_length + def get_audio_token_length(self, audio: AudioItem): L = audio.audio_length audio_token_num = 0 diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 9b9fe2569c..0127fbea8b 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -1,11 +1,15 @@ +import rpyc +import socket import torch import torch.distributed as dist from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.server.embed_cache.utils import get_shm_name_embed, load_tensor_afs from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.envs_utils import get_env_start_args """ @@ -26,17 +30,33 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config): super().__init__(network_config) + self.args = get_env_start_args() + self.cache_client = None + if self.args.enable_remote_vit: + self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return + + def _copy_loaded_embed_to_cache( + self, embed_tensor: torch.Tensor, cpu_embed_cache_tensor: torch.Tensor, start_index: int + ): + if embed_tensor.ndim == 2: + embed_tensor = embed_tensor.unsqueeze(1) + + token_num, layer_num, hidden_size = embed_tensor.shape + cpu_embed_cache_tensor[start_index : start_index + token_num, :layer_num, :hidden_size].copy_(embed_tensor) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): img_start_token_ids = [] img_token_lens = [] img_start_locs_in_cache = [] + unique_uids = [] device = layer_weight.wte_weight_.weight.device dtype = layer_weight.wte_weight_.weight.dtype hidden_size = layer_weight.wte_weight_.weight.shape[1] - for batch_id, p in enumerate(infer_state.multimodal_params): + for _, p in enumerate(infer_state.multimodal_params): for img in p["images"] + p["audios"]: # skip the same image if img["token_id"] in img_start_token_ids: @@ -44,6 +64,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs_in_cache.append(img["start_index_in_embed_cache"]) + unique_uids.append(img["uuid"]) out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) from lightllm.server.router.model_infer.infer_batch import g_infer_context @@ -55,6 +76,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei else cpu_embed_cache_client.cpu_embed_cache_tensor ) + if self.args.enable_remote_vit: + release_ids = [] + for _, p in enumerate(infer_state.multimodal_params): + for img in p["images"] + p["audios"]: + release_ids.append(img["uuid"]) + + for uid, start_index_in_embed_cache in zip(unique_uids, img_start_locs_in_cache): + embed_tensor = load_tensor_afs(get_shm_name_embed(uid), self.args.image_embed_dir) + self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache) + + if release_ids: + self.cache_client.root.release(release_ids) + assert cpu_embed_cache_tensor.shape[2] == hidden_size, ( f"Dimension mismatch: text weight dimension is {hidden_size}, " f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}" diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index 8428e52996..3a0b2d2069 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -167,44 +167,20 @@ def flash_attention_v3_fwd( head_dim = q.shape[-1] softmax_scale = head_dim ** -0.5 window_size = (-1, -1) - torch.ops.sgl_kernel.fwd.default( + o = flash_attn_varlen_func( q, k, v, - None, # k_new - None, # v_new - None, # qv - o, # out - cu_seqlens, - cu_seqlens, - None, # cu_seqlens_k_new - None, - None, - max_seqlen, - max_seqlen, - None, # page_table, - None, # kv_batch_idx - None, # leftpad_k - None, # rotary cos - None, # rotary sin - None, # seqlens_rotary - None, - None, - None, - softmax_scale, - False, - window_size[0], - window_size[1], - 0.0, - is_rotary_interleaved=False, - scheduler_metadata=None, - num_splits=1, - pack_gqa=None, - sm_margin=0, - sinks=None, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + softmax_scale=softmax_scale, + causal=False, + window_size=window_size, + softcap=0.0, ) - - return + return o except ImportError: print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.") diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 762d84575b..924b26fce3 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,17 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"], + choices=[ + "normal", + "prefill", + "decode", + "nixl_prefill", + "nixl_decode", + "pd_master", + "config_server", + "visual", + "visual_only", + ], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -593,6 +603,40 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + parser.add_argument( + "--image_embed_dir", + type=str, + default=None, + help="path for vit embed", + ) + parser.add_argument( + "--enable_remote_vit", + action="store_true", + help="Whether to enable remote vit for multimodal service.", + ) + parser.add_argument( + "--remote_vit_port", + type=int, + default=12346, + help="The port number for the remote vit service.", + ) + parser.add_argument( + "--redis_port", + type=int, + default=6379, + help="The port number for the redis service in config_server mode.", + ) + parser.add_argument( + "--redis_evict_fraction", + type=float, + default=0.3, + help="The evict fraction for the redis service in config_server mode.", + ) + parser.add_argument( + "--start_redis", + action="store_true", + help="Whether to start the redis service in config_server mode.", + ) parser.add_argument( "--enable_cpu_cache", action="store_true", diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 230da5b369..bf246f8f0d 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -43,7 +43,7 @@ from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster -from .api_lightllm import lightllm_get_score +from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.log_utils import init_logger from lightllm.utils.error_utils import ServerBusyError @@ -92,6 +92,8 @@ def set_args(self, args: StartArgs): self.httpserver_manager = HttpServerManagerForPDMaster( args=args, ) + elif args.run_mode == "visual": + self.metric_client = MetricClient(args.metric_port) else: init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) @@ -136,7 +138,7 @@ def get_model_name(): @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") async def healthcheck(request: Request): - if g_objs.args.run_mode == "pd_master": + if g_objs.args.run_mode in ["pd_master", "visual"]: return JSONResponse({"message": "Ok"}, status_code=200) if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": @@ -221,6 +223,18 @@ async def get_score(request: Request) -> Response: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) +@app.post("/get_image_embedding") +async def get_image_embed(request: Request) -> Response: + try: + return await lightllm_get_image_embedding(request, g_objs.httpserver_manager) + except ServerBusyError as e: + logger.error("%s", str(e), exc_info=True) + return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e)) + except Exception as e: + logger.error("An error occurred: %s", str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) + + @app.post("/") async def compat_generate(request: Request) -> Response: if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]: @@ -359,6 +373,8 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) + if g_objs.httpserver_manager is None: + return loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index d3592a5f54..bfb8bff6db 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -1,7 +1,7 @@ import collections from typing import AsyncGenerator from fastapi import BackgroundTasks, Request -from fastapi.responses import Response, StreamingResponse +from fastapi.responses import Response, StreamingResponse, JSONResponse from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager @@ -150,3 +150,19 @@ async def stream_results() -> AsyncGenerator[bytes, None]: background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + + +async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response: + request_dict = await request.json() + # request_dict: {'parameters': {'max_new_tokens': 128}, + # 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}} + sample_params_dict = request_dict["parameters"] + sampling_params = SamplingParams() + sampling_params.init(tokenizer=None, **sample_params_dict) + sampling_params.verify() + multimodal_params_dict = request_dict.get("multimodal_params", {}) + multimodal_params = MultimodalParams(**multimodal_params_dict) + + await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request) + + return JSONResponse({"message": "OK"}, status_code=200) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808a..c6700c0416 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -5,11 +5,13 @@ torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess parser = make_argument_parser() args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) + elif args.run_mode == "visual": + visual_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 0db786d0bf..918ed3c2d0 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -5,7 +5,7 @@ import subprocess import signal from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker -from lightllm.utils.start_utils import process_manager, kill_recursive +from lightllm.utils.start_utils import process_manager, kill_recursive, is_multimodal_mode from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger @@ -15,12 +15,37 @@ from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip +from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size from lightllm.utils.config_utils import has_audio_module, has_vision_module logger = init_logger(__name__) +def _ensure_remote_vit_embed_dir(image_embed_dir: str) -> None: + if os.path.exists(image_embed_dir): + if not os.path.isdir(image_embed_dir): + raise ValueError(f"image_embed_dir is not a directory: {image_embed_dir}") + return + + os.makedirs(image_embed_dir, mode=0o777, exist_ok=True) + os.chmod(image_embed_dir, 0o777) + + +def _prepare_remote_vit_embed_dir(args): + remote_vit_mode = args.enable_remote_vit or args.run_mode in ["visual", "visual_only"] + if not remote_vit_mode: + return + + if not args.image_embed_dir: + raise ValueError("remote vit mode requires --image_embed_dir to be set") + + args.image_embed_dir = os.path.abspath(args.image_embed_dir) + _ensure_remote_vit_embed_dir(args.image_embed_dir) + + logger.info(f"using image_embed_dir: {args.image_embed_dir}") + + def setup_signal_handlers(http_server_process, process_manager): def signal_handler(sig, frame): if sig == signal.SIGINT: @@ -57,11 +82,12 @@ def signal_handler(sig, frame): signal.signal(signal.SIGINT, signal_handler) logger.info(f"start process pid {os.getpid()}") - logger.info(f"http server pid {http_server_process.pid}") + if http_server_process: + logger.info(f"http server pid {http_server_process.pid}") return -def normal_or_p_d_start(args): +def normal_or_p_d_start(args, only_prepare=False): from lightllm.server.core.objs.start_args_type import StartArgs args: StartArgs = args @@ -73,7 +99,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "visual", "visual_only"]: return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 @@ -161,6 +187,8 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + args.enable_multimodal = is_multimodal_mode(args) + _prepare_remote_vit_embed_dir(args) # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -222,11 +250,16 @@ def normal_or_p_d_start(args): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + if only_prepare: + return + already_uesd_ports = [args.port] if args.nccl_port is not None: already_uesd_ports.append(args.nccl_port) if args.pd_decode_rpyc_port is not None: already_uesd_ports.append(args.pd_decode_rpyc_port) + if args.visual_nccl_ports is not None: + already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -234,8 +267,10 @@ def normal_or_p_d_start(args): ports_locker.lock_port() node_world_size = args.tp // args.nnodes + need_visual_nccl_ports = 0 if args.visual_nccl_ports is not None else args.visual_dp can_use_ports = alloc_can_use_network_port( - num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_nccl_ports=already_uesd_ports + num=10 + node_world_size + args.visual_dp * args.visual_tp + need_visual_nccl_ports, + used_nccl_ports=already_uesd_ports, ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -258,8 +293,12 @@ def normal_or_p_d_start(args): tp_ports_for_dp = can_use_ports[0 : args.visual_tp] visual_model_tp_ports.append(tp_ports_for_dp) can_use_ports = can_use_ports[args.visual_tp :] - visual_nccl_ports.append(can_use_ports[0]) - can_use_ports = can_use_ports[1:] + if args.visual_nccl_ports is None: + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] + + if args.visual_nccl_ports is not None: + args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] # 将申请好的端口放入args参数中 if args.nccl_port is None: @@ -309,27 +348,27 @@ def normal_or_p_d_start(args): start_args=[(args,)], ) - if not args.disable_vision: - from .visualserver.manager import start_visual_process + if not args.disable_audio: + from .audioserver.manager import start_audio_process process_manager.start_submodule_processes( start_funcs=[ - start_visual_process, + start_audio_process, ], start_args=[ - (args, visual_model_tp_ports), + (args,), ], ) - if not args.disable_audio: - from .audioserver.manager import start_audio_process + if not args.disable_vision and not args.enable_remote_vit: + from .visualserver.manager import start_visual_process process_manager.start_submodule_processes( start_funcs=[ - start_audio_process, + start_visual_process, ], start_args=[ - (args,), + (args, visual_model_tp_ports), ], ) @@ -435,7 +474,6 @@ def pd_master_start(args): "-", "--error-logfile", "-", - "--preload", "lightllm.server.api_http:app", "--keep-alive", f"{get_lightllm_gunicorn_keep_alive()}", @@ -452,6 +490,81 @@ def pd_master_start(args): http_server_process.wait() +def visual_start(args): + normal_or_p_d_start(args, only_prepare=True) + + already_uesd_ports = [args.remote_vit_port] + if args.nccl_port is not None: + already_uesd_ports.append(args.nccl_port) + if args.visual_nccl_ports is not None: + already_uesd_ports.extend(args.visual_nccl_ports[: args.visual_dp]) + + need_visual_nccl_ports = 0 if args.visual_nccl_ports is not None else args.visual_dp + can_use_ports = alloc_can_use_network_port( + num=5 + args.visual_dp * args.visual_tp + need_visual_nccl_ports, + used_nccl_ports=already_uesd_ports, + ) + logger.info(f"alloced ports: {can_use_ports}") + ( + router_port, + visual_port, + audio_port, + cache_port, + metric_port, + ) = can_use_ports[0:5] + can_use_ports = can_use_ports[5:] + + visual_model_tp_ports = [] + visual_nccl_ports = [] + for _ in range(args.visual_dp): + tp_ports_for_dp = can_use_ports[0 : args.visual_tp] + visual_model_tp_ports.append(tp_ports_for_dp) + can_use_ports = can_use_ports[args.visual_tp :] + if args.visual_nccl_ports is None: + visual_nccl_ports.append(can_use_ports[0]) + can_use_ports = can_use_ports[1:] + + if args.visual_nccl_ports is not None: + args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + + args.router_port = router_port + args.visual_port = visual_port + args.audio_port = audio_port + args.cache_port = cache_port + args.metric_port = metric_port + args.visual_node_id = uuid.uuid4().int + + logger.info(f"all start args:{args}") + + set_env_start_args(args) + + from .visualserver.manager import start_visual_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_cache_manager, + ], + start_args=[(args,)], + ) + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, visual_model_tp_ports), + ], + ) + setup_signal_handlers(None, process_manager) + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Received keyboard interrupt, shutting down...") + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully.") + sys.exit(0) + + def config_server_start(args): set_unique_server_name(args) if args.run_mode != "config_server": @@ -459,6 +572,9 @@ def config_server_start(args): logger.info(f"all start args:{args}") + if args.start_redis: + start_redis_service(args) + set_env_start_args(args) command = [ @@ -473,7 +589,6 @@ def config_server_start(args): "-", "--error-logfile", "-", - "--preload", "lightllm.server.config_server.api_http:app", "--keep-alive", f"{get_lightllm_gunicorn_keep_alive()}", diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index c5505acda4..c55b743480 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -9,6 +9,7 @@ from typing import Dict, List from fastapi.responses import JSONResponse from lightllm.utils.log_utils import init_logger +from lightllm.server.visualserver.vit_connect import VIT_Obj from ..pd_io_struct import PD_Master_Obj from .nccl_tcp_store import start_tcp_store_server from lightllm.utils.envs_utils import get_env_start_args, get_unique_server_name @@ -19,7 +20,9 @@ app = FastAPI() registered_pd_master_objs: Dict[str, PD_Master_Obj] = {} +registered_visual_server_objs: Dict[str, VIT_Obj] = {} registered_pd_master_obj_lock = Lock() +registered_visual_server_obj_lock = Lock() global_req_id = 0 global_req_id_lock = Lock() @@ -72,6 +75,30 @@ async def websocket_endpoint(websocket: WebSocket): return +@app.websocket("/visual_register") +async def visual_websocket_endpoint(websocket: WebSocket): + await websocket.accept() + client_ip, client_port = websocket.client + logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") + registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes()) + logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}") + with registered_visual_server_obj_lock: + registered_visual_server_objs[registered_visual_server_obj.node_id] = registered_visual_server_obj + + try: + while True: + data = await websocket.receive_text() + assert data == "heartbeat" + except (WebSocketDisconnect, Exception, RuntimeError) as e: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} has error {str(e)}") + logger.exception(str(e)) + finally: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} removed") + with registered_visual_server_obj_lock: + registered_visual_server_objs.pop(registered_visual_server_obj.node_id, None) + return + + @app.get("/registered_objects") async def get_registered_objects(): with registered_pd_master_obj_lock: @@ -80,6 +107,14 @@ async def get_registered_objects(): return {"data": base64_encoded} +@app.get("/registered_visual_objects") +async def get_vit_registered_objects(): + with registered_visual_server_obj_lock: + serialized_data = pickle.dumps(registered_visual_server_objs) + base64_encoded = base64.b64encode(serialized_data).decode("utf-8") + return {"data": base64_encoded} + + @app.get("/allocate_global_unique_id_range") async def allocate_global_id_range(): """ diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index dfcbdd2562..75f2c0e2f1 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -23,7 +23,9 @@ def to_group_req_index(self): return GroupReqIndexes( group_req_id=self.group_req_id, multimodal_params=self.multimodal_params, - shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], + shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs] + if self.shm_req_objs is not None + else None, time_mark=self.time_mark, ) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py new file mode 100644 index 0000000000..6ef8c66f68 --- /dev/null +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -0,0 +1,74 @@ +import uuid +import threading +import dataclasses +import requests +from typing import Union, Optional +import torch +import time +from collections import deque +import multiprocessing.shared_memory as shm +from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, EmbedRefCountRedis +from .naive_memory_cache import Record, InMemoryCache +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MemoryCacheWithRedis(InMemoryCache): + def __init__(self, args) -> None: + super().__init__(args) + redis_url = f"redis://{args.config_server_host}:{args.redis_port}/0" + self.redis_cache = EmbedRefCountRedis( + redis_url=redis_url, + capacity=args.cache_capacity, + evict_fraction=args.redis_evict_fraction, + image_embed_dir=args.image_embed_dir, + ) + # 这里之所以把cache * 2是因为,在分离模式下,cache 服务只是为了更新redis状态,以及维护图片cache的 token_id + # 便于 dynamic prompt cache 的使用。所以要把cache_capacity * 2,保障其保留的图片cache > redis 服务维护的 + # 硬盘里的图片image embed 数量。 + self.capacity = max(1, args.cache_capacity * 2) + + # llm 负责release + def release(self, ids: list[int]) -> None: + with self.lock: + for id in ids: + rec = self._records.get(id) + if rec is None: + continue + + redis_exist = self.redis_cache.query(str(id)) + if redis_exist: + self.redis_cache.decr(str(id)) + + # remote_vit 模式下 release 可能走“预层提前释放 + 请求结束兜底释放”两条路径, + # 这里避免本地 ref 被重复减成负数,保证 release 可重复调用。 + if rec.ref > 0: + self._update_record_ref(rec, -1) + + # vit 负责set + def set_items_embed(self, ids: list[int]) -> None: + with self.lock: + for id in ids: + self.redis_cache.insert(str(id)) + rec = self._records.get(id) + if rec is not None: + rec.embed = True + if rec.ref > 0: + self._update_record_ref_by_id(id, -1) + # 保留一份 redis 引用,直到真正的消费者读取完成后再 release, + # 避免 VIT 刚写完文件但 LLM 还没来得及读取时被 LRU 误删。 + + def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: + ret = [] + for id in ids: + if embeding_only: + exist = self.redis_cache.query(str(id)) + else: + exist = self.redis_cache.query_and_incre(str(id)) + ret.append(exist) + if exist: + rec = self._records.get(id) + if rec is not None: + rec.embed = True + return ret diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 61ba46d7c6..0788803753 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -36,6 +36,7 @@ class InMemoryCache: def __init__(self, args) -> None: self.args = args self._id_to_records = dict() + self._records = self._id_to_records self._md5_to_record = dict() self._sorted_records = SortedSet(key=lambda x: (x.ref, x.visittime, x.id)) self.capacity = max(1, args.cache_capacity) @@ -125,18 +126,26 @@ def _free_to_alloc(self, free_min_count: int, new_md5_dict: Dict[str, int]) -> D def _add_ref(self, md5_sum): rec: Record = self._md5_to_record[md5_sum] - self._sorted_records.remove(rec) - rec.ref += 1 - self._sorted_records.add(rec) + self._update_record_ref(rec, 1) return def _del_ref(self, md5_sum): rec: Record = self._md5_to_record[md5_sum] + self._update_record_ref(rec, -1) + return + + def _update_record_ref(self, rec: Record, delta: int): self._sorted_records.remove(rec) - rec.ref -= 1 + rec.ref += delta + rec.visittime = time.time() self._sorted_records.add(rec) return + def _update_record_ref_by_id(self, id_: int, delta: int): + rec: Record = self._id_to_records[id_] + self._update_record_ref(rec, delta) + return + def _judge_enough_token_cache(self, md5sum_list: list[str], token_num_list: list[int]) -> bool: tmp_dict = {} for md5, token_num in zip(md5sum_list, token_num_list): @@ -160,14 +169,13 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l new_md5_dict[m] = token_need new_needed = len(new_md5_dict) - alloc_md5_dict = self._free_to_alloc( free_min_count=new_needed - (self.capacity - self.occupied), new_md5_dict=new_md5_dict ) if len(alloc_md5_dict) == len(new_md5_dict): for md5sum, mem_block in alloc_md5_dict.items(): token_num = new_md5_dict[md5sum] - uid_int = uuid.uuid1().int + uid_int = md5sum self._check_and_set_new_id_range(token_num) rec = Record( id=uid_int, @@ -207,15 +215,14 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l return results else: + for md5sum in add_ref_m_list: + self._del_ref(md5sum) return None def release(self, ids: list[int]) -> None: with self.lock: for id_ in ids: - rec: Record = self._id_to_records[id_] - self._sorted_records.remove(rec) - rec.ref -= 1 - self._sorted_records.add(rec) + self._update_record_ref_by_id(id_, -1) def set_items_data(self, ids: list[int]) -> None: for id_ in ids: @@ -228,5 +235,5 @@ def set_items_embed(self, ids: list[int]) -> None: for id_ in ids: self._id_to_records[id_].embed = True - def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: + def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: return [self._id_to_records.get(id_).embed if id_ in self._id_to_records else False for id_ in ids] diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 5de4df4ab3..faf48c4085 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -6,6 +6,7 @@ from lightllm.server.core.objs import StartArgs from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache +from lightllm.server.embed_cache.impl.memory_cache_with_redis import MemoryCacheWithRedis from rpyc.utils.classic import obtain from lightllm.utils.envs_utils import get_unique_server_name @@ -47,9 +48,16 @@ def exposed_set_items_embed(self, ids: list[int]) -> None: ids = obtain(ids) return self._impl.set_items_embed(ids) - def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: + def exposed_get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[bool]: ids = obtain(ids) - return self._impl.get_items_embed(ids) + return self._impl.get_items_embed(ids, embeding_only) + + +def get_cache_manager(args): + if args.enable_remote_vit or args.run_mode == "visual": + return MemoryCacheWithRedis(args) + else: + return InMemoryCache(args) def start_cache_manager(args: StartArgs, pipe_writer): @@ -57,7 +65,7 @@ def start_cache_manager(args: StartArgs, pipe_writer): graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::cache_manager") - manager = InMemoryCache(args) + manager = get_cache_manager(args) service = CacheServer(manager) from rpyc.utils.server import ThreadedServer import lightllm.utils.rpyc_fix_utils as _ diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 367bcc91a9..caeca0b2b6 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -1,4 +1,59 @@ +import os +import time +import torch +import redis +import numpy as np +from typing import List, Tuple +from io import BytesIO +from pathlib import Path import multiprocessing.shared_memory as shm +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def _get_afs_path(base_dir: str, name: str) -> Path: + if not base_dir: + raise ValueError("image_embed_dir must be set before using disk-backed embed cache") + return Path(base_dir) / name + + +def tensor2bytes(t: torch.Tensor): + buf = BytesIO() + t = t.detach().cpu() + dest = torch.empty_like(t) + dest.copy_(t) + torch.save(dest, buf, _use_new_zipfile_serialization=False, pickle_protocol=4) + buf.seek(0) + return buf.read() + + +def bytes2tensor(b): + return torch.load(BytesIO(b), weights_only=False) + + +def save_tensor_afs(name: str, tensor: torch.Tensor, base_dir: str) -> None: + target_path = _get_afs_path(base_dir, name) + tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}" + + try: + with open(tmp_path, "wb") as f: + torch.save(tensor.detach().cpu(), f, _use_new_zipfile_serialization=False, pickle_protocol=4) + os.replace(tmp_path, target_path) + os.chmod(target_path, 0o777) + except Exception: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + logger.exception(f"failed to save embed tensor file: {target_path}") + raise + + +def load_tensor_afs(name: str, base_dir: str) -> torch.Tensor: + path = _get_afs_path(base_dir, name) + with open(path, "rb") as f: + return torch.load(f, weights_only=False) def create_shm(name, data): @@ -11,17 +66,388 @@ def create_shm(name, data): print("Warning create shm {} failed because of FileExistsError!".format(name)) +def create_afs(name, data, path): + target_path = _get_afs_path(path, name) + data_size = len(data) + tmp_path = target_path.parent / f".{target_path.name}.tmp-{os.getpid()}-{time.time_ns()}" + + try: + with open(tmp_path, "wb") as f: + mem_view = memoryview(data) + f.write(mem_view[:data_size]) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, target_path) + os.chmod(target_path, 0o777) + except Exception: + try: + tmp_path.unlink(missing_ok=True) + except Exception: + pass + logger.exception(f"failed to create embed file: {target_path}") + raise + + def read_shm(name): shared_memory = shm.SharedMemory(name=name) data = shared_memory.buf.tobytes() return data +def read_afs(name: str, base_dir) -> bytes: + path = _get_afs_path(base_dir, name) + return path.read_bytes() + + def free_shm(name): shared_memory = shm.SharedMemory(name=name) shared_memory.close() shared_memory.unlink() +def free_afs(name: str, base_dir) -> None: + path = _get_afs_path(base_dir, name) + path.unlink(missing_ok=True) + + def get_shm_name_data(uid): return str(uid) + "-data" + + +def get_shm_name_embed(uid): + return str(uid) + "-embed" + + +""" +Importable Redis-backed MD5 refcount with LRU eviction. + +Public API: + from md5_refcount import EmbedRefCountRedis + + cache = EmbedRefCountRedis( + redis_url="redis://localhost:6379/0", + capacity=10000, + evict_fraction=0.2 + ) + + # Insert a new md5 with default ref_count=0 + success, evicted_list = cache.insert(md5) + + # Query if exists and increment ref_count if found + exists = cache.query_and_incre(md5) + + # Decrement ref_count + rc, deleted = cache.decr(md5) + + s = cache.stats() +""" + + +class EmbedRefCountRedis: + def __init__( + self, + redis_url: str = "redis://localhost:6379/0", + capacity: int = 50000, + evict_fraction: float = 0.1, + key_prefix: str = "md5:", + image_embed_dir: str = None, + path_ext: str = "-embed", + **redis_kwargs, + ) -> None: + """ + - capacity: max count of md5 entries allowed in Redis + - evict_fraction: fraction to evict when inserting a NEW md5 and at capacity + - image_embed_dir: base directory for image embed files (e.g., "/afs/embeds") + - path_ext: file extension for embed files (default: "-embed") + """ + if not (0.0 <= evict_fraction <= 1.0): + raise ValueError("evict_fraction must be 0..1") + if capacity < 1: + raise ValueError("capacity must be >=1") + + self.capacity = int(capacity) + self.evict_fraction = float(evict_fraction) + self.zset_key = f"{key_prefix}lru" + self.ref_prefix = f"{key_prefix}rc:" + self.lock_key = f"{key_prefix}evict:lock" + self.image_embed_dir = image_embed_dir + self.path_ext = path_ext + + self.r = redis.Redis.from_url(redis_url, decode_responses=True, **redis_kwargs) + + # Register Lua scripts + self._insert_script = self.r.register_script(self._INSERT_LUA) + self._query_incre_script = self.r.register_script(self._QUERY_INCRE_LUA) + self._decr_script = self.r.register_script(self._DECR_LUA) + self._evict_and_insert_script = self.r.register_script(self._EVICT_AND_INSERT_LUA) + + def insert(self, md5: str) -> Tuple[bool, List[str]]: + """Insert a new md5 with default ref_count=1. May trigger LRU eviction.""" + # 等待任何正在进行的逐出操作 + self._wait_if_eviction() + + res = self._insert_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5, self.capacity, self.evict_fraction], + ) + + if res[0] == 0: # No eviction needed + return True, [] + + # Need eviction - use atomic eviction script + try: + if self._try_acquire_lock(): + try: + # 原子执行逐出和插入 + evict_res = self._evict_and_insert_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5, self.capacity, self.evict_fraction], + ) + success = bool(evict_res[0]) + victims = evict_res[1:] if len(evict_res) > 1 else [] + + if success: + # 删除被逐出md5对应的AFS文件 + if victims and self.image_embed_dir: + self._delete_afs_files(victims) + return True, victims + else: + # 逐出失败,短暂退避后重试 + time.sleep(0.01) + return self.insert(md5) + finally: + self._release_lock() + else: + # 等待锁释放后重试 + time.sleep(0.01) + return self.insert(md5) + except Exception as e: + self._release_lock() + raise e + + def query(self, md5: str) -> bool: + """Quert if md5 exists.""" + self._wait_if_eviction() + return bool(self.r.exists(self.ref_prefix + md5)) + + def query_and_incre(self, md5: str) -> bool: + """Query if md5 exists and increment ref_count if found.""" + self._wait_if_eviction() + res = self._query_incre_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5], + ) + return bool(res[0]) + + def decr(self, md5: str) -> Tuple[int, bool]: + """Decrement ref_count for md5. Returns (ref_count, deleted).""" + self._wait_if_eviction() + + res = self._decr_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5], + ) + if res[0] == -1: + raise KeyError("md5 not found") + return int(res[0]), bool(res[1]) + + def stats(self) -> dict: + self._wait_if_eviction() + + size = self.r.zcard(self.zset_key) + return { + "items": size, + "capacity": self.capacity, + "evict_fraction": self.evict_fraction, + } + + def get_ref(self, md5: str) -> int | None: + self._wait_if_eviction() + val = self.r.get(self.ref_prefix + md5) + return int(val) if val is not None else None + + def _wait_if_eviction(self) -> None: + max_wait = 30 + start_time = time.time() + + while self.r.exists(self.lock_key): + if time.time() - start_time > max_wait: + raise TimeoutError("Eviction operation timeout, waited too long") + time.sleep(0.01) # 短暂等待 + + def _try_acquire_lock(self) -> bool: + return bool(self.r.set(self.lock_key, "1", nx=True, ex=30)) + + def _release_lock(self) -> None: + try: + self.r.delete(self.lock_key) + except Exception: + pass + + def _md5_to_afs_path(self, md5: str) -> str: + """Convert md5 to AFS file path.""" + if not self.image_embed_dir: + return None + return str(_get_afs_path(self.image_embed_dir, f"{md5}{self.path_ext}")) + + def _delete_afs_files(self, victims: List[str]) -> None: + """Delete AFS files for evicted md5s.""" + if not self.image_embed_dir: + return + + for md5 in victims: + try: + file_path = self._md5_to_afs_path(md5) + if file_path and os.path.exists(file_path): + os.remove(file_path) + logger.debug(f"Deleted AFS file: {file_path}") + except Exception as e: + logger.debug(f"Warning: Failed to delete AFS file for {md5}: {e}") + + # ---------------- Lua scripts ---------------- + _INSERT_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5, ARGV[2] = capacity, ARGV[3] = evict_fraction +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] +local capacity = tonumber(ARGV[2]) + +local unpack = unpack or table.unpack +local ref_key = ref_prefix .. md5 +if redis.call('GET', ref_key) then + return {0} -- Already exists +end + +local size = redis.call('ZCARD', zset) +if size < capacity then + -- Insert with ref_count=1 + redis.call('SET', ref_key, 1) + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) + return {0} -- Success, no eviction +end + +return {1} -- Need eviction +""" + + _QUERY_INCRE_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5 +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] + +local ref_key = ref_prefix .. md5 +local val = redis.call('GET', ref_key) + +if not val then + return {0} -- Not found +end + +-- Found, increment ref_count and update LRU +local rc = tonumber(val) + 1 +redis.call('SET', ref_key, rc) +local now = redis.call('TIME')[1] * 1000 +redis.call('ZADD', zset, now, md5) +return {1} -- Found and incremented +""" + + _DECR_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5 +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] + +local ref_key = ref_prefix .. md5 +local val = redis.call('GET', ref_key) + +if not val then + return {-1, 0} -- Not found +end + +--ref 递减到 0 时保留键,只更新计数与 LRU +local rc = tonumber(val) - 1 +if rc < 0 then rc = 0 end +redis.call('SET', ref_key, rc) + +if rc > 0 then + -- 只有仍被引用时才更新 LRU + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) +end + +return {rc, 0} +""" + + _EVICT_AND_INSERT_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = new_md5, ARGV[2] = capacity, ARGV[3] = evict_fraction +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local new_md5 = ARGV[1] +local capacity = tonumber(ARGV[2]) +local evict_fraction = tonumber(ARGV[3]) + +local unpack = unpack or table.unpack + +-- helper: now millis +local function now_ms() + local t = redis.call('TIME') + return t[1] * 1000 + math.floor(t[2] / 1000) +end + +local new_ref_key = ref_prefix .. new_md5 + +-- If already exists, treat as a hit: bump ref_count and refresh LRU +local cur = redis.call('GET', new_ref_key) +if cur then + local rc = tonumber(cur) + 1 + redis.call('SET', new_ref_key, rc) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1} -- success, no victims +end + +-- If not at capacity, just insert +local size = redis.call('ZCARD', zset) +if size < capacity then + redis.call('SET', new_ref_key, 1) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1} -- success, no victims +end + +-- At capacity: try to evict up to max_try items with rc==0, but success if at least 1 is freed +local max_try = math.max(1, math.floor(size * evict_fraction + 0.5)) +local victims = {} +local freed = 0 + +-- Scan from LRU (smallest score) to MRU +local all_keys = redis.call('ZRANGE', zset, 0, -1, 'WITHSCORES') +local i = 1 +while freed < 1 and i <= #all_keys and #victims < max_try do + local md5 = all_keys[i] + local ref_key = ref_prefix .. md5 + local v = redis.call('GET', ref_key) + if v and tonumber(v) <= 0 then + table.insert(victims, md5) + freed = freed + 1 + end + i = i + 2 -- skip score +end + +if freed >= 1 then + -- delete victims + for _, v in ipairs(victims) do + redis.call('DEL', ref_prefix .. v) + redis.call('ZREM', zset, v) + end + -- insert new + redis.call('SET', new_ref_key, 1) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1, unpack(victims)} +else + -- no zero-ref items found + return {0} +end +""" diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 6481098eb9..0ee3ce03dd 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -82,15 +82,10 @@ def __init__( if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + # 初始化VIT连接管理器 + from lightllm.server.visualserver.vit_connect import VITConnectionManager - if not self.args.disable_vision: - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") - - if not self.args.disable_audio: - self.send_to_audio = context.socket(zmq.PUSH) - self.send_to_audio.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") - + self.vit_manager = VITConnectionManager(args, context, args.visual_port, self.cache_client) if args.enable_cpu_cache and not self.args.enable_multimodal: self.send_to_multi_level_kv_cache = context.socket(zmq.PUSH) self.send_to_multi_level_kv_cache.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") @@ -124,10 +119,10 @@ def __init__( self.latest_success_infer_time_mark.set_value(int(time.time())) return - async def _alloc_resource(self, items, md5sums, token_nums, datas): + async def _alloc_resource(self, items, uuids, token_nums, datas): while True: - records = obtain(self.cache_client.root.alloc(md5sums, token_nums)) + records = obtain(self.cache_client.root.alloc(uuids, token_nums)) if records is None: await asyncio.sleep(0.1) @@ -147,6 +142,12 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): uid_list.append(rec["id"]) + # # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server + if self.args.enable_remote_vit: + # 避免远端lru被逐出 + self.cache_client.root.get_items_embed(uid_list, False) + return + ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) update_data_ids = [] @@ -166,14 +167,15 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 async with self._resource_lock: - items, md5sums, tokens_nums, datas = [], [], [], [] + items, uuids, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) - md5sums.append(md5sum) + md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), img.patch_num) + uuid = int(md5sum, 16) + uuids.append(uuid) tokens_nums.append(token_num) datas.append(data) items.append(img) @@ -181,13 +183,17 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) data = audio.read() token_num = self.tokenizer.get_audio_token_length(audio) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) - md5sums.append(md5sum) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), + ) + uuid = int(md5sum, 16) + uuids.append(uuid) tokens_nums.append(token_num) datas.append(data) items.append(audio) - await self._alloc_resource(items, md5sums, tokens_nums, datas) + await self._alloc_resource(items, uuids, tokens_nums, datas) return async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): @@ -408,6 +414,48 @@ async def generate( raise e return + async def get_image_embeding( + self, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + is_health_req: bool = False, + ) -> Tuple[int, str, dict, FinishStatus]: + start_time = time.time() + request_headers = request.headers if request is not None else {} + group_request_id = self.alloc_req_id(sampling_params, is_health_req) + + try: + original_multimodal_params = None + if self.is_multinode_tp_master: + original_multimodal_params = copy.deepcopy(multimodal_params) + + await multimodal_params.verify_and_preload(request) + image_count = len(multimodal_params.images) + # 记录请求到达的相关信息 + + await self._log_req_header(request_headers, group_request_id) + logger.info(f"image_count:{image_count}") + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + + visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time) + + await self.transfer_to_next_module_or_node( + None, sampling_params, original_multimodal_params, visual_req_status + ) + await self._release_multimodal_resources(multimodal_params) + + except Exception as e: + logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + await self._release_multimodal_resources(multimodal_params) + await self.abort(group_request_id) + raise e + return + def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]: image_tokens = 0 audio_tokens = 0 @@ -527,13 +575,11 @@ async def transfer_to_next_module( ): if self.pd_mode.is_P_or_NORMAL(): - if not self.args.disable_vision: - self.send_to_visual.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL) - return - - if not self.args.disable_audio: - self.send_to_audio.send_pyobj(group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL) - return + if self.enable_multimodal: + await self.vit_manager.send_to_vit( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL, + ) if self.args.enable_cpu_cache: self.send_to_multi_level_kv_cache.send_pyobj( @@ -542,10 +588,11 @@ async def transfer_to_next_module( ) return - self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), - protocol=pickle.HIGHEST_PROTOCOL, - ) + if not self.enable_multimodal or self.args.enable_remote_vit: + self.send_to_router.send_pyobj( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL, + ) return if self.pd_mode.is_D(): @@ -742,6 +789,9 @@ async def handle_loop(self): asyncio.create_task(pd_handle_loop(self)) + if self.enable_multimodal: + asyncio.create_task(self.vit_manager.vit_handle_loop()) + while True: try: await asyncio.wait_for(self.zmq_recv_socket.recv_pyobj(), timeout=0.05) diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 09a07455b3..c86a9a0786 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -26,6 +26,7 @@ def __init__(self, **kwargs): self.token_num = None # the audio length self.audio_length = None + self.afs_embed = False self._preload_data = None self.extra_params = {} @@ -54,10 +55,11 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - ans = self._preload_data + return self._preload_data + + def free(self): self._preload_data = None self._data = None - return ans def to_dict(self): ret = {} @@ -95,6 +97,7 @@ def __init__(self, **kwargs): self.grid_thwd = None self.image_w = 0 self.image_h = 0 + self.patch_num = 0 self._preload_data = None self.extra_params = {} @@ -128,10 +131,11 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - ans = self._preload_data + return self._preload_data + + def free(self): self._preload_data = None self._data = None - return ans def to_dict(self): ret = {} @@ -163,6 +167,15 @@ def __init__( self.audios = [AudioItem(**a) for a in audios] return + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + + def get_all_uuids(self): + return [image.uuid for image in self.images] + [audio.uuid for audio in self.audios] + async def verify_and_preload(self, request: Request): for image in self.images: await image.preload(request) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 1b4a1ca5cb..3aee370eb1 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -59,8 +59,8 @@ def register( self.vocab_size = vocab_size return - def init_cpu_embed_cache_client(self): - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=False) + def init_cpu_embed_cache_client(self, init_shm_data: bool = False): + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=init_shm_data) return def get_overlap_stream(self) -> torch.cuda.Stream: diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 8b085c45ed..8b0522d58e 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -145,7 +145,7 @@ def init_model(self, kvargs): wait_events.append(self.multi_level_cache_module) if self.args.enable_multimodal: - g_infer_context.init_cpu_embed_cache_client() + g_infer_context.init_cpu_embed_cache_client(init_shm_data=self.args.enable_remote_vit) model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 8fba9f08d7..508860e899 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -20,7 +20,7 @@ from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name from rpyc.utils.classic import obtain - +from lightllm.server.embed_cache.utils import create_shm, get_shm_name_data logger = init_logger(__name__) @@ -31,13 +31,16 @@ def __init__( args: StartArgs, visual_model_rpc_ports, ): + self.args = args + self.visual_only = args.run_mode in ["visual", "visual_only"] + self.remote_vit = args.enable_remote_vit or self.visual_only + context = zmq.Context(2) - enable_audio = not args.disable_audio - if enable_audio: - self.send_to_next_module = context.socket(zmq.PUSH) - self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") - else: - if args.enable_cpu_cache: + if not self.visual_only: + if not args.disable_audio: + self.send_to_next_module = context.socket(zmq.PUSH) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.audio_port}") + elif args.enable_cpu_cache: self.send_to_next_module = context.socket(zmq.PUSH) self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.multi_level_kv_cache_port}") else: @@ -45,7 +48,11 @@ def __init__( self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{args.router_port}") self.zmq_recv_socket = context.socket(zmq.PULL) - self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + if self.remote_vit: + self.zmq_recv_socket.bind(f"tcp://*:{args.remote_vit_port}") + else: + self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") + self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.cache_port = args.cache_port @@ -56,13 +63,11 @@ def __init__( self.vit_tp = args.visual_tp self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code - self.args = args self.visual_model_rpc_ports = visual_model_rpc_ports self.send_batch_size = args.visual_send_batch_size self.shm_req_manager = ShmReqManager() async def wait_to_model_ready(self): - self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] self.vit_attn_backend = init_vit_att_backend(index=0) for dp_rank_id in range(self.vit_dp): @@ -146,13 +151,12 @@ def flush_ready(force: bool = False): continue multimodal_params = group_req_indexes.multimodal_params - img_uuids = [img.uuid for img in multimodal_params.images] # disable prompt cache通常用来测试,需要也去掉image cache的影响 if disable_prompt_cache: ready_image = [False] * len(img_uuids) else: - ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) + ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids, True)) for img, ready in zip(multimodal_params.images, ready_image): if not ready: @@ -180,6 +184,43 @@ def flush_ready(force: bool = False): processing_group_reqs = [] flush_ready(force=True) + async def _recv_reqs(self): + recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + if not self.remote_vit: + return recv_req + + uuids = [img.uuid for img in recv_req.multimodal_params.images] + already_embed = await asyncio.to_thread(self.cache_client.root.get_items_embed, uuids, True) + if all(already_embed): + return None + + missing_uuids = [] + token_nums = [] + datas = [] + for img, embed_ready in zip(recv_req.multimodal_params.images, already_embed): + if embed_ready: + continue + missing_uuids.append(img.uuid) + token_nums.append(img.token_num) + datas.append(img.read()) + img.free() + + while True: + if await asyncio.to_thread(self.cache_client.root.alloc, missing_uuids, token_nums) is not None: + break + await asyncio.sleep(0.01) + + ready_flags = obtain(self.cache_client.root.get_items_data(missing_uuids)) + update_data_ids = [] + for uid, ready, data in zip(missing_uuids, ready_flags, datas): + if not ready: + create_shm(get_shm_name_data(uid), data) + update_data_ids.append(uid) + + if update_data_ids: + await asyncio.to_thread(self.cache_client.root.set_items_data, update_data_ids) + return recv_req + async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): self.visual_recv_max_count = 64 @@ -187,7 +228,9 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.zmq_recv_socket.recv_pyobj(zmq.NOBLOCK) + recv_req = await self._recv_reqs() + if recv_req is None: + continue if isinstance(recv_req, GroupReqIndexes): logger.info( f"visual recv req id {recv_req.group_req_id} " @@ -196,12 +239,31 @@ async def loop_for_netio_req(self): self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" - self.visual_recv_max_count = int(min(self.visual_recv_max_count * 1.3, 256)) + self.visual_recv_max_count = min(int(self.visual_recv_max_count * 1.3), 256) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 self.visual_recv_max_count = 64 await asyncio.sleep(0.01) + async def loop_for_fwd_visual_only(self): + while True: + if len(self.waiting_reqs) == 0: + await asyncio.sleep(0.01) + continue + + images_need_infer = [] + while len(self.waiting_reqs) > 0: + visual_req = self.waiting_reqs.pop(0) + for img in visual_req.multimodal_params.images: + images_need_infer.append(img) + if len(images_need_infer) == self.infer_batch_size: + await self.infer_imgs(images_need_infer) + images_need_infer = [] + + if len(images_need_infer) > 0: + await self.infer_imgs(images_need_infer) + images_need_infer = [] + def clean_up(self): for model_rpc in self.model_rpcs: model_rpc.rpc_server_process.kill() @@ -210,17 +272,29 @@ def clean_up(self): return +def create_forward_loop(args, visualserver: VisualManager, loop: asyncio.AbstractEventLoop): + if args.run_mode in ["visual", "visual_only"]: + from .register_loop import register_loop + + loop.create_task(visualserver.loop_for_fwd_visual_only()) + loop.create_task(register_loop(args)) + else: + loop.create_task(visualserver.loop_for_fwd()) + + def start_visual_process(args, model_rpc_ports, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) setproctitle.setproctitle(f"lightllm::{get_unique_server_name()}::visual_server") start_parent_check_thread() + visualserver = None try: visualserver = VisualManager(args=args, visual_model_rpc_ports=model_rpc_ports) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: logger.exception(str(e)) - visualserver.clean_up() + if visualserver is not None: + visualserver.clean_up() raise e pipe_writer.send("init ok") @@ -231,6 +305,6 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - loop.create_task(visualserver.loop_for_fwd()) + create_forward_loop(args, visualserver, loop) loop.run_until_complete(visualserver.loop_for_netio_req()) return diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 3e97f4de3e..237e52ce6f 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -6,25 +6,27 @@ import inspect from datetime import timedelta from typing import Dict, List, Tuple -from transformers.configuration_utils import PretrainedConfig from rpyc.utils.classic import obtain from rpyc.utils.server import ThreadedServer +from transformers.configuration_utils import PretrainedConfig + from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer -from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel +from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel from lightllm.models.vit.model import VisionTransformer from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel +from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel -from lightllm.models.qwen3_omni_moe_thinker.qwen3_omni_visual import Qwen3OmniMoeVisionTransformerPretrainedModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient +from lightllm.server.embed_cache.utils import create_afs, get_shm_name_embed, tensor2bytes, save_tensor_afs from lightllm.server.visualserver import set_vit_att_backend @@ -34,6 +36,7 @@ def exposed_init_model(self, kvargs): import torch import torch.distributed as dist + self.args = get_env_start_args() self.vit_dp = kvargs["vit_dp"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] @@ -41,6 +44,11 @@ def exposed_init_model(self, kvargs): self.cache_port = kvargs["cache_port"] weight_dir = kvargs["weight_dir"] self.vit_rank_id = kvargs["vit_rank_id"] + self.image_embed_dir = self.args.image_embed_dir + self.remote_vit = self.args.enable_remote_vit or self.args.run_mode in ["visual", "visual_only"] + if self.remote_vit and not self.image_embed_dir: + raise ValueError("remote vit mode requires image_embed_dir") + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self.data_type = kvargs["data_type"] @@ -56,6 +64,7 @@ def exposed_init_model(self, kvargs): "quant_type": kvargs["quant_type"], "quant_cfg": kvargs["quant_cfg"], "max_batch_size": kvargs["max_batch_size"], + "remote_vit": self.remote_vit, } self.model_type = model_cfg["model_type"] if self.model_type == "qwen": @@ -92,10 +101,10 @@ def exposed_init_model(self, kvargs): ) else: raise Exception(f"can not support {self.model_type} now") - self.model.load_model(weight_dir) self.model = self.model.cuda() - self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True) + if not self.remote_vit: + self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True) except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -116,33 +125,47 @@ def forward(self, images: List[ImageItem]): def exposed_encode(self, images: List[ImageItem]): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cuda")) - - if self.tp_rank_id == 0: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) - ids_to_set = [] - for i, ready in enumerate(ready_flags): - if ready: - continue - uid = uuids[i] - start, end = valid_ids[i] - image = images[i] + + if self.tp_rank_id != 0: + return + + ready_flags = obtain(self.cache_client.root.get_items_embed(uuids, True)) + ids_to_set = [] + cpu_embeds = None + if self.remote_vit: + cpu_embeds = all_img_embeds.to(torch.device("cpu"), non_blocking=True) + + for i, ready in enumerate(ready_flags): + if ready: + continue + uid = uuids[i] + start, end = valid_ids[i] + image = images[i] + if self.remote_vit: + save_tensor_afs(get_shm_name_embed(uid), cpu_embeds[start:end], self.image_embed_dir) + else: self.cpu_embed_cache_client.copy_vision_to_cache( - embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache + embed_tensor=all_img_embeds[start:end], + start_index_in_cache=image.start_index_in_embed_cache, ) - ids_to_set.append(uid) - if ids_to_set: - self.cache_client.root.set_items_embed(ids_to_set) + ids_to_set.append(uid) + + if ids_to_set: + self.cache_client.root.set_items_embed(ids_to_set) + if not self.remote_vit: torch.cuda.current_stream().synchronize() return class VisualModelRpcClient: - def __init__(self, model_rpc, vit_tp, rpc_server_process=None): - self.model: VisualModelRpcServer = model_rpc + def __init__(self, conn, vit_tp, rpc_server_process=None): + self.conn = conn + self.model: VisualModelRpcServer = conn.root self.vit_tp = vit_tp self.rpc_server_process = rpc_server_process self.use_rpc = True + self._bg = rpyc.BgServingThread(self.conn) + if self.use_rpc: def async_wrap(f): @@ -161,15 +184,12 @@ async def _func(*args, **kwargs): else: self._init_model = self.model.exposed_init_model self._encode = self.model.exposed_encode - return async def init_model(self, kvargs): ans: rpyc.AsyncResult = self._init_model(kvargs) if self.use_rpc: await ans return - else: - return async def encode(self, images: List[ImageItem]): ans = self._encode(images) @@ -215,4 +235,4 @@ async def start_model_process(port, vit_tp, device_id): raise Exception("init rpc env error!") assert proc.is_alive() - return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc) + return VisualModelRpcClient(con, vit_tp, rpc_server_process=proc) diff --git a/lightllm/server/visualserver/register_loop.py b/lightllm/server/visualserver/register_loop.py new file mode 100644 index 0000000000..31d0f7b8ac --- /dev/null +++ b/lightllm/server/visualserver/register_loop.py @@ -0,0 +1,42 @@ +import asyncio +import pickle +import websockets +import socket +from lightllm.utils.net_utils import get_hostname_ip +from lightllm.utils.log_utils import init_logger +from .vit_connect import VIT_Obj + +logger = init_logger(__name__) + + +async def register_loop(args): + assert args.host not in ["127.0.0.1", "localhost"], "remote visual server must specify host ip" + + if args.host in ["0.0.0.0"]: + host_ip = get_hostname_ip() + else: + host_ip = args.host + + while True: + + try: + uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_register" + async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket: + + sock = websocket.transport.get_extra_info("socket") + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip_port=f"{host_ip}:{args.remote_vit_port}") + + await websocket.send(pickle.dumps(vit_obj)) + logger.info(f"Sent registration vit_obj: {vit_obj}") + + while True: + await websocket.send("heartbeat") + await asyncio.sleep(40) + + except Exception as e: + logger.error("connetion to config_server has error") + logger.exception(str(e)) + await asyncio.sleep(10) + logger.info("reconnection to config_server") diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py new file mode 100644 index 0000000000..af93effaf0 --- /dev/null +++ b/lightllm/server/visualserver/vit_connect.py @@ -0,0 +1,236 @@ +import asyncio +import zmq +import zmq.asyncio +import time +import pickle +from typing import Dict, List, Optional, Any +from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs.io_objs import GroupReqObjs, GroupReqIndexes +from lightllm.server.multimodal_params import MultimodalParams +import httpx +import base64 +from dataclasses import dataclass +import rpyc + +logger = init_logger(__name__) + + +@dataclass +class VIT_Obj: + node_id: int + host_ip_port: str + + def to_log_str(self): + return f"VIT host_ip_port: {self.host_ip_port} node_id: {self.node_id}" + + +class VITConnectionManager: + """VIT连接管理器""" + + def __init__(self, args, context, local_visual_port: int, cache_client: rpyc.Connection): + self.args = args + self.context = context + self.local_visual_port = local_visual_port + + self.send_to_visual = None + self.remote_vit_instances = {} + self.current_vit_index = 0 + self.remote_vit = args.enable_remote_vit + self.remote_vit_port = args.remote_vit_port + self.cache_client = cache_client + + self._setup_vit_connections() + + def _setup_vit_connections(self): + """ + 设置VIT连接,支持本地和远程VIT实例 + 支持多种连接模式: + 1. 本地VIT实例 (默认) + 2. 远程多个VIT实例 (负载均衡) + """ + if self.remote_vit: + # 远程VIT实例模式 + self._setup_remote_vit_connections() + else: + self._setup_local_vit_connection() + + def _setup_local_vit_connection(self): + self.send_to_visual = self.context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + + def _setup_remote_vit_connections(self): + """ + 初始化远程VIT连接,同步获取初始实例 + """ + logger.info("Setting up remote VIT connections...") + + self._sync_init_vit_instances() + + retry_count = 0 + max_retries = 30 # 最多等待30秒 + while len(self.remote_vit_instances) == 0 and retry_count < max_retries: + logger.info(f"Waiting for VIT instances... (attempt {retry_count + 1}/{max_retries})") + time.sleep(1) + retry_count += 1 + self._sync_init_vit_instances() + + if len(self.remote_vit_instances) == 0: + logger.warning("No VIT instances available after initialization") + else: + logger.info(f"Successfully connected to {len(self.remote_vit_instances)} VIT instances") + + def _sync_init_vit_instances(self): + """ + 同步初始化VIT实例连接 + """ + try: + # 使用同步方式获取VIT实例 + vit_objs = self._sync_get_vit_objs() + if vit_objs: + self._update_vit_connections(vit_objs) + except Exception as e: + logger.error(f"Failed to initialize VIT instances: {e}") + + def _sync_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + 同步获取VIT实例信息 + """ + import requests + + uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + response = requests.get(uri, timeout=10) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.error(f"Error getting VIT instances: {e}") + return None + + def _update_vit_connections(self, id_to_vit_obj: Dict[int, VIT_Obj]): + """ + 更新VIT连接,添加新的连接,关闭失效的连接 + """ + # 关闭不再存在的连接 + closed_ids = [] + for id, remote_instance in self.remote_vit_instances.items(): + if id not in id_to_vit_obj: + try: + remote_instance.close() + except: + pass + closed_ids.append(id) + logger.info(f"Closed VIT connection {id}") + + for id in closed_ids: + self.remote_vit_instances.pop(id) + + # 建立新的连接 + for id, vit_obj in id_to_vit_obj.items(): + if id not in self.remote_vit_instances: + try: + socket = self.context.socket(zmq.PUSH) + # print(vit_obj.host_ip_port, self.args.remote_vit_port, flush=True) + ip, port = vit_obj.host_ip_port.split(":") + socket.connect(f"tcp://{ip}:{port}") + self.remote_vit_instances[id] = socket + logger.info(f"Connected to VIT instance {id} at {vit_obj.host_ip_port}") + except Exception as e: + logger.error(f"Failed to connect to VIT instance {id}: {e}") + + def _get_vit_instance(self): + """ + 获取下一个可用的VIT实例 (轮询负载均衡) + """ + if not self.remote_vit: + return self.send_to_visual + + if len(self.remote_vit_instances) == 0: + raise Exception("No available VIT instances") + + # 简单的轮询负载均衡 + index = (self.current_vit_index + 1) % len(self.remote_vit_instances) + self.current_vit_index = index + return list(self.remote_vit_instances.values())[index] + + async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL): + """ + 发送数据到VIT实例,支持本地和远程模式 + """ + instance = self._get_vit_instance() + # 本地模式下,提前释放图片资源,降低传输开销 + if not self.remote_vit: + req.multimodal_params.free() + + try: + instance.send_pyobj(req, protocol=protocol) + except Exception as e: + logger.error(f"Failed to send to VIT instance: {e}") + raise Exception(f"Failed to send to VIT instance: {e}") + + # 远程模式下,发送完以后,在释放图片资源 + await self._wait_visual_embed_ready(req) + if self.remote_vit: + req.multimodal_params.free() + + async def vit_handle_loop(self): + """ + 异步VIT连接管理循环,由外部启动 + """ + if not self.remote_vit: + return + logger.info("Starting VIT connection management loop") + while True: + try: + id_to_vit_obj = await self._async_get_vit_objs() + if id_to_vit_obj: + self._update_vit_connections(id_to_vit_obj) + await asyncio.sleep(30) + except Exception as e: + logger.exception(f"Error in VIT handle loop: {e}") + await asyncio.sleep(10) + + async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + 异步获取VIT实例信息 + """ + uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + async with httpx.AsyncClient() as client: + response = await client.get(uri) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.exception(f"Error getting VIT instances: {e}") + return None + + async def _wait_visual_embed_ready( + self, + req: GroupReqIndexes, + timeout_seconds: int = 1000, + ): + # 本地模式不需要等待 + if not self.remote_vit: + return + uuids = req.multimodal_params.get_all_uuids() + + async def wait_for_embeds(): + while not all(self.cache_client.root.get_items_embed(uuids, True)): + await asyncio.sleep(0.01) + + try: + await asyncio.wait_for(wait_for_embeds(), timeout=timeout_seconds) + except asyncio.TimeoutError: + logger.error( + f"Req {req.group_req_id}: timeout waiting for visual embed ready after {timeout_seconds} seconds" + ) diff --git a/lightllm/utils/redis_utils.py b/lightllm/utils/redis_utils.py new file mode 100644 index 0000000000..acc4deb589 --- /dev/null +++ b/lightllm/utils/redis_utils.py @@ -0,0 +1,74 @@ +import subprocess +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def start_redis_service(args): + """launch redis service""" + if not hasattr(args, "start_redis") or not args.start_redis: + return None + + config_server_host = args.config_server_host + redis_port = args.redis_port + try: + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "FLUSHALL", "ASYNC"], check=False, timeout=2 + ) + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "SHUTDOWN", "NOSAVE"], check=False, timeout=2 + ) + except Exception: + pass + + try: + redis_command = [ + "redis-server", + "--port", + str(redis_port), + "--bind", + f"{config_server_host}", + "--daemonize", + "no", + "--logfile", + "/dev/stdout", + "--loglevel", + "notice", + "--save", + '""', # 不触发 RDB 快照 + "--appendonly", + "no", # 关闭 AOF + ] + + logger.info(f"Starting Redis service on port {redis_port}") + redis_process = subprocess.Popen(redis_command) + + import redis + import time + + max_wait = 10 + start_time = time.time() + + while time.time() - start_time < max_wait: + try: + r = redis.Redis(host=args.config_server_host, port=redis_port, socket_connect_timeout=1) + r.ping() + logger.info(f"Redis service started successfully on port {redis_port}") + del r + break + except Exception: + time.sleep(0.5) + if redis_process.poll() is not None: + logger.error("Redis service failed to start") + return None + else: + logger.error("Redis service startup timeout") + if redis_process.poll() is None: + redis_process.terminate() + return None + + return redis_process + + except Exception as e: + logger.error(f"Failed to start Redis service: {e}") + return None diff --git a/lightllm/utils/start_utils.py b/lightllm/utils/start_utils.py index 372b7e1cfa..8245431084 100644 --- a/lightllm/utils/start_utils.py +++ b/lightllm/utils/start_utils.py @@ -111,4 +111,12 @@ def kill_recursive(proc): logger.warning(f"Process {proc.pid} does not exist.") +def is_multimodal_mode(args): + from transformers import PretrainedConfig + + model_cfg, _ = PretrainedConfig.get_config_dict(args.model_dir) + is_multimodal = "visual" in model_cfg or "vision_config" in model_cfg + return is_multimodal + + process_manager = SubmoduleManager() diff --git a/requirements.txt b/requirements.txt index 5b0b201ae3..5331227586 100644 --- a/requirements.txt +++ b/requirements.txt @@ -94,3 +94,5 @@ partial_json_parser==0.2.1.1.post6 websockets==15.0.1 cupy-cuda12x==13.6.0 nixl==0.8.0 +xformers==0.0.33.post2 +redis==7.3.0