Skip to content
Open

Vit sep #1234

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 9 additions & 32 deletions lightllm/common/basemodel/attention_vit/fa3/fp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import torch
from lightllm.common.basemodel.attention_vit.base_att import BaseVitAttBackend
from lightllm.utils.sgl_utils import flash_attn_varlen_func


class Fa3VitAttBackend(BaseVitAttBackend):
Expand All @@ -17,42 +18,18 @@ def _vit_att_fwd(
head_dim = q.shape[-1]
softmax_scale = head_dim ** -0.5
window_size = (-1, -1)
torch.ops.sgl_kernel.fwd.default(
o = flash_attn_varlen_func(
q,
k,
v,
None, # k_new
None, # v_new
None, # qv
o, # out
cu_seqlens,
cu_seqlens,
None, # cu_seqlens_k_new
None,
None,
max_seqlen,
max_seqlen,
None, # page_table,
None, # kv_batch_idx
None, # leftpad_k
None, # rotary cos
None, # rotary sin
None, # seqlens_rotary
None,
None,
None,
softmax_scale,
False,
window_size[0],
window_size[1],
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
softmax_scale=softmax_scale,
causal=False,
window_size=window_size,
attention_chunk=0,
softcap=0.0,
is_rotary_interleaved=False,
scheduler_metadata=None,
num_splits=1,
pack_gqa=None,
sm_margin=0,
sinks=None,
)

return o
13 changes: 7 additions & 6 deletions lightllm/models/internvl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,22 @@ def init_imageitem_extral_params(
img.extra_params["image_patch_max_num"] = 6
elif num_images > 6:
img.extra_params["image_patch_max_num"] = 0
img.patch_num = self.get_image_patch(img)
return

def init_audioitem_extral_params(
self, audio: AudioItem, multi_params: MultimodalParams, sampling_params: SamplingParams
):
return

def get_image_token_length(self, img: ImageItem):
return (
self.get_image_patch_func(
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
)
* self.image_length
def get_image_patch(self, img: ImageItem):
return self.get_image_patch_func(
img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True
)

def get_image_token_length(self, img: ImageItem):
return self.get_image_patch(img) * self.image_length

def get_audio_token_length(self, audio: AudioItem):
L = audio.audio_length
audio_token_num = 0
Expand Down
36 changes: 35 additions & 1 deletion lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import rpyc
import socket
import torch
import torch.distributed as dist

from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.server.embed_cache.utils import get_shm_name_embed, load_tensor_afs
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
from lightllm.distributed.communication_op import all_reduce
from lightllm.utils.envs_utils import get_env_start_args


"""
Expand All @@ -26,24 +30,41 @@
class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer):
def __init__(self, network_config):
super().__init__(network_config)
self.args = get_env_start_args()
self.cache_client = None
if self.args.enable_remote_vit:
self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True})
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
return

def _copy_loaded_embed_to_cache(
self, embed_tensor: torch.Tensor, cpu_embed_cache_tensor: torch.Tensor, start_index: int
):
if embed_tensor.ndim == 2:
embed_tensor = embed_tensor.unsqueeze(1)

token_num, layer_num, hidden_size = embed_tensor.shape
cpu_embed_cache_tensor[start_index : start_index + token_num, :layer_num, :hidden_size].copy_(embed_tensor)
return

def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
img_start_token_ids = []
img_token_lens = []
img_start_locs_in_cache = []
unique_uids = []
device = layer_weight.wte_weight_.weight.device
dtype = layer_weight.wte_weight_.weight.dtype
hidden_size = layer_weight.wte_weight_.weight.shape[1]

for batch_id, p in enumerate(infer_state.multimodal_params):
for _, p in enumerate(infer_state.multimodal_params):
for img in p["images"] + p["audios"]:
# skip the same image
if img["token_id"] in img_start_token_ids:
continue
img_start_token_ids.append(img["token_id"])
img_token_lens.append(img["token_num"])
img_start_locs_in_cache.append(img["start_index_in_embed_cache"])
unique_uids.append(img["uuid"])
out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device)

from lightllm.server.router.model_infer.infer_batch import g_infer_context
Expand All @@ -55,6 +76,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
else cpu_embed_cache_client.cpu_embed_cache_tensor
)

if self.args.enable_remote_vit:
release_ids = []
for _, p in enumerate(infer_state.multimodal_params):
for img in p["images"] + p["audios"]:
release_ids.append(img["uuid"])
Comment on lines +81 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This loop over infer_state.multimodal_params is redundant. You can collect release_ids in the first loop over the same data structure (lines 59-67) to improve efficiency and code clarity. This avoids iterating over the same data twice. You can add release_ids.append(img["uuid"]) to the first loop when self.args.enable_remote_vit is true.


for uid, start_index_in_embed_cache in zip(unique_uids, img_start_locs_in_cache):
embed_tensor = load_tensor_afs(get_shm_name_embed(uid), self.args.image_embed_dir)
self._copy_loaded_embed_to_cache(embed_tensor, cpu_embed_cache_tensor, start_index_in_embed_cache)

if release_ids:
self.cache_client.root.release(release_ids)

assert cpu_embed_cache_tensor.shape[2] == hidden_size, (
f"Dimension mismatch: text weight dimension is {hidden_size}, "
f"but image embed dimension is {cpu_embed_cache_tensor.shape[2]}"
Expand Down
44 changes: 10 additions & 34 deletions lightllm/models/vit/triton_kernel/flashattention_nopad.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,44 +167,20 @@ def flash_attention_v3_fwd(
head_dim = q.shape[-1]
softmax_scale = head_dim ** -0.5
window_size = (-1, -1)
torch.ops.sgl_kernel.fwd.default(
o = flash_attn_varlen_func(
q,
k,
v,
None, # k_new
None, # v_new
None, # qv
o, # out
cu_seqlens,
cu_seqlens,
None, # cu_seqlens_k_new
None,
None,
max_seqlen,
max_seqlen,
None, # page_table,
None, # kv_batch_idx
None, # leftpad_k
None, # rotary cos
None, # rotary sin
None, # seqlens_rotary
None,
None,
None,
softmax_scale,
False,
window_size[0],
window_size[1],
0.0,
is_rotary_interleaved=False,
scheduler_metadata=None,
num_splits=1,
pack_gqa=None,
sm_margin=0,
sinks=None,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
softmax_scale=softmax_scale,
causal=False,
window_size=window_size,
softcap=0.0,
)

return
return o

except ImportError:
print("Failed to import _flash_attn_forward from hopper.flash_attn_interface.")
Expand Down
46 changes: 45 additions & 1 deletion lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--run_mode",
type=str,
choices=["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"],
choices=[
"normal",
"prefill",
"decode",
"nixl_prefill",
"nixl_decode",
"pd_master",
"config_server",
"visual",
"visual_only",
],
default="normal",
help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode,
config_server is for pd split mode used to register pd_master node, and get pd_master node list,
Expand Down Expand Up @@ -593,6 +603,40 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=0.03,
help="""The interval of the schedule time, default is 30ms.""",
)
parser.add_argument(
"--image_embed_dir",
type=str,
default=None,
help="path for vit embed",
)
parser.add_argument(
"--enable_remote_vit",
action="store_true",
help="Whether to enable remote vit for multimodal service.",
)
parser.add_argument(
"--remote_vit_port",
type=int,
default=12346,
help="The port number for the remote vit service.",
)
parser.add_argument(
"--redis_port",
type=int,
default=6379,
help="The port number for the redis service in config_server mode.",
)
parser.add_argument(
"--redis_evict_fraction",
type=float,
default=0.3,
help="The evict fraction for the redis service in config_server mode.",
)
parser.add_argument(
"--start_redis",
action="store_true",
help="Whether to start the redis service in config_server mode.",
)
parser.add_argument(
"--enable_cpu_cache",
action="store_true",
Expand Down
20 changes: 18 additions & 2 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from .multimodal_params import MultimodalParams
from .httpserver.manager import HttpServerManager
from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster
from .api_lightllm import lightllm_get_score
from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding
from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size
from lightllm.utils.log_utils import init_logger
from lightllm.utils.error_utils import ServerBusyError
Expand Down Expand Up @@ -92,6 +92,8 @@ def set_args(self, args: StartArgs):
self.httpserver_manager = HttpServerManagerForPDMaster(
args=args,
)
elif args.run_mode == "visual":
self.metric_client = MetricClient(args.metric_port)
else:
init_tokenizer(args) # for openai api
SamplingParams.load_generation_cfg(args.model_dir)
Expand Down Expand Up @@ -136,7 +138,7 @@ def get_model_name():
@app.get("/health", summary="Check server health")
@app.head("/health", summary="Check server health")
async def healthcheck(request: Request):
if g_objs.args.run_mode == "pd_master":
if g_objs.args.run_mode in ["pd_master", "visual"]:
return JSONResponse({"message": "Ok"}, status_code=200)

if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true":
Expand Down Expand Up @@ -221,6 +223,18 @@ async def get_score(request: Request) -> Response:
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))


@app.post("/get_image_embedding")
async def get_image_embed(request: Request) -> Response:
try:
return await lightllm_get_image_embedding(request, g_objs.httpserver_manager)
except ServerBusyError as e:
logger.error("%s", str(e), exc_info=True)
return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e))
except Exception as e:
logger.error("An error occurred: %s", str(e), exc_info=True)
return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e))


@app.post("/")
async def compat_generate(request: Request) -> Response:
if get_env_start_args().run_mode in ["prefill", "decode", "nixl_prefill", "nixl_decode"]:
Expand Down Expand Up @@ -359,6 +373,8 @@ async def startup_event():
logger.info("server start up")
loop = asyncio.get_event_loop()
g_objs.set_args(get_env_start_args())
if g_objs.httpserver_manager is None:
return
loop.create_task(g_objs.httpserver_manager.handle_loop())
logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}")
return
18 changes: 17 additions & 1 deletion lightllm/server/api_lightllm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
from typing import AsyncGenerator
from fastapi import BackgroundTasks, Request
from fastapi.responses import Response, StreamingResponse
from fastapi.responses import Response, StreamingResponse, JSONResponse
from lightllm.server.core.objs.sampling_params import SamplingParams
from .multimodal_params import MultimodalParams
from .httpserver.manager import HttpServerManager
Expand Down Expand Up @@ -150,3 +150,19 @@ async def stream_results() -> AsyncGenerator[bytes, None]:

background_tasks = BackgroundTasks()
return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks)


async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response:
request_dict = await request.json()
# request_dict: {'parameters': {'max_new_tokens': 128},
# 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}}
sample_params_dict = request_dict["parameters"]
sampling_params = SamplingParams()
sampling_params.init(tokenizer=None, **sample_params_dict)
sampling_params.verify()
multimodal_params_dict = request_dict.get("multimodal_params", {})
multimodal_params = MultimodalParams(**multimodal_params_dict)

await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request)

return JSONResponse({"message": "OK"}, status_code=200)
4 changes: 3 additions & 1 deletion lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess
parser = make_argument_parser()
args = parser.parse_args()
from .api_start import pd_master_start, normal_or_p_d_start, config_server_start
from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start

if args.run_mode == "pd_master":
pd_master_start(args)
elif args.run_mode == "config_server":
config_server_start(args)
elif args.run_mode == "visual":
visual_start(args)
else:
normal_or_p_d_start(args)
Loading
Loading