From c5ee00436b61b9ee54a933a46718edb3f1b896ce Mon Sep 17 00:00:00 2001 From: rubik Date: Tue, 12 May 2026 16:43:42 +0800 Subject: [PATCH] =?UTF-8?q?issue/349=20-=20Support=20GLM4=20model=20?= =?UTF-8?q?=E6=A0=B9=E6=8D=AEhttps://github.com/InfiniTensor/InfiniLM/pull?= =?UTF-8?q?/352=20=E8=BF=99=E4=B8=AAPR=E9=87=8C=E9=9D=A2=E7=9A=84=E6=A3=80?= =?UTF-8?q?=E8=A7=86=E6=84=8F=E8=A7=81=E4=BB=A5=E5=8F=8A=E6=9C=AC=E6=AC=A1?= =?UTF-8?q?PR=E4=B9=8Bcommitter=E7=9A=84=E6=A3=80=E8=A7=86=E6=84=8F?= =?UTF-8?q?=E8=A7=81=EF=BC=8C=E8=BF=9B=E8=A1=8C=E4=BA=86=E9=87=8D=E6=9E=84?= =?UTF-8?q?=20=E5=BB=BA=E8=AE=AE=E5=8F=82=E8=80=83=E4=B8=80=E4=B8=8B?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E7=82=B9=201=E3=80=81=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E4=B8=8D=E5=BA=94=E8=AF=A5=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E5=B7=B2=E6=9C=89=E6=A8=A1=E5=9E=8B=E4=BB=A3=E7=A0=81=EF=BC=8C?= =?UTF-8?q?=E4=B8=8D=E8=A6=81=E4=BF=AE=E6=94=B9llama=5Flegacy=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E4=B8=AD=E7=9A=84=E4=BB=A3=E7=A0=81=E3=80=82=202?= =?UTF-8?q?=E3=80=81=E5=88=A0=E9=99=A4config=5Ffactory.cpp=E5=92=8Crank=5F?= =?UTF-8?q?worker.cpp=E7=9A=84=E6=94=B9=E5=8A=A8=203=E3=80=81=E5=8F=82?= =?UTF-8?q?=E8=80=83=E5=B7=B2=E6=9C=89=E4=BB=A3=E7=A0=81=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=EF=BC=88=E9=9D=9Ellama=5Flegacy=E6=96=87=E4=BB=B6=E5=A4=B9?= =?UTF-8?q?=EF=BC=89=EF=BC=8Cmlp=20model=20causual=5Flm=20=E5=BA=94?= =?UTF-8?q?=E8=AF=A5=E5=8F=AF=E4=BB=A5=E4=BD=BF=E7=94=A8=E7=8E=B0=E6=9C=89?= =?UTF-8?q?=E7=9A=84=E6=A8=A1=E5=9D=97=E3=80=82=204=E3=80=81=E5=9C=A8glm4?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E6=B7=BB=E5=8A=A0=E5=A6=82=E4=B8=8B=E6=96=87?= =?UTF-8?q?=E4=BB=B6=20glm4=5Fdecoder=5Flayer.cpp/hpp=20+=20glm4=5Ffor=5Fc?= =?UTF-8?q?ausal=5Flm.cpp/hpp=E3=80=82=205=E3=80=81csrc/models/glm4/glm4?= =?UTF-8?q?=5Ffor=5Fcausal=5Flm.cpp=E4=B8=AD=EF=BC=8C=E9=9C=80=E8=A6=81?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=E4=B8=80=E4=B8=AA=E8=87=AA=E5=B7=B1=E7=9A=84?= =?UTF-8?q?Glm4ForCausalLM=E7=B1=BB=EF=BC=8C=E4=B8=8D=E8=A6=81=E4=BD=BF?= =?UTF-8?q?=E7=94=A8nfinilm::models::llama::LlamaForCausalLM=E3=80=82=206?= =?UTF-8?q?=E3=80=81RoPE=E7=B1=BB=E5=9E=8B=E9=97=AE=E9=A2=98=EF=BC=9Acsrc/?= =?UTF-8?q?layers/rotary=5Fembedding/rotary=5Fembedding.cpp=E4=B8=ADget=5F?= =?UTF-8?q?rope=E5=87=BD=E6=95=B0=E7=9A=84=E5=8A=9F=E8=83=BD=EF=BC=8C?= =?UTF-8?q?=E5=9C=A8=E8=BF=99=E4=B8=AA=E5=87=BD=E6=95=B0=E4=B8=AD=E5=A4=84?= =?UTF-8?q?=E7=90=86GPT=5FJ=E7=B1=BB=E5=9E=8B=E5=92=8C"partial=5Frotary=5F?= =?UTF-8?q?factor"=E8=B6=85=E5=8F=82=E6=95=B0=E3=80=82=207=E3=80=81?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E5=AD=97=E5=85=B8=E6=A8=A1=E5=BC=8F=E8=AE=BE?= =?UTF-8?q?=E8=AE=A1weights=20remap=208=E3=80=81=E5=90=8C=E6=97=B6?= =?UTF-8?q?=E6=94=AF=E6=8C=81atention=5Fstatic=E5=92=8Cattention=5Fpaged?= =?UTF-8?q?=209=E3=80=81=E4=B8=AD=E6=96=87=E6=B3=A8=E9=87=8A=E6=94=B9?= =?UTF-8?q?=E6=88=90=E8=8B=B1=E6=96=87=E6=B3=A8=E9=87=8A=2010=E3=80=81usin?= =?UTF-8?q?g=20Glm4ForCausalLM=20=3D=20infinilm::layers::causal=5Flm=5Ftem?= =?UTF-8?q?plates::TextCausalLM;=E5=A4=8D=E7=94=A8=E5=B7=B2=E6=9C=89?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=2011=E3=80=81=E5=88=A0=E9=99=A4reset=5Fcache?= =?UTF-8?q?=E9=87=8D=E8=BD=BD=2012=E3=80=81=E9=AA=8C=E8=AF=81tp=E5=B9=B6?= =?UTF-8?q?=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../rotary_embedding/rotary_embedding.cpp | 43 +++- .../rotary_embedding/rotary_embedding.hpp | 10 +- csrc/models/glm4/glm4_attention.cpp | 200 ++++++++++++++++++ csrc/models/glm4/glm4_attention.hpp | 50 +++++ csrc/models/glm4/glm4_decoder_layer.cpp | 75 +++++++ csrc/models/glm4/glm4_decoder_layer.hpp | 38 ++++ csrc/models/glm4/glm4_for_causal_lm.cpp | 33 +++ csrc/models/glm4/glm4_for_causal_lm.hpp | 16 ++ python/infinilm/modeling_utils.py | 104 ++++++++- 9 files changed, 560 insertions(+), 9 deletions(-) create mode 100644 csrc/models/glm4/glm4_attention.cpp create mode 100644 csrc/models/glm4/glm4_attention.hpp create mode 100644 csrc/models/glm4/glm4_decoder_layer.cpp create mode 100644 csrc/models/glm4/glm4_decoder_layer.hpp create mode 100644 csrc/models/glm4/glm4_for_causal_lm.cpp create mode 100644 csrc/models/glm4/glm4_for_causal_lm.hpp diff --git a/csrc/layers/rotary_embedding/rotary_embedding.cpp b/csrc/layers/rotary_embedding/rotary_embedding.cpp index 39254cb6..07c8eb3f 100644 --- a/csrc/layers/rotary_embedding/rotary_embedding.cpp +++ b/csrc/layers/rotary_embedding/rotary_embedding.cpp @@ -1,17 +1,48 @@ #include "rotary_embedding.hpp" #include #include +#include // std::llround +#include // std::clamp namespace infinilm::layers::rotary_embedding { namespace { thread_local std::unordered_map> _ROPE_DICT; - } // namespace +size_t get_rotary_dim(size_t head_dim, double partial_rotary_factor) { + if (partial_rotary_factor <= 0.0 || partial_rotary_factor >= 1.0) { + return head_dim; + } + + size_t rotary_dim = static_cast(std::llround( + static_cast(head_dim) * partial_rotary_factor)); + rotary_dim = std::clamp(rotary_dim, static_cast(2), head_dim); + + // RoPE operates on complex pairs, so the rotary dimension must be even + if (rotary_dim % 2 != 0) { + rotary_dim -= 1; + } + return std::max(rotary_dim, static_cast(2)); +} + std::shared_ptr get_rope(const std::shared_ptr &model_config, - const infinicore::Device &device) { + const infinicore::Device &device, + infinicore::nn::RoPE::Algo algo) { + // 1. Get head dimension + size_t head_dim = model_config->get_head_dim(); + + // 2. Safely get partial_rotary_factor, defaulting to 1.0 (full rotation) + double partial_rotary_factor = model_config->get_or("partial_rotary_factor", 1.0); + + // 3. Compute the actual rotary dimension + size_t rotary_dim = get_rotary_dim(head_dim, partial_rotary_factor); + + + // 4. Cache key must include rotary_dim to avoid reusing the same RoPE instance + // across models with different partial_rotary_factor values const std::string scaling_type = "default"; - auto it = _ROPE_DICT.find(scaling_type); + std::string cache_key = scaling_type + "_rope_dim_" + std::to_string(rotary_dim); + auto it = _ROPE_DICT.find(cache_key); if (it != _ROPE_DICT.end()) { return it->second; } @@ -19,11 +50,11 @@ std::shared_ptr get_rope(const std::shared_ptrget_dtype(); size_t max_position_embeddings = model_config->get("max_position_embeddings"); double rope_theta = model_config->get("rope_theta"); - auto rope = std::make_shared(model_config->get_head_dim(), max_position_embeddings, rope_theta, - infinicore::nn::RoPE::Algo::GPT_NEOX, dtype, device, + auto rope = std::make_shared(rotary_dim, max_position_embeddings, rope_theta, + algo, dtype, device, model_config->get_rope_scaling()); - _ROPE_DICT.emplace(scaling_type, rope); + _ROPE_DICT.emplace(cache_key, rope); return rope; } diff --git a/csrc/layers/rotary_embedding/rotary_embedding.hpp b/csrc/layers/rotary_embedding/rotary_embedding.hpp index 2d07b346..d6d7852e 100644 --- a/csrc/layers/rotary_embedding/rotary_embedding.hpp +++ b/csrc/layers/rotary_embedding/rotary_embedding.hpp @@ -1,11 +1,17 @@ #pragma once -#include "../../config/model_config.hpp" #include +#include "infinicore/nn/rope.hpp" +#include "../../config/model_config.hpp" namespace infinilm::layers::rotary_embedding { +// Compute the actual number of dimensions involved in rotary position embedding. +// For partial rotation, the dimension is clamped to [2, head_dim] and must be even. +size_t get_rotary_dim(size_t head_dim, double partial_rotary_factor); + std::shared_ptr get_rope(const std::shared_ptr &model_config, - const infinicore::Device &device); + const infinicore::Device &device, + infinicore::nn::RoPE::Algo algo = infinicore::nn::RoPE::Algo::GPT_NEOX); } // namespace infinilm::layers::rotary_embedding diff --git a/csrc/models/glm4/glm4_attention.cpp b/csrc/models/glm4/glm4_attention.cpp new file mode 100644 index 00000000..cf3854cc --- /dev/null +++ b/csrc/models/glm4/glm4_attention.cpp @@ -0,0 +1,200 @@ +#include "glm4_attention.hpp" +#include "../../global_state/global_state.hpp" +#include "../../layers/rotary_embedding/rotary_embedding.hpp" +#include "../../utils.hpp" +#include "infinicore/nn/linear.hpp" +#include "infinicore/nn/rope.hpp" +#include "infinicore/ops.hpp" + +#include +#include +#include + +namespace infinilm::models::glm4 { + +Glm4Attention::Glm4Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) + : model_config_(model_config), + layer_idx_(layer_idx), + hidden_size_(model_config->get("hidden_size")), + num_attention_heads_(model_config->get("num_attention_heads")), + num_key_value_heads_(model_config->get("num_key_value_heads")), + head_dim_(model_config->get_head_dim()), + rotary_dim_(infinilm::layers::rotary_embedding::get_rotary_dim(model_config->get_head_dim(), model_config->get_or("partial_rotary_factor", 1.0))), + use_bias_(model_config->get_or("attention_bias", true)), + use_output_bias_(model_config->get_or("attention_output_bias", false)) { + + const auto &dtype{model_config_->get_dtype()}; + + const engine::distributed::RankInfo &g_rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank(); + int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size(); + + if ((num_key_value_heads_ >= tp_size) && (0 == (num_key_value_heads_ % tp_size))) { + num_attention_heads_ /= tp_size; + num_key_value_heads_ /= tp_size; + } else { + throw std::runtime_error("Glm4Attention: num_key_value_heads must be divisible by tp_size"); + } + scaling_ = 1.0f / std::sqrt(static_cast(head_dim_)); + + // Linear layer initialization + auto quant_scheme = this->model_config_->get_quant_scheme(); + switch (quant_scheme) { + case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8: + INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get("num_attention_heads"), model_config_->get("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, dtype, device, g_rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_, dtype, device, tp_rank, tp_size, g_rank_info.comm); + break; + case infinicore::quantization::QuantScheme::AWQ_W4A16: { + INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get("num_attention_heads"), model_config_->get("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, dtype, device, g_rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_, dtype, device, tp_rank, tp_size, g_rank_info.comm); + break; + } + case infinicore::quantization::QuantScheme::GPTQ_W4A16_QY: { + INFINILM_QKV_LINEAR_W4A16GPTQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get("num_attention_heads"), model_config_->get("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, dtype, device, g_rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_, dtype, device, tp_rank, tp_size, g_rank_info.comm); + break; + } + default: + INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get("num_attention_heads"), model_config_->get("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, dtype, device, g_rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_, dtype, device, tp_rank, tp_size, g_rank_info.comm); + break; + } + + // RoPE initialization + attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; + rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device, infinicore::nn::RoPE::Algo::GPT_J); + + attn_ = std::make_shared( + num_attention_heads_, head_dim_, scaling_, + num_key_value_heads_, layer_idx_, + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + + // KV Cache quantization scale initialization + auto kv_quant_scheme = infinilm::global_state::get_infinilm_config().model_config->get_kv_quant_scheme(); + if (kv_quant_scheme == infinicore::quantization::KVQuantAlgo::INT8) { + INFINICORE_NN_PARAMETER_INIT(kv_cache_k_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); + INFINICORE_NN_PARAMETER_INIT(kv_cache_v_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); + } +} + +infinicore::Tensor Glm4Attention::forward(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states) { + if (::infinilm::backends::AttentionBackend::STATIC_ATTN == attention_backend_) { + return forward_static_(positions, hidden_states); + } + return forward_paged_(positions, hidden_states); +} + +infinicore::Tensor Glm4Attention::forward_paged_(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states) { + if (!rotary_emb_) { + throw std::runtime_error("Glm4Attention: rotary_emb not configured"); + } + + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + // Paged mode strictly requires batch_size == 1 due to continuous batching scheduler flattening + if (batch_size != 1) { + throw std::runtime_error("Glm4Attention::forward_paged_ expects batch_size == 1"); + } + + // 1. QKV Projection + auto [q, k, v] = qkv_proj_->forward_split(hidden_states); + + // Reshape to 3D [sl, nh, hd] - Native layout required by Paged KV Cache update + // DO NOT transpose K and V, otherwise do_kv_cache_update will cause segmentation fault! + q = q->view({seq_len, num_attention_heads_, head_dim_}); + k = k->view({seq_len, num_key_value_heads_, head_dim_}); + v = v->view({seq_len, num_key_value_heads_, head_dim_}); + + // 2. Position IDs + infinicore::Tensor pos_ids_for_rope; + if (positions->shape().size() == 2) { + // Squeeze the batch dimension directly to 1D [seq_len] + pos_ids_for_rope = positions->narrow({{0, 0, 1}})->view({seq_len}); + } else if (positions->shape().size() == 1) { + pos_ids_for_rope = positions; + } else { + throw std::runtime_error("Unexpected position_ids shape in forward_paged_"); + } + + // 3. Rotary Position Embedding (RoPE) + // Apply in-place on 3D tensors. Dimension is now 2 (hd) instead of 3 in 4D tensors. + rotary_emb_->forward(q->narrow({{2, 0, rotary_dim_}}), pos_ids_for_rope, true); + rotary_emb_->forward(k->narrow({{2, 0, rotary_dim_}}), pos_ids_for_rope, true); + + // 4. Attention computation + // Pass 3D Q, K, V directly. PagedAttention backend handles KV cache updating internally. + auto attn_output = attn_->forward(q, k, v); + + // 5. Output Projection + // Reshape attention output to 2D [seq_len, hidden_size] for the linear projection + auto output_2d = attn_output->view({seq_len, num_attention_heads_ * head_dim_}); + auto output = o_proj_->forward(output_2d); + + // Restore 3D [1, seq_len, hidden_size] to match the input hidden_states shape + return output->view({1, seq_len, hidden_size_}); +} + +infinicore::Tensor Glm4Attention::forward_static_(const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states) { + if (!rotary_emb_) { + throw std::runtime_error("Glm4Attention: rotary_emb not configured"); + } + + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + size_t num_tokens = batch_size * seq_len; + + // 1. QKV Projection -> [bs, sl, nh, hd] + auto [q, k, v] = qkv_proj_->forward_split(hidden_states); + q = q->contiguous()->view({batch_size, seq_len, num_attention_heads_, head_dim_}); + k = k->contiguous()->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + v = v->contiguous()->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + + // 2. Position IDs + infinicore::Tensor pos_ids_for_rope; + if (positions->shape().size() == 2) { + pos_ids_for_rope = positions->narrow({{0, 0, 1}})->contiguous()->view({seq_len}); + } else if (positions->shape().size() == 1) { + pos_ids_for_rope = positions->contiguous(); + } else { + throw std::runtime_error("Unexpected position_ids shape"); + } + + // 3. Rotary Position Embedding (RoPE) + // Use `true` to perform in-place modification on non-contiguous narrow views + rotary_emb_->forward(q->narrow({{3, 0, rotary_dim_}}), pos_ids_for_rope, true); + rotary_emb_->forward(k->narrow({{3, 0, rotary_dim_}}), pos_ids_for_rope, true); + + // Trick: Create tensor with target physical layout [bs, nh, sl, hd], + // permute to logical layout [bs, sl, nh, hd], and copy data. + // This satisfies the Attention backend's memory format requirement without realigning data. + auto q_in = infinicore::Tensor::empty({batch_size, num_attention_heads_, seq_len, head_dim_}, q->dtype(), q->device()) + ->permute({0, 2, 1, 3}); + q_in->copy_from(q); + + auto k_in = infinicore::Tensor::empty({batch_size, num_key_value_heads_, seq_len, head_dim_}, k->dtype(), k->device()) + ->permute({0, 2, 1, 3}); + k_in->copy_from(k); + + auto v_in = infinicore::Tensor::empty({batch_size, num_key_value_heads_, seq_len, head_dim_}, v->dtype(), v->device()) + ->permute({0, 2, 1, 3}); + v_in->copy_from(v); + + // 4. Attention computation (q, k, v) + auto attn_output = attn_->forward(q_in, k_in, v_in); + + // 5. Output Projection + auto output_2d = attn_output->contiguous()->view({num_tokens, num_attention_heads_ * head_dim_}); + auto output = o_proj_->forward(output_2d); + return output->view({batch_size, seq_len, hidden_size_}); +} + +} // namespace infinilm::models::glm4 + diff --git a/csrc/models/glm4/glm4_attention.hpp b/csrc/models/glm4/glm4_attention.hpp new file mode 100644 index 00000000..8ca87945 --- /dev/null +++ b/csrc/models/glm4/glm4_attention.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include "../../layers/common_modules.hpp" + +namespace infinilm::layers::attention { +class AttentionLayer; +} + +namespace infinilm::models::glm4 { + +class Glm4Attention : public infinicore::nn::Module { +public: + Glm4Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &positions, infinicore::Tensor &hidden_states); + + void set_rotary_emb(const std::shared_ptr &rotary_emb); + +protected: + // Operator layers + INFINICORE_NN_MODULE(infinilm::layers::linear::QKVParallelLinear, qkv_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj); + std::shared_ptr attn_; + ::infinilm::backends::AttentionBackend attention_backend_; + std::shared_ptr rotary_emb_; + std::shared_ptr model_config_; + + // Model parameters + size_t layer_idx_; + size_t hidden_size_; + size_t num_attention_heads_; + size_t num_key_value_heads_; + size_t head_dim_; + size_t rotary_dim_; + bool use_bias_; + bool use_output_bias_; + float scaling_; + + // KV Cache quantization + INFINICORE_NN_PARAMETER(kv_cache_k_scale); + INFINICORE_NN_PARAMETER(kv_cache_v_scale); +private: + infinicore::Tensor forward_static_(const infinicore::Tensor &positions, infinicore::Tensor &hidden_states); + infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, infinicore::Tensor &hidden_states); +}; + +} // namespace infinilm::models::glm4 + diff --git a/csrc/models/glm4/glm4_decoder_layer.cpp b/csrc/models/glm4/glm4_decoder_layer.cpp new file mode 100644 index 00000000..0a72ab7b --- /dev/null +++ b/csrc/models/glm4/glm4_decoder_layer.cpp @@ -0,0 +1,75 @@ +#include "glm4_decoder_layer.hpp" +#include "infinicore/ops.hpp" // 包含 add 等算子 + +namespace infinilm::models::glm4 { + +Glm4DecoderLayer::Glm4DecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + const auto &dtype = model_config->get_dtype(); + size_t hidden_size = model_config->get("hidden_size"); + double rms_norm_eps = model_config->get("rms_norm_eps"); + + self_attn_ = this->register_module( + "self_attn", model_config, layer_idx, device); + + mlp_ = this->register_module( + "mlp", model_config, device); + + input_layernorm_ = this->register_module( + "input_layernorm", hidden_size, rms_norm_eps, dtype, device); + + post_attention_layernorm_ = this->register_module( + "post_attention_layernorm", hidden_size, rms_norm_eps, dtype, device); + + post_self_attn_layernorm_ = this->register_module( + "post_self_attn_layernorm", hidden_size, rms_norm_eps, dtype, device); + + post_mlp_layernorm_ = this->register_module( + "post_mlp_layernorm", hidden_size, rms_norm_eps, dtype, device); +} + +std::tuple Glm4DecoderLayer::forward( + const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states, + infinicore::Tensor &residual) { + + // 1. Attention Block + residual = hidden_states; + hidden_states = input_layernorm_->forward(hidden_states); + hidden_states = self_attn_->forward(positions, hidden_states); + hidden_states = post_self_attn_layernorm_->forward(hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + + // 2. MLP Block + residual = hidden_states; + hidden_states = post_attention_layernorm_->forward(hidden_states); + hidden_states = mlp_->forward(hidden_states); + hidden_states = post_mlp_layernorm_->forward(hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + + return std::make_tuple(hidden_states, residual); +} + +infinicore::Tensor Glm4DecoderLayer::forward( + const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states) { + + auto residual = hidden_states; + hidden_states = input_layernorm_->forward(hidden_states); + hidden_states = self_attn_->forward(positions, hidden_states); + //hidden_states = post_attention_layernorm_->forward(hidden_states); + hidden_states = post_self_attn_layernorm_->forward(hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + + residual = hidden_states; + hidden_states = post_attention_layernorm_->forward(hidden_states); + hidden_states = mlp_->forward(hidden_states); + hidden_states = post_mlp_layernorm_->forward(hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + + return hidden_states; +} + +} // namespace infinilm::models::glm4 + diff --git a/csrc/models/glm4/glm4_decoder_layer.hpp b/csrc/models/glm4/glm4_decoder_layer.hpp new file mode 100644 index 00000000..0ba0086a --- /dev/null +++ b/csrc/models/glm4/glm4_decoder_layer.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include "../../config/model_config.hpp" +#include "../../backends/attention_backends.hpp" +#include "../../engine/distributed/distributed.hpp" +#include "glm4_attention.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "../../layers/mlp/mlp.hpp" + +namespace infinilm::models::glm4 { + +class Glm4DecoderLayer : public infinicore::nn::Module { +public: + Glm4DecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + std::tuple forward( + const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states, + infinicore::Tensor &residual); + + infinicore::Tensor forward( + const infinicore::Tensor &positions, + infinicore::Tensor &hidden_states); + +private: + INFINICORE_NN_MODULE(Glm4Attention, self_attn); + INFINICORE_NN_MODULE(infinilm::layers::mlp::MLP, mlp); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_self_attn_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_mlp_layernorm); +}; + +} // namespace infinilm::models::glm4 + diff --git a/csrc/models/glm4/glm4_for_causal_lm.cpp b/csrc/models/glm4/glm4_for_causal_lm.cpp new file mode 100644 index 00000000..6181a5c8 --- /dev/null +++ b/csrc/models/glm4/glm4_for_causal_lm.cpp @@ -0,0 +1,33 @@ +#include "glm4_for_causal_lm.hpp" +#include "../models_registry.hpp" +#include + +namespace infinilm::models::glm4 { + +std::shared_ptr create_glm4_model_config( + std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + if ("glm4" != model_type) { + throw std::runtime_error( + "infinilm::models::glm4::create_glm4_model_config: model_type is not glm4"); + } + + nlohmann::json &config_json = model_config->get_config_json(); + + if (!config_json.contains("attention_bias")) { + config_json["attention_bias"] = false; + } + + return model_config; +} + +} // namespace infinilm::models::glm4 + +namespace { + +INFINILM_REGISTER_CAUSAL_LM_MODEL( + glm4, + infinilm::models::glm4::Glm4ForCausalLM, + infinilm::models::glm4::create_glm4_model_config); +} // namespace + diff --git a/csrc/models/glm4/glm4_for_causal_lm.hpp b/csrc/models/glm4/glm4_for_causal_lm.hpp new file mode 100644 index 00000000..466a93b1 --- /dev/null +++ b/csrc/models/glm4/glm4_for_causal_lm.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "glm4_decoder_layer.hpp" +#include + +namespace infinilm::models::glm4 { + +using Glm4Model = infinilm::layers::causal_lm_templates::TextModel; + +using Glm4ForCausalLM = infinilm::layers::causal_lm_templates::TextCausalLM; + +std::shared_ptr create_glm4_model_config( + std::shared_ptr model_config); + +} // namespace infinilm::models::glm4 + diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index 7edc7cfe..d63d7634 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -1,6 +1,6 @@ import os import json -from typing import Dict, Union +from typing import Dict, Union, Optional, List import time import torch from safetensors import safe_open @@ -170,6 +170,8 @@ def load_model_state_dict_by_file( print(" load weights ......") t1 = time.time() + model_type = model.hf_config.get("model_type", "") + torch_device = "cpu" torch_dtype = infinicore.utils.to_torch_dtype(dtype) model_keys = model.state_dict_keyname() @@ -189,6 +191,12 @@ def load_model_state_dict_by_file( model_param = load_state_dict( file_path, device=torch_device, dtype=torch_dtype ) + + # Apply model-specific weight remapping + remapper = _WEIGHT_REMAPPER.get(model_type) + if remapper is not None: + model_param = remapper(model_param) + already_loaded_keys.extend(model_param.keys()) # --------------------------------------------------------- # @@ -214,6 +222,11 @@ def load_model_state_dict_by_file( file_path = os.path.join(model_path, "pytorch_model.bin") model_params = torch.load(file_path, weights_only=True, map_location="cpu") + # Apply model-specific weight remapping + remapper = _WEIGHT_REMAPPER.get(model_type) + if remapper is not None: + model_params = remapper(model_params) + # Scale embed_tokens on torch side before converting if "model.embed_tokens.weight" in model_params: embed_tokens_torch_unscaled = model_params["model.embed_tokens.weight"].to(dtype=torch_dtype) @@ -312,3 +325,92 @@ def load_model_state_dict_by_tensor( t2 = time.time() print(f" load weights over! {(t2 - t1) * 1000} ms \n") + +# ============================================================================ +# Common weight transformation utilities +# ============================================================================ + +def split_fused_weight( + state_dict: Dict[str, torch.Tensor], + fused_key: str, + output_names: List[str], + split_dim: int = 0, + split_ratios: Optional[List[float]] = None, +) -> Dict[str, torch.Tensor]: + """Split fused weight tensors into separate weights. + + Args: + state_dict: Original state dict from HuggingFace safetensors. + fused_key: Substring to match in key names (e.g. "gate_up_proj"). + output_names: Names of the split outputs (e.g. ["gate_proj", "up_proj"]). + split_dim: Dimension along which to split. Default 0. + split_ratios: Optional ratios. If None, split equally. + + Returns: + New state dict with fused keys replaced by split keys. + """ + result = {} + for key, tensor in state_dict.items(): + if fused_key not in key: + result[key] = tensor + continue + + base_key = key.replace(f".{fused_key}.weight", "") + dim_size = tensor.shape[split_dim] + num_splits = len(output_names) + + if split_ratios is not None: + total_ratio = sum(split_ratios) + sizes = [int(dim_size * r / total_ratio) for r in split_ratios[:-1]] + sizes.append(dim_size - sum(sizes)) + else: + chunk = dim_size // num_splits + sizes = [chunk] * (num_splits - 1) + sizes.append(dim_size - chunk * (num_splits - 1)) + + splits = torch.split(tensor, sizes, dim=split_dim) + for name, split_tensor in zip(output_names, splits): + result[f"{base_key}.{name}.weight"] = split_tensor + + return result + + +def rename_keys( + state_dict: Dict[str, torch.Tensor], + mapping: Dict[str, str], +) -> Dict[str, torch.Tensor]: + """Rename weight keys according to a substring mapping.""" + result = {} + for key, tensor in state_dict.items(): + new_key = key + for old_str, new_str in mapping.items(): + new_key = new_key.replace(old_str, new_str) + result[new_key] = tensor + return result + + +# ============================================================================ +# Model-specific remap functions +# ============================================================================ + +def _remap_glm4(state_dict): + """Split GLM-4 fused gate_up_proj into gate_proj + up_proj.""" + return split_fused_weight( + state_dict, + fused_key="gate_up_proj", + output_names=["gate_proj", "up_proj"], + ) + + +# Add more model remap functions here as needed: +# +# def _remap_qwen3(state_dict): +# state_dict = split_fused_weight(state_dict, "gate_up_proj", ["gate_proj", "up_proj"]) +# state_dict = rename_keys(state_dict, {"model.layers": "decoder.layers"}) +# return state_dict + +# Model type → remap function mapping +_WEIGHT_REMAPPER = { + "glm4": _remap_glm4, + # "qwen3": _remap_qwen3, +}