From f4e5432511370309389399e870aa01cad6b0624f Mon Sep 17 00:00:00 2001 From: shihaobai <1798930569@qq.com> Date: Mon, 9 Mar 2026 07:36:07 +0000 Subject: [PATCH] refactor max_tokens --- lightllm/server/api_models.py | 11 ++++++++-- lightllm/server/api_openai.py | 17 +++++++++++---- .../server/core/objs/py_sampling_params.py | 21 +++++++++++-------- lightllm/server/core/objs/sampling_params.py | 12 ++++++----- lightllm/server/httpserver/manager.py | 3 +++ 5 files changed, 44 insertions(+), 20 deletions(-) diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index e194a24d0f..82a57ab1c3 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -1,4 +1,5 @@ import time +from typing_extensions import deprecated import uuid from pydantic import BaseModel, Field, field_validator, model_validator @@ -114,7 +115,10 @@ class CompletionRequest(BaseModel): # prompt: string or tokens prompt: Union[str, List[str], List[int], List[List[int]]] suffix: Optional[str] = None - max_tokens: Optional[int] = 8192 + max_tokens: Optional[int] = Field( + default=None, deprecated="max_tokens is deprecated, please use max_completion_tokens instead" + ) + max_completion_tokens: Optional[int] = None temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 n: Optional[int] = 1 @@ -187,7 +191,10 @@ class ChatCompletionRequest(BaseModel): stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None stop: Optional[Union[str, List[str]]] = None - max_tokens: Optional[int] = 8192 + max_tokens: Optional[int] = Field( + default=None, deprecated="max_tokens is deprecated, please use max_completion_tokens instead" + ) + max_completion_tokens: Optional[int] = None presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 11e24612b0..f24cddc331 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -203,14 +203,19 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req "top_p": request.top_p, "top_k": request.top_k, "ignore_eos": request.ignore_eos, - "max_new_tokens": request.max_tokens, - "stop_sequences": request.stop, "n": request.n, "best_of": request.n, "add_special_tokens": False, "seed": request.seed, } + if request.max_completion_tokens is not None: + sampling_params_dict["max_new_tokens"] = request.max_completion_tokens + if request.max_tokens is not None: + sampling_params_dict["max_new_tokens"] = request.max_tokens + if request.stop is not None: + sampling_params_dict["stop_sequences"] = request.stop + # Structured output handling if request.response_format: if request.response_format.type == "json_schema": @@ -533,13 +538,17 @@ async def completions_impl(request: CompletionRequest, raw_request: Request) -> "top_p": request.top_p, "top_k": request.top_k, "ignore_eos": request.ignore_eos, - "max_new_tokens": request.max_tokens, - "stop_sequences": request.stop, "n": request.n, "best_of": request.best_of, "add_special_tokens": False, "seed": request.seed, } + if request.max_completion_tokens is not None: + sampling_params_dict["max_new_tokens"] = request.max_completion_tokens + if request.max_tokens is not None: + sampling_params_dict["max_new_tokens"] = request.max_tokens + if request.stop is not None: + sampling_params_dict["stop_sequences"] = request.stop if request.response_format: if request.response_format.type == "json_schema": diff --git a/lightllm/server/core/objs/py_sampling_params.py b/lightllm/server/core/objs/py_sampling_params.py index 51c6add0a5..840ffd8bc9 100644 --- a/lightllm/server/core/objs/py_sampling_params.py +++ b/lightllm/server/core/objs/py_sampling_params.py @@ -38,7 +38,7 @@ def __init__( top_k: int = None, # -1 is for all ignore_eos: bool = False, image_max_patch_num: int = -1, - max_new_tokens: int = 16, + max_new_tokens: int = -1, min_new_tokens: int = 1, stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件 skip_special_tokens: bool = True, # whether to skip special tokens when decoding @@ -141,14 +141,6 @@ def verify(self): raise ValueError(f"top_p must in (0.0, 1.0], got {self.top_p}") if self.top_k < -1 or self.top_k == 0: raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") - if self.max_new_tokens < 1: - raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") - if self.min_new_tokens < 1: - raise ValueError(f"min_new_tokens must be at least 1 , got {self.min_new_tokens}.") - if self.min_new_tokens > self.max_new_tokens: - raise ValueError( - f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}." - ) if len(self.exponential_decay_length_penalty) != 2: raise ValueError( @@ -201,6 +193,17 @@ def verify(self): return + def verify_length(self): + if self.max_new_tokens < 1: + raise ValueError(f"max_new_tokens must be at least 1, got {self.max_new_tokens}.") + if self.min_new_tokens < 1: + raise ValueError(f"min_new_tokens must be at least 1, got {self.min_new_tokens}.") + if self.min_new_tokens > self.max_new_tokens: + raise ValueError( + f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}." + ) + return + def _verify_allowed_token_ids(self): if self.allowed_token_ids is not None: if (not isinstance(self.allowed_token_ids, list)) or ( diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index b9ad314dd9..fc98da4693 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -345,7 +345,7 @@ def init(self, tokenizer, **kwargs): self.top_k = kwargs.get("top_k", SamplingParams._top_k) self.ignore_eos = kwargs.get("ignore_eos", False) self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) - self.max_new_tokens = kwargs.get("max_new_tokens", 16) + self.max_new_tokens = kwargs.get("max_new_tokens", -1) self.min_new_tokens = kwargs.get("min_new_tokens", 1) self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY) self.group_request_id = kwargs.get("group_request_id", -1) @@ -439,6 +439,12 @@ def verify(self): raise ValueError(f"top_p must be in (0.0, 1.0], got {self.top_p}") if self.top_k < -1 or self.top_k == 0: raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") + self._verify_allowed_token_ids() + self._verify_grammar_constraint() + + return + + def verify_length(self): if self.max_new_tokens < 1: raise ValueError(f"max_new_tokens must be at least 1, got {self.max_new_tokens}.") if self.min_new_tokens < 1: @@ -447,10 +453,6 @@ def verify(self): raise ValueError( f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}." ) - - self._verify_allowed_token_ids() - self._verify_grammar_constraint() - return def _verify_grammar_constraint(self): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index d51f88cdda..d084436ac3 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -456,6 +456,8 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: if not prompt_ids: raise ValueError("prompt_ids is empty") prompt_tokens = len(prompt_ids) + if sampling_params.max_new_tokens == -1: + sampling_params.max_new_tokens = self.max_req_total_len - prompt_tokens if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len: # use long_truncation_mode to truncate long input len req. if self.args.long_truncation_mode is None: @@ -472,6 +474,7 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: assert prompt_tokens == req_input_len else: assert False, "error args" + sampling_params.verify_length() # last repaired req_total_len = len(prompt_ids) + sampling_params.max_new_tokens