Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion examples/bench.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import infinicore
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.infer_engine import (
GenerationConfig,
InferEngine,
read_hf_generation_config,
)
from infinilm.base_config import BaseConfig
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
from infinilm.processors import AutoInfinilmProcessor
Expand Down Expand Up @@ -88,6 +92,23 @@ def read_json_file(file_path):
return json.load(file)


def resolve_generation_defaults(model_path, top_k, top_p, temperature):
generation_config = read_hf_generation_config(model_path)

def resolve(value, name, fallback, cast):
if value is None:
value = generation_config.get(name)
if value is None:
value = fallback
return cast(value)

return (
resolve(top_k, "top_k", 1, int),
resolve(top_p, "top_p", 1.0, float),
resolve(temperature, "temperature", 1.0, float),
)


def get_test_cases(
model_path: str,
batch_size_list: list[int],
Expand Down Expand Up @@ -286,6 +307,9 @@ def run(
enable_paged_attn = cfg.enable_paged_attn
enable_graph = cfg.enable_graph
attn_backend = cfg.attn
cfg.top_k, cfg.top_p, cfg.temperature = resolve_generation_defaults(
model_path, cfg.top_k, cfg.top_p, cfg.temperature
)

if isinstance(batch_size, int):
batch_size = [batch_size]
Expand Down
6 changes: 3 additions & 3 deletions examples/test_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def test(
tp=1,
enable_paged_attn=False,
enable_graph=False,
top_k=1,
top_p=1.0,
temperature=1.0,
top_k=None,
top_p=None,
temperature=None,
attn_backend="default",
use_mla=False,
image_path=None,
Expand Down
6 changes: 3 additions & 3 deletions python/infinilm/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def _add_common_args(self):
self.parser.add_argument(
"--prompt", type=str, default="How are you", help="default prompt text"
)
self.parser.add_argument("--top-k", type=int, default=1)
self.parser.add_argument("--top-p", type=float, default=1.0)
self.parser.add_argument("--temperature", type=float, default=1.0)
self.parser.add_argument("--top-k", type=int, default=None)
self.parser.add_argument("--top-p", type=float, default=None)
self.parser.add_argument("--temperature", type=float, default=None)

# --- debug ---
self.parser.add_argument("--warmup", action="store_true")
Expand Down
6 changes: 3 additions & 3 deletions python/infinilm/config/engine_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ class EngineConfig:
num_blocks: int = 512
block_size: int = 256
max_cache_len: int = 4096
temperature: float = 1.0
top_p: float = 0.8
top_k: int = 1
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
enable_graph: bool = False
attn_backend: str = "default"
use_mla: bool = False
Expand Down
42 changes: 30 additions & 12 deletions python/infinilm/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,33 @@
from infinilm.multimodal.multimodal import resolve_multimodal_inputs
from infinilm.config.kv_transfer import KVTransferConfig
from infinilm.config.engine_config import EngineConfig
from infinilm.infer_engine import read_hf_generation_config
from infinilm.kv_connector import KVConnectorRole, KVConnectorFactory

logger = logging.getLogger(__name__)


def _resolve_generation_defaults(config: EngineConfig) -> None:
generation_config = read_hf_generation_config(config.model_path)

def resolve(name: str, fallback, cast):
value = getattr(config, name)
if value is None:
value = generation_config.get(name)
if value is None:
value = fallback
setattr(config, name, cast(value))

resolve("top_k", 1, int)
resolve("top_p", 1.0, float)
resolve("temperature", 1.0, float)


class LLMEngine:
"""Low-level LLM engine that handles inference execution."""

def __init__(self, config: EngineConfig):
_resolve_generation_defaults(config)
self.config = config

self.model_runner = ModelRunner(config)
Expand Down Expand Up @@ -296,9 +314,9 @@ def __init__(
num_blocks: int = 512,
block_size: int = 256,
max_cache_len: int = 4096,
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
temperature: float | None = None,
top_p: float | None = None,
top_k: int | None = None,
enable_graph: bool = False,
attn_backend: str = "default",
use_mla: bool = False,
Expand Down Expand Up @@ -490,9 +508,9 @@ def __init__(
num_blocks: int = 512,
block_size: int = 256,
max_cache_len: int = 4096,
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
temperature: float | None = None,
top_p: float | None = None,
top_k: int | None = None,
enable_graph: bool = False,
attn_backend: str = "default",
kv_transfer_config: Optional[KVTransferConfig] = None,
Expand Down Expand Up @@ -715,13 +733,13 @@ def add_request(
elif prompt is not None:
prompt_token_ids = self.engine.tokenize(prompt)
else:
assert messages is not None, (
"Either messages or prompt/prompt_token_ids must be provided"
)
assert (
messages is not None
), "Either messages or prompt/prompt_token_ids must be provided"

assert apply_chat_template, (
"apply_chat_template needs to be true for multi-role conversation"
)
assert (
apply_chat_template
), "apply_chat_template needs to be true for multi-role conversation"

prompt = self.engine.apply_chat_template(
messages, add_generation_prompt=add_generation_prompt
Expand Down
19 changes: 13 additions & 6 deletions python/infinilm/server/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def __init__(
num_blocks: int = 512,
block_size: int = 256,
max_cache_len: int = 4096,
temperature: float = 1.0,
top_p: float = 0.8,
top_k: int = 1,
temperature: float | None = None,
top_p: float | None = None,
top_k: int | None = None,
host: str = "0.0.0.0",
port: int = 8000,
enable_graph: bool = False,
Expand Down Expand Up @@ -194,6 +194,9 @@ async def lifespan(app: FastAPI):
use_mla=self.use_mla,
kv_transfer_config=self.kv_transfer_config,
)
self.temperature = self.engine.config.temperature
self.top_p = self.engine.config.top_p
self.top_k = self.engine.config.top_k
self.engine.start()
logger.info(f"Engine initialized with model at {self.model_path}")
logger.info(f" enable_graph: {self.enable_graph}")
Expand Down Expand Up @@ -337,10 +340,14 @@ def pick(key: str, default):
if isinstance(stop, str):
stop = [stop]

temperature = pick("temperature", self.temperature)
top_p = pick("top_p", self.top_p)
top_k = pick("top_k", self.top_k)

return SamplingParams(
temperature=float(pick("temperature", self.temperature)),
top_p=float(pick("top_p", self.top_p)),
top_k=int(pick("top_k", self.top_k)),
temperature=float(1.0 if temperature is None else temperature),
top_p=float(1.0 if top_p is None else top_p),
top_k=int(1 if top_k is None else top_k),
max_tokens=int(max_tokens) if max_tokens is not None else None,
stop=stop,
ignore_eos=self.ignore_eos,
Expand Down