diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 05aaaadca8..f87a78f2d5 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -23,6 +23,7 @@ from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph from lightllm.common.quantization import Quantcfg from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token +from lightllm.utils.config_utils import create_model_paths, get_model_config from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_dp_world_size from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num @@ -58,6 +59,7 @@ def __init__(self, kvargs): self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] self.weight_dir_ = kvargs["weight_dir"] + self.model_paths_ = create_model_paths(self.weight_dir_) self.max_total_token_num = kvargs["max_total_token_num"] self.batch_max_tokens = kvargs.get("batch_max_tokens", None) self.load_way = kvargs.get("load_way", "HF") @@ -104,6 +106,8 @@ def __init__(self, kvargs): self._verify_params() self._init_quant() + # Read GGUF weights mapping + self._init_gguf() self._init_weights() self._init_req_manager() self._init_mem_manager() @@ -144,8 +148,7 @@ def _wait_other_modules_ready(self): return def _init_config(self): - with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: - self.config = json.load(json_file) + self.config = get_model_config(self.model_paths_) # rename keys repair_config(self.config, same_names=["num_attention_heads", "n_head"]) repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) @@ -168,6 +171,15 @@ def _init_quant(self): self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path) logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}") + def _init_gguf(self): + if self.model_paths_.gguf_path is None: + return + + from lightllm.common.basemodel.layer_weights.gguf_load_utils import get_gguf_reader + + reader = get_gguf_reader(self.model_paths_.gguf_path) + self.quant_cfg.gguf_quant_meta_map = reader.build_quant_meta_map(self.config) + def _init_weights(self, start_layer_index=0): self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config) self.trans_layers_weight = [ @@ -182,13 +194,24 @@ def _init_weights(self, start_layer_index=0): return def _load_hf_weights(self): - load_hf_weights( - self.data_type, - weight_dir=self.weight_dir_, - pre_post_layer=self.pre_post_weight, - transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict, - ) + if self.model_paths_.gguf_path is not None: + from lightllm.common.basemodel.layer_weights.gguf_load_utils import get_gguf_reader + + reader = get_gguf_reader(self.model_paths_.gguf_path) + reader.load_weights( + data_type=self.data_type, + config=self.config, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + ) + else: + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict, + ) self.pre_post_weight.verify_load() [weight.verify_load() for weight in self.trans_layers_weight] return diff --git a/lightllm/common/basemodel/layer_weights/gguf_load_utils.py b/lightllm/common/basemodel/layer_weights/gguf_load_utils.py new file mode 100644 index 0000000000..a19ae3f3b0 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/gguf_load_utils.py @@ -0,0 +1,502 @@ +import gguf +import numpy as np +import os +import torch +from functools import lru_cache +from gguf import GGMLQuantizationType, dequantize, quant_shape_to_byte_shape +from gguf.gguf_reader import GGUFReader, ReaderTensor +from transformers import AutoModelForCausalLM, AutoModelForImageTextToText +from transformers.models.auto.configuration_auto import CONFIG_MAPPING +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple + +from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.log_utils import init_logger +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES + +try: + from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES +except ImportError: + MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = {} + +logger = init_logger(__name__) + +DEQUANT_KEYS = frozenset(["token_embd.weight", "output.weight"]) + +# Native float tensors in GGUF are stored as plain arrays, not byte-packed blocks. +UNQUANTIZED_GGML_TYPES = {GGMLQuantizationType.F32, GGMLQuantizationType.F16} + + +class LightLLMGGUFReader: + + def __init__(self, gguf_path: str): + self.gguf_path = gguf_path + self._reader: Optional[GGUFReader] = None + self._config: Optional[dict] = None + self._gguf_to_hf: Optional[Dict[str, str]] = None + self._quant_meta_map: Optional[Dict[str, Any]] = None + + @property + def reader(self) -> GGUFReader: + if self._reader is None: + self._reader = GGUFReader(self.gguf_path) + return self._reader + + def read_field(self, field_name: str) -> Any: + """ Read a field from the GGUF reader. """ + field = self.reader.fields.get(field_name) + if field is None: + return None + return field.contents() + + def load_config(self) -> dict: + """ Load model config from the GGUF reader. """ + if self._config is not None: + return self._config + + try: + from transformers.modeling_gguf_pytorch_utils import load_gguf_checkpoint + except ImportError as e: + raise ImportError( + "Loading config from GGUF requires transformers with GGUF support and the gguf package." + ) from e + config: dict = load_gguf_checkpoint(self.gguf_path, return_tensors=False)["config"] + config["architectures"] = self._resolve_hf_architectures(config) + + self._config = config + return config + + def get_gguf_to_hf_mapping(self, config: Optional[dict] = None) -> Dict[str, str]: + """ Build a mapping from gguf tensor name to hf tensor name. """ + if self._gguf_to_hf is not None: + return self._gguf_to_hf + + if config is None: + config = self.load_config() + + self._gguf_to_hf = build_gguf_to_hf_mapping(config) + return self._gguf_to_hf + + def build_quant_meta_map(self, config: Optional[dict] = None) -> Dict[str, Any]: + """ Build a quant metadata map for GGUF quantized weights. """ + if self._quant_meta_map is not None: + return self._quant_meta_map + + from lightllm.common.quantization.quantize_method import GGUFWeightMeta + + if config is None: + config = self.load_config() + gguf_to_hf = self.get_gguf_to_hf_mapping(config) + gguf_quant_meta_map = {} + for t in self.reader.tensors: + if t.name in DEQUANT_KEYS: + continue + hf_name = gguf_to_hf.get(t.name) + if hf_name is None: + continue + logical_shape = _gguf_logical_shape(t) + if len(logical_shape) != 2: + continue + np_data = np.asarray(t.data) + gguf_quant_meta_map[hf_name] = GGUFWeightMeta( + shape=logical_shape, + dtype=_numpy_dtype_to_torch(np_data.dtype), + quant_type=t.tensor_type, + ) + + self._quant_meta_map = gguf_quant_meta_map + return gguf_quant_meta_map + + def load_weights( + self, + data_type: str, + config: dict, + pre_post_layer: Any = None, + transformer_layer_list: Any = None, + ) -> None: + """ Load GGUF weights into model layers, then release the reader. """ + if isinstance(data_type, str): + data_type = torch.float16 if data_type == "fp16" else torch.float32 + + if pre_post_layer is not None: + assert pre_post_layer.data_type_ == data_type, "type is not right" + + if transformer_layer_list is not None: + assert transformer_layer_list[0].data_type_ == data_type, "type is not right" + + try: + gguf_to_hf = self.get_gguf_to_hf_mapping(config) + gguf_tensor_names = {t.name for t in self.reader.tensors} + validate_gguf_weight_mapping( + gguf_to_hf=gguf_to_hf, + gguf_tensor_names=gguf_tensor_names, + ) + torch.cuda.set_device(get_current_device_id()) + gguf_weights = _gguf_reader_to_weight_dict(self.reader) + hf_weights = rename_weights(gguf_weights, gguf_to_hf) + del gguf_weights + if pre_post_layer is not None: + pre_post_layer.load_hf_weights(hf_weights) + if transformer_layer_list is not None: + for layer in transformer_layer_list: + layer.load_hf_weights(hf_weights) + del hf_weights + finally: + self.close() + + def close(self) -> None: + self._reader = None + + def __enter__(self) -> "LightLLMGGUFReader": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def _resolve_hf_architectures(self, config: Dict[str, Any]) -> List[str]: + """ Resolve HuggingFace architectures from model_type. """ + model_type = config.get("model_type") + assert model_type is not None, "model_type is not found in config" + assert model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, ( + f"model_type {model_type!r} is not found in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES" + ) + + return [MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[model_type]] + + +@lru_cache(maxsize=10) +def get_gguf_reader(gguf_path: str) -> LightLLMGGUFReader: + gguf_path = os.path.abspath(gguf_path) + return LightLLMGGUFReader(gguf_path) + + +def _numpy_dtype_to_torch(np_dtype: np.dtype) -> torch.dtype: + return torch.from_numpy(np.zeros((), dtype=np_dtype)).dtype + + +def _gguf_logical_shape(rt: ReaderTensor) -> Tuple[int, ...]: + return tuple(reversed([int(x) for x in rt.shape.tolist()])) + + +def _resolve_model_type(model_type: str) -> str: + MODEL_TYPE_MAPPING = { + "qwen2_vl": "qwen2vl", + "qwen2_5_vl": "qwen2vl", + } + + return MODEL_TYPE_MAPPING.get(model_type, model_type) + + +def _normalize_hf_weight_name(hf_name: str) -> str: + multimodal_lm_prefix = "model.language_model." + if hf_name.startswith(multimodal_lm_prefix): + return "model." + hf_name[len(multimodal_lm_prefix) :] + + return hf_name + + +def _is_multimodal(config: Dict[str, Any]) -> bool: + if config.get("vision_config") is not None: + return True + model_type = config.get("model_type") + return model_type in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES + + +def _dummy_hf_model_from_config(config: Dict[str, Any], hf_config: Any): + with torch.device("meta"): + if _is_multimodal(config): + return AutoModelForImageTextToText.from_config(hf_config) + return AutoModelForCausalLM.from_config(hf_config) + + +def build_gguf_to_hf_mapping(config: Dict[str, Any]) -> Dict[str, str]: + """ + Build a mapping from gguf tensor name to hf tensor name, + e.g., 'token_embd.weight' -> 'model.embed_tokens.weight'. + """ + num_layers = config.get("num_hidden_layers") + assert num_layers is not None, "num_hidden_layers is not found in config" + model_type = config.get("model_type") + assert model_type is not None, "model_type is not found in config" + arch = None + resolve_model_type = _resolve_model_type(model_type) + for gguf_arch, hf_arch in gguf.MODEL_ARCH_NAMES.items(): + if hf_arch == resolve_model_type: + arch = gguf_arch + break + assert arch is not None, f"model_type {model_type!r} is not found in gguf.MODEL_ARCH_NAMES" + tensor_name_map = gguf.get_tensor_name_map(arch, num_layers) + + config_cls = CONFIG_MAPPING[model_type] + assert config_cls is not None, f"config_cls is not found in CONFIG_MAPPING for model_type={model_type}" + hf_config = config_cls(**config) + dummy_model = _dummy_hf_model_from_config(config, hf_config) + + gguf_to_hf_name_mapping: Dict[str, str] = {} + invalid_hf_names: List[str] = [] + for hf_name in dummy_model.state_dict(): + lightllm_hf_name = _normalize_hf_weight_name(hf_name) + name, extension = lightllm_hf_name.rsplit(".", 1) + gguf_name = tensor_name_map.get_name(name) + if gguf_name is None: + invalid_hf_names.append(hf_name) + continue + gguf_key = f"{gguf_name}.{extension}" + if gguf_key in gguf_to_hf_name_mapping: + raise ValueError( + f"duplicate GGUF tensor key {gguf_key!r} while mapping HF weights; " + f"existing={gguf_to_hf_name_mapping[gguf_key]!r}, new={lightllm_hf_name!r}" + ) + gguf_to_hf_name_mapping[gguf_key] = lightllm_hf_name + + if invalid_hf_names: + logger.warning( + "skipped %d HF weight(s) with no GGUF tensor name mapping (first 5: %s)", + len(invalid_hf_names), + invalid_hf_names[:5], + ) + return gguf_to_hf_name_mapping + + +def validate_gguf_weight_mapping( + gguf_to_hf: Dict[str, str], + gguf_tensor_names: Set[str], +) -> None: + """ Validate GGUF↔HF mapping and log tensors that will not be loaded. """ + # Check if all gguf tensors are mapped to hf tensors + mapped_gguf_names = set(gguf_to_hf.keys()) + unmapped_gguf = sorted(gguf_tensor_names - mapped_gguf_names) + if unmapped_gguf: + logger.warning( + "GGUF file contains %d tensor(s) without HF mapping; they will be skipped (first 10: %s)", + len(unmapped_gguf), + unmapped_gguf[:10], + ) + + # Check if all hf tensors are mapped to gguf tensors + missing_in_gguf = sorted(mapped_gguf_names - gguf_tensor_names) + if missing_in_gguf: + logger.warning( + "HF mapping expects %d GGUF tensor(s) that are absent in the file (first 10: %s)", + len(missing_in_gguf), + missing_in_gguf[:10], + ) + + +def rename_weights( + weights: Dict[str, torch.Tensor], + gguf_to_hf: Dict[str, str], +) -> Dict[str, torch.Tensor]: + """ Rename GGUF weights to HF names. """ + renamed: Dict[str, torch.Tensor] = {} + for gguf_name, tensor in weights.items(): + hf_name = gguf_to_hf.get(gguf_name) + if hf_name is None: + continue + renamed[hf_name] = tensor + + return renamed + + +def _normalize_visual_module_key(module_key: str) -> str: + """ LightLLM visual modules omit the HF 'visual.' prefix used by GGUF MMPROJ. """ + if module_key.startswith("visual."): + return module_key + return f"visual.{module_key}" + + +def build_mmproj_to_hf_mapping( + module_keys: Iterable[str], + depth: int, +) -> Tuple[Dict[str, str], List[str]]: + """ Build gguf tensor name -> LightLLM visual module state dict name mapping. """ + tensor_name_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, depth) + gguf_to_hf: Dict[str, str] = {} + skipped: List[str] = [] + for hf_name in module_keys: + prefix, extension = hf_name.rsplit(".", 1) + lookup_prefix = _normalize_visual_module_key(prefix) + gguf_prefix = tensor_name_map.get_name(lookup_prefix) + if gguf_prefix is None: + skipped.append(hf_name) + continue + gguf_key = f"{gguf_prefix}.{extension}" + if gguf_key in gguf_to_hf: + raise ValueError( + f"duplicate GGUF tensor key {gguf_key!r} while mapping mmproj weights; " + f"existing={gguf_to_hf[gguf_key]!r}, new={hf_name!r}" + ) + gguf_to_hf[gguf_key] = hf_name + + if skipped: + logger.warning( + "skipped %d visual module weight(s) with no GGUF MMPROJ mapping (first 5: %s)", + len(skipped), + skipped[:5], + ) + return gguf_to_hf, skipped + + +def _merge_mmproj_patch_embed(gguf_weights: Dict[str, torch.Tensor]) -> None: + """ Merge split Conv2d patch-embed slices back into a Conv3d weight. """ + patch_key = "v.patch_embd.weight" + w0 = gguf_weights.pop(patch_key, None) + w1 = gguf_weights.pop(f"{patch_key}.1", None) + if w0 is None: + return + if w1 is not None: + # GGUF stores temporal_patch_size=2 as two [out, in, H, W] slices. + gguf_weights[patch_key] = torch.stack([w0, w1], dim=2) + else: + gguf_weights[patch_key] = w0 + + +def _merge_mmproj_attn_qkv(gguf_weights: Dict[str, torch.Tensor], depth: int) -> None: + """ Merge split q/k/v tensors in mmproj into fused qkv weights. """ + for i in range(depth): + prefix = f"v.blk.{i}" + qw = gguf_weights.pop(f"{prefix}.attn_q.weight", None) + kw = gguf_weights.pop(f"{prefix}.attn_k.weight", None) + vw = gguf_weights.pop(f"{prefix}.attn_v.weight", None) + if qw is not None: + gguf_weights[f"{prefix}.attn_qkv.weight"] = torch.cat([qw, kw, vw], dim=0) + qb = gguf_weights.pop(f"{prefix}.attn_q.bias", None) + kb = gguf_weights.pop(f"{prefix}.attn_k.bias", None) + vb = gguf_weights.pop(f"{prefix}.attn_v.bias", None) + if qb is not None: + gguf_weights[f"{prefix}.attn_qkv.bias"] = torch.cat([qb, kb, vb], dim=0) + + +def _supplement_mmproj_qkv_mapping( + gguf_to_hf: Dict[str, str], + gguf_weights: Dict[str, torch.Tensor], + depth: int, +) -> None: + """ Add gguf->module mappings for qkv tensors created after q/k/v merge. """ + for i in range(depth): + prefix = f"v.blk.{i}" + for extension in ("weight", "bias"): + gguf_key = f"{prefix}.attn_qkv.{extension}" + if gguf_key in gguf_weights: + gguf_to_hf[gguf_key] = f"blocks.{i}.attn.qkv.{extension}" + + +def load_mmproj_gguf_weights( + mmproj_path: str, + module_keys: Iterable[str], + depth: int, + dtype: torch.dtype, +) -> Dict[str, torch.Tensor]: + """ Load and dequantize mmproj GGUF weights into LightLLM visual module key names. """ + gguf_to_hf, _ = build_mmproj_to_hf_mapping(module_keys, depth) + reader = get_gguf_reader(mmproj_path) + try: + gguf_weights = _gguf_reader_to_weight_dict(reader.reader, dequant_all=True) + _merge_mmproj_patch_embed(gguf_weights) + _merge_mmproj_attn_qkv(gguf_weights, depth) + _supplement_mmproj_qkv_mapping(gguf_to_hf, gguf_weights, depth) + weight_dict = rename_weights(gguf_weights, gguf_to_hf) + return {name: tensor.to(dtype) for name, tensor in weight_dict.items()} + finally: + reader.close() + + +def _reader_tensor_to_torch_cpu( + rt: ReaderTensor, + dequant: bool = False, + logical_shape: Tuple[int, ...] = None, +) -> torch.Tensor: + """ Read a GGUF tensor to a CPU torch tensor. """ + assert rt.shape.ndim <= 2, "GGUF tensor must be 2D or less" + if dequant: + d_tensor = dequantize(rt.data, rt.tensor_type) + if logical_shape is not None: + d_tensor = d_tensor.reshape(logical_shape) + + return torch.from_numpy(np.array(d_tensor, copy=True)) + + arr = np.array(rt.data, copy=True) + if logical_shape is None: + logical_shape = _gguf_logical_shape(rt) + + # Norm/bias vectors and native float weights are plain arrays, not block-quantized bytes. + if rt.tensor_type in UNQUANTIZED_GGML_TYPES or len(logical_shape) == 1: + if arr.size != np.prod(logical_shape): + raise ValueError( + f"GGUF tensor {rt.name} has size {arr.size} but expected {np.prod(logical_shape)}" + ) + return torch.from_numpy(arr.reshape(logical_shape)) + + byte_shape = quant_shape_to_byte_shape(logical_shape, rt.tensor_type) + if arr.size != np.prod(byte_shape): + raise ValueError(f"GGUF tensor {rt.name} has size {arr.size} but expected {np.prod(byte_shape)}") + + arr = arr.reshape(byte_shape) + + return torch.from_numpy(arr) + + +def _gguf_reader_to_weight_dict( + reader: GGUFReader, + *, + dequant_all: bool = False, +) -> Dict[str, torch.Tensor]: + """ Read GGUF reader to torch tensor dictionary. """ + gguf_weights = {} + for t in reader.tensors: + if dequant_all or t.name in DEQUANT_KEYS: + gguf_weights[t.name] = _reader_tensor_to_torch_cpu( + t, dequant=True, logical_shape=_gguf_logical_shape(t) + ) + else: + gguf_weights[t.name] = _reader_tensor_to_torch_cpu(t, dequant=False) + return gguf_weights + + +def dequant_gguf_weight( + byte_weight: torch.Tensor, + quant_type: gguf.GGMLQuantizationType, + logical_shape: Tuple[int, int], + device: torch.device, + dtype: torch.dtype +) -> torch.Tensor: + """ + Dequantize a GGUF weight to a CPU torch tensor during loading for GGUF. + """ + expected_byte_shape = quant_shape_to_byte_shape(logical_shape, quant_type) + if tuple(byte_weight.shape) != expected_byte_shape: + raise ValueError( + f"byte shard shape {byte_weight.shape} != expected {expected_byte_shape}" + ) + fp32 = dequantize(byte_weight.contiguous().numpy(), quant_type) + arr = np.asarray(fp32, dtype=np.float32).reshape(logical_shape) + + return torch.from_numpy(arr).to(device=device, dtype=dtype) + + +def load_weight_shard( + *, + raw_weight: torch.Tensor, + param_name: str, + weight_pack: Any, + slicer: Any, + quant_method: Any, +) -> None: + """ + Load a GGUF linear weight shard with TP slicing. + For predquant weights, dequantize the full tensor first, then slice. + """ + if weight_pack.gguf_load_predquant: + meta = quant_method.gguf_quant_meta_map[param_name] + full_fp = dequant_gguf_weight( + byte_weight=raw_weight, + quant_type=meta.quant_type, + logical_shape=meta.shape, + device=weight_pack.weight.device, + dtype=weight_pack.weight.dtype, + ) + shard = slicer._slice_weight(full_fp) + weight_pack.weight.copy_(shard) + weight_pack.load_ok[0] = True + else: + shard = slicer._slice_weight(raw_weight) + quant_method.load_weight(shard, weight_pack) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 8f54e14a72..75b210c744 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -1,6 +1,7 @@ import torch import threading from typing import Dict, Any, Optional, Tuple, List +from lightllm.common.basemodel.layer_weights.gguf_load_utils import load_weight_shard from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl from lightllm.common.quantization.quantize_method import WeightPack from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight.mm_slicer import ( @@ -332,14 +333,21 @@ def _load_expert( w1_weight = f"{self.weight_prefix}.{expert_idx}.{self.w1_weight_name}.{self.quant_method.weight_suffix}" w2_weight = f"{self.weight_prefix}.{expert_idx}.{self.w2_weight_name}.{self.quant_method.weight_suffix}" w3_weight = f"{self.weight_prefix}.{expert_idx}.{self.w3_weight_name}.{self.quant_method.weight_suffix}" - row_slice_func = self.row_slicer._slice_weight - col_slice_func = self.col_slicer._slice_weight - if w1_weight in weights: - self.quant_method.load_weight(row_slice_func(weights[w1_weight]), self.w1_list[local_expert_idx]) - if w3_weight in weights: - self.quant_method.load_weight(row_slice_func(weights[w3_weight]), self.w3_list[local_expert_idx]) - if w2_weight in weights: - self.quant_method.load_weight(col_slice_func(weights[w2_weight]), self.w2_list[local_expert_idx]) + expert_weights = ( + (w1_weight, self.w1_list[local_expert_idx], self.row_slicer), + (w3_weight, self.w3_list[local_expert_idx], self.row_slicer), + (w2_weight, self.w2_list[local_expert_idx], self.col_slicer), + ) + for param_name, weight_pack, slicer in expert_weights: + if param_name not in weights: + continue + load_weight_shard( + raw_weight=weights[param_name], + param_name=param_name, + weight_pack=weight_pack, + slicer=slicer, + quant_method=self.quant_method, + ) def _load_expert_scale( self, diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py index ddbf98a866..5b7173882f 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_slicer.py @@ -148,7 +148,7 @@ def get_row_slice_mixin( ) -> SliceMixinTpl: if quant_method_name.startswith("awq"): return AwqQuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) - elif quant_method_name == "none": + elif quant_method_name == "none" or quant_method_name.startswith("gguf"): return RowSliceMixin(tp_rank, tp_world_size, repeat_times) else: return QuantizedRowSliceMixin(tp_rank, tp_world_size, repeat_times) @@ -159,7 +159,7 @@ def get_col_slice_mixin( ) -> SliceMixinTpl: if quant_method_name.startswith("awq"): return AwqQuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) - elif quant_method_name == "none": + elif quant_method_name == "none" or quant_method_name.startswith("gguf"): return ColSliceMixin(tp_rank, tp_world_size, repeat_times) else: return QuantizedColSliceMixin(tp_rank, tp_world_size, repeat_times) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 5021699143..dd7c36cc58 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -6,6 +6,7 @@ from typing import Optional, Tuple, List, Dict, Union, Type from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.common.quantization.quantize_method import QuantizationMethod, WeightPack +from lightllm.common.basemodel.layer_weights.gguf_load_utils import load_weight_shard from lightllm.common.basemodel.layer_weights.meta_weights.base_weight import BaseWeightTpl from lightllm.common.quantization import Quantcfg from lightllm.common.quantization.no_quant import NoQuantization @@ -100,7 +101,11 @@ def _create_weight(self): self.mm_param: WeightPack = None self.mm_param_list: List[WeightPack] = None self.mm_param, self.mm_param_list = self.quant_method.create_weight( - in_dim=self.in_dim, out_dims=self.out_dims, dtype=self.data_type_, device_id=get_current_device_id() + in_dim=self.in_dim, + out_dims=self.out_dims, + dtype=self.data_type_, + device_id=get_current_device_id(), + weight_names=self.weight_names, ) return @@ -123,9 +128,13 @@ def _load_weight( if quanted_param_name in weights: param_name = quanted_param_name if param_name in weights: - slicer = self._get_param_slicer(sub_child_index) - weight = slicer._slice_weight(weights[param_name]) - self.quant_method.load_weight(weight, self.mm_param_list[sub_child_index]) + load_weight_shard( + raw_weight=weights[param_name], + param_name=param_name, + weight_pack=self.mm_param_list[sub_child_index], + slicer=self._get_param_slicer(sub_child_index), + quant_method=self.quant_method, + ) return def _load_bias( diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py index fb50398368..9fda4967df 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/rowmm_weight.py @@ -22,6 +22,8 @@ def __init__( ) -> None: self.tp_rank_ = tp_rank if tp_rank is not None else get_current_rank_in_dp() self.tp_world_size_ = tp_world_size if tp_world_size is not None else get_dp_world_size() + if isinstance(out_dims, int): + out_dims = [out_dims] out_dims = [self._get_tp_dim(out_dim) for out_dim in out_dims] super().__init__( in_dim=in_dim, diff --git a/lightllm/common/gguf_kernel/__init__.py b/lightllm/common/gguf_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/gguf_kernel/dequantization.py b/lightllm/common/gguf_kernel/dequantization.py new file mode 100644 index 0000000000..40387f2006 --- /dev/null +++ b/lightllm/common/gguf_kernel/dequantization.py @@ -0,0 +1,276 @@ +""" +ref: https://github.com/ggml-org/llama.cpp/discussions/17393 +""" +import torch +import triton +import triton.language as tl +from gguf import GGMLQuantizationType +from typing import Callable, Optional + +_GGUF_DEQUANT_REGISTRY: dict[GGMLQuantizationType, Callable[..., torch.Tensor]] = {} + + +def register_gguf_dequant(quant_type: GGMLQuantizationType): + + def _wrap(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + if quant_type in _GGUF_DEQUANT_REGISTRY: + raise ValueError( + f"Duplicate GGUF dequant registration for {quant_type}: " + f"{_GGUF_DEQUANT_REGISTRY[quant_type]!r} vs {fn!r}") + _GGUF_DEQUANT_REGISTRY[quant_type] = fn + return fn + + return _wrap + + +def get_gguf_dequant_fn( + quant_type: GGMLQuantizationType, +) -> Optional[Callable[..., torch.Tensor]]: + return _GGUF_DEQUANT_REGISTRY.get(quant_type) + + +QK5_0 = 32 +BLOCK_Q5_0_BYTES = 22 +""" +# Each block is represented by 22 consecutive values in the final ndarray, will be dequantized into 32 floats. +typedef struct { + ggml_half d; // scale, total 16 bits + uint8_t qh[4]; // each 1 bit high bits of quants, total 8*4=32 bits + uint8_t qs[QK5_0 / 2]; // each 4 bits low bits of quants, total 4*2*(32/2)=128 bits +} block_q5_0; // total 172 bits = 22 bytes +""" + + +@triton.jit +def dequantize_q5_0_kernel( + w_ptr, + out_ptr, + k, + QK: tl.constexpr, + BLOCK_BYTES: tl.constexpr, + OUT_IS_BF16: tl.constexpr, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < k + + block_idx = offs // QK + pos_in_block = offs % QK + + half_k = QK // 2 + first_half = pos_in_block < half_k + # [0, 15] + qs_low_bits_idx = tl.where(first_half, pos_in_block, pos_in_block - half_k) + + block_base = block_idx * BLOCK_BYTES + w_f16 = tl.cast(w_ptr, tl.pointer_type(tl.float16)) + d = tl.load(w_f16 + block_base // 2, mask=mask, other=0).to(tl.float32) + + qh = (tl.load(w_ptr + block_base + 2, mask=mask, other=0).to(tl.uint32) + | (tl.load(w_ptr + block_base + 3, mask=mask, other=0).to(tl.uint32) + << 8) + | (tl.load(w_ptr + block_base + 4, mask=mask, other=0).to(tl.uint32) + << 16) + | (tl.load(w_ptr + block_base + 5, mask=mask, other=0).to(tl.uint32) + << 24)) + + qsb = tl.load(w_ptr + block_base + 6 + qs_low_bits_idx, mask=mask, + other=0).to(tl.int32) + nib_lo = qsb & 0xF + nib_hi = (qsb >> 4) & 0xF + + qs_low_bits_idx_u32 = qs_low_bits_idx.to(tl.uint32) + xh_0 = ((qh >> (qs_low_bits_idx_u32 + 0)) << 4) & 0x10 + xh_1 = (qh >> (qs_low_bits_idx_u32 + 12)) & 0x10 + + val = tl.where(first_half, nib_lo | xh_0, nib_hi | xh_1) + y = (val.to(tl.float32) - 16.0) * d + + if OUT_IS_BF16: + tl.store(out_ptr + offs, y.to(tl.bfloat16), mask=mask) + else: + tl.store(out_ptr + offs, y.to(tl.float16), mask=mask) + + +@register_gguf_dequant(GGMLQuantizationType.Q5_0) +def dequantize_q5_0( + weight_uint8: torch.Tensor, + m: int, + n: int, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if dtype not in (torch.float16, torch.bfloat16): + raise ValueError("The output dtype must be float16 or bfloat16") + if weight_uint8.dtype != torch.uint8: + raise ValueError("The weight must be uint8 packed weights") + + k = m * n + if k % QK5_0 != 0: + raise ValueError(f"element count {k} must be a multiple of {QK5_0}") + + nb = k // QK5_0 + expected_bytes = nb * BLOCK_Q5_0_BYTES + if weight_uint8.numel() < expected_bytes: + raise ValueError( + f"W has {weight_uint8.numel()} bytes, need at least {expected_bytes}" + ) + + if out is None: + out = torch.empty((m, n), device=weight_uint8.device, dtype=dtype) + BLOCK = 1024 + grid = (triton.cdiv(k, BLOCK), ) + dequantize_q5_0_kernel[grid]( + weight_uint8, + out.view(-1), + k, + QK=QK5_0, + BLOCK_BYTES=BLOCK_Q5_0_BYTES, + OUT_IS_BF16=dtype == torch.bfloat16, + BLOCK=BLOCK, + ) + return out + + +QK8_0 = 32 +BLOCK_Q8_0_BYTES = 34 +""" +# Each block is represented by 34 consecutive values in the final ndarray, will be dequantized into 32 floats. +typedef struct { + ggml_half d; // scale, total 16 bits + int8_t qs[QK8_0]; // each 8 bits of quants, total 8*32=256 bits +} block_q8_0; // total 272 bits = 34 bytes +""" + + +@triton.jit +def dequantize_q8_0_kernel( + w_ptr, + out_ptr, + k, + QK: tl.constexpr, + BLOCK_BYTES: tl.constexpr, + OUT_IS_BF16: tl.constexpr, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < k + + block_idx = offs // QK + pos_in_block = offs % QK + + block_base = block_idx * BLOCK_BYTES + w_f16 = tl.cast(w_ptr, tl.pointer_type(tl.float16)) + d = tl.load(w_f16 + block_base // 2, mask=mask, other=0).to(tl.float32) + + qb = tl.load(w_ptr + block_base + 2 + pos_in_block, mask=mask, + other=0).to(tl.int32) + qs = tl.where(qb > 127, qb - 256, qb) + y = qs.to(tl.float32) * d + + if OUT_IS_BF16: + tl.store(out_ptr + offs, y.to(tl.bfloat16), mask=mask) + else: + tl.store(out_ptr + offs, y.to(tl.float16), mask=mask) + + +@register_gguf_dequant(GGMLQuantizationType.Q8_0) +def dequantize_q8_0( + weight_uint8: torch.Tensor, + m: int, + n: int, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if dtype not in (torch.float16, torch.bfloat16): + raise ValueError("The output dtype must be float16 or bfloat16") + if weight_uint8.dtype != torch.uint8: + raise ValueError("The weight must be uint8 packed weights") + + k = m * n + if k % QK8_0 != 0: + raise ValueError(f"element count {k} must be a multiple of {QK8_0}") + + nb = k // QK8_0 + expected_bytes = nb * BLOCK_Q8_0_BYTES + if weight_uint8.numel() < expected_bytes: + raise ValueError( + f"W has {weight_uint8.numel()} bytes, need at least {expected_bytes}" + ) + + if out is None: + out = torch.empty((m, n), device=weight_uint8.device, dtype=dtype) + BLOCK = 1024 + grid = (triton.cdiv(k, BLOCK), ) + dequantize_q8_0_kernel[grid]( + weight_uint8, + out.view(-1), + k, + QK=QK8_0, + BLOCK_BYTES=BLOCK_Q8_0_BYTES, + OUT_IS_BF16=dtype == torch.bfloat16, + BLOCK=BLOCK, + ) + return out + + +""" +Two uint8 values form a uint16 value, which is then converted to bfloat16. +""" + + +@triton.jit +def dequantize_bf16_kernel( + w_ptr, + out_ptr, + k, + OUT_IS_BF16: tl.constexpr, + BLOCK: tl.constexpr, +): + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < k + + w_bf16 = tl.cast(w_ptr, tl.pointer_type(tl.bfloat16)) + y = tl.load(w_bf16 + offs, mask=mask, other=0).to(tl.float32) + + if OUT_IS_BF16: + tl.store(out_ptr + offs, y.to(tl.bfloat16), mask=mask) + else: + tl.store(out_ptr + offs, y.to(tl.float16), mask=mask) + + +@register_gguf_dequant(GGMLQuantizationType.BF16) +def dequantize_bf16( + weight_uint8: torch.Tensor, + m: int, + n: int, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if dtype not in (torch.float16, torch.bfloat16): + raise ValueError("The output dtype must be float16 or bfloat16") + if weight_uint8.dtype != torch.uint8: + raise ValueError("The weight must be uint8 packed weights") + + k = m * n + expected_bytes = k * 2 + if weight_uint8.numel() < expected_bytes: + raise ValueError( + f"W has {weight_uint8.numel()} bytes, need at least {expected_bytes}" + ) + + if out is None: + out = torch.empty((m, n), device=weight_uint8.device, dtype=dtype) + BLOCK = 1024 + grid = (triton.cdiv(k, BLOCK), ) + dequantize_bf16_kernel[grid]( + weight_uint8, + out.view(-1), + k, + OUT_IS_BF16=dtype == torch.bfloat16, + BLOCK=BLOCK, + ) + return out diff --git a/lightllm/common/quantization/__init__.py b/lightllm/common/quantization/__init__.py index 1f08432c6a..34801955d3 100644 --- a/lightllm/common/quantization/__init__.py +++ b/lightllm/common/quantization/__init__.py @@ -6,6 +6,7 @@ from .deepgemm import * from .awq import * from .no_quant import * +from .gguf import * from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -16,6 +17,7 @@ def __init__(self, network_config, quant_type="none", custom_cfg_path=None): self.layer_num = network_config["n_layer"] self.quant_type = quant_type self.network_config_ = network_config + self.gguf_quant_meta_map = None self._parse_custom_cfg(custom_cfg_path) self._parse_network_config(network_config) @@ -80,4 +82,5 @@ def get_quant_method(self, layer_num, name): quant_type = self.get_quant_type(layer_num, name) quant_method = QUANTMETHODS.get(quant_type) quant_method.hf_quantization_config = self.hf_quantization_config + quant_method.gguf_quant_meta_map = self.gguf_quant_meta_map return quant_method diff --git a/lightllm/common/quantization/awq.py b/lightllm/common/quantization/awq.py index f3c7623975..dd917ddc88 100644 --- a/lightllm/common/quantization/awq.py +++ b/lightllm/common/quantization/awq.py @@ -109,7 +109,13 @@ def apply( return out def _create_weight( - self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) group_size = self.hf_quantization_config["group_size"] @@ -206,7 +212,13 @@ def apply( return out def _create_weight( - self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) self.n = out_dim diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 137455a821..779559db76 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -98,7 +98,13 @@ def apply( return out def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims weight_scale_out_dims = [(_out_dim + self.block_size - 1) // self.block_size for _out_dim in out_dims] diff --git a/lightllm/common/quantization/gguf.py b/lightllm/common/quantization/gguf.py new file mode 100644 index 0000000000..fd9e22888b --- /dev/null +++ b/lightllm/common/quantization/gguf.py @@ -0,0 +1,192 @@ +import torch +from gguf import GGMLQuantizationType, quant_shape_from_byte_shape +from gguf.quants import quant_shape_to_byte_shape +from typing import List, Optional, Tuple + +from .registry import QUANTMETHODS +from .quantize_method import GGUFWeightMeta, QuantizationMethod, WeightPack +from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager +from lightllm.common.basemodel.layer_weights.gguf_load_utils import UNQUANTIZED_GGML_TYPES +from lightllm.common.gguf_kernel.dequantization import get_gguf_dequant_fn +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +# To avoid logging the warning multiple times +_predquant_warned = False + + +def _warn_predquant_once(unsupported_quant_type: str) -> None: + global _predquant_warned + if _predquant_warned: + return + _predquant_warned = True + logger.warning( + "The current GGUF model contains quantization formats that do not support runtime " + f"dequantization (e.g., {unsupported_quant_type}). These weights will be dequantized during model loading, which " + "may increase GPU memory usage. To add support, register a dequantization implementation via " + "register_gguf_dequant in lightllm/common/gguf_kernel/dequantization.py.", + ) + + +def _linear( + input_tensor: torch.Tensor, + weight: torch.Tensor, + out: torch.Tensor, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if bias is None: + return torch.mm(input_tensor, weight, out=out) + return torch.addmm(bias, input_tensor, weight, out=out) + + +@QUANTMETHODS.register("gguf", platform="cuda") +class GGUFQuantizationMethod(QuantizationMethod): + + def __init__(self): + super().__init__() + + def quantize(self, weight: torch.Tensor, output: WeightPack) -> None: + raise NotImplementedError("GGUF online quantization is not supported") + + def apply( + self, + input_tensor: torch.Tensor, + weight_pack: WeightPack, + out: Optional[torch.Tensor] = None, + workspace: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Allocate output tensor based on the shape of the weight pack + assert weight_pack.gguf_quant_type is not None, "gguf_quant_type must be set on WeightPack" + if out is None: + if ( + weight_pack.gguf_quant_type in UNQUANTIZED_GGML_TYPES + or weight_pack.gguf_load_predquant + ): + out_features = weight_pack.weight.shape[-2] + else: + out_features, _ = quant_shape_from_byte_shape( + weight_pack.weight.shape, + weight_pack.gguf_quant_type, + ) + shape = (input_tensor.shape[0], out_features) + if use_custom_tensor_mananger: + out = g_cache_manager.alloc_tensor(shape, input_tensor.dtype, device=input_tensor.device) + else: + out = torch.empty(shape, dtype=input_tensor.dtype, device=input_tensor.device) + # Unquantized types and load-time dequantized weights are directly used + if ( + weight_pack.gguf_quant_type in UNQUANTIZED_GGML_TYPES + or weight_pack.gguf_load_predquant + ): + weight = weight_pack.weight.t() + return _linear(input_tensor, weight, out, bias) + # For quantized types, we need to dequantize the weight and then use it + dequant_fn = get_gguf_dequant_fn(weight_pack.gguf_quant_type) + if dequant_fn is None: + raise ValueError( + f"Unsupported GGUF quantization type: {weight_pack.gguf_quant_type}" + ) + m, n = quant_shape_from_byte_shape(weight_pack.weight.shape, + weight_pack.gguf_quant_type) + alloc_func = torch.empty if not use_custom_tensor_mananger else g_cache_manager.empty + dequantized_weight = alloc_func((m, n), + dtype=input_tensor.dtype, + device=input_tensor.device) + dequant_fn( + weight_pack.weight, + m, + n, + input_tensor.dtype, + out=dequantized_weight, + ) + weight = dequantized_weight.t() + + return _linear(input_tensor, weight, out, bias) + + @property + def method_name(self): + return "gguf" + + def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack) -> None: + assert not weight_pack.gguf_load_predquant, ( + "predquant GGUF weights must be loaded via load_weight_shard" + ) + if weight.shape != weight_pack.weight.shape: + raise ValueError( + f"GGUF weight shape {weight.shape} does not match weight pack shape {weight_pack.weight.shape}" + ) + device = weight_pack.weight.device + weight_pack.weight.copy_( + weight.contiguous().to(device=device, dtype=weight_pack.weight.dtype) + ) + weight_pack.load_ok[0] = True + + def _check_weight_need_quanted(self, weight: torch.Tensor) -> bool: + return False + + def _create_weight( + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, + ) -> Tuple[WeightPack, List[WeightPack]]: + assert weight_names is not None and len(weight_names) > 0, "weight_names must be provided" + assert len(weight_names) == len(out_dims), "weight_names and out_dims must align" + + weight_dtype = None + shard_quant_types: List[GGMLQuantizationType] = [] + for weight_name in weight_names: + meta: GGUFWeightMeta = self.gguf_quant_meta_map[weight_name] + shard_quant_types.append(meta.quant_type) + quant_shape = meta.shape + assert len( + quant_shape + ) == 2, f"GGUF linear weight must be 2D, got {quant_shape} for {weight_name}" + if weight_dtype is None: + weight_dtype = meta.dtype + else: + assert weight_dtype == meta.dtype, f"merged GGUF weights must share dtype, got {weight_dtype} vs {meta.dtype}" + + # If there are mixed quant types, we need to dequant each shard at load time + mixed_quant_types = len(set(shard_quant_types)) > 1 + gguf_quant_type = shard_quant_types[0] + gguf_load_predquant = mixed_quant_types + if not mixed_quant_types and gguf_quant_type not in UNQUANTIZED_GGML_TYPES: + if get_gguf_dequant_fn(gguf_quant_type) is None: + gguf_load_predquant = True + _warn_predquant_once(gguf_quant_type.name) + + logical_shape_rowmajor = (sum(out_dims), in_dim) + expert_prefix = (num_experts, ) if num_experts > 1 else () + if gguf_quant_type in UNQUANTIZED_GGML_TYPES or gguf_load_predquant: + full_shape = expert_prefix + logical_shape_rowmajor + storage_dtype = dtype if gguf_load_predquant else weight_dtype + else: + full_shape = expert_prefix + quant_shape_to_byte_shape( + logical_shape_rowmajor, gguf_quant_type) + storage_dtype = weight_dtype + + weight = torch.empty(full_shape, dtype=storage_dtype).cuda(device_id) + + mm_param = WeightPack( + weight=weight, + weight_scale=None, + weight_zero_point=None, + gguf_quant_type=gguf_quant_type, + gguf_load_predquant=gguf_load_predquant, + ) + mm_param_list = self._split_weight_pack( + mm_param, + weight_out_dims=out_dims, + weight_split_dim=-2, + ) + for pack, shard_quant_type in zip(mm_param_list, shard_quant_types): + pack.gguf_quant_type = shard_quant_type + + return mm_param, mm_param_list diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py index fa926ad6f0..2ad5a29d74 100644 --- a/lightllm/common/quantization/no_quant.py +++ b/lightllm/common/quantization/no_quant.py @@ -35,7 +35,13 @@ def apply( return torch.addmm(bias, input_tensor, weight, out=out) def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 95d8d806f9..e2b1be0250 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -1,8 +1,17 @@ +import numpy as np import torch from abc import ABC, abstractmethod from dataclasses import dataclass +from gguf import GGMLQuantizationType from lightllm.utils.dist_utils import get_current_device_id -from typing import Optional, List, Tuple +from typing import Dict, Optional, List, Tuple + + +@dataclass(frozen=True) +class GGUFWeightMeta: + shape: Tuple[int, ...] + dtype: torch.dtype + quant_type: GGMLQuantizationType @dataclass @@ -10,6 +19,9 @@ class WeightPack: weight: Optional[torch.Tensor] = None weight_scale: Optional[torch.Tensor] = None weight_zero_point: Optional[torch.Tensor] = None + gguf_quant_type: Optional[GGMLQuantizationType] = None + # Dequantize unsupported quantization formats during loading for GGUF + gguf_load_predquant: bool = False def __post_init__(self): self.load_ok = [False, self.weight_scale is None, self.weight_zero_point is None] @@ -19,7 +31,13 @@ def get_expert(self, expert_idx: int): weight = self.weight[expert_idx] weight_scale = self.weight_scale[expert_idx] if self.weight_scale is not None else None weight_zero_point = self.weight_zero_point[expert_idx] if self.weight_zero_point is not None else None - return WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + return WeightPack( + weight=weight, + weight_scale=weight_scale, + weight_zero_point=weight_zero_point, + gguf_quant_type=self.gguf_quant_type, + gguf_load_predquant=self.gguf_load_predquant, + ) class QuantizationMethod(ABC): @@ -37,6 +55,8 @@ def __init__(self): # 一些量化模式需要用到的额外量化参数,如awq量化 self.hf_quantization_config = None + # GGUF hf_name -> gguf storage meta info + self.gguf_quant_meta_map: Optional[Dict[str, GGUFWeightMeta]] = None @abstractmethod def quantize( @@ -64,17 +84,30 @@ def method_name(self): pass def create_weight( - self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: return self._create_weight( out_dims=out_dims, in_dim=in_dim, dtype=dtype, device_id=device_id, + num_experts=1, + weight_names=weight_names, ) def create_moe_weight( - self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: return self._create_weight( out_dims=out_dims, @@ -82,6 +115,7 @@ def create_moe_weight( dtype=dtype, device_id=device_id, num_experts=num_experts, + weight_names=weight_names, ) def load_weight(self, weight: torch.Tensor, weight_pack: WeightPack) -> None: @@ -112,7 +146,13 @@ def _check_weight_need_quanted(self, weight: torch.Tensor) -> bool: return weight.dtype in [torch.bfloat16, torch.float16, torch.float32, torch.float64] def _create_weight( - self, out_dims: List[int], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: List[int], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[str] = None, ) -> Tuple[WeightPack, List[WeightPack]]: pass @@ -144,6 +184,12 @@ def _split_weight_pack( ) for weight, weight_scale, weight_zero_point in zip(weight, weight_scale, weight_zero_point): mm_param_list.append( - WeightPack(weight=weight, weight_scale=weight_scale, weight_zero_point=weight_zero_point) + WeightPack( + weight=weight, + weight_scale=weight_scale, + weight_zero_point=weight_zero_point, + gguf_quant_type=weight_pack.gguf_quant_type, + gguf_load_predquant=weight_pack.gguf_load_predquant, + ) ) return mm_param_list diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index 65ec6cd145..0cf5c3bb22 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -49,7 +49,13 @@ def method_name(self): return "w8a8-base" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: raise NotImplementedError("Not implemented") @@ -99,7 +105,13 @@ def method_name(self): return "vllm-w8a8" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () @@ -163,7 +175,13 @@ def method_name(self): return "vllm-fp8w8a8" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () @@ -241,7 +259,13 @@ def method_name(self): return "vllm-fp8w8a8-b128" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () diff --git a/lightllm/common/quantization/w8a8gx.py b/lightllm/common/quantization/w8a8gx.py index c25136697d..704a82c197 100644 --- a/lightllm/common/quantization/w8a8gx.py +++ b/lightllm/common/quantization/w8a8gx.py @@ -32,7 +32,13 @@ def method_name(self): return "w8a8gx-base" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: raise NotImplementedError("Not implemented") @@ -119,7 +125,13 @@ def method_name(self): return f"triton-fp8w8a8g{self.act_quant_group_size}" def _create_weight( - self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1 + self, + out_dims: Union[int, List[int]], + in_dim: int, + dtype: torch.dtype, + device_id: int, + num_experts: int = 1, + weight_names: Optional[List[str]] = None, ) -> Tuple[WeightPack, List[WeightPack]]: out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims expert_prefix = (num_experts,) if num_experts > 1 else () diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 01e9c4ad35..fc94222b5b 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -8,7 +8,7 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args -from lightllm.utils.config_utils import get_vocab_size +from lightllm.utils.config_utils import get_model_paths, get_vocab_size from lightllm.server.router.model_infer.pin_mem_manager import g_pin_mem_manager from lightllm.common.linear_att_cache_manager.layer_cache import LayerCache from lightllm.common.linear_att_cache_manager.linear_att_buffer_manager import LinearAttCacheManager @@ -112,7 +112,7 @@ class ReqSamplingParamsManager: def __init__(self, max_request_num): # mode ["cpu_counter", "pin_mem_counter", "gpu_counter"] self.penalty_counter_mode = get_env_start_args().penalty_counter_mode - self.vocab_size = get_vocab_size(get_env_start_args().model_dir) + self.vocab_size = get_vocab_size(get_model_paths()) self.req_to_presence_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda") diff --git a/lightllm/models/gemma3/model.py b/lightllm/models/gemma3/model.py index 9931c31713..e49b3ff813 100644 --- a/lightllm/models/gemma3/model.py +++ b/lightllm/models/gemma3/model.py @@ -11,6 +11,7 @@ from lightllm.models.gemma3.layer_infer.transformer_layer_infer import Gemma3TransformerLayerInfer from lightllm.models.gemma3.layer_weights.pre_and_post_layer_weight import Gemma3PreAndPostLayerWeight from lightllm.models.gemma3.layer_weights.transformer_layer_weight import Gemma3TransformerLayerWeight +from lightllm.common.basemodel import TpPartBaseModel from lightllm.models.llama.model import LlamaTpPartModel from lightllm.server.multimodal_params import AudioItem, MultimodalParams, ImageItem from lightllm.server.core.objs import SamplingParams @@ -151,6 +152,7 @@ def _init_mem_manager(self): return def _init_config(self): + TpPartBaseModel._init_config(self) with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: self.config = json.load(json_file) # rename keys diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 7156a5ce23..7d51c2d001 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -4,11 +4,13 @@ import torch.nn.functional as F from PIL import Image from typing import List, Optional +from lightllm.common.basemodel.layer_weights.gguf_load_utils import load_mmproj_gguf_weights +from lightllm.utils.log_utils import init_logger from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from io import BytesIO import torch.nn as nn from transformers.activations import ACT2FN -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor, resize_image from safetensors import safe_open from lightllm.server.multimodal_params import ImageItem from lightllm.models.qwen2_vl.qwen2_visual import PatchEmbed, VisionRotaryEmbedding @@ -17,6 +19,8 @@ from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton +logger = init_logger(__name__) + class Qwen2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -154,8 +158,7 @@ def __init__( fullatt_block_indexes=[7, 15, 23, 31], **kwargs, ): - super().__init__() - self.weight_dir = kvargs["weight_dir"] + super().__init__() self.data_type = kvargs.get("data_type", "bfloat16") self.depth = depth @@ -204,10 +207,8 @@ def __init__( self.gradient_checkpointing = False - processor_config_path = os.path.join(self.weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + processor_dir = kvargs.get("processor_dir") + self.processor = load_image_processor(processor_dir, Qwen2VLImageProcessor) self._init_datatype() @@ -350,6 +351,22 @@ def load_image(self, img: List[ImageItem]): return pixel_values.to(dtype=self.data_type), image_grid_thw def load_model(self, weight_dir): + if weight_dir.endswith(".gguf"): + load_dtype = self.data_type + if not isinstance(load_dtype, torch.dtype): + load_dtype = torch.bfloat16 if load_dtype in ("bf16", "bfloat16") else torch.float16 + weight_dict = load_mmproj_gguf_weights( + mmproj_path=weight_dir, + module_keys=self.state_dict().keys(), + depth=self.depth, + dtype=load_dtype, + ) + missing, unexpected = self.load_state_dict(weight_dict, strict=False) + if missing: + logger.warning("mmproj missing keys (first 10): %s", missing[:10]) + if unexpected: + logger.warning("mmproj unexpected keys (first 10): %s", unexpected[:10]) + return bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/qwen2_vl/model.py b/lightllm/models/qwen2_vl/model.py index 237c4ad897..13d2bd884b 100644 --- a/lightllm/models/qwen2_vl/model.py +++ b/lightllm/models/qwen2_vl/model.py @@ -8,6 +8,7 @@ from lightllm.models.registry import ModelRegistry from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer +from lightllm.utils.config_utils import get_model_config from .vision_process import smart_resize from lightllm.models.qwen2.model import Qwen2TpPartModel @@ -100,8 +101,7 @@ def __init__(self, kvargs): return def _init_config(self): - with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: - self.config = json.load(json_file) + self.config = get_model_config(self.model_paths_) # rename keys repair_config(self.config, same_names=["num_attention_heads", "n_head"]) repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 6076756043..17cb9cfa72 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -33,7 +33,7 @@ from safetensors import safe_open from lightllm.server.multimodal_params import ImageItem from lightllm.server.visualserver import get_vit_attn_backend -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager from lightllm.models.qwen2_vl.triton_kernel.rotary_pos_emb import apply_rotary_pos_emb_triton @@ -192,6 +192,7 @@ def __init__( **kwargs, ): super().__init__() + self.processor_dir = kvargs.get("processor_dir") self.data_type = kvargs.get("data_type", "bfloat16") self.depth = depth @@ -239,11 +240,7 @@ def _init_datatype(self): return def load_model(self, weight_dir): - - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + self.processor = load_image_processor(self.processor_dir, Qwen2VLImageProcessor) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index bc313fe467..1f9cd5842e 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -1,5 +1,7 @@ from __future__ import annotations +import json import math +import os import torch import numpy as np from PIL import Image @@ -284,3 +286,14 @@ def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torc image_grid_thw = torch.as_tensor(processed_grids) return pixel_values, image_grid_thw + + +def load_image_processor(processor_dir: str, processor_cls): + assert os.path.isdir(processor_dir), f"Processor directory not found at {processor_dir}" + config_path = os.path.join(processor_dir, "preprocessor_config.json") + + assert os.path.exists(config_path), f"Processor config file not found at {config_path} for {processor_dir}" + with open(config_path, "r") as f: + processor_config = json.load(f) + + return processor_cls(**processor_config) \ No newline at end of file diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py index 0276724749..bf582d49a3 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -27,7 +27,7 @@ from lightllm.server.multimodal_params import ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention @@ -139,6 +139,7 @@ def __init__( **kwargs, ): super().__init__() + self.processor_dir = kvargs.get("processor_dir", kvargs.get("weight_dir")) self.data_type = kvargs.get("data_type", "bfloat16") self.depth = depth @@ -220,11 +221,7 @@ def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature return all_img_embeds_ds, valid_ids def load_model(self, weight_dir): - - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + self.processor = load_image_processor(self.processor_dir, Qwen2VLImageProcessor) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index bed8898115..019137e36a 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -27,7 +27,7 @@ from lightllm.server.multimodal_params import ImageItem from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data -from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention from lightllm.utils.log_utils import init_logger @@ -135,6 +135,7 @@ def __init__( **kwargs, ): super().__init__() + self.processor_dir = kvargs.get("processor_dir", kvargs.get("weight_dir")) self.data_type = kvargs.get("data_type", "bfloat16") self.depth = depth @@ -215,11 +216,7 @@ def concat_img_embed_and_deepstack_features(self, image_embed, deepstack_feature return all_img_embeds_ds, valid_ids def load_model(self, weight_dir): - - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + self.processor = load_image_processor(self.processor_dir, Qwen2VLImageProcessor) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 9deaf08575..2fd221529f 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -16,7 +16,7 @@ from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.server.multimodal_params import ImageItem -from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, resize_image +from lightllm.models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor def add_split_tokens(image_features, image_newline_embed, image_new_embed): @@ -217,10 +217,7 @@ def forward( return image_features def load_model(self, weight_dir): - processor_config_path = os.path.join(weight_dir, "preprocessor_config.json") - with open(processor_config_path, "r") as f: - processor_config_dict = json.load(f) - self.processor = Qwen2VLImageProcessor(**processor_config_dict) + self.processor = load_image_processor(self.processor_dir, Qwen2VLImageProcessor) bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] if bin_weight_files: diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index f33f58b86d..b8ffd5ba59 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -121,6 +121,25 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="the model weight dir path, the app will load config, weights and tokenizer from this dir", ) + parser.add_argument( + "--config_path", + type=str, + default=None, + help="the config path (JSON). If not set, use model_dir/config.json when present; " + "otherwise derive config from GGUF metadata when model_dir is a .gguf file or contains one", + ) + parser.add_argument( + "--tokenizer_dir", + type=str, + default=None, + help="tokenizer directory; for gguf, load tokenizer from gguf is default, override it by providing this param", + ) + parser.add_argument( + "--mmproj_path", + type=str, + default=None, + help="The path of mmproj for multimodal mode, only supported for GGUF models", + ) parser.add_argument( "--tokenizer_mode", type=str, @@ -589,9 +608,10 @@ def make_argument_parser() -> argparse.ArgumentParser: "--quant_type", type=str, default="none", - help="""Quantization method: vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 - | deepgemm-fp8w8a8-b128 | triton-fp8w8a8-block128 | awq | awq_marlin | - | triton-fp8w8a8g128 (weight perchannel quant and act per group quant) | + help="""Quantization method: none | gguf (requires .gguf weights under model_dir) | + vllm-w8a8 | vllm-fp8w8a8 | vllm-fp8w8a8-b128 | deepgemm-fp8w8a8-b128 | + triton-fp8w8a8-block128 | awq | awq_marlin | + triton-fp8w8a8g128 (weight perchannel quant and act per group quant) | triton-fp8w8a8g64 (weight perchannel quantization with group size 64)""", ) parser.add_argument( diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index c324df19c8..7a238012a4 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -32,6 +32,7 @@ from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.error_utils import ClientDisconnected +from lightllm.utils.config_utils import create_model_paths, has_vision_module from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient from lightllm.utils.envs_utils import get_unique_server_name @@ -232,6 +233,12 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req created_time = int(time.time()) + supports_vl_chat = has_vision_module(create_model_paths( + g_objs.args.model_dir, + config_path=g_objs.args.config_path, + tokenizer_dir=g_objs.args.tokenizer_dir, + mmproj_path=g_objs.args.mmproj_path, + )) multimodal_params_dict = {"images": [], "audios": []} for message in request.messages: if isinstance(message.content, list): @@ -276,6 +283,9 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req else: raise ValueError("Unrecognized audio input. Supports local path, http url, base64.") + if not supports_vl_chat: + message.content = "\n".join(texts) if texts else "" + tools = None if request.tools and request.tool_choice != "none": # request.skip_special_tokens = False diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 8c6af128c8..319b0a96bf 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -18,10 +18,13 @@ from lightllm.utils.redis_utils import start_redis_service from lightllm.utils.shm_size_check import check_recommended_shm_size from lightllm.utils.config_utils import ( + create_model_paths, + apply_gguf_quant_type, has_audio_module, has_vision_module, is_linear_att_mixed_model, auto_set_max_req_total_len, + check_gguf_multimodal_paths, ) from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args @@ -74,7 +77,14 @@ def normal_or_p_d_start(args): args: StartArgs = args - auto_set_max_req_total_len(args) + paths = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ) + + auto_set_max_req_total_len(args, paths) set_unique_server_name(args) if args.enable_mps: @@ -86,16 +96,22 @@ def normal_or_p_d_start(args): return # 通过模型的参数判断是否是多模态模型,包含哪几种模态, 并设置是否启动相应得模块 - if args.disable_vision is None: - if has_vision_module(args.model_dir): - args.disable_vision = False - else: - args.disable_vision = True - if args.disable_audio is None: - if has_audio_module(args.model_dir): - args.disable_audio = False - else: + if paths.is_gguf: + if args.disable_vision is None: + args.disable_vision = not bool(args.mmproj_path) + if args.disable_audio is None: args.disable_audio = True + else: + if args.disable_vision is None: + if has_vision_module(paths): + args.disable_vision = False + else: + args.disable_vision = True + if args.disable_audio is None: + if has_audio_module(paths): + args.disable_audio = False + else: + args.disable_audio = True # pd 分离模式下,不启动多模态的模块 if args.run_mode in ["decode", "nixl_decode"]: @@ -107,6 +123,9 @@ def normal_or_p_d_start(args): else: args.enable_multimodal = True + # Check if the tokenizer directory and mmproj path are provided for GGUF multimodal models + check_gguf_multimodal_paths(paths, enable_multimodal=args.enable_multimodal) + if args.enable_cpu_cache: # 生成一个用于创建cpu kv cache的共享内存id。 args.cpu_kv_cache_shm_id = uuid.uuid1().int % 123456789 @@ -142,6 +161,8 @@ def normal_or_p_d_start(args): if not args.disable_shm_warning: check_recommended_shm_size(args) + args.quant_type = apply_gguf_quant_type(paths, args.quant_type) + assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] # 确保单机上多实列不冲突 if args.zmq_mode == "ipc:///tmp/": @@ -291,7 +312,7 @@ def normal_or_p_d_start(args): # linear_att_cache_size 只会在 qwen3.5 等混合线性层模型中生效。 args.linear_att_cache_size = args.running_max_req_size * 2 - if args.enable_cpu_cache and is_linear_att_mixed_model(args.model_dir): + if args.enable_cpu_cache and is_linear_att_mixed_model(paths): args.cpu_cache_token_page_size = args.linear_att_hash_page_size * args.linear_att_page_block_num logger.info(f"set cpu_cache_token_page_size to {args.cpu_cache_token_page_size} for linear hybrid att model") @@ -305,13 +326,13 @@ def normal_or_p_d_start(args): if args.eos_id is None: from lightllm.utils.config_utils import get_eos_token_ids - args.eos_id = get_eos_token_ids(args.model_dir) + args.eos_id = get_eos_token_ids(paths) # 如果 tool_call_parser 是 None,尝试根据模型类型自动设置 if args.tool_call_parser is None: from lightllm.utils.config_utils import get_tool_call_parser_for_model - args.tool_call_parser = get_tool_call_parser_for_model(args.model_dir) + args.tool_call_parser = get_tool_call_parser_for_model(paths) if args.tool_call_parser: logger.info(f"Auto set tool_call_parser to {args.tool_call_parser} based on model type") @@ -319,14 +340,14 @@ def normal_or_p_d_start(args): if args.reasoning_parser is None: from lightllm.utils.config_utils import get_reasoning_parser_for_model - args.reasoning_parser = get_reasoning_parser_for_model(args.model_dir) + args.reasoning_parser = get_reasoning_parser_for_model(paths) if args.reasoning_parser: logger.info(f"Auto set reasoning_parser to {args.reasoning_parser} based on model type") if args.data_type is None: from lightllm.utils.config_utils import get_dtype - args.data_type = get_dtype(args.model_dir) + args.data_type = get_dtype(paths) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] already_uesd_ports = [args.port] @@ -527,7 +548,13 @@ def pd_master_start(args): if args.run_mode != "pd_master": return - auto_set_max_req_total_len(args) + paths = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ) + auto_set_max_req_total_len(args, paths) # when use config_server to support multi pd_master node, we # need generate unique node id for each pd_master node. @@ -552,6 +579,8 @@ def pd_master_start(args): set_env_start_args(args) + args.quant_type = apply_gguf_quant_type(paths, args.quant_type) + process_manager.start_submodule_processes( start_funcs=[ start_metric_manager, @@ -591,6 +620,12 @@ def visual_only_start(args): from lightllm.server.core.objs.start_args_type import StartArgs args: StartArgs = args + paths = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ) if args.afs_image_embed_dir is not None: os.makedirs(args.afs_image_embed_dir, mode=0o777, exist_ok=True) os.chmod(args.afs_image_embed_dir, 0o777) @@ -609,7 +644,7 @@ def visual_only_start(args): if args.data_type is None: from lightllm.utils.config_utils import get_dtype - args.data_type = get_dtype(args.model_dir) + args.data_type = get_dtype(paths) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] logger.info(f"alloced ports: {can_use_ports}") diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 84044fccce..160e9a70db 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -1,5 +1,6 @@ import os import json +from lightllm.utils.config_utils import create_model_paths from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -11,7 +12,16 @@ def init_tokenizer(args): global tokenizer - tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + tokenizer = get_tokenizer( + create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ), + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) chat_path = args.chat_template if chat_path is not None: with open(chat_path, "r", encoding="utf-8") as f: diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 954daa50fe..2e8000cf08 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -29,7 +29,10 @@ class StartArgs: select_p_d_node_strategy: str = field(default=None) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) + tokenizer_dir: Optional[str] = field(default=None) tokenizer_mode: str = field(default="slow") + config_path: Optional[str] = field(default=None) + mmproj_path: Optional[str] = field(default=None) load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) mem_fraction: float = field(default=0.9) diff --git a/lightllm/server/detokenization/manager.py b/lightllm/server/detokenization/manager.py index 389171ba8a..c3a8a95bb5 100644 --- a/lightllm/server/detokenization/manager.py +++ b/lightllm/server/detokenization/manager.py @@ -12,7 +12,7 @@ from .decode import decode_token from .decode_mode_fix import decode_mode_fix from .decode_req import DecodeReq -from ..tokenizer import get_tokenizer +from ..tokenizer import create_model_paths, get_tokenizer import pickle import time from lightllm.utils.log_utils import init_logger @@ -34,7 +34,16 @@ def __init__( self.pub_to_httpserver = context.socket(zmq.PUB) self.pub_to_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") logger.info(f"pub_to_httpserver sendhwm {self.pub_to_httpserver.getsockopt(zmq.SNDHWM)}") - self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + self.tokenizer = get_tokenizer( + create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ), + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) self.all_special_ids = set(self.tokenizer.all_special_ids) self.req_id_to_out: Dict[int, DecodeReq] = {} self.eos_id = args.eos_id diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 4c049f77c0..9917daf8f6 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -32,7 +32,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.server.metrics.manager import MetricClient from lightllm.utils.statics_utils import MovingAverage -from lightllm.utils.config_utils import get_vocab_size +from lightllm.utils.config_utils import create_model_paths, get_vocab_size from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.error_utils import ClientDisconnected, NixlPrefillNodeStopGenToken from rpyc.utils.classic import obtain @@ -102,7 +102,18 @@ def __init__( self.zmq_recv_socket.connect(f"{args.zmq_mode}127.0.0.1:{args.http_server_port}") self.zmq_recv_socket.setsockopt(zmq.SUBSCRIBE, b"") - self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + paths = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ) + + self.tokenizer = get_tokenizer( + paths, + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) self.req_id_to_out_inf: Dict[int, ReqStatus] = {} # value type (out_str, metadata, finished, event) self.forwarding_queue: AsyncQueue = None # p d 分离模式使用的转发队列, 需要延迟初始化 @@ -116,7 +127,7 @@ def __init__( self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() # 有的模型的vocab size 读取tokenizer和config.json中不一致 - self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size) + self.vocab_size = max(get_vocab_size(paths), self.tokenizer.vocab_size) # The timemark of the latest inference(prefill/decode) which is used to check the health status of the system. # If the timemark is not updated for a pre-set time, a prob request will be sent to the backend. diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index af7a1e29fa..38013f785f 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -12,7 +12,7 @@ from ..pd_io_struct import PD_Client_Obj, UpKVStatus, NixlUpKVStatus, ObjType, NodeRole, NIXLDecodeNodeInfo from lightllm.server.core.objs import SamplingParams, StartArgs from ..multimodal_params import MultimodalParams -from ..tokenizer import get_tokenizer +from ..tokenizer import create_model_paths, get_tokenizer from ..req_id_generator import ReqIDGenerator, convert_sub_id_to_group_id from fastapi import Request from lightllm.utils.log_utils import init_logger @@ -42,7 +42,16 @@ def __init__( self.req_id_to_out_inf: Dict[int, ReqStatus] = {} self.infos_queues = None # 这个需要延迟初始化,否则使用的loop不对 - self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + self.tokenizer = get_tokenizer( + create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ), + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index ca982ec0f0..82094eefe2 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -19,6 +19,7 @@ from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.basemodel.batch_objs import ModelOutput, ModelInput from lightllm.common.basemodel.triton_kernel.mtp_utils import mtp_verify +from lightllm.utils.config_utils import create_model_paths, get_model_config from lightllm.utils.dist_utils import init_distributed_env from lightllm.utils.envs_utils import get_unique_server_name from lightllm.server.core.objs import ShmReqManager, StartArgs @@ -150,7 +151,7 @@ def init_model(self, kvargs): if self.args.enable_multimodal: g_infer_context.init_cpu_embed_cache_client() - model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir) + model_cfg = get_model_config(create_model_paths(self.weight_dir)) model_kvargs = { "weight_dir": self.weight_dir, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py index 9be5fcd1f5..21341e0589 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py @@ -7,6 +7,7 @@ from lightllm.server.core.objs import FinishStatus from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.server.tokenizer import get_tokenizer +from lightllm.utils.config_utils import create_model_paths from typing import List, Tuple from lightllm.utils.log_utils import init_logger @@ -35,7 +36,16 @@ def init_custom(self): from outlines.models.transformers import TransformerTokenizer self.tokenizer = TransformerTokenizer( - get_tokenizer(self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code) + get_tokenizer( + create_model_paths( + self.args.model_dir, + config_path=self.args.config_path, + tokenizer_dir=self.args.tokenizer_dir, + mmproj_path=self.args.mmproj_path, + ), + tokenizer_mode=self.args.tokenizer_mode, + trust_remote_code=self.args.trust_remote_code, + ), ) eos_token_ids = [] eos_token_ids.append(self.tokenizer.eos_token_id) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py index d8de132cac..94fec89193 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py @@ -2,6 +2,7 @@ from .impl import ChunkedPrefillBackend from typing import List from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.utils.config_utils import create_model_paths from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -21,7 +22,14 @@ def init_custom(self): 初始化tokenizer 词表相关的的操作 """ self.tokenizer = get_tokenizer( - self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code + create_model_paths( + self.args.model_dir, + config_path=self.args.config_path, + tokenizer_dir=self.args.tokenizer_dir, + mmproj_path=self.args.mmproj_path, + ), + tokenizer_mode=self.args.tokenizer_mode, + trust_remote_code=self.args.trust_remote_code, ) vob_dict = self.tokenizer.get_vocab() self.token_to_token_id = vob_dict diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py index b159b98f25..f3e987868e 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py @@ -5,6 +5,7 @@ from lightllm.utils.infer_utils import calculate_time from lightllm.server.core.objs import FinishStatus from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.utils.config_utils import create_model_paths from lightllm.server.tokenizer import get_tokenizer from lightllm.utils.log_utils import init_logger @@ -23,7 +24,14 @@ def init_custom(self): import xgrammar as xgr self.tokenizer = get_tokenizer( - self.args.model_dir, self.args.tokenizer_mode, trust_remote_code=self.args.trust_remote_code + create_model_paths( + self.args.model_dir, + config_path=self.args.config_path, + tokenizer_dir=self.args.tokenizer_dir, + mmproj_path=self.args.mmproj_path, + ), + tokenizer_mode=self.args.tokenizer_mode, + trust_remote_code=self.args.trust_remote_code, ) self.tokenizer_info = xgr.TokenizerInfo.from_huggingface(self.tokenizer) diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 25726b2578..1b2ae639cd 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -21,6 +21,8 @@ from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast from transformers.convert_slow_tokenizer import convert_slow_tokenizer from transformers.configuration_utils import PretrainedConfig +from lightllm.utils.config_utils import ModelPaths, get_model_config, create_model_paths +from lightllm.utils.gguf_tokenizer_utils import load_tokenizer_from_gguf from lightllm.utils.log_utils import init_logger from ..models.tarsier2.model import Tarsier2Tokenizer @@ -37,98 +39,147 @@ _FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" -def get_tokenizer( - tokenizer_name: str, - tokenizer_mode: str = "auto", - trust_remote_code: bool = False, +def _load_hf_tokenizer( + path: str, + trust_remote_code: bool, *args, **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - """Gets a tokenizer for the given model name via Huggingface.""" + try: + return AutoTokenizer.from_pretrained(path, trust_remote_code=trust_remote_code, *args, **kwargs) + except TypeError as e: + logger.warning(f"load fast tokenizer fail: {str(e)}") + kwargs["use_fast"] = False + return AutoTokenizer.from_pretrained(path, trust_remote_code=trust_remote_code, *args, **kwargs) + + +def _load_base_tokenizer( + load_path: str, + from_gguf: bool, + model_cfg: dict, + trust_remote_code: bool, + tokenizer_mode: str, + *args, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """ Load the base tokenizer based on AutoTokenizer. """ if tokenizer_mode == "slow": if kwargs.get("use_fast", False): raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False + + if from_gguf: + return load_tokenizer_from_gguf(load_path, model_cfg, *args, **kwargs) - if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True): + if "llama" in load_path.lower() and kwargs.get("use_fast", True): logger.info( "For some LLaMA-based models, initializing the fast tokenizer may " "take a long time. To eliminate the initialization time, consider " f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " "tokenizer." ) - # tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name) - # tokenizer = convert_slow_tokenizer(tokenizer) - # return tokenizer - - try: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs) - except TypeError as e: - # The LLaMA tokenizer causes a protobuf error in some environments, using slow mode. - # you can try pip install protobuf==3.20.0 to try repair - logger.warning(f"load fast tokenizer fail: {str(e)}") - kwargs["use_fast"] = False - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs) - + + tokenizer = _load_hf_tokenizer(load_path, trust_remote_code, *args, **kwargs) if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.info( "Using a slow tokenizer. This might cause a significant " "slowdown. Consider using a fast tokenizer instead." ) - model_cfg, _ = PretrainedConfig.get_config_dict(tokenizer_name) + return tokenizer + + +def _wrap_tokenizer( + base_tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + model_cfg: dict, + processor_dir: str, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: model_type = model_cfg.get("model_type", "") # DeepSeek-V3.2 custom tokenizer mode: wraps the HF tokenizer with # a Python-based apply_chat_template that uses encoding_dsv32.py. if model_type == "deepseek_v32": from ..models.deepseek3_2.model import DeepSeekV32Tokenizer - hf_tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, trust_remote_code=trust_remote_code, *args, **kwargs - ) logger.info("Using DeepSeek-V3.2 tokenizer mode with Python-based chat template encoding.") - return DeepSeekV32Tokenizer(hf_tokenizer) + return DeepSeekV32Tokenizer(base_tokenizer) if model_cfg["architectures"][0] == "TarsierForConditionalGeneration": - from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor - - image_processor = Qwen2VLImageProcessor.from_pretrained(tokenizer_name) - tokenizer = Tarsier2Tokenizer(tokenizer=tokenizer, image_processor=image_processor, model_cfg=model_cfg) - elif model_type == "llava" or model_type == "internlmxcomposer2": - tokenizer = LlavaTokenizer(tokenizer, model_cfg) - elif model_type == "qwen" and "visual" in model_cfg: - tokenizer = QWenVLTokenizer(tokenizer, model_cfg) - elif model_type in ["qwen2_vl", "qwen2_5_vl"] and "vision_config" in model_cfg: - from transformers import AutoProcessor + from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor + + image_processor = load_image_processor(processor_dir, Qwen2VLImageProcessor) + return Tarsier2Tokenizer(tokenizer=base_tokenizer, image_processor=image_processor, model_cfg=model_cfg) + + if model_type == "llava" or model_type == "internlmxcomposer2": + return LlavaTokenizer(base_tokenizer, model_cfg) + + if model_type == "qwen" and "visual" in model_cfg: + return QWenVLTokenizer(base_tokenizer, model_cfg) - processor = AutoProcessor.from_pretrained(tokenizer_name) - tokenizer = QWen2VLTokenizer( - tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + if model_type in ["qwen2_vl", "qwen2_5_vl"] and "vision_config" in model_cfg: + from ..models.qwen2_vl.vision_process import Qwen2VLImageProcessor, load_image_processor + + image_processor = load_image_processor(processor_dir, Qwen2VLImageProcessor) + return QWen2VLTokenizer( + tokenizer=base_tokenizer, image_processor=image_processor, model_cfg=model_cfg ) - elif model_type in ["qwen3_vl", "qwen3_vl_moe"] and "vision_config" in model_cfg: + + if model_type in ["qwen3_vl", "qwen3_vl_moe"] and "vision_config" in model_cfg: from transformers import AutoProcessor - processor = AutoProcessor.from_pretrained(tokenizer_name) - tokenizer = QWen3VLTokenizer( - tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + processor = AutoProcessor.from_pretrained(processor_dir) + return QWen3VLTokenizer( + tokenizer=base_tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) - elif model_type in ["qwen3_5", "qwen3_5_moe"] and "vision_config" in model_cfg: + + if model_type in ["qwen3_5", "qwen3_5_moe"] and "vision_config" in model_cfg: from transformers import AutoProcessor from ..models.qwen3_5.model import QWen3_5Tokenizer - processor = AutoProcessor.from_pretrained(tokenizer_name) - tokenizer = QWen3_5Tokenizer( - tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + processor = AutoProcessor.from_pretrained(processor_dir) + return QWen3_5Tokenizer( + tokenizer=base_tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) - elif model_cfg.get("thinker_config") is not None: + + if model_cfg.get("thinker_config") is not None: from transformers import AutoProcessor model_cfg = model_cfg["thinker_config"] - processor = AutoProcessor.from_pretrained(tokenizer_name) - tokenizer = QWen3OmniTokenizer(tokenizer, processor=processor, model_cfg=model_cfg) - elif model_type == "internvl_chat": - tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) - elif model_type == "gemma3": - tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + processor = AutoProcessor.from_pretrained(processor_dir) + return QWen3OmniTokenizer(base_tokenizer, processor=processor, model_cfg=model_cfg) - return tokenizer + if model_type == "internvl_chat": + return InternvlTokenizer(base_tokenizer, model_cfg, weight_dir=processor_dir) + + if model_type == "gemma3": + return Gemma3Tokenizer(base_tokenizer, model_cfg) + + return base_tokenizer + + +def get_tokenizer( + paths: Union[str, ModelPaths], + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + *args, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """ Load base tokenizer (HF or GGUF), then wrap for model-specific behavior if needed. """ + paths = create_model_paths(paths) + model_cfg = get_model_config(paths) + load_path, from_gguf = paths.tokenizer_load_path + + base_tokenizer = _load_base_tokenizer( + load_path=load_path, + from_gguf=from_gguf, + model_cfg=model_cfg, + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + *args, + **kwargs, + ) + + return _wrap_tokenizer( + base_tokenizer=base_tokenizer, + model_cfg=model_cfg, + processor_dir=paths.processor_dir, + ) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 1dffdaf681..b548339332 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -17,6 +17,7 @@ from lightllm.server.multimodal_params import MultimodalParams, ImageItem from .model_infer import start_model_process, VisualModelRpcClient from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend +from lightllm.utils.config_utils import create_model_paths from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -50,7 +51,12 @@ def __init__( self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}") self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True}) self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - self.model_weightdir = args.model_dir + self.visual_weight_dir, self.processor_dir = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ).resolve_visual_dirs() self.vit_dp = args.visual_dp self.vit_tp = args.visual_tp # image 最大推理 batch size @@ -74,7 +80,8 @@ async def wait_to_model_ready(self): for tp_rank_id in range(self.vit_tp): device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] kvargs = { - "weight_dir": self.model_weightdir, + "weight_dir": self.visual_weight_dir, + "processor_dir": self.processor_dir, "device_id": device_id, "vit_tp": self.vit_tp, "cache_port": self.args.cache_port, diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 92ca2e3836..8b5aa15239 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -7,7 +7,7 @@ import time import torch.distributed as dist from typing import Dict, List, Tuple, Deque, Optional -from transformers.configuration_utils import PretrainedConfig +from lightllm.utils.config_utils import get_model_paths, get_model_config from rpyc.utils.classic import obtain from lightllm.models.qwen_vl.qwen_visual import QWenVisionTransformer from lightllm.models.llava.llava_visual import LlavaVisionModel @@ -52,6 +52,7 @@ def exposed_init_model(self, kvargs): # } weight_dir = kvargs["weight_dir"] + processor_dir = kvargs.get("processor_dir", weight_dir) self.infer_max_batch_size = kvargs["max_batch_size"] self.device_id = kvargs["device_id"] self.vit_tp = kvargs["vit_tp"] @@ -63,11 +64,12 @@ def exposed_init_model(self, kvargs): self.vit_attn_backend = kvargs["vit_attn_backend"] set_vit_att_backend(self.vit_attn_backend) init_vision_distributed_env(kvargs) - model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) + model_cfg = get_model_config(get_model_paths()) try: kvargs = { "weight_dir": weight_dir, + "processor_dir": processor_dir, "data_type": self.data_type, "quant_type": kvargs["quant_type"], "quant_cfg": kvargs["quant_cfg"], diff --git a/lightllm/server/visualserver/visual_only_manager.py b/lightllm/server/visualserver/visual_only_manager.py index 15705a1140..aa489ff51c 100644 --- a/lightllm/server/visualserver/visual_only_manager.py +++ b/lightllm/server/visualserver/visual_only_manager.py @@ -21,6 +21,7 @@ from lightllm.server.multimodal_params import MultimodalParams, ImageItem from .model_infer import start_model_process, VisualModelRpcClient from lightllm.common.basemodel.attention_vit.create_utils import init_vit_att_backend +from lightllm.utils.config_utils import create_model_paths from lightllm.utils.log_utils import init_logger from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread @@ -39,7 +40,12 @@ def __init__( args: StartArgs, ): self.args = args - self.model_weightdir = args.model_dir + self.visual_weight_dir, self.processor_dir = create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ).resolve_visual_dirs() self.vit_dp = args.visual_dp assert self.vit_dp == 1 self.vit_tp = args.visual_tp @@ -106,7 +112,8 @@ async def wait_to_model_ready(self): for tp_rank_id in range(self.vit_tp): device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] kvargs = { - "weight_dir": self.model_weightdir, + "weight_dir": self.visual_weight_dir, + "processor_dir": self.processor_dir, "device_id": device_id, "vit_tp": self.vit_tp, "cache_port": None, # visual only 模式下不使用 embed cache diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index c64e8a912b..7ea8792491 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -1,20 +1,241 @@ import json import os -from typing import Optional, List +from dataclasses import dataclass, field from functools import lru_cache -from .envs_utils import get_env_start_args +from typing import List, Optional, Union + +from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) -def get_config_json(model_path: str): - with open(os.path.join(model_path, "config.json"), "r") as file: - json_obj = json.load(file) - return json_obj +@dataclass(frozen=True) +class ModelPaths: + """ Resolved model-related paths for config / tokenizer / multimodal loading """ + model_dir: str + tokenizer_dir: Optional[str] = None + config_path: Optional[str] = None + mmproj_path: Optional[str] = None + _gguf_path: Optional[str] = field(default=None, init=False, repr=False, compare=False) + + @property + def is_gguf(self) -> bool: + return self._gguf_path is not None + + @property + def gguf_path(self) -> Optional[str]: + return self._gguf_path + + @property + def processor_dir(self) -> str: + return self.tokenizer_dir or self.model_dir + + @property + def tokenizer_load_path(self) -> tuple[str, bool]: + if self._gguf_path is not None and not self.tokenizer_dir: + return self._gguf_path, True + return self.processor_dir, False + + def resolve_visual_dirs(self) -> tuple[Optional[str], Optional[str]]: + if self.is_gguf: + return self.mmproj_path, self.tokenizer_dir + return self.model_dir, self.model_dir + + def load_config(self) -> dict: + if self.config_path is not None: + return _load_config_from_path(self.config_path) + + gguf_path = self.gguf_path + + if gguf_path is not None and self.tokenizer_dir is not None: + hf_config_path = os.path.join(self.tokenizer_dir, "config.json") + assert os.path.isfile(hf_config_path), f"config.json {hf_config_path} is not found" + return _load_config_from_path(hf_config_path) + + if gguf_path is None and self.model_dir: + config_json_path = os.path.join(self.model_dir, "config.json") + if os.path.isfile(config_json_path): + return _load_config_from_path(config_json_path) + + if gguf_path is not None: + return _load_config_from_gguf(gguf_path) + + raise FileNotFoundError( + f"no model config found (config_path={self.config_path!r}, model_dir={self.model_dir!r}). " + "Provide --config_path, place config.json under model_dir, or use a .gguf model path." + ) + + def align_quant_type(self, quant_type: str) -> str: + if self.is_gguf: + return "gguf" + + if quant_type == "gguf": + raise ValueError("--quant_type gguf is not supported for non-GGUF models") + + return quant_type + + def __post_init__(self): + object.__setattr__(self, "_gguf_path", _find_gguf_path_cached(self.model_dir)) + + +def _fill_paths_from_env(paths: ModelPaths) -> ModelPaths: + try: + start_args = get_env_start_args() + except KeyError: + return paths + + config_path = paths.config_path if paths.config_path is not None else getattr(start_args, "config_path", None) + tokenizer_dir = ( + paths.tokenizer_dir if paths.tokenizer_dir is not None else getattr(start_args, "tokenizer_dir", None) + ) + mmproj_path = paths.mmproj_path if paths.mmproj_path is not None else getattr(start_args, "mmproj_path", None) + + if config_path == paths.config_path and tokenizer_dir == paths.tokenizer_dir and mmproj_path == paths.mmproj_path: + return paths + + return ModelPaths( + model_dir=paths.model_dir, + config_path=config_path, + tokenizer_dir=tokenizer_dir, + mmproj_path=mmproj_path, + ) + + +def create_model_paths( + model_dir_or_paths: Union[str, ModelPaths, None] = None, + *, + config_path: Optional[str] = None, + tokenizer_dir: Optional[str] = None, + mmproj_path: Optional[str] = None, +) -> ModelPaths: + if model_dir_or_paths is None: + start_args = get_env_start_args() + return create_model_paths( + start_args.model_dir, + config_path=getattr(start_args, "config_path", None), + tokenizer_dir=getattr(start_args, "tokenizer_dir", None), + mmproj_path=getattr(start_args, "mmproj_path", None), + ) + + if isinstance(model_dir_or_paths, ModelPaths): + paths = model_dir_or_paths + else: + paths = ModelPaths( + model_dir=model_dir_or_paths, + config_path=config_path, + tokenizer_dir=tokenizer_dir, + mmproj_path=mmproj_path, + ) + + return _fill_paths_from_env(paths) + + +@lru_cache(maxsize=1) +def get_model_paths() -> ModelPaths: + return create_model_paths() + + +def _load_config_from_path(config_path: str) -> dict: + if not os.path.isfile(config_path): + raise FileNotFoundError(f"config file not found: {config_path}") + with open(config_path, "r") as file: + return json.load(file) + + +def _load_config_from_gguf(gguf_path: str) -> dict: + from lightllm.common.basemodel.layer_weights.gguf_load_utils import get_gguf_reader + + return get_gguf_reader(gguf_path).load_config() + +@lru_cache(maxsize=128) +def _find_gguf_path_cached(model_dir: Optional[str]) -> Optional[str]: + if not model_dir: + return None + + if model_dir.endswith(".gguf") and os.path.isfile(model_dir): + return model_dir + + if os.path.isdir(model_dir): + gguf_files = sorted( + os.path.join(model_dir, name) for name in os.listdir(model_dir) if name.endswith(".gguf") + ) + if not gguf_files: + return None + if len(gguf_files) > 1: + raise ValueError( + f"multiple GGUF files found in {model_dir} is not supported, please specify the target gguf file." + ) + return gguf_files[0] + + return None + + +def apply_gguf_quant_type(paths: Union[str, ModelPaths], quant_type: str) -> str: + """Align quant_type for GGUF models and log when overriding.""" + paths = create_model_paths(paths) + aligned = paths.align_quant_type(quant_type) + if aligned != quant_type: + logger.warning( + f"model_dir contains GGUF weights; overriding --quant_type {quant_type!r} -> {aligned!r}" + ) + return aligned + + +def _normalize_gguf_model_config(config: dict) -> dict: + if config.get("head_dim") is None: + hidden_size = config.get("hidden_size") or config.get("n_embd") or config.get("n_embed") + num_heads = config.get("num_attention_heads") or config.get("n_head") + if hidden_size and num_heads: + config["head_dim"] = hidden_size // num_heads + return config + + +@lru_cache(maxsize=None) +def get_model_config(paths: Union[str, ModelPaths]) -> dict: + paths = create_model_paths(paths) + config = paths.load_config() + if paths.is_gguf: + config = _normalize_gguf_model_config(config) + return config -def _derive_max_req_total_len_from_model_config(model_dir: str) -> Optional[int]: + +def check_gguf_multimodal_paths(paths: ModelPaths, enable_multimodal: bool = False) -> None: + if not enable_multimodal or not paths.is_gguf: + return + + if not paths.tokenizer_dir: + raise ValueError("tokenizer_dir is required when enable_multimodal is True for GGUF models") + if not os.path.isdir(paths.tokenizer_dir): + raise FileNotFoundError(f"tokenizer_dir {paths.tokenizer_dir} is not found") + + effective_config_path = paths.config_path + if effective_config_path is not None: + if not os.path.isfile(effective_config_path): + raise FileNotFoundError(f"config.json {effective_config_path} is not found") + else: + effective_config_path = os.path.join(paths.tokenizer_dir, "config.json") + if not os.path.isfile(effective_config_path): + raise FileNotFoundError( + f"config.json is not provided and not found in tokenizer_dir: " + f"{paths.tokenizer_dir} when enable_multimodal is True for GGUF models" + ) + + processor_path = os.path.join(paths.tokenizer_dir, "preprocessor_config.json") + if not os.path.isfile(processor_path): + raise FileNotFoundError( + f"preprocessor_config.json not found in tokenizer_dir: " + f"{paths.tokenizer_dir} when enable_multimodal is True for GGUF models" + ) + + if not paths.mmproj_path: + raise ValueError("mmproj_path is required when enable_multimodal is True for GGUF models") + if not os.path.isfile(paths.mmproj_path): + raise FileNotFoundError(f"mmproj_path {paths.mmproj_path} is not found") + + +def _derive_max_req_total_len_from_model_config(paths: ModelPaths) -> Optional[int]: """ Derive `max_req_total_len` from model config.json. @@ -24,7 +245,7 @@ def _derive_max_req_total_len_from_model_config(model_dir: str) -> Optional[int] """ try: - cfg = get_config_json(model_dir) + cfg = get_model_config(paths) except Exception as e: logger.warning(f"failed to load config.json for max_req_total_len derive: {e}") return None @@ -113,7 +334,7 @@ def _find_rope_scaling() -> dict: return None -def auto_set_max_req_total_len(args) -> None: +def auto_set_max_req_total_len(args, paths: ModelPaths) -> None: """ Ensure `args.max_req_total_len` is an int. @@ -125,14 +346,13 @@ def auto_set_max_req_total_len(args) -> None: if args.max_req_total_len is not None: return - model_dir = args.model_dir - if not model_dir: + if not paths.model_dir: logger.warning("model_dir is empty; fallback max_req_total_len=16384") args.max_req_total_len = default_fallback return try: - derived = _derive_max_req_total_len_from_model_config(model_dir) + derived = _derive_max_req_total_len_from_model_config(paths) except Exception as e: logger.warning(f"failed to derive max_req_total_len from model config: {e}") derived = None @@ -146,8 +366,8 @@ def auto_set_max_req_total_len(args) -> None: logger.info(f"auto derived max_req_total_len={args.max_req_total_len} from model config") -def _get_config_llm_keyvalue(model_path: str, key_name: list[str]): - config_json = get_config_json(model_path) +def _get_config_llm_keyvalue(paths: ModelPaths, key_name: list[str]): + config_json = get_model_config(paths) for key in key_name: try: value = config_json[key] @@ -167,53 +387,57 @@ def _get_config_llm_keyvalue(model_path: str, key_name: list[str]): return None -def get_hidden_size(model_path: str) -> Optional[int]: - hidden_size = _get_config_llm_keyvalue(model_path=model_path, key_name=["hidden_size", "n_embd", "n_embed"]) +def get_hidden_size(model_dir_or_paths: Union[str, ModelPaths]) -> Optional[int]: + paths = create_model_paths(model_dir_or_paths) + hidden_size = _get_config_llm_keyvalue(paths, key_name=["hidden_size", "n_embd", "n_embed"]) if isinstance(hidden_size, int): return hidden_size return None @lru_cache(maxsize=None) -def get_num_key_value_heads(model_path: str) -> int: - num_key_value_heads = _get_config_llm_keyvalue(model_path=model_path, key_name=["num_key_value_heads"]) +def get_num_key_value_heads(model_dir_or_paths: Union[str, ModelPaths]) -> int: + paths = create_model_paths(model_dir_or_paths) + num_key_value_heads = _get_config_llm_keyvalue(paths, key_name=["num_key_value_heads"]) if isinstance(num_key_value_heads, int): return num_key_value_heads return None @lru_cache(maxsize=None) -def get_num_attention_heads(model_path: str) -> int: - num_attention_heads = _get_config_llm_keyvalue(model_path=model_path, key_name=["num_attention_heads"]) +def get_num_attention_heads(model_dir_or_paths: Union[str, ModelPaths]) -> int: + paths = create_model_paths(model_dir_or_paths) + num_attention_heads = _get_config_llm_keyvalue(paths, key_name=["num_attention_heads"]) if isinstance(num_attention_heads, int): return num_attention_heads return None @lru_cache(maxsize=None) -def get_head_dim(model_path: str) -> int: - head_dim = _get_config_llm_keyvalue(model_path=model_path, key_name=["head_dim"]) +def get_head_dim(model_dir_or_paths: Union[str, ModelPaths]) -> int: + paths = create_model_paths(model_dir_or_paths) + head_dim = _get_config_llm_keyvalue(paths, key_name=["head_dim"]) if isinstance(head_dim, int): return head_dim - # calcu head_dim - head_dim = get_hidden_size(model_path=model_path) // get_num_attention_heads(model_path=model_path) + head_dim = get_hidden_size(paths) // get_num_attention_heads(paths) return head_dim @lru_cache(maxsize=None) -def get_layer_num(model_path: str) -> int: - num_hidden_layers = _get_config_llm_keyvalue(model_path=model_path, key_name=["num_hidden_layers"]) +def get_layer_num(model_dir_or_paths: Union[str, ModelPaths]) -> int: + paths = create_model_paths(model_dir_or_paths) + num_hidden_layers = _get_config_llm_keyvalue(paths, key_name=["num_hidden_layers"]) if isinstance(num_hidden_layers, int): return num_hidden_layers return None -def get_eos_token_ids(model_path: str) -> Optional[List[int]]: +def get_eos_token_ids(model_dir_or_paths: Union[str, ModelPaths]) -> Optional[List[int]]: + paths = create_model_paths(model_dir_or_paths) try: - # qwen3-omini special eos_token_id - config_json = get_config_json(model_path) + config_json = get_model_config(paths) assert config_json["architectures"][0] == "Qwen3OmniMoeForConditionalGeneration" return [151645] except: @@ -221,20 +445,20 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]: # Qwen3.5 checkpoints can have an eos_token_id in config that differs from # tokenizer.eos_token_id. In practice tokenizer.eos_token_id is the reliable - # stop id (<|im_end|>) for detokenization/stop behavior. + # stop id (<|im_end|>) try: - config_json = get_config_json(model_path) + config_json = get_model_config(paths) model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type") if model_type in {"qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"}: from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) + tokenizer = AutoTokenizer.from_pretrained(paths.processor_dir, trust_remote_code=False) if tokenizer.eos_token_id is not None: return [int(tokenizer.eos_token_id)] except Exception: pass - eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) + eos_token_id = _get_config_llm_keyvalue(paths, key_name=["eos_token_id"]) if isinstance(eos_token_id, int): return [eos_token_id] if isinstance(eos_token_id, list): @@ -244,9 +468,10 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]: return -def get_model_architectures(model_path: str): +def get_model_architectures(model_dir_or_paths: Union[str, ModelPaths]): + paths = create_model_paths(model_dir_or_paths) try: - config_json = get_config_json(model_path) + config_json = get_model_config(paths) arch = config_json["architectures"][0] return arch except: @@ -254,9 +479,13 @@ def get_model_architectures(model_path: str): return "unknown_architecture" -def get_vocab_size(model_path: str): +def get_vocab_size(paths: Optional[Union[str, ModelPaths]] = None) -> int: try: - config_json = get_config_json(model_path) + if paths is None: + paths = get_model_paths() + elif not isinstance(paths, ModelPaths): + paths = create_model_paths(paths) + config_json = get_model_config(paths) # qwen3-omini special if "thinker_config" in config_json: config_json = config_json["thinker_config"] @@ -275,8 +504,9 @@ def get_vocab_size(model_path: str): return 0 -def get_dtype(model_path: str): - torch_dtype = _get_config_llm_keyvalue(model_path=model_path, key_name=["torch_dtype", "dtype", "model_dtype"]) +def get_dtype(model_dir_or_paths: Union[str, ModelPaths]): + paths = create_model_paths(model_dir_or_paths) + torch_dtype = _get_config_llm_keyvalue(paths, key_name=["torch_dtype", "dtype", "model_dtype"]) if torch_dtype is None: logger.warning("torch_dtype not in config.json, use float16 as default") return "float16" @@ -286,8 +516,7 @@ def get_dtype(model_path: str): @lru_cache(maxsize=None) def get_fixed_kv_len(): - start_args = get_env_start_args() - model_cfg = get_config_json(start_args.model_dir) + model_cfg = get_model_config(get_model_paths()) if "prompt_cache_token_ids" in model_cfg: return len(model_cfg["prompt_cache_token_ids"]) else: @@ -295,11 +524,10 @@ def get_fixed_kv_len(): @lru_cache(maxsize=None) -def has_vision_module(model_path: str) -> bool: +def has_vision_module(model_dir_or_paths: Union[str, ModelPaths]) -> bool: + paths = create_model_paths(model_dir_or_paths) try: - from transformers.configuration_utils import PretrainedConfig - - model_cfg, _ = PretrainedConfig.get_config_dict(model_path) + model_cfg = get_model_config(paths) model_type = model_cfg["model_type"] if model_type == "qwen": # QWenVisionTransformer @@ -338,16 +566,15 @@ def has_vision_module(model_path: str) -> bool: else: raise Exception("unknown vision model type") except: - logger.info(f"model path: {model_path} does not has vision module") + logger.info(f"model path: {paths.model_dir} does not has vision module") return False @lru_cache(maxsize=None) -def has_audio_module(model_path: str) -> bool: +def has_audio_module(model_dir_or_paths: Union[str, ModelPaths]) -> bool: + paths = create_model_paths(model_dir_or_paths) try: - from transformers.configuration_utils import PretrainedConfig - - model_cfg, _ = PretrainedConfig.get_config_dict(model_path) + model_cfg = get_model_config(paths) if model_cfg.get("thinker_config") is not None: model_cfg = model_cfg["thinker_config"] audio_config = model_cfg["audio_config"] @@ -361,40 +588,39 @@ def has_audio_module(model_path: str) -> bool: else: raise Exception("unknown audio model type") except: - logger.info(f"model path: {model_path} does not has audio module") + logger.info(f"model path: {paths.model_dir} does not has audio module") return False @lru_cache(maxsize=None) -def is_linear_att_mixed_model(model_path: str) -> bool: +def is_linear_att_mixed_model(model_dir_or_paths: Union[str, ModelPaths]) -> bool: + paths = create_model_paths(model_dir_or_paths) try: - from transformers.configuration_utils import PretrainedConfig - - model_cfg, _ = PretrainedConfig.get_config_dict(model_path) + model_cfg = get_model_config(paths) model_type = model_cfg["model_type"] if model_type in ["qwen3_5", "qwen3_5_moe", "qwen3_5_text", "qwen3_5_moe_text"]: return True else: return False except: - logger.info(f"model path: {model_path} does not has linear hybrid attention") + logger.info(f"model path: {paths.model_dir} does not has linear hybrid attention") return False -def get_model_type(model_path: str) -> Optional[str]: - """Get model type from config.json""" +def get_model_type(paths: Union[str, ModelPaths]) -> Optional[str]: + """Get model type from model config.""" try: - config_json = get_config_json(model_path) + config_json = get_model_config(paths) model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type") return model_type except Exception as e: - logger.error(f"Failed to get model_type from {model_path}: {e}") + logger.error(f"Failed to get model_type (paths={paths!r}): {e}") return None -def get_tool_call_parser_for_model(model_path: str) -> Optional[str]: +def get_tool_call_parser_for_model(paths: Union[str, ModelPaths]) -> Optional[str]: """Auto-detect tool_call_parser based on model type""" - model_type = get_model_type(model_path) + model_type = get_model_type(paths) if model_type is None: return None @@ -421,9 +647,8 @@ def get_tool_call_parser_for_model(model_path: str) -> Optional[str]: return None -def get_reasoning_parser_for_model(model_path: str) -> Optional[str]: - """Auto-detect reasoning_parser based on model type""" - model_type = get_model_type(model_path) +def get_reasoning_parser_for_model(paths: Union[str, ModelPaths]) -> Optional[str]: + model_type = get_model_type(paths) if model_type is None: return None diff --git a/lightllm/utils/gguf_tokenizer_utils.py b/lightllm/utils/gguf_tokenizer_utils.py new file mode 100644 index 0000000000..6d67a7d31d --- /dev/null +++ b/lightllm/utils/gguf_tokenizer_utils.py @@ -0,0 +1,207 @@ +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers.integrations.ggml import ( + GGUF_TO_FAST_CONVERTERS, + GGUF_TOKENIZER_MAPPING, + convert_gguf_tokenizer, +) +from transformers.models.auto.configuration_auto import CONFIG_MAPPING +from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING, tokenizer_class_from_name +from typing import Any, Dict, Optional, Tuple, Union + +from lightllm.common.basemodel.layer_weights.gguf_load_utils import LightLLMGGUFReader, get_gguf_reader +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def _build_gguf_tokenizer_field_map() -> Dict[str, Tuple[str, str]]: + """ Build a mapping from GGUF tokenizer field name to HuggingFace tokenizer field name. """ + field_map: Dict[str, Tuple[str, str]] = {} + # Build mapping for tokenizer fields + for gguf_suffix, hf_key in GGUF_TOKENIZER_MAPPING["tokenizer"].items(): + # e.g., "ggml.model" -> ("tokenizer", "model_type") + field_map[f"tokenizer.{gguf_suffix}"] = ("tokenizer", hf_key) + # Build mapping for tokenizer_config fields + for gguf_suffix, hf_key in GGUF_TOKENIZER_MAPPING["tokenizer_config"].items(): + # e.g., "ggml.model" -> ("tokenizer_config", "model_type") + gguf_key = f"tokenizer.{gguf_suffix}" + if gguf_key not in field_map: + field_map[gguf_key] = ("tokenizer_config", hf_key) + # Add special case for add_bos_token + field_map["tokenizer.ggml.add_bos_token"] = ("tokenizer_config", "add_bos_token") + + return field_map + + +def _parse_gguf_tokenizer_from_reader(reader: LightLLMGGUFReader) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ Parse tokenizer metadata from a GGUFReader using ReaderField.contents(). """ + tokenizer: Dict[str, Any] = {} + tokenizer_config: Dict[str, Any] = {} + + tokenizer_field_map = _build_gguf_tokenizer_field_map() + for gguf_key, (bucket, hf_key) in tokenizer_field_map.items(): + value = reader.read_field(gguf_key) + if value is None: + continue + # Add value to tokenizer or tokenizer_config + if bucket == "tokenizer": + tokenizer[hf_key] = value + else: + tokenizer_config[hf_key] = value + + if "model_type" not in tokenizer_config and tokenizer.get("tokenizer_type") is not None: + tokenizer_config["model_type"] = tokenizer["tokenizer_type"] + + for id_key in ("bos_token_id", "eos_token_id", "unk_token_id", "pad_token_id"): + token_id = tokenizer.get(id_key) + if token_id is not None: + tokenizer_config[id_key] = token_id + + return tokenizer, tokenizer_config + + +def _architecture_to_converter_key(architecture: str) -> str: + """ Convert GGUF architecture to HuggingFace converter key. """ + arch = architecture.replace("-", "_") + if arch in GGUF_TO_FAST_CONVERTERS: + return arch + # e.g., qwen2moe / qwen3moe -> qwen2_moe / qwen3_moe + if arch.endswith("moe") and not arch.endswith("_moe"): + alt = f"{arch[:-3]}_moe" + if alt in GGUF_TO_FAST_CONVERTERS: + return alt + + return arch + + +def _resolve_gguf_tokenizer_architecture(model_config: Dict[str, Any], reader: LightLLMGGUFReader) -> str: + """ Resolve GGUF tokenizer architecture from GGUF reader or model config. """ + architecture = None + # Read architecture from GGUF reader + arch = reader.read_field("general.architecture") + if isinstance(arch, (list, tuple)): + arch = arch[0] + if isinstance(arch, str): + architecture = arch + + # Read architecture from model config + if not architecture: + architecture = model_config.get("model_type") + + if not architecture: + raise ValueError( + "Can't resolve GGUF tokenizer architecture: " + "missing general.architecture in GGUF and model_type in config" + ) + + return _architecture_to_converter_key(architecture) + + +def _build_tokenizer_init_kwargs( + tokenizer_dict: Dict[str, Any], + tokenizer_config: Dict[str, Any], + additional_kwargs: Dict[str, Any], +) -> Dict[str, Any]: + """ Build tokenizer initialization kwargs from tokenizer dictionary and configuration. """ + # Merge tokenizer configuration and additional kwargs + init_kwargs = {**tokenizer_config, **additional_kwargs} + + tokens = tokenizer_dict.get("tokens") or [] + for id_key, token_key in ( + ("bos_token_id", "bos_token"), + ("eos_token_id", "eos_token"), + ("unk_token_id", "unk_token"), + ("pad_token_id", "pad_token"), + ): + token_id = tokenizer_dict.get(id_key) + if token_id is None: + token_id = init_kwargs.get(id_key) + + if token_id is not None and token_id < len(tokens): + init_kwargs.setdefault(token_key, tokens[token_id]) + + return init_kwargs + + +def _get_hf_tokenizer_class_from_config( + model_config: Dict[str, Any], + use_fast: bool = True, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """ Resolve the same tokenizer class AutoTokenizer would use for this config. """ + # Get tokenizer class name from tokenizer_class key of model config + tokenizer_class_name: str = model_config.get("tokenizer_class") + if tokenizer_class_name: + candidates = [] + if use_fast and not tokenizer_class_name.endswith("Fast"): + candidates.append(f"{tokenizer_class_name}Fast") + candidates.append(tokenizer_class_name) + for candidate in candidates: + tokenizer_class = tokenizer_class_from_name(candidate) + if tokenizer_class is not None: + return tokenizer_class + + # Get HuggingFace tokenizer class from model_type key of model config + model_type = model_config.get("model_type") + # CONFIG_MAPPING: model_type -> config class + if not model_type or model_type not in CONFIG_MAPPING: + return PreTrainedTokenizerFast + + config_class = CONFIG_MAPPING[model_type] + # TOKENIZER_MAPPING: config class -> (tokenizer class, fast tokenizer class) + if config_class not in TOKENIZER_MAPPING: + return PreTrainedTokenizerFast + + tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[config_class] + if use_fast and tokenizer_class_fast is not None: + return tokenizer_class_fast + + if tokenizer_class_py is not None: + return tokenizer_class_py + + return PreTrainedTokenizerFast + + +def load_tokenizer_from_gguf( + gguf_path: str, + model_config: Optional[Dict[str, Any]] = None, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """ Load the tokenizer from GGUF file. """ + if model_config is None: + raise ValueError("model_config is required when loading tokenizer from GGUF") + + reader = get_gguf_reader(gguf_path) + tokenizer_dict, tokenizer_config = _parse_gguf_tokenizer_from_reader(reader) + architecture = _resolve_gguf_tokenizer_architecture(model_config, reader=reader) + + if "tokens" not in tokenizer_dict: + raise ValueError(f"GGUF file does not contain tokenizer.ggml.tokens: {gguf_path}") + + if architecture not in GGUF_TO_FAST_CONVERTERS: + supported = ", ".join(sorted(GGUF_TO_FAST_CONVERTERS.keys())) + raise ValueError( + f"unsupported GGUF tokenizer architecture {architecture!r} for {gguf_path}; " + f"supported: {supported}" + ) + + logger.info( + f"Loading tokenizer from GGUF ReaderField metadata: {gguf_path} " + f"(architecture={architecture}, vocab_size={len(tokenizer_dict['tokens'])})" + ) + + # Convert GGUF tokenizer to HuggingFace tokenizer + fast_tokenizer, additional_kwargs = convert_gguf_tokenizer(architecture, tokenizer_dict) + + init_kwargs = _build_tokenizer_init_kwargs(tokenizer_dict, tokenizer_config, additional_kwargs) + + tokenizer_kwargs = {k: v for k, v in kwargs.items() if k != "trust_remote_code"} + init_kwargs.update(tokenizer_kwargs) + + use_fast = kwargs.pop("use_fast", True) + # Get HuggingFace tokenizer class from model config + tokenizer_class = _get_hf_tokenizer_class_from_config(model_config, use_fast=use_fast) + + logger.info(f"GGUF tokenizer class: {tokenizer_class.__name__}") + + # Initialize HuggingFace tokenizer + return tokenizer_class(tokenizer_object=fast_tokenizer, **init_kwargs) diff --git a/lightllm/utils/llm_utils.py b/lightllm/utils/llm_utils.py index ced75615d6..dfaccaca38 100644 --- a/lightllm/utils/llm_utils.py +++ b/lightllm/utils/llm_utils.py @@ -1,5 +1,4 @@ from functools import lru_cache -from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.log_utils import init_logger @@ -8,10 +7,9 @@ @lru_cache(maxsize=None) def get_llm_model_class(): - from transformers.configuration_utils import PretrainedConfig - - model_cfg, _ = PretrainedConfig.get_config_dict(get_env_start_args().model_dir) from lightllm.models import get_model_class + from lightllm.utils.config_utils import get_model_config, get_model_paths + model_cfg = get_model_config(get_model_paths()) model_class = get_model_class(model_cfg=model_cfg) return model_class diff --git a/lightllm/utils/shm_size_check.py b/lightllm/utils/shm_size_check.py index 2f99a855e7..ad7f5915a4 100644 --- a/lightllm/utils/shm_size_check.py +++ b/lightllm/utils/shm_size_check.py @@ -6,7 +6,7 @@ from lightllm.server.core.objs.req import ChunkedPrefillReq, TokenHealingReq from lightllm.server.multimodal_params import ImageItem from lightllm.server.tokenizer import get_tokenizer -from lightllm.utils.config_utils import get_hidden_size +from lightllm.utils.config_utils import create_model_paths, get_hidden_size from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -86,7 +86,16 @@ def _get_recommended_shm_size_gb(args, max_image_resolution=(3940, 2160), dtype_ """ 获取所需的 /dev/shm 大小(以GB为单位)。 """ - tokenizer = get_tokenizer(args.model_dir, trust_remote_code=True) + tokenizer = get_tokenizer( + create_model_paths( + args.model_dir, + config_path=args.config_path, + tokenizer_dir=args.tokenizer_dir, + mmproj_path=args.mmproj_path, + ), + tokenizer_mode=args.tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) # 估算input_token和logprob占用shm大小,由于是double和int64,所以固定占用8个字节 input_token_logprob_size_bytes = args.running_max_req_size * 8 * 2 * args.max_req_total_len @@ -122,7 +131,7 @@ def _get_recommended_shm_size_gb(args, max_image_resolution=(3940, 2160), dtype_ max_image_tokens = tokenizer.get_image_token_length(fake_image_item) # 估算图片 token 所需的资源 - hidden_size = get_hidden_size(args.model_dir) + hidden_size = get_hidden_size(paths) if hidden_size is None: logger.warning( "Model config not contain 'hidden_size', " "using 4096 by default to calculate recommended shm size."