Skip to content

Latest commit

 

History

History
executable file
·
617 lines (506 loc) · 30.7 KB

File metadata and controls

executable file
·
617 lines (506 loc) · 30.7 KB

How to Add New Model Support

1. Current Inference Architecture Introduction

Under the lightllm/common/basemodel directory is the base class implementation of the entire inference architecture

├── basemodel.py   # Model framework class
├── infer_struct.py # Inference state class
├── __init__.py
├── layer_infer # Base class implementation of inference layers
│   ├── base_layer_infer.py
│   ├── __init__.py
│   ├── post_layer_infer.py
│   ├── pre_layer_infer.py
│   ├── template # Template implementation of inference layers, inheriting from templates can reduce development effort and duplicate code
│   │   ├── __init__.py
│   │   ├── post_layer_infer_template.py
│   │   ├── pre_layer_infer_template.py
│   │   └── transformer_layer_infer_template.py
│   └── transformer_layer_infer.py
├── layer_weights # Weight base class implementation
│   ├── base_layer_weight.py
│   ├── hf_load_utils.py
│   ├── __init__.py
│   ├── pre_and_post_layer_weight.py
│   └── transformer_layer_weight.py
└── triton_kernel # Some commonly used triton kernel operators
    ├── apply_penalty.py
    ├── destindex_copy_kv.py
    └── __init__.py

As shown above, the current model inference architecture mainly consists of two parts: weights and inference.

Weights

Under the layer_weights directory is the weight-related code. Theoretically, for a newly added model, you need to inherit and implement the PreAndPostLayerWeight and TransformerLayerWeight classes in pre_and_post_layer_weight.py and transformer_layer_weight.py to implement weight loading.

Weight Base Class Responsibilities
PreAndPostLayerWeight Responsible for loading weights of the first Embedding layer and the last post-processing layer of LLM models, and splitting weights according to the tp parameter used
TransformerLayerWeight Responsible for loading weights of transformer layers of LLM models and splitting weights according to the tp parameter used

Inference

Under the layer_infer directory are the relevant base classes for inference processing, and some templates are provided under the template directory. Inheriting from template classes can reduce some unnecessary duplicate code and simplify implementation. There are three inference classes that need to be inherited and implemented under this directory.

Inference Base Class Responsibilities
PreLayerInfer Responsible for inference of Embedding layer
TransformerLayerInfer Responsible for inference of transformer layer
PostLayerInfer Responsible for converting the final hidden layer output of the network to logits inference

The base class BaseLayerInfer of the above three classes provides two most important external service function interfaces. All inference behaviors will enter through these two interfaces.

Interface Responsibilities
def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BaseLayerWeight): First inference of batch (also called prefill in code)
def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BaseLayerWeight): Single step decode stage inference

Operators

Under the triton_kernel directory are some operators needed for inference implemented using OpenAI triton.

State Class

The InferStateInfo class in infer_struct.py is a state class that passes some important information between layers during a model inference. Different models can inherit and implement this class to add unique state information that each model needs to pass. The InferStateInfo class provides an inheritable init_some_extra_state interface for initializing additional unique information.

    def init_some_extra_state(self, 
            model, 
            batch_size, 
            total_token_num,
            max_len_in_batch,
            input_ids : torch.Tensor,
            b_loc : torch.Tensor,
            b_start_loc : torch.Tensor,
            b_seq_len : torch.Tensor,
            is_prefill):
        pass

Model Framework Class

The TpPartBaseModel class in basemodel.py is the entry point of the entire model. Each type of model needs to inherit and implement this class. This class uses inference classes, weight classes, and state classes in a building block-like manner to complete model loading and inference functions. There are many interfaces that can be inherited and implemented to complete unique operations for each model type.

class TpPartBaseModel:
    # weight class
    pre_and_post_weight_class = None
    transformer_weight_class = None

    # infer class
    pre_layer_infer_class = None
    post_layer_infer_class = None
    transformer_layer_infer_class = None

    # infer state class
    infer_state_class = InferStateInfo

    def __init__(self, weight_dir, max_total_token_num, load_way="HF", mode=[]):
        self.weight_dir_ = weight_dir
        self.max_total_token_num = max_total_token_num
        self.load_way = load_way
        self.mode = mode

        self._init_config()
        self._verify_must()
        self._verify_params()
        self._init_weights()
        self._init_mem_manager()
        self._init_infer_layer()
        self._init_some_value()
        self._init_custom()
        return
   ...
   ...

