Skip to content

Commit 14ba4a6

Browse files
authored
openai api: max_tokens -> max_completion_tokens (#1222)
1 parent 60c18fe commit 14ba4a6

5 files changed

Lines changed: 33 additions & 15 deletions

File tree

lightllm/models/deepseek3_2/triton_kernel/act_quant.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
# Adapted from https://github.com/sgl-project/sglang/blob/ce6b17c0f94e6bf53633c8f324176a891e67fa7f/python/sglang/srt/layers/attention/nsa/triton_kernel.py
1+
# Adapted from sglang Triton kernel implementation:
2+
# https://github.com/sgl-project/sglang/blob/ce6b17c0f94e6bf53633c8f324176a891e67fa7f/
3+
# python/sglang/srt/layers/attention/nsa/triton_kernel.py
24
from typing import Optional, Tuple
35

46
import torch
@@ -91,7 +93,8 @@ def act_quant(
9193
Quantizes the input tensor `x` using block-wise quantization with Triton.
9294
9395
Args:
94-
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
96+
x (torch.Tensor): The input tensor to be quantized. Must be
97+
contiguous and its last dimension size must be divisible by `block_size`.
9598
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
9699
scale_fmt (Optional[str], optional): The format of the scale. Default is None.
97100
Returns:

lightllm/server/api_models.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time
2+
from typing_extensions import deprecated
23
import uuid
34

45
from pydantic import BaseModel, Field, field_validator, model_validator
@@ -119,7 +120,10 @@ class CompletionRequest(BaseModel):
119120
# prompt: string or tokens
120121
prompt: Union[str, List[str], List[int], List[List[int]]]
121122
suffix: Optional[str] = None
122-
max_tokens: Optional[int] = 8192
123+
max_tokens: Optional[int] = Field(
124+
default=16384, deprecated="max_tokens is deprecated, please use max_completion_tokens instead"
125+
)
126+
max_completion_tokens: Optional[int] = None
123127
temperature: Optional[float] = 1.0
124128
top_p: Optional[float] = 1.0
125129
n: Optional[int] = 1
@@ -192,7 +196,10 @@ class ChatCompletionRequest(BaseModel):
192196
stream: Optional[bool] = False
193197
stream_options: Optional[StreamOptions] = None
194198
stop: Optional[Union[str, List[str]]] = None
195-
max_tokens: Optional[int] = 8192
199+
max_tokens: Optional[int] = Field(
200+
default=16384, deprecated="max_tokens is deprecated, please use max_completion_tokens instead"
201+
)
202+
max_completion_tokens: Optional[int] = None
196203
presence_penalty: Optional[float] = 0.0
197204
frequency_penalty: Optional[float] = 0.0
198205
logit_bias: Optional[Dict[str, float]] = None

lightllm/server/api_openai.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,19 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
233233
"top_p": request.top_p,
234234
"top_k": request.top_k,
235235
"ignore_eos": request.ignore_eos,
236-
"max_new_tokens": request.max_tokens,
237-
"stop_sequences": request.stop,
238236
"n": request.n,
239237
"best_of": request.n,
240238
"add_special_tokens": False,
241239
"seed": request.seed,
242240
}
243241

242+
if request.max_completion_tokens is not None:
243+
sampling_params_dict["max_new_tokens"] = request.max_completion_tokens
244+
elif request.max_tokens is not None:
245+
sampling_params_dict["max_new_tokens"] = request.max_tokens
246+
if request.stop is not None:
247+
sampling_params_dict["stop_sequences"] = request.stop
248+
244249
# Structured output handling
245250
if request.response_format:
246251
if request.response_format.type == "json_schema":
@@ -571,13 +576,17 @@ async def completions_impl(request: CompletionRequest, raw_request: Request) ->
571576
"top_p": request.top_p,
572577
"top_k": request.top_k,
573578
"ignore_eos": request.ignore_eos,
574-
"max_new_tokens": request.max_tokens,
575-
"stop_sequences": request.stop,
576579
"n": request.n,
577580
"best_of": request.best_of,
578581
"add_special_tokens": False,
579582
"seed": request.seed,
580583
}
584+
if request.max_completion_tokens is not None:
585+
sampling_params_dict["max_new_tokens"] = request.max_completion_tokens
586+
elif request.max_tokens is not None:
587+
sampling_params_dict["max_new_tokens"] = request.max_tokens
588+
if request.stop is not None:
589+
sampling_params_dict["stop_sequences"] = request.stop
581590

582591
if request.response_format:
583592
if request.response_format.type == "json_schema":

lightllm/server/core/objs/py_sampling_params.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
top_k: int = None, # -1 is for all
3939
ignore_eos: bool = False,
4040
image_max_patch_num: int = -1,
41-
max_new_tokens: int = 16,
41+
max_new_tokens: int = 16384,
4242
min_new_tokens: int = 1,
4343
stop_sequences: Optional[Union[str, List[str], List[List[int]]]] = None, # 停止句子条件
4444
skip_special_tokens: bool = True, # whether to skip special tokens when decoding
@@ -142,9 +142,9 @@ def verify(self):
142142
if self.top_k < -1 or self.top_k == 0:
143143
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
144144
if self.max_new_tokens < 1:
145-
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
145+
raise ValueError(f"max_new_tokens must be at least 1, got {self.max_new_tokens}.")
146146
if self.min_new_tokens < 1:
147-
raise ValueError(f"min_new_tokens must be at least 1 , got {self.min_new_tokens}.")
147+
raise ValueError(f"min_new_tokens must be at least 1, got {self.min_new_tokens}.")
148148
if self.min_new_tokens > self.max_new_tokens:
149149
raise ValueError(
150150
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."

lightllm/server/core/objs/sampling_params.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def init(self, tokenizer, **kwargs):
345345
self.top_k = kwargs.get("top_k", SamplingParams._top_k)
346346
self.ignore_eos = kwargs.get("ignore_eos", False)
347347
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
348-
self.max_new_tokens = kwargs.get("max_new_tokens", 16)
348+
self.max_new_tokens = kwargs.get("max_new_tokens", 16384)
349349
self.min_new_tokens = kwargs.get("min_new_tokens", 1)
350350
self.input_penalty = kwargs.get("input_penalty", DEFAULT_INPUT_PENALTY)
351351
self.group_request_id = kwargs.get("group_request_id", -1)
@@ -440,14 +440,13 @@ def verify(self):
440440
if self.top_k < -1 or self.top_k == 0:
441441
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
442442
if self.max_new_tokens < 1:
443-
raise ValueError(f"max_new_tokens must be at least 1, got {self.max_new_tokens}.")
443+
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
444444
if self.min_new_tokens < 1:
445-
raise ValueError(f"min_new_tokens must be at least 1, got {self.min_new_tokens}.")
445+
raise ValueError(f"min_new_tokens must be at least 1 , got {self.min_new_tokens}.")
446446
if self.min_new_tokens > self.max_new_tokens:
447447
raise ValueError(
448448
f"min_new_tokens must <= max_new_tokens, but got min {self.min_new_tokens}, max {self.max_new_tokens}."
449449
)
450-
451450
self._verify_allowed_token_ids()
452451
self._verify_grammar_constraint()
453452

0 commit comments

Comments
 (0)