From 9f6a390ad078c7579c0835947aeee08e060ee8e8 Mon Sep 17 00:00:00 2001 From: wooway777 Date: Thu, 18 Jun 2026 19:38:10 +0800 Subject: [PATCH] feat: adjust sampling defaults --- examples/bench.py | 26 +++++++++++++- examples/test_infer.py | 6 ++-- python/infinilm/base_config.py | 6 ++-- python/infinilm/config/engine_config.py | 6 ++-- python/infinilm/llm/llm.py | 42 +++++++++++++++------- python/infinilm/server/inference_server.py | 19 ++++++---- 6 files changed, 77 insertions(+), 28 deletions(-) diff --git a/examples/bench.py b/examples/bench.py index 37ad326c0..5754bbf6c 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -1,7 +1,11 @@ import infinicore from infinilm.modeling_utils import load_model_state_dict_by_file from infinilm.distributed import DistConfig -from infinilm.infer_engine import GenerationConfig, InferEngine +from infinilm.infer_engine import ( + GenerationConfig, + InferEngine, + read_hf_generation_config, +) from infinilm.base_config import BaseConfig from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig from infinilm.processors import AutoInfinilmProcessor @@ -88,6 +92,23 @@ def read_json_file(file_path): return json.load(file) +def resolve_generation_defaults(model_path, top_k, top_p, temperature): + generation_config = read_hf_generation_config(model_path) + + def resolve(value, name, fallback, cast): + if value is None: + value = generation_config.get(name) + if value is None: + value = fallback + return cast(value) + + return ( + resolve(top_k, "top_k", 1, int), + resolve(top_p, "top_p", 1.0, float), + resolve(temperature, "temperature", 1.0, float), + ) + + def get_test_cases( model_path: str, batch_size_list: list[int], @@ -286,6 +307,9 @@ def run( enable_paged_attn = cfg.enable_paged_attn enable_graph = cfg.enable_graph attn_backend = cfg.attn + cfg.top_k, cfg.top_p, cfg.temperature = resolve_generation_defaults( + model_path, cfg.top_k, cfg.top_p, cfg.temperature + ) if isinstance(batch_size, int): batch_size = [batch_size] diff --git a/examples/test_infer.py b/examples/test_infer.py index dffe64d7f..8c78be6be 100644 --- a/examples/test_infer.py +++ b/examples/test_infer.py @@ -12,9 +12,9 @@ def test( tp=1, enable_paged_attn=False, enable_graph=False, - top_k=1, - top_p=1.0, - temperature=1.0, + top_k=None, + top_p=None, + temperature=None, attn_backend="default", use_mla=False, image_path=None, diff --git a/python/infinilm/base_config.py b/python/infinilm/base_config.py index 6b5515275..4c0591ae7 100644 --- a/python/infinilm/base_config.py +++ b/python/infinilm/base_config.py @@ -176,9 +176,9 @@ def _add_common_args(self): self.parser.add_argument( "--prompt", type=str, default="How are you", help="default prompt text" ) - self.parser.add_argument("--top-k", type=int, default=1) - self.parser.add_argument("--top-p", type=float, default=1.0) - self.parser.add_argument("--temperature", type=float, default=1.0) + self.parser.add_argument("--top-k", type=int, default=None) + self.parser.add_argument("--top-p", type=float, default=None) + self.parser.add_argument("--temperature", type=float, default=None) # --- debug --- self.parser.add_argument("--warmup", action="store_true") diff --git a/python/infinilm/config/engine_config.py b/python/infinilm/config/engine_config.py index 044cfdda4..0dc2c45fb 100644 --- a/python/infinilm/config/engine_config.py +++ b/python/infinilm/config/engine_config.py @@ -37,9 +37,9 @@ class EngineConfig: num_blocks: int = 512 block_size: int = 256 max_cache_len: int = 4096 - temperature: float = 1.0 - top_p: float = 0.8 - top_k: int = 1 + temperature: float | None = None + top_p: float | None = None + top_k: int | None = None enable_graph: bool = False attn_backend: str = "default" use_mla: bool = False diff --git a/python/infinilm/llm/llm.py b/python/infinilm/llm/llm.py index f8171c0db..c42d39f81 100644 --- a/python/infinilm/llm/llm.py +++ b/python/infinilm/llm/llm.py @@ -29,15 +29,33 @@ from infinilm.multimodal.multimodal import resolve_multimodal_inputs from infinilm.config.kv_transfer import KVTransferConfig from infinilm.config.engine_config import EngineConfig +from infinilm.infer_engine import read_hf_generation_config from infinilm.kv_connector import KVConnectorRole, KVConnectorFactory logger = logging.getLogger(__name__) +def _resolve_generation_defaults(config: EngineConfig) -> None: + generation_config = read_hf_generation_config(config.model_path) + + def resolve(name: str, fallback, cast): + value = getattr(config, name) + if value is None: + value = generation_config.get(name) + if value is None: + value = fallback + setattr(config, name, cast(value)) + + resolve("top_k", 1, int) + resolve("top_p", 1.0, float) + resolve("temperature", 1.0, float) + + class LLMEngine: """Low-level LLM engine that handles inference execution.""" def __init__(self, config: EngineConfig): + _resolve_generation_defaults(config) self.config = config self.model_runner = ModelRunner(config) @@ -296,9 +314,9 @@ def __init__( num_blocks: int = 512, block_size: int = 256, max_cache_len: int = 4096, - temperature: float = 1.0, - top_p: float = 0.8, - top_k: int = 1, + temperature: float | None = None, + top_p: float | None = None, + top_k: int | None = None, enable_graph: bool = False, attn_backend: str = "default", use_mla: bool = False, @@ -490,9 +508,9 @@ def __init__( num_blocks: int = 512, block_size: int = 256, max_cache_len: int = 4096, - temperature: float = 1.0, - top_p: float = 0.8, - top_k: int = 1, + temperature: float | None = None, + top_p: float | None = None, + top_k: int | None = None, enable_graph: bool = False, attn_backend: str = "default", kv_transfer_config: Optional[KVTransferConfig] = None, @@ -715,13 +733,13 @@ def add_request( elif prompt is not None: prompt_token_ids = self.engine.tokenize(prompt) else: - assert messages is not None, ( - "Either messages or prompt/prompt_token_ids must be provided" - ) + assert ( + messages is not None + ), "Either messages or prompt/prompt_token_ids must be provided" - assert apply_chat_template, ( - "apply_chat_template needs to be true for multi-role conversation" - ) + assert ( + apply_chat_template + ), "apply_chat_template needs to be true for multi-role conversation" prompt = self.engine.apply_chat_template( messages, add_generation_prompt=add_generation_prompt diff --git a/python/infinilm/server/inference_server.py b/python/infinilm/server/inference_server.py index 645b00656..40253bdc8 100644 --- a/python/infinilm/server/inference_server.py +++ b/python/infinilm/server/inference_server.py @@ -103,9 +103,9 @@ def __init__( num_blocks: int = 512, block_size: int = 256, max_cache_len: int = 4096, - temperature: float = 1.0, - top_p: float = 0.8, - top_k: int = 1, + temperature: float | None = None, + top_p: float | None = None, + top_k: int | None = None, host: str = "0.0.0.0", port: int = 8000, enable_graph: bool = False, @@ -194,6 +194,9 @@ async def lifespan(app: FastAPI): use_mla=self.use_mla, kv_transfer_config=self.kv_transfer_config, ) + self.temperature = self.engine.config.temperature + self.top_p = self.engine.config.top_p + self.top_k = self.engine.config.top_k self.engine.start() logger.info(f"Engine initialized with model at {self.model_path}") logger.info(f" enable_graph: {self.enable_graph}") @@ -337,10 +340,14 @@ def pick(key: str, default): if isinstance(stop, str): stop = [stop] + temperature = pick("temperature", self.temperature) + top_p = pick("top_p", self.top_p) + top_k = pick("top_k", self.top_k) + return SamplingParams( - temperature=float(pick("temperature", self.temperature)), - top_p=float(pick("top_p", self.top_p)), - top_k=int(pick("top_k", self.top_k)), + temperature=float(1.0 if temperature is None else temperature), + top_p=float(1.0 if top_p is None else top_p), + top_k=int(1 if top_k is None else top_k), max_tokens=int(max_tokens) if max_tokens is not None else None, stop=stop, ignore_eos=self.ignore_eos,