Commonly used interfaces that need to be inherited and implemented

Interface Function
def _init_config(self): Read the config.json for initializing the model and perform some key name legalization operations
def _verify_params(self): Validate parameters
def _init_mem_manager(self): Initialize the mem manager object used by token attention
def _init_some_value(self): Initialize values of some member variables that the inference framework will use
def _init_custom(self): Some personalized initialization of the model itself, such as llama initializing its own Rotary values

2. Example of Adding Bloom Model

The specific implementation is under the lightllm/models/bloom directory. Please read the source code for the code snippets below. The triton_kernel directory contains some kernels used by inference classes, which will not be introduced in detail in this article. At the same time, the bloom model uses the default state class because it doesn't need to pass special state information. For a deeper understanding of the entire framework, you can further refer to the source code implementation of llama and llama2 model integration.

(1) Add Implementation Weight Classes

pre_and_post_layer_weight.py

import torch
import numpy as np
from lightllm.common.basemodel import PreAndPostLayerWeight

class BloomPreAndPostLayerWeight(PreAndPostLayerWeight):
    def __init__(self, tp_rank, world_size, data_type, network_config, mode):
        super().__init__(tp_rank, world_size, data_type, network_config, mode)

    def load_hf_weights(self, weights):

        if "word_embeddings_layernorm.weight" in weights:
            self.pre_norm_weight_ = self._cuda(weights['word_embeddings_layernorm.weight'])
        if "word_embeddings_layernorm.bias" in weights:
            self.pre_norm_bias_ = self._cuda(weights['word_embeddings_layernorm.bias'])
        if "ln_f.weight" in weights:
            self.final_norm_weight_ = self._cuda(weights['ln_f.weight'])
        if "ln_f.bias" in weights:
            self.final_norm_bias_ = self._cuda(weights["ln_f.bias"])
        if "word_embeddings.weight" in weights:
            vob_size = self.network_config_["vocab_size"]
            split_vob_size = vob_size // self.tp_world_size_
            self.wte_weight_ = self._cuda(weights["word_embeddings.weight"][split_vob_size *
                                                                 self.tp_rank_: split_vob_size * (self.tp_rank_ + 1), :])
            self.lm_head_weight_ = self.wte_weight_
        return

transformer_layer_weight.py

import torch
import math
import numpy as np
from lightllm.common.basemodel import TransformerLayerWeight


