Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import bisect
import triton
from typing import Optional
from tqdm import tqdm
from lightllm.utils.log_utils import init_logger
from lightllm.utils.envs_utils import get_env_start_args
from lightllm.distributed import dist_group_manager
Expand Down Expand Up @@ -197,7 +198,11 @@ def warmup(self, model):
model: TpPartBaseModel = model

# decode cuda graph init
for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs")
for batch_size in progress_bar:
avail_mem, _ = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB")
seq_len = 2
total_token_num = batch_size * seq_len
max_len_in_batch = self.graph_max_len_in_batch
Expand Down Expand Up @@ -252,7 +257,13 @@ def warmup_overlap(self, model):

model: TpPartBaseModel = model

for batch_size in self.cuda_graph_batch_sizes[::-1]:
progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs")
for batch_size in progress_bar:
avail_mem, _ = torch.cuda.mem_get_info()
avail_mem_gb = avail_mem / (1024 ** 3)
progress_bar.set_description(
f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB"
)
decode_batches = []
for micro_batch_index in [0, 1]:
# dummy decoding, capture the cudagraph
Expand Down
6 changes: 3 additions & 3 deletions lightllm/common/triton_utils/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _try_load_cache(self, static_key):

cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
if os.path.exists(cache_file):
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
logger.info(f"Loading cached configs for {self.kernel_name} - {dict(static_key)}")
with open(cache_file, "rb") as f:
self.cached_configs[static_key] = orjson.loads(f.read())
return True
Expand Down Expand Up @@ -353,9 +353,9 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
option=orjson.OPT_INDENT_2 | orjson.OPT_SORT_KEYS | orjson.OPT_NON_STR_KEYS,
)
)
logger.info(f"Saved configs for {self.kernel_name} - {_static_key}")
logger.info(f"Saved configs for {self.kernel_name} - {dict(_static_key)}")

logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {static_key} finished")
logger.info(f"rank {rank_id} tuning {self.kernel_name} _static_key {dict(static_key)} finished")

def _mutate_args_clone(self, args, kwargs):
origin_list = []
Expand Down
31 changes: 31 additions & 0 deletions lightllm/server/access_log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
_ACCESS_LOG_RESET = "\033[0m"


class _AccessLogMiddleware:
def __init__(self, app, logger):
self.app = app
self.logger = logger

async def __call__(self, scope, receive, send):
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return

status_holder = {"status": 0}

async def send_wrapper(message):
if message["type"] == "http.response.start":
status_holder["status"] = message["status"]
await send(message)

try:
await self.app(scope, receive, send_wrapper)
finally:
if scope["type"] == "http":
status = status_holder["status"]
msg = f"{scope['method']} {scope['path']} {status}"
color = _ACCESS_LOG_STATUS_COLORS.get(status // 100, "")
if color:
msg = color + msg + _ACCESS_LOG_RESET
self.logger.info(msg)
37 changes: 2 additions & 35 deletions lightllm/server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from lightllm.utils.error_utils import ClientDisconnected, ServerBusyError
from lightllm.server.metrics.manager import MetricClient
from lightllm.utils.envs_utils import get_unique_server_name
from lightllm.server.access_log import _AccessLogMiddleware
from dataclasses import dataclass

from .api_openai import chat_completions_impl, completions_impl
Expand Down Expand Up @@ -115,41 +116,7 @@ def set_args(self, args: StartArgs):

app = FastAPI()
g_objs.app = app

_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
_ACCESS_LOG_STATUS_COLORS = {2: "\033[32m", 3: "\033[36m", 4: "\033[33m", 5: "\033[31m"}
_ACCESS_LOG_RESET = "\033[0m"


class _AccessLogMiddleware:
def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
if scope["type"] not in ("http", "websocket"):
await self.app(scope, receive, send)
return

status_holder = {"status": 0}

async def send_wrapper(message):
if message["type"] == "http.response.start":
status_holder["status"] = message["status"]
await send(message)

try:
await self.app(scope, receive, send_wrapper)
finally:
if scope["type"] == "http":
status = status_holder["status"]
msg = f"{scope['method']} {scope['path']} {status}"
color = _ACCESS_LOG_STATUS_COLORS.get(status // 100, "")
if color:
msg = color + msg + _ACCESS_LOG_RESET
logger.info(msg)


app.add_middleware(_AccessLogMiddleware)
app.add_middleware(_AccessLogMiddleware, logger=logger)


def create_error_response(
Expand Down
6 changes: 0 additions & 6 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,6 @@ def normal_or_p_d_start(args):
f"{args.host}:{args.port}",
"--log-level",
"info",
"--access-logfile",
"-",
"--error-logfile",
"-",
"lightllm.server.api_http:app",
Expand Down Expand Up @@ -566,8 +564,6 @@ def pd_master_start(args):
f"{args.host}:{args.port}",
"--log-level",
"info",
"--access-logfile",
"-",
"--error-logfile",
"-",
"lightllm.server.api_http:app",
Expand Down Expand Up @@ -662,8 +658,6 @@ def config_server_start(args):
f"{args.config_server_host}:{args.config_server_port}",
"--log-level",
"info",
"--access-logfile",
"-",
"--error-logfile",
"-",
"lightllm.server.config_server.api_http:app",
Expand Down
2 changes: 2 additions & 0 deletions lightllm/server/config_server/api_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, List
from fastapi.responses import JSONResponse
from lightllm.utils.log_utils import init_logger
from lightllm.server.access_log import _AccessLogMiddleware
from lightllm.server.visualserver.objs import VIT_Obj
from ..pd_io_struct import PD_Master_Obj
from .nccl_tcp_store import start_tcp_store_server
Expand All @@ -18,6 +19,7 @@

logger = init_logger(__name__)
app = FastAPI()
app.add_middleware(_AccessLogMiddleware, logger=logger)

registered_pd_master_objs: Dict[str, PD_Master_Obj] = {}
registered_visual_server_objs: Dict[str, VIT_Obj] = {}
Expand Down
4 changes: 2 additions & 2 deletions lightllm/server/detokenization/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _add_new_group_req_index(self, recv_obj: GroupReqIndexes):
req.link_prompt_ids_shm_array()
req.link_logprobs_shm_array()

logger.info(
logger.debug(
f"detokenization recv req id {req.request_id} " f"cost time {time.time() - recv_obj.time_mark} s"
)

Expand Down Expand Up @@ -160,7 +160,7 @@ def remove_finished_reqs(self):

for decode_req in finished_reqs:
decode_req.req.can_released_mark = True
logger.info(f"detoken release req id {decode_req.req.request_id}")
logger.debug(f"detoken release req id {decode_req.req.request_id}")
self.shm_req_manager.put_back_req_obj(decode_req.req)
self.req_id_to_out.pop(decode_req.request_id, None)
return
Expand Down
Loading
Loading