From 31523779052d8e8f178fea4ed96823d9e0de0d5c Mon Sep 17 00:00:00 2001 From: zhangtaoshan Date: Mon, 18 May 2026 19:39:41 +0800 Subject: [PATCH 1/5] initial GGUF support --- lightllm/common/basemodel/basemodel.py | 92 +++++- .../layer_weights/gguf_load_utils.py | 121 ++++++++ .../meta_weights/mm_weight/mm_weight.py | 6 +- lightllm/common/quantization/__init__.py | 3 + lightllm/common/quantization/awq.py | 16 +- lightllm/common/quantization/deepgemm.py | 8 +- lightllm/common/quantization/gguf.py | 172 +++++++++++ lightllm/common/quantization/no_quant.py | 8 +- .../common/quantization/quantize_method.py | 58 +++- lightllm/common/quantization/w8a8.py | 32 +- lightllm/common/quantization/w8a8gx.py | 16 +- lightllm/common/req_manager.py | 3 +- .../gguf/triton_kernel/dequantization.py | 276 ++++++++++++++++++ lightllm/server/api_cli.py | 20 +- lightllm/server/api_start.py | 36 ++- lightllm/server/build_prompt.py | 7 +- lightllm/server/detokenization/manager.py | 7 +- lightllm/server/httpserver/manager.py | 12 +- .../httpserver_for_pd_master/manager.py | 7 +- .../model_infer/mode_backend/base_backend.py | 3 +- .../impl_for_outlines_constraint_mode.py | 7 +- .../chunked_prefill/impl_for_token_healing.py | 5 +- .../chunked_prefill/impl_for_xgrammar_mode.py | 5 +- lightllm/server/tokenizer.py | 3 + lightllm/utils/config_utils.py | 255 ++++++++++++++-- lightllm/utils/llm_utils.py | 6 +- lightllm/utils/shm_size_check.py | 7 +- 27 files changed, 1114 insertions(+), 77 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/gguf_load_utils.py create mode 100644 lightllm/common/quantization/gguf.py create mode 100644 lightllm/models/gguf/triton_kernel/dequantization.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 05aaaadca8..ab7efc271c 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -16,13 +16,13 @@ from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager from lightllm.common.infer_utils import init_req_to_token_indexes -from lightllm.common.build_utils import repair_config from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token +from lightllm.utils.config_utils import find_gguf_path, get_model_config_dict from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num @@ -58,6 +58,7 @@ def __init__(self, kvargs): self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] self.weight_dir_ = kvargs["weight_dir"] + self.gguf_path_ = find_gguf_path(self.weight_dir_) self.max_total_token_num = kvargs["max_total_token_num"] self.batch_max_tokens = kvargs.get("batch_max_tokens", None) self.load_way = kvargs.get("load_way", "HF") @@ -103,7 +104,10 @@ def __init__(self, kvargs): self._verify_must() self._verify_params() self._init_quant() + self._align_quant_type_for_gguf_weights() + # read gguf and get quant shape + self._init_gguf() self._init_weights() self._init_req_manager() self._init_mem_manager() @@ -144,12 +148,10 @@ def _wait_other_modules_ready(self): return def _init_config(self): - with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: - self.config = json.load(json_file) - # rename keys - repair_config(self.config, same_names=["num_attention_heads", "n_head"]) - repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) - repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + self.config = get_model_config_dict( + config_path=self.args.config_path, + model_dir=self.weight_dir_, + ) if self.finetune_config: self.config["vocab_size"] = self.finetune_config.vocab_size return @@ -168,6 +170,53 @@ def _init_quant(self): self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path) logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") + def _align_quant_type_for_gguf_weights(self): + if self.gguf_path_ is None: + return + if self.quant_cfg.quant_type == "gguf": + return + logger.warning( + f"model_dir contains GGUF weights ({self.gguf_path_!r}) but quant_type is " + f"{self.quant_cfg.quant_type!r}; using quant_type='gguf' instead." + ) + self.quant_cfg.quant_type = "gguf" + + def _init_gguf(self): + if self.gguf_path_ is None: + return + + import numpy as np + from gguf.gguf_reader import GGUFReader + from lightllm.common.basemodel.layer_weights.gguf_load_utils import ( + DEQUANT_KEYS, + build_gguf_to_hf_mapping, + ) + from lightllm.common.quantization.quantize_method import GGUFWeightMeta, numpy_dtype_to_torch + + self._gguf_to_hf = build_gguf_to_hf_mapping(self.config) + self._gguf_reader = GGUFReader(self.gguf_path_) + + gguf_quant_meta_map = {} + for t in self._gguf_reader.tensors: + if t.name in DEQUANT_KEYS: + continue + hf_name = self._gguf_to_hf.get(t.name) + if hf_name is None: + continue + np_data = np.asarray(t.data) + logical_shape = tuple(reversed([int(x) for x in t.shape.tolist()])) + gguf_quant_meta_map[hf_name] = GGUFWeightMeta( + shape=logical_shape, + dtype=numpy_dtype_to_torch(np_data.dtype), + quant_type=t.tensor_type, + ) + self.quant_cfg.gguf_quant_meta_map = gguf_quant_meta_map + + def _release_gguf_reader(self): + if getattr(self, "_gguf_reader", None) is not None: + del self._gguf_reader + self._gguf_reader = None + def _init_weights(self, start_layer_index=0): self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ @@ -182,13 +231,28 @@ def _init_weights(self, start_layer_index=0): return def _load_hf_weights(self): - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) + if self.gguf_path_ is not None: + from lightllm.common.basemodel.layer_weights.gguf_load_utils import load_gguf_weights + + load_gguf_weights( + self.data_type, + weight_dir=self.gguf_path_, + config=self.config, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + reader=getattr(self, "_gguf_reader", None), + gguf_to_hf=getattr(self, "_gguf_to_hf", None), + release_reader=self._release_gguf_reader, + ) + else: + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) self.pre_post_weight.verify_load() [weight.verify_load() for weight in self.trans_layers_weight] return diff --git a/lightllm/common/basemodel/layer_weights/gguf_load_utils.py b/lightllm/common/basemodel/layer_weights/gguf_load_utils.py new file mode 100644 index 0000000000..2cf6c2afa4 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/gguf_load_utils.py @@ -0,0 +1,121 @@ +import gguf +import numpy as np +import torch +from gguf import dequantize +from gguf.gguf_reader import GGUFReader, ReaderTensor +from transformers import AutoModelForCausalLM +from transformers.models.auto.configuration_auto import CONFIG_MAPPING +from typing import Any, Callable, Dict, Optional + +from lightllm.utils.dist_utils import get_current_device_id + +DEQUANT_KEYS = ["token_embd.weight", "output.weight"] + + +def build_gguf_to_hf_mapping(config: Dict[str, Any]) -> Dict[str, str]: + num_layers = config.get("num_hidden_layers") + assert num_layers is not None, "num_hidden_layers is not found in config" + model_type = config.get("model_type") + assert model_type is not None, "model_type is not found in config" + arch = None + for gguf_arch, hf_arch in gguf.MODEL_ARCH_NAMES.items(): + if hf_arch == model_type: + arch = gguf_arch + break + assert arch is not None, "model_type is not found in gguf.MODEL_ARCH_NAMES" + tensor_name_map = gguf.get_tensor_name_map(arch, num_layers) + + config_cls = CONFIG_MAPPING[model_type] + assert config_cls is not None, f"config_cls is not found in CONFIG_MAPPING for model_type={model_type}" + hf_config = config_cls(**config) + with torch.device("meta"): + dummy_model = AutoModelForCausalLM.from_config(hf_config) + gguf_to_hf_name_mapping = {} + for hf_name in dummy_model.state_dict(): + name, extension = hf_name.rsplit(".", 1) + gguf_name = tensor_name_map.get_name(name) + gguf_to_hf_name_mapping[f"{gguf_name}.{extension}"] = hf_name + + return gguf_to_hf_name_mapping + + +def _reader_tensor_to_torch_cpu(rt: ReaderTensor, dequant: bool = False) -> torch.Tensor: + assert rt.shape.ndim <= 2, "GGUF tensor must be 2D or less" + if dequant: + d_tensor = dequantize(rt.data, rt.tensor_type) + else: + d_tensor = rt.data + arr = np.array(d_tensor, copy=True) + + return torch.from_numpy(arr) + + +def _gguf_reader_to_weight_dict(reader: GGUFReader) -> Dict[str, torch.Tensor]: + gguf_weights = {} + for t in reader.tensors: + if t.name in DEQUANT_KEYS: + gguf_weights[t.name] = _reader_tensor_to_torch_cpu(t, dequant=True) + else: + gguf_weights[t.name] = _reader_tensor_to_torch_cpu(t, dequant=False) + return gguf_weights + + +def rename_weights( + weights: Dict[str, torch.Tensor], + config: Dict[str, Any], + gguf_to_hf: Optional[Dict[str, str]] = None, +) -> Dict[str, torch.Tensor]: + if gguf_to_hf is None: + gguf_to_hf = build_gguf_to_hf_mapping(config) + return {gguf_to_hf[k]: v for k, v in weights.items() if k in gguf_to_hf} + + +def load_gguf_weights( + data_type: str, + weight_dir: str, + config: Dict[str, Any], + pre_post_layer: Any = None, + transformer_layer_list: Any = None, + weight_dict: Optional[Dict[str, torch.Tensor]] = None, + reader: Optional[GGUFReader] = None, + gguf_to_hf: Optional[Dict[str, str]] = None, + release_reader: Optional[Callable[[], None]] = None, +) -> None: + if isinstance(data_type, str): + data_type = torch.float16 if data_type == "fp16" else torch.float32 + if pre_post_layer is not None: + assert pre_post_layer.data_type_ == data_type, "type is not right" + if transformer_layer_list is not None: + assert transformer_layer_list[ + 0].data_type_ == data_type, "type is not right" + if weight_dict: + torch.cuda.set_device(get_current_device_id()) + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weight_dict) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weight_dict) + del weight_dict + if release_reader is not None: + release_reader() + return + + need_init_reader = reader is None + if need_init_reader: + reader = GGUFReader(weight_dir) + try: + weights = _gguf_reader_to_weight_dict(reader) + weights = rename_weights(weights, config=config, gguf_to_hf=gguf_to_hf) + finally: + if need_init_reader: + del reader + elif release_reader is not None: + release_reader() + + torch.cuda.set_device(get_current_device_id()) + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(weights) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(weights) + del weights diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 5021699143..87530e7951 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -100,7 +100,11 @@ def _create_weight(self): self.mm_param: WeightPack = None self.mm_param_list: List[WeightPack] = None self.mm_param, self.mm_param_list = self.quant_method.create_weight( - in_dim=self.in_dim, out_dims=self.out_dims, dtype=self.data_type_, device_id=get_current_device_id() + in_dim=self.in_dim, + out_dims=self.out_dims, + dtype=self.data_type_, + device_id=get_current_device_id(), + weight_names=self.weight_names, ) return diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 1f08432c6a..34801955d3 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -6,6 +6,7 @@ from .deepgemm import * from .awq import * from .no_quant import * +from .gguf import * from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -16,6 +17,7 @@ def __init__(self, network_config, quant_type="none", custom_cfg_path=None): self.layer_num = network_config["n_layer"] self.quant_type = quant_type self.network_config_ = network_config + self.gguf_quant_meta_map = None self._parse_custom_cfg(custom_cfg_path) self._parse_network_config(network_config) @@ -80,4 +82,5 @@ def get_quant_method(self, layer_num, name): quant_type = self.get_quant_type(layer_num, name) quant_method = QUANTMETHODS.get(quant_type) quant_method.hf_quantization_config = self.hf_quantization_config + quant_method.gguf_quant_meta_map = self.gguf_quant_meta_map return quant_method diff --git a/lightllm/common/quantization/awq.py b/lightllm/common/quantization/awq.py index f3c7623975..dd917ddc88 100644 --- a/lightllm/common/quantization/awq.py +++ b/lightllm/common/quantization/awq.py @@ -109,7 +109,13 @@ def apply( return out def _create_weight( - self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) group_size = self.hf_quantization_config["group_size"] @@ -206,7 +212,13 @@ def apply( return out def _create_weight( - self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) self.n = out_dim diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 137455a821..779559db76 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -98,7 +98,13 @@ def apply( return out def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims weight_scale_out_dims = [(_out_dim + self.block_size - 1) // self.block_size for _out_dim in out_dims] diff --git a/lightllm/common/quantization/gguf.py b/lightllm/common/quantization/gguf.py new file mode 100644 index 0000000000..17037e0a03 --- /dev/null +++ b/lightllm/common/quantization/gguf.py @@ -0,0 +1,172 @@ +import torch +from gguf import GGMLQuantizationType, quant_shape_from_byte_shape +from gguf.quants import quant_shape_to_byte_shape +from typing import List, Optional, Tuple + +from .registry import QUANTMETHODS +from .quantize_method import GGUFWeightMeta, QuantizationMethod, WeightPack +from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.models.gguf.triton_kernel.dequantization import get_gguf_dequant_fn + +# These types are not quantized, so they are directly used +UNQUANTIZED_TYPES = {GGMLQuantizationType.F32, GGMLQuantizationType.F16} + + +def _linear( + input_tensor: torch.Tensor, + weight: torch.Tensor, + out: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if bias is None: + return torch.mm(input_tensor, weight, out=out) + return torch.addmm(bias, input_tensor, weight, out=out) + + +@QUANTMETHODS.register("gguf", platform="cuda") +class GGUFQuantizationMethod(QuantizationMethod): + + def __init__(self): + super().__init__() + + def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: + raise NotImplementedError("GGUF online quantization is not supported") + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # allocate output tensor if not provided + assert weight_pack.gguf_quant_type is not None, "gguf_quant_type must be set on WeightPack" + if out is None: + if weight_pack.gguf_quant_type in UNQUANTIZED_TYPES: + out_features = weight_pack.weight.shape[0] + else: + out_features, _ = quant_shape_from_byte_shape( + weight_pack.weight.shape, + weight_pack.gguf_quant_type, + ) + shape = (input_tensor.shape[0], out_features) + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor(shape, input_tensor.dtype, device=input_tensor.device) + else: + out = torch.empty(shape, dtype=input_tensor.dtype, device=input_tensor.device) + # unquantized types are directly used + if weight_pack.gguf_quant_type in UNQUANTIZED_TYPES: + weight = weight_pack.weight.t() + return _linear(input_tensor, weight, out, bias) + # quantized types are dequantized and then used + dequant_fn = get_gguf_dequant_fn(weight_pack.gguf_quant_type) + if dequant_fn is None: + raise ValueError( + f"Unsupported GGUF quantization type: {weight_pack.gguf_quant_type}" + ) + m, n = quant_shape_from_byte_shape(weight_pack.weight.shape, + weight_pack.gguf_quant_type) + alloc_func = torch.empty if not use_custom_tensor_mananger else g_cache_manager.empty + dequantized_weight = alloc_func((m, n), + dtype=input_tensor.dtype, + device=input_tensor.device) + dequant_fn( + weight_pack.weight, + m, + n, + input_tensor.dtype, + out=dequantized_weight, + ) + weight = dequantized_weight.t() + + return _linear(input_tensor, weight, out, bias) + + @property + def method_name(self): + return "gguf" + + def load_weight(self, weight: torch.Tensor, + weight_pack: WeightPack) -> None: + device = weight_pack.weight.device + weight_pack.weight.copy_(weight.contiguous().to( + device=device, dtype=weight_pack.weight.dtype)) + weight_pack.load_ok[0] = True + + def _check_weight_need_quanted(self, weight: torch.Tensor) -> bool: + return False + + def _create_weight( + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, + ) -> Tuple[WeightPack, List[WeightPack]]: + assert weight_names is not None and len( + weight_names) > 0, "weight_names must be provided" + if self.gguf_quant_meta_map is None: + raise ValueError( + f"Cannot load GGUF-quantized weights {weight_names!r}: no GGUF metadata was built. " + f"quant_type is 'gguf', but model_dir has no .gguf file (only HuggingFace/safetensors). " + f"Use --quant_type to no set gguf for HF checkpoints, or set model_dir to a path that contains " + f"exactly one .gguf file." + ) + + assert len(weight_names) == len( + out_dims), "weight_names and out_dims must align" + weight_dtype = None + gguf_quant_types = set() + for weight_name in weight_names: + meta: GGUFWeightMeta = self.gguf_quant_meta_map[weight_name] + gguf_quant_types.add(meta.quant_type) + quant_shape = meta.shape + assert len( + quant_shape + ) == 2, f"GGUF linear weight must be 2D, got {quant_shape} for {weight_name}" + _, in_d = quant_shape[0], quant_shape[1] + assert in_d == in_dim, ( + f"GGUF tensor {weight_name} has in_features {in_d}, layer expects in_dim {in_dim}" + ) + if weight_dtype is None: + weight_dtype = meta.dtype + else: + assert weight_dtype == meta.dtype, f"merged GGUF weights must share dtype, got {weight_dtype} vs {meta.dtype}" + assert len( + gguf_quant_types + ) == 1, f"merged GGUF weights must share quant_type, got {gguf_quant_types}" + gguf_quant_type = gguf_quant_types.pop() + + if gguf_quant_type not in UNQUANTIZED_TYPES: + if get_gguf_dequant_fn(gguf_quant_type) is None: + raise ValueError( + f"No CUDA dequant registered for GGUF type {gguf_quant_type!r}; " + f"add @register_gguf_dequant in " + f"lightllm/models/gguf/triton_kernel/dequantization.py" + ) + + # Buffer sizes follow layer/tp shard dims from the caller (load_path slices file weights into this storage). + logical_shape_rowmajor = (sum(out_dims), in_dim) + expert_prefix = (num_experts, ) if num_experts > 1 else () + if gguf_quant_type in UNQUANTIZED_TYPES: + full_shape = expert_prefix + logical_shape_rowmajor + else: + full_shape = expert_prefix + quant_shape_to_byte_shape( + logical_shape_rowmajor, gguf_quant_type) + weight = torch.empty(full_shape, dtype=weight_dtype).cuda(device_id) + mm_param = WeightPack( + weight=weight, + weight_scale=None, + weight_zero_point=None, + gguf_quant_type=gguf_quant_type, + ) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + ) + + return mm_param, mm_param_list diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py index fa926ad6f0..2ad5a29d74 100644 --- a/lightllm/common/quantization/no_quant.py +++ b/lightllm/common/quantization/no_quant.py @@ -35,7 +35,13 @@ def apply( return torch.addmm(bias, input_tensor, weight, out=out) def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 95d8d806f9..ef244f770f 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,8 +1,21 @@ +import numpy as np import torch from abc import ABC, abstractmethod from dataclasses import dataclass +from gguf import GGMLQuantizationType from lightllm.utils.dist_utils import get_current_device_id -from typing import Optional, List, Tuple +from typing import Dict, Optional, List, Tuple + + +@dataclass(frozen=True) +class GGUFWeightMeta: + shape: Tuple[int, ...] + dtype: torch.dtype + quant_type: GGMLQuantizationType + + +def numpy_dtype_to_torch(np_dtype: np.dtype) -> torch.dtype: + return torch.from_numpy(np.zeros((), dtype=np_dtype)).dtype @dataclass @@ -10,6 +23,7 @@ class WeightPack: weight: Optional[torch.Tensor] = None weight_scale: Optional[torch.Tensor] = None weight_zero_point: Optional[torch.Tensor] = None + gguf_quant_type: Optional[GGMLQuantizationType] = None def __post_init__(self): self.load_ok = [False, self.weight_scale is None, self.weight_zero_point is None] @@ -19,7 +33,12 @@ def get_expert(self, expert_idx: int): weight = self.weight[expert_idx] weight_scale = self.weight_scale[expert_idx] if self.weight_scale is not None else None weight_zero_point = self.weight_zero_point[expert_idx] if self.weight_zero_point is not None else None - return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + return WeightPack( + weight=weight, + weight_scale=weight_scale, + weight_zero_point=weight_zero_point, + gguf_quant_type=self.gguf_quant_type, + ) class QuantizationMethod(ABC): @@ -37,6 +56,8 @@ def __init__(self): # 一些量化模式需要用到的额外量化参数,如awq量化 self.hf_quantization_config = None + # GGUF hf_name -> gguf storage meta info + self.gguf_quant_meta_map: Optional[Dict[str, GGUFWeightMeta]] = None @abstractmethod def quantize( @@ -64,17 +85,30 @@ def method_name(self): pass def create_weight( - self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: return self._create_weight( out_dims=out_dims, in_dim=in_dim, dtype=dtype, device_id=device_id, + num_experts=1, + weight_names=weight_names, ) def create_moe_weight( - self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: return self._create_weight( out_dims=out_dims, @@ -82,6 +116,7 @@ def create_moe_weight( dtype=dtype, device_id=device_id, num_experts=num_experts, + weight_names=weight_names, ) def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack) -> None: @@ -112,7 +147,13 @@ def _check_weight_need_quanted(self, weight: torch.Tensor) -> bool: return weight.dtype in [torch.bfloat16, torch.float16, torch.float32, torch.float64] def _create_weight( - self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[str] = None, ) -> Tuple[WeightPack, List[WeightPack]]: pass @@ -144,6 +185,11 @@ def _split_weight_pack( ) for weight, weight_scale, weight_zero_point in zip(weight, weight_scale, weight_zero_point): mm_param_list.append( - WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + WeightPack( + weight=weight, + weight_scale=weight_scale, + weight_zero_point=weight_zero_point, + gguf_quant_type=weight_pack.gguf_quant_type, + ) ) return mm_param_list diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index 65ec6cd145..0cf5c3bb22 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -49,7 +49,13 @@ def method_name(self): return "w8a8-base" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: raise NotImplementedError("Not implemented") @@ -99,7 +105,13 @@ def method_name(self): return "vllm-w8a8" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () @@ -163,7 +175,13 @@ def method_name(self): return "vllm-fp8w8a8" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () @@ -241,7 +259,13 @@ def method_name(self): return "vllm-fp8w8a8-b128" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () diff --git a/lightllm/common/quantization/w8a8gx.py b/lightllm/common/quantization/w8a8gx.py index c25136697d..704a82c197 100644 --- a/lightllm/common/quantization/w8a8gx.py +++ b/lightllm/common/quantization/w8a8gx.py @@ -32,7 +32,13 @@ def method_name(self): return "w8a8gx-base" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: raise NotImplementedError("Not implemented") @@ -119,7 +125,13 @@ def method_name(self): return f"triton-fp8w8a8g{self.act_quant_group_size}" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad35..ab5967866d 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -112,7 +112,8 @@ class ReqSamplingParamsManager: def __init__(self, max_request_num): # mode ["cpu_counter", "pin_mem_counter", "gpu_counter"] self.penalty_counter_mode = get_env_start_args().penalty_counter_mode - self.vocab_size = get_vocab_size(get_env_start_args().model_dir) + start_args = get_env_start_args() + self.vocab_size = get_vocab_size(config_path=start_args.config_path, model_dir=start_args.model_dir) self.req_to_presence_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") diff --git a/lightllm/models/gguf/triton_kernel/dequantization.py b/lightllm/models/gguf/triton_kernel/dequantization.py new file mode 100644 index 0000000000..75b1502b05 --- /dev/null +++ b/lightllm/models/gguf/triton_kernel/dequantization.py @@ -0,0 +1,276 @@ +""" +ref: https://github.com/ggml-org/llama.cpp/discussions/17393 +""" +import torch +import triton +import triton.language as tl +from gguf import GGMLQuantizationType +from typing import Callable, Optional + +_GGUF_DEQUANT_REGISTRY: dict[GGMLQuantizationType, Callable[..., torch.Tensor]] = {} + + +def register_gguf_dequant(quant_type: GGMLQuantizationType): + + def _wrap(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + if quant_type in _GGUF_DEQUANT_REGISTRY: + raise ValueError( + f"Duplicate GGUF dequant registration for {quant_type}: " + f"{_GGUF_DEQUANT_REGISTRY[quant_type]!r} vs {fn!r}") + _GGUF_DEQUANT_REGISTRY[quant_type] = fn + return fn + + return _wrap + + +def get_gguf_dequant_fn( + quant_type: GGMLQuantizationType, +) -> Optional[Callable[..., torch.Tensor]]: + return _GGUF_DEQUANT_REGISTRY.get(quant_type) + + +QK5_0 = 32 +BLOCK_Q5_0_BYTES = 22 +""" +# Each block is represented by 22 consecutive values in the final ndarray. +typedef struct { + ggml_half d; // scale, total 16 bits + uint8_t qh[4]; // all 32 high bits of quants, total 8*4=32 bits + uint8_t qs[QK5_0 / 2]; // all 32 low bits of quants, total 8*(32/2)=128 bits +} block_q5_0; // total 172 bits = 22 bytes +""" + + +@triton.jit +def dequantize_q5_0_kernel( + w_ptr, + out_ptr, + k, + QK: tl.constexpr, + BLOCK_BYTES: tl.constexpr, + OUT_IS_BF16: tl.constexpr, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < k + + block_idx = offs // QK + pos_in_block = offs % QK + + half_k = QK // 2 + first_half = pos_in_block < half_k + # [0, 15] + qs_low_bits_idx = tl.where(first_half, pos_in_block, pos_in_block - half_k) + + block_base = block_idx * BLOCK_BYTES + w_f16 = tl.cast(w_ptr, tl.pointer_type(tl.float16)) + d = tl.load(w_f16 + block_base // 2, mask=mask, other=0).to(tl.float32) + + qh = (tl.load(w_ptr + block_base + 2, mask=mask, other=0).to(tl.uint32) + | (tl.load(w_ptr + block_base + 3, mask=mask, other=0).to(tl.uint32) + << 8) + | (tl.load(w_ptr + block_base + 4, mask=mask, other=0).to(tl.uint32) + << 16) + | (tl.load(w_ptr + block_base + 5, mask=mask, other=0).to(tl.uint32) + << 24)) + + qsb = tl.load(w_ptr + block_base + 6 + qs_low_bits_idx, mask=mask, + other=0).to(tl.int32) + nib_lo = qsb & 0xF + nib_hi = (qsb >> 4) & 0xF + + qs_low_bits_idx_u32 = qs_low_bits_idx.to(tl.uint32) + xh_0 = ((qh >> (qs_low_bits_idx_u32 + 0)) << 4) & 0x10 + xh_1 = (qh >> (qs_low_bits_idx_u32 + 12)) & 0x10 + + val = tl.where(first_half, nib_lo | xh_0, nib_hi | xh_1) + y = (val.to(tl.float32) - 16.0) * d + + if OUT_IS_BF16: + tl.store(out_ptr + offs, y.to(tl.bfloat16), mask=mask) + else: + tl.store(out_ptr + offs, y.to(tl.float16), mask=mask) + + +@register_gguf_dequant(GGMLQuantizationType.Q5_0) +def dequantize_q5_0( + weight_uint8: torch.Tensor, + m: int, + n: int, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if dtype not in (torch.float16, torch.bfloat16): + raise ValueError("The output dtype must be float16 or bfloat16") + if weight_uint8.dtype != torch.uint8: + raise ValueError("The weight must be uint8 packed weights") + + k = m * n + if k % QK5_0 != 0: + raise ValueError(f"element count {k} must be a multiple of {QK5_0}") + + nb = k // QK5_0 + expected_bytes = nb * BLOCK_Q5_0_BYTES + if weight_uint8.numel() < expected_bytes: + raise ValueError( + f"W has {weight_uint8.numel()} bytes, need at least {expected_bytes}" + ) + + if out is None: + out = torch.empty((m, n), device=weight_uint8.device, dtype=dtype) + BLOCK = 1024 + grid = (triton.cdiv(k, BLOCK), ) + dequantize_q5_0_kernel[grid]( + weight_uint8, + out.view(-1), + k, + QK=QK5_0, + BLOCK_BYTES=BLOCK_Q5_0_BYTES, + OUT_IS_BF16=dtype == torch.bfloat16, + BLOCK=BLOCK, + ) + return out + + +QK8_0 = 32 +BLOCK_Q8_0_BYTES = 34 +""" +# Each block is represented by 34 consecutive values in the final ndarray. +typedef struct { + ggml_half d; // scale, total 16 bits + int8_t qs[QK8_0]; // all 32 bits of quants, total 8*32=256 bits +} block_q8_0; // total 272 bits = 34 bytes +""" + + +@triton.jit +def dequantize_q8_0_kernel( + w_ptr, + out_ptr, + k, + QK: tl.constexpr, + BLOCK_BYTES: tl.constexpr, + OUT_IS_BF16: tl.constexpr, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < k + + block_idx = offs // QK + pos_in_block = offs % QK + + block_base = block_idx * BLOCK_BYTES + w_f16 = tl.cast(w_ptr, tl.pointer_type(tl.float16)) + d = tl.load(w_f16 + block_base // 2, mask=mask, other=0).to(tl.float32) + + qb = tl.load(w_ptr + block_base + 2 + pos_in_block, mask=mask, + other=0).to(tl.int32) + qs = tl.where(qb > 127, qb - 256, qb) + y = qs.to(tl.float32) * d + + if OUT_IS_BF16: + tl.store(out_ptr + offs, y.to(tl.bfloat16), mask=mask) + else: + tl.store(out_ptr + offs, y.to(tl.float16), mask=mask) + + +@register_gguf_dequant(GGMLQuantizationType.Q8_0) +def dequantize_q8_0( + weight_uint8: torch.Tensor, + m: int, + n: int, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if dtype not in (torch.float16, torch.bfloat16): + raise ValueError("The output dtype must be float16 or bfloat16") + if weight_uint8.dtype != torch.uint8: + raise ValueError("The weight must be uint8 packed weights") + + k = m * n + if k % QK8_0 != 0: + raise ValueError(f"element count {k} must be a multiple of {QK8_0}") + + nb = k // QK8_0 + expected_bytes = nb * BLOCK_Q8_0_BYTES + if weight_uint8.numel() < expected_bytes: + raise ValueError( + f"W has {weight_uint8.numel()} bytes, need at least {expected_bytes}" + ) + + if out is None: + out = torch.empty((m, n), device=weight_uint8.device, dtype=dtype) + BLOCK = 1024 + grid = (triton.cdiv(k, BLOCK), ) + dequantize_q8_0_kernel[grid]( + weight_uint8, + out.view(-1), + k, + QK=QK8_0, + BLOCK_BYTES=BLOCK_Q8_0_BYTES, + OUT_IS_BF16=dtype == torch.bfloat16, + BLOCK=BLOCK, + ) + return out + + +""" +Two uint8 values form a uint16 value, which is then converted to bfloat16. +""" + + +@triton.jit +def dequantize_bf16_kernel( + w_ptr, + out_ptr, + k, + OUT_IS_BF16: tl.constexpr, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < k + + w_bf16 = tl.cast(w_ptr, tl.pointer_type(tl.bfloat16)) + y = tl.load(w_bf16 + offs, mask=mask, other=0).to(tl.float32) + + if OUT_IS_BF16: + tl.store(out_ptr + offs, y.to(tl.bfloat16), mask=mask) + else: + tl.store(out_ptr + offs, y.to(tl.float16), mask=mask) + + +@register_gguf_dequant(GGMLQuantizationType.BF16) +def dequantize_bf16( + weight_uint8: torch.Tensor, + m: int, + n: int, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if dtype not in (torch.float16, torch.bfloat16): + raise ValueError("The output dtype must be float16 or bfloat16") + if weight_uint8.dtype != torch.uint8: + raise ValueError("The weight must be uint8 packed weights") + + k = m * n + expected_bytes = k * 2 + if weight_uint8.numel() < expected_bytes: + raise ValueError( + f"W has {weight_uint8.numel()} bytes, need at least {expected_bytes}" + ) + + if out is None: + out = torch.empty((m, n), device=weight_uint8.device, dtype=dtype) + BLOCK = 1024 + grid = (triton.cdiv(k, BLOCK), ) + dequantize_bf16_kernel[grid]( + weight_uint8, + out.view(-1), + k, + OUT_IS_BF16=dtype == torch.bfloat16, + BLOCK=BLOCK, + ) + return out diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index f33f58b86d..837b09aa8d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -121,6 +121,19 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="the model weight dir path, the app will load config, weights and tokenizer from this dir", ) + parser.add_argument( + "--config_path", + type=str, + default=None, + help="the config path (JSON). If not set, use model_dir/config.json when present; " + "otherwise derive config from GGUF metadata when model_dir is a .gguf file or contains one", + ) + parser.add_argument( + "--tokenizer_dir", + type=str, + default=None, + help="tokenizer directory; required for GGUF models, otherwise defaults to model_dir", + ) parser.add_argument( "--tokenizer_mode", type=str, @@ -589,9 +602,10 @@ def make_argument_parser() -> argparse.ArgumentParser: "--quant_type", type=str, default="none", - help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 - | deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin | - | triton-fp8w8a8g128 (weight perchannel quant and act per group quant) | + help="""Quantization method: none | gguf (requires .gguf weights under model_dir) | + vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 | deepgemm-fp8w8a8-b128 | + triton-fp8w8a8-block128 | awq | awq_marlin | + triton-fp8w8a8g128 (weight perchannel quant and act per group quant) | triton-fp8w8a8g64 (weight perchannel quantization with group size 64)""", ) parser.add_argument( diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 8c6af128c8..90635034de 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -18,6 +18,9 @@ from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size from lightllm.utils.config_utils import ( + align_quant_type_for_gguf_model, + check_gguf_quant_model_dir, + check_gguf_tokenizer_dir, has_audio_module, has_vision_module, is_linear_att_mixed_model, @@ -142,6 +145,15 @@ def normal_or_p_d_start(args): if not args.disable_shm_warning: check_recommended_shm_size(args) + check_gguf_tokenizer_dir(args.model_dir, args.tokenizer_dir) + aligned_quant_type = align_quant_type_for_gguf_model(args.model_dir, args.quant_type, args.quant_cfg) + if aligned_quant_type != args.quant_type: + logger.warning( + f"model_dir contains GGUF weights; overriding --quant_type {args.quant_type!r} -> {aligned_quant_type!r}" + ) + args.quant_type = aligned_quant_type + check_gguf_quant_model_dir(args.model_dir, args.quant_type, args.quant_cfg) + assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] # 确保单机上多实列不冲突 if args.zmq_mode == "ipc:///tmp/": @@ -311,7 +323,9 @@ def normal_or_p_d_start(args): if args.tool_call_parser is None: from lightllm.utils.config_utils import get_tool_call_parser_for_model - args.tool_call_parser = get_tool_call_parser_for_model(args.model_dir) + args.tool_call_parser = get_tool_call_parser_for_model( + config_path=args.config_path, model_dir=args.model_dir + ) if args.tool_call_parser: logger.info(f"Auto set tool_call_parser to {args.tool_call_parser} based on model type") @@ -319,7 +333,9 @@ def normal_or_p_d_start(args): if args.reasoning_parser is None: from lightllm.utils.config_utils import get_reasoning_parser_for_model - args.reasoning_parser = get_reasoning_parser_for_model(args.model_dir) + args.reasoning_parser = get_reasoning_parser_for_model( + config_path=args.config_path, model_dir=args.model_dir + ) if args.reasoning_parser: logger.info(f"Auto set reasoning_parser to {args.reasoning_parser} based on model type") @@ -552,6 +568,22 @@ def pd_master_start(args): set_env_start_args(args) + check_gguf_tokenizer_dir(args.model_dir, args.tokenizer_dir) + aligned_quant_type = align_quant_type_for_gguf_model(args.model_dir, args.quant_type, args.quant_cfg) + if aligned_quant_type != args.quant_type: + logger.warning( + f"model_dir contains GGUF weights; overriding --quant_type {args.quant_type!r} -> {aligned_quant_type!r}" + ) + args.quant_type = aligned_quant_type + check_gguf_quant_model_dir(args.model_dir, args.quant_type, args.quant_cfg) + aligned_quant_type = align_quant_type_for_gguf_model(args.model_dir, args.quant_type, args.quant_cfg) + if aligned_quant_type != args.quant_type: + logger.warning( + f"model_dir contains GGUF weights; overriding --quant_type {args.quant_type!r} -> {aligned_quant_type!r}" + ) + args.quant_type = aligned_quant_type + check_gguf_quant_model_dir(args.model_dir, args.quant_type, args.quant_cfg) + process_manager.start_submodule_processes( start_funcs=[ start_metric_manager, diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 84044fccce..04e0c12c31 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -11,7 +11,12 @@ def init_tokenizer(args): global tokenizer - tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + tokenizer = get_tokenizer( + tokenizer_name=args.model_dir, + tokenizer_dir=args.tokenizer_dir, + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) chat_path = args.chat_template if chat_path is not None: with open(chat_path, "r", encoding="utf-8") as f: diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8a..b5b512d31e 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -34,7 +34,12 @@ def __init__( self.pub_to_httpserver = context.socket(zmq.PUB) self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}") - self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + self.tokenizer = get_tokenizer( + tokenizer_name=args.model_dir, + tokenizer_dir=args.tokenizer_dir, + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) self.all_special_ids = set(self.tokenizer.all_special_ids) self.req_id_to_out: Dict[int, DecodeReq] = {} self.eos_id = args.eos_id diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 4c049f77c0..2b0a827d14 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -102,7 +102,12 @@ def __init__( self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"") - self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + self.tokenizer = get_tokenizer( + tokenizer_name=args.model_dir, + tokenizer_dir=args.tokenizer_dir, + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) self.req_id_to_out_inf: Dict[int, ReqStatus] = {} # value type (out_str, metadata, finished, event) self.forwarding_queue: AsyncQueue = None # p d 分离模式使用的转发队列, 需要延迟初始化 @@ -116,7 +121,10 @@ def __init__( self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() # 有的模型的vocab size 读取tokenizer和config.json中不一致 - self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size) + self.vocab_size = max( + get_vocab_size(config_path=args.config_path, model_dir=args.model_dir), + self.tokenizer.vocab_size, + ) # The timemark of the latest inference(prefill/decode) which is used to check the health status of the system. # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index af7a1e29fa..afa198bb45 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -42,7 +42,12 @@ def __init__( self.req_id_to_out_inf: Dict[int, ReqStatus] = {} self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对 - self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + self.tokenizer = get_tokenizer( + args.model_dir, + args.tokenizer_dir, + args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ca982ec0f0..0254cec1d7 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -19,6 +19,7 @@ from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify +from lightllm.utils.config_utils import get_model_config_dict from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -150,7 +151,7 @@ def init_model(self, kvargs): if self.args.enable_multimodal: g_infer_context.init_cpu_embed_cache_client() - model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir) + model_cfg = get_model_config_dict(config_path=self.args.config_path, model_dir=self.weight_dir) model_kvargs = { "weight_dir": self.weight_dir, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py index 9be5fcd1f5..8f4745d487 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py @@ -35,7 +35,12 @@ def init_custom(self): from outlines.models.transformers import TransformerTokenizer self.tokenizer = TransformerTokenizer( - get_tokenizer(self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code) + get_tokenizer( + self.args.model_dir, + self.args.tokenizer_dir, + self.args.tokenizer_mode, + trust_remote_code=self.args.trust_remote_code, + ) ) eos_token_ids = [] eos_token_ids.append(self.tokenizer.eos_token_id) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py index d8de132cac..26ab259863 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py @@ -21,7 +21,10 @@ def init_custom(self): 初始化tokenizer 词表相关的的操作 """ self.tokenizer = get_tokenizer( - self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code + self.args.model_dir, + self.args.tokenizer_dir, + self.args.tokenizer_mode, + trust_remote_code=self.args.trust_remote_code, ) vob_dict = self.tokenizer.get_vocab() self.token_to_token_id = vob_dict diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py index b159b98f25..2bdae895f7 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py @@ -23,7 +23,10 @@ def init_custom(self): import xgrammar as xgr self.tokenizer = get_tokenizer( - self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code + self.args.model_dir, + self.args.tokenizer_dir, + self.args.tokenizer_mode, + trust_remote_code=self.args.trust_remote_code, ) self.tokenizer_info = xgr.TokenizerInfo.from_huggingface(self.tokenizer) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 25726b2578..7d17db4fbc 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -21,6 +21,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.convert_slow_tokenizer import convert_slow_tokenizer from transformers.configuration_utils import PretrainedConfig +from lightllm.utils.config_utils import resolve_tokenizer_dir from lightllm.utils.log_utils import init_logger from ..models.tarsier2.model import Tarsier2Tokenizer @@ -39,12 +40,14 @@ def get_tokenizer( tokenizer_name: str, + tokenizer_dir: str = None, tokenizer_mode: str = "auto", trust_remote_code: bool = False, *args, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: """Gets a tokenizer for the given model name via Huggingface.""" + tokenizer_name = resolve_tokenizer_dir(tokenizer_name, tokenizer_dir) if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index c64e8a912b..ee7e87543b 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -1,17 +1,218 @@ import json import os -from typing import Optional, List from functools import lru_cache +from typing import List, Optional + from .envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) -def get_config_json(model_path: str): - with open(os.path.join(model_path, "config.json"), "r") as file: - json_obj = json.load(file) - return json_obj +def find_gguf_path(model_dir: Optional[str]) -> Optional[str]: + if not model_dir: + return None + if model_dir.endswith(".gguf") and os.path.isfile(model_dir): + return model_dir + if os.path.isdir(model_dir): + gguf_files = sorted( + os.path.join(model_dir, name) for name in os.listdir(model_dir) if name.endswith(".gguf") + ) + if not gguf_files: + return None + if len(gguf_files) > 1: + raise ValueError( + f"multiple GGUF files found in {model_dir} is not supported, please specify the target gguf file." + ) + return gguf_files[0] + return None + + +def is_gguf_model_path(model_dir: Optional[str]) -> bool: + return find_gguf_path(model_dir) is not None + + +def check_gguf_tokenizer_dir(model_dir: Optional[str], tokenizer_dir: Optional[str]) -> None: + if not is_gguf_model_path(model_dir): + return + if not tokenizer_dir: + raise ValueError( + f"GGUF model requires --tokenizer_dir (model_dir={model_dir!r}). " + "Provide a HuggingFace tokenizer directory separate from the .gguf weights." + ) + if not os.path.isdir(tokenizer_dir): + raise FileNotFoundError(f"tokenizer_dir is not a directory: {tokenizer_dir!r}") + for name in ("tokenizer.json", "tokenizer_config.json", "vocab.json"): + if os.path.isfile(os.path.join(tokenizer_dir, name)): + return + raise FileNotFoundError( + f"tokenizer_dir missing tokenizer files (tokenizer.json / tokenizer_config.json / vocab.json): " + f"{tokenizer_dir!r}" + ) + + +def resolve_tokenizer_dir(model_dir: Optional[str], tokenizer_dir: Optional[str]) -> str: + check_gguf_tokenizer_dir(model_dir, tokenizer_dir) + if is_gguf_model_path(model_dir): + return tokenizer_dir + return tokenizer_dir or model_dir + + +def uses_gguf_quant_type(quant_type: str, quant_cfg_path: Optional[str] = None) -> bool: + if quant_type == "gguf": + return True + if quant_cfg_path is None: + return False + import yaml + + with open(quant_cfg_path, "r") as file: + data = yaml.safe_load(file) or {} + if data.get("quant_type") == "gguf": + return True + for layer_quant_cfg in data.get("mix_bits", []) or []: + if layer_quant_cfg.get("quant_type") == "gguf": + return True + return False + + +def get_gguf_quant_conflicts(quant_type: str, quant_cfg_path: Optional[str] = None) -> List[str]: + """Non-gguf quant types that cannot be auto-overridden on GGUF weights ('none' is allowed).""" + conflicts = [] + if quant_type not in ("gguf", "none"): + conflicts.append(quant_type) + if quant_cfg_path is None: + return sorted(set(conflicts)) + import yaml + + with open(quant_cfg_path, "r") as file: + data = yaml.safe_load(file) or {} + cfg_quant_type = data.get("quant_type") + if cfg_quant_type is not None and cfg_quant_type not in ("gguf", "none"): + conflicts.append(cfg_quant_type) + for layer_quant_cfg in data.get("mix_bits", []) or []: + layer_quant_type = layer_quant_cfg.get("quant_type") + if layer_quant_type is not None and layer_quant_type != "gguf": + conflicts.append(layer_quant_type) + return sorted(set(conflicts)) + + +def align_quant_type_for_gguf_model( + model_dir: Optional[str], + quant_type: str, + quant_cfg_path: Optional[str] = None, +) -> str: + """GGUF weights only support gguf quantization; auto-align CLI default 'none' to 'gguf'.""" + if find_gguf_path(model_dir) is None: + return quant_type + conflicts = get_gguf_quant_conflicts(quant_type, quant_cfg_path) + if conflicts: + raise ValueError( + f"model_dir contains GGUF weights but quantization is configured as {conflicts!r}. " + "GGUF checkpoints only support --quant_type gguf. " + "Use a HuggingFace safetensors directory for awq/fp8/none, or remove non-gguf entries from --quant_cfg mix_bits." + ) + if quant_type == "gguf": + return quant_type + return "gguf" + + +def check_gguf_quant_model_dir( + model_dir: Optional[str], + quant_type: str, + quant_cfg_path: Optional[str] = None, +) -> None: + if not uses_gguf_quant_type(quant_type, quant_cfg_path): + return + if find_gguf_path(model_dir) is not None: + return + raise ValueError( + f"--quant_type gguf requires a .gguf weights file, but none found under model_dir={model_dir!r}. " + "Point model_dir to a .gguf file or a directory with exactly one .gguf file. " + "For HuggingFace safetensors checkpoints, use --quant_type none (or awq / fp8, etc.), not gguf." + ) + + +def load_model_config(config_path: Optional[str] = None, model_dir: Optional[str] = None) -> dict: + # load config from config_path + if config_path is not None: + if not os.path.isfile(config_path): + raise FileNotFoundError(f"config file not found: {config_path}") + with open(config_path, "r") as file: + return json.load(file) + + # load config from model_dir/config.json + if model_dir and not model_dir.endswith(".gguf"): + default_json = os.path.join(model_dir, "config.json") + if os.path.isfile(default_json): + with open(default_json, "r") as file: + return json.load(file) + + # load config from GGUF metadata + gguf_path = find_gguf_path(model_dir) + if gguf_path is not None: + try: + from transformers.modeling_gguf_pytorch_utils import load_gguf_checkpoint + except ImportError as e: + raise ImportError( + "Loading config from GGUF requires transformers with GGUF support and the gguf package." + ) from e + config = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"] + logger.info(f"loaded model config from GGUF metadata: {gguf_path}") + + return config + + raise FileNotFoundError( + f"no model config found (config_path={config_path!r}, model_dir={model_dir!r}). " + "Provide --config_path, place config.json under model_dir, or use a .gguf model path." + ) + + +def normalize_model_config(config: dict) -> dict: + from lightllm.common.build_utils import repair_config + + repair_config(config, same_names=["num_attention_heads", "n_head"]) + repair_config(config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(config, same_names=["num_hidden_layers", "n_layer"]) + if config.get("head_dim") is None: + hidden_size = config.get("hidden_size") or config.get("n_embd") or config.get("n_embed") + num_heads = config.get("num_attention_heads") or config.get("n_head") + if hidden_size and num_heads: + config["head_dim"] = hidden_size // num_heads + return config + + +@lru_cache(maxsize=None) +def get_model_config_dict(config_path: Optional[str] = None, model_dir: Optional[str] = None) -> dict: + return normalize_model_config(load_model_config(config_path=config_path, model_dir=model_dir)) + + +def _get_config_from_model_path(model_path: str) -> dict: + if os.path.isfile(model_path) and not model_path.endswith(".gguf"): + return get_model_config_dict(config_path=model_path, model_dir=None) + return get_model_config_dict(model_dir=model_path) + + +def resolve_model_config( + config_path: Optional[str] = None, + model_dir: Optional[str] = None, +) -> dict: + if config_path is not None or model_dir is not None: + return get_model_config_dict(config_path=config_path, model_dir=model_dir) + raise ValueError("resolve_model_config requires config_path or model_dir") + + +@lru_cache(maxsize=1) +def get_start_args_model_config() -> dict: + start_args = get_env_start_args() + return get_model_config_dict(config_path=start_args.config_path, model_dir=start_args.model_dir) + + +def get_config_json(model_path: str) -> dict: + logger.warning( + "get_config_json(model_path) is deprecated; " + "use get_model_config_dict(config_path=..., model_dir=...) instead." + ) + return _get_config_from_model_path(model_path) def _derive_max_req_total_len_from_model_config(model_dir: str) -> Optional[int]: @@ -254,9 +455,9 @@ def get_model_architectures(model_path: str): return "unknown_architecture" -def get_vocab_size(model_path: str): +def get_vocab_size(config_path: Optional[str] = None, model_dir: Optional[str] = None) -> int: try: - config_json = get_config_json(model_path) + config_json = resolve_model_config(config_path=config_path, model_dir=model_dir) # qwen3-omini special if "thinker_config" in config_json: config_json = config_json["thinker_config"] @@ -286,8 +487,7 @@ def get_dtype(model_path: str): @lru_cache(maxsize=None) def get_fixed_kv_len(): - start_args = get_env_start_args() - model_cfg = get_config_json(start_args.model_dir) + model_cfg = get_start_args_model_config() if "prompt_cache_token_ids" in model_cfg: return len(model_cfg["prompt_cache_token_ids"]) else: @@ -297,9 +497,7 @@ def get_fixed_kv_len(): @lru_cache(maxsize=None) def has_vision_module(model_path: str) -> bool: try: - from transformers.configuration_utils import PretrainedConfig - - model_cfg, _ = PretrainedConfig.get_config_dict(model_path) + model_cfg = get_model_config_dict(model_dir=model_path) model_type = model_cfg["model_type"] if model_type == "qwen": # QWenVisionTransformer @@ -345,9 +543,7 @@ def has_vision_module(model_path: str) -> bool: @lru_cache(maxsize=None) def has_audio_module(model_path: str) -> bool: try: - from transformers.configuration_utils import PretrainedConfig - - model_cfg, _ = PretrainedConfig.get_config_dict(model_path) + model_cfg = get_model_config_dict(model_dir=model_path) if model_cfg.get("thinker_config") is not None: model_cfg = model_cfg["thinker_config"] audio_config = model_cfg["audio_config"] @@ -368,9 +564,7 @@ def has_audio_module(model_path: str) -> bool: @lru_cache(maxsize=None) def is_linear_att_mixed_model(model_path: str) -> bool: try: - from transformers.configuration_utils import PretrainedConfig - - model_cfg, _ = PretrainedConfig.get_config_dict(model_path) + model_cfg = get_model_config_dict(model_dir=model_path) model_type = model_cfg["model_type"] if model_type in ["qwen3_5", "qwen3_5_moe", "qwen3_5_text", "qwen3_5_moe_text"]: return True @@ -381,20 +575,26 @@ def is_linear_att_mixed_model(model_path: str) -> bool: return False -def get_model_type(model_path: str) -> Optional[str]: - """Get model type from config.json""" +def get_model_type( + config_path: Optional[str] = None, + model_dir: Optional[str] = None, +) -> Optional[str]: + """Get model type from model config.""" try: - config_json = get_config_json(model_path) + config_json = resolve_model_config(config_path=config_path, model_dir=model_dir) model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type") return model_type except Exception as e: - logger.error(f"Failed to get model_type from {model_path}: {e}") + logger.error(f"Failed to get model_type (config_path={config_path!r}, model_dir={model_dir!r}): {e}") return None -def get_tool_call_parser_for_model(model_path: str) -> Optional[str]: +def get_tool_call_parser_for_model( + config_path: Optional[str] = None, + model_dir: Optional[str] = None, +) -> Optional[str]: """Auto-detect tool_call_parser based on model type""" - model_type = get_model_type(model_path) + model_type = get_model_type(config_path=config_path, model_dir=model_dir) if model_type is None: return None @@ -421,9 +621,12 @@ def get_tool_call_parser_for_model(model_path: str) -> Optional[str]: return None -def get_reasoning_parser_for_model(model_path: str) -> Optional[str]: +def get_reasoning_parser_for_model( + config_path: Optional[str] = None, + model_dir: Optional[str] = None, +) -> Optional[str]: """Auto-detect reasoning_parser based on model type""" - model_type = get_model_type(model_path) + model_type = get_model_type(config_path=config_path, model_dir=model_dir) if model_type is None: return None diff --git a/lightllm/utils/llm_utils.py b/lightllm/utils/llm_utils.py index ced75615d6..298352e72d 100644 --- a/lightllm/utils/llm_utils.py +++ b/lightllm/utils/llm_utils.py @@ -1,5 +1,4 @@ from functools import lru_cache -from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger @@ -8,10 +7,9 @@ @lru_cache(maxsize=None) def get_llm_model_class(): - from transformers.configuration_utils import PretrainedConfig - - model_cfg, _ = PretrainedConfig.get_config_dict(get_env_start_args().model_dir) from lightllm.models import get_model_class + from lightllm.utils.config_utils import get_start_args_model_config + model_cfg = get_start_args_model_config() model_class = get_model_class(model_cfg=model_cfg) return model_class diff --git a/lightllm/utils/shm_size_check.py b/lightllm/utils/shm_size_check.py index 2f99a855e7..dcf9aaadc7 100644 --- a/lightllm/utils/shm_size_check.py +++ b/lightllm/utils/shm_size_check.py @@ -86,7 +86,12 @@ def _get_recommended_shm_size_gb(args, max_image_resolution=(3940, 2160), dtype_ """ 获取所需的 /dev/shm 大小(以GB为单位)。 """ - tokenizer = get_tokenizer(args.model_dir, trust_remote_code=True) + tokenizer = get_tokenizer( + args.model_dir, + args.tokenizer_dir, + args.tokenizer_mode, + trust_remote_code=True, + ) # 估算input_token和logprob占用shm大小,由于是double和int64,所以固定占用8个字节 input_token_logprob_size_bytes = args.running_max_req_size * 8 * 2 * args.max_req_total_len From 9752f96ce0e4ae934e0f8594771f223ac624e0f0 Mon Sep 17 00:00:00 2001 From: zhangtaoshan Date: Wed, 20 May 2026 16:00:15 +0800 Subject: [PATCH 2/5] misc --- lightllm/models/gguf/triton_kernel/dequantization.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lightllm/models/gguf/triton_kernel/dequantization.py b/lightllm/models/gguf/triton_kernel/dequantization.py index 75b1502b05..40387f2006 100644 --- a/lightllm/models/gguf/triton_kernel/dequantization.py +++ b/lightllm/models/gguf/triton_kernel/dequantization.py @@ -32,11 +32,11 @@ def get_gguf_dequant_fn( QK5_0 = 32 BLOCK_Q5_0_BYTES = 22 """ -# Each block is represented by 22 consecutive values in the final ndarray. +# Each block is represented by 22 consecutive values in the final ndarray, will be dequantized into 32 floats. typedef struct { ggml_half d; // scale, total 16 bits - uint8_t qh[4]; // all 32 high bits of quants, total 8*4=32 bits - uint8_t qs[QK5_0 / 2]; // all 32 low bits of quants, total 8*(32/2)=128 bits + uint8_t qh[4]; // each 1 bit high bits of quants, total 8*4=32 bits + uint8_t qs[QK5_0 / 2]; // each 4 bits low bits of quants, total 4*2*(32/2)=128 bits } block_q5_0; // total 172 bits = 22 bytes """ @@ -136,10 +136,10 @@ def dequantize_q5_0( QK8_0 = 32 BLOCK_Q8_0_BYTES = 34 """ -# Each block is represented by 34 consecutive values in the final ndarray. +# Each block is represented by 34 consecutive values in the final ndarray, will be dequantized into 32 floats. typedef struct { ggml_half d; // scale, total 16 bits - int8_t qs[QK8_0]; // all 32 bits of quants, total 8*32=256 bits + int8_t qs[QK8_0]; // each 8 bits of quants, total 8*32=256 bits } block_q8_0; // total 272 bits = 34 bytes """ From bc5eb9ed18bf170964ea130b61e3d4ae162b517b Mon Sep 17 00:00:00 2001 From: zhangtaoshan Date: Mon, 15 Jun 2026 13:01:00 +0800 Subject: [PATCH 3/5] Add 'ModelPaths' as a path manager --- lightllm/common/basemodel/basemodel.py | 74 +-- .../layer_weights/gguf_load_utils.py | 535 +++++++++++++++--- .../fused_moe/fused_moe_weight.py | 24 +- .../meta_weights/mm_weight/mm_slicer.py | 10 +- .../meta_weights/mm_weight/mm_weight.py | 11 +- .../meta_weights/mm_weight/rowmm_weight.py | 2 + lightllm/common/gguf_kernel/__init__.py | 0 .../gguf_kernel}/dequantization.py | 0 lightllm/common/quantization/gguf.py | 106 ++-- .../common/quantization/quantize_method.py | 8 +- lightllm/common/req_manager.py | 5 +- lightllm/models/gemma3/model.py | 2 + lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 31 +- lightllm/models/qwen2_vl/model.py | 4 +- lightllm/models/qwen2_vl/qwen2_visual.py | 9 +- lightllm/models/qwen2_vl/vision_process.py | 13 + .../qwen3_omni_visual.py | 9 +- lightllm/models/qwen3_vl/qwen3_visual.py | 9 +- lightllm/models/tarsier2/tarsier2_visual.py | 7 +- lightllm/server/api_cli.py | 8 +- lightllm/server/api_openai.py | 5 + lightllm/server/api_start.py | 82 ++- lightllm/server/build_prompt.py | 5 +- lightllm/server/core/objs/start_args_type.py | 3 + lightllm/server/detokenization/manager.py | 4 +- lightllm/server/httpserver/manager.py | 9 +- .../httpserver_for_pd_master/manager.py | 6 +- .../model_infer/mode_backend/base_backend.py | 4 +- .../impl_for_outlines_constraint_mode.py | 6 +- .../chunked_prefill/impl_for_token_healing.py | 6 +- .../chunked_prefill/impl_for_xgrammar_mode.py | 6 +- lightllm/server/tokenizer.py | 160 ++++-- lightllm/server/visualserver/manager.py | 6 +- .../visualserver/model_infer/model_rpc.py | 6 +- .../visualserver/visual_only_manager.py | 6 +- lightllm/utils/config_utils.py | 484 ++++++++-------- lightllm/utils/gguf_tokenizer_utils.py | 207 +++++++ lightllm/utils/shm_size_check.py | 12 +- 38 files changed, 1295 insertions(+), 589 deletions(-) create mode 100644 lightllm/common/gguf_kernel/__init__.py rename lightllm/{models/gguf/triton_kernel => common/gguf_kernel}/dequantization.py (100%) create mode 100644 lightllm/utils/gguf_tokenizer_utils.py diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index ab7efc271c..8d2db60c99 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -22,7 +22,11 @@ from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token -from lightllm.utils.config_utils import find_gguf_path, get_model_config_dict +from lightllm.utils.config_utils import ( + apply_gguf_quant_type, + _create_model_paths, + get_model_config, +) from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num @@ -58,7 +62,7 @@ def __init__(self, kvargs): self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] self.weight_dir_ = kvargs["weight_dir"] - self.gguf_path_ = find_gguf_path(self.weight_dir_) + self.model_paths_ = _create_model_paths(self.weight_dir_) self.max_total_token_num = kvargs["max_total_token_num"] self.batch_max_tokens = kvargs.get("batch_max_tokens", None) self.load_way = kvargs.get("load_way", "HF") @@ -148,10 +152,7 @@ def _wait_other_modules_ready(self): return def _init_config(self): - self.config = get_model_config_dict( - config_path=self.args.config_path, - model_dir=self.weight_dir_, - ) + self.config = get_model_config(self.model_paths_) if self.finetune_config: self.config["vocab_size"] = self.finetune_config.vocab_size return @@ -171,51 +172,20 @@ def _init_quant(self): logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") def _align_quant_type_for_gguf_weights(self): - if self.gguf_path_ is None: - return - if self.quant_cfg.quant_type == "gguf": + if self.model_paths_.gguf_path is None: return - logger.warning( - f"model_dir contains GGUF weights ({self.gguf_path_!r}) but quant_type is " - f"{self.quant_cfg.quant_type!r}; using quant_type='gguf' instead." - ) - self.quant_cfg.quant_type = "gguf" + aligned = apply_gguf_quant_type(self.model_paths_, self.quant_cfg.quant_type) + if aligned != self.quant_cfg.quant_type: + self.quant_cfg.quant_type = aligned def _init_gguf(self): - if self.gguf_path_ is None: + if self.model_paths_.gguf_path is None: return - import numpy as np - from gguf.gguf_reader import GGUFReader - from lightllm.common.basemodel.layer_weights.gguf_load_utils import ( - DEQUANT_KEYS, - build_gguf_to_hf_mapping, - ) - from lightllm.common.quantization.quantize_method import GGUFWeightMeta, numpy_dtype_to_torch - - self._gguf_to_hf = build_gguf_to_hf_mapping(self.config) - self._gguf_reader = GGUFReader(self.gguf_path_) - - gguf_quant_meta_map = {} - for t in self._gguf_reader.tensors: - if t.name in DEQUANT_KEYS: - continue - hf_name = self._gguf_to_hf.get(t.name) - if hf_name is None: - continue - np_data = np.asarray(t.data) - logical_shape = tuple(reversed([int(x) for x in t.shape.tolist()])) - gguf_quant_meta_map[hf_name] = GGUFWeightMeta( - shape=logical_shape, - dtype=numpy_dtype_to_torch(np_data.dtype), - quant_type=t.tensor_type, - ) - self.quant_cfg.gguf_quant_meta_map = gguf_quant_meta_map + from lightllm.common.basemodel.layer_weights.gguf_load_utils import get_gguf_reader - def _release_gguf_reader(self): - if getattr(self, "_gguf_reader", None) is not None: - del self._gguf_reader - self._gguf_reader = None + reader = get_gguf_reader(self.model_paths_.gguf_path) + self.quant_cfg.gguf_quant_meta_map = reader.build_quant_meta_map(self.config) def _init_weights(self, start_layer_index=0): self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) @@ -231,19 +201,15 @@ def _init_weights(self, start_layer_index=0): return def _load_hf_weights(self): - if self.gguf_path_ is not None: - from lightllm.common.basemodel.layer_weights.gguf_load_utils import load_gguf_weights + if self.model_paths_.gguf_path is not None: + from lightllm.common.basemodel.layer_weights.gguf_load_utils import get_gguf_reader - load_gguf_weights( - self.data_type, - weight_dir=self.gguf_path_, + reader = get_gguf_reader(self.model_paths_.gguf_path) + reader.load_weights( + data_type=self.data_type, config=self.config, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - reader=getattr(self, "_gguf_reader", None), - gguf_to_hf=getattr(self, "_gguf_to_hf", None), - release_reader=self._release_gguf_reader, ) else: load_hf_weights( diff --git a/lightllm/common/basemodel/layer_weights/gguf_load_utils.py b/lightllm/common/basemodel/layer_weights/gguf_load_utils.py index 2cf6c2afa4..a19ae3f3b0 100644 --- a/lightllm/common/basemodel/layer_weights/gguf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/gguf_load_utils.py @@ -1,121 +1,502 @@ import gguf import numpy as np +import os import torch -from gguf import dequantize +from functools import lru_cache +from gguf import GGMLQuantizationType, dequantize, quant_shape_to_byte_shape from gguf.gguf_reader import GGUFReader, ReaderTensor -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoModelForImageTextToText from transformers.models.auto.configuration_auto import CONFIG_MAPPING -from typing import Any, Callable, Dict, Optional +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.log_utils import init_logger +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -DEQUANT_KEYS = ["token_embd.weight", "output.weight"] +try: + from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES +except ImportError: + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = {} + +logger = init_logger(__name__) + +DEQUANT_KEYS = frozenset(["token_embd.weight", "output.weight"]) + +# Native float tensors in GGUF are stored as plain arrays, not byte-packed blocks. +UNQUANTIZED_GGML_TYPES = {GGMLQuantizationType.F32, GGMLQuantizationType.F16} + + +class LightLLMGGUFReader: + + def __init__(self, gguf_path: str): + self.gguf_path = gguf_path + self._reader: Optional[GGUFReader] = None + self._config: Optional[dict] = None + self._gguf_to_hf: Optional[Dict[str, str]] = None + self._quant_meta_map: Optional[Dict[str, Any]] = None + + @property + def reader(self) -> GGUFReader: + if self._reader is None: + self._reader = GGUFReader(self.gguf_path) + return self._reader + + def read_field(self, field_name: str) -> Any: + """ Read a field from the GGUF reader. """ + field = self.reader.fields.get(field_name) + if field is None: + return None + return field.contents() + + def load_config(self) -> dict: + """ Load model config from the GGUF reader. """ + if self._config is not None: + return self._config + + try: + from transformers.modeling_gguf_pytorch_utils import load_gguf_checkpoint + except ImportError as e: + raise ImportError( + "Loading config from GGUF requires transformers with GGUF support and the gguf package." + ) from e + config: dict = load_gguf_checkpoint(self.gguf_path, return_tensors=False)["config"] + config["architectures"] = self._resolve_hf_architectures(config) + + self._config = config + return config + + def get_gguf_to_hf_mapping(self, config: Optional[dict] = None) -> Dict[str, str]: + """ Build a mapping from gguf tensor name to hf tensor name. """ + if self._gguf_to_hf is not None: + return self._gguf_to_hf + + if config is None: + config = self.load_config() + + self._gguf_to_hf = build_gguf_to_hf_mapping(config) + return self._gguf_to_hf + + def build_quant_meta_map(self, config: Optional[dict] = None) -> Dict[str, Any]: + """ Build a quant metadata map for GGUF quantized weights. """ + if self._quant_meta_map is not None: + return self._quant_meta_map + + from lightllm.common.quantization.quantize_method import GGUFWeightMeta + + if config is None: + config = self.load_config() + gguf_to_hf = self.get_gguf_to_hf_mapping(config) + gguf_quant_meta_map = {} + for t in self.reader.tensors: + if t.name in DEQUANT_KEYS: + continue + hf_name = gguf_to_hf.get(t.name) + if hf_name is None: + continue + logical_shape = _gguf_logical_shape(t) + if len(logical_shape) != 2: + continue + np_data = np.asarray(t.data) + gguf_quant_meta_map[hf_name] = GGUFWeightMeta( + shape=logical_shape, + dtype=_numpy_dtype_to_torch(np_data.dtype), + quant_type=t.tensor_type, + ) + + self._quant_meta_map = gguf_quant_meta_map + return gguf_quant_meta_map + + def load_weights( + self, + data_type: str, + config: dict, + pre_post_layer: Any = None, + transformer_layer_list: Any = None, + ) -> None: + """ Load GGUF weights into model layers, then release the reader. """ + if isinstance(data_type, str): + data_type = torch.float16 if data_type == "fp16" else torch.float32 + + if pre_post_layer is not None: + assert pre_post_layer.data_type_ == data_type, "type is not right" + + if transformer_layer_list is not None: + assert transformer_layer_list[0].data_type_ == data_type, "type is not right" + + try: + gguf_to_hf = self.get_gguf_to_hf_mapping(config) + gguf_tensor_names = {t.name for t in self.reader.tensors} + validate_gguf_weight_mapping( + gguf_to_hf=gguf_to_hf, + gguf_tensor_names=gguf_tensor_names, + ) + torch.cuda.set_device(get_current_device_id()) + gguf_weights = _gguf_reader_to_weight_dict(self.reader) + hf_weights = rename_weights(gguf_weights, gguf_to_hf) + del gguf_weights + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(hf_weights) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(hf_weights) + del hf_weights + finally: + self.close() + + def close(self) -> None: + self._reader = None + + def __enter__(self) -> "LightLLMGGUFReader": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def _resolve_hf_architectures(self, config: Dict[str, Any]) -> List[str]: + """ Resolve HuggingFace architectures from model_type. """ + model_type = config.get("model_type") + assert model_type is not None, "model_type is not found in config" + assert model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, ( + f"model_type {model_type!r} is not found in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES" + ) + + return [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]] + + +@lru_cache(maxsize=10) +def get_gguf_reader(gguf_path: str) -> LightLLMGGUFReader: + gguf_path = os.path.abspath(gguf_path) + return LightLLMGGUFReader(gguf_path) + + +def _numpy_dtype_to_torch(np_dtype: np.dtype) -> torch.dtype: + return torch.from_numpy(np.zeros((), dtype=np_dtype)).dtype + + +def _gguf_logical_shape(rt: ReaderTensor) -> Tuple[int, ...]: + return tuple(reversed([int(x) for x in rt.shape.tolist()])) + + +def _resolve_model_type(model_type: str) -> str: + MODEL_TYPE_MAPPING = { + "qwen2_vl": "qwen2vl", + "qwen2_5_vl": "qwen2vl", + } + + return MODEL_TYPE_MAPPING.get(model_type, model_type) + + +def _normalize_hf_weight_name(hf_name: str) -> str: + multimodal_lm_prefix = "model.language_model." + if hf_name.startswith(multimodal_lm_prefix): + return "model." + hf_name[len(multimodal_lm_prefix) :] + + return hf_name + + +def _is_multimodal(config: Dict[str, Any]) -> bool: + if config.get("vision_config") is not None: + return True + model_type = config.get("model_type") + return model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + + +def _dummy_hf_model_from_config(config: Dict[str, Any], hf_config: Any): + with torch.device("meta"): + if _is_multimodal(config): + return AutoModelForImageTextToText.from_config(hf_config) + return AutoModelForCausalLM.from_config(hf_config) def build_gguf_to_hf_mapping(config: Dict[str, Any]) -> Dict[str, str]: + """ + Build a mapping from gguf tensor name to hf tensor name, + e.g., 'token_embd.weight' -> 'model.embed_tokens.weight'. + """ num_layers = config.get("num_hidden_layers") assert num_layers is not None, "num_hidden_layers is not found in config" model_type = config.get("model_type") assert model_type is not None, "model_type is not found in config" arch = None + resolve_model_type = _resolve_model_type(model_type) for gguf_arch, hf_arch in gguf.MODEL_ARCH_NAMES.items(): - if hf_arch == model_type: + if hf_arch == resolve_model_type: arch = gguf_arch break - assert arch is not None, "model_type is not found in gguf.MODEL_ARCH_NAMES" + assert arch is not None, f"model_type {model_type!r} is not found in gguf.MODEL_ARCH_NAMES" tensor_name_map = gguf.get_tensor_name_map(arch, num_layers) config_cls = CONFIG_MAPPING[model_type] assert config_cls is not None, f"config_cls is not found in CONFIG_MAPPING for model_type={model_type}" hf_config = config_cls(**config) - with torch.device("meta"): - dummy_model = AutoModelForCausalLM.from_config(hf_config) - gguf_to_hf_name_mapping = {} + dummy_model = _dummy_hf_model_from_config(config, hf_config) + + gguf_to_hf_name_mapping: Dict[str, str] = {} + invalid_hf_names: List[str] = [] for hf_name in dummy_model.state_dict(): - name, extension = hf_name.rsplit(".", 1) + lightllm_hf_name = _normalize_hf_weight_name(hf_name) + name, extension = lightllm_hf_name.rsplit(".", 1) gguf_name = tensor_name_map.get_name(name) - gguf_to_hf_name_mapping[f"{gguf_name}.{extension}"] = hf_name + if gguf_name is None: + invalid_hf_names.append(hf_name) + continue + gguf_key = f"{gguf_name}.{extension}" + if gguf_key in gguf_to_hf_name_mapping: + raise ValueError( + f"duplicate GGUF tensor key {gguf_key!r} while mapping HF weights; " + f"existing={gguf_to_hf_name_mapping[gguf_key]!r}, new={lightllm_hf_name!r}" + ) + gguf_to_hf_name_mapping[gguf_key] = lightllm_hf_name + if invalid_hf_names: + logger.warning( + "skipped %d HF weight(s) with no GGUF tensor name mapping (first 5: %s)", + len(invalid_hf_names), + invalid_hf_names[:5], + ) return gguf_to_hf_name_mapping -def _reader_tensor_to_torch_cpu(rt: ReaderTensor, dequant: bool = False) -> torch.Tensor: +def validate_gguf_weight_mapping( + gguf_to_hf: Dict[str, str], + gguf_tensor_names: Set[str], +) -> None: + """ Validate GGUF↔HF mapping and log tensors that will not be loaded. """ + # Check if all gguf tensors are mapped to hf tensors + mapped_gguf_names = set(gguf_to_hf.keys()) + unmapped_gguf = sorted(gguf_tensor_names - mapped_gguf_names) + if unmapped_gguf: + logger.warning( + "GGUF file contains %d tensor(s) without HF mapping; they will be skipped (first 10: %s)", + len(unmapped_gguf), + unmapped_gguf[:10], + ) + + # Check if all hf tensors are mapped to gguf tensors + missing_in_gguf = sorted(mapped_gguf_names - gguf_tensor_names) + if missing_in_gguf: + logger.warning( + "HF mapping expects %d GGUF tensor(s) that are absent in the file (first 10: %s)", + len(missing_in_gguf), + missing_in_gguf[:10], + ) + + +def rename_weights( + weights: Dict[str, torch.Tensor], + gguf_to_hf: Dict[str, str], +) -> Dict[str, torch.Tensor]: + """ Rename GGUF weights to HF names. """ + renamed: Dict[str, torch.Tensor] = {} + for gguf_name, tensor in weights.items(): + hf_name = gguf_to_hf.get(gguf_name) + if hf_name is None: + continue + renamed[hf_name] = tensor + + return renamed + + +def _normalize_visual_module_key(module_key: str) -> str: + """ LightLLM visual modules omit the HF 'visual.' prefix used by GGUF MMPROJ. """ + if module_key.startswith("visual."): + return module_key + return f"visual.{module_key}" + + +def build_mmproj_to_hf_mapping( + module_keys: Iterable[str], + depth: int, +) -> Tuple[Dict[str, str], List[str]]: + """ Build gguf tensor name -> LightLLM visual module state dict name mapping. """ + tensor_name_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, depth) + gguf_to_hf: Dict[str, str] = {} + skipped: List[str] = [] + for hf_name in module_keys: + prefix, extension = hf_name.rsplit(".", 1) + lookup_prefix = _normalize_visual_module_key(prefix) + gguf_prefix = tensor_name_map.get_name(lookup_prefix) + if gguf_prefix is None: + skipped.append(hf_name) + continue + gguf_key = f"{gguf_prefix}.{extension}" + if gguf_key in gguf_to_hf: + raise ValueError( + f"duplicate GGUF tensor key {gguf_key!r} while mapping mmproj weights; " + f"existing={gguf_to_hf[gguf_key]!r}, new={hf_name!r}" + ) + gguf_to_hf[gguf_key] = hf_name + + if skipped: + logger.warning( + "skipped %d visual module weight(s) with no GGUF MMPROJ mapping (first 5: %s)", + len(skipped), + skipped[:5], + ) + return gguf_to_hf, skipped + + +def _merge_mmproj_patch_embed(gguf_weights: Dict[str, torch.Tensor]) -> None: + """ Merge split Conv2d patch-embed slices back into a Conv3d weight. """ + patch_key = "v.patch_embd.weight" + w0 = gguf_weights.pop(patch_key, None) + w1 = gguf_weights.pop(f"{patch_key}.1", None) + if w0 is None: + return + if w1 is not None: + # GGUF stores temporal_patch_size=2 as two [out, in, H, W] slices. + gguf_weights[patch_key] = torch.stack([w0, w1], dim=2) + else: + gguf_weights[patch_key] = w0 + + +def _merge_mmproj_attn_qkv(gguf_weights: Dict[str, torch.Tensor], depth: int) -> None: + """ Merge split q/k/v tensors in mmproj into fused qkv weights. """ + for i in range(depth): + prefix = f"v.blk.{i}" + qw = gguf_weights.pop(f"{prefix}.attn_q.weight", None) + kw = gguf_weights.pop(f"{prefix}.attn_k.weight", None) + vw = gguf_weights.pop(f"{prefix}.attn_v.weight", None) + if qw is not None: + gguf_weights[f"{prefix}.attn_qkv.weight"] = torch.cat([qw, kw, vw], dim=0) + qb = gguf_weights.pop(f"{prefix}.attn_q.bias", None) + kb = gguf_weights.pop(f"{prefix}.attn_k.bias", None) + vb = gguf_weights.pop(f"{prefix}.attn_v.bias", None) + if qb is not None: + gguf_weights[f"{prefix}.attn_qkv.bias"] = torch.cat([qb, kb, vb], dim=0) + + +def _supplement_mmproj_qkv_mapping( + gguf_to_hf: Dict[str, str], + gguf_weights: Dict[str, torch.Tensor], + depth: int, +) -> None: + """ Add gguf->module mappings for qkv tensors created after q/k/v merge. """ + for i in range(depth): + prefix = f"v.blk.{i}" + for extension in ("weight", "bias"): + gguf_key = f"{prefix}.attn_qkv.{extension}" + if gguf_key in gguf_weights: + gguf_to_hf[gguf_key] = f"blocks.{i}.attn.qkv.{extension}" + + +def load_mmproj_gguf_weights( + mmproj_path: str, + module_keys: Iterable[str], + depth: int, + dtype: torch.dtype, +) -> Dict[str, torch.Tensor]: + """ Load and dequantize mmproj GGUF weights into LightLLM visual module key names. """ + gguf_to_hf, _ = build_mmproj_to_hf_mapping(module_keys, depth) + reader = get_gguf_reader(mmproj_path) + try: + gguf_weights = _gguf_reader_to_weight_dict(reader.reader, dequant_all=True) + _merge_mmproj_patch_embed(gguf_weights) + _merge_mmproj_attn_qkv(gguf_weights, depth) + _supplement_mmproj_qkv_mapping(gguf_to_hf, gguf_weights, depth) + weight_dict = rename_weights(gguf_weights, gguf_to_hf) + return {name: tensor.to(dtype) for name, tensor in weight_dict.items()} + finally: + reader.close() + + +def _reader_tensor_to_torch_cpu( + rt: ReaderTensor, + dequant: bool = False, + logical_shape: Tuple[int, ...] = None, +) -> torch.Tensor: + """ Read a GGUF tensor to a CPU torch tensor. """ assert rt.shape.ndim <= 2, "GGUF tensor must be 2D or less" if dequant: d_tensor = dequantize(rt.data, rt.tensor_type) - else: - d_tensor = rt.data - arr = np.array(d_tensor, copy=True) + if logical_shape is not None: + d_tensor = d_tensor.reshape(logical_shape) + + return torch.from_numpy(np.array(d_tensor, copy=True)) - return torch.from_numpy(arr) + arr = np.array(rt.data, copy=True) + if logical_shape is None: + logical_shape = _gguf_logical_shape(rt) + # Norm/bias vectors and native float weights are plain arrays, not block-quantized bytes. + if rt.tensor_type in UNQUANTIZED_GGML_TYPES or len(logical_shape) == 1: + if arr.size != np.prod(logical_shape): + raise ValueError( + f"GGUF tensor {rt.name} has size {arr.size} but expected {np.prod(logical_shape)}" + ) + return torch.from_numpy(arr.reshape(logical_shape)) -def _gguf_reader_to_weight_dict(reader: GGUFReader) -> Dict[str, torch.Tensor]: + byte_shape = quant_shape_to_byte_shape(logical_shape, rt.tensor_type) + if arr.size != np.prod(byte_shape): + raise ValueError(f"GGUF tensor {rt.name} has size {arr.size} but expected {np.prod(byte_shape)}") + + arr = arr.reshape(byte_shape) + + return torch.from_numpy(arr) + + +def _gguf_reader_to_weight_dict( + reader: GGUFReader, + *, + dequant_all: bool = False, +) -> Dict[str, torch.Tensor]: + """ Read GGUF reader to torch tensor dictionary. """ gguf_weights = {} for t in reader.tensors: - if t.name in DEQUANT_KEYS: - gguf_weights[t.name] = _reader_tensor_to_torch_cpu(t, dequant=True) + if dequant_all or t.name in DEQUANT_KEYS: + gguf_weights[t.name] = _reader_tensor_to_torch_cpu( + t, dequant=True, logical_shape=_gguf_logical_shape(t) + ) else: gguf_weights[t.name] = _reader_tensor_to_torch_cpu(t, dequant=False) return gguf_weights -def rename_weights( - weights: Dict[str, torch.Tensor], - config: Dict[str, Any], - gguf_to_hf: Optional[Dict[str, str]] = None, -) -> Dict[str, torch.Tensor]: - if gguf_to_hf is None: - gguf_to_hf = build_gguf_to_hf_mapping(config) - return {gguf_to_hf[k]: v for k, v in weights.items() if k in gguf_to_hf} - - -def load_gguf_weights( - data_type: str, - weight_dir: str, - config: Dict[str, Any], - pre_post_layer: Any = None, - transformer_layer_list: Any = None, - weight_dict: Optional[Dict[str, torch.Tensor]] = None, - reader: Optional[GGUFReader] = None, - gguf_to_hf: Optional[Dict[str, str]] = None, - release_reader: Optional[Callable[[], None]] = None, -) -> None: - if isinstance(data_type, str): - data_type = torch.float16 if data_type == "fp16" else torch.float32 - if pre_post_layer is not None: - assert pre_post_layer.data_type_ == data_type, "type is not right" - if transformer_layer_list is not None: - assert transformer_layer_list[ - 0].data_type_ == data_type, "type is not right" - if weight_dict: - torch.cuda.set_device(get_current_device_id()) - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weight_dict) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weight_dict) - del weight_dict - if release_reader is not None: - release_reader() - return +def dequant_gguf_weight( + byte_weight: torch.Tensor, + quant_type: gguf.GGMLQuantizationType, + logical_shape: Tuple[int, int], + device: torch.device, + dtype: torch.dtype +) -> torch.Tensor: + """ + Dequantize a GGUF weight to a CPU torch tensor during loading for GGUF. + """ + expected_byte_shape = quant_shape_to_byte_shape(logical_shape, quant_type) + if tuple(byte_weight.shape) != expected_byte_shape: + raise ValueError( + f"byte shard shape {byte_weight.shape} != expected {expected_byte_shape}" + ) + fp32 = dequantize(byte_weight.contiguous().numpy(), quant_type) + arr = np.asarray(fp32, dtype=np.float32).reshape(logical_shape) - need_init_reader = reader is None - if need_init_reader: - reader = GGUFReader(weight_dir) - try: - weights = _gguf_reader_to_weight_dict(reader) - weights = rename_weights(weights, config=config, gguf_to_hf=gguf_to_hf) - finally: - if need_init_reader: - del reader - elif release_reader is not None: - release_reader() - - torch.cuda.set_device(get_current_device_id()) - if pre_post_layer is not None: - pre_post_layer.load_hf_weights(weights) - if transformer_layer_list is not None: - for layer in transformer_layer_list: - layer.load_hf_weights(weights) - del weights + return torch.from_numpy(arr).to(device=device, dtype=dtype) + + +def load_weight_shard( + *, + raw_weight: torch.Tensor, + param_name: str, + weight_pack: Any, + slicer: Any, + quant_method: Any, +) -> None: + """ + Load a GGUF linear weight shard with TP slicing. + For predquant weights, dequantize the full tensor first, then slice. + """ + if weight_pack.gguf_load_predquant: + meta = quant_method.gguf_quant_meta_map[param_name] + full_fp = dequant_gguf_weight( + byte_weight=raw_weight, + quant_type=meta.quant_type, + logical_shape=meta.shape, + device=weight_pack.weight.device, + dtype=weight_pack.weight.dtype, + ) + shard = slicer._slice_weight(full_fp) + weight_pack.weight.copy_(shard) + weight_pack.load_ok[0] = True + else: + shard = slicer._slice_weight(raw_weight) + quant_method.load_weight(shard, weight_pack) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 8f54e14a72..75b210c744 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -1,6 +1,7 @@ import torch import threading from typing import Dict, Any, Optional, Tuple, List +from lightllm.common.basemodel.layer_weights.gguf_load_utils import load_weight_shard from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl from lightllm.common.quantization.quantize_method import WeightPack from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import ( @@ -332,14 +333,21 @@ def _load_expert( w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}" w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}" w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}" - row_slice_func = self.row_slicer._slice_weight - col_slice_func = self.col_slicer._slice_weight - if w1_weight in weights: - self.quant_method.load_weight(row_slice_func(weights[w1_weight]), self.w1_list[local_expert_idx]) - if w3_weight in weights: - self.quant_method.load_weight(row_slice_func(weights[w3_weight]), self.w3_list[local_expert_idx]) - if w2_weight in weights: - self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) + expert_weights = ( + (w1_weight, self.w1_list[local_expert_idx], self.row_slicer), + (w3_weight, self.w3_list[local_expert_idx], self.row_slicer), + (w2_weight, self.w2_list[local_expert_idx], self.col_slicer), + ) + for param_name, weight_pack, slicer in expert_weights: + if param_name not in weights: + continue + load_weight_shard( + raw_weight=weights[param_name], + param_name=param_name, + weight_pack=weight_pack, + slicer=slicer, + quant_method=self.quant_method, + ) def _load_expert_scale( self, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index ddbf98a866..fb670fafef 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -148,10 +148,9 @@ def get_row_slice_mixin( ) -> SliceMixinTpl: if quant_method_name.startswith("awq"): return AwqQuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) - elif quant_method_name == "none": + if quant_method_name == "none" or quant_method_name.startswith("gguf"): return RowSliceMixin(tp_rank, tp_world_size, repeat_times) - else: - return QuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) + return QuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) def get_col_slice_mixin( @@ -159,7 +158,6 @@ def get_col_slice_mixin( ) -> SliceMixinTpl: if quant_method_name.startswith("awq"): return AwqQuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) - elif quant_method_name == "none": + if quant_method_name == "none" or quant_method_name.startswith("gguf"): return ColSliceMixin(tp_rank, tp_world_size, repeat_times) - else: - return QuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) + return QuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 87530e7951..dd7c36cc58 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -6,6 +6,7 @@ from typing import Optional, Tuple, List, Dict, Union, Type from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.basemodel.layer_weights.gguf_load_utils import load_weight_shard from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl from lightllm.common.quantization import Quantcfg from lightllm.common.quantization.no_quant import NoQuantization @@ -127,9 +128,13 @@ def _load_weight( if quanted_param_name in weights: param_name = quanted_param_name if param_name in weights: - slicer = self._get_param_slicer(sub_child_index) - weight = slicer._slice_weight(weights[param_name]) - self.quant_method.load_weight(weight, self.mm_param_list[sub_child_index]) + load_weight_shard( + raw_weight=weights[param_name], + param_name=param_name, + weight_pack=self.mm_param_list[sub_child_index], + slicer=self._get_param_slicer(sub_child_index), + quant_method=self.quant_method, + ) return def _load_bias( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index fb50398368..9fda4967df 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -22,6 +22,8 @@ def __init__( ) -> None: self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + if isinstance(out_dims, int): + out_dims = [out_dims] out_dims = [self._get_tp_dim(out_dim) for out_dim in out_dims] super().__init__( in_dim=in_dim, diff --git a/lightllm/common/gguf_kernel/__init__.py b/lightllm/common/gguf_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/gguf/triton_kernel/dequantization.py b/lightllm/common/gguf_kernel/dequantization.py similarity index 100% rename from lightllm/models/gguf/triton_kernel/dequantization.py rename to lightllm/common/gguf_kernel/dequantization.py diff --git a/lightllm/common/quantization/gguf.py b/lightllm/common/quantization/gguf.py index 17037e0a03..fd9e22888b 100644 --- a/lightllm/common/quantization/gguf.py +++ b/lightllm/common/quantization/gguf.py @@ -6,10 +6,27 @@ from .registry import QUANTMETHODS from .quantize_method import GGUFWeightMeta, QuantizationMethod, WeightPack from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager -from lightllm.models.gguf.triton_kernel.dequantization import get_gguf_dequant_fn +from lightllm.common.basemodel.layer_weights.gguf_load_utils import UNQUANTIZED_GGML_TYPES +from lightllm.common.gguf_kernel.dequantization import get_gguf_dequant_fn +from lightllm.utils.log_utils import init_logger -# These types are not quantized, so they are directly used -UNQUANTIZED_TYPES = {GGMLQuantizationType.F32, GGMLQuantizationType.F16} +logger = init_logger(__name__) + +# To avoid logging the warning multiple times +_predquant_warned = False + + +def _warn_predquant_once(unsupported_quant_type: str) -> None: + global _predquant_warned + if _predquant_warned: + return + _predquant_warned = True + logger.warning( + "The current GGUF model contains quantization formats that do not support runtime " + f"dequantization (e.g., {unsupported_quant_type}). These weights will be dequantized during model loading, which " + "may increase GPU memory usage. To add support, register a dequantization implementation via " + "register_gguf_dequant in lightllm/common/gguf_kernel/dequantization.py.", + ) def _linear( @@ -41,11 +58,14 @@ def apply( use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # allocate output tensor if not provided + # Allocate output tensor based on the shape of the weight pack assert weight_pack.gguf_quant_type is not None, "gguf_quant_type must be set on WeightPack" if out is None: - if weight_pack.gguf_quant_type in UNQUANTIZED_TYPES: - out_features = weight_pack.weight.shape[0] + if ( + weight_pack.gguf_quant_type in UNQUANTIZED_GGML_TYPES + or weight_pack.gguf_load_predquant + ): + out_features = weight_pack.weight.shape[-2] else: out_features, _ = quant_shape_from_byte_shape( weight_pack.weight.shape, @@ -56,11 +76,14 @@ def apply( out = g_cache_manager.alloc_tensor(shape, input_tensor.dtype, device=input_tensor.device) else: out = torch.empty(shape, dtype=input_tensor.dtype, device=input_tensor.device) - # unquantized types are directly used - if weight_pack.gguf_quant_type in UNQUANTIZED_TYPES: + # Unquantized types and load-time dequantized weights are directly used + if ( + weight_pack.gguf_quant_type in UNQUANTIZED_GGML_TYPES + or weight_pack.gguf_load_predquant + ): weight = weight_pack.weight.t() return _linear(input_tensor, weight, out, bias) - # quantized types are dequantized and then used + # For quantized types, we need to dequantize the weight and then use it dequant_fn = get_gguf_dequant_fn(weight_pack.gguf_quant_type) if dequant_fn is None: raise ValueError( @@ -87,11 +110,18 @@ def apply( def method_name(self): return "gguf" - def load_weight(self, weight: torch.Tensor, - weight_pack: WeightPack) -> None: + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack) -> None: + assert not weight_pack.gguf_load_predquant, ( + "predquant GGUF weights must be loaded via load_weight_shard" + ) + if weight.shape != weight_pack.weight.shape: + raise ValueError( + f"GGUF weight shape {weight.shape} does not match weight pack shape {weight_pack.weight.shape}" + ) device = weight_pack.weight.device - weight_pack.weight.copy_(weight.contiguous().to( - device=device, dtype=weight_pack.weight.dtype)) + weight_pack.weight.copy_( + weight.contiguous().to(device=device, dtype=weight_pack.weight.dtype) + ) weight_pack.load_ok[0] = True def _check_weight_need_quanted(self, weight: torch.Tensor) -> bool: @@ -106,67 +136,57 @@ def _create_weight( num_experts: int = 1, weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: - assert weight_names is not None and len( - weight_names) > 0, "weight_names must be provided" - if self.gguf_quant_meta_map is None: - raise ValueError( - f"Cannot load GGUF-quantized weights {weight_names!r}: no GGUF metadata was built. " - f"quant_type is 'gguf', but model_dir has no .gguf file (only HuggingFace/safetensors). " - f"Use --quant_type to no set gguf for HF checkpoints, or set model_dir to a path that contains " - f"exactly one .gguf file." - ) + assert weight_names is not None and len(weight_names) > 0, "weight_names must be provided" + assert len(weight_names) == len(out_dims), "weight_names and out_dims must align" - assert len(weight_names) == len( - out_dims), "weight_names and out_dims must align" weight_dtype = None - gguf_quant_types = set() + shard_quant_types: List[GGMLQuantizationType] = [] for weight_name in weight_names: meta: GGUFWeightMeta = self.gguf_quant_meta_map[weight_name] - gguf_quant_types.add(meta.quant_type) + shard_quant_types.append(meta.quant_type) quant_shape = meta.shape assert len( quant_shape ) == 2, f"GGUF linear weight must be 2D, got {quant_shape} for {weight_name}" - _, in_d = quant_shape[0], quant_shape[1] - assert in_d == in_dim, ( - f"GGUF tensor {weight_name} has in_features {in_d}, layer expects in_dim {in_dim}" - ) if weight_dtype is None: weight_dtype = meta.dtype else: assert weight_dtype == meta.dtype, f"merged GGUF weights must share dtype, got {weight_dtype} vs {meta.dtype}" - assert len( - gguf_quant_types - ) == 1, f"merged GGUF weights must share quant_type, got {gguf_quant_types}" - gguf_quant_type = gguf_quant_types.pop() - if gguf_quant_type not in UNQUANTIZED_TYPES: + # If there are mixed quant types, we need to dequant each shard at load time + mixed_quant_types = len(set(shard_quant_types)) > 1 + gguf_quant_type = shard_quant_types[0] + gguf_load_predquant = mixed_quant_types + if not mixed_quant_types and gguf_quant_type not in UNQUANTIZED_GGML_TYPES: if get_gguf_dequant_fn(gguf_quant_type) is None: - raise ValueError( - f"No CUDA dequant registered for GGUF type {gguf_quant_type!r}; " - f"add @register_gguf_dequant in " - f"lightllm/models/gguf/triton_kernel/dequantization.py" - ) + gguf_load_predquant = True + _warn_predquant_once(gguf_quant_type.name) - # Buffer sizes follow layer/tp shard dims from the caller (load_path slices file weights into this storage). logical_shape_rowmajor = (sum(out_dims), in_dim) expert_prefix = (num_experts, ) if num_experts > 1 else () - if gguf_quant_type in UNQUANTIZED_TYPES: + if gguf_quant_type in UNQUANTIZED_GGML_TYPES or gguf_load_predquant: full_shape = expert_prefix + logical_shape_rowmajor + storage_dtype = dtype if gguf_load_predquant else weight_dtype else: full_shape = expert_prefix + quant_shape_to_byte_shape( logical_shape_rowmajor, gguf_quant_type) - weight = torch.empty(full_shape, dtype=weight_dtype).cuda(device_id) + storage_dtype = weight_dtype + + weight = torch.empty(full_shape, dtype=storage_dtype).cuda(device_id) + mm_param = WeightPack( weight=weight, weight_scale=None, weight_zero_point=None, gguf_quant_type=gguf_quant_type, + gguf_load_predquant=gguf_load_predquant, ) mm_param_list = self._split_weight_pack( mm_param, weight_out_dims=out_dims, weight_split_dim=-2, ) + for pack, shard_quant_type in zip(mm_param_list, shard_quant_types): + pack.gguf_quant_type = shard_quant_type return mm_param, mm_param_list diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index ef244f770f..e2b1be0250 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -14,16 +14,14 @@ class GGUFWeightMeta: quant_type: GGMLQuantizationType -def numpy_dtype_to_torch(np_dtype: np.dtype) -> torch.dtype: - return torch.from_numpy(np.zeros((), dtype=np_dtype)).dtype - - @dataclass class WeightPack: weight: Optional[torch.Tensor] = None weight_scale: Optional[torch.Tensor] = None weight_zero_point: Optional[torch.Tensor] = None gguf_quant_type: Optional[GGMLQuantizationType] = None + # Dequantize unsupported quantization formats during loading for GGUF + gguf_load_predquant: bool = False def __post_init__(self): self.load_ok = [False, self.weight_scale is None, self.weight_zero_point is None] @@ -38,6 +36,7 @@ def get_expert(self, expert_idx: int): weight_scale=weight_scale, weight_zero_point=weight_zero_point, gguf_quant_type=self.gguf_quant_type, + gguf_load_predquant=self.gguf_load_predquant, ) @@ -190,6 +189,7 @@ def _split_weight_pack( weight_scale=weight_scale, weight_zero_point=weight_zero_point, gguf_quant_type=weight_pack.gguf_quant_type, + gguf_load_predquant=weight_pack.gguf_load_predquant, ) ) return mm_param_list diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index ab5967866d..fc94222b5b 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -8,7 +8,7 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args -from lightllm.utils.config_utils import get_vocab_size +from lightllm.utils.config_utils import get_model_paths, get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.common.linear_att_cache_manager.layer_cache import LayerCache from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import LinearAttCacheManager @@ -112,8 +112,7 @@ class ReqSamplingParamsManager: def __init__(self, max_request_num): # mode ["cpu_counter", "pin_mem_counter", "gpu_counter"] self.penalty_counter_mode = get_env_start_args().penalty_counter_mode - start_args = get_env_start_args() - self.vocab_size = get_vocab_size(config_path=start_args.config_path, model_dir=start_args.model_dir) + self.vocab_size = get_vocab_size(get_model_paths()) self.req_to_presence_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") diff --git a/lightllm/models/gemma3/model.py b/lightllm/models/gemma3/model.py index 9931c31713..e49b3ff813 100644 --- a/lightllm/models/gemma3/model.py +++ b/lightllm/models/gemma3/model.py @@ -11,6 +11,7 @@ from lightllm.models.gemma3.layer_infer.transformer_layer_infer import Gemma3TransformerLayerInfer from lightllm.models.gemma3.layer_weights.pre_and_post_layer_weight import Gemma3PreAndPostLayerWeight from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel from lightllm.models.llama.model import LlamaTpPartModel from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem from lightllm.server.core.objs import SamplingParams @@ -151,6 +152,7 @@ def _init_mem_manager(self): return def _init_config(self): + TpPartBaseModel._init_config(self) with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: self.config = json.load(json_file) # rename keys diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 7156a5ce23..7d51c2d001 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -4,11 +4,13 @@ import torch.nn.functional as F from PIL import Image from typing import List, Optional +from lightllm.common.basemodel.layer_weights.gguf_load_utils import load_mmproj_gguf_weights +from lightllm.utils.log_utils import init_logger from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from io import BytesIO import torch.nn as nn from transformers.activations import ACT2FN -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor, resize_image from safetensors import safe_open from lightllm.server.multimodal_params import ImageItem from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding @@ -17,6 +19,8 @@ from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton +logger = init_logger(__name__) + class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -154,8 +158,7 @@ def __init__( fullatt_block_indexes=[7, 15, 23, 31], **kwargs, ): - super().__init__() - self.weight_dir = kvargs["weight_dir"] + super().__init__() self.data_type = kvargs.get("data_type", "bfloat16") self.depth = depth @@ -204,10 +207,8 @@ def __init__( self.gradient_checkpointing = False - processor_config_path = os.path.join(self.weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + processor_dir = kvargs.get("processor_dir") + self.processor = load_image_processor(processor_dir, Qwen2VLImageProcessor) self._init_datatype() @@ -350,6 +351,22 @@ def load_image(self, img: List[ImageItem]): return pixel_values.to(dtype=self.data_type), image_grid_thw def load_model(self, weight_dir): + if weight_dir.endswith(".gguf"): + load_dtype = self.data_type + if not isinstance(load_dtype, torch.dtype): + load_dtype = torch.bfloat16 if load_dtype in ("bf16", "bfloat16") else torch.float16 + weight_dict = load_mmproj_gguf_weights( + mmproj_path=weight_dir, + module_keys=self.state_dict().keys(), + depth=self.depth, + dtype=load_dtype, + ) + missing, unexpected = self.load_state_dict(weight_dict, strict=False) + if missing: + logger.warning("mmproj missing keys (first 10): %s", missing[:10]) + if unexpected: + logger.warning("mmproj unexpected keys (first 10): %s", unexpected[:10]) + return bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 237c4ad897..13d2bd884b 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -8,6 +8,7 @@ from lightllm.models.registry import ModelRegistry from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer +from lightllm.utils.config_utils import get_model_config from .vision_process import smart_resize from lightllm.models.qwen2.model import Qwen2TpPartModel @@ -100,8 +101,7 @@ def __init__(self, kvargs): return def _init_config(self): - with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: - self.config = json.load(json_file) + self.config = get_model_config(self.model_paths_) # rename keys repair_config(self.config, same_names=["num_attention_heads", "n_head"]) repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 6076756043..17cb9cfa72 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -33,7 +33,7 @@ from safetensors import safe_open from lightllm.server.multimodal_params import ImageItem from lightllm.server.visualserver import get_vit_attn_backend -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -192,6 +192,7 @@ def __init__( **kwargs, ): super().__init__() + self.processor_dir = kvargs.get("processor_dir") self.data_type = kvargs.get("data_type", "bfloat16") self.depth = depth @@ -239,11 +240,7 @@ def _init_datatype(self): return def load_model(self, weight_dir): - - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + self.processor = load_image_processor(self.processor_dir, Qwen2VLImageProcessor) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index bc313fe467..1f9cd5842e 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json import math +import os import torch import numpy as np from PIL import Image @@ -284,3 +286,14 @@ def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torc image_grid_thw = torch.as_tensor(processed_grids) return pixel_values, image_grid_thw + + +def load_image_processor(processor_dir: str, processor_cls): + assert os.path.isdir(processor_dir), f"Processor directory not found at {processor_dir}" + config_path = os.path.join(processor_dir, "preprocessor_config.json") + + assert os.path.exists(config_path), f"Processor config file not found at {config_path} for {processor_dir}" + with open(config_path, "r") as f: + processor_config = json.load(f) + + return processor_cls(**processor_config) \ No newline at end of file diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py index 0276724749..bf582d49a3 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -27,7 +27,7 @@ from lightllm.server.multimodal_params import ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention @@ -139,6 +139,7 @@ def __init__( **kwargs, ): super().__init__() + self.processor_dir = kvargs.get("processor_dir", kvargs.get("weight_dir")) self.data_type = kvargs.get("data_type", "bfloat16") self.depth = depth @@ -220,11 +221,7 @@ def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature return all_img_embeds_ds, valid_ids def load_model(self, weight_dir): - - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + self.processor = load_image_processor(self.processor_dir, Qwen2VLImageProcessor) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index bed8898115..019137e36a 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -27,7 +27,7 @@ from lightllm.server.multimodal_params import ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention from lightllm.utils.log_utils import init_logger @@ -135,6 +135,7 @@ def __init__( **kwargs, ): super().__init__() + self.processor_dir = kvargs.get("processor_dir", kvargs.get("weight_dir")) self.data_type = kvargs.get("data_type", "bfloat16") self.depth = depth @@ -215,11 +216,7 @@ def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature return all_img_embeds_ds, valid_ids def load_model(self, weight_dir): - - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + self.processor = load_image_processor(self.processor_dir, Qwen2VLImageProcessor) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 9deaf08575..2fd221529f 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -16,7 +16,7 @@ from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.server.multimodal_params import ImageItem -from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, resize_image +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor def add_split_tokens(image_features, image_newline_embed, image_new_embed): @@ -217,10 +217,7 @@ def forward( return image_features def load_model(self, weight_dir): - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + self.processor = load_image_processor(self.processor_dir, Qwen2VLImageProcessor) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 837b09aa8d..b8ffd5ba59 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -132,7 +132,13 @@ def make_argument_parser() -> argparse.ArgumentParser: "--tokenizer_dir", type=str, default=None, - help="tokenizer directory; required for GGUF models, otherwise defaults to model_dir", + help="tokenizer directory; for gguf, load tokenizer from gguf is default, override it by providing this param", + ) + parser.add_argument( + "--mmproj_path", + type=str, + default=None, + help="The path of mmproj for multimodal mode, only supported for GGUF models", ) parser.add_argument( "--tokenizer_mode", diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index c324df19c8..16b7f42e12 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -32,6 +32,7 @@ from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.error_utils import ClientDisconnected +from lightllm.utils.config_utils import ModelPaths, has_vision_module from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name @@ -232,6 +233,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req created_time = int(time.time()) + supports_vl_chat = has_vision_module(ModelPaths.from_args(g_objs.args)) multimodal_params_dict = {"images": [], "audios": []} for message in request.messages: if isinstance(message.content, list): @@ -276,6 +278,9 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req else: raise ValueError("Unrecognized audio input. Supports local path, http url, base64.") + if not supports_vl_chat: + message.content = "\n".join(texts) if texts else "" + tools = None if request.tools and request.tool_choice != "none": # request.skip_special_tokens = False diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 90635034de..094db272d0 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -18,13 +18,13 @@ from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size from lightllm.utils.config_utils import ( - align_quant_type_for_gguf_model, - check_gguf_quant_model_dir, - check_gguf_tokenizer_dir, + ModelPaths, + apply_gguf_quant_type, has_audio_module, has_vision_module, is_linear_att_mixed_model, auto_set_max_req_total_len, + check_gguf_multimodal_paths, ) from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args @@ -77,7 +77,9 @@ def normal_or_p_d_start(args): args: StartArgs = args - auto_set_max_req_total_len(args) + paths = ModelPaths.from_args(args) + + auto_set_max_req_total_len(args, paths) set_unique_server_name(args) if args.enable_mps: @@ -89,16 +91,22 @@ def normal_or_p_d_start(args): return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 - if args.disable_vision is None: - if has_vision_module(args.model_dir): - args.disable_vision = False - else: - args.disable_vision = True - if args.disable_audio is None: - if has_audio_module(args.model_dir): - args.disable_audio = False - else: + if paths.is_gguf: + if args.disable_vision is None: + args.disable_vision = not bool(args.mmproj_path) + if args.disable_audio is None: args.disable_audio = True + else: + if args.disable_vision is None: + if has_vision_module(paths): + args.disable_vision = False + else: + args.disable_vision = True + if args.disable_audio is None: + if has_audio_module(paths): + args.disable_audio = False + else: + args.disable_audio = True # pd 分离模式下,不启动多模态的模块 if args.run_mode in ["decode", "nixl_decode"]: @@ -110,6 +118,9 @@ def normal_or_p_d_start(args): else: args.enable_multimodal = True + # Check if the tokenizer directory and mmproj path are provided for GGUF multimodal models + check_gguf_multimodal_paths(paths, enable_multimodal=args.enable_multimodal) + if args.enable_cpu_cache: # 生成一个用于创建cpu kv cache的共享内存id。 args.cpu_kv_cache_shm_id = uuid.uuid1().int % 123456789 @@ -145,14 +156,7 @@ def normal_or_p_d_start(args): if not args.disable_shm_warning: check_recommended_shm_size(args) - check_gguf_tokenizer_dir(args.model_dir, args.tokenizer_dir) - aligned_quant_type = align_quant_type_for_gguf_model(args.model_dir, args.quant_type, args.quant_cfg) - if aligned_quant_type != args.quant_type: - logger.warning( - f"model_dir contains GGUF weights; overriding --quant_type {args.quant_type!r} -> {aligned_quant_type!r}" - ) - args.quant_type = aligned_quant_type - check_gguf_quant_model_dir(args.model_dir, args.quant_type, args.quant_cfg) + args.quant_type = apply_gguf_quant_type(paths, args.quant_type) assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] # 确保单机上多实列不冲突 @@ -303,7 +307,7 @@ def normal_or_p_d_start(args): # linear_att_cache_size 只会在 qwen3.5 等混合线性层模型中生效。 args.linear_att_cache_size = args.running_max_req_size * 2 - if args.enable_cpu_cache and is_linear_att_mixed_model(args.model_dir): + if args.enable_cpu_cache and is_linear_att_mixed_model(paths): args.cpu_cache_token_page_size = args.linear_att_hash_page_size * args.linear_att_page_block_num logger.info(f"set cpu_cache_token_page_size to {args.cpu_cache_token_page_size} for linear hybrid att model") @@ -317,15 +321,13 @@ def normal_or_p_d_start(args): if args.eos_id is None: from lightllm.utils.config_utils import get_eos_token_ids - args.eos_id = get_eos_token_ids(args.model_dir) + args.eos_id = get_eos_token_ids(paths) # 如果 tool_call_parser 是 None,尝试根据模型类型自动设置 if args.tool_call_parser is None: from lightllm.utils.config_utils import get_tool_call_parser_for_model - args.tool_call_parser = get_tool_call_parser_for_model( - config_path=args.config_path, model_dir=args.model_dir - ) + args.tool_call_parser = get_tool_call_parser_for_model(paths) if args.tool_call_parser: logger.info(f"Auto set tool_call_parser to {args.tool_call_parser} based on model type") @@ -333,16 +335,14 @@ def normal_or_p_d_start(args): if args.reasoning_parser is None: from lightllm.utils.config_utils import get_reasoning_parser_for_model - args.reasoning_parser = get_reasoning_parser_for_model( - config_path=args.config_path, model_dir=args.model_dir - ) + args.reasoning_parser = get_reasoning_parser_for_model(paths) if args.reasoning_parser: logger.info(f"Auto set reasoning_parser to {args.reasoning_parser} based on model type") if args.data_type is None: from lightllm.utils.config_utils import get_dtype - args.data_type = get_dtype(args.model_dir) + args.data_type = get_dtype(paths) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] already_uesd_ports = [args.port] @@ -543,7 +543,8 @@ def pd_master_start(args): if args.run_mode != "pd_master": return - auto_set_max_req_total_len(args) + paths = ModelPaths.from_args(args) + auto_set_max_req_total_len(args, paths) # when use config_server to support multi pd_master node, we # need generate unique node id for each pd_master node. @@ -568,21 +569,7 @@ def pd_master_start(args): set_env_start_args(args) - check_gguf_tokenizer_dir(args.model_dir, args.tokenizer_dir) - aligned_quant_type = align_quant_type_for_gguf_model(args.model_dir, args.quant_type, args.quant_cfg) - if aligned_quant_type != args.quant_type: - logger.warning( - f"model_dir contains GGUF weights; overriding --quant_type {args.quant_type!r} -> {aligned_quant_type!r}" - ) - args.quant_type = aligned_quant_type - check_gguf_quant_model_dir(args.model_dir, args.quant_type, args.quant_cfg) - aligned_quant_type = align_quant_type_for_gguf_model(args.model_dir, args.quant_type, args.quant_cfg) - if aligned_quant_type != args.quant_type: - logger.warning( - f"model_dir contains GGUF weights; overriding --quant_type {args.quant_type!r} -> {aligned_quant_type!r}" - ) - args.quant_type = aligned_quant_type - check_gguf_quant_model_dir(args.model_dir, args.quant_type, args.quant_cfg) + args.quant_type = apply_gguf_quant_type(paths, args.quant_type) process_manager.start_submodule_processes( start_funcs=[ @@ -623,6 +610,7 @@ def visual_only_start(args): from lightllm.server.core.objs.start_args_type import StartArgs args: StartArgs = args + paths = ModelPaths.from_args(args) if args.afs_image_embed_dir is not None: os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) os.chmod(args.afs_image_embed_dir, 0o777) @@ -641,7 +629,7 @@ def visual_only_start(args): if args.data_type is None: from lightllm.utils.config_utils import get_dtype - args.data_type = get_dtype(args.model_dir) + args.data_type = get_dtype(paths) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] logger.info(f"alloced ports: {can_use_ports}") diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 04e0c12c31..5f34166cf4 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -1,5 +1,6 @@ import os import json +from lightllm.utils.config_utils import ModelPaths from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -11,9 +12,9 @@ def init_tokenizer(args): global tokenizer + paths = ModelPaths.from_args(args) tokenizer = get_tokenizer( - tokenizer_name=args.model_dir, - tokenizer_dir=args.tokenizer_dir, + paths, tokenizer_mode=args.tokenizer_mode, trust_remote_code=args.trust_remote_code, ) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 954daa50fe..2e8000cf08 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -29,7 +29,10 @@ class StartArgs: select_p_d_node_strategy: str = field(default=None) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) + tokenizer_dir: Optional[str] = field(default=None) tokenizer_mode: str = field(default="slow") + config_path: Optional[str] = field(default=None) + mmproj_path: Optional[str] = field(default=None) load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) mem_fraction: float = field(default=0.9) diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index b5b512d31e..588c78201d 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -12,6 +12,7 @@ from .decode import decode_token from .decode_mode_fix import decode_mode_fix from .decode_req import DecodeReq +from lightllm.utils.config_utils import ModelPaths from ..tokenizer import get_tokenizer import pickle import time @@ -35,8 +36,7 @@ def __init__( self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}") self.tokenizer = get_tokenizer( - tokenizer_name=args.model_dir, - tokenizer_dir=args.tokenizer_dir, + ModelPaths.from_args(args), tokenizer_mode=args.tokenizer_mode, trust_remote_code=args.trust_remote_code, ) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 2b0a827d14..e8e0070ffa 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -32,7 +32,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient from lightllm.utils.statics_utils import MovingAverage -from lightllm.utils.config_utils import get_vocab_size +from lightllm.utils.config_utils import ModelPaths, get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import ClientDisconnected, NixlPrefillNodeStopGenToken from rpyc.utils.classic import obtain @@ -102,9 +102,10 @@ def __init__( self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"") + paths = ModelPaths.from_args(args) + self.tokenizer = get_tokenizer( - tokenizer_name=args.model_dir, - tokenizer_dir=args.tokenizer_dir, + paths, tokenizer_mode=args.tokenizer_mode, trust_remote_code=args.trust_remote_code, ) @@ -122,7 +123,7 @@ def __init__( self.per_token_costs = MovingAverage() # 有的模型的vocab size 读取tokenizer和config.json中不一致 self.vocab_size = max( - get_vocab_size(config_path=args.config_path, model_dir=args.model_dir), + get_vocab_size(paths), self.tokenizer.vocab_size, ) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index afa198bb45..3c317abe68 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -12,6 +12,7 @@ from ..pd_io_struct import PD_Client_Obj, UpKVStatus, NixlUpKVStatus, ObjType, NodeRole, NIXLDecodeNodeInfo from lightllm.server.core.objs import SamplingParams, StartArgs from ..multimodal_params import MultimodalParams +from lightllm.utils.config_utils import ModelPaths from ..tokenizer import get_tokenizer from ..req_id_generator import ReqIDGenerator, convert_sub_id_to_group_id from fastapi import Request @@ -43,9 +44,8 @@ def __init__( self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对 self.tokenizer = get_tokenizer( - args.model_dir, - args.tokenizer_dir, - args.tokenizer_mode, + ModelPaths.from_args(args), + tokenizer_mode=args.tokenizer_mode, trust_remote_code=args.trust_remote_code, ) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 0254cec1d7..1d83245bfa 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -19,7 +19,7 @@ from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify -from lightllm.utils.config_utils import get_model_config_dict +from lightllm.utils.config_utils import _create_model_paths, get_model_config from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -151,7 +151,7 @@ def init_model(self, kvargs): if self.args.enable_multimodal: g_infer_context.init_cpu_embed_cache_client() - model_cfg = get_model_config_dict(config_path=self.args.config_path, model_dir=self.weight_dir) + model_cfg = get_model_config(_create_model_paths(self.weight_dir)) model_kvargs = { "weight_dir": self.weight_dir, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py index 8f4745d487..c9314abafa 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py @@ -6,6 +6,7 @@ from .impl import ChunkedPrefillBackend from lightllm.server.core.objs import FinishStatus from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.utils.config_utils import ModelPaths from lightllm.server.tokenizer import get_tokenizer from typing import List, Tuple from lightllm.utils.log_utils import init_logger @@ -36,9 +37,8 @@ def init_custom(self): self.tokenizer = TransformerTokenizer( get_tokenizer( - self.args.model_dir, - self.args.tokenizer_dir, - self.args.tokenizer_mode, + ModelPaths.from_args(self.args), + tokenizer_mode=self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code, ) ) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py index 26ab259863..887f89c2f4 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py @@ -2,6 +2,7 @@ from .impl import ChunkedPrefillBackend from typing import List from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.utils.config_utils import ModelPaths from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -21,9 +22,8 @@ def init_custom(self): 初始化tokenizer 词表相关的的操作 """ self.tokenizer = get_tokenizer( - self.args.model_dir, - self.args.tokenizer_dir, - self.args.tokenizer_mode, + ModelPaths.from_args(self.args), + tokenizer_mode=self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code, ) vob_dict = self.tokenizer.get_vocab() diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py index 2bdae895f7..09fde5cf75 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py @@ -5,6 +5,7 @@ from lightllm.utils.infer_utils import calculate_time from lightllm.server.core.objs import FinishStatus from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.utils.config_utils import ModelPaths from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -23,9 +24,8 @@ def init_custom(self): import xgrammar as xgr self.tokenizer = get_tokenizer( - self.args.model_dir, - self.args.tokenizer_dir, - self.args.tokenizer_mode, + ModelPaths.from_args(self.args), + tokenizer_mode=self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code, ) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 7d17db4fbc..a8ebe96665 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -21,7 +21,8 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.convert_slow_tokenizer import convert_slow_tokenizer from transformers.configuration_utils import PretrainedConfig -from lightllm.utils.config_utils import resolve_tokenizer_dir +from lightllm.utils.config_utils import ModelPaths, get_model_config +from lightllm.utils.gguf_tokenizer_utils import load_tokenizer_from_gguf from lightllm.utils.log_utils import init_logger from ..models.tarsier2.model import Tarsier2Tokenizer @@ -38,100 +39,147 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" -def get_tokenizer( - tokenizer_name: str, - tokenizer_dir: str = None, - tokenizer_mode: str = "auto", - trust_remote_code: bool = False, +def _load_hf_tokenizer( + path: str, + trust_remote_code: bool, *args, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - """Gets a tokenizer for the given model name via Huggingface.""" - tokenizer_name = resolve_tokenizer_dir(tokenizer_name, tokenizer_dir) + try: + return AutoTokenizer.from_pretrained(path, trust_remote_code=trust_remote_code, *args, **kwargs) + except TypeError as e: + logger.warning(f"load fast tokenizer fail: {str(e)}") + kwargs["use_fast"] = False + return AutoTokenizer.from_pretrained(path, trust_remote_code=trust_remote_code, *args, **kwargs) + + +def _load_base_tokenizer( + load_path: str, + from_gguf: bool, + model_cfg: dict, + trust_remote_code: bool, + tokenizer_mode: str, + *args, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """ Load the base tokenizer based on AutoTokenizer. """ if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False + + if from_gguf: + logger.info(f"Loading tokenizer from GGUF file: {load_path}") + return load_tokenizer_from_gguf(load_path, model_cfg, *args, **kwargs) - if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True): + if "llama" in load_path.lower() and kwargs.get("use_fast", True): logger.info( "For some LLaMA-based models, initializing the fast tokenizer may " "take a long time. To eliminate the initialization time, consider " f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " "tokenizer." ) - # tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name) - # tokenizer = convert_slow_tokenizer(tokenizer) - # return tokenizer - - try: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs) - except TypeError as e: - # The LLaMA tokenizer causes a protobuf error in some environments, using slow mode. - # you can try pip install protobuf==3.20.0 to try repair - logger.warning(f"load fast tokenizer fail: {str(e)}") - kwargs["use_fast"] = False - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs) - + + tokenizer = _load_hf_tokenizer(load_path, trust_remote_code, *args, **kwargs) if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.info( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) - model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) + return tokenizer + + +def _wrap_tokenizer( + base_tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + model_cfg: dict, + processor_dir: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: model_type = model_cfg.get("model_type", "") # DeepSeek-V3.2 custom tokenizer mode: wraps the HF tokenizer with # a Python-based apply_chat_template that uses encoding_dsv32.py. if model_type == "deepseek_v32": from ..models.deepseek3_2.model import DeepSeekV32Tokenizer - hf_tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs - ) logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") - return DeepSeekV32Tokenizer(hf_tokenizer) + return DeepSeekV32Tokenizer(base_tokenizer) if model_cfg["architectures"][0] == "TarsierForConditionalGeneration": - from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor - - image_processor = Qwen2VLImageProcessor.from_pretrained(tokenizer_name) - tokenizer = Tarsier2Tokenizer(tokenizer=tokenizer, image_processor=image_processor, model_cfg=model_cfg) - elif model_type == "llava" or model_type == "internlmxcomposer2": - tokenizer = LlavaTokenizer(tokenizer, model_cfg) - elif model_type == "qwen" and "visual" in model_cfg: - tokenizer = QWenVLTokenizer(tokenizer, model_cfg) - elif model_type in ["qwen2_vl", "qwen2_5_vl"] and "vision_config" in model_cfg: - from transformers import AutoProcessor + from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor + + image_processor = load_image_processor(processor_dir, Qwen2VLImageProcessor) + return Tarsier2Tokenizer(tokenizer=base_tokenizer, image_processor=image_processor, model_cfg=model_cfg) + + if model_type == "llava" or model_type == "internlmxcomposer2": + return LlavaTokenizer(base_tokenizer, model_cfg) + + if model_type == "qwen" and "visual" in model_cfg: + return QWenVLTokenizer(base_tokenizer, model_cfg) - processor = AutoProcessor.from_pretrained(tokenizer_name) - tokenizer = QWen2VLTokenizer( - tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + if model_type in ["qwen2_vl", "qwen2_5_vl"] and "vision_config" in model_cfg: + from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor + + image_processor = load_image_processor(processor_dir, Qwen2VLImageProcessor) + return QWen2VLTokenizer( + tokenizer=base_tokenizer, image_processor=image_processor, model_cfg=model_cfg ) - elif model_type in ["qwen3_vl", "qwen3_vl_moe"] and "vision_config" in model_cfg: + + if model_type in ["qwen3_vl", "qwen3_vl_moe"] and "vision_config" in model_cfg: from transformers import AutoProcessor - processor = AutoProcessor.from_pretrained(tokenizer_name) - tokenizer = QWen3VLTokenizer( - tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + processor = AutoProcessor.from_pretrained(processor_dir) + return QWen3VLTokenizer( + tokenizer=base_tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) - elif model_type in ["qwen3_5", "qwen3_5_moe"] and "vision_config" in model_cfg: + + if model_type in ["qwen3_5", "qwen3_5_moe"] and "vision_config" in model_cfg: from transformers import AutoProcessor from ..models.qwen3_5.model import QWen3_5Tokenizer - processor = AutoProcessor.from_pretrained(tokenizer_name) - tokenizer = QWen3_5Tokenizer( - tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + processor = AutoProcessor.from_pretrained(processor_dir) + return QWen3_5Tokenizer( + tokenizer=base_tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) - elif model_cfg.get("thinker_config") is not None: + + if model_cfg.get("thinker_config") is not None: from transformers import AutoProcessor model_cfg = model_cfg["thinker_config"] - processor = AutoProcessor.from_pretrained(tokenizer_name) - tokenizer = QWen3OmniTokenizer(tokenizer, processor=processor, model_cfg=model_cfg) - elif model_type == "internvl_chat": - tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) - elif model_type == "gemma3": - tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + processor = AutoProcessor.from_pretrained(processor_dir) + return QWen3OmniTokenizer(base_tokenizer, processor=processor, model_cfg=model_cfg) - return tokenizer + if model_type == "internvl_chat": + return InternvlTokenizer(base_tokenizer, model_cfg, weight_dir=processor_dir) + + if model_type == "gemma3": + return Gemma3Tokenizer(base_tokenizer, model_cfg) + + return base_tokenizer + + +def get_tokenizer( + paths: ModelPaths, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + *args, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Load base tokenizer (HF or GGUF), then wrap for model-specific behavior if needed.""" + model_cfg = get_model_config(paths) + load_path, from_gguf = paths.tokenizer_load_path + + base_tokenizer = _load_base_tokenizer( + load_path=load_path, + from_gguf=from_gguf, + model_cfg=model_cfg, + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + *args, + **kwargs, + ) + + return _wrap_tokenizer( + base_tokenizer=base_tokenizer, + model_cfg=model_cfg, + processor_dir=paths.processor_dir, + ) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 1dffdaf681..274a8e7478 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -17,6 +17,7 @@ from lightllm.server.multimodal_params import MultimodalParams, ImageItem from .model_infer import start_model_process, VisualModelRpcClient from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend +from lightllm.utils.config_utils import ModelPaths from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -50,7 +51,7 @@ def __init__( self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.model_weightdir = args.model_dir + self.visual_weight_dir, self.processor_dir = ModelPaths.from_args(args).resolve_visual_dirs() self.vit_dp = args.visual_dp self.vit_tp = args.visual_tp # image 最大推理 batch size @@ -74,7 +75,8 @@ async def wait_to_model_ready(self): for tp_rank_id in range(self.vit_tp): device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] kvargs = { - "weight_dir": self.model_weightdir, + "weight_dir": self.visual_weight_dir, + "processor_dir": self.processor_dir, "device_id": device_id, "vit_tp": self.vit_tp, "cache_port": self.args.cache_port, diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 92ca2e3836..8b5aa15239 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -7,7 +7,7 @@ import time import torch.distributed as dist from typing import Dict, List, Tuple, Deque, Optional -from transformers.configuration_utils import PretrainedConfig +from lightllm.utils.config_utils import get_model_paths, get_model_config from rpyc.utils.classic import obtain from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer from lightllm.models.llava.llava_visual import LlavaVisionModel @@ -52,6 +52,7 @@ def exposed_init_model(self, kvargs): # } weight_dir = kvargs["weight_dir"] + processor_dir = kvargs.get("processor_dir", weight_dir) self.infer_max_batch_size = kvargs["max_batch_size"] self.device_id = kvargs["device_id"] self.vit_tp = kvargs["vit_tp"] @@ -63,11 +64,12 @@ def exposed_init_model(self, kvargs): self.vit_attn_backend = kvargs["vit_attn_backend"] set_vit_att_backend(self.vit_attn_backend) init_vision_distributed_env(kvargs) - model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) + model_cfg = get_model_config(get_model_paths()) try: kvargs = { "weight_dir": weight_dir, + "processor_dir": processor_dir, "data_type": self.data_type, "quant_type": kvargs["quant_type"], "quant_cfg": kvargs["quant_cfg"], diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index 15705a1140..79ab281c4c 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -21,6 +21,7 @@ from lightllm.server.multimodal_params import MultimodalParams, ImageItem from .model_infer import start_model_process, VisualModelRpcClient from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend +from lightllm.utils.config_utils import ModelPaths from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -39,7 +40,7 @@ def __init__( args: StartArgs, ): self.args = args - self.model_weightdir = args.model_dir + self.visual_weight_dir, self.processor_dir = ModelPaths.from_args(args).resolve_visual_dirs() self.vit_dp = args.visual_dp assert self.vit_dp == 1 self.vit_tp = args.visual_tp @@ -106,7 +107,8 @@ async def wait_to_model_ready(self): for tp_rank_id in range(self.vit_tp): device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] kvargs = { - "weight_dir": self.model_weightdir, + "weight_dir": self.visual_weight_dir, + "processor_dir": self.processor_dir, "device_id": device_id, "vit_tp": self.vit_tp, "cache_port": None, # visual only 模式下不使用 embed cache diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index ee7e87543b..b47c07fbd1 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -1,19 +1,166 @@ import json import os +from dataclasses import dataclass, field from functools import lru_cache -from typing import List, Optional +from typing import List, Optional, Union -from .envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) -def find_gguf_path(model_dir: Optional[str]) -> Optional[str]: +@dataclass(frozen=True) +class ModelPaths: + """ Resolved model-related paths for config / tokenizer / multimodal loading """ + model_dir: str + tokenizer_dir: Optional[str] = None + config_path: Optional[str] = None + mmproj_path: Optional[str] = None + _gguf_path: Optional[str] = field(default=None, init=False, repr=False, compare=False) + + @classmethod + def from_args(cls, args) -> "ModelPaths": + return _fill_paths_from_env( + cls( + model_dir=args.model_dir, + tokenizer_dir=getattr(args, "tokenizer_dir", None), + config_path=getattr(args, "config_path", None), + mmproj_path=getattr(args, "mmproj_path", None), + ) + ) + + @classmethod + def from_env(cls) -> "ModelPaths": + return cls.from_args(get_env_start_args()) + + @property + def is_gguf(self) -> bool: + return self._gguf_path is not None + + @property + def gguf_path(self) -> Optional[str]: + return self._gguf_path + + @property + def processor_dir(self) -> str: + return self.tokenizer_dir or self.model_dir + + @property + def tokenizer_load_path(self) -> tuple[str, bool]: + if self._gguf_path is not None and not self.tokenizer_dir: + return self._gguf_path, True + return self.processor_dir, False + + def resolve_visual_dirs(self) -> tuple[Optional[str], Optional[str]]: + if self.is_gguf: + return self.mmproj_path, self.tokenizer_dir + return self.model_dir, self.model_dir + + def load_config(self) -> dict: + if self.config_path is not None: + return _load_config_from_path(self.config_path) + + gguf_path = self.gguf_path + + if gguf_path is not None and self.tokenizer_dir is not None: + hf_config_path = os.path.join(self.tokenizer_dir, "config.json") + assert os.path.isfile(hf_config_path), f"config.json {hf_config_path} is not found" + return _load_config_from_path(hf_config_path) + + if gguf_path is None and self.model_dir: + config_json_path = os.path.join(self.model_dir, "config.json") + if os.path.isfile(config_json_path): + return _load_config_from_path(config_json_path) + + if gguf_path is not None: + return _load_config_from_gguf(gguf_path) + + raise FileNotFoundError( + f"no model config found (config_path={self.config_path!r}, model_dir={self.model_dir!r}). " + "Provide --config_path, place config.json under model_dir, or use a .gguf model path." + ) + + def align_quant_type(self, quant_type: str) -> str: + if self.is_gguf: + return "gguf" + + if quant_type == "gguf": + raise ValueError("--quant_type gguf is not supported for non-GGUF models") + + return quant_type + + def __post_init__(self): + object.__setattr__(self, "_gguf_path", _find_gguf_path_cached(self.model_dir)) + + +def _fill_paths_from_env(paths: ModelPaths) -> ModelPaths: + try: + start_args = get_env_start_args() + except KeyError: + return paths + + config_path = paths.config_path if paths.config_path is not None else getattr(start_args, "config_path", None) + tokenizer_dir = ( + paths.tokenizer_dir if paths.tokenizer_dir is not None else getattr(start_args, "tokenizer_dir", None) + ) + mmproj_path = paths.mmproj_path if paths.mmproj_path is not None else getattr(start_args, "mmproj_path", None) + + if config_path == paths.config_path and tokenizer_dir == paths.tokenizer_dir and mmproj_path == paths.mmproj_path: + return paths + + return ModelPaths( + model_dir=paths.model_dir, + config_path=config_path, + tokenizer_dir=tokenizer_dir, + mmproj_path=mmproj_path, + ) + + +def _create_model_paths( + model_dir_or_paths: Union[str, ModelPaths], + config_path: Optional[str] = None, + tokenizer_dir: Optional[str] = None, + mmproj_path: Optional[str] = None, +) -> ModelPaths: + if isinstance(model_dir_or_paths, ModelPaths): + paths = model_dir_or_paths + else: + paths = ModelPaths( + model_dir=model_dir_or_paths, + config_path=config_path, + tokenizer_dir=tokenizer_dir, + mmproj_path=mmproj_path, + ) + return _fill_paths_from_env(paths) + + +@lru_cache(maxsize=1) +def get_model_paths() -> ModelPaths: + return ModelPaths.from_env() + + +def _load_config_from_path(config_path: str) -> dict: + if not os.path.isfile(config_path): + raise FileNotFoundError(f"config file not found: {config_path}") + with open(config_path, "r") as file: + return json.load(file) + + +def _load_config_from_gguf(gguf_path: str) -> dict: + from lightllm.common.basemodel.layer_weights.gguf_load_utils import get_gguf_reader + + return get_gguf_reader(gguf_path).load_config() + + +@lru_cache(maxsize=128) +def _find_gguf_path_cached(model_dir: Optional[str]) -> Optional[str]: if not model_dir: return None + if model_dir.endswith(".gguf") and os.path.isfile(model_dir): return model_dir + if os.path.isdir(model_dir): gguf_files = sorted( os.path.join(model_dir, name) for name in os.listdir(model_dir) if name.endswith(".gguf") @@ -25,146 +172,20 @@ def find_gguf_path(model_dir: Optional[str]) -> Optional[str]: f"multiple GGUF files found in {model_dir} is not supported, please specify the target gguf file." ) return gguf_files[0] - return None - - -def is_gguf_model_path(model_dir: Optional[str]) -> bool: - return find_gguf_path(model_dir) is not None - - -def check_gguf_tokenizer_dir(model_dir: Optional[str], tokenizer_dir: Optional[str]) -> None: - if not is_gguf_model_path(model_dir): - return - if not tokenizer_dir: - raise ValueError( - f"GGUF model requires --tokenizer_dir (model_dir={model_dir!r}). " - "Provide a HuggingFace tokenizer directory separate from the .gguf weights." - ) - if not os.path.isdir(tokenizer_dir): - raise FileNotFoundError(f"tokenizer_dir is not a directory: {tokenizer_dir!r}") - for name in ("tokenizer.json", "tokenizer_config.json", "vocab.json"): - if os.path.isfile(os.path.join(tokenizer_dir, name)): - return - raise FileNotFoundError( - f"tokenizer_dir missing tokenizer files (tokenizer.json / tokenizer_config.json / vocab.json): " - f"{tokenizer_dir!r}" - ) - -def resolve_tokenizer_dir(model_dir: Optional[str], tokenizer_dir: Optional[str]) -> str: - check_gguf_tokenizer_dir(model_dir, tokenizer_dir) - if is_gguf_model_path(model_dir): - return tokenizer_dir - return tokenizer_dir or model_dir + return None -def uses_gguf_quant_type(quant_type: str, quant_cfg_path: Optional[str] = None) -> bool: - if quant_type == "gguf": - return True - if quant_cfg_path is None: - return False - import yaml - - with open(quant_cfg_path, "r") as file: - data = yaml.safe_load(file) or {} - if data.get("quant_type") == "gguf": - return True - for layer_quant_cfg in data.get("mix_bits", []) or []: - if layer_quant_cfg.get("quant_type") == "gguf": - return True - return False - - -def get_gguf_quant_conflicts(quant_type: str, quant_cfg_path: Optional[str] = None) -> List[str]: - """Non-gguf quant types that cannot be auto-overridden on GGUF weights ('none' is allowed).""" - conflicts = [] - if quant_type not in ("gguf", "none"): - conflicts.append(quant_type) - if quant_cfg_path is None: - return sorted(set(conflicts)) - import yaml - - with open(quant_cfg_path, "r") as file: - data = yaml.safe_load(file) or {} - cfg_quant_type = data.get("quant_type") - if cfg_quant_type is not None and cfg_quant_type not in ("gguf", "none"): - conflicts.append(cfg_quant_type) - for layer_quant_cfg in data.get("mix_bits", []) or []: - layer_quant_type = layer_quant_cfg.get("quant_type") - if layer_quant_type is not None and layer_quant_type != "gguf": - conflicts.append(layer_quant_type) - return sorted(set(conflicts)) - - -def align_quant_type_for_gguf_model( - model_dir: Optional[str], - quant_type: str, - quant_cfg_path: Optional[str] = None, -) -> str: - """GGUF weights only support gguf quantization; auto-align CLI default 'none' to 'gguf'.""" - if find_gguf_path(model_dir) is None: - return quant_type - conflicts = get_gguf_quant_conflicts(quant_type, quant_cfg_path) - if conflicts: - raise ValueError( - f"model_dir contains GGUF weights but quantization is configured as {conflicts!r}. " - "GGUF checkpoints only support --quant_type gguf. " - "Use a HuggingFace safetensors directory for awq/fp8/none, or remove non-gguf entries from --quant_cfg mix_bits." +def apply_gguf_quant_type(paths: Union[str, ModelPaths], quant_type: str) -> str: + """Align quant_type for GGUF models and log when overriding.""" + if not isinstance(paths, ModelPaths): + paths = _create_model_paths(paths) + aligned = paths.align_quant_type(quant_type) + if aligned != quant_type: + logger.warning( + f"model_dir contains GGUF weights; overriding --quant_type {quant_type!r} -> {aligned!r}" ) - if quant_type == "gguf": - return quant_type - return "gguf" - - -def check_gguf_quant_model_dir( - model_dir: Optional[str], - quant_type: str, - quant_cfg_path: Optional[str] = None, -) -> None: - if not uses_gguf_quant_type(quant_type, quant_cfg_path): - return - if find_gguf_path(model_dir) is not None: - return - raise ValueError( - f"--quant_type gguf requires a .gguf weights file, but none found under model_dir={model_dir!r}. " - "Point model_dir to a .gguf file or a directory with exactly one .gguf file. " - "For HuggingFace safetensors checkpoints, use --quant_type none (or awq / fp8, etc.), not gguf." - ) - - -def load_model_config(config_path: Optional[str] = None, model_dir: Optional[str] = None) -> dict: - # load config from config_path - if config_path is not None: - if not os.path.isfile(config_path): - raise FileNotFoundError(f"config file not found: {config_path}") - with open(config_path, "r") as file: - return json.load(file) - - # load config from model_dir/config.json - if model_dir and not model_dir.endswith(".gguf"): - default_json = os.path.join(model_dir, "config.json") - if os.path.isfile(default_json): - with open(default_json, "r") as file: - return json.load(file) - - # load config from GGUF metadata - gguf_path = find_gguf_path(model_dir) - if gguf_path is not None: - try: - from transformers.modeling_gguf_pytorch_utils import load_gguf_checkpoint - except ImportError as e: - raise ImportError( - "Loading config from GGUF requires transformers with GGUF support and the gguf package." - ) from e - config = load_gguf_checkpoint(gguf_path, return_tensors=False)["config"] - logger.info(f"loaded model config from GGUF metadata: {gguf_path}") - - return config - - raise FileNotFoundError( - f"no model config found (config_path={config_path!r}, model_dir={model_dir!r}). " - "Provide --config_path, place config.json under model_dir, or use a .gguf model path." - ) + return aligned def normalize_model_config(config: dict) -> dict: @@ -182,40 +203,56 @@ def normalize_model_config(config: dict) -> dict: @lru_cache(maxsize=None) -def get_model_config_dict(config_path: Optional[str] = None, model_dir: Optional[str] = None) -> dict: - return normalize_model_config(load_model_config(config_path=config_path, model_dir=model_dir)) +def _get_model_config_cached(paths: ModelPaths) -> dict: + return normalize_model_config(paths.load_config()) -def _get_config_from_model_path(model_path: str) -> dict: - if os.path.isfile(model_path) and not model_path.endswith(".gguf"): - return get_model_config_dict(config_path=model_path, model_dir=None) - return get_model_config_dict(model_dir=model_path) +def get_model_config(paths: Union[str, ModelPaths]) -> dict: + if isinstance(paths, str): + paths = _create_model_paths(paths) + return _get_model_config_cached(paths) -def resolve_model_config( - config_path: Optional[str] = None, - model_dir: Optional[str] = None, -) -> dict: - if config_path is not None or model_dir is not None: - return get_model_config_dict(config_path=config_path, model_dir=model_dir) - raise ValueError("resolve_model_config requires config_path or model_dir") +def check_gguf_multimodal_paths(paths: ModelPaths, enable_multimodal: bool = False) -> None: + if not enable_multimodal or not paths.is_gguf: + return + if not paths.tokenizer_dir: + raise ValueError("tokenizer_dir is required when enable_multimodal is True for GGUF models") + if not os.path.isdir(paths.tokenizer_dir): + raise FileNotFoundError(f"tokenizer_dir {paths.tokenizer_dir} is not found") -@lru_cache(maxsize=1) -def get_start_args_model_config() -> dict: - start_args = get_env_start_args() - return get_model_config_dict(config_path=start_args.config_path, model_dir=start_args.model_dir) + effective_config_path = paths.config_path + if effective_config_path is not None: + if not os.path.isfile(effective_config_path): + raise FileNotFoundError(f"config.json {effective_config_path} is not found") + else: + effective_config_path = os.path.join(paths.tokenizer_dir, "config.json") + if not os.path.isfile(effective_config_path): + raise FileNotFoundError( + f"config.json is not provided and not found in tokenizer_dir: " + f"{paths.tokenizer_dir} when enable_multimodal is True for GGUF models" + ) + processor_path = os.path.join(paths.tokenizer_dir, "preprocessor_config.json") + if not os.path.isfile(processor_path): + raise FileNotFoundError( + f"preprocessor_config.json not found in tokenizer_dir: " + f"{paths.tokenizer_dir} when enable_multimodal is True for GGUF models" + ) -def get_config_json(model_path: str) -> dict: - logger.warning( - "get_config_json(model_path) is deprecated; " - "use get_model_config_dict(config_path=..., model_dir=...) instead." - ) - return _get_config_from_model_path(model_path) + if not paths.mmproj_path: + raise ValueError("mmproj_path is required when enable_multimodal is True for GGUF models") + if not os.path.isfile(paths.mmproj_path): + raise FileNotFoundError(f"mmproj_path {paths.mmproj_path} is not found") + + +@lru_cache(maxsize=1) +def get_start_args_model_config() -> dict: + return get_model_config(get_model_paths()) -def _derive_max_req_total_len_from_model_config(model_dir: str) -> Optional[int]: +def _derive_max_req_total_len_from_model_config(paths: ModelPaths) -> Optional[int]: """ Derive `max_req_total_len` from model config.json. @@ -225,7 +262,7 @@ def _derive_max_req_total_len_from_model_config(model_dir: str) -> Optional[int] """ try: - cfg = get_config_json(model_dir) + cfg = get_model_config(paths) except Exception as e: logger.warning(f"failed to load config.json for max_req_total_len derive: {e}") return None @@ -314,7 +351,7 @@ def _find_rope_scaling() -> dict: return None -def auto_set_max_req_total_len(args) -> None: +def auto_set_max_req_total_len(args, paths: Optional[ModelPaths] = None) -> None: """ Ensure `args.max_req_total_len` is an int. @@ -326,14 +363,16 @@ def auto_set_max_req_total_len(args) -> None: if args.max_req_total_len is not None: return - model_dir = args.model_dir - if not model_dir: + if paths is None: + paths = ModelPaths.from_args(args) + + if not paths.model_dir: logger.warning("model_dir is empty; fallback max_req_total_len=16384") args.max_req_total_len = default_fallback return try: - derived = _derive_max_req_total_len_from_model_config(model_dir) + derived = _derive_max_req_total_len_from_model_config(paths) except Exception as e: logger.warning(f"failed to derive max_req_total_len from model config: {e}") derived = None @@ -347,8 +386,8 @@ def auto_set_max_req_total_len(args) -> None: logger.info(f"auto derived max_req_total_len={args.max_req_total_len} from model config") -def _get_config_llm_keyvalue(model_path: str, key_name: list[str]): - config_json = get_config_json(model_path) +def _get_config_llm_keyvalue(paths: ModelPaths, key_name: list[str]): + config_json = get_model_config(paths) for key in key_name: try: value = config_json[key] @@ -368,53 +407,57 @@ def _get_config_llm_keyvalue(model_path: str, key_name: list[str]): return None -def get_hidden_size(model_path: str) -> Optional[int]: - hidden_size = _get_config_llm_keyvalue(model_path=model_path, key_name=["hidden_size", "n_embd", "n_embed"]) +def get_hidden_size(model_dir_or_paths: Union[str, ModelPaths]) -> Optional[int]: + paths = _create_model_paths(model_dir_or_paths) + hidden_size = _get_config_llm_keyvalue(paths, key_name=["hidden_size", "n_embd", "n_embed"]) if isinstance(hidden_size, int): return hidden_size return None @lru_cache(maxsize=None) -def get_num_key_value_heads(model_path: str) -> int: - num_key_value_heads = _get_config_llm_keyvalue(model_path=model_path, key_name=["num_key_value_heads"]) +def get_num_key_value_heads(model_dir_or_paths: Union[str, ModelPaths]) -> int: + paths = _create_model_paths(model_dir_or_paths) + num_key_value_heads = _get_config_llm_keyvalue(paths, key_name=["num_key_value_heads"]) if isinstance(num_key_value_heads, int): return num_key_value_heads return None @lru_cache(maxsize=None) -def get_num_attention_heads(model_path: str) -> int: - num_attention_heads = _get_config_llm_keyvalue(model_path=model_path, key_name=["num_attention_heads"]) +def get_num_attention_heads(model_dir_or_paths: Union[str, ModelPaths]) -> int: + paths = _create_model_paths(model_dir_or_paths) + num_attention_heads = _get_config_llm_keyvalue(paths, key_name=["num_attention_heads"]) if isinstance(num_attention_heads, int): return num_attention_heads return None @lru_cache(maxsize=None) -def get_head_dim(model_path: str) -> int: - head_dim = _get_config_llm_keyvalue(model_path=model_path, key_name=["head_dim"]) +def get_head_dim(model_dir_or_paths: Union[str, ModelPaths]) -> int: + paths = _create_model_paths(model_dir_or_paths) + head_dim = _get_config_llm_keyvalue(paths, key_name=["head_dim"]) if isinstance(head_dim, int): return head_dim - # calcu head_dim - head_dim = get_hidden_size(model_path=model_path) // get_num_attention_heads(model_path=model_path) + head_dim = get_hidden_size(paths) // get_num_attention_heads(paths) return head_dim @lru_cache(maxsize=None) -def get_layer_num(model_path: str) -> int: - num_hidden_layers = _get_config_llm_keyvalue(model_path=model_path, key_name=["num_hidden_layers"]) +def get_layer_num(model_dir_or_paths: Union[str, ModelPaths]) -> int: + paths = _create_model_paths(model_dir_or_paths) + num_hidden_layers = _get_config_llm_keyvalue(paths, key_name=["num_hidden_layers"]) if isinstance(num_hidden_layers, int): return num_hidden_layers return None -def get_eos_token_ids(model_path: str) -> Optional[List[int]]: +def get_eos_token_ids(model_dir_or_paths: Union[str, ModelPaths]) -> Optional[List[int]]: + paths = _create_model_paths(model_dir_or_paths) try: - # qwen3-omini special eos_token_id - config_json = get_config_json(model_path) + config_json = get_model_config(paths) assert config_json["architectures"][0] == "Qwen3OmniMoeForConditionalGeneration" return [151645] except: @@ -422,20 +465,20 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]: # Qwen3.5 checkpoints can have an eos_token_id in config that differs from # tokenizer.eos_token_id. In practice tokenizer.eos_token_id is the reliable - # stop id (<|im_end|>) for detokenization/stop behavior. + # stop id (<|im_end|>) try: - config_json = get_config_json(model_path) + config_json = get_model_config(paths) model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type") if model_type in {"qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"}: from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) + tokenizer = AutoTokenizer.from_pretrained(paths.processor_dir, trust_remote_code=False) if tokenizer.eos_token_id is not None: return [int(tokenizer.eos_token_id)] except Exception: pass - eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) + eos_token_id = _get_config_llm_keyvalue(paths, key_name=["eos_token_id"]) if isinstance(eos_token_id, int): return [eos_token_id] if isinstance(eos_token_id, list): @@ -445,9 +488,10 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]: return -def get_model_architectures(model_path: str): +def get_model_architectures(model_dir_or_paths: Union[str, ModelPaths]): + paths = _create_model_paths(model_dir_or_paths) try: - config_json = get_config_json(model_path) + config_json = get_model_config(paths) arch = config_json["architectures"][0] return arch except: @@ -455,9 +499,13 @@ def get_model_architectures(model_path: str): return "unknown_architecture" -def get_vocab_size(config_path: Optional[str] = None, model_dir: Optional[str] = None) -> int: +def get_vocab_size(paths: Optional[Union[str, ModelPaths]] = None) -> int: try: - config_json = resolve_model_config(config_path=config_path, model_dir=model_dir) + if paths is None: + paths = get_model_paths() + elif not isinstance(paths, ModelPaths): + paths = _create_model_paths(paths) + config_json = get_model_config(paths) # qwen3-omini special if "thinker_config" in config_json: config_json = config_json["thinker_config"] @@ -476,8 +524,9 @@ def get_vocab_size(config_path: Optional[str] = None, model_dir: Optional[str] = return 0 -def get_dtype(model_path: str): - torch_dtype = _get_config_llm_keyvalue(model_path=model_path, key_name=["torch_dtype", "dtype", "model_dtype"]) +def get_dtype(model_dir_or_paths: Union[str, ModelPaths]): + paths = _create_model_paths(model_dir_or_paths) + torch_dtype = _get_config_llm_keyvalue(paths, key_name=["torch_dtype", "dtype", "model_dtype"]) if torch_dtype is None: logger.warning("torch_dtype not in config.json, use float16 as default") return "float16" @@ -495,9 +544,10 @@ def get_fixed_kv_len(): @lru_cache(maxsize=None) -def has_vision_module(model_path: str) -> bool: +def has_vision_module(model_dir_or_paths: Union[str, ModelPaths]) -> bool: + paths = _create_model_paths(model_dir_or_paths) try: - model_cfg = get_model_config_dict(model_dir=model_path) + model_cfg = get_model_config(paths) model_type = model_cfg["model_type"] if model_type == "qwen": # QWenVisionTransformer @@ -536,14 +586,15 @@ def has_vision_module(model_path: str) -> bool: else: raise Exception("unknown vision model type") except: - logger.info(f"model path: {model_path} does not has vision module") + logger.info(f"model path: {paths.model_dir} does not has vision module") return False @lru_cache(maxsize=None) -def has_audio_module(model_path: str) -> bool: +def has_audio_module(model_dir_or_paths: Union[str, ModelPaths]) -> bool: + paths = _create_model_paths(model_dir_or_paths) try: - model_cfg = get_model_config_dict(model_dir=model_path) + model_cfg = get_model_config(paths) if model_cfg.get("thinker_config") is not None: model_cfg = model_cfg["thinker_config"] audio_config = model_cfg["audio_config"] @@ -557,44 +608,39 @@ def has_audio_module(model_path: str) -> bool: else: raise Exception("unknown audio model type") except: - logger.info(f"model path: {model_path} does not has audio module") + logger.info(f"model path: {paths.model_dir} does not has audio module") return False @lru_cache(maxsize=None) -def is_linear_att_mixed_model(model_path: str) -> bool: +def is_linear_att_mixed_model(model_dir_or_paths: Union[str, ModelPaths]) -> bool: + paths = _create_model_paths(model_dir_or_paths) try: - model_cfg = get_model_config_dict(model_dir=model_path) + model_cfg = get_model_config(paths) model_type = model_cfg["model_type"] if model_type in ["qwen3_5", "qwen3_5_moe", "qwen3_5_text", "qwen3_5_moe_text"]: return True else: return False except: - logger.info(f"model path: {model_path} does not has linear hybrid attention") + logger.info(f"model path: {paths.model_dir} does not has linear hybrid attention") return False -def get_model_type( - config_path: Optional[str] = None, - model_dir: Optional[str] = None, -) -> Optional[str]: +def get_model_type(paths: Union[str, ModelPaths]) -> Optional[str]: """Get model type from model config.""" try: - config_json = resolve_model_config(config_path=config_path, model_dir=model_dir) + config_json = get_model_config(paths) model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type") return model_type except Exception as e: - logger.error(f"Failed to get model_type (config_path={config_path!r}, model_dir={model_dir!r}): {e}") + logger.error(f"Failed to get model_type (paths={paths!r}): {e}") return None -def get_tool_call_parser_for_model( - config_path: Optional[str] = None, - model_dir: Optional[str] = None, -) -> Optional[str]: +def get_tool_call_parser_for_model(paths: Union[str, ModelPaths]) -> Optional[str]: """Auto-detect tool_call_parser based on model type""" - model_type = get_model_type(config_path=config_path, model_dir=model_dir) + model_type = get_model_type(paths) if model_type is None: return None @@ -621,12 +667,8 @@ def get_tool_call_parser_for_model( return None -def get_reasoning_parser_for_model( - config_path: Optional[str] = None, - model_dir: Optional[str] = None, -) -> Optional[str]: - """Auto-detect reasoning_parser based on model type""" - model_type = get_model_type(config_path=config_path, model_dir=model_dir) +def get_reasoning_parser_for_model(paths: Union[str, ModelPaths]) -> Optional[str]: + model_type = get_model_type(paths) if model_type is None: return None diff --git a/lightllm/utils/gguf_tokenizer_utils.py b/lightllm/utils/gguf_tokenizer_utils.py new file mode 100644 index 0000000000..8fafc20cec --- /dev/null +++ b/lightllm/utils/gguf_tokenizer_utils.py @@ -0,0 +1,207 @@ +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.integrations.ggml import ( + GGUF_TO_FAST_CONVERTERS, + GGUF_TOKENIZER_MAPPING, + convert_gguf_tokenizer, +) +from transformers.models.auto.configuration_auto import CONFIG_MAPPING +from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, tokenizer_class_from_name +from typing import Any, Dict, Optional, Tuple, Union + +from lightllm.common.basemodel.layer_weights.gguf_load_utils import LightLLMGGUFReader, get_gguf_reader +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def _build_gguf_tokenizer_field_map() -> Dict[str, Tuple[str, str]]: + """ Build a mapping from GGUF tokenizer field name to HuggingFace tokenizer field name. """ + field_map: Dict[str, Tuple[str, str]] = {} + # Build mapping for tokenizer fields + for gguf_suffix, hf_key in GGUF_TOKENIZER_MAPPING["tokenizer"].items(): + # e.g., "ggml.model" -> ("tokenizer", "model_type") + field_map[f"tokenizer.{gguf_suffix}"] = ("tokenizer", hf_key) + # Build mapping for tokenizer_config fields + for gguf_suffix, hf_key in GGUF_TOKENIZER_MAPPING["tokenizer_config"].items(): + # e.g., "ggml.model" -> ("tokenizer_config", "model_type") + gguf_key = f"tokenizer.{gguf_suffix}" + if gguf_key not in field_map: + field_map[gguf_key] = ("tokenizer_config", hf_key) + # Add special case for add_bos_token + field_map["tokenizer.ggml.add_bos_token"] = ("tokenizer_config", "add_bos_token") + + return field_map + + +def _parse_gguf_tokenizer_from_reader(reader: LightLLMGGUFReader) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ Parse tokenizer metadata from a GGUFReader using ReaderField.contents(). """ + tokenizer: Dict[str, Any] = {} + tokenizer_config: Dict[str, Any] = {} + + tokenizer_field_map = _build_gguf_tokenizer_field_map() + for gguf_key, (bucket, hf_key) in tokenizer_field_map.items(): + value = reader.read_field(gguf_key) + if value is None: + continue + # Add value to tokenizer or tokenizer_config + if bucket == "tokenizer": + tokenizer[hf_key] = value + else: + tokenizer_config[hf_key] = value + + if "model_type" not in tokenizer_config and tokenizer.get("tokenizer_type") is not None: + tokenizer_config["model_type"] = tokenizer["tokenizer_type"] + + for id_key in ("bos_token_id", "eos_token_id", "unk_token_id", "pad_token_id"): + token_id = tokenizer.get(id_key) + if token_id is not None: + tokenizer_config[id_key] = token_id + + return tokenizer, tokenizer_config + + +def _architecture_to_converter_key(architecture: str) -> str: + """ Convert GGUF architecture to HuggingFace converter key. """ + arch = architecture.replace("-", "_") + if arch in GGUF_TO_FAST_CONVERTERS: + return arch + # e.g., qwen2moe / qwen3moe -> qwen2_moe / qwen3_moe + if arch.endswith("moe") and not arch.endswith("_moe"): + alt = f"{arch[:-3]}_moe" + if alt in GGUF_TO_FAST_CONVERTERS: + return alt + + return arch + + +def _resolve_gguf_tokenizer_architecture(model_config: Dict[str, Any], reader: LightLLMGGUFReader) -> str: + """ Resolve GGUF tokenizer architecture from GGUF reader or model config. """ + architecture = None + # Read architecture from GGUF reader + arch = reader.read_field("general.architecture") + if isinstance(arch, (list, tuple)): + arch = arch[0] + if isinstance(arch, str): + architecture = arch + + # Read architecture from model config + if not architecture: + architecture = model_config.get("model_type") + + if not architecture: + raise ValueError( + "Can't resolve GGUF tokenizer architecture: " + "missing general.architecture in GGUF and model_type in config" + ) + + return _architecture_to_converter_key(architecture) + + +def _build_tokenizer_init_kwargs( + tokenizer_dict: Dict[str, Any], + tokenizer_config: Dict[str, Any], + additional_kwargs: Dict[str, Any], +) -> Dict[str, Any]: + """ Build tokenizer initialization kwargs from tokenizer dictionary and configuration. """ + # Merge tokenizer configuration and additional kwargs + init_kwargs = {**tokenizer_config, **additional_kwargs} + + tokens = tokenizer_dict.get("tokens") or [] + for id_key, token_key in ( + ("bos_token_id", "bos_token"), + ("eos_token_id", "eos_token"), + ("unk_token_id", "unk_token"), + ("pad_token_id", "pad_token"), + ): + token_id = tokenizer_dict.get(id_key) + if token_id is None: + token_id = init_kwargs.get(id_key) + + if token_id is not None and token_id < len(tokens): + init_kwargs.setdefault(token_key, tokens[token_id]) + + return init_kwargs + + +def _get_hf_tokenizer_class_from_config( + model_config: Dict[str, Any], + use_fast: bool = True, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """ Resolve the same tokenizer class AutoTokenizer would use for this config. """ + # Get tokenizer class name from tokenizer_class key of model config + tokenizer_class_name: str = model_config.get("tokenizer_class") + if tokenizer_class_name: + candidates = [] + if use_fast and not tokenizer_class_name.endswith("Fast"): + candidates.append(f"{tokenizer_class_name}Fast") + candidates.append(tokenizer_class_name) + for candidate in candidates: + tokenizer_class = tokenizer_class_from_name(candidate) + if tokenizer_class is not None: + return tokenizer_class + + # Get HuggingFace tokenizer class from model_type key of model config + model_type = model_config.get("model_type") + # CONFIG_MAPPING: model_type -> config class + if not model_type or model_type not in CONFIG_MAPPING: + return PreTrainedTokenizerFast + + config_class = CONFIG_MAPPING[model_type] + # TOKENIZER_MAPPING: config class -> (tokenizer class, fast tokenizer class) + if config_class not in TOKENIZER_MAPPING: + return PreTrainedTokenizerFast + + tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[config_class] + if use_fast and tokenizer_class_fast is not None: + return tokenizer_class_fast + + if tokenizer_class_py is not None: + return tokenizer_class_py + + return PreTrainedTokenizerFast + + +def load_tokenizer_from_gguf( + gguf_path: str, + model_config: Optional[Dict[str, Any]] = None, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """ Load the tokenizer from GGUF file. """ + if model_config is None: + raise ValueError("model_config is required when loading tokenizer from GGUF") + + reader = get_gguf_reader(gguf_path) + tokenizer_dict, tokenizer_config = _parse_gguf_tokenizer_from_reader(reader) + architecture = _resolve_gguf_tokenizer_architecture(model_config, reader=reader) + + if "tokens" not in tokenizer_dict: + raise ValueError(f"GGUF file does not contain tokenizer.ggml.tokens: {gguf_path}") + + if architecture not in GGUF_TO_FAST_CONVERTERS: + supported = ", ".join(sorted(GGUF_TO_FAST_CONVERTERS.keys())) + raise ValueError( + f"unsupported GGUF tokenizer architecture {architecture!r} for {gguf_path}; " + f"supported: {supported}" + ) + + logger.info( + f"loading tokenizer from GGUF ReaderField metadata: {gguf_path} " + f"(architecture={architecture}, vocab_size={len(tokenizer_dict['tokens'])})" + ) + + # Convert GGUF tokenizer to HuggingFace tokenizer + fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict) + + init_kwargs = _build_tokenizer_init_kwargs(tokenizer_dict, tokenizer_config, additional_kwargs) + + tokenizer_kwargs = {k: v for k, v in kwargs.items() if k != "trust_remote_code"} + init_kwargs.update(tokenizer_kwargs) + + use_fast = kwargs.pop("use_fast", True) + # Get HuggingFace tokenizer class from model config + tokenizer_class = _get_hf_tokenizer_class_from_config(model_config, use_fast=use_fast) + + logger.info(f"GGUF tokenizer class: {tokenizer_class.__name__}") + + # Initialize HuggingFace tokenizer + return tokenizer_class(tokenizer_object=fast_tokenizer, **init_kwargs) diff --git a/lightllm/utils/shm_size_check.py b/lightllm/utils/shm_size_check.py index dcf9aaadc7..0b3e9b0280 100644 --- a/lightllm/utils/shm_size_check.py +++ b/lightllm/utils/shm_size_check.py @@ -6,7 +6,7 @@ from lightllm.server.core.objs.req import ChunkedPrefillReq, TokenHealingReq from lightllm.server.multimodal_params import ImageItem from lightllm.server.tokenizer import get_tokenizer -from lightllm.utils.config_utils import get_hidden_size +from lightllm.utils.config_utils import ModelPaths, get_hidden_size from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -86,11 +86,11 @@ def _get_recommended_shm_size_gb(args, max_image_resolution=(3940, 2160), dtype_ """ 获取所需的 /dev/shm 大小(以GB为单位)。 """ + paths = ModelPaths.from_args(args) tokenizer = get_tokenizer( - args.model_dir, - args.tokenizer_dir, - args.tokenizer_mode, - trust_remote_code=True, + paths, + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, ) # 估算input_token和logprob占用shm大小,由于是double和int64,所以固定占用8个字节 @@ -127,7 +127,7 @@ def _get_recommended_shm_size_gb(args, max_image_resolution=(3940, 2160), dtype_ max_image_tokens = tokenizer.get_image_token_length(fake_image_item) # 估算图片 token 所需的资源 - hidden_size = get_hidden_size(args.model_dir) + hidden_size = get_hidden_size(paths) if hidden_size is None: logger.warning( "Model config not contain 'hidden_size', " "using 4096 by default to calculate recommended shm size." From f21e144f9454aad3834f8a19ef4bf1e995290a4c Mon Sep 17 00:00:00 2001 From: zhangtaoshan Date: Mon, 15 Jun 2026 19:01:02 +0800 Subject: [PATCH 4/5] Normalize 'get_model_config' --- lightllm/common/basemodel/basemodel.py | 9 +- .../meta_weights/mm_weight/mm_slicer.py | 5 +- lightllm/server/api_openai.py | 9 +- lightllm/server/api_start.py | 23 ++++- lightllm/server/build_prompt.py | 10 ++- lightllm/server/detokenization/manager.py | 10 ++- lightllm/server/httpserver/manager.py | 9 +- .../httpserver_for_pd_master/manager.py | 10 ++- .../model_infer/mode_backend/base_backend.py | 4 +- .../impl_for_outlines_constraint_mode.py | 11 ++- .../chunked_prefill/impl_for_token_healing.py | 9 +- .../chunked_prefill/impl_for_xgrammar_mode.py | 9 +- lightllm/server/tokenizer.py | 8 +- lightllm/server/visualserver/manager.py | 9 +- .../visualserver/visual_only_manager.py | 9 +- lightllm/utils/config_utils.py | 83 ++++++++----------- lightllm/utils/gguf_tokenizer_utils.py | 2 +- lightllm/utils/shm_size_check.py | 10 ++- 18 files changed, 148 insertions(+), 91 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 8d2db60c99..ac89b7050f 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -16,6 +16,7 @@ from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager from lightllm.common.infer_utils import init_req_to_token_indexes +from lightllm.common.build_utils import repair_config from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.basemodel.cuda_graph import CudaGraph @@ -24,7 +25,7 @@ from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token from lightllm.utils.config_utils import ( apply_gguf_quant_type, - _create_model_paths, + create_model_paths, get_model_config, ) from lightllm.utils.log_utils import init_logger @@ -62,7 +63,7 @@ def __init__(self, kvargs): self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] self.weight_dir_ = kvargs["weight_dir"] - self.model_paths_ = _create_model_paths(self.weight_dir_) + self.model_paths_ = create_model_paths(self.weight_dir_) self.max_total_token_num = kvargs["max_total_token_num"] self.batch_max_tokens = kvargs.get("batch_max_tokens", None) self.load_way = kvargs.get("load_way", "HF") @@ -153,6 +154,10 @@ def _wait_other_modules_ready(self): def _init_config(self): self.config = get_model_config(self.model_paths_) + # rename keys + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) if self.finetune_config: self.config["vocab_size"] = self.finetune_config.vocab_size return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index fb670fafef..312d3dd90d 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -148,9 +148,10 @@ def get_row_slice_mixin( ) -> SliceMixinTpl: if quant_method_name.startswith("awq"): return AwqQuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) - if quant_method_name == "none" or quant_method_name.startswith("gguf"): + elif quant_method_name == "none" or quant_method_name.startswith("gguf"): return RowSliceMixin(tp_rank, tp_world_size, repeat_times) - return QuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) + else: + return QuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) def get_col_slice_mixin( diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 16b7f42e12..7a238012a4 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -32,7 +32,7 @@ from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.error_utils import ClientDisconnected -from lightllm.utils.config_utils import ModelPaths, has_vision_module +from lightllm.utils.config_utils import create_model_paths, has_vision_module from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name @@ -233,7 +233,12 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req created_time = int(time.time()) - supports_vl_chat = has_vision_module(ModelPaths.from_args(g_objs.args)) + supports_vl_chat = has_vision_module(create_model_paths( + g_objs.args.model_dir, + config_path=g_objs.args.config_path, + tokenizer_dir=g_objs.args.tokenizer_dir, + mmproj_path=g_objs.args.mmproj_path, + )) multimodal_params_dict = {"images": [], "audios": []} for message in request.messages: if isinstance(message.content, list): diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 094db272d0..319b0a96bf 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -18,7 +18,7 @@ from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size from lightllm.utils.config_utils import ( - ModelPaths, + create_model_paths, apply_gguf_quant_type, has_audio_module, has_vision_module, @@ -77,7 +77,12 @@ def normal_or_p_d_start(args): args: StartArgs = args - paths = ModelPaths.from_args(args) + paths = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ) auto_set_max_req_total_len(args, paths) set_unique_server_name(args) @@ -543,7 +548,12 @@ def pd_master_start(args): if args.run_mode != "pd_master": return - paths = ModelPaths.from_args(args) + paths = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ) auto_set_max_req_total_len(args, paths) # when use config_server to support multi pd_master node, we @@ -610,7 +620,12 @@ def visual_only_start(args): from lightllm.server.core.objs.start_args_type import StartArgs args: StartArgs = args - paths = ModelPaths.from_args(args) + paths = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ) if args.afs_image_embed_dir is not None: os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) os.chmod(args.afs_image_embed_dir, 0o777) diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 5f34166cf4..160e9a70db 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -1,6 +1,6 @@ import os import json -from lightllm.utils.config_utils import ModelPaths +from lightllm.utils.config_utils import create_model_paths from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -12,9 +12,13 @@ def init_tokenizer(args): global tokenizer - paths = ModelPaths.from_args(args) tokenizer = get_tokenizer( - paths, + create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ), tokenizer_mode=args.tokenizer_mode, trust_remote_code=args.trust_remote_code, ) diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 588c78201d..c3a8a95bb5 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -12,8 +12,7 @@ from .decode import decode_token from .decode_mode_fix import decode_mode_fix from .decode_req import DecodeReq -from lightllm.utils.config_utils import ModelPaths -from ..tokenizer import get_tokenizer +from ..tokenizer import create_model_paths, get_tokenizer import pickle import time from lightllm.utils.log_utils import init_logger @@ -36,7 +35,12 @@ def __init__( self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}") self.tokenizer = get_tokenizer( - ModelPaths.from_args(args), + create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ), tokenizer_mode=args.tokenizer_mode, trust_remote_code=args.trust_remote_code, ) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e8e0070ffa..b6a4cba27d 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -32,7 +32,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient from lightllm.utils.statics_utils import MovingAverage -from lightllm.utils.config_utils import ModelPaths, get_vocab_size +from lightllm.utils.config_utils import create_model_paths, get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import ClientDisconnected, NixlPrefillNodeStopGenToken from rpyc.utils.classic import obtain @@ -102,7 +102,12 @@ def __init__( self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"") - paths = ModelPaths.from_args(args) + paths = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ) self.tokenizer = get_tokenizer( paths, diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 3c317abe68..38013f785f 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -12,8 +12,7 @@ from ..pd_io_struct import PD_Client_Obj, UpKVStatus, NixlUpKVStatus, ObjType, NodeRole, NIXLDecodeNodeInfo from lightllm.server.core.objs import SamplingParams, StartArgs from ..multimodal_params import MultimodalParams -from lightllm.utils.config_utils import ModelPaths -from ..tokenizer import get_tokenizer +from ..tokenizer import create_model_paths, get_tokenizer from ..req_id_generator import ReqIDGenerator, convert_sub_id_to_group_id from fastapi import Request from lightllm.utils.log_utils import init_logger @@ -44,7 +43,12 @@ def __init__( self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对 self.tokenizer = get_tokenizer( - ModelPaths.from_args(args), + create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ), tokenizer_mode=args.tokenizer_mode, trust_remote_code=args.trust_remote_code, ) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 1d83245bfa..82094eefe2 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -19,7 +19,7 @@ from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify -from lightllm.utils.config_utils import _create_model_paths, get_model_config +from lightllm.utils.config_utils import create_model_paths, get_model_config from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -151,7 +151,7 @@ def init_model(self, kvargs): if self.args.enable_multimodal: g_infer_context.init_cpu_embed_cache_client() - model_cfg = get_model_config(_create_model_paths(self.weight_dir)) + model_cfg = get_model_config(create_model_paths(self.weight_dir)) model_kvargs = { "weight_dir": self.weight_dir, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py index c9314abafa..21341e0589 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py @@ -6,8 +6,8 @@ from .impl import ChunkedPrefillBackend from lightllm.server.core.objs import FinishStatus from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.utils.config_utils import ModelPaths from lightllm.server.tokenizer import get_tokenizer +from lightllm.utils.config_utils import create_model_paths from typing import List, Tuple from lightllm.utils.log_utils import init_logger @@ -37,10 +37,15 @@ def init_custom(self): self.tokenizer = TransformerTokenizer( get_tokenizer( - ModelPaths.from_args(self.args), + create_model_paths( + self.args.model_dir, + config_path=self.args.config_path, + tokenizer_dir=self.args.tokenizer_dir, + mmproj_path=self.args.mmproj_path, + ), tokenizer_mode=self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code, - ) + ), ) eos_token_ids = [] eos_token_ids.append(self.tokenizer.eos_token_id) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py index 887f89c2f4..94fec89193 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py @@ -2,7 +2,7 @@ from .impl import ChunkedPrefillBackend from typing import List from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.utils.config_utils import ModelPaths +from lightllm.utils.config_utils import create_model_paths from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -22,7 +22,12 @@ def init_custom(self): 初始化tokenizer 词表相关的的操作 """ self.tokenizer = get_tokenizer( - ModelPaths.from_args(self.args), + create_model_paths( + self.args.model_dir, + config_path=self.args.config_path, + tokenizer_dir=self.args.tokenizer_dir, + mmproj_path=self.args.mmproj_path, + ), tokenizer_mode=self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code, ) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py index 09fde5cf75..f3e987868e 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py @@ -5,7 +5,7 @@ from lightllm.utils.infer_utils import calculate_time from lightllm.server.core.objs import FinishStatus from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.utils.config_utils import ModelPaths +from lightllm.utils.config_utils import create_model_paths from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -24,7 +24,12 @@ def init_custom(self): import xgrammar as xgr self.tokenizer = get_tokenizer( - ModelPaths.from_args(self.args), + create_model_paths( + self.args.model_dir, + config_path=self.args.config_path, + tokenizer_dir=self.args.tokenizer_dir, + mmproj_path=self.args.mmproj_path, + ), tokenizer_mode=self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code, ) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index a8ebe96665..1b2ae639cd 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -21,7 +21,7 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.convert_slow_tokenizer import convert_slow_tokenizer from transformers.configuration_utils import PretrainedConfig -from lightllm.utils.config_utils import ModelPaths, get_model_config +from lightllm.utils.config_utils import ModelPaths, get_model_config, create_model_paths from lightllm.utils.gguf_tokenizer_utils import load_tokenizer_from_gguf from lightllm.utils.log_utils import init_logger from ..models.tarsier2.model import Tarsier2Tokenizer @@ -69,7 +69,6 @@ def _load_base_tokenizer( kwargs["use_fast"] = False if from_gguf: - logger.info(f"Loading tokenizer from GGUF file: {load_path}") return load_tokenizer_from_gguf(load_path, model_cfg, *args, **kwargs) if "llama" in load_path.lower() and kwargs.get("use_fast", True): @@ -158,13 +157,14 @@ def _wrap_tokenizer( def get_tokenizer( - paths: ModelPaths, + paths: Union[str, ModelPaths], tokenizer_mode: str = "auto", trust_remote_code: bool = False, *args, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - """Load base tokenizer (HF or GGUF), then wrap for model-specific behavior if needed.""" + """ Load base tokenizer (HF or GGUF), then wrap for model-specific behavior if needed. """ + paths = create_model_paths(paths) model_cfg = get_model_config(paths) load_path, from_gguf = paths.tokenizer_load_path diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 274a8e7478..b548339332 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -17,7 +17,7 @@ from lightllm.server.multimodal_params import MultimodalParams, ImageItem from .model_infer import start_model_process, VisualModelRpcClient from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend -from lightllm.utils.config_utils import ModelPaths +from lightllm.utils.config_utils import create_model_paths from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -51,7 +51,12 @@ def __init__( self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.visual_weight_dir, self.processor_dir = ModelPaths.from_args(args).resolve_visual_dirs() + self.visual_weight_dir, self.processor_dir = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ).resolve_visual_dirs() self.vit_dp = args.visual_dp self.vit_tp = args.visual_tp # image 最大推理 batch size diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index 79ab281c4c..aa489ff51c 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -21,7 +21,7 @@ from lightllm.server.multimodal_params import MultimodalParams, ImageItem from .model_infer import start_model_process, VisualModelRpcClient from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend -from lightllm.utils.config_utils import ModelPaths +from lightllm.utils.config_utils import create_model_paths from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -40,7 +40,12 @@ def __init__( args: StartArgs, ): self.args = args - self.visual_weight_dir, self.processor_dir = ModelPaths.from_args(args).resolve_visual_dirs() + self.visual_weight_dir, self.processor_dir = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ).resolve_visual_dirs() self.vit_dp = args.visual_dp assert self.vit_dp == 1 self.vit_tp = args.visual_tp diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index b47c07fbd1..e4502c8dda 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -19,21 +19,6 @@ class ModelPaths: mmproj_path: Optional[str] = None _gguf_path: Optional[str] = field(default=None, init=False, repr=False, compare=False) - @classmethod - def from_args(cls, args) -> "ModelPaths": - return _fill_paths_from_env( - cls( - model_dir=args.model_dir, - tokenizer_dir=getattr(args, "tokenizer_dir", None), - config_path=getattr(args, "config_path", None), - mmproj_path=getattr(args, "mmproj_path", None), - ) - ) - - @classmethod - def from_env(cls) -> "ModelPaths": - return cls.from_args(get_env_start_args()) - @property def is_gguf(self) -> bool: return self._gguf_path is not None @@ -117,12 +102,22 @@ def _fill_paths_from_env(paths: ModelPaths) -> ModelPaths: ) -def _create_model_paths( - model_dir_or_paths: Union[str, ModelPaths], +def create_model_paths( + model_dir_or_paths: Union[str, ModelPaths, None] = None, + *, config_path: Optional[str] = None, tokenizer_dir: Optional[str] = None, mmproj_path: Optional[str] = None, ) -> ModelPaths: + if model_dir_or_paths is None: + start_args = get_env_start_args() + return create_model_paths( + start_args.model_dir, + config_path=getattr(start_args, "config_path", None), + tokenizer_dir=getattr(start_args, "tokenizer_dir", None), + mmproj_path=getattr(start_args, "mmproj_path", None), + ) + if isinstance(model_dir_or_paths, ModelPaths): paths = model_dir_or_paths else: @@ -132,12 +127,13 @@ def _create_model_paths( tokenizer_dir=tokenizer_dir, mmproj_path=mmproj_path, ) + return _fill_paths_from_env(paths) @lru_cache(maxsize=1) def get_model_paths() -> ModelPaths: - return ModelPaths.from_env() + return create_model_paths() def _load_config_from_path(config_path: str) -> dict: @@ -178,8 +174,7 @@ def _find_gguf_path_cached(model_dir: Optional[str]) -> Optional[str]: def apply_gguf_quant_type(paths: Union[str, ModelPaths], quant_type: str) -> str: """Align quant_type for GGUF models and log when overriding.""" - if not isinstance(paths, ModelPaths): - paths = _create_model_paths(paths) + paths = create_model_paths(paths) aligned = paths.align_quant_type(quant_type) if aligned != quant_type: logger.warning( @@ -188,12 +183,7 @@ def apply_gguf_quant_type(paths: Union[str, ModelPaths], quant_type: str) -> str return aligned -def normalize_model_config(config: dict) -> dict: - from lightllm.common.build_utils import repair_config - - repair_config(config, same_names=["num_attention_heads", "n_head"]) - repair_config(config, same_names=["hidden_size", "n_embd", "n_embed"]) - repair_config(config, same_names=["num_hidden_layers", "n_layer"]) +def _normalize_gguf_model_config(config: dict) -> dict: if config.get("head_dim") is None: hidden_size = config.get("hidden_size") or config.get("n_embd") or config.get("n_embed") num_heads = config.get("num_attention_heads") or config.get("n_head") @@ -203,14 +193,12 @@ def normalize_model_config(config: dict) -> dict: @lru_cache(maxsize=None) -def _get_model_config_cached(paths: ModelPaths) -> dict: - return normalize_model_config(paths.load_config()) - - def get_model_config(paths: Union[str, ModelPaths]) -> dict: - if isinstance(paths, str): - paths = _create_model_paths(paths) - return _get_model_config_cached(paths) + paths = create_model_paths(paths) + config = paths.load_config() + if paths.is_gguf: + config = _normalize_gguf_model_config(config) + return config def check_gguf_multimodal_paths(paths: ModelPaths, enable_multimodal: bool = False) -> None: @@ -351,7 +339,7 @@ def _find_rope_scaling() -> dict: return None -def auto_set_max_req_total_len(args, paths: Optional[ModelPaths] = None) -> None: +def auto_set_max_req_total_len(args, paths: ModelPaths) -> None: """ Ensure `args.max_req_total_len` is an int. @@ -363,9 +351,6 @@ def auto_set_max_req_total_len(args, paths: Optional[ModelPaths] = None) -> None if args.max_req_total_len is not None: return - if paths is None: - paths = ModelPaths.from_args(args) - if not paths.model_dir: logger.warning("model_dir is empty; fallback max_req_total_len=16384") args.max_req_total_len = default_fallback @@ -408,7 +393,7 @@ def _get_config_llm_keyvalue(paths: ModelPaths, key_name: list[str]): def get_hidden_size(model_dir_or_paths: Union[str, ModelPaths]) -> Optional[int]: - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) hidden_size = _get_config_llm_keyvalue(paths, key_name=["hidden_size", "n_embd", "n_embed"]) if isinstance(hidden_size, int): return hidden_size @@ -417,7 +402,7 @@ def get_hidden_size(model_dir_or_paths: Union[str, ModelPaths]) -> Optional[int] @lru_cache(maxsize=None) def get_num_key_value_heads(model_dir_or_paths: Union[str, ModelPaths]) -> int: - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) num_key_value_heads = _get_config_llm_keyvalue(paths, key_name=["num_key_value_heads"]) if isinstance(num_key_value_heads, int): return num_key_value_heads @@ -426,7 +411,7 @@ def get_num_key_value_heads(model_dir_or_paths: Union[str, ModelPaths]) -> int: @lru_cache(maxsize=None) def get_num_attention_heads(model_dir_or_paths: Union[str, ModelPaths]) -> int: - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) num_attention_heads = _get_config_llm_keyvalue(paths, key_name=["num_attention_heads"]) if isinstance(num_attention_heads, int): return num_attention_heads @@ -435,7 +420,7 @@ def get_num_attention_heads(model_dir_or_paths: Union[str, ModelPaths]) -> int: @lru_cache(maxsize=None) def get_head_dim(model_dir_or_paths: Union[str, ModelPaths]) -> int: - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) head_dim = _get_config_llm_keyvalue(paths, key_name=["head_dim"]) if isinstance(head_dim, int): return head_dim @@ -447,7 +432,7 @@ def get_head_dim(model_dir_or_paths: Union[str, ModelPaths]) -> int: @lru_cache(maxsize=None) def get_layer_num(model_dir_or_paths: Union[str, ModelPaths]) -> int: - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) num_hidden_layers = _get_config_llm_keyvalue(paths, key_name=["num_hidden_layers"]) if isinstance(num_hidden_layers, int): return num_hidden_layers @@ -455,7 +440,7 @@ def get_layer_num(model_dir_or_paths: Union[str, ModelPaths]) -> int: def get_eos_token_ids(model_dir_or_paths: Union[str, ModelPaths]) -> Optional[List[int]]: - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) try: config_json = get_model_config(paths) assert config_json["architectures"][0] == "Qwen3OmniMoeForConditionalGeneration" @@ -489,7 +474,7 @@ def get_eos_token_ids(model_dir_or_paths: Union[str, ModelPaths]) -> Optional[Li def get_model_architectures(model_dir_or_paths: Union[str, ModelPaths]): - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) try: config_json = get_model_config(paths) arch = config_json["architectures"][0] @@ -504,7 +489,7 @@ def get_vocab_size(paths: Optional[Union[str, ModelPaths]] = None) -> int: if paths is None: paths = get_model_paths() elif not isinstance(paths, ModelPaths): - paths = _create_model_paths(paths) + paths = create_model_paths(paths) config_json = get_model_config(paths) # qwen3-omini special if "thinker_config" in config_json: @@ -525,7 +510,7 @@ def get_vocab_size(paths: Optional[Union[str, ModelPaths]] = None) -> int: def get_dtype(model_dir_or_paths: Union[str, ModelPaths]): - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) torch_dtype = _get_config_llm_keyvalue(paths, key_name=["torch_dtype", "dtype", "model_dtype"]) if torch_dtype is None: logger.warning("torch_dtype not in config.json, use float16 as default") @@ -545,7 +530,7 @@ def get_fixed_kv_len(): @lru_cache(maxsize=None) def has_vision_module(model_dir_or_paths: Union[str, ModelPaths]) -> bool: - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) try: model_cfg = get_model_config(paths) model_type = model_cfg["model_type"] @@ -592,7 +577,7 @@ def has_vision_module(model_dir_or_paths: Union[str, ModelPaths]) -> bool: @lru_cache(maxsize=None) def has_audio_module(model_dir_or_paths: Union[str, ModelPaths]) -> bool: - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) try: model_cfg = get_model_config(paths) if model_cfg.get("thinker_config") is not None: @@ -614,7 +599,7 @@ def has_audio_module(model_dir_or_paths: Union[str, ModelPaths]) -> bool: @lru_cache(maxsize=None) def is_linear_att_mixed_model(model_dir_or_paths: Union[str, ModelPaths]) -> bool: - paths = _create_model_paths(model_dir_or_paths) + paths = create_model_paths(model_dir_or_paths) try: model_cfg = get_model_config(paths) model_type = model_cfg["model_type"] diff --git a/lightllm/utils/gguf_tokenizer_utils.py b/lightllm/utils/gguf_tokenizer_utils.py index 8fafc20cec..6d67a7d31d 100644 --- a/lightllm/utils/gguf_tokenizer_utils.py +++ b/lightllm/utils/gguf_tokenizer_utils.py @@ -185,7 +185,7 @@ def load_tokenizer_from_gguf( ) logger.info( - f"loading tokenizer from GGUF ReaderField metadata: {gguf_path} " + f"Loading tokenizer from GGUF ReaderField metadata: {gguf_path} " f"(architecture={architecture}, vocab_size={len(tokenizer_dict['tokens'])})" ) diff --git a/lightllm/utils/shm_size_check.py b/lightllm/utils/shm_size_check.py index 0b3e9b0280..ad7f5915a4 100644 --- a/lightllm/utils/shm_size_check.py +++ b/lightllm/utils/shm_size_check.py @@ -6,7 +6,7 @@ from lightllm.server.core.objs.req import ChunkedPrefillReq, TokenHealingReq from lightllm.server.multimodal_params import ImageItem from lightllm.server.tokenizer import get_tokenizer -from lightllm.utils.config_utils import ModelPaths, get_hidden_size +from lightllm.utils.config_utils import create_model_paths, get_hidden_size from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -86,9 +86,13 @@ def _get_recommended_shm_size_gb(args, max_image_resolution=(3940, 2160), dtype_ """ 获取所需的 /dev/shm 大小(以GB为单位)。 """ - paths = ModelPaths.from_args(args) tokenizer = get_tokenizer( - paths, + create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ), tokenizer_mode=args.tokenizer_mode, trust_remote_code=args.trust_remote_code, ) From a8101f14a068fdb9a8bde1706167e1a8a4331d70 Mon Sep 17 00:00:00 2001 From: zhangtaoshan Date: Mon, 15 Jun 2026 20:47:25 +0800 Subject: [PATCH 5/5] misc --- lightllm/common/basemodel/basemodel.py | 16 ++-------------- .../meta_weights/mm_weight/mm_slicer.py | 5 +++-- lightllm/server/httpserver/manager.py | 5 +---- lightllm/utils/config_utils.py | 7 +------ lightllm/utils/llm_utils.py | 4 ++-- 5 files changed, 9 insertions(+), 28 deletions(-) diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index ac89b7050f..f87a78f2d5 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -23,11 +23,7 @@ from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token -from lightllm.utils.config_utils import ( - apply_gguf_quant_type, - create_model_paths, - get_model_config, -) +from lightllm.utils.config_utils import create_model_paths, get_model_config from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num @@ -109,9 +105,8 @@ def __init__(self, kvargs): self._verify_must() self._verify_params() self._init_quant() - self._align_quant_type_for_gguf_weights() - # read gguf and get quant shape + # Read GGUF weights mapping self._init_gguf() self._init_weights() self._init_req_manager() @@ -176,13 +171,6 @@ def _init_quant(self): self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path) logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") - def _align_quant_type_for_gguf_weights(self): - if self.model_paths_.gguf_path is None: - return - aligned = apply_gguf_quant_type(self.model_paths_, self.quant_cfg.quant_type) - if aligned != self.quant_cfg.quant_type: - self.quant_cfg.quant_type = aligned - def _init_gguf(self): if self.model_paths_.gguf_path is None: return diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index 312d3dd90d..5b7173882f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -159,6 +159,7 @@ def get_col_slice_mixin( ) -> SliceMixinTpl: if quant_method_name.startswith("awq"): return AwqQuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) - if quant_method_name == "none" or quant_method_name.startswith("gguf"): + elif quant_method_name == "none" or quant_method_name.startswith("gguf"): return ColSliceMixin(tp_rank, tp_world_size, repeat_times) - return QuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) + else: + return QuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index b6a4cba27d..9917daf8f6 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -127,10 +127,7 @@ def __init__( self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() # 有的模型的vocab size 读取tokenizer和config.json中不一致 - self.vocab_size = max( - get_vocab_size(paths), - self.tokenizer.vocab_size, - ) + self.vocab_size = max(get_vocab_size(paths), self.tokenizer.vocab_size) # The timemark of the latest inference(prefill/decode) which is used to check the health status of the system. # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index e4502c8dda..7ea8792491 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -235,11 +235,6 @@ def check_gguf_multimodal_paths(paths: ModelPaths, enable_multimodal: bool = Fal raise FileNotFoundError(f"mmproj_path {paths.mmproj_path} is not found") -@lru_cache(maxsize=1) -def get_start_args_model_config() -> dict: - return get_model_config(get_model_paths()) - - def _derive_max_req_total_len_from_model_config(paths: ModelPaths) -> Optional[int]: """ Derive `max_req_total_len` from model config.json. @@ -521,7 +516,7 @@ def get_dtype(model_dir_or_paths: Union[str, ModelPaths]): @lru_cache(maxsize=None) def get_fixed_kv_len(): - model_cfg = get_start_args_model_config() + model_cfg = get_model_config(get_model_paths()) if "prompt_cache_token_ids" in model_cfg: return len(model_cfg["prompt_cache_token_ids"]) else: diff --git a/lightllm/utils/llm_utils.py b/lightllm/utils/llm_utils.py index 298352e72d..dfaccaca38 100644 --- a/lightllm/utils/llm_utils.py +++ b/lightllm/utils/llm_utils.py @@ -8,8 +8,8 @@ @lru_cache(maxsize=None) def get_llm_model_class(): from lightllm.models import get_model_class - from lightllm.utils.config_utils import get_start_args_model_config + from lightllm.utils.config_utils import get_model_config, get_model_paths - model_cfg = get_start_args_model_config() + model_cfg = get_model_config(get_model_paths()) model_class = get_model_class(model_cfg=model_cfg) return model_class