class BloomTransformerLayerWeight(TransformerLayerWeight):
    def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode):
        super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode)
        return

    def init_static_params(self):
        head_num = self.network_config_["num_attention_heads"]
        tp_head_num = head_num // self.tp_world_size_
        tmp_alibi = self._generate_alibi(head_num, dtype=torch.float32)
        assert head_num % self.tp_world_size_ == 0
        self.tp_alibi = tmp_alibi[self.tp_rank_ * tp_head_num: (self.tp_rank_ + 1) * tp_head_num].contiguous().cuda()
        return
    
    def load_hf_weights(self, weights):

        self._load_qkvo_weights(weights)
        self._load_ffn_weights(weights)
        return

    def _load_qkvo_weights(self, weights):
        if f"h.{self.layer_num_}.input_layernorm.weight" in weights:
            self.att_norm_weight_ = self._cuda(weights[f"h.{self.layer_num_}.input_layernorm.weight"])
        if f"h.{self.layer_num_}.input_layernorm.bias" in weights:
            self.att_norm_bias_ = self._cuda(weights[f"h.{self.layer_num_}.input_layernorm.bias"])

        if f"h.{self.layer_num_}.self_attention.query_key_value.weight" in weights:
            n_embed = self.network_config_["n_embed"]
            split_n_embed = n_embed // self.tp_world_size_
            head_num = self.network_config_["num_attention_heads"]
            att_qkv_dense_weight = weights[f"h.{self.layer_num_}.self_attention.query_key_value.weight"].reshape(head_num, 3, -1, n_embed)
            self.q_weight_ = self._cuda(att_qkv_dense_weight[:,
                                                  0,
                                                  :,
                                                  :].reshape(-1,
                                                             n_embed)[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1),
                                                                      :].transpose(0,
                                                                                   1))
            self.k_weight_ = self._cuda(att_qkv_dense_weight[:,
                                                  1,
                                                  :,
                                                  :].reshape(-1,
                                                             n_embed)[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1),
                                                                      :].transpose(0,
                                                                                   1))
            self.v_weight_ = self._cuda(att_qkv_dense_weight[:,
                                                  2,
                                                  :,
                                                  :].reshape(-1,
                                                             n_embed)[split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1),
                                                                      :].transpose(0,
                                                                                   1))
        if f"h.{self.layer_num_}.self_attention.query_key_value.bias" in weights:
            n_embed = self.network_config_["n_embed"]
            split_n_embed = n_embed // self.tp_world_size_
            head_num = self.network_config_["num_attention_heads"]
            att_qkv_dense_bias = weights[f"h.{self.layer_num_}.self_attention.query_key_value.bias"].reshape(head_num, 3, -1)
            self.q_bias_ = self._cuda(att_qkv_dense_bias[:, 0, :].reshape(-1)[split_n_embed *
                                                                   self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)])
            self.k_bias_ = self._cuda(att_qkv_dense_bias[:, 1, :].reshape(-1)[split_n_embed *
                                                                   self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)])
            self.v_bias_ = self._cuda(att_qkv_dense_bias[:, 2, :].reshape(-1)[split_n_embed *
                                                                   self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)])

        if f"h.{self.layer_num_}.self_attention.dense.weight" in weights:
            n_embed = self.network_config_["n_embed"]
            split_n_embed = n_embed // self.tp_world_size_
            self.o_weight_ = self._cuda(weights[f"h.{self.layer_num_}.self_attention.dense.weight"][:,
                                                                                                     split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)].transpose(0, 1))
        if f"h.{self.layer_num_}.self_attention.dense.bias" in weights:
            self.o_bias_ = self._cuda(weights[f"h.{self.layer_num_}.self_attention.dense.bias"])
        return 

    def _load_ffn_weights(self, weights):
        if f"h.{self.layer_num_}.post_attention_layernorm.weight" in weights:
            self.ffn_norm_weight_ = self._cuda(weights[f"h.{self.layer_num_}.post_attention_layernorm.weight"])
            self.ffn_norm_bias_ = self._cuda(weights[f"h.{self.layer_num_}.post_attention_layernorm.bias"])

        # ffn params
        if f"h.{self.layer_num_}.mlp.dense_h_to_4h.weight" in weights:
            n_embed = self.network_config_["n_embed"] * 4
            split_n_embed = n_embed // self.tp_world_size_
            self.ffn_1_weight_ = weights[f"h.{self.layer_num_}.mlp.dense_h_to_4h.weight"]
            self.ffn_1_weight_ = self._cuda(self.ffn_1_weight_[split_n_embed * self.tp_rank_: split_n_embed *
                                                    (self.tp_rank_ + 1), :].transpose(0, 1))

        if f"h.{self.layer_num_}.mlp.dense_h_to_4h.bias" in weights:
            n_embed = self.network_config_["n_embed"] * 4
            split_n_embed = n_embed // self.tp_world_size_
            self.ffn_1_bias_ = self._cuda(weights[f"h.{self.layer_num_}.mlp.dense_h_to_4h.bias"][split_n_embed *
                                                                                      self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)])
        if f"h.{self.layer_num_}.mlp.dense_4h_to_h.weight" in weights:
            n_embed = self.network_config_["n_embed"] * 4
            split_n_embed = n_embed // self.tp_world_size_
            self.ffn_2_weight_ = weights[f"h.{self.layer_num_}.mlp.dense_4h_to_h.weight"]
            self.ffn_2_weight_ = self._cuda(self.ffn_2_weight_[:, split_n_embed *
                                                    self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)].transpose(0, 1))

        if f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias" in weights:
            self.ffn_2_bias_ = self._cuda(weights[f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias"])
        return 

    def _generate_alibi(self, n_head, dtype=torch.float16):
        def get_slopes(n):
            def get_slopes_power_of_2(n):
                start = 2 ** (-(2 ** -(math.log2(n) - 3)))
                ratio = start
                return [start * ratio**i for i in range(n)]

            if math.log2(n).is_integer():
                return get_slopes_power_of_2(n)
            else:
                closest_power_of_2 = 2 ** math.floor(math.log2(n))
                return (
                    get_slopes_power_of_2(closest_power_of_2)
                    + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
                )

        slopes = torch.Tensor(get_slopes(n_head))
        head_alibi = slopes.to(dtype)
        return head_alibi

(2) Add implementation inference class

pre_layer_infer.py

import torch
import torch.distributed as dist
from lightllm.common.basemodel import PreLayerInferTpl
from lightllm.common.basemodel import InferStateInfo
from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward

class BloomPreLayerInfer(PreLayerInferTpl):
    """
    """
    def __init__(self, tp_rank, world_size, network_config, mode):
        super().__init__(tp_rank, world_size, network_config, mode)
        self.eps_ = network_config["layer_norm_epsilon"]
        tp_vocab_size_ = network_config["vocab_size"] // self.tp_world_size_
        self.vob_start_id_ = tp_vocab_size_ * self.tp_rank_
        self.vob_end_id_ = tp_vocab_size_ * (self.tp_rank_ + 1)
        return
    
    def _norm(self, input, infer_state, layer_weight : BloomPreAndPostLayerWeight) -> torch.Tensor:
        return layernorm_forward(input, layer_weight.pre_norm_weight_, layer_weight.pre_norm_bias_, eps=self.eps_)

    def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight):
        total_token_num = infer_state.total_token_num
        input_ids = input_ids[0:total_token_num]

        input_mask = torch.logical_or(self.vob_start_id_ > input_ids, input_ids >= self.vob_end_id_)
        tmp_input_ids = (input_ids - self.vob_start_id_)
        tmp_input_ids[input_mask] = 0
        input_embdings = torch.embedding(layer_weight.wte_weight_, tmp_input_ids, padding_idx=-1)
        input_embdings[input_mask] = 0.0
        if self.tp_world_size_ > 1:
            dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False)
        input_embdings = self._norm(input_embdings, infer_state, layer_weight)
        return input_embdings

    def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight):
        input_mask = torch.logical_or(self.vob_start_id_ > input_ids, input_ids >= self.vob_end_id_)
        tmp_input_ids = (input_ids - self.vob_start_id_)
        tmp_input_ids[input_mask] = 0
        input_embdings = torch.embedding(layer_weight.wte_weight_, tmp_input_ids, padding_idx=-1)
        input_embdings[input_mask] = 0.0
        if self.tp_world_size_ > 1:
            dist.all_reduce(input_embdings, op=dist.ReduceOp.SUM, async_op=False)
        input_embdings = self._norm(input_embdings, infer_state, layer_weight)
        return input_embdings

