-
Notifications
You must be signed in to change notification settings - Fork 65
issue/349 - Support GLM4 model #370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,17 @@ | ||
| #pragma once | ||
|
|
||
| #include "../../config/model_config.hpp" | ||
| #include <memory> | ||
| #include "infinicore/nn/rope.hpp" | ||
| #include "../../config/model_config.hpp" | ||
|
|
||
| namespace infinilm::layers::rotary_embedding { | ||
|
|
||
| // Compute the actual number of dimensions involved in rotary position embedding. | ||
| // For partial rotation, the dimension is clamped to [2, head_dim] and must be even. | ||
| size_t get_rotary_dim(size_t head_dim, double partial_rotary_factor); | ||
|
|
||
| std::shared_ptr<infinicore::nn::RoPE> get_rope(const std::shared_ptr<infinilm::config::ModelConfig> &model_config, | ||
| const infinicore::Device &device); | ||
| const infinicore::Device &device, | ||
| infinicore::nn::RoPE::Algo algo = infinicore::nn::RoPE::Algo::GPT_NEOX); | ||
|
|
||
| } // namespace infinilm::layers::rotary_embedding |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,200 @@ | ||
| #include "glm4_attention.hpp" | ||
| #include "../../global_state/global_state.hpp" | ||
| #include "../../layers/rotary_embedding/rotary_embedding.hpp" | ||
| #include "../../utils.hpp" | ||
| #include "infinicore/nn/linear.hpp" | ||
| #include "infinicore/nn/rope.hpp" | ||
| #include "infinicore/ops.hpp" | ||
|
|
||
| #include <cmath> | ||
| #include <spdlog/spdlog.h> | ||
| #include <stdexcept> | ||
|
|
||
| namespace infinilm::models::glm4 { | ||
|
|
||
| Glm4Attention::Glm4Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config, | ||
| size_t layer_idx, | ||
| const infinicore::Device &device) | ||
| : model_config_(model_config), | ||
| layer_idx_(layer_idx), | ||
| hidden_size_(model_config->get<size_t>("hidden_size")), | ||
| num_attention_heads_(model_config->get<size_t>("num_attention_heads")), | ||
| num_key_value_heads_(model_config->get<size_t>("num_key_value_heads")), | ||
| head_dim_(model_config->get_head_dim()), | ||
| rotary_dim_(infinilm::layers::rotary_embedding::get_rotary_dim(model_config->get_head_dim(), model_config->get_or<double>("partial_rotary_factor", 1.0))), | ||
| use_bias_(model_config->get_or<bool>("attention_bias", true)), | ||
| use_output_bias_(model_config->get_or<bool>("attention_output_bias", false)) { | ||
|
|
||
| const auto &dtype{model_config_->get_dtype()}; | ||
|
|
||
| const engine::distributed::RankInfo &g_rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); | ||
| int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank(); | ||
| int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size(); | ||
|
|
||
| if ((num_key_value_heads_ >= tp_size) && (0 == (num_key_value_heads_ % tp_size))) { | ||
| num_attention_heads_ /= tp_size; | ||
| num_key_value_heads_ /= tp_size; | ||
| } else { | ||
| throw std::runtime_error("Glm4Attention: num_key_value_heads must be divisible by tp_size"); | ||
| } | ||
| scaling_ = 1.0f / std::sqrt(static_cast<float>(head_dim_)); | ||
|
|
||
| // Linear layer initialization | ||
| auto quant_scheme = this->model_config_->get_quant_scheme(); | ||
| switch (quant_scheme) { | ||
| case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8: | ||
| INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, dtype, device, g_rank_info); | ||
| INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_, dtype, device, tp_rank, tp_size, g_rank_info.comm); | ||
| break; | ||
| case infinicore::quantization::QuantScheme::AWQ_W4A16: { | ||
| INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, dtype, device, g_rank_info); | ||
| INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_, dtype, device, tp_rank, tp_size, g_rank_info.comm); | ||
| break; | ||
| } | ||
| case infinicore::quantization::QuantScheme::GPTQ_W4A16_QY: { | ||
| INFINILM_QKV_LINEAR_W4A16GPTQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, dtype, device, g_rank_info); | ||
| INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_, dtype, device, tp_rank, tp_size, g_rank_info.comm); | ||
| break; | ||
| } | ||
| default: | ||
| INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, model_config_->get<size_t>("num_attention_heads"), model_config_->get<size_t>("num_key_value_heads"), this->model_config_->get_quantization_method(), use_bias_, dtype, device, g_rank_info); | ||
| INFINICORE_NN_MODULE_INIT(o_proj, model_config_->get<size_t>("num_attention_heads") * head_dim_, hidden_size_, this->model_config_->get_quantization_method(), use_output_bias_, dtype, device, tp_rank, tp_size, g_rank_info.comm); | ||
| break; | ||
| } | ||
|
|
||
| // RoPE initialization | ||
| attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; | ||
| rotary_emb_ = infinilm::layers::rotary_embedding::get_rope(model_config, device, infinicore::nn::RoPE::Algo::GPT_J); | ||
|
|
||
| attn_ = std::make_shared<infinilm::layers::attention::AttentionLayer>( | ||
| num_attention_heads_, head_dim_, scaling_, | ||
| num_key_value_heads_, layer_idx_, | ||
| kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); | ||
|
|
||
| // KV Cache quantization scale initialization | ||
| auto kv_quant_scheme = infinilm::global_state::get_infinilm_config().model_config->get_kv_quant_scheme(); | ||
| if (kv_quant_scheme == infinicore::quantization::KVQuantAlgo::INT8) { | ||
| INFINICORE_NN_PARAMETER_INIT(kv_cache_k_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); | ||
| INFINICORE_NN_PARAMETER_INIT(kv_cache_v_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); | ||
| } | ||
| } | ||
|
|
||
| infinicore::Tensor Glm4Attention::forward(const infinicore::Tensor &positions, | ||
| infinicore::Tensor &hidden_states) { | ||
| if (::infinilm::backends::AttentionBackend::STATIC_ATTN == attention_backend_) { | ||
| return forward_static_(positions, hidden_states); | ||
| } | ||
| return forward_paged_(positions, hidden_states); | ||
| } | ||
|
|
||
| infinicore::Tensor Glm4Attention::forward_paged_(const infinicore::Tensor &positions, | ||
| infinicore::Tensor &hidden_states) { | ||
| if (!rotary_emb_) { | ||
| throw std::runtime_error("Glm4Attention: rotary_emb not configured"); | ||
| } | ||
|
|
||
| auto shape = hidden_states->shape(); | ||
| size_t batch_size = shape[0]; | ||
| size_t seq_len = shape[1]; | ||
|
|
||
| // Paged mode strictly requires batch_size == 1 due to continuous batching scheduler flattening | ||
| if (batch_size != 1) { | ||
| throw std::runtime_error("Glm4Attention::forward_paged_ expects batch_size == 1"); | ||
| } | ||
|
|
||
| // 1. QKV Projection | ||
| auto [q, k, v] = qkv_proj_->forward_split(hidden_states); | ||
|
|
||
| // Reshape to 3D [sl, nh, hd] - Native layout required by Paged KV Cache update | ||
| // DO NOT transpose K and V, otherwise do_kv_cache_update will cause segmentation fault! | ||
| q = q->view({seq_len, num_attention_heads_, head_dim_}); | ||
| k = k->view({seq_len, num_key_value_heads_, head_dim_}); | ||
| v = v->view({seq_len, num_key_value_heads_, head_dim_}); | ||
|
|
||
| // 2. Position IDs | ||
| infinicore::Tensor pos_ids_for_rope; | ||
| if (positions->shape().size() == 2) { | ||
| // Squeeze the batch dimension directly to 1D [seq_len] | ||
| pos_ids_for_rope = positions->narrow({{0, 0, 1}})->view({seq_len}); | ||
| } else if (positions->shape().size() == 1) { | ||
| pos_ids_for_rope = positions; | ||
| } else { | ||
| throw std::runtime_error("Unexpected position_ids shape in forward_paged_"); | ||
| } | ||
|
|
||
| // 3. Rotary Position Embedding (RoPE) | ||
| // Apply in-place on 3D tensors. Dimension is now 2 (hd) instead of 3 in 4D tensors. | ||
| rotary_emb_->forward(q->narrow({{2, 0, rotary_dim_}}), pos_ids_for_rope, true); | ||
| rotary_emb_->forward(k->narrow({{2, 0, rotary_dim_}}), pos_ids_for_rope, true); | ||
|
|
||
| // 4. Attention computation | ||
| // Pass 3D Q, K, V directly. PagedAttention backend handles KV cache updating internally. | ||
| auto attn_output = attn_->forward(q, k, v); | ||
|
|
||
| // 5. Output Projection | ||
| // Reshape attention output to 2D [seq_len, hidden_size] for the linear projection | ||
| auto output_2d = attn_output->view({seq_len, num_attention_heads_ * head_dim_}); | ||
| auto output = o_proj_->forward(output_2d); | ||
|
|
||
| // Restore 3D [1, seq_len, hidden_size] to match the input hidden_states shape | ||
| return output->view({1, seq_len, hidden_size_}); | ||
| } | ||
|
|
||
| infinicore::Tensor Glm4Attention::forward_static_(const infinicore::Tensor &positions, | ||
| infinicore::Tensor &hidden_states) { | ||
| if (!rotary_emb_) { | ||
| throw std::runtime_error("Glm4Attention: rotary_emb not configured"); | ||
| } | ||
|
|
||
| auto shape = hidden_states->shape(); | ||
| size_t batch_size = shape[0]; | ||
| size_t seq_len = shape[1]; | ||
| size_t num_tokens = batch_size * seq_len; | ||
|
|
||
| // 1. QKV Projection -> [bs, sl, nh, hd] | ||
| auto [q, k, v] = qkv_proj_->forward_split(hidden_states); | ||
| q = q->contiguous()->view({batch_size, seq_len, num_attention_heads_, head_dim_}); | ||
| k = k->contiguous()->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); | ||
| v = v->contiguous()->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); | ||
|
|
||
| // 2. Position IDs | ||
| infinicore::Tensor pos_ids_for_rope; | ||
| if (positions->shape().size() == 2) { | ||
| pos_ids_for_rope = positions->narrow({{0, 0, 1}})->contiguous()->view({seq_len}); | ||
| } else if (positions->shape().size() == 1) { | ||
| pos_ids_for_rope = positions->contiguous(); | ||
| } else { | ||
| throw std::runtime_error("Unexpected position_ids shape"); | ||
| } | ||
|
|
||
| // 3. Rotary Position Embedding (RoPE) | ||
| // Use `true` to perform in-place modification on non-contiguous narrow views | ||
| rotary_emb_->forward(q->narrow({{3, 0, rotary_dim_}}), pos_ids_for_rope, true); | ||
| rotary_emb_->forward(k->narrow({{3, 0, rotary_dim_}}), pos_ids_for_rope, true); | ||
|
|
||
| // Trick: Create tensor with target physical layout [bs, nh, sl, hd], | ||
| // permute to logical layout [bs, sl, nh, hd], and copy data. | ||
| // This satisfies the Attention backend's memory format requirement without realigning data. | ||
| auto q_in = infinicore::Tensor::empty({batch_size, num_attention_heads_, seq_len, head_dim_}, q->dtype(), q->device()) | ||
| ->permute({0, 2, 1, 3}); | ||
|
rubik-hua marked this conversation as resolved.
|
||
| q_in->copy_from(q); | ||
|
|
||
| auto k_in = infinicore::Tensor::empty({batch_size, num_key_value_heads_, seq_len, head_dim_}, k->dtype(), k->device()) | ||
| ->permute({0, 2, 1, 3}); | ||
| k_in->copy_from(k); | ||
|
|
||
| auto v_in = infinicore::Tensor::empty({batch_size, num_key_value_heads_, seq_len, head_dim_}, v->dtype(), v->device()) | ||
| ->permute({0, 2, 1, 3}); | ||
| v_in->copy_from(v); | ||
|
|
||
| // 4. Attention computation (q, k, v) | ||
| auto attn_output = attn_->forward(q_in, k_in, v_in); | ||
|
|
||
| // 5. Output Projection | ||
| auto output_2d = attn_output->contiguous()->view({num_tokens, num_attention_heads_ * head_dim_}); | ||
| auto output = o_proj_->forward(output_2d); | ||
| return output->view({batch_size, seq_len, hidden_size_}); | ||
| } | ||
|
|
||
| } // namespace infinilm::models::glm4 | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| #pragma once | ||
|
|
||
| #include "../../layers/common_modules.hpp" | ||
|
|
||
| namespace infinilm::layers::attention { | ||
| class AttentionLayer; | ||
| } | ||
|
|
||
| namespace infinilm::models::glm4 { | ||
|
|
||
| class Glm4Attention : public infinicore::nn::Module { | ||
|
rubik-hua marked this conversation as resolved.
|
||
| public: | ||
| Glm4Attention(std::shared_ptr<infinilm::config::ModelConfig> model_config, | ||
| size_t layer_idx, | ||
| const infinicore::Device &device); | ||
|
|
||
| infinicore::Tensor forward(const infinicore::Tensor &positions, infinicore::Tensor &hidden_states); | ||
|
|
||
| void set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb); | ||
|
rubik-hua marked this conversation as resolved.
|
||
|
|
||
| protected: | ||
| // Operator layers | ||
| INFINICORE_NN_MODULE(infinilm::layers::linear::QKVParallelLinear, qkv_proj); | ||
| INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj); | ||
| std::shared_ptr<infinilm::layers::attention::AttentionLayer> attn_; | ||
| ::infinilm::backends::AttentionBackend attention_backend_; | ||
| std::shared_ptr<infinicore::nn::RoPE> rotary_emb_; | ||
| std::shared_ptr<infinilm::config::ModelConfig> model_config_; | ||
|
|
||
| // Model parameters | ||
| size_t layer_idx_; | ||
| size_t hidden_size_; | ||
| size_t num_attention_heads_; | ||
| size_t num_key_value_heads_; | ||
| size_t head_dim_; | ||
| size_t rotary_dim_; | ||
| bool use_bias_; | ||
| bool use_output_bias_; | ||
| float scaling_; | ||
|
|
||
| // KV Cache quantization | ||
| INFINICORE_NN_PARAMETER(kv_cache_k_scale); | ||
| INFINICORE_NN_PARAMETER(kv_cache_v_scale); | ||
| private: | ||
| infinicore::Tensor forward_static_(const infinicore::Tensor &positions, infinicore::Tensor &hidden_states); | ||
| infinicore::Tensor forward_paged_(const infinicore::Tensor &positions, infinicore::Tensor &hidden_states); | ||
| }; | ||
|
|
||
| } // namespace infinilm::models::glm4 | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.