Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph
from lightllm.common.quantization import Quantcfg
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
from lightllm.utils.config_utils import create_model_paths, get_model_config
from lightllm.utils.log_utils import init_logger
from lightllm.utils.dist_utils import get_dp_world_size
from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type, get_added_mtp_kv_layer_num
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(self, kvargs):
self.args = get_env_start_args()
self.run_mode = kvargs["run_mode"]
self.weight_dir_ = kvargs["weight_dir"]
self.model_paths_ = create_model_paths(self.weight_dir_)
self.max_total_token_num = kvargs["max_total_token_num"]
self.batch_max_tokens = kvargs.get("batch_max_tokens", None)
self.load_way = kvargs.get("load_way", "HF")
Expand Down Expand Up @@ -104,6 +106,8 @@ def __init__(self, kvargs):
self._verify_params()
self._init_quant()

# Read GGUF weights mapping
self._init_gguf()
self._init_weights()
self._init_req_manager()
self._init_mem_manager()
Expand Down Expand Up @@ -144,8 +148,7 @@ def _wait_other_modules_ready(self):
return

def _init_config(self):
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
self.config = json.load(json_file)
self.config = get_model_config(self.model_paths_)
# rename keys
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
Expand All @@ -168,6 +171,15 @@ def _init_quant(self):
self.quant_cfg = Quantcfg(self.config, self.quant_type, self.quant_cfg_path)
logger.info(f"Initial quantization. " f"The default quantization method is {self.quant_cfg.quant_type}")

def _init_gguf(self):
if self.model_paths_.gguf_path is None:
return

from lightllm.common.basemodel.layer_weights.gguf_load_utils import get_gguf_reader

reader = get_gguf_reader(self.model_paths_.gguf_path)
self.quant_cfg.gguf_quant_meta_map = reader.build_quant_meta_map(self.config)

def _init_weights(self, start_layer_index=0):
self.pre_post_weight = self.pre_and_post_weight_class(self.data_type, network_config=self.config)
self.trans_layers_weight = [
Expand All @@ -182,13 +194,24 @@ def _init_weights(self, start_layer_index=0):
return

def _load_hf_weights(self):
load_hf_weights(
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict,
)
if self.model_paths_.gguf_path is not None:
from lightllm.common.basemodel.layer_weights.gguf_load_utils import get_gguf_reader

reader = get_gguf_reader(self.model_paths_.gguf_path)
reader.load_weights(
data_type=self.data_type,
config=self.config,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
)
else:
load_hf_weights(
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict,
)
self.pre_post_weight.verify_load()
[weight.verify_load() for weight in self.trans_layers_weight]
return
Expand Down
Loading