transformer_layer_infer.py

import time
import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np

from lightllm.common.basemodel import TransformerLayerInferTpl
from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight
from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd
from lightllm.models.bloom.triton_kernel.token_flashattention_nopad import token_attention_fwd
from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward
from lightllm.common.basemodel import InferStateInfo
from lightllm.utils.infer_utils import mark_cost_time

class BloomTransformerLayerInfer(TransformerLayerInferTpl):
    """
    """

    def __init__(self, layer_num, tp_rank, world_size, network_config, mode):
        super().__init__(layer_num, tp_rank, world_size, network_config, mode)
        self.eps_ = network_config["layer_norm_epsilon"]
        self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_
        self.tp_k_head_num_ = self.tp_q_head_num_
        self.tp_v_head_num_ = self.tp_q_head_num_
        self.tp_o_head_num_ = self.tp_q_head_num_
        self.head_dim_ = network_config["n_embed"] // network_config["num_attention_heads"]
        self.embed_dim_ = network_config["n_embed"]
        return
    
    def _att_norm(self, input, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
        return layernorm_forward(
            input.view(-1, self.embed_dim_),
            weight=layer_weight.att_norm_weight_,
            bias=layer_weight.att_norm_bias_,
            eps=self.eps_)
    
    def _ffn_norm(self, input, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
        return layernorm_forward(
            input.view(-1, self.embed_dim_),
            weight=layer_weight.ffn_norm_weight_,
            bias=layer_weight.ffn_norm_bias_,
            eps=self.eps_)
    
    def _get_qkv(self, input, cache_k, cache_v, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
        q = torch.addmm(layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0)
        torch.addmm(layer_weight.k_bias_, input.view(-1, self.embed_dim_), layer_weight.k_weight_, beta=1.0,
                    alpha=1.0, out=cache_k.view(-1, self.tp_k_head_num_ * self.head_dim_))
        torch.addmm(layer_weight.v_bias_, input.view(-1, self.embed_dim_), layer_weight.v_weight_, beta=1.0,
                    alpha=1.0, out=cache_v.view(-1, self.tp_v_head_num_ * self.head_dim_))
        return q
    
    def _context_attention_kernel(self, q, kv, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
        o_tensor = torch.empty_like(q)
        context_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_),
                              kv[:, 0: self.tp_k_head_num_, :],
                              kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
                              o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
                              layer_weight.tp_alibi,
                              infer_state.b_start_loc,
                              infer_state.b_seq_len,
                              infer_state.max_len_in_batch)
        return o_tensor
    
    def _token_attention_kernel(self, q, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
        o_tensor = torch.empty_like(q)
        token_attention_fwd(q.view(-1, self.tp_q_head_num_, self.head_dim_),
                            infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0: self.tp_k_head_num_, :],
                            infer_state.mem_manager.kv_buffer[self.layer_num_][:, self.tp_k_head_num_: self.tp_k_head_num_+ self.tp_v_head_num_, :],
                            o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
                            layer_weight.tp_alibi,
                            infer_state.b_loc,
                            infer_state.b_start_loc,
                            infer_state.b_seq_len,
                            infer_state.max_len_in_batch)
        return o_tensor
    
    def _get_o(self, input, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
        o = torch.addmm(layer_weight.o_bias_,
                        input.view(-1, self.tp_q_head_num_ * self.head_dim_),
                        layer_weight.o_weight_,
                        beta=1.0 / self.tp_world_size_)
        return o
    
    def _ffn(self, input, infer_state:InferStateInfo, layer_weight: BloomTransformerLayerWeight)->torch.Tensor:
        ffn1_out = torch.addmm(layer_weight.ffn_1_bias_, input.view(-1, self.embed_dim_), layer_weight.ffn_1_weight_)
        input = None
        gelu_out = torch.nn.functional.gelu(ffn1_out, approximate='tanh')
        ffn1_out = None
        ffn2_out = torch.addmm(layer_weight.ffn_2_bias_, gelu_out, layer_weight.ffn_2_weight_, beta=1.0 / self.tp_world_size_)
        gelu_out = None
        return ffn2_out

post_layer_infer.py

import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np

from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight
from einops import rearrange
from lightllm.common.basemodel import InferStateInfo, PostLayerInferTpl
from lightllm.models.bloom.triton_kernel.layernorm import layernorm_forward


class BloomPostLayerInfer(PostLayerInferTpl):
    """
    """

    def __init__(self, tp_rank, world_size, network_config, mode):
        super().__init__(tp_rank, world_size, network_config, mode)
        assert (network_config["vocab_size"] % self.tp_world_size_ == 0)
        self.eps_ = network_config["layer_norm_epsilon"]
        self.vocab_size_ = network_config["vocab_size"]
        self.embed_dim_ = network_config["n_embed"]
        return
    
    def _norm(self, input, infer_state, layer_weight : BloomPreAndPostLayerWeight) -> torch.Tensor:
        return layernorm_forward(input, layer_weight.final_norm_weight_, layer_weight.final_norm_bias_, eps=self.eps_)

    def soft_max(self, data):
        return torch.softmax(data.permute(1, 0).float(), dim=-1)

    def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight, return_logics=False):
        batch_size = infer_state.batch_size
        last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16)
        if infer_state.is_prefill:
            last_index = torch.cumsum(infer_state.b_seq_len, dim=0, dtype=torch.long) - 1
            last_input[:, :] = input_embdings[last_index, :]
        else:
            last_input[:, :] = input_embdings[-batch_size:, :]

        input_embdings = None
        last_input = self._norm(last_input, infer_state, layer_weight)
        last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, batch_size)
        logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input)
        last_input = None
        if self.tp_world_size_ == 1:
            gather_data = logic_batch
        else:
            gather_data = torch.empty((self.vocab_size_, batch_size), device=logic_batch.device, dtype=torch.float16)
            split_size = self.vocab_size_ // self.tp_world_size_
            dist.all_gather([gather_data[i * split_size: (i + 1) * split_size, :]
                            for i in range(self.tp_world_size_)], logic_batch, group=None, async_op=False)
        logic_batch = None

        if not return_logics:
            prob_out = self.soft_max(gather_data)
            gather_data = None
            return prob_out
        else:
            ans_logics = gather_data.permute(1, 0).float()
            gather_data = None
            return ans_logics

(3) Add implementation model class

model.py

import os
import json
from lightllm.models.bloom.layer_infer.pre_layer_infer import BloomPreLayerInfer
from lightllm.models.bloom.layer_infer.post_layer_infer import BloomPostLayerInfer
from lightllm.models.bloom.layer_infer.transformer_layer_infer import BloomTransformerLayerInfer
from lightllm.models.bloom.layer_weights.pre_and_post_layer_weight import BloomPreAndPostLayerWeight
from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight
from lightllm.common.basemodel import InferStateInfo, TpPartBaseModel

from lightllm.common.build_utils import repair_config

class BloomTpPartModel(TpPartBaseModel):
    # weight class
    pre_and_post_weight_class = BloomPreAndPostLayerWeight
    transformer_weight_class = BloomTransformerLayerWeight

    # infer class
    pre_layer_infer_class = BloomPreLayerInfer
    post_layer_infer_class = BloomPostLayerInfer
    transformer_layer_infer_class = BloomTransformerLayerInfer

    # infer state class
    infer_state_class = InferStateInfo

    def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_way="HF", mode=[]):
        super().__init__(tp_rank, world_size, weight_dir, max_total_token_num, load_way, mode)
        return

    def _init_config(self):
        super()._init_config()
        # rename key
        # repair_config()
        return 

(4) Add support for models in the server service layer

lightllm/server/router/model_infer/model_rpc.py

import asyncio
import rpyc
import torch
import traceback
from datetime import timedelta
from typing import Dict, List, Tuple
from transformers.configuration_utils import PretrainedConfig
from lightllm.server.router.model_infer.infer_batch import InferBatch
from rpyc.utils.classic import obtain

from lightllm.models.bloom.model import BloomTpPartModel
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
from lightllm.common.configs.config import setting
from .post_process import sample

class ModelRpcServer(rpyc.Service):

    def exposed_init_model(self, rank_id, world_size, weight_dir, max_total_token_num, load_way, mode):
        import torch
        import torch.distributed as dist
        if world_size != 1:
            trans_list = [obtain(e) for e in (rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)]
            rank_id, world_size, weight_dir, max_total_token_num, load_way, mode = trans_list

        self.tp_rank = rank_id
        self.tp_world_size_ = world_size
        self.load_way = load_way
        self.mode = mode
        self.cache = {}

        dist.init_process_group('nccl', init_method=f'tcp://127.0.0.1:{setting["nccl_port"]}', rank=rank_id, world_size=world_size)
        torch.cuda.set_device(rank_id)

        model_cfg, _ = PretrainedConfig.get_config_dict(
            weight_dir
        )
        try:
            self.model_type = model_cfg["model_type"]
            if self.model_type == "bloom":
                self.model = BloomTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
                raise Exception(f"can not support {self.model_type} now")
        except Exception as e:
            print("#" * 16)
            print("load model error:", str(e), e, type(e))
            raise e
        
        set_random_seed(2147483647)
        return
    ...