From 262a634ba27231c949ceaf27a4b2789464d5c930 Mon Sep 17 00:00:00 2001 From: wangpengcheng Date: Tue, 31 Mar 2026 11:46:23 +0000 Subject: [PATCH] issue/276 - code refactoring --- csrc/backends/attention_backends.hpp | 40 +++- csrc/cache/kv_cache.cpp | 76 +++++++ csrc/cache/kv_cache.hpp | 17 ++ csrc/config/config_factory.cpp | 29 +++ csrc/config/config_factory.hpp | 14 ++ csrc/config/model_config.cpp | 30 ++- csrc/config/model_config.hpp | 12 +- csrc/engine/compiler/paged_compiler.cpp | 12 ++ .../compiler/static_batching_compiler.cpp | 13 +- .../distributed/communication_group.cpp | 9 +- csrc/engine/infer_engine.cpp | 20 +- csrc/engine/infer_engine.hpp | 1 + csrc/engine/rank_worker.cpp | 48 +++-- csrc/engine/rank_worker.hpp | 9 +- csrc/global_state/forward_context.hpp | 57 +++++ csrc/global_state/global_state.cpp | 69 ++++++ csrc/global_state/global_state.hpp | 5 + csrc/global_state/infinilm_config.hpp | 34 +++ csrc/global_state/parallel_state.hpp | 24 +++ csrc/layers/attention/attention.cpp | 193 +++++++++++++++++ csrc/layers/attention/attention.hpp | 59 ++++++ .../attention/backends/attention_layer.cpp | 51 +++++ .../attention/backends/attention_layer.hpp | 53 +++++ csrc/layers/attention/backends/flash_attn.cpp | 104 +++++++++ csrc/layers/attention/backends/flash_attn.hpp | 55 +++++ csrc/layers/attention/backends/paged_attn.cpp | 83 ++++++++ csrc/layers/attention/backends/paged_attn.hpp | 54 +++++ .../layers/attention/backends/static_attn.cpp | 133 ++++++++++++ .../layers/attention/backends/static_attn.hpp | 46 ++++ .../causal_lm_templates/text_causal_lm.hpp | 56 +++++ .../text_decoder_layer.hpp | 77 +++++++ .../layers/causal_lm_templates/text_model.hpp | 96 +++++++++ csrc/layers/common_modules.hpp | 21 ++ csrc/layers/{ => linear}/fused_linear.cpp | 4 +- csrc/layers/{ => linear}/fused_linear.hpp | 18 +- csrc/layers/linear/linear.hpp | 12 ++ csrc/layers/mlp/mlp.cpp | 60 ++++++ csrc/layers/mlp/mlp.hpp | 52 +++++ csrc/layers/mlp/moe_mlp.cpp | 46 ++++ csrc/layers/mlp/moe_mlp.hpp | 29 +++ csrc/layers/{ => quantization}/kv_quant.cpp | 0 csrc/layers/{ => quantization}/kv_quant.hpp | 0 csrc/models/fm9g/fm9g_for_causal_lm.cpp | 50 +++++ csrc/models/fm9g/fm9g_for_causal_lm.hpp | 24 +++ csrc/models/infinilm_model.cpp | 94 ++++++++ csrc/models/infinilm_model.hpp | 24 ++- csrc/models/llama/llama_attention.cpp | 4 +- csrc/models/llama/llama_attention.hpp | 6 +- csrc/models/llama/llama_for_causal_lm.cpp | 2 + csrc/models/llama/llama_mlp.hpp | 4 +- ...minicpm_sala_allocate_kv_cache_tensors.cpp | 104 +++++++++ .../minicpm_sala/minicpm_sala_attention.cpp | 116 ++++++++++ .../minicpm_sala/minicpm_sala_attention.hpp | 87 ++++++++ .../minicpm_sala_decoderLayer.cpp | 66 ++++++ .../minicpm_sala_decoderLayer.hpp | 33 +++ .../minicpm_sala_for_causal_lm.cpp | 56 +++++ .../minicpm_sala_for_causal_lm.hpp | 30 +++ csrc/models/model_factory.cpp | 26 ++- csrc/models/model_factory.hpp | 12 +- csrc/models/models_registry.cpp | 32 +++ csrc/models/models_registry.hpp | 79 +++++++ csrc/models/qwen3/qwen3_attention.cpp | 200 ++++++++++++++++++ csrc/models/qwen3/qwen3_attention.hpp | 49 +++++ csrc/models/qwen3/qwen3_for_causal_lm.cpp | 23 ++ csrc/models/qwen3/qwen3_for_causal_lm.hpp | 23 ++ .../qwen3_moe/qwen3_moe_for_causal_lm.cpp | 24 +++ .../qwen3_moe/qwen3_moe_for_causal_lm.hpp | 18 ++ .../qwen3_moe/qwen3_moe_sparse_moe_block.cpp | 31 +++ .../qwen3_moe/qwen3_moe_sparse_moe_block.hpp | 22 ++ .../qwen3_next_allocate_kv_cache_tensors.cpp | 98 +++++++++ .../qwen3_next/qwen3_next_attention.cpp | 78 +++++++ .../qwen3_next/qwen3_next_attention.hpp | 42 ++++ .../qwen3_next/qwen3_next_decoderLayer.cpp | 73 +++++++ .../qwen3_next/qwen3_next_decoderLayer.hpp | 38 ++++ .../qwen3_next/qwen3_next_for_causal_lm.cpp | 73 +++++++ .../qwen3_next/qwen3_next_for_causal_lm.hpp | 31 +++ .../qwen3_next/qwen3_next_gated_deltanet.cpp | 59 ++++++ .../qwen3_next/qwen3_next_gated_deltanet.hpp | 46 ++++ .../qwen3_vl_for_conditional_generation.cpp | 83 ++++++++ .../qwen3_vl_for_conditional_generation.hpp | 36 ++++ csrc/pybind11/engine/engine.hpp | 12 +- examples/bench.py | 8 +- examples/jiuge.py | 5 +- python/infinilm/auto_config.py | 2 + xmake.lua | 10 + 85 files changed, 3687 insertions(+), 67 deletions(-) create mode 100644 csrc/config/config_factory.cpp create mode 100644 csrc/config/config_factory.hpp create mode 100644 csrc/global_state/forward_context.hpp create mode 100644 csrc/global_state/global_state.cpp create mode 100644 csrc/global_state/global_state.hpp create mode 100644 csrc/global_state/infinilm_config.hpp create mode 100644 csrc/global_state/parallel_state.hpp create mode 100644 csrc/layers/attention/attention.cpp create mode 100644 csrc/layers/attention/attention.hpp create mode 100644 csrc/layers/attention/backends/attention_layer.cpp create mode 100644 csrc/layers/attention/backends/attention_layer.hpp create mode 100644 csrc/layers/attention/backends/flash_attn.cpp create mode 100644 csrc/layers/attention/backends/flash_attn.hpp create mode 100644 csrc/layers/attention/backends/paged_attn.cpp create mode 100644 csrc/layers/attention/backends/paged_attn.hpp create mode 100644 csrc/layers/attention/backends/static_attn.cpp create mode 100644 csrc/layers/attention/backends/static_attn.hpp create mode 100644 csrc/layers/causal_lm_templates/text_causal_lm.hpp create mode 100644 csrc/layers/causal_lm_templates/text_decoder_layer.hpp create mode 100644 csrc/layers/causal_lm_templates/text_model.hpp create mode 100644 csrc/layers/common_modules.hpp rename csrc/layers/{ => linear}/fused_linear.cpp (99%) rename csrc/layers/{ => linear}/fused_linear.hpp (95%) create mode 100644 csrc/layers/linear/linear.hpp create mode 100644 csrc/layers/mlp/mlp.cpp create mode 100644 csrc/layers/mlp/mlp.hpp create mode 100644 csrc/layers/mlp/moe_mlp.cpp create mode 100644 csrc/layers/mlp/moe_mlp.hpp rename csrc/layers/{ => quantization}/kv_quant.cpp (100%) rename csrc/layers/{ => quantization}/kv_quant.hpp (100%) create mode 100644 csrc/models/fm9g/fm9g_for_causal_lm.cpp create mode 100644 csrc/models/fm9g/fm9g_for_causal_lm.hpp create mode 100644 csrc/models/infinilm_model.cpp create mode 100644 csrc/models/minicpm_sala/minicpm_sala_allocate_kv_cache_tensors.cpp create mode 100644 csrc/models/minicpm_sala/minicpm_sala_attention.cpp create mode 100644 csrc/models/minicpm_sala/minicpm_sala_attention.hpp create mode 100644 csrc/models/minicpm_sala/minicpm_sala_decoderLayer.cpp create mode 100644 csrc/models/minicpm_sala/minicpm_sala_decoderLayer.hpp create mode 100644 csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.cpp create mode 100644 csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.hpp create mode 100644 csrc/models/models_registry.cpp create mode 100644 csrc/models/models_registry.hpp create mode 100644 csrc/models/qwen3/qwen3_attention.cpp create mode 100644 csrc/models/qwen3/qwen3_attention.hpp create mode 100644 csrc/models/qwen3/qwen3_for_causal_lm.cpp create mode 100644 csrc/models/qwen3/qwen3_for_causal_lm.hpp create mode 100644 csrc/models/qwen3_moe/qwen3_moe_for_causal_lm.cpp create mode 100644 csrc/models/qwen3_moe/qwen3_moe_for_causal_lm.hpp create mode 100644 csrc/models/qwen3_moe/qwen3_moe_sparse_moe_block.cpp create mode 100644 csrc/models/qwen3_moe/qwen3_moe_sparse_moe_block.hpp create mode 100644 csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp create mode 100644 csrc/models/qwen3_next/qwen3_next_attention.cpp create mode 100644 csrc/models/qwen3_next/qwen3_next_attention.hpp create mode 100644 csrc/models/qwen3_next/qwen3_next_decoderLayer.cpp create mode 100644 csrc/models/qwen3_next/qwen3_next_decoderLayer.hpp create mode 100644 csrc/models/qwen3_next/qwen3_next_for_causal_lm.cpp create mode 100644 csrc/models/qwen3_next/qwen3_next_for_causal_lm.hpp create mode 100644 csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp create mode 100644 csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp create mode 100644 csrc/models/qwen3_vl/qwen3_vl_for_conditional_generation.cpp create mode 100644 csrc/models/qwen3_vl/qwen3_vl_for_conditional_generation.hpp diff --git a/csrc/backends/attention_backends.hpp b/csrc/backends/attention_backends.hpp index 5cf66305..b274aacc 100644 --- a/csrc/backends/attention_backends.hpp +++ b/csrc/backends/attention_backends.hpp @@ -1,25 +1,57 @@ #pragma once +#include #include #include namespace infinilm::backends { +/** + * @brief Enumeration of all supported attention backends. + */ enum class AttentionBackend { - Default, - FlashAttn, + STATIC_ATTN, + PAGED_ATTN, + FLASH_ATTN, + FLASHINFER, + Default = STATIC_ATTN }; +inline std::ostream &operator<<(std::ostream &os, AttentionBackend backend) { + switch (backend) { + case AttentionBackend::STATIC_ATTN: + return os << "AttentionBackend::STATIC_ATTN"; + case AttentionBackend::PAGED_ATTN: + return os << "AttentionBackend::PAGED_ATTN"; + case AttentionBackend::FLASH_ATTN: + return os << "AttentionBackend::FLASH_ATTN"; + case AttentionBackend::FLASHINFER: + return os << "AttentionBackend::FLASHINFER"; + default: + throw std::invalid_argument("infinilm::backends: invalid attention backend: " + std::to_string(static_cast(backend))); + break; + } +} + inline AttentionBackend parse_attention_backend(const std::string &backend) { if (backend == "default") { return AttentionBackend::Default; } + if (backend == "static-attn") { + return AttentionBackend::STATIC_ATTN; + } + if (backend == "paged-attn") { + return AttentionBackend::PAGED_ATTN; + } if (backend == "flash-attn") { - return AttentionBackend::FlashAttn; + return AttentionBackend::FLASH_ATTN; + } + if (backend == "flashinfer") { + return AttentionBackend::FLASHINFER; } throw std::invalid_argument( - "Invalid attention_backend: " + backend + ". Valid options are: default, flash-attn"); + "Invalid attention_backend: " + backend + ". Valid options are: static-attn, paged-attn, flash-attn, flashinfer"); } } // namespace infinilm::backends diff --git a/csrc/cache/kv_cache.cpp b/csrc/cache/kv_cache.cpp index 5f0de647..ab5f213e 100644 --- a/csrc/cache/kv_cache.cpp +++ b/csrc/cache/kv_cache.cpp @@ -1,5 +1,6 @@ #include "kv_cache.hpp" +#include "../global_state/global_state.hpp" #include "../utils.hpp" #include "infinicore/ops.hpp" #include @@ -76,6 +77,44 @@ StaticKVCache::StaticKVCache( rank_info.device); } +std::tuple StaticKVCache::create_layer_kv_cache( + const infinicore::Size k_dim, + const infinicore::Size v_dim, + const infinicore::Size num_k_heads, + const infinicore::Size num_v_heads, + const infinicore::Size max_positional_embedding, + const infinicore::DataType dtype, + const StaticKVCacheConfig &config) { + + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + + size_t rank_batch_size = (config.max_batch_size()); + size_t num_rank_k_heads = (num_k_heads / rank_info.tp_size); + size_t num_rank_v_heads = (num_v_heads / rank_info.tp_size); + + size_t cache_len = (config.max_cache_len() == std::numeric_limits::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()); + + // Allocate K cache + infinicore::Tensor k_caches = infinicore::Tensor::empty( + {rank_batch_size, + num_rank_k_heads, + cache_len, + k_dim}, + dtype, + rank_info.device); + + // Allocate V cache + infinicore::Tensor v_caches = infinicore::Tensor::empty( + {rank_batch_size, + num_rank_v_heads, + cache_len, + v_dim}, + dtype, + rank_info.device); + + return {k_caches, v_caches}; +} + std::tuple StaticKVCache::update(size_t layer_idx, const infinicore::Tensor &k, @@ -182,6 +221,43 @@ PagedKVCache::PagedKVCache( rank_info.device); } +std::tuple PagedKVCache::create_layer_kv_cache( + infinicore::Size k_dim, + infinicore::Size v_dim, + infinicore::Size num_k_heads, + infinicore::Size num_v_heads, + infinicore::DataType dtype, + const PagedKVCacheConfig &config) { + + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + + size_t num_rank_k_heads(num_k_heads / rank_info.tp_size); + size_t num_rank_v_heads(num_v_heads / rank_info.tp_size); + + size_t num_blocks_per_layer = config.num_blocks(); + size_t block_size = config.block_size(); + + // [ num_blocks, num_rank_k_heads, block_size, k_dim] + infinicore::Tensor k_caches = infinicore::Tensor::empty( + {num_blocks_per_layer, + num_rank_k_heads, + block_size, + k_dim}, + dtype, + rank_info.device); + + // [ num_blocks, num_rank_v_heads, block_size, v_dim] + infinicore::Tensor v_caches = infinicore::Tensor::empty( + {num_blocks_per_layer, + num_rank_v_heads, + block_size, + v_dim}, + dtype, + rank_info.device); + + return {k_caches, v_caches}; +} + std::tuple PagedKVCache::update( size_t layer_idx, const infinicore::Tensor &k, diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp index 164751e8..8fd4c5f6 100644 --- a/csrc/cache/kv_cache.hpp +++ b/csrc/cache/kv_cache.hpp @@ -45,6 +45,15 @@ class StaticKVCache final : public Cache { const StaticKVCacheConfig &config, const engine::distributed::RankInfo &rank_info); + static std::tuple create_layer_kv_cache( + const infinicore::Size k_dim, + const infinicore::Size v_dim, + const infinicore::Size num_k_heads, + const infinicore::Size num_v_heads, + const infinicore::Size max_positional_embedding, + const infinicore::DataType dtype, + const StaticKVCacheConfig &config); + /** * @brief Update KV cache at a given layer and cache position. * @@ -109,6 +118,14 @@ class PagedKVCache final : public Cache { const PagedKVCacheConfig &config, const engine::distributed::RankInfo &rank_info); + static std::tuple create_layer_kv_cache( + infinicore::Size k_dim, + infinicore::Size v_dim, + infinicore::Size num_k_heads, + infinicore::Size num_v_heads, + infinicore::DataType dtype, + const PagedKVCacheConfig &config); + /** * @brief Update Paged KV cache at a given layer given slot info for each token. * diff --git a/csrc/config/config_factory.cpp b/csrc/config/config_factory.cpp new file mode 100644 index 00000000..c822983e --- /dev/null +++ b/csrc/config/config_factory.cpp @@ -0,0 +1,29 @@ +#include "config_factory.hpp" +#include "../models/models_registry.hpp" +#include + +namespace infinilm::config { + +std::shared_ptr ConfigFactory::createConfig(const std::string &model_path) { + auto model_config = std::make_shared(model_path + "/config.json"); + if (nullptr == model_config) { + throw std::runtime_error("infinilm::config::ConfigFactory::createConfig: model_config is not initialized"); + } + + const std::string model_type = model_config->get("model_type"); + const auto &config_map = models::get_model_config_map(); + auto it = config_map.find(model_type); + if (it != config_map.end()) { + it->second(model_config); + } else { + std::vector classic_models = {"llama", "qwen2", "minicpm", "fm9g", "fm9g7b"}; + const std::string &model_type = model_config->get("model_type"); + if (std::find(classic_models.begin(), classic_models.end(), model_type) == classic_models.end()) { + throw std::invalid_argument("infinilm::config::ConfigFactory::createConfig: Unsupported model config type: " + model_type); + } + } + + return model_config; +} + +} // namespace infinilm::config diff --git a/csrc/config/config_factory.hpp b/csrc/config/config_factory.hpp new file mode 100644 index 00000000..5b5231ce --- /dev/null +++ b/csrc/config/config_factory.hpp @@ -0,0 +1,14 @@ +#pragma once + +#include "model_config.hpp" +#include +#include + +namespace infinilm::config { + +class ConfigFactory { +public: + static std::shared_ptr createConfig(const std::string &model_path); +}; + +} // namespace infinilm::config diff --git a/csrc/config/model_config.cpp b/csrc/config/model_config.cpp index 5b3c6bd6..957c2ec7 100644 --- a/csrc/config/model_config.cpp +++ b/csrc/config/model_config.cpp @@ -1,6 +1,10 @@ #include "model_config.hpp" namespace infinilm::config { +ModelConfig::ModelConfig(const nlohmann::json &json) : config_json(json) { + this->quant_config = QuantConfig(config_json["quantization_config"]); +}; + ModelConfig::ModelConfig(const std::string &path) { std::ifstream file(path); if (file.is_open()) { @@ -32,11 +36,15 @@ ModelConfig::get_rope_scaling() const { throw std::runtime_error("rope_scaling must be an object"); } - if (!rope_scaling.contains("type")) { - throw std::runtime_error("rope_scaling must contain 'type' field"); + std::string type_str; + if (rope_scaling.contains("type")) { + type_str = rope_scaling["type"].get(); + } else if (rope_scaling.contains("rope_type")) { + type_str = rope_scaling["rope_type"].get(); + } else { + throw std::runtime_error("rope_scaling must contain 'type' or 'rope_type' field"); } - std::string type_str = rope_scaling["type"].get(); if (type_str == "longrope") { // Required fields for LongRopeConfig if (!rope_scaling.contains("short_factor") || !rope_scaling.contains("long_factor") || !rope_scaling.contains("original_max_position_embeddings")) { @@ -67,7 +75,21 @@ ModelConfig::get_rope_scaling() const { } infinicore::DataType ModelConfig::get_dtype() const { - std::string dtype_str = this->get("torch_dtype"); + std::string dtype_str; + if (config_json.contains("dtype")) { + dtype_str = this->get("dtype"); + } else if (config_json.contains("torch_dtype")) { + dtype_str = this->get("torch_dtype"); + } else { + throw std::runtime_error("ModelConfig::get_dtype(): No dtype or torch_dtype found in config"); + } + return parse_dtype(dtype_str); } + +std::ostream &operator<<(std::ostream &os, const ModelConfig &config) { + os << config.config_json.dump(4); + return os; +} + } // namespace infinilm::config diff --git a/csrc/config/model_config.hpp b/csrc/config/model_config.hpp index 97049d6a..fc982173 100644 --- a/csrc/config/model_config.hpp +++ b/csrc/config/model_config.hpp @@ -5,7 +5,9 @@ #include "infinicore/ops.hpp" #include "quant_config.hpp" #include +#include #include +#include namespace infinilm::config { class ModelConfig { @@ -14,10 +16,13 @@ class ModelConfig { // and passed through the InferEngine during inference. public: ModelConfig() = default; - // Not Implemented - // ModelConfig(const nlohmann::json &json) : config_json(json) {}; + ModelConfig(const nlohmann::json &json); ModelConfig(const std::string &path); + nlohmann::json &get_config_json() { + return config_json; + } + // Template Function to get a value by key with type safety template T get(const std::string &key) const { @@ -78,6 +83,9 @@ class ModelConfig { } } + // Stream output operator + friend std::ostream &operator<<(std::ostream &os, const ModelConfig &config); + private: nlohmann::json config_json; QuantConfig quant_config; diff --git a/csrc/engine/compiler/paged_compiler.cpp b/csrc/engine/compiler/paged_compiler.cpp index 8f47f749..daca9ed2 100644 --- a/csrc/engine/compiler/paged_compiler.cpp +++ b/csrc/engine/compiler/paged_compiler.cpp @@ -1,4 +1,5 @@ #include "paged_compiler.hpp" +#include "../../global_state/global_state.hpp" namespace { // Todo: replace with Tensor::zeros when it is available @@ -59,6 +60,17 @@ void PagedCompiler::compile() { input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); set_zeros(input.slot_mapping.value()); + // Attention reads attn_metadata from thread-local forward context. + infinilm::global_state::get_forward_context().attn_metadata = { + input.position_ids, + input.past_sequence_lengths, + input.total_sequence_lengths, + input.input_offsets, + input.cu_seqlens, + input.block_tables, + input.slot_mapping, + }; + barrier_->wait(); infinicore::context::startGraphRecording(); auto output = model_->forward(input); diff --git a/csrc/engine/compiler/static_batching_compiler.cpp b/csrc/engine/compiler/static_batching_compiler.cpp index 34873038..49aef1b1 100644 --- a/csrc/engine/compiler/static_batching_compiler.cpp +++ b/csrc/engine/compiler/static_batching_compiler.cpp @@ -1,6 +1,6 @@ #include "static_batching_compiler.hpp" - #include "../../cache/cache.hpp" +#include "../../global_state/global_state.hpp" namespace infinilm::engine { StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr &model, RankBarrier *barrier) @@ -18,6 +18,17 @@ void StaticBatchingCompiler::compile() { std::vector total_sequence_lengths_vec(b, 1); infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false); + // Attention reads attn_metadata from thread-local forward context. + infinilm::global_state::get_forward_context().attn_metadata = { + input.position_ids, + input.past_sequence_lengths, + input.total_sequence_lengths, + input.input_offsets, + input.cu_seqlens, + input.block_tables, + input.slot_mapping, + }; + barrier_->wait(); infinicore::context::startGraphRecording(); auto output = model_->forward(input); diff --git a/csrc/engine/distributed/communication_group.cpp b/csrc/engine/distributed/communication_group.cpp index 92000a9f..782faa9e 100644 --- a/csrc/engine/distributed/communication_group.cpp +++ b/csrc/engine/distributed/communication_group.cpp @@ -6,10 +6,17 @@ namespace infinilm::engine::distributed { CommunicationGroup::CommunicationGroup(const DistConfig &dist_config, infinicore::Device::Type device_type) : dist_config_(dist_config), device_type_(device_type), communicators_(std::vector(dist_config.tp_device_ids.size(), nullptr)) { + + size_t world_size = dist_config_.tp_device_ids.size(); + size_t device_count = infinicore::context::getDeviceCount(device_type); + if (device_count < world_size) { + throw std::runtime_error("infinilm::engine::distributed::CommunicationGroup error, world size is larger than the number of available GPUs. world size: " + std::to_string(world_size) + ", device count: " + std::to_string(device_count)); + } + if (infinicore::context::getDevice().getType() != device_type_) { infinicore::context::setDevice(infinicore::Device(device_type_, 0)); } - if (dist_config_.tp_device_ids.size() > 1) { + if (world_size > 1) { RUN_INFINI(infinicclCommInitAll( (infiniDevice_t)infinicore::context::getDevice().getType(), communicators_.data(), diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 80aa1005..50dd8e92 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -1,4 +1,5 @@ #include "infer_engine.hpp" +#include "../config/config_factory.hpp" #include "spdlog/spdlog.h" namespace infinilm::engine { @@ -63,7 +64,9 @@ InferEngine::InferEngine( } // Load model config if model_path is provided, model_path must be valid, and config.json exists - this->model_config_ = std::make_shared(model_path + "/config.json"); + this->model_config_ = infinilm::config::ConfigFactory::createConfig(model_path); + auto infinilm_config = std::make_shared(attention_backend, this->model_config_); + // Only support offline int8 kv cache quantization in this version if (kv_cache_dtype.has_value()) { this->model_config_->set_kv_quant_scheme(kv_cache_dtype.value()); @@ -74,7 +77,7 @@ InferEngine::InferEngine( workers_.reserve(world_size); for (int r = 0; r < world_size; ++r) { workers_.emplace_back(std::make_unique( - model_config_, + infinilm_config, communication_group_.get_rank_info(r), cache_config_ != nullptr ? cache_config_.get() : nullptr, barrier_.get(), @@ -121,7 +124,7 @@ InferEngine::Input::to_model_input(infinicore::Device device) const { return t.has_value() ? t.value()->to(device) : t; }; - return { + infinilm::InfinilmModel::Input input = { to_device(input_ids), // @todo: on device in the future to_device(position_ids), to_device(past_sequence_lengths), // @todo: on device in the future @@ -131,6 +134,17 @@ InferEngine::Input::to_model_input(infinicore::Device device) const { to_device(block_tables), to_device(slot_mapping), }; + + infinilm::global_state::get_forward_context().attn_metadata = { + input.position_ids, + input.past_sequence_lengths, + input.total_sequence_lengths, + input.input_offsets, + input.cu_seqlens, + input.block_tables, + input.slot_mapping, + }; + return input; } InferEngine::Output InferEngine::forward(const InferEngine::Input &input) { diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 5e8030a2..45728abd 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -1,6 +1,7 @@ #pragma once #include "../config/model_config.hpp" +#include "../global_state/global_state.hpp" #include "../models/infinilm_model.hpp" #include "../models/llama/llama_config.hpp" #include "distributed/distributed.hpp" diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index c9cdc97f..1542c1e0 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -1,9 +1,9 @@ #include "rank_worker.hpp" +#include "../global_state/global_state.hpp" #include "../models/model_factory.hpp" - +#include "../models/models_registry.hpp" #include "infinicore/ops.hpp" - #include #include #include @@ -51,13 +51,14 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config, } RankWorker::RankWorker( - std::shared_ptr model_config, + std::shared_ptr infinilm_config, const distributed::RankInfo &rank_info, const cache::CacheConfig *cache_config, RankBarrier *barrier, bool enable_graph_compiling, backends::AttentionBackend attention_backend) - : model_config_(model_config), + : infinilm_config_(infinilm_config), + model_config_(infinilm_config->model_config), rank_info_(rank_info), attention_backend_(attention_backend), enable_graph_compiling_(enable_graph_compiling), @@ -236,20 +237,40 @@ void RankWorker::thread_loop() { // Initialize device & model outside of holding the main mutex to avoid blocking callers. infinicore::context::setDevice(rank_info_.device); + // Initialize global enviromnet. + infinilm::global_state::initialize_model_parallel(rank_info_); + infinilm::global_state::initialize_forward_context(forward_context_); + infinilm::global_state::initialize_infinilm_config(infinilm_config_); + // Create model using factory (may be expensive) if (model_config_ == nullptr) { - model_ = InfinilmModelFactory::createModel( - legacy_model_config_, - rank_info_, - pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr, - attention_backend_); + // model_ = InfinilmModelFactory::createModel( + // legacy_model_config_, + // rank_info_, + // pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr, + // attention_backend_); + throw std::runtime_error("RankWorker::thread_loop(): the way of creating models using LlamaConfig is no longer supported !!!"); + } - } else { + const std::string &model_type = model_config_->get("model_type"); + const auto &model_map = models::get_causal_lm_model_map(); + auto it = model_map.find(model_type); + if (it != model_map.end()) { model_ = InfinilmModelFactory::createModel( model_config_, - rank_info_, - pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr, - attention_backend_); + rank_info_.device, + pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr); + } else { + std::vector classic_models = {"llama", "qwen2", "minicpm", "fm9g", "fm9g7b"}; + if ((std::find(classic_models.begin(), classic_models.end(), model_type) != classic_models.end())) { + model_ = InfinilmModelFactory::createModel( + model_config_, + rank_info_, + pending_cache_config_ != nullptr ? pending_cache_config_.get() : nullptr, + attention_backend_); + } else { + throw std::runtime_error("RankWorker::thread_loop(): Unsupported model config type: " + model_type); + } } if (!model_) { @@ -431,7 +452,6 @@ void RankWorker::thread_loop() { // Shouldn't reach here (no-op) } } // while - // Some clean up should be done before exiting the thread compiler_.reset(); } catch (const std::exception &e) { diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 719c1cb1..1a38bb98 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -3,6 +3,7 @@ #include "../backends/attention_backends.hpp" #include "../cache/cache.hpp" #include "../config/model_config.hpp" +#include "../global_state/global_state.hpp" #include "../models/model_factory.hpp" #include "compiler/general_compiler.hpp" #include "distributed/distributed.hpp" @@ -18,6 +19,8 @@ namespace infinilm::engine { +using ForwardContext = infinilm::global_state::ForwardContext; + class RankWorker { enum class Command { INIT, @@ -67,7 +70,7 @@ class RankWorker { bool enable_graph_compiling, backends::AttentionBackend attention_backend); - RankWorker(std::shared_ptr model_config, + RankWorker(std::shared_ptr infinilm_config, const distributed::RankInfo &rank_info, const cache::CacheConfig *cache_config, RankBarrier *barrier, @@ -107,8 +110,10 @@ class RankWorker { private: // Worker properties const InfinilmModel::Config &legacy_model_config_ = InfinilmModel::Config(); + std::shared_ptr infinilm_config_; std::shared_ptr model_config_; - distributed::RankInfo rank_info_; + engine::distributed::RankInfo rank_info_; + ForwardContext forward_context_; std::shared_ptr model_; std::shared_ptr cache_; diff --git a/csrc/global_state/forward_context.hpp b/csrc/global_state/forward_context.hpp new file mode 100644 index 00000000..6333a52a --- /dev/null +++ b/csrc/global_state/forward_context.hpp @@ -0,0 +1,57 @@ +#pragma once + +#include "../models/infinilm_model.hpp" + +namespace infinilm::global_state { + +struct AttentionMetadata { + /// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`. + std::optional position_ids; + /// Past Lengths of cached sequence for each request, of shape `[num_requests]`. + std::optional past_sequence_lengths; + /// ToTal Lengths for each request sequence, of shape `[num_requests]`. + std::optional total_sequence_lengths; + /// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`. + std::optional input_offsets; + /// Cumulative total sequence lengths for each request, of shape `[num_requests + 1]`. + std::optional cu_seqlens; + /// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache. + std::optional block_tables; + /// Slot ids for each token `[seq]`. Used for paged cache. + std::optional slot_mapping; + + AttentionMetadata() = default; + + AttentionMetadata(std::optional position_ids, + std::optional past_sequence_lengths, + std::optional total_sequence_lengths, + std::optional input_offsets, + std::optional cu_seqlens, + std::optional block_tables, + std::optional slot_mapping) : position_ids(position_ids), + past_sequence_lengths(past_sequence_lengths), + total_sequence_lengths(total_sequence_lengths), + input_offsets(input_offsets), + cu_seqlens(cu_seqlens), + block_tables(block_tables), + slot_mapping(slot_mapping) {} + + AttentionMetadata(const infinilm::InfinilmModel::Input &input) : AttentionMetadata(input.position_ids, + input.past_sequence_lengths, + input.total_sequence_lengths, + input.input_offsets, + input.cu_seqlens, + input.block_tables, + input.slot_mapping) {} +}; + +struct ForwardContext { + AttentionMetadata attn_metadata; + std::vector> kv_cache_vec; +}; + +void initialize_forward_context(ForwardContext &forward_context); + +ForwardContext &get_forward_context(); + +} // namespace infinilm::global_state diff --git a/csrc/global_state/global_state.cpp b/csrc/global_state/global_state.cpp new file mode 100644 index 00000000..a40a20af --- /dev/null +++ b/csrc/global_state/global_state.cpp @@ -0,0 +1,69 @@ +#include "global_state.hpp" +#include "../utils.hpp" +#include +#include + +namespace infinilm::global_state { + +namespace { + +thread_local ForwardContext *_forward_context{nullptr}; + +thread_local InfinilmConfig *_infinilm_config{nullptr}; + +thread_local const engine::distributed::RankInfo *_TP{nullptr}; + +} // namespace +} // namespace infinilm::global_state + +namespace infinilm::global_state { + +void initialize_forward_context(ForwardContext &forward_context) { + ASSERT(nullptr == _forward_context && "Forward context is already initialized, cannot be initialized again."); + _forward_context = &forward_context; +} + +ForwardContext &get_forward_context() { + return *_forward_context; +} + +} // namespace infinilm::global_state + +namespace infinilm::global_state { + +void initialize_infinilm_config(const std::shared_ptr &config) { + ASSERT(nullptr == _infinilm_config); + ASSERT(nullptr != config); + _infinilm_config = config.get(); +} + +const InfinilmConfig &get_infinilm_config() { + ASSERT(nullptr != _infinilm_config && "Current Infinilm config is not set."); + return *_infinilm_config; +} + +} // namespace infinilm::global_state + +namespace infinilm::global_state { + +void initialize_model_parallel(const engine::distributed::RankInfo &rank_info) { + ASSERT(nullptr == _TP && "Tensor model parallel state is already initialized, cannot be initialized again."); + _TP = &rank_info; +} + +const size_t get_tensor_model_parallel_world_size() { + ASSERT(nullptr != _TP && "Tensor model parallel state is not initialized."); + return _TP->tp_size; +} + +const size_t get_tensor_model_parallel_rank() { + ASSERT(nullptr != _TP && "Tensor model parallel state is not initialized."); + return _TP->tp_rank; +} + +const engine::distributed::RankInfo &get_tensor_model_parallel_rank_info() { + ASSERT(nullptr != _TP && "Tensor model parallel state is not initialized."); + return *_TP; +} + +} // namespace infinilm::global_state diff --git a/csrc/global_state/global_state.hpp b/csrc/global_state/global_state.hpp new file mode 100644 index 00000000..19161c89 --- /dev/null +++ b/csrc/global_state/global_state.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "forward_context.hpp" +#include "infinilm_config.hpp" +#include "parallel_state.hpp" diff --git a/csrc/global_state/infinilm_config.hpp b/csrc/global_state/infinilm_config.hpp new file mode 100644 index 00000000..9b80706c --- /dev/null +++ b/csrc/global_state/infinilm_config.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include "../backends/attention_backends.hpp" +#include "../config/model_config.hpp" +#include + +namespace infinilm::global_state { + +/** + * @brief Dataclass which contains all infinilm-related configuration. + * This simplifies passing around the distinct configurations in the codebase. + */ +struct InfinilmConfig { +public: + InfinilmConfig() = default; + InfinilmConfig(const infinilm::backends::AttentionBackend &backend, + const std::shared_ptr &model_config) + : attention_backend(backend), + model_config(model_config) {} + +public: + infinilm::backends::AttentionBackend attention_backend; + std::shared_ptr model_config; +}; + +/** + * @brief save the current Infinilm config in a global variable, + * so that all modules can access it, e.g. custom ops can access the Infinilm config to determine how to dispatch. + */ +void initialize_infinilm_config(const std::shared_ptr &config); + +const InfinilmConfig &get_infinilm_config(); + +} // namespace infinilm::global_state diff --git a/csrc/global_state/parallel_state.hpp b/csrc/global_state/parallel_state.hpp new file mode 100644 index 00000000..f5f7fdc8 --- /dev/null +++ b/csrc/global_state/parallel_state.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "../engine/distributed/distributed.hpp" + +namespace infinilm::global_state { + +void initialize_model_parallel(const engine::distributed::RankInfo &rank_info); + +/** + * @brief get the world size of the tensor model parallel group. + */ +const size_t get_tensor_model_parallel_world_size(); + +/** + * @brief get the rank of the current process in the tensor model parallel group. + */ +const size_t get_tensor_model_parallel_rank(); + +/** + * @brief get the rank_info of the current process in the tensor model parallel group. + */ +const engine::distributed::RankInfo &get_tensor_model_parallel_rank_info(); + +} // namespace infinilm::global_state diff --git a/csrc/layers/attention/attention.cpp b/csrc/layers/attention/attention.cpp new file mode 100644 index 00000000..f291fb77 --- /dev/null +++ b/csrc/layers/attention/attention.cpp @@ -0,0 +1,193 @@ +#include "attention.hpp" +#include "../../global_state/global_state.hpp" +#include "../../utils.hpp" + +namespace infinilm::layers::attention { +using infinilm::global_state::AttentionMetadata; + +Attention::Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + layer_idx_ = layer_idx; + + const auto &dtype{model_config->get_dtype()}; + num_attention_heads_ = model_config->get("num_attention_heads"); + num_key_value_heads_ = model_config->get("num_key_value_heads"); + hidden_size_ = model_config->get("hidden_size"); + head_dim_ = model_config->get("head_dim"); + + float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); + + bool use_bias = model_config->get_or("attention_bias", true); + bool use_output_bias = model_config->get_or("attention_output_bias", false); + double rms_norm_eps = model_config->get("rms_norm_eps"); + + attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank(); + int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size(); + + auto quant_scheme = model_config->get_quant_scheme(); + auto quantization_method = model_config->get_quantization_method(); + switch (quant_scheme) { + case infinicore::quantization::QuantScheme::NONE: { + INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_, + quantization_method, use_bias, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads_ * head_dim_, hidden_size_, quantization_method, + use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8: { + INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_, + quantization_method, use_bias, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads_ * head_dim_, hidden_size_, quantization_method, + use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + case infinicore::quantization::QuantScheme::AWQ_W4A16: { + INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_, + quantization_method, use_bias, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads_ * head_dim_, hidden_size_, quantization_method, + use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + default: { + throw std::runtime_error("infinilm::layers::attention::Attention: unsupported quantization scheme"); + break; + } + } + + auto kv_quant_scheme = infinilm::global_state::get_infinilm_config().model_config->get_kv_quant_scheme(); + switch (kv_quant_scheme) { + case (infinicore::quantization::KVQuantAlgo::NONE): { + break; + } + case (infinicore::quantization::KVQuantAlgo::INT8): { + INFINICORE_NN_PARAMETER_INIT(kv_cache_k_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); + INFINICORE_NN_PARAMETER_INIT(kv_cache_v_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); + break; + } + default: { + throw std::runtime_error("infinilm::layers::attention: unsupported kv_quant_scheme"); + break; + } + } + + if ((num_key_value_heads_ < tp_size) || (0 != (num_key_value_heads_ % tp_size))) { + throw std::runtime_error("infinilm::layers::attention::Attention: num_key_value_heads must be divisible by tp_size"); + } + + size_t num_attention_heads_rank = num_attention_heads_ / tp_size; + size_t num_key_value_heads_rank = num_key_value_heads_ / tp_size; + attn_ = std::make_shared(num_attention_heads_rank, head_dim_, scaling, num_key_value_heads_rank, layer_idx_, + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + + num_attention_heads_ = num_attention_heads_rank; + num_key_value_heads_ = num_key_value_heads_rank; +} + +infinicore::Tensor Attention::forward(const infinicore::Tensor &hidden_states) const { + if (!rotary_emb_) { + throw std::runtime_error("infinilm::layers::attention::Attention: rotary_emb not configured"); + } + + auto &forward_context = infinilm::global_state::get_forward_context(); + AttentionMetadata &attn_metadata = forward_context.attn_metadata; + std::tuple &kv_cache = forward_context.kv_cache_vec[layer_idx_]; + + if (::infinilm::backends::AttentionBackend::STATIC_ATTN == attention_backend_) { + return forward_static_(hidden_states, attn_metadata, kv_cache); + } + return forward_paged_(hidden_states, attn_metadata, kv_cache); +} + +infinicore::Tensor Attention::forward_static_(const infinicore::Tensor &hidden_states, + const infinilm::global_state::AttentionMetadata &attn_metadata, + std::tuple &kv_cache) const { + auto position_ids = attn_metadata.position_ids.value(); + + // hidden_states shape: [batch, seq_len, hidden_size] + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + // 1. Project Q, K, V + auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + + // 2. Reshape for multi-head attention + auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + + // 3. Prepare position_ids for RoPE + auto pos_shape = position_ids->shape(); + infinicore::Tensor pos_ids_for_rope = position_ids; + if (pos_shape.size() == 2) { + auto pos_narrowed = position_ids->narrow({{0, 0, 1}}); + pos_ids_for_rope = pos_narrowed->contiguous()->view({pos_shape[1]}); + } else if (pos_shape.size() == 1) { + pos_ids_for_rope = position_ids->contiguous(); + } else { + throw std::runtime_error("infinilm::layers::attention::Attention: Unexpected position_ids shape"); + } + + // 4. Apply RoPE to QK + auto q_rope = infinicore::Tensor::empty({batch_size, num_attention_heads_, seq_len, head_dim_}, q_reshaped->dtype(), q_reshaped->device())->permute({0, 2, 1, 3}); + rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); + rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); + + // 5. Attn Backend calculate + auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped, kv_cache, attn_metadata); + + // 7. Project output + auto output = o_proj_->forward(attn_output); + return output; +} + +infinicore::Tensor Attention::forward_paged_(const infinicore::Tensor &hidden_states, + const infinilm::global_state::AttentionMetadata &attn_metadata, + std::tuple &kv_cache) const { + auto position_ids = attn_metadata.position_ids.value(); + + // hidden_states shape: [batch, seq_len, hidden_size] + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + // Only support batchsize==1, all requests should be flattened along seqlen dimension + ASSERT_EQ(batch_size, 1); + + // 1. Project Q, K, V + auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + + // 2. Reshape for multi-head attention + auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_}); + + // 3. Prepare position_ids for RoPE + auto pos_shape = position_ids->shape(); + infinicore::Tensor pos_ids_for_rope = position_ids; + if (pos_shape.size() == 2) { + auto pos_narrowed = position_ids->narrow({{0, 0, 1}}); + pos_ids_for_rope = pos_narrowed->view({pos_shape[1]}); + } else if (pos_shape.size() == 1) { + pos_ids_for_rope = position_ids; + } else { + throw std::runtime_error("Unexpected position_ids shape"); + } + + // 4. Apply RoPE to QK + rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); + rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); + + // 5. Attn Backend calculate + auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped, kv_cache, attn_metadata); + + // 6. Project output + auto output = o_proj_->forward(attn_output); + return output; +} +} // namespace infinilm::layers::attention diff --git a/csrc/layers/attention/attention.hpp b/csrc/layers/attention/attention.hpp new file mode 100644 index 00000000..ba6c595e --- /dev/null +++ b/csrc/layers/attention/attention.hpp @@ -0,0 +1,59 @@ +#pragma once + +#include "../../backends/attention_backends.hpp" +#include "../../config/model_config.hpp" +#include "../../global_state/global_state.hpp" +#include "../linear/linear.hpp" +#include "backends/attention_layer.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/nn/rope.hpp" +#include "infinicore/tensor.hpp" +#include +#include + +namespace infinilm::layers::attention { +class Attention : public infinicore::nn::Module { +public: + Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + + size_t layer_idx() const { return layer_idx_; } + size_t num_heads() const { return num_attention_heads_; } + size_t num_kv_heads() const { return num_key_value_heads_; } + size_t head_dim() const { return head_dim_; } + size_t hidden_size() const { return hidden_size_; } + void set_rotary_emb(const std::shared_ptr &rotary_emb) { rotary_emb_ = rotary_emb; } + +private: + infinicore::Tensor forward_static_(const infinicore::Tensor &hidden_states, + const infinilm::global_state::AttentionMetadata &attn_metadata, + std::tuple &kv_cache) const; + + infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states, + const infinilm::global_state::AttentionMetadata &attn_metadata, + std::tuple &kv_cache) const; + +protected: + INFINICORE_NN_MODULE(infinilm::layers::linear::QKVParallelLinear, qkv_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm); + + std::shared_ptr rotary_emb_; + std::shared_ptr attn_; + ::infinilm::backends::AttentionBackend attention_backend_; + size_t layer_idx_; + size_t num_attention_heads_; + size_t num_key_value_heads_; + size_t hidden_size_; + size_t head_dim_; + + // For off-line kv cache quantization + INFINICORE_NN_PARAMETER(kv_cache_k_scale); + INFINICORE_NN_PARAMETER(kv_cache_v_scale); +}; +} // namespace infinilm::layers::attention diff --git a/csrc/layers/attention/backends/attention_layer.cpp b/csrc/layers/attention/backends/attention_layer.cpp new file mode 100644 index 00000000..c4206cbc --- /dev/null +++ b/csrc/layers/attention/backends/attention_layer.cpp @@ -0,0 +1,51 @@ +#include "attention_layer.hpp" + +namespace infinilm::layers::attention { + +AttentionLayer::AttentionLayer(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx, + infinicore::Tensor k_scale, + infinicore::Tensor v_scale, + ::infinilm::backends::AttentionBackend attn_backend) : k_scale_(k_scale), v_scale_(v_scale), attn_backend_(attn_backend) { + switch (attn_backend) { + case ::infinilm::backends::AttentionBackend::STATIC_ATTN: + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + break; + case ::infinilm::backends::AttentionBackend::PAGED_ATTN: + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + break; + case ::infinilm::backends::AttentionBackend::FLASH_ATTN: + attn_backend_impl_ = std::make_shared(num_heads, head_size, scale, num_kv_heads, layer_idx); + break; + default: + throw std::runtime_error("infinilm::layers::attention::AttentionLayer: unsupported attention backend"); + } +} + +infinicore::Tensor AttentionLayer::forward(infinicore::Tensor &query, + infinicore::Tensor &key, + infinicore::Tensor &value, + std::tuple kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const { + // switch (attn_backend_) { + // case ::infinilm::backends::AttentionBackend::STATIC_ATTN: + // return std::get>(attn_backend_impl_)->forward(*this, query, key, value, kv_cache, attn_metadata); + // case ::infinilm::backends::AttentionBackend::PAGED_ATTN: + // return std::get>(attn_backend_impl_)->forward(*this, query, key, value, kv_cache, attn_metadata); + // case ::infinilm::backends::AttentionBackend::FLASH_ATTN: + // return std::get>(attn_backend_impl_)->forward(*this, query, key, value, kv_cache, attn_metadata); + // default: + // throw std::runtime_error("infinilm::layers::attention::AttentionLayer::forward: unsupported attention backend"); + // } + + return std::visit( + [&](auto &impl_ptr) -> infinicore::Tensor { + return impl_ptr->forward(*this, query, key, value, kv_cache, attn_metadata); + }, + attn_backend_impl_); +} + +} // namespace infinilm::layers::attention diff --git a/csrc/layers/attention/backends/attention_layer.hpp b/csrc/layers/attention/backends/attention_layer.hpp new file mode 100644 index 00000000..50dd9ac5 --- /dev/null +++ b/csrc/layers/attention/backends/attention_layer.hpp @@ -0,0 +1,53 @@ +#pragma once + +#include "../../../backends/attention_backends.hpp" +#include "../../../global_state/global_state.hpp" +#include "flash_attn.hpp" +#include "infinicore/tensor.hpp" +#include "paged_attn.hpp" +#include "static_attn.hpp" +#include +#include +#include + +namespace infinilm::layers::attention { +using AttentionImpl = std::variant, std::shared_ptr, std::shared_ptr>; + +/** + * @brief Attention layer. + * This class takes query, key, and value tensors as input. + * The input tensors can either contain prompt tokens or generation tokens. + * + * The class does the following: + * - Update the KV cache. + * - Perform (multi-head/multi-query/grouped-query) attention. + * - Return the output tensor. + */ +class AttentionLayer { +public: + AttentionLayer(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx, + infinicore::Tensor k_scale, + infinicore::Tensor v_scale, + ::infinilm::backends::AttentionBackend attention_backend); + + infinicore::Tensor forward(infinicore::Tensor &query, + infinicore::Tensor &key, + infinicore::Tensor &value, + std::tuple kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const; + + inline infinicore::Tensor get_k_scale() const { return k_scale_; } + inline infinicore::Tensor get_v_scale() const { return v_scale_; } + +private: + infinicore::Tensor k_scale_; + infinicore::Tensor v_scale_; + + AttentionImpl attn_backend_impl_; + ::infinilm::backends::AttentionBackend attn_backend_; +}; +} // namespace infinilm::layers::attention diff --git a/csrc/layers/attention/backends/flash_attn.cpp b/csrc/layers/attention/backends/flash_attn.cpp new file mode 100644 index 00000000..f9f7dbe3 --- /dev/null +++ b/csrc/layers/attention/backends/flash_attn.cpp @@ -0,0 +1,104 @@ +#include "flash_attn.hpp" + +#include "../../../global_state/global_state.hpp" +#include "../../../utils.hpp" +#include "infinicore/ops.hpp" +#include "infinicore/ops/mha_kvcache.hpp" +#include "infinicore/ops/mha_varlen.hpp" + +namespace infinilm::layers::attention::backends { + +FlashAttentionImpl::FlashAttentionImpl(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx) + : num_heads_(num_heads), + head_size_(head_size), + scale_(scale), + num_kv_heads_(num_kv_heads), + layer_idx_(layer_idx), + head_dim_(head_size) { + + const infinilm::global_state::InfinilmConfig &infinilm_config = infinilm::global_state::get_infinilm_config(); + if (!infinilm_config.model_config) { + throw std::runtime_error("infinilm::layers::attention::backends::FlashAttentionImpl: model_config is null"); + } + max_position_embeddings_ = infinilm_config.model_config->get("max_position_embeddings"); +} + +infinicore::Tensor FlashAttentionImpl::forward(const AttentionLayer &layer, + const infinicore::Tensor &query, + const infinicore::Tensor &key, + const infinicore::Tensor &value, + std::tuple kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const { + auto total_sequence_lengths = attn_metadata.total_sequence_lengths; + auto input_offsets = attn_metadata.input_offsets; + auto block_tables = attn_metadata.block_tables; + auto slot_mapping = attn_metadata.slot_mapping; + auto cu_seqlens = attn_metadata.cu_seqlens; + + ASSERT(block_tables.has_value()); + ASSERT(slot_mapping.has_value()); + + // 1. update paged kv cache + auto [k_total, v_total] = do_kv_cache_update(layer, key, value, kv_cache, slot_mapping.value()); + + size_t seq_len = query->shape()[0]; + bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); + + // 2. Compute attention + infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + if (is_prefill) { + infinicore::op::mha_varlen_( + attn_output, + query, + k_total->permute({0, 2, 1, 3}), + v_total->permute({0, 2, 1, 3}), + input_offsets.value(), + cu_seqlens.value(), + block_tables.value(), + max_position_embeddings_, + max_position_embeddings_, + std::nullopt, + scale_); + } else { + // FA2 decode path: flash::mha_fwd_kvcache + // In paged-attn mode, seq_len = actual batch_size (one query token per sequence). + // q_reshaped: [seq_len, num_heads, head_dim] → [seq_len, 1, num_heads, head_dim] + // k/v cache: [num_blocks, num_kv_heads, block_size, head_dim] + // → permute {0,2,1,3} → [num_blocks, block_size, num_kv_heads, head_dim] + auto q_for_fa = query->view({seq_len, 1, num_heads_, head_dim_}); + auto attn_out_4d = infinicore::op::mha_kvcache( + q_for_fa, + k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim] + v_total->permute({0, 2, 1, 3}), + total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence) + block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32 + std::nullopt, + scale_); + attn_output = attn_out_4d->view({seq_len, num_heads_, head_dim_}); + } + attn_output = attn_output->view({1, seq_len, num_heads_ * head_dim_}); + return attn_output; +} + +std::tuple FlashAttentionImpl::do_kv_cache_update(const AttentionLayer &layer, + const infinicore::Tensor key, + const infinicore::Tensor value, + std::tuple kv_cache, + const infinicore::Tensor slot_mapping) const { + auto k_cache_layer = std::get<0>(kv_cache); + auto v_cache_layer = std::get<1>(kv_cache); + infinicore::op::paged_caching_( + k_cache_layer, + v_cache_layer, + key, + value, + slot_mapping); + + return {k_cache_layer, v_cache_layer}; +} + +} // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/flash_attn.hpp b/csrc/layers/attention/backends/flash_attn.hpp new file mode 100644 index 00000000..9368726a --- /dev/null +++ b/csrc/layers/attention/backends/flash_attn.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include "../../../global_state/global_state.hpp" +#include "infinicore/tensor.hpp" +#include +#include + +namespace infinilm::layers::attention { +class AttentionLayer; +} + +namespace infinilm::layers::attention::backends { + +class FlashAttentionImpl { +public: + FlashAttentionImpl(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx); + + /** + * @brief Forward pass with FlashAttention. + * + * @param layer: The `AttentionLayer` instance. + * @param query: Query tensor, shape `[num_tokens, num_heads, head_dim]`. + * @param key: Key tensor, shape `[num_tokens, num_kv_heads, head_dim]`. + * @param value: Value tensor, shape `[num_tokens, num_kv_heads, head_dim]`. + * @param kv_cache: `(k_cache, v_cache)` paged KV tensors for this layer. + * @param attn_metadata: Attention metadata. + * @return Attention output, shape `[num_tokens, num_heads * head_dim]`. + */ + infinicore::Tensor forward(const AttentionLayer &layer, + const infinicore::Tensor &query, + const infinicore::Tensor &key, + const infinicore::Tensor &value, + std::tuple kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const; + + std::tuple do_kv_cache_update(const AttentionLayer &layer, + const infinicore::Tensor key, + const infinicore::Tensor value, + std::tuple kv_cache, + const infinicore::Tensor slot_mapping) const; + +private: + size_t num_heads_; + size_t head_size_; + float scale_; + size_t num_kv_heads_; + size_t layer_idx_; + size_t head_dim_; // Note: head_dim equals to head_size + size_t max_position_embeddings_; +}; +} // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.cpp b/csrc/layers/attention/backends/paged_attn.cpp new file mode 100644 index 00000000..a04bca0e --- /dev/null +++ b/csrc/layers/attention/backends/paged_attn.cpp @@ -0,0 +1,83 @@ +#include "paged_attn.hpp" + +#include "../../../utils.hpp" +#include "infinicore/ops.hpp" + +namespace infinilm::layers::attention::backends { + +PagedAttentionImpl::PagedAttentionImpl(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx) + : num_heads_(num_heads), + head_size_(head_size), + scale_(scale), + num_kv_heads_(num_kv_heads), + layer_idx_(layer_idx), + head_dim_(head_size) {} + +infinicore::Tensor PagedAttentionImpl::forward(const AttentionLayer &layer, + const infinicore::Tensor &query, + const infinicore::Tensor &key, + const infinicore::Tensor &value, + std::tuple kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const { + auto total_sequence_lengths = attn_metadata.total_sequence_lengths; + auto input_offsets = attn_metadata.input_offsets; + auto block_tables = attn_metadata.block_tables; + auto slot_mapping = attn_metadata.slot_mapping; + ASSERT(block_tables.has_value()); + ASSERT(slot_mapping.has_value()); + + // 1. update paged kv cache + auto [k_total, v_total] = do_kv_cache_update(layer, key, value, kv_cache, slot_mapping.value()); + + size_t seq_len = query->shape()[0]; + bool is_prefill = (seq_len != total_sequence_lengths.value()->shape()[0]); + + // 2. Compute attention + infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + if (is_prefill) { + infinicore::op::paged_attention_prefill_( + attn_output, + query, + k_total, + v_total, + block_tables.value(), + total_sequence_lengths.value(), + input_offsets.value(), + std::nullopt, + scale_); + } else { + infinicore::op::paged_attention_( + attn_output, + query, + k_total, + v_total, + block_tables.value(), + total_sequence_lengths.value(), + std::nullopt, + scale_); + } + attn_output = attn_output->view({1, seq_len, num_heads_ * head_dim_}); + return attn_output; +} + +std::tuple PagedAttentionImpl::do_kv_cache_update(const AttentionLayer &layer, + const infinicore::Tensor key, + const infinicore::Tensor value, + std::tuple kv_cache, + const infinicore::Tensor slot_mapping) const { + auto k_cache_layer = std::get<0>(kv_cache); + auto v_cache_layer = std::get<1>(kv_cache); + infinicore::op::paged_caching_( + k_cache_layer, + v_cache_layer, + key, + value, + slot_mapping); + + return {k_cache_layer, v_cache_layer}; +} +} // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/paged_attn.hpp b/csrc/layers/attention/backends/paged_attn.hpp new file mode 100644 index 00000000..a15f73e8 --- /dev/null +++ b/csrc/layers/attention/backends/paged_attn.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "../../../global_state/global_state.hpp" +#include "infinicore/tensor.hpp" +#include +#include + +namespace infinilm::layers::attention { +class AttentionLayer; +} + +namespace infinilm::layers::attention::backends { + +class PagedAttentionImpl { +public: + PagedAttentionImpl(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx); + + /** + * @brief Forward pass with PagedAttention. + * + * @param layer: The `AttentionLayer` instance. + * @param query: Query tensor, shape `[num_tokens, num_heads, head_dim]`. + * @param key: Key tensor, shape `[num_tokens, num_kv_heads, head_dim]`. + * @param value: Value tensor, shape `[num_tokens, num_kv_heads, head_dim]`. + * @param kv_cache: `(k_cache, v_cache)` paged KV tensors for this layer. + * @param attn_metadata: Attention metadata. + * @return Attention output, shape `[num_tokens, num_heads * head_dim]`. + */ + infinicore::Tensor forward(const AttentionLayer &layer, + const infinicore::Tensor &query, + const infinicore::Tensor &key, + const infinicore::Tensor &value, + std::tuple kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const; + + std::tuple do_kv_cache_update(const AttentionLayer &layer, + const infinicore::Tensor key, + const infinicore::Tensor value, + std::tuple kv_cache, + const infinicore::Tensor slot_mapping) const; + +private: + size_t num_heads_; + size_t head_size_; + float scale_; + size_t num_kv_heads_; + size_t layer_idx_; + size_t head_dim_; // Note: head_dim equals to head_size +}; +} // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/static_attn.cpp b/csrc/layers/attention/backends/static_attn.cpp new file mode 100644 index 00000000..ceb7cc3c --- /dev/null +++ b/csrc/layers/attention/backends/static_attn.cpp @@ -0,0 +1,133 @@ +#include "static_attn.hpp" +#include "../../../global_state/global_state.hpp" +#include "../../../utils.hpp" +#include "attention_layer.hpp" +#include "infinicore/ops.hpp" +#include "infinicore/ops/per_tensor_dequant_i8.hpp" +#include "infinicore/ops/per_tensor_quant_i8.hpp" + +namespace infinilm::layers::attention::backends { + +StaticAttentionImpl::StaticAttentionImpl(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx) + : num_heads_(num_heads), + head_size_(head_size), + scale_(scale), + num_kv_heads_(num_kv_heads), + layer_idx_(layer_idx), + head_dim_(head_size) { + kv_quant_scheme_ = infinilm::global_state::get_infinilm_config().model_config->get_kv_quant_scheme(); +} + +infinicore::Tensor StaticAttentionImpl::forward(const AttentionLayer &layer, + infinicore::Tensor &q_rope, + infinicore::Tensor &k_reshaped, + infinicore::Tensor &v_reshaped, + std::tuple kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const { + + auto k_scale = layer.get_k_scale(); + auto v_scale = layer.get_v_scale(); + if (infinicore::quantization::KVQuantAlgo::NONE != this->kv_quant_scheme_) { + infinilm::KVQuantUtils::quantize( + k_reshaped, v_reshaped, + this->kv_quant_scheme_, + k_scale, + v_scale); + } + + auto q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim] + auto k_permuted = k_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] + auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim] + + // Prepare Attn + auto shape = q_reshaped->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[2]; + + auto past_sequence_lengths = attn_metadata.past_sequence_lengths; + auto total_sequence_lengths = attn_metadata.total_sequence_lengths; + + // update static kv cache + // k_total: [bs, n_kv_head, max_seq_len, head_dim] + // v_total : [bs, n_kv_head, max_seq_len, head_dim] + auto [k_total, v_total] = do_kv_cache_update(layer, k_permuted, v_permuted, kv_cache, past_sequence_lengths.value()); + + infinicore::Tensor attn_output; + if (false) { + // experimental nineoothed flash attention + attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scale_, true); + attn_output = attn_output->permute({0, 2, 1, 3}) + ->contiguous() + ->view({batch_size, seq_len, num_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] + } else { + size_t total_seq_len = reinterpret_cast(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0]; + + if (infinicore::quantization::KVQuantAlgo::NONE != this->kv_quant_scheme_) { + infinilm::KVQuantUtils::dequantize( + k_total, v_total, + this->kv_quant_scheme_, + k_scale, + v_scale, + q_reshaped); + } + + k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] + v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] + + // Compute attention + size_t ngroup = num_heads_ / num_kv_heads_; + auto Q = q_reshaped->view({batch_size * num_kv_heads_, ngroup * seq_len, head_dim_}); + auto K = k_total->view({batch_size * num_kv_heads_, total_seq_len, head_dim_}); + auto V = v_total->view({batch_size * num_kv_heads_, total_seq_len, head_dim_}); + + auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len] + + auto attn_weight = infinicore::op::matmul(Q, K_transposed, scale_); // [bs * n_kv_head, ng * seq_len, total_seq_len] + + auto attn_weight_softmax = attn_weight->view({batch_size * num_heads_, seq_len, total_seq_len}); + infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax); + + auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim] + + attn_output = out->view({batch_size, num_heads_, seq_len, head_dim_}) + ->permute({0, 2, 1, 3}) + ->contiguous() + ->view({batch_size, seq_len, num_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] + } + return attn_output; +} + +std::tuple StaticAttentionImpl::do_kv_cache_update(const AttentionLayer &layer, + const infinicore::Tensor key, + const infinicore::Tensor value, + std::tuple kv_cache, + const infinicore::Tensor past_sequence_lengths) const { + auto batch_size = key->size(0); + auto update_len = key->size(2); + auto k_cache_layer = std::get<0>(kv_cache); + auto v_cache_layer = std::get<1>(kv_cache); + + size_t max_batch_size = k_cache_layer->size(0); + size_t max_seq_len = k_cache_layer->size(2); + auto device = k_cache_layer->device(); + + ASSERT_EQ(batch_size, max_batch_size); + + size_t cache_pos = reinterpret_cast(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0]; + auto result_len = cache_pos + update_len; + ASSERT(result_len <= max_seq_len); + + auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}}); + auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}}); + + k_cache_update->copy_from(key); + v_cache_update->copy_from(value); + + return {k_cache_layer, v_cache_layer}; +} + +} // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/attention/backends/static_attn.hpp b/csrc/layers/attention/backends/static_attn.hpp new file mode 100644 index 00000000..c1de9f7e --- /dev/null +++ b/csrc/layers/attention/backends/static_attn.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include "../../../global_state/global_state.hpp" +#include "../../../layers/quantization/kv_quant.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/tensor.hpp" +#include + +namespace infinilm::layers::attention { +class AttentionLayer; +} + +namespace infinilm::layers::attention::backends { + +class StaticAttentionImpl { +public: + StaticAttentionImpl(size_t num_heads, + size_t head_size, + float scale, + size_t num_kv_heads, + size_t layer_idx); + + infinicore::Tensor forward(const AttentionLayer &layer, + infinicore::Tensor &q_reshaped, // query + infinicore::Tensor &k_permuted, // key + infinicore::Tensor &v_permuted, // value + std::tuple kv_cache, + const infinilm::global_state::AttentionMetadata &attn_metadata) const; + + std::tuple do_kv_cache_update(const AttentionLayer &layer, + const infinicore::Tensor key, + const infinicore::Tensor value, + std::tuple kv_cache, + const infinicore::Tensor past_sequence_lengths) const; + +private: + size_t num_heads_; + size_t head_size_; + float scale_; + size_t num_kv_heads_; + size_t layer_idx_; + size_t head_dim_; // Note: head_dim equals to head_size + + infinicore::quantization::KVQuantAlgo kv_quant_scheme_; +}; +} // namespace infinilm::layers::attention::backends diff --git a/csrc/layers/causal_lm_templates/text_causal_lm.hpp b/csrc/layers/causal_lm_templates/text_causal_lm.hpp new file mode 100644 index 00000000..f1589a63 --- /dev/null +++ b/csrc/layers/causal_lm_templates/text_causal_lm.hpp @@ -0,0 +1,56 @@ +#pragma once + +#include "../../models/infinilm_model.hpp" +#include "../linear/linear.hpp" +#include "infinicore/device.hpp" + +namespace infinilm::layers::causal_lm_templates { + +/** + * @brief Text Causal Language Modeling class + * + * A generic template class for Causal Language Modeling. + * + * @tparam Model The base model type (e.g., Qwen3Model, Qwen3MoeModel) + * + * Usage example: + * @code + * using Qwen3CausalLM = TextCausalLM; + * @endcode + */ +template +class TextCausalLM : public InfinilmModel { +public: + /** + * @brief Construct TextCausalLM module + * + * @param model_config: Model configuration. + * @param device: Device to create tensors on + */ + TextCausalLM(std::shared_ptr model_config, + const infinicore::Device &device) { + model_config_ = model_config; + + size_t hidden_size = model_config->get("hidden_size"); + size_t vocab_size = model_config->get("vocab_size"); + const auto &dtype{model_config->get_dtype()}; + + model_ = this->register_module("model", model_config, device); + lm_head_ = this->register_module("lm_head", hidden_size, vocab_size, false, dtype, device); + } + + /** + * @brief Forward pass: compute language modeling logits + */ + Output forward(const Input &input) const override { + auto hidden_states = model_->forward(input); + auto logits = lm_head_->forward(hidden_states); + return {logits}; + } + +protected: + INFINICORE_NN_MODULE(Model, model); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); +}; + +} // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/causal_lm_templates/text_decoder_layer.hpp b/csrc/layers/causal_lm_templates/text_decoder_layer.hpp new file mode 100644 index 00000000..0c409163 --- /dev/null +++ b/csrc/layers/causal_lm_templates/text_decoder_layer.hpp @@ -0,0 +1,77 @@ +#pragma once + +#include "../../config/model_config.hpp" +#include "infinicore/device.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/ops.hpp" +#include "infinicore/tensor.hpp" +#include +#include +namespace infinilm::layers::causal_lm_templates { + +/** + * @brief Generic Text decoder layer (transformer block) class. + * + * @tparam Attention: The attention module type (e.g., Qwen3Attention) + * @tparam MLP: The MLP module type (e.g., Qwen3MLP) + */ +template +class TextDecoderLayer : public infinicore::nn::Module { +public: + TextDecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) + : layer_idx_(layer_idx) { + + const auto &dtype{model_config->get_dtype()}; + size_t hidden_size = model_config->get("hidden_size"); + double rms_norm_eps = model_config->get("rms_norm_eps"); + + input_layernorm_ = this->register_module("input_layernorm", hidden_size, rms_norm_eps, dtype, device); + post_attention_layernorm_ = this->register_module("post_attention_layernorm", hidden_size, rms_norm_eps, dtype, device); + + self_attn_ = this->register_module("self_attn", model_config, layer_idx, device); + mlp_ = this->register_module("mlp", model_config, device); + } + + std::tuple forward(infinicore::Tensor &hidden_states, + infinicore::Tensor &residual) { + input_layernorm_->forward_inplace(hidden_states, residual); + hidden_states = self_attn_->forward(hidden_states); + post_attention_layernorm_->forward_inplace(hidden_states, residual); + hidden_states = mlp_->forward(hidden_states); + return std::make_tuple(hidden_states, residual); + } + + infinicore::Tensor forward(infinicore::Tensor &hidden_states) { + auto residual = hidden_states; + hidden_states = input_layernorm_->forward(hidden_states); + hidden_states = self_attn_->forward(hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + + residual = hidden_states; + hidden_states = post_attention_layernorm_->forward(hidden_states); + hidden_states = mlp_->forward(hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + return hidden_states; + } + + size_t layer_idx() const { return layer_idx_; } + + void set_rotary_emb(const std::shared_ptr &rotary_emb) { + if (self_attn_) { + self_attn_->set_rotary_emb(rotary_emb); + } + } + +protected: + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); + INFINICORE_NN_MODULE(Attention, self_attn); + INFINICORE_NN_MODULE(MLP, mlp); + + size_t layer_idx_; +}; + +} // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/causal_lm_templates/text_model.hpp b/csrc/layers/causal_lm_templates/text_model.hpp new file mode 100644 index 00000000..15e79796 --- /dev/null +++ b/csrc/layers/causal_lm_templates/text_model.hpp @@ -0,0 +1,96 @@ +#pragma once + +#include "../../config/model_config.hpp" +#include "../../models/infinilm_model.hpp" +#include "infinicore/nn/embedding.hpp" +#include "infinicore/nn/rmsnorm.hpp" +#include "infinicore/tensor.hpp" +#include +#include + +namespace infinilm::layers::causal_lm_templates { + +/** + * @brief Text model architecture (without language modeling head). + * + * Generic transformer model consisting of: + * - Token embeddings + * - Multiple decoder layers + * - Final layer normalization + * - Rotary Position Embeddings + * + * @tparam DecoderLayer The decoder layer type (e.g., Qwen3DecoderLayer) + */ +template +class TextModel : public infinicore::nn::Module { +public: + TextModel(std::shared_ptr model_config, + const infinicore::Device &device) { + const auto &dtype{model_config->get_dtype()}; + size_t vocab_size = model_config->get("vocab_size"); + size_t hidden_size = model_config->get("hidden_size"); + size_t max_position_embeddings = model_config->get("max_position_embeddings"); + size_t num_hidden_layers = model_config->get("num_hidden_layers"); + double rope_theta = model_config->get("rope_theta"); + double rms_norm_eps = model_config->get("rms_norm_eps"); + + embed_tokens_ = this->register_module("embed_tokens", vocab_size, hidden_size, std::nullopt, dtype, device); + + layers_.reserve(num_hidden_layers); + for (size_t i = 0; i < num_hidden_layers; ++i) { + layers_.push_back(this->register_module("layers." + std::to_string(i), model_config, i, device)); + } + + norm_ = this->register_module("norm", hidden_size, rms_norm_eps, dtype, device); + + rotary_emb_ = this->register_module( + "rotary_emb", model_config->get_head_dim(), max_position_embeddings, + rope_theta, infinicore::nn::RoPE::Algo::GPT_NEOX, dtype, device, + model_config->get_rope_scaling()); + + for (auto &layer : layers_) { + if (layer) { + layer->set_rotary_emb(rotary_emb_); + } + } + } + + infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input) const { + auto input_ids = input.input_ids.value(); + // 1. Embed tokens: input_ids -> [batch, seq_len, hidden_size] + auto hidden_states = embed_tokens_->forward(input_ids); + + // 2. Process through all decoder layers + size_t num_layers = layers_.size(); + infinicore::Tensor residual; + for (size_t i = 0; i < num_layers; ++i) { + layers_.at(i)->forward( + hidden_states, + residual); + } + + norm_->forward_inplace(hidden_states, residual); + return hidden_states; + } + + infinicore::Tensor forward_naive(const infinilm::InfinilmModel::Input &input) const { + auto input_ids = input.input_ids.value(); + auto hidden_states = embed_tokens_->forward(input_ids); + + size_t num_layers = layers_.size(); + for (size_t i = 0; i < num_layers; ++i) { + hidden_states = layers_.at(i)->forward(hidden_states); + } + + hidden_states = norm_->forward(hidden_states); + return hidden_states; + } + +protected: + INFINICORE_NN_MODULE(infinicore::nn::Embedding, embed_tokens); + INFINICORE_NN_MODULE_VEC(DecoderLayer, layers); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, norm); + INFINICORE_NN_MODULE(infinicore::nn::RoPE, rotary_emb); +}; + +} // namespace infinilm::layers::causal_lm_templates diff --git a/csrc/layers/common_modules.hpp b/csrc/layers/common_modules.hpp new file mode 100644 index 00000000..81c07de6 --- /dev/null +++ b/csrc/layers/common_modules.hpp @@ -0,0 +1,21 @@ +#pragma once +#include "mlp/mlp.hpp" +#include "mlp/moe_mlp.hpp" + +#include "attention/attention.hpp" +#include "causal_lm_templates/text_causal_lm.hpp" +#include "causal_lm_templates/text_decoder_layer.hpp" +#include "causal_lm_templates/text_model.hpp" +#include "linear/linear.hpp" + +namespace infinilm::layers { + +using MLP = infinilm::layers::mlp::MLP; +using MoeMLP = infinilm::layers::moe_mlp::MoeMLP; + +namespace attention { + +using AttentionLayer = infinilm::layers::attention::AttentionLayer; + +} // namespace attention +} // namespace infinilm::layers diff --git a/csrc/layers/fused_linear.cpp b/csrc/layers/linear/fused_linear.cpp similarity index 99% rename from csrc/layers/fused_linear.cpp rename to csrc/layers/linear/fused_linear.cpp index 2ff5ffbb..2291a2b8 100644 --- a/csrc/layers/fused_linear.cpp +++ b/csrc/layers/linear/fused_linear.cpp @@ -2,7 +2,7 @@ #include -namespace infinilm::layers { +namespace infinilm::layers::linear { // --------------------------------------------------------- // QKV Parallel Linear // --------------------------------------------------------- @@ -397,4 +397,4 @@ infinicore::nn::Parameter GateUpParallelLinear::get_up_weight_zeros_awq() const return infinicore::nn::Parameter(weight_zeros_->narrow({{1, weight_zeros_->size(1) / 2, weight_zeros_->size(1) / 2}}), 1, tp_rank_, tp_size_); } -} // namespace infinilm::layers +} // namespace infinilm::layers::linear diff --git a/csrc/layers/fused_linear.hpp b/csrc/layers/linear/fused_linear.hpp similarity index 95% rename from csrc/layers/fused_linear.hpp rename to csrc/layers/linear/fused_linear.hpp index 2e1217b6..d5c9adb8 100644 --- a/csrc/layers/fused_linear.hpp +++ b/csrc/layers/linear/fused_linear.hpp @@ -2,9 +2,9 @@ #include "infinicore/nn/linear.hpp" #include "infinicore/quantization.hpp" -#include "../engine/distributed/communication_group.hpp" +#include "../../engine/distributed/communication_group.hpp" -namespace infinilm::layers { +namespace infinilm::layers::linear { class QKVParallelLinear : public infinicore::nn::ColumnParallelLinear { public: explicit QKVParallelLinear(size_t hidden_size, @@ -169,7 +169,7 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { }; #define INFINILM_QKV_LINEAR_INIT(name, q_name, k_name, v_name, ...) \ - name##_ = std::make_shared(__VA_ARGS__); \ + name##_ = std::make_shared(__VA_ARGS__); \ this->register_parameter(std::string(q_name) + ".weight", name##_->get_q_weight()); \ this->register_parameter(std::string(k_name) + ".weight", name##_->get_k_weight()); \ this->register_parameter(std::string(v_name) + ".weight", name##_->get_v_weight()); \ @@ -181,7 +181,7 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias()); #define INFINILM_GATE_UP_LINEAR_INIT(name, gate_name, up_name, ...) \ - name##_ = std::make_shared(__VA_ARGS__); \ + name##_ = std::make_shared(__VA_ARGS__); \ this->register_parameter(std::string(gate_name) + ".weight", name##_->get_gate_weight()); \ this->register_parameter(std::string(up_name) + ".weight", name##_->get_up_weight()); \ if (name##_->has_gate_bias()) \ @@ -191,7 +191,7 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { // ========================= QKV Quantization ================================== #define INFINILM_QKV_LINEAR_W8A8_INIT(name, q_name, k_name, v_name, ...) \ - name##_ = std::make_shared(__VA_ARGS__); \ + name##_ = std::make_shared(__VA_ARGS__); \ this->register_parameter(std::string(q_name) + ".weight", name##_->get_q_weight()); \ this->register_parameter(std::string(q_name) + ".weight_scale", name##_->get_q_weight_scale()); \ this->register_parameter(std::string(k_name) + ".weight", name##_->get_k_weight()); \ @@ -206,7 +206,7 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { this->register_parameter(std::string(v_name) + ".bias", name##_->get_v_bias()); #define INFINILM_QKV_LINEAR_W4A16AWQ_INIT(name, q_name, k_name, v_name, ...) \ - name##_ = std::make_shared(__VA_ARGS__); \ + name##_ = std::make_shared(__VA_ARGS__); \ auto awq_ptr = std::static_pointer_cast(name##_->get_quantization()); \ int packing_num = awq_ptr->get_packing_num(); \ this->register_parameter(std::string(q_name) + ".qweight", name##_->get_q_weight_awq(packing_num)); \ @@ -227,7 +227,7 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { // ========================= Gate-Up Quantization ============================== #define INFINILM_GATE_UP_LINEAR_W8A8_INIT(name, gate_name, up_name, ...) \ - name##_ = std::make_shared(__VA_ARGS__); \ + name##_ = std::make_shared(__VA_ARGS__); \ this->register_parameter(std::string(gate_name) + ".weight", name##_->get_gate_weight()); \ this->register_parameter(std::string(gate_name) + ".weight_scale", name##_->get_gate_weight_scale()); \ this->register_parameter(std::string(up_name) + ".weight", name##_->get_up_weight()); \ @@ -238,7 +238,7 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias()); #define INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(name, gate_name, up_name, ...) \ - name##_ = std::make_shared(__VA_ARGS__); \ + name##_ = std::make_shared(__VA_ARGS__); \ this->register_parameter(std::string(gate_name) + ".qweight", name##_->get_gate_weight_awq()); \ this->register_parameter(std::string(gate_name) + ".qzeros", name##_->get_gate_weight_zeros_awq()); \ this->register_parameter(std::string(gate_name) + ".scales", name##_->get_gate_weight_scale_awq()); \ @@ -249,4 +249,4 @@ class GateUpParallelLinear : public infinicore::nn::ColumnParallelLinear { this->register_parameter(std::string(gate_name) + ".bias", name##_->get_gate_bias()); \ if (name##_->has_up_bias()) \ this->register_parameter(std::string(up_name) + ".bias", name##_->get_up_bias()); -} // namespace infinilm::layers +} // namespace infinilm::layers::linear diff --git a/csrc/layers/linear/linear.hpp b/csrc/layers/linear/linear.hpp new file mode 100644 index 00000000..4cfab257 --- /dev/null +++ b/csrc/layers/linear/linear.hpp @@ -0,0 +1,12 @@ +#pragma once +#include "fused_linear.hpp" + +namespace infinilm::layers::linear { + +using QKVParallelLinear = infinilm::layers::linear::QKVParallelLinear; +using ReplicatedLinear = infinicore::nn::Linear; +using ColumnParallelLinear = infinicore::nn::ColumnParallelLinear; +using RowParallelLinear = infinicore::nn::RowParallelLinear; +using GateUpParallelLinear = infinilm::layers::linear::GateUpParallelLinear; + +} // namespace infinilm::layers::linear diff --git a/csrc/layers/mlp/mlp.cpp b/csrc/layers/mlp/mlp.cpp new file mode 100644 index 00000000..d2eed7ad --- /dev/null +++ b/csrc/layers/mlp/mlp.cpp @@ -0,0 +1,60 @@ +#include "mlp.hpp" +#include "../../global_state/global_state.hpp" +#include "infinicore/ops.hpp" + +namespace infinilm::layers::mlp { + +MLP::MLP(std::shared_ptr model_config, + const infinicore::Device &device) { + + const auto &dtype{model_config->get_dtype()}; + hidden_size_ = model_config->get("hidden_size"); + intermediate_size_ = model_config->get("intermediate_size"); + use_bias_ = model_config->get_or("mlp_bias", false); + + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + int tp_rank = rank_info.tp_rank; + int tp_size = rank_info.tp_size; + + auto quant_scheme = model_config->get_quant_scheme(); + auto quantization_method = model_config->get_quantization_method(); + switch (quant_scheme) { + case infinicore::quantization::QuantScheme::NONE: { + INFINILM_GATE_UP_LINEAR_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, quantization_method, + use_bias_, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, quantization_method, + use_bias_, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8: { + INFINILM_GATE_UP_LINEAR_W8A8_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, quantization_method, + use_bias_, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, quantization_method, + use_bias_, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + case infinicore::quantization::QuantScheme::AWQ_W4A16: { + INFINILM_GATE_UP_LINEAR_W4A16AWQ_INIT(gate_up_proj, "gate_proj", "up_proj", hidden_size_, intermediate_size_, quantization_method, + use_bias_, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(down_proj, intermediate_size_, hidden_size_, quantization_method, + use_bias_, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + default: { + throw std::runtime_error("infinilm::layers::mlp::MLP: unsupported quantization scheme"); + break; + } + } +} + +infinicore::Tensor MLP::forward(const infinicore::Tensor &hidden_states) const { + // 1. Project to gate and up + auto hidden_states_mutable = hidden_states; + auto [gate, up] = gate_up_proj_->forward_split(hidden_states_mutable); + // 2. Apply SwiGLU: silu(gate) * up + auto intermediate = infinicore::op::swiglu(up, gate); + // 3. Project down + auto output = down_proj_->forward(intermediate); + return output; +} +} // namespace infinilm::layers::mlp diff --git a/csrc/layers/mlp/mlp.hpp b/csrc/layers/mlp/mlp.hpp new file mode 100644 index 00000000..d0ba5c09 --- /dev/null +++ b/csrc/layers/mlp/mlp.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include "../../config/model_config.hpp" +#include "../linear/linear.hpp" +#include "infinicore/nn/module.hpp" + +namespace infinilm::layers::mlp { + +/** + * @brief MLP (Feed-Forward Network) module. + * + * Implements the MLP block with: + * - Gate projection + * - Up projection + * - Down projection + * - SiLU activation function + * + * Formula: down_proj(SiLU(gate_proj(x)) * up_proj(x)) + */ +class MLP : public infinicore::nn::Module { +public: + /** + * @brief Construct MLP module + * + * @param model_config: Model configuration. + * @param device Device to create tensors on + */ + MLP(std::shared_ptr model_config, + const infinicore::Device &device); + + /** + * @brief Forward pass: compute MLP output + * + * @param hidden_states Input tensor of shape [batch, seq_len, hidden_size] + * @return Output tensor of shape [batch, seq_len, hidden_size] + */ + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + + // Module information + size_t hidden_size() const { return hidden_size_; } + size_t intermediate_size() const { return intermediate_size_; } + +protected: + INFINICORE_NN_MODULE(infinilm::layers::linear::GateUpParallelLinear, gate_up_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, down_proj); + + size_t hidden_size_; + size_t intermediate_size_; + bool use_bias_; +}; + +} // namespace infinilm::layers::mlp diff --git a/csrc/layers/mlp/moe_mlp.cpp b/csrc/layers/mlp/moe_mlp.cpp new file mode 100644 index 00000000..2d338404 --- /dev/null +++ b/csrc/layers/mlp/moe_mlp.cpp @@ -0,0 +1,46 @@ +#include "moe_mlp.hpp" +#include "../../global_state/global_state.hpp" +#include "infinicore/ops.hpp" + +namespace infinilm::layers::moe_mlp { + +MoeMLP::MoeMLP(std::shared_ptr model_config, + const infinicore::Device &device) { + + const auto &dtype{model_config->get_dtype()}; + hidden_size_ = model_config->get("hidden_size"); + moe_intermediate_size_ = model_config->get("moe_intermediate_size"); + use_bias_ = model_config->get_or("mlp_bias", false); + + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + int tp_rank = rank_info.tp_rank; + int tp_size = rank_info.tp_size; + + auto quant_scheme = model_config->get_quant_scheme(); + auto quantization_method = model_config->get_quantization_method(); + switch (quant_scheme) { + case infinicore::quantization::QuantScheme::NONE: { + INFINICORE_NN_MODULE_INIT(gate_proj, hidden_size_, moe_intermediate_size_, false, + dtype, device, tp_rank, tp_size); + INFINICORE_NN_MODULE_INIT(up_proj, hidden_size_, moe_intermediate_size_, false, + dtype, device, tp_rank, tp_size); + INFINICORE_NN_MODULE_INIT(down_proj, moe_intermediate_size_, hidden_size_, false, + dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + default: { + throw std::runtime_error("infinilm::layers::moe_mlp::MoeMLP: unsupported quantization scheme"); + break; + } + } +} + +infinicore::Tensor MoeMLP::forward(const infinicore::Tensor &hidden_states) const { + auto hidden_states_mutable = hidden_states; + auto gate = gate_proj_->forward(hidden_states_mutable); + auto up = up_proj_->forward(hidden_states_mutable); + auto intermediate = infinicore::op::swiglu(up, gate); + auto output = down_proj_->forward(intermediate); + return output; +} +} // namespace infinilm::layers::moe_mlp diff --git a/csrc/layers/mlp/moe_mlp.hpp b/csrc/layers/mlp/moe_mlp.hpp new file mode 100644 index 00000000..17cbfdb8 --- /dev/null +++ b/csrc/layers/mlp/moe_mlp.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include "../../config/model_config.hpp" +#include "../linear/linear.hpp" +#include "infinicore/nn/module.hpp" + +namespace infinilm::layers::moe_mlp { + +class MoeMLP : public infinicore::nn::Module { +public: + MoeMLP(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + + size_t hidden_size() const { return hidden_size_; } + size_t moe_intermediate_size() const { return moe_intermediate_size_; } + +protected: + INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, gate_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, up_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, down_proj); + + size_t hidden_size_; + size_t moe_intermediate_size_; + bool use_bias_; +}; + +} // namespace infinilm::layers::moe_mlp diff --git a/csrc/layers/kv_quant.cpp b/csrc/layers/quantization/kv_quant.cpp similarity index 100% rename from csrc/layers/kv_quant.cpp rename to csrc/layers/quantization/kv_quant.cpp diff --git a/csrc/layers/kv_quant.hpp b/csrc/layers/quantization/kv_quant.hpp similarity index 100% rename from csrc/layers/kv_quant.hpp rename to csrc/layers/quantization/kv_quant.hpp diff --git a/csrc/models/fm9g/fm9g_for_causal_lm.cpp b/csrc/models/fm9g/fm9g_for_causal_lm.cpp new file mode 100644 index 00000000..f6b25bce --- /dev/null +++ b/csrc/models/fm9g/fm9g_for_causal_lm.cpp @@ -0,0 +1,50 @@ +#include "fm9g_for_causal_lm.hpp" +#include "../models_registry.hpp" + +namespace infinilm::models::fm9g { + +std::shared_ptr create_fm9g_model_config(std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + + nlohmann::json &config_json = model_config->get_config_json(); + if (!config_json.contains("head_dim")) { + size_t head_dim = model_config->get("hidden_size") / model_config->get("num_attention_heads"); + config_json["head_dim"] = head_dim; + } + return model_config; +} + +} // namespace infinilm::models::fm9g + +namespace { + +#ifndef USE_CLASSIC_LLAMA + +INFINILM_REGISTER_CAUSAL_LM_MODEL( + llama, + infinilm::models::fm9g::FM9GForCausalLM, + infinilm::models::fm9g::create_fm9g_model_config); + +INFINILM_REGISTER_CAUSAL_LM_MODEL( + qwen2, + infinilm::models::fm9g::FM9GForCausalLM, + infinilm::models::fm9g::create_fm9g_model_config); + +INFINILM_REGISTER_CAUSAL_LM_MODEL( + fm9g, + infinilm::models::fm9g::FM9GForCausalLM, + infinilm::models::fm9g::create_fm9g_model_config); + +INFINILM_REGISTER_CAUSAL_LM_MODEL( + fm9g7b, + infinilm::models::fm9g::FM9GForCausalLM, + infinilm::models::fm9g::create_fm9g_model_config); + +INFINILM_REGISTER_CAUSAL_LM_MODEL( + minicpm, + infinilm::models::fm9g::FM9GForCausalLM, + infinilm::models::fm9g::create_fm9g_model_config); + +#endif + +} // namespace diff --git a/csrc/models/fm9g/fm9g_for_causal_lm.hpp b/csrc/models/fm9g/fm9g_for_causal_lm.hpp new file mode 100644 index 00000000..d60cfa8b --- /dev/null +++ b/csrc/models/fm9g/fm9g_for_causal_lm.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include "../../layers/common_modules.hpp" +#include + +namespace infinilm::models::fm9g { + +using FM9GMLP = infinilm::layers::MLP; + +using FM9GAttention = infinilm::layers::attention::Attention; + +using FM9GDecoderLayer = infinilm::layers::causal_lm_templates::TextDecoderLayer; + +using FM9GModel = infinilm::layers::causal_lm_templates::TextModel; + +using FM9GForCausalLM = infinilm::layers::causal_lm_templates::TextCausalLM; + +} // namespace infinilm::models::fm9g + +namespace infinilm::models::fm9g { + +std::shared_ptr create_fm9g_model_config(std::shared_ptr model_config); + +} // namespace infinilm::models::fm9g diff --git a/csrc/models/infinilm_model.cpp b/csrc/models/infinilm_model.cpp new file mode 100644 index 00000000..02d2a66f --- /dev/null +++ b/csrc/models/infinilm_model.cpp @@ -0,0 +1,94 @@ +#include "infinilm_model.hpp" +#include "../backends/attention_backends.hpp" +#include "../cache/kv_cache.hpp" +#include "../global_state/global_state.hpp" + +#include + +namespace infinilm { + +void InfinilmModel::reset_cache(const cache::CacheConfig *cache_config) { + if (cache_config == nullptr) { + cache_config_.reset(); + global_state::get_forward_context().kv_cache_vec.clear(); + return; + } + cache_config_ = cache_config->unique_copy(); + auto &kv_cache_vec = global_state::get_forward_context().kv_cache_vec; + kv_cache_vec.clear(); + const backends::AttentionBackend attention_backend = infinilm::global_state::get_infinilm_config().attention_backend; + kv_cache_vec = std::move(default_allocate_kv_cache_tensors(cache_config, model_config_, attention_backend)); +} + +std::vector> InfinilmModel::default_allocate_kv_cache_tensors( + const cache::CacheConfig *cache_config, + const std::shared_ptr &text_config, + const backends::AttentionBackend &attention_backend) { + if (nullptr == cache_config) { + return {}; + } + if (nullptr == text_config) { + throw std::runtime_error("infinilm::InfinilmModel::default_allocate_kv_cache_tensors: text_config is null"); + } + + std::vector> kv_cache_vec; + switch (attention_backend) { + case backends::AttentionBackend::STATIC_ATTN: { + auto static_kv_cache_config = dynamic_cast(cache_config); + if (nullptr == static_kv_cache_config) { + throw std::runtime_error("infinilm::InfinilmModel::default_allocate_kv_cache_tensors: invalid static kv cache config type"); + } + const size_t num_hidden_layers = text_config->get("num_hidden_layers"); + kv_cache_vec.reserve(num_hidden_layers); + + size_t head_dim = text_config->get("head_dim"); + size_t num_key_value_heads = text_config->get("num_key_value_heads"); + size_t max_position_embeddings = text_config->get("max_position_embeddings"); + const auto &dtype = model_config_->get_kv_cache_dtype(); + for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { + auto kv_cache = cache::StaticKVCache::create_layer_kv_cache( + head_dim, + head_dim, + num_key_value_heads, + num_key_value_heads, + max_position_embeddings, + dtype, + *static_kv_cache_config); + kv_cache_vec.push_back(kv_cache); + } + break; + } + case backends::AttentionBackend::FLASH_ATTN: { + ; + } + case backends::AttentionBackend::PAGED_ATTN: { + auto paged_kv_cache_config = dynamic_cast(cache_config); + if (nullptr == paged_kv_cache_config) { + throw std::runtime_error( + "infinilm::InfinilmModel::default_allocate_kv_cache_tensors: invalid paged kv cache config type"); + } + const size_t num_hidden_layers = text_config->get("num_hidden_layers"); + kv_cache_vec.reserve(num_hidden_layers); + + size_t head_dim = text_config->get("head_dim"); + size_t num_key_value_heads = text_config->get("num_key_value_heads"); + const auto &dtype = model_config_->get_kv_cache_dtype(); + for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { + auto kv_cache = cache::PagedKVCache::create_layer_kv_cache( + head_dim, + head_dim, + num_key_value_heads, + num_key_value_heads, + dtype, + *paged_kv_cache_config); + kv_cache_vec.push_back(kv_cache); + } + break; + } + default: + throw std::runtime_error("infinilm::InfinilmModel::default_allocate_kv_cache_tensors: Unsupported attention backend: " + std::to_string(static_cast(attention_backend))); + } + return kv_cache_vec; +} + +} // namespace infinilm diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp index 550bf1aa..b8a38032 100644 --- a/csrc/models/infinilm_model.hpp +++ b/csrc/models/infinilm_model.hpp @@ -1,12 +1,14 @@ #pragma once +#include "../backends/attention_backends.hpp" #include "../cache/cache.hpp" +#include "../config/model_config.hpp" #include "infinicore/nn/module.hpp" -#include "nlohmann/json.hpp" - -#include +#include "infinicore/tensor.hpp" #include +#include +#include namespace infinilm { class InfinilmModel : public infinicore::nn::Module { @@ -42,8 +44,18 @@ class InfinilmModel : public infinicore::nn::Module { virtual ~InfinilmModel() = default; virtual Output forward(const Input &input) const = 0; - - virtual void reset_cache(const cache::CacheConfig *cache_config) = 0; - virtual const cache::CacheConfig *get_cache_config() const = 0; + virtual void reset_cache(const cache::CacheConfig *cache_config); + virtual const cache::CacheConfig *get_cache_config() const { + return cache_config_.get(); + } + +protected: + std::vector> default_allocate_kv_cache_tensors( + const cache::CacheConfig *cache_config, + const std::shared_ptr &text_config, + const backends::AttentionBackend &attention_backend); + + std::unique_ptr cache_config_; + std::shared_ptr model_config_; }; } // namespace infinilm diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index b4ce1f4d..c7895adb 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -334,7 +334,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd infinicore::Tensor attn_output = infinicore::Tensor::empty({seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device()); if (is_prefill) { - if (attention_backend_ == backends::AttentionBackend::FlashAttn) { + if (attention_backend_ == backends::AttentionBackend::FLASH_ATTN) { infinicore::op::mha_varlen_( attn_output, q_reshaped, @@ -360,7 +360,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd scaling_); } } else { - if (attention_backend_ == backends::AttentionBackend::FlashAttn) { + if (attention_backend_ == backends::AttentionBackend::FLASH_ATTN) { // FA2 decode path: flash::mha_fwd_kvcache // In paged-attn mode, seq_len = actual batch_size (one query token per sequence). // q_reshaped: [seq_len, num_heads, head_dim] → [seq_len, 1, num_heads, head_dim] diff --git a/csrc/models/llama/llama_attention.hpp b/csrc/models/llama/llama_attention.hpp index 39493e76..3c128783 100644 --- a/csrc/models/llama/llama_attention.hpp +++ b/csrc/models/llama/llama_attention.hpp @@ -4,8 +4,8 @@ #include "../../cache/kv_cache.hpp" #include "../../config/model_config.hpp" #include "../../engine/distributed/distributed.hpp" -#include "../../layers/fused_linear.hpp" -#include "../../layers/kv_quant.hpp" +#include "../../layers/linear/fused_linear.hpp" +#include "../../layers/quantization/kv_quant.hpp" #include "llama_config.hpp" #include "infinicore/nn/linear.hpp" @@ -115,7 +115,7 @@ class LlamaAttention : public infinicore::nn::Module { protected: // Projection layers - INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::QKVParallelLinear, qkv_proj); INFINICORE_NN_MODULE(infinicore::nn::RowParallelLinear, o_proj); INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm); INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm); diff --git a/csrc/models/llama/llama_for_causal_lm.cpp b/csrc/models/llama/llama_for_causal_lm.cpp index 9596e668..4cd0b6da 100644 --- a/csrc/models/llama/llama_for_causal_lm.cpp +++ b/csrc/models/llama/llama_for_causal_lm.cpp @@ -19,6 +19,7 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, const infinicore::Device &device, engine::distributed::RankInfo rank_info, backends::AttentionBackend attention_backend) { + spdlog::warn("infinilm::models::llama: LlamaForCausalLM is no longer supported, please use the new model instead."); // Initialize module's device_ member device_ = device; @@ -37,6 +38,7 @@ LlamaForCausalLM::LlamaForCausalLM(std::shared_ptr +#include +#include + +namespace infinilm::models::minicpm_sala { + +std::vector> minicpm_sala_allocate_kv_cache_tensors(const cache::CacheConfig *cache_config, + const std::shared_ptr &text_config, + const backends::AttentionBackend &attention_backend) { + if (nullptr == cache_config) { + return {}; + } + if (nullptr == text_config) { + throw std::runtime_error("infinilm::models::minicpm_sala::minicpm_sala_allocate_kv_cache_tensors: text_config is null"); + } + + std::vector> kv_cache_vec; + + switch (attention_backend) { + case backends::AttentionBackend::STATIC_ATTN: { + auto static_kv_cache_config = dynamic_cast(cache_config); + if (nullptr == static_kv_cache_config) { + throw std::runtime_error("infinilm::models::minicpm_sala::minicpm_sala_allocate_kv_cache_tensors: invalid static kv cache config type"); + } + const size_t num_hidden_layers = text_config->get("num_hidden_layers"); + kv_cache_vec.reserve(num_hidden_layers); + + const size_t head_dim = text_config->get("head_dim"); + const size_t num_key_value_heads = text_config->get("num_key_value_heads"); + const size_t max_position_embeddings = text_config->get("max_position_embeddings"); + + const auto &dtype{text_config->get_dtype()}; + std::vector mixer_types = text_config->get>("mixer_types"); + size_t current_layer_head_dim, current_layer_num_key_value_heads; + for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { + std::string mixer_type = mixer_types[layer_idx]; + if ("minicpm4" == mixer_type) { + current_layer_head_dim = head_dim; + current_layer_num_key_value_heads = num_key_value_heads; + } else if ("lightning" == mixer_type || "lightning_attn" == mixer_type || "lightning-attn" == mixer_type) { + current_layer_head_dim = text_config->get("lightning_head_dim"); + current_layer_num_key_value_heads = text_config->get("lightning_nkv"); + } else { + throw std::runtime_error("infinilm::models::minicpm_sala::minicpm_sala_allocate_kv_cache_tensors: unsupported mixer_type '" + mixer_type + "' for layer " + std::to_string(layer_idx)); + } + auto kv_cache = cache::StaticKVCache::create_layer_kv_cache( + current_layer_head_dim, + current_layer_head_dim, + current_layer_num_key_value_heads, + current_layer_num_key_value_heads, + max_position_embeddings, + dtype, + *static_kv_cache_config); + kv_cache_vec.push_back(kv_cache); + } + break; + } + case backends::AttentionBackend::PAGED_ATTN: { + auto paged_kv_cache_config = dynamic_cast(cache_config); + if (nullptr == paged_kv_cache_config) { + throw std::runtime_error( + "infinilm::models::minicpm_sala::minicpm_sala_allocate_kv_cache_tensors: invalid paged kv cache config type"); + } + + const size_t num_hidden_layers = text_config->get("num_hidden_layers"); + kv_cache_vec.reserve(num_hidden_layers); + + const size_t head_dim = text_config->get("head_dim"); + const size_t num_key_value_heads = text_config->get("num_key_value_heads"); + const auto &dtype{text_config->get_dtype()}; + std::vector mixer_types = text_config->get>("mixer_types"); + size_t current_layer_head_dim, current_layer_num_key_value_heads; + for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { + std::string mixer_type = mixer_types[layer_idx]; + if ("minicpm4" == mixer_type) { + current_layer_head_dim = head_dim; + current_layer_num_key_value_heads = num_key_value_heads; + } else if ("lightning" == mixer_type || "lightning_attn" == mixer_type || "lightning-attn" == mixer_type) { + current_layer_head_dim = text_config->get("lightning_head_dim"); + current_layer_num_key_value_heads = text_config->get("lightning_nkv"); + } else { + throw std::runtime_error("infinilm::models::minicpm_sala::minicpm_sala_allocate_kv_cache_tensors: unsupported mixer_type '" + mixer_type + "' for layer " + std::to_string(layer_idx)); + } + auto kv_cache = cache::PagedKVCache::create_layer_kv_cache( + current_layer_head_dim, + current_layer_head_dim, + current_layer_num_key_value_heads, + current_layer_num_key_value_heads, + dtype, + *paged_kv_cache_config); + kv_cache_vec.push_back(kv_cache); + } + break; + } + default: + throw std::runtime_error("infinilm::models::minicpm_sala::minicpm_sala_allocate_kv_cache_tensors: Unsupported attention backend: " + std::to_string(static_cast(attention_backend))); + } + return kv_cache_vec; +} + +} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_attention.cpp b/csrc/models/minicpm_sala/minicpm_sala_attention.cpp new file mode 100644 index 00000000..ac4ac69a --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_attention.cpp @@ -0,0 +1,116 @@ +#include "minicpm_sala_attention.hpp" +#include "../../global_state/global_state.hpp" +#include + +namespace infinilm::models::minicpm_sala { + +AttentionBase::AttentionBase(std::shared_ptr model_config, + size_t num_attention_heads, + size_t num_key_value_heads, + size_t layer_idx, + const infinicore::Device &device) + : layer_idx_(layer_idx), + num_attention_heads_(num_attention_heads), + num_key_value_heads_(num_key_value_heads), + hidden_size_(model_config->get("hidden_size")), + head_dim_(model_config->get("head_dim")) { + + const auto &dtype{model_config->get_dtype()}; + + use_bias_ = model_config->get_or("attention_bias", true); + use_output_bias_ = model_config->get_or("attention_output_bias", false); + double rms_norm_eps = model_config->get("rms_norm_eps"); + float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); + + attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank(); + int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size(); + + auto quant_scheme = model_config->get_quant_scheme(); + auto quantization_method = model_config->get_quantization_method(); + switch (quant_scheme) { + case infinicore::quantization::QuantScheme::NONE: + INFINICORE_NN_MODULE_INIT(q_proj, hidden_size_, num_attention_heads * head_dim_, quantization_method, + use_bias_, dtype, device, tp_rank, tp_size); + INFINICORE_NN_MODULE_INIT(k_proj, hidden_size_, num_key_value_heads * head_dim_, quantization_method, + use_bias_, dtype, device, tp_rank, tp_size); + INFINICORE_NN_MODULE_INIT(v_proj, hidden_size_, num_key_value_heads * head_dim_, quantization_method, + use_bias_, dtype, device, tp_rank, tp_size); + INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads * head_dim_, hidden_size_, quantization_method, + use_output_bias_, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + default: + throw std::runtime_error("infinilm::models::minicpm_sala::AttentionBase: unsupported quantization scheme"); + break; + } + + if ((num_key_value_heads_ < tp_size) || (0 != (num_key_value_heads_ % tp_size))) { + throw std::runtime_error("infinilm::models::minicpm_sala::AttentionBase: num_key_value_heads must be divisible by tp_size"); + } + + size_t num_attention_heads_rank = num_attention_heads_ / tp_size; + size_t num_key_value_heads_rank = num_key_value_heads_ / tp_size; + attn_ = std::make_shared(num_attention_heads_rank, head_dim_, scaling, + num_key_value_heads_rank, layer_idx_, + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + + num_attention_heads_ = num_attention_heads_rank; + num_key_value_heads_ = num_key_value_heads_rank; +} + +InfLLMv2Attention::InfLLMv2Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) + : AttentionBase(model_config, + model_config->get("num_attention_heads"), + model_config->get("num_key_value_heads"), + layer_idx, device) { + use_output_gate_ = model_config->get_or("use_output_gate", false); + const auto &dtype{model_config->get_dtype()}; + size_t num_attention_heads = model_config->get("num_attention_heads"); + if (use_output_gate_) { + INFINICORE_NN_MODULE_INIT(o_gate, hidden_size_, num_attention_heads * head_dim_, + model_config->get_quantization_method(), use_bias_, dtype, device); + } +} + +infinicore::Tensor InfLLMv2Attention::forward(const infinicore::Tensor &hidden_states) const { + spdlog::error("InfLLMv2Attention is not implemented"); + return hidden_states; +} + +LightningAttention::LightningAttention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) + : AttentionBase(model_config, + model_config->get("num_attention_heads"), + model_config->get("lightning_nkv"), + layer_idx, device) { + + qk_norm_ = model_config->get_or("qk_norm", false); + use_output_norm_ = model_config->get_or("use_output_norm", false); + use_output_gate_ = model_config->get_or("use_output_gate", false); + const auto &dtype{model_config->get_dtype()}; + double rms_norm_eps = model_config->get("rms_norm_eps"); + size_t num_attention_heads = model_config->get("num_attention_heads"); + + if (qk_norm_) { + INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); + } + if (use_output_norm_) { + INFINICORE_NN_MODULE_INIT(o_norm, num_attention_heads * head_dim_, rms_norm_eps, dtype, device); + } + if (use_output_gate_) { + INFINICORE_NN_MODULE_INIT(z_proj, hidden_size_, num_attention_heads * head_dim_, + model_config->get_quantization_method(), use_bias_, dtype, device); + } +} + +infinicore::Tensor LightningAttention::forward(const infinicore::Tensor &hidden_states) const { + spdlog::error("LightningAttention is not implemented"); + return hidden_states; +} + +} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_attention.hpp b/csrc/models/minicpm_sala/minicpm_sala_attention.hpp new file mode 100644 index 00000000..f24ddeff --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_attention.hpp @@ -0,0 +1,87 @@ +#pragma once + +#include "../../layers/common_modules.hpp" + +namespace infinilm::layers::attention { +class AttentionLayer; +} + +namespace infinilm::models::minicpm_sala { + +class AttentionBase : public infinicore::nn::Module { +protected: + AttentionBase(std::shared_ptr model_config, + size_t num_attention_heads, + size_t num_key_value_heads, + size_t layer_idx, + const infinicore::Device &device); + +public: + size_t layer_idx() const { return layer_idx_; } + size_t num_heads() const { return num_attention_heads_; } + size_t num_kv_heads() const { return num_key_value_heads_; } + size_t head_dim() const { return head_dim_; } + size_t hidden_size() const { return hidden_size_; } + void set_rotary_emb(const std::shared_ptr &rotary_emb) { rotary_emb_ = rotary_emb; } + +protected: + INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, q_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, k_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::ColumnParallelLinear, v_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj); + + std::shared_ptr attn_; + ::infinilm::backends::AttentionBackend attention_backend_; + std::shared_ptr rotary_emb_; + + size_t layer_idx_; + size_t hidden_size_; + size_t num_attention_heads_; + size_t num_key_value_heads_; + size_t head_dim_; + bool use_bias_; + bool use_output_bias_; + + // For off-line kv cache quantization + INFINICORE_NN_PARAMETER(kv_cache_k_scale); + INFINICORE_NN_PARAMETER(kv_cache_v_scale); +}; + +/** + * @brief InfLLMv2 attention with optional output gate + */ +class InfLLMv2Attention : public AttentionBase { +public: + InfLLMv2Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +protected: + bool use_output_gate_; + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, o_gate); +}; + +/** + * @brief Lightning attention with optional output norm and gate + */ +class LightningAttention : public AttentionBase { +public: + LightningAttention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +protected: + bool qk_norm_; + bool use_output_norm_; + bool use_output_gate_; + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, o_norm); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, z_proj); +}; + +} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.cpp b/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.cpp new file mode 100644 index 00000000..5bceaec3 --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.cpp @@ -0,0 +1,66 @@ +#include "minicpm_sala_decoderLayer.hpp" + +#include "infinicore/ops.hpp" +#include +#include +#include + +namespace infinilm::models::minicpm_sala { + +MiniCPMSALADecoderLayer::MiniCPMSALADecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) + : layer_idx_(layer_idx) { + const auto &dtype{model_config->get_dtype()}; + size_t hidden_size = model_config->get("hidden_size"); + size_t intermediate_size = model_config->get("intermediate_size"); + double rms_norm_eps = model_config->get("rms_norm_eps"); + + INFINICORE_NN_MODULE_INIT(input_layernorm, hidden_size, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(post_attention_layernorm, hidden_size, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(mlp, model_config, device); + + std::vector mixer_types = model_config->get>("mixer_types"); + std::string mixer_type = mixer_types[layer_idx]; + if ("minicpm4" == mixer_type) { + self_attn_ = std::make_shared(this->register_module("self_attn", model_config, layer_idx, device)); + } else if ("lightning" == mixer_type || "lightning_attn" == mixer_type || "lightning-attn" == mixer_type) { + self_attn_ = std::make_shared(this->register_module("self_attn", model_config, layer_idx, device)); + } else { + throw std::runtime_error("infinilm::models::minicpm_sala::MiniCPMSALADecoderLayer: unsupported mixer_type '" + mixer_type + "' for layer " + std::to_string(layer_idx)); + } +} + +std::tuple MiniCPMSALADecoderLayer::forward(infinicore::Tensor &hidden_states, + infinicore::Tensor &residual) { + input_layernorm_->forward_inplace(hidden_states, residual); + hidden_states = std::visit( + [&](auto &attn_ptr) { return attn_ptr->forward(hidden_states); }, *self_attn_); + + post_attention_layernorm_->forward_inplace(hidden_states, residual); + hidden_states = mlp_->forward(hidden_states); + return std::make_tuple(hidden_states, residual); +} + +infinicore::Tensor MiniCPMSALADecoderLayer::forward(infinicore::Tensor &hidden_states) { + auto residual = hidden_states; + hidden_states = input_layernorm_->forward(hidden_states); + hidden_states = std::visit( + [&](auto &attn_ptr) { return attn_ptr->forward(hidden_states); }, *self_attn_); + + hidden_states = infinicore::op::add(residual, hidden_states); + + residual = hidden_states; + hidden_states = post_attention_layernorm_->forward(hidden_states); + hidden_states = mlp_->forward(hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + return hidden_states; +} + +void MiniCPMSALADecoderLayer::set_rotary_emb(const std::shared_ptr &rotary_emb) { + if (self_attn_) { + std::visit([&](auto &attn_ptr) { attn_ptr->set_rotary_emb(rotary_emb); }, *self_attn_); + } +} + +} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.hpp b/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.hpp new file mode 100644 index 00000000..a0411b25 --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_decoderLayer.hpp @@ -0,0 +1,33 @@ +#pragma once + +#include "../../layers/mlp/mlp.hpp" +#include "minicpm_sala_attention.hpp" +#include +#include + +namespace infinilm::models::minicpm_sala { +using MiniCPMMLP = infinilm::layers::MLP; +using MiniCPMSALAAttention = std::variant, std::shared_ptr>; + +class MiniCPMSALADecoderLayer : public infinicore::nn::Module { +public: + MiniCPMSALADecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + std::tuple forward(infinicore::Tensor &hidden_states, infinicore::Tensor &residual); + + infinicore::Tensor forward(infinicore::Tensor &hidden_states); + + void set_rotary_emb(const std::shared_ptr &rotary_emb); + +protected: + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); + INFINICORE_NN_MODULE(MiniCPMSALAAttention, self_attn); + INFINICORE_NN_MODULE(MiniCPMMLP, mlp); + + size_t layer_idx_; +}; + +} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.cpp b/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.cpp new file mode 100644 index 00000000..793f86bd --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.cpp @@ -0,0 +1,56 @@ +#include "minicpm_sala_for_causal_lm.hpp" +#include "../../global_state/global_state.hpp" +#include "../models_registry.hpp" +#include +#include + +namespace infinilm::models::minicpm_sala { + +MiniCPMSALAForCausalLM::MiniCPMSALAForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device) { + model_config_ = model_config; + size_t hidden_size = model_config->get("hidden_size"); + size_t vocab_size = model_config->get("vocab_size"); + const auto &dtype{model_config->get_dtype()}; + + INFINICORE_NN_MODULE_INIT(model, model_config, device); + INFINICORE_NN_MODULE_INIT(lm_head, hidden_size, vocab_size, false, dtype, device); +} + +infinilm::InfinilmModel::Output MiniCPMSALAForCausalLM::forward(const infinilm::InfinilmModel::Input &input) const { + auto hidden_states = model_->forward(input); + auto logits = lm_head_->forward(hidden_states); + return {logits}; +} + +void MiniCPMSALAForCausalLM::reset_cache(const cache::CacheConfig *cache_config) { + if (nullptr == cache_config) { + InfinilmModel::reset_cache(nullptr); + return; + } + cache_config_ = cache_config->unique_copy(); + + auto &kv_cache_vec = infinilm::global_state::get_forward_context().kv_cache_vec; + kv_cache_vec.clear(); + const backends::AttentionBackend attention_backend = infinilm::global_state::get_infinilm_config().attention_backend; + + auto new_kv_cache_vec = minicpm_sala_allocate_kv_cache_tensors(cache_config, model_config_, attention_backend); + kv_cache_vec = std::move(new_kv_cache_vec); +} + +std::shared_ptr create_minicpm_sala_model_config(std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + if ("minicpm_sala" != model_type) { + throw std::runtime_error("infinilm::models::minicpm_sala::create_minicpm_sala_model_config: model_type is not minicpm_sala"); + } + return model_config; +} + +} // namespace infinilm::models::minicpm_sala + +namespace { +INFINILM_REGISTER_CAUSAL_LM_MODEL( + minicpm_sala, + infinilm::models::minicpm_sala::MiniCPMSALAForCausalLM, + infinilm::models::minicpm_sala::create_minicpm_sala_model_config); +} // namespace diff --git a/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.hpp b/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.hpp new file mode 100644 index 00000000..720814dc --- /dev/null +++ b/csrc/models/minicpm_sala/minicpm_sala_for_causal_lm.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include "minicpm_sala_decoderLayer.hpp" +#include + +namespace infinilm::models::minicpm_sala { + +using MiniCPMSALAModel = infinilm::layers::causal_lm_templates::TextModel; + +class MiniCPMSALAForCausalLM : public InfinilmModel { +public: + MiniCPMSALAForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device); + + Output forward(const Input &input) const override; + + void reset_cache(const cache::CacheConfig *cache_config) override; + +protected: + INFINICORE_NN_MODULE(MiniCPMSALAModel, model); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); +}; + +std::shared_ptr create_minicpm_sala_model_config(std::shared_ptr model_config); + +/** Implemented in `minicpm_sala_allocate_kv_cache_tensors.cpp`. */ +std::vector> minicpm_sala_allocate_kv_cache_tensors(const cache::CacheConfig *cache_config, + const std::shared_ptr &text_config, + const backends::AttentionBackend &attention_backend); +} // namespace infinilm::models::minicpm_sala diff --git a/csrc/models/model_factory.cpp b/csrc/models/model_factory.cpp index 319a7baa..03734ac9 100644 --- a/csrc/models/model_factory.cpp +++ b/csrc/models/model_factory.cpp @@ -1,5 +1,6 @@ #include "model_factory.hpp" -#include "llama/llama.hpp" +#include "llama/llama_for_causal_lm.hpp" +#include "models_registry.hpp" namespace infinilm { /** @@ -40,7 +41,6 @@ std::shared_ptr InfinilmModelFactory::createModel( engine::distributed::RankInfo rank_info, const cache::CacheConfig *cache, backends::AttentionBackend attention_backend) { - std::shared_ptr model; if (true) { model = std::make_shared( @@ -55,4 +55,26 @@ std::shared_ptr InfinilmModelFactory::createModel( return model; } + +std::shared_ptr InfinilmModelFactory::createModel( + std::shared_ptr model_config, + const infinicore::Device &device, + const cache::CacheConfig *cache) { + const std::string model_type = model_config->get("model_type"); + std::shared_ptr model; + const auto &model_map = models::get_causal_lm_model_map(); + auto it = model_map.find(model_type); + if (it != model_map.end()) { + // create model + auto &model_creator = it->second; + model = model_creator(model_config, device); + } else { + throw std::invalid_argument("InfinilmModelFactory::createModel: Unsupported model_type"); + } + + if (cache) { + model->reset_cache(cache); + } + return model; +} } // namespace infinilm diff --git a/csrc/models/model_factory.hpp b/csrc/models/model_factory.hpp index 3c3c2e38..787a6406 100644 --- a/csrc/models/model_factory.hpp +++ b/csrc/models/model_factory.hpp @@ -1,10 +1,8 @@ #pragma once -#include "../config/model_config.hpp" -#include "infinilm_model.hpp" - #include "../backends/attention_backends.hpp" #include "../engine/distributed/distributed.hpp" +#include "infinilm_model.hpp" namespace infinilm { class InfinilmModelFactory { @@ -27,10 +25,18 @@ class InfinilmModelFactory { const cache::CacheConfig *cache = nullptr, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); + /** + * @deprecated This function is deprecated and will be REMOVED in the next major release. + */ static std::shared_ptr createModel( std::shared_ptr model_config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), const cache::CacheConfig *cache = nullptr, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); + + static std::shared_ptr createModel( + std::shared_ptr model_config, + const infinicore::Device &device, + const cache::CacheConfig *cache = nullptr); }; } // namespace infinilm diff --git a/csrc/models/models_registry.cpp b/csrc/models/models_registry.cpp new file mode 100644 index 00000000..030267c3 --- /dev/null +++ b/csrc/models/models_registry.cpp @@ -0,0 +1,32 @@ +#include "models_registry.hpp" +#include + +namespace infinilm::models { +namespace { +std::map &causal_lm_models_map() { + static std::map m; + return m; +} +std::map &model_configs_map() { + static std::map m; + return m; +} +} // namespace + +void register_causal_lm_model(const std::string &model_type, ModelCreator creator) { + causal_lm_models_map()[model_type] = std::move(creator); +} + +void register_model_config(const std::string &model_type, ConfigCreator creator) { + model_configs_map()[model_type] = std::move(creator); +} + +const std::map &get_causal_lm_model_map() { + return causal_lm_models_map(); +} + +const std::map &get_model_config_map() { + return model_configs_map(); +} + +} // namespace infinilm::models diff --git a/csrc/models/models_registry.hpp b/csrc/models/models_registry.hpp new file mode 100644 index 00000000..d73e2e41 --- /dev/null +++ b/csrc/models/models_registry.hpp @@ -0,0 +1,79 @@ +#pragma once + +#include "infinicore/device.hpp" +#include +#include +#include +#include + +namespace infinilm::config { +class ModelConfig; +} + +namespace infinilm { +class InfinilmModel; +} + +namespace infinilm::models { + +/** + * @brief Factory that builds a causal LM instance from config and device. + */ +using ModelCreator = std::function(std::shared_ptr, + const infinicore::Device &)>; + +/** + * @brief post-processor for `ModelConfig`. + */ +using ConfigCreator = std::function(std::shared_ptr)>; + +/** + * @brief Register one causal LM constructor for a `model_type` string. + */ +void register_causal_lm_model(const std::string &model_type, ModelCreator creator); + +/** + * @brief Register one config post-processor for a `model_type` string. + */ +void register_model_config(const std::string &model_type, ConfigCreator creator); + +/** + * @brief Snapshot of all registered causal LM factories. + * + * @return Map from `model_type` to `ModelCreator`. + */ +const std::map &get_causal_lm_model_map(); + +/** + * @brief Snapshot of all registered config post-processors. + * + * @return Map from `model_type` to `ConfigCreator`. + */ +const std::map &get_model_config_map(); + +/** + * @brief Used by `INFINILM_REGISTER_CAUSAL_LM_MODEL`: registers model factory + config handler at static init. + * + * @tparam ModelT Causal LM type constructible as `std::make_shared(config, device)`. + * @tparam ConfigCreatorFn Type of a function like `create_qwen3_model_config` (for `decltype`). + */ +template +struct CausalLmRegistrar { + /** + * @brief Calls `register_causal_lm_model` and `register_model_config` for `model_type`. + */ + explicit CausalLmRegistrar(const char *model_type, ConfigCreatorFn config_creator) { + infinilm::models::register_causal_lm_model( + model_type, + [](std::shared_ptr config, const infinicore::Device &device) { + return std::make_shared(std::move(config), device); + }); + + infinilm::models::register_model_config(model_type, config_creator); + } +}; + +} // namespace infinilm::models + +#define INFINILM_REGISTER_CAUSAL_LM_MODEL(model_type, ModelT, ConfigCreatorFn) \ + auto g_##model_type##_registry = infinilm::models::CausalLmRegistrar(#model_type, ConfigCreatorFn) diff --git a/csrc/models/qwen3/qwen3_attention.cpp b/csrc/models/qwen3/qwen3_attention.cpp new file mode 100644 index 00000000..0a004d0e --- /dev/null +++ b/csrc/models/qwen3/qwen3_attention.cpp @@ -0,0 +1,200 @@ +#include "qwen3_attention.hpp" +#include "../../global_state/global_state.hpp" +#include "../../utils.hpp" + +namespace infinilm::models::qwen3 { +using infinilm::global_state::AttentionMetadata; + +Qwen3Attention::Qwen3Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + layer_idx_ = layer_idx; + + const auto &dtype{model_config->get_dtype()}; + num_attention_heads_ = model_config->get("num_attention_heads"); + num_key_value_heads_ = model_config->get("num_key_value_heads"); + hidden_size_ = model_config->get("hidden_size"); + head_dim_ = model_config->get("head_dim"); + + float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); + + bool use_bias = model_config->get_or("attention_bias", true); + bool use_output_bias = model_config->get_or("attention_output_bias", false); + double rms_norm_eps = model_config->get("rms_norm_eps"); + + attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank(); + int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size(); + + auto quant_scheme = model_config->get_quant_scheme(); + auto quantization_method = model_config->get_quantization_method(); + switch (quant_scheme) { + case infinicore::quantization::QuantScheme::NONE: { + INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_, + quantization_method, use_bias, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads_ * head_dim_, hidden_size_, quantization_method, + use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8: { + INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_, + quantization_method, use_bias, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads_ * head_dim_, hidden_size_, quantization_method, + use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + case infinicore::quantization::QuantScheme::AWQ_W4A16: { + INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, num_attention_heads_, num_key_value_heads_, + quantization_method, use_bias, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, num_attention_heads_ * head_dim_, hidden_size_, quantization_method, + use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + default: { + throw std::runtime_error("infinilm::models::qwen3::Qwen3Attention: unsupported quantization scheme"); + break; + } + } + + INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); + + auto kv_quant_scheme = infinilm::global_state::get_infinilm_config().model_config->get_kv_quant_scheme(); + switch (kv_quant_scheme) { + case (infinicore::quantization::KVQuantAlgo::NONE): { + break; + } + case (infinicore::quantization::KVQuantAlgo::INT8): { + INFINICORE_NN_PARAMETER_INIT(kv_cache_k_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); + INFINICORE_NN_PARAMETER_INIT(kv_cache_v_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); + break; + } + default: { + throw std::runtime_error("infinilm::layers::attention: unsupported kv_quant_scheme"); + break; + } + } + + if ((num_key_value_heads_ < tp_size) || (0 != (num_key_value_heads_ % tp_size))) { + throw std::runtime_error("infinilm::models::qwen3::Qwen3Attention: num_key_value_heads must be divisible by tp_size"); + } + + size_t num_attention_heads_rank = num_attention_heads_ / tp_size; + size_t num_key_value_heads_rank = num_key_value_heads_ / tp_size; + attn_ = std::make_shared(num_attention_heads_rank, head_dim_, scaling, num_key_value_heads_rank, layer_idx_, + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + + num_attention_heads_ = num_attention_heads_rank; + num_key_value_heads_ = num_key_value_heads_rank; +} + +infinicore::Tensor Qwen3Attention::forward(const infinicore::Tensor &hidden_states) const { + if (!rotary_emb_) { + throw std::runtime_error("infinilm::models::qwen3::Qwen3Attention: rotary_emb not configured"); + } + + auto &forward_context = infinilm::global_state::get_forward_context(); + AttentionMetadata &attn_metadata = forward_context.attn_metadata; + std::tuple &kv_cache = forward_context.kv_cache_vec[layer_idx_]; + + if (::infinilm::backends::AttentionBackend::STATIC_ATTN == attention_backend_) { + return forward_static_(hidden_states, attn_metadata, kv_cache); + } + return forward_paged_(hidden_states, attn_metadata, kv_cache); +} + +infinicore::Tensor Qwen3Attention::forward_static_(const infinicore::Tensor &hidden_states, + const infinilm::global_state::AttentionMetadata &attn_metadata, + std::tuple &kv_cache) const { + auto position_ids = attn_metadata.position_ids.value(); + + // hidden_states shape: [batch, seq_len, hidden_size] + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + // 1. Project Q, K, V + auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + + q = q_norm_->forward(q->view({batch_size * seq_len, num_attention_heads_, head_dim_})); + k = k_norm_->forward(k->view({batch_size * seq_len, num_key_value_heads_, head_dim_})); + + // 2. Reshape for multi-head attention + auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->view({batch_size, seq_len, num_key_value_heads_, head_dim_}); + + // 3. Prepare position_ids for RoPE + auto pos_shape = position_ids->shape(); + infinicore::Tensor pos_ids_for_rope = position_ids; + if (pos_shape.size() == 2) { + auto pos_narrowed = position_ids->narrow({{0, 0, 1}}); + pos_ids_for_rope = pos_narrowed->contiguous()->view({pos_shape[1]}); + } else if (pos_shape.size() == 1) { + pos_ids_for_rope = position_ids->contiguous(); + } else { + throw std::runtime_error("infinilm::models::qwen3::Qwen3Attention: Unexpected position_ids shape"); + } + + // 4. Apply RoPE to QK + auto q_rope = infinicore::Tensor::empty({batch_size, num_attention_heads_, seq_len, head_dim_}, q_reshaped->dtype(), q_reshaped->device())->permute({0, 2, 1, 3}); + rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); + rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); + + // 6. Attn Backend calculate + auto attn_output = attn_->forward(q_rope, k_reshaped, v_reshaped, kv_cache, attn_metadata); + + // 7. Project output + auto output = o_proj_->forward(attn_output); + return output; +} + +infinicore::Tensor Qwen3Attention::forward_paged_(const infinicore::Tensor &hidden_states, + const infinilm::global_state::AttentionMetadata &attn_metadata, + std::tuple &kv_cache) const { + auto position_ids = attn_metadata.position_ids.value(); + + // hidden_states shape: [batch, seq_len, hidden_size] + auto hidden_states_mutable = hidden_states; + auto shape = hidden_states->shape(); + size_t batch_size = shape[0]; + size_t seq_len = shape[1]; + + // Only support batchsize==1, all requests should be flattened along seqlen dimension + ASSERT_EQ(batch_size, 1); + + // 1. Project Q, K, V + auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable); + + // 2. Reshape for multi-head attention + auto q_reshaped = q->view({seq_len, num_attention_heads_, head_dim_}); + auto k_reshaped = k->view({seq_len, num_key_value_heads_, head_dim_}); + auto v_reshaped = v->view({seq_len, num_key_value_heads_, head_dim_}); + q_reshaped = q_norm_->forward(q_reshaped); + k_reshaped = k_norm_->forward(k_reshaped); + + // 3. Prepare position_ids for RoPE + auto pos_shape = position_ids->shape(); + infinicore::Tensor pos_ids_for_rope = position_ids; + if (pos_shape.size() == 2) { + auto pos_narrowed = position_ids->narrow({{0, 0, 1}}); + pos_ids_for_rope = pos_narrowed->view({pos_shape[1]}); + } else if (pos_shape.size() == 1) { + pos_ids_for_rope = position_ids; + } else { + throw std::runtime_error("Unexpected position_ids shape"); + } + + // 4. Apply RoPE to QK + rotary_emb_->forward(q_reshaped, pos_ids_for_rope, true); + rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); + + // 5. Attn Backend calculate + auto attn_output = attn_->forward(q_reshaped, k_reshaped, v_reshaped, kv_cache, attn_metadata); + + // 6. Project output + return o_proj_->forward(attn_output); +} +} // namespace infinilm::models::qwen3 diff --git a/csrc/models/qwen3/qwen3_attention.hpp b/csrc/models/qwen3/qwen3_attention.hpp new file mode 100644 index 00000000..3d329ef1 --- /dev/null +++ b/csrc/models/qwen3/qwen3_attention.hpp @@ -0,0 +1,49 @@ +#pragma once + +#include "../../layers/common_modules.hpp" + +namespace infinilm::models::qwen3 { +class Qwen3Attention : public infinicore::nn::Module { +public: + Qwen3Attention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + + size_t layer_idx() const { return layer_idx_; } + size_t num_heads() const { return num_attention_heads_; } + size_t num_kv_heads() const { return num_key_value_heads_; } + size_t head_dim() const { return head_dim_; } + size_t hidden_size() const { return hidden_size_; } + void set_rotary_emb(const std::shared_ptr &rotary_emb) { rotary_emb_ = rotary_emb; } + +private: + infinicore::Tensor forward_static_(const infinicore::Tensor &hidden_states, + const infinilm::global_state::AttentionMetadata &attn_metadata, + std::tuple &kv_cache) const; + + infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states, + const infinilm::global_state::AttentionMetadata &attn_metadata, + std::tuple &kv_cache) const; + +protected: + INFINICORE_NN_MODULE(infinilm::layers::linear::QKVParallelLinear, qkv_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm); + + std::shared_ptr rotary_emb_; + std::shared_ptr attn_; + ::infinilm::backends::AttentionBackend attention_backend_; + size_t layer_idx_; + size_t num_attention_heads_; + size_t num_key_value_heads_; + size_t hidden_size_; + size_t head_dim_; + + // For off-line kv cache quantization + INFINICORE_NN_PARAMETER(kv_cache_k_scale); + INFINICORE_NN_PARAMETER(kv_cache_v_scale); +}; +} // namespace infinilm::models::qwen3 diff --git a/csrc/models/qwen3/qwen3_for_causal_lm.cpp b/csrc/models/qwen3/qwen3_for_causal_lm.cpp new file mode 100644 index 00000000..d06f7680 --- /dev/null +++ b/csrc/models/qwen3/qwen3_for_causal_lm.cpp @@ -0,0 +1,23 @@ +#include "qwen3_for_causal_lm.hpp" +#include "../models_registry.hpp" +#include +#include + +namespace infinilm::models::qwen3 { + +std::shared_ptr create_qwen3_model_config(std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + if ("qwen3" != model_type) { + throw std::runtime_error("infinilm::models::qwen3::create_qwen3_model_config: model_type is not qwen3"); + } + return model_config; +} + +} // namespace infinilm::models::qwen3 + +namespace { +INFINILM_REGISTER_CAUSAL_LM_MODEL( + qwen3, + infinilm::models::qwen3::Qwen3ForCausalLM, + infinilm::models::qwen3::create_qwen3_model_config); +} // namespace diff --git a/csrc/models/qwen3/qwen3_for_causal_lm.hpp b/csrc/models/qwen3/qwen3_for_causal_lm.hpp new file mode 100644 index 00000000..0a35170b --- /dev/null +++ b/csrc/models/qwen3/qwen3_for_causal_lm.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include "qwen3_attention.hpp" + +namespace infinilm::models::qwen3 { + +using Qwen3MLP = infinilm::layers::MLP; + +using Qwen3Attention = infinilm::models::qwen3::Qwen3Attention; + +using Qwen3DecoderLayer = infinilm::layers::causal_lm_templates::TextDecoderLayer; + +using Qwen3Model = infinilm::layers::causal_lm_templates::TextModel; + +using Qwen3ForCausalLM = infinilm::layers::causal_lm_templates::TextCausalLM; + +} // namespace infinilm::models::qwen3 + +namespace infinilm::models::qwen3 { + +std::shared_ptr create_qwen3_model_config(std::shared_ptr model_config); + +} // namespace infinilm::models::qwen3 diff --git a/csrc/models/qwen3_moe/qwen3_moe_for_causal_lm.cpp b/csrc/models/qwen3_moe/qwen3_moe_for_causal_lm.cpp new file mode 100644 index 00000000..14109941 --- /dev/null +++ b/csrc/models/qwen3_moe/qwen3_moe_for_causal_lm.cpp @@ -0,0 +1,24 @@ +#include "qwen3_moe_for_causal_lm.hpp" +#include "../models_registry.hpp" + +#include +#include + +namespace infinilm::models::qwen3_moe { + +std::shared_ptr create_qwen3_moe_model_config(std::shared_ptr model_config) { + const std::string model_type = model_config->get("model_type"); + if ("qwen3_moe" != model_type) { + throw std::runtime_error("create_qwen3_moe_model_config: model_type is not qwen3_moe"); + } + return model_config; +} + +} // namespace infinilm::models::qwen3_moe + +namespace { +INFINILM_REGISTER_CAUSAL_LM_MODEL( + qwen3_moe, + infinilm::models::qwen3_moe::Qwen3MoeForCausalLM, + infinilm::models::qwen3_moe::create_qwen3_moe_model_config); +} // namespace diff --git a/csrc/models/qwen3_moe/qwen3_moe_for_causal_lm.hpp b/csrc/models/qwen3_moe/qwen3_moe_for_causal_lm.hpp new file mode 100644 index 00000000..36e05bbf --- /dev/null +++ b/csrc/models/qwen3_moe/qwen3_moe_for_causal_lm.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "../qwen3/qwen3_attention.hpp" +#include "qwen3_moe_sparse_moe_block.hpp" +#include + +namespace infinilm::models::qwen3_moe { + +using Qwen3MoeAttention = qwen3::Qwen3Attention; + +using Qwen3MoeDecoderLayer = infinilm::layers::causal_lm_templates::TextDecoderLayer; + +using Qwen3MoeModel = infinilm::layers::causal_lm_templates::TextModel; + +using Qwen3MoeForCausalLM = infinilm::layers::causal_lm_templates::TextCausalLM; + +std::shared_ptr create_qwen3_moe_model_config(std::shared_ptr model_config); +} // namespace infinilm::models::qwen3_moe diff --git a/csrc/models/qwen3_moe/qwen3_moe_sparse_moe_block.cpp b/csrc/models/qwen3_moe/qwen3_moe_sparse_moe_block.cpp new file mode 100644 index 00000000..63d90a01 --- /dev/null +++ b/csrc/models/qwen3_moe/qwen3_moe_sparse_moe_block.cpp @@ -0,0 +1,31 @@ +#include "qwen3_moe_sparse_moe_block.hpp" +#include + +namespace infinilm::models::qwen3_moe { + +Qwen3MoeSparseMoeBlock::Qwen3MoeSparseMoeBlock(std::shared_ptr model_config, + const infinicore::Device &device) { + const auto &dtype{model_config->get_dtype()}; + size_t hidden_size = model_config->get("hidden_size"); + size_t moe_intermediate_size = model_config->get("moe_intermediate_size"); + size_t shared_expert_intermediate_size = model_config->get_or("shared_expert_intermediate_size", 0); + size_t num_experts = model_config->get("num_experts"); + + INFINICORE_NN_MODULE_INIT(gate, hidden_size, num_experts, false, dtype, device); + experts_.reserve(num_experts); + for (size_t i = 0; i < num_experts; ++i) { + experts_.push_back(this->register_module("experts." + std::to_string(i), model_config, device)); + } + + if (shared_expert_intermediate_size > 0) { + INFINICORE_NN_MODULE_INIT(shared_expert, model_config, device); + INFINICORE_NN_MODULE_INIT(shared_expert_gate, hidden_size, 1, false, dtype, device); + } +} + +infinicore::Tensor Qwen3MoeSparseMoeBlock::forward(const infinicore::Tensor &hidden_states) const { + spdlog::error("Qwen3MoeSparseMoeBlock: forward not implemented"); + return hidden_states; +} + +} // namespace infinilm::models::qwen3_moe diff --git a/csrc/models/qwen3_moe/qwen3_moe_sparse_moe_block.hpp b/csrc/models/qwen3_moe/qwen3_moe_sparse_moe_block.hpp new file mode 100644 index 00000000..bfa8e1cf --- /dev/null +++ b/csrc/models/qwen3_moe/qwen3_moe_sparse_moe_block.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include "../../layers/common_modules.hpp" + +namespace infinilm::models::qwen3_moe { +using Qwen3MoeMLP = infinilm::layers::MoeMLP; + +class Qwen3MoeSparseMoeBlock : public infinicore::nn::Module { +public: + Qwen3MoeSparseMoeBlock(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +protected: + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, gate); + INFINICORE_NN_MODULE_VEC(Qwen3MoeMLP, experts); + INFINICORE_NN_MODULE(Qwen3MoeMLP, shared_expert); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, shared_expert_gate); +}; + +} // namespace infinilm::models::qwen3_moe diff --git a/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp b/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp new file mode 100644 index 00000000..38aa9fa3 --- /dev/null +++ b/csrc/models/qwen3_next/qwen3_next_allocate_kv_cache_tensors.cpp @@ -0,0 +1,98 @@ +#include "../../backends/attention_backends.hpp" +#include "../../cache/kv_cache.hpp" +#include "qwen3_next_for_causal_lm.hpp" +#include +#include +#include + +namespace infinilm::models::qwen3_next { + +std::vector> qwen3_next_allocate_kv_cache_tensors( + const cache::CacheConfig *cache_config, + const std::shared_ptr &text_config, + const backends::AttentionBackend &attention_backend) { + if (nullptr == cache_config) { + return {}; + } + if (nullptr == text_config) { + throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: text_config is null"); + } + + std::vector> kv_cache_vec; + switch (attention_backend) { + case backends::AttentionBackend::STATIC_ATTN: { + auto static_kv_cache_config = dynamic_cast(cache_config); + if (nullptr == static_kv_cache_config) { + throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: invalid static kv cache config type"); + } + const size_t num_hidden_layers = text_config->get("num_hidden_layers"); + kv_cache_vec.reserve(num_hidden_layers); + + const size_t head_dim = text_config->get("head_dim"); + const size_t num_key_value_heads = text_config->get("num_key_value_heads"); + const size_t max_position_embeddings = text_config->get("max_position_embeddings"); + const auto &dtype{text_config->get_dtype()}; + const std::vector layer_types = text_config->get>("layer_types"); + + for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { + const std::string &layer_type = layer_types[layer_idx]; + if ("linear_attention" == layer_type) { + kv_cache_vec.emplace_back(); + } else if ("full_attention" == layer_type) { + auto kv_cache = cache::StaticKVCache::create_layer_kv_cache( + head_dim, + head_dim, + num_key_value_heads, + num_key_value_heads, + max_position_embeddings, + dtype, + *static_kv_cache_config); + kv_cache_vec.push_back(kv_cache); + } else { + throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: unsupported layer_type '" + layer_type + "' for layer " + std::to_string(layer_idx)); + } + } + break; + } + case backends::AttentionBackend::FLASH_ATTN: { + ; + } + case backends::AttentionBackend::PAGED_ATTN: { + auto paged_kv_cache_config = dynamic_cast(cache_config); + if (nullptr == paged_kv_cache_config) { + throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: invalid paged kv cache config type"); + } + const size_t num_hidden_layers = text_config->get("num_hidden_layers"); + kv_cache_vec.reserve(num_hidden_layers); + + const size_t head_dim = text_config->get("head_dim"); + const size_t num_key_value_heads = text_config->get("num_key_value_heads"); + const auto &dtype{text_config->get_dtype()}; + const std::vector layer_types = text_config->get>("layer_types"); + + for (size_t layer_idx = 0; layer_idx < num_hidden_layers; ++layer_idx) { + const std::string &layer_type = layer_types[layer_idx]; + if ("linear_attention" == layer_type) { + kv_cache_vec.emplace_back(); + } else if ("full_attention" == layer_type) { + auto kv_cache = cache::PagedKVCache::create_layer_kv_cache( + head_dim, + head_dim, + num_key_value_heads, + num_key_value_heads, + dtype, + *paged_kv_cache_config); + kv_cache_vec.push_back(kv_cache); + } else { + throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: unsupported layer_type '" + layer_type + "' for layer " + std::to_string(layer_idx)); + } + } + break; + } + default: + throw std::runtime_error("infinilm::models::qwen3_next::qwen3_next_allocate_kv_cache_tensors: Unsupported attention backend: " + std::to_string(static_cast(attention_backend))); + } + return kv_cache_vec; +} + +} // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_attention.cpp b/csrc/models/qwen3_next/qwen3_next_attention.cpp new file mode 100644 index 00000000..b3dc474d --- /dev/null +++ b/csrc/models/qwen3_next/qwen3_next_attention.cpp @@ -0,0 +1,78 @@ +#include "qwen3_next_attention.hpp" + +#include "../../global_state/global_state.hpp" +#include +#include +#include + +namespace infinilm::models::qwen3_next { + +Qwen3NextAttention::Qwen3NextAttention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + layer_idx_ = layer_idx; + const auto &dtype{model_config->get_dtype()}; + + num_attention_heads_ = model_config->get("num_attention_heads"); + num_key_value_heads_ = model_config->get("num_key_value_heads"); + hidden_size_ = model_config->get("hidden_size"); + head_dim_ = model_config->get("head_dim"); + size_t total_num_heads = num_attention_heads_; + float scaling = 1.0f / std::sqrt(static_cast(head_dim_)); + + bool use_bias = model_config->get_or("attention_bias", true); + bool use_output_bias = model_config->get_or("attention_output_bias", false); + double rms_norm_eps = model_config->get("rms_norm_eps"); + bool attn_output_gate = model_config->get_or("attn_output_gate", true); + + const engine::distributed::RankInfo &rank_info = infinilm::global_state::get_tensor_model_parallel_rank_info(); + int tp_rank = infinilm::global_state::get_tensor_model_parallel_rank(); + int tp_size = infinilm::global_state::get_tensor_model_parallel_world_size(); + auto quant_scheme = model_config->get_quant_scheme(); + auto quantization_method = model_config->get_quantization_method(); + switch (quant_scheme) { + case infinicore::quantization::QuantScheme::NONE: { + INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, total_num_heads * (1 + attn_output_gate), num_key_value_heads_, quantization_method, use_bias, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, total_num_heads * head_dim_, hidden_size_, quantization_method, + use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + case infinicore::quantization::QuantScheme::COMPRESSED_TENSOR_W8A8I8: { + INFINILM_QKV_LINEAR_W8A8_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, total_num_heads * (1 + attn_output_gate), num_key_value_heads_, quantization_method, use_bias, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, total_num_heads * head_dim_, hidden_size_, quantization_method, + use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + case infinicore::quantization::QuantScheme::AWQ_W4A16: { + INFINILM_QKV_LINEAR_W4A16AWQ_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, total_num_heads * (1 + attn_output_gate), num_key_value_heads_, quantization_method, use_bias, dtype, device, rank_info); + INFINICORE_NN_MODULE_INIT(o_proj, total_num_heads * head_dim_, hidden_size_, quantization_method, + use_output_bias, dtype, device, tp_rank, tp_size, rank_info.comm); + break; + } + default: { + throw std::runtime_error("infinilm::models::qwen3_next::Qwen3NextAttention: unsupported quantization scheme"); + } + } + + INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, rms_norm_eps, dtype, device); + if ((num_key_value_heads_ < tp_size) || (0 != (num_key_value_heads_ % tp_size))) { + throw std::runtime_error("infinilm::models::qwen3_next::Qwen3NextAttention: num_key_value_heads must be divisible by tp_size"); + } + + attention_backend_ = infinilm::global_state::get_infinilm_config().attention_backend; + size_t num_attention_heads_rank = num_attention_heads_ / tp_size; + size_t num_key_value_heads_rank = num_key_value_heads_ / tp_size; + attn_ = std::make_shared(num_attention_heads_rank, head_dim_, scaling, num_key_value_heads_rank, layer_idx_, + kv_cache_k_scale_, kv_cache_v_scale_, attention_backend_); + + num_attention_heads_ = num_attention_heads_rank; + num_key_value_heads_ = num_key_value_heads_rank; +} + +infinicore::Tensor Qwen3NextAttention::forward(const infinicore::Tensor &hidden_states) const { + spdlog::error("infinilm::models::qwen3_next::Qwen3NextAttention: forward not implemented"); + return hidden_states; +} + +} // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_attention.hpp b/csrc/models/qwen3_next/qwen3_next_attention.hpp new file mode 100644 index 00000000..84395b97 --- /dev/null +++ b/csrc/models/qwen3_next/qwen3_next_attention.hpp @@ -0,0 +1,42 @@ +#pragma once + +#include "../../layers/common_modules.hpp" + +namespace infinilm::models::qwen3_next { + +class Qwen3NextAttention : public infinicore::nn::Module { +public: + Qwen3NextAttention(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + + size_t layer_idx() const { return layer_idx_; } + size_t num_heads() const { return num_attention_heads_; } + size_t num_kv_heads() const { return num_key_value_heads_; } + size_t head_dim() const { return head_dim_; } + size_t hidden_size() const { return hidden_size_; } + void set_rotary_emb(const std::shared_ptr &rotary_emb) { rotary_emb_ = rotary_emb; } + +protected: + INFINICORE_NN_MODULE(infinilm::layers::linear::QKVParallelLinear, qkv_proj); + INFINICORE_NN_MODULE(infinilm::layers::linear::RowParallelLinear, o_proj); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, q_norm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, k_norm); + std::shared_ptr rotary_emb_; + + std::shared_ptr attn_; + ::infinilm::backends::AttentionBackend attention_backend_; + size_t layer_idx_; + size_t num_attention_heads_; + size_t num_key_value_heads_; + size_t hidden_size_; + size_t head_dim_; + + // For off-line kv cache quantization + INFINICORE_NN_PARAMETER(kv_cache_k_scale); + INFINICORE_NN_PARAMETER(kv_cache_v_scale); +}; + +} // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_decoderLayer.cpp b/csrc/models/qwen3_next/qwen3_next_decoderLayer.cpp new file mode 100644 index 00000000..df139f35 --- /dev/null +++ b/csrc/models/qwen3_next/qwen3_next_decoderLayer.cpp @@ -0,0 +1,73 @@ +#include "qwen3_next_decoderLayer.hpp" +#include "infinicore/ops.hpp" +#include +#include +#include +#include + +namespace infinilm::models::qwen3_next { + +Qwen3NextDecoderLayer::Qwen3NextDecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) + : layer_idx_(layer_idx) { + + const auto &dtype{model_config->get_dtype()}; + size_t hidden_size = model_config->get("hidden_size"); + size_t intermediate_size = model_config->get("intermediate_size"); + double rms_norm_eps = model_config->get("rms_norm_eps"); + + INFINICORE_NN_MODULE_INIT(input_layernorm, hidden_size, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(post_attention_layernorm, hidden_size, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(mlp, model_config, device); + + const std::vector layer_types = model_config->get>("layer_types"); + layer_type_ = layer_types[layer_idx]; + if ("linear_attention" == layer_type_) { + INFINICORE_NN_MODULE_INIT(linear_attn, model_config, layer_idx, device); + } else if ("full_attention" == layer_type_) { + INFINICORE_NN_MODULE_INIT(self_attn, model_config, layer_idx, device); + } else { + throw std::runtime_error("infinilm::models::qwen3_next::Qwen3NextDecoderLayer: unsupported layer_type '" + layer_type_ + "' for layer " + std::to_string(layer_idx)); + } +} + +std::tuple Qwen3NextDecoderLayer::forward(infinicore::Tensor &hidden_states, + infinicore::Tensor &residual) { + input_layernorm_->forward_inplace(hidden_states, residual); + if ("linear_attention" == layer_type_) { + hidden_states = linear_attn_->forward(hidden_states); + } else if ("full_attention" == layer_type_) { + hidden_states = self_attn_->forward(hidden_states); + } + + post_attention_layernorm_->forward_inplace(hidden_states, residual); + hidden_states = mlp_->forward(hidden_states); + return std::make_tuple(hidden_states, residual); +} + +infinicore::Tensor Qwen3NextDecoderLayer::forward(infinicore::Tensor &hidden_states) { + + auto residual = hidden_states; + hidden_states = input_layernorm_->forward(hidden_states); + if ("linear_attention" == layer_type_) { + hidden_states = linear_attn_->forward(hidden_states); + } else if ("full_attention" == layer_type_) { + hidden_states = self_attn_->forward(hidden_states); + } + hidden_states = infinicore::op::add(residual, hidden_states); + + residual = hidden_states; + hidden_states = post_attention_layernorm_->forward(hidden_states); + hidden_states = mlp_->forward(hidden_states); + hidden_states = infinicore::op::add(residual, hidden_states); + return hidden_states; +} + +void Qwen3NextDecoderLayer::set_rotary_emb(const std::shared_ptr &rotary_emb) { + if (self_attn_) { + self_attn_->set_rotary_emb(rotary_emb); + } +} + +} // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_decoderLayer.hpp b/csrc/models/qwen3_next/qwen3_next_decoderLayer.hpp new file mode 100644 index 00000000..518a1537 --- /dev/null +++ b/csrc/models/qwen3_next/qwen3_next_decoderLayer.hpp @@ -0,0 +1,38 @@ +#pragma once + +#include "../qwen3_moe/qwen3_moe_sparse_moe_block.hpp" +#include "qwen3_next_attention.hpp" +#include "qwen3_next_gated_deltanet.hpp" +#include +#include + +namespace infinilm::models::qwen3_next { +using Qwen3NextSparseMoeBlock = qwen3_moe::Qwen3MoeSparseMoeBlock; + +class Qwen3NextDecoderLayer : public infinicore::nn::Module { +public: + Qwen3NextDecoderLayer(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + std::tuple forward(infinicore::Tensor &hidden_states, infinicore::Tensor &residual); + + infinicore::Tensor forward(infinicore::Tensor &hidden_states); + + size_t layer_idx() const { return layer_idx_; } + + void set_rotary_emb(const std::shared_ptr &rotary_emb); + +protected: + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, input_layernorm); + INFINICORE_NN_MODULE(infinicore::nn::RMSNorm, post_attention_layernorm); + INFINICORE_NN_MODULE(Qwen3NextAttention, self_attn); + INFINICORE_NN_MODULE(Qwen3NextGatedDeltaNet, linear_attn); + INFINICORE_NN_MODULE(Qwen3NextSparseMoeBlock, mlp); + +private: + size_t layer_idx_; + std::string layer_type_; +}; + +} // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_for_causal_lm.cpp b/csrc/models/qwen3_next/qwen3_next_for_causal_lm.cpp new file mode 100644 index 00000000..7d2f8a6e --- /dev/null +++ b/csrc/models/qwen3_next/qwen3_next_for_causal_lm.cpp @@ -0,0 +1,73 @@ +#include "qwen3_next_for_causal_lm.hpp" +#include "../../global_state/global_state.hpp" +#include "../models_registry.hpp" +#include +#include +#include + +namespace infinilm::models::qwen3_next { + +Qwen3NextForCausalLM::Qwen3NextForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device) { + model_config_ = model_config; + size_t hidden_size = model_config->get("hidden_size"); + size_t vocab_size = model_config->get("vocab_size"); + const auto &dtype{model_config->get_dtype()}; + + INFINICORE_NN_MODULE_INIT(model, model_config, device); + INFINICORE_NN_MODULE_INIT(lm_head, hidden_size, vocab_size, false, dtype, device); +} + +infinilm::InfinilmModel::Output Qwen3NextForCausalLM::forward(const infinilm::InfinilmModel::Input &input) const { + auto hidden_states = model_->forward(input); + auto logits = lm_head_->forward(hidden_states); + return {logits}; +} + +void Qwen3NextForCausalLM::reset_cache(const cache::CacheConfig *cache_config) { + if (nullptr == cache_config) { + InfinilmModel::reset_cache(nullptr); + return; + } + cache_config_ = cache_config->unique_copy(); + + auto &kv_cache_vec = infinilm::global_state::get_forward_context().kv_cache_vec; + kv_cache_vec.clear(); + const backends::AttentionBackend attention_backend = infinilm::global_state::get_infinilm_config().attention_backend; + + auto new_kv_cache_vec = qwen3_next_allocate_kv_cache_tensors(cache_config, model_config_, attention_backend); + kv_cache_vec = std::move(new_kv_cache_vec); +} + +std::shared_ptr create_qwen3_next_model_config(std::shared_ptr model_config) { + const std::string model_type = model_config->get("model_type"); + if ("qwen3_next" != model_type) { + throw std::runtime_error("infinilm::models::qwen3_next::create_qwen3_next_model_config: model_type is not qwen3_next"); + } + + nlohmann::json &config_json = model_config->get_config_json(); + if (!config_json.contains("layer_types")) { + size_t full_attention_interval = model_config->get("full_attention_interval"); + size_t num_hidden_layers = model_config->get("num_hidden_layers"); + std::vector layer_types; + layer_types.reserve(num_hidden_layers); + for (size_t i = 0; i < num_hidden_layers; i++) { + layer_types.push_back(bool((i + 1) % full_attention_interval) ? "linear_attention" : "full_attention"); + } + config_json["layer_types"] = layer_types; + } + + if (!config_json.contains("attention_bias")) { + config_json["attention_bias"] = false; + } + return model_config; +} + +} // namespace infinilm::models::qwen3_next + +namespace { +INFINILM_REGISTER_CAUSAL_LM_MODEL( + qwen3_next, + infinilm::models::qwen3_next::Qwen3NextForCausalLM, + infinilm::models::qwen3_next::create_qwen3_next_model_config); +} // namespace diff --git a/csrc/models/qwen3_next/qwen3_next_for_causal_lm.hpp b/csrc/models/qwen3_next/qwen3_next_for_causal_lm.hpp new file mode 100644 index 00000000..fd2ea9a0 --- /dev/null +++ b/csrc/models/qwen3_next/qwen3_next_for_causal_lm.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include "qwen3_next_decoderLayer.hpp" +#include + +namespace infinilm::models::qwen3_next { + +using Qwen3NextModel = infinilm::layers::causal_lm_templates::TextModel; + +class Qwen3NextForCausalLM : public InfinilmModel { +public: + Qwen3NextForCausalLM(std::shared_ptr model_config, + const infinicore::Device &device); + + Output forward(const Input &input) const override; + + void reset_cache(const cache::CacheConfig *cache_config) override; + +protected: + INFINICORE_NN_MODULE(Qwen3NextModel, model); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); +}; + +std::shared_ptr create_qwen3_next_model_config(std::shared_ptr model_config); + +/** Implemented in `qwen3_next_allocate_kv_cache_tensors.cpp`. */ +std::vector> qwen3_next_allocate_kv_cache_tensors( + const cache::CacheConfig *cache_config, + const std::shared_ptr &text_config, + const backends::AttentionBackend &attention_backend); +} // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp new file mode 100644 index 00000000..261e3d7e --- /dev/null +++ b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.cpp @@ -0,0 +1,59 @@ +#include "qwen3_next_gated_deltanet.hpp" +#include + +namespace infinilm::models::qwen3_next { + +FakeConv1d::FakeConv1d(size_t in_channels, + size_t out_channels, + size_t kernel_size, + size_t stride, + size_t padding, + size_t dilation, + size_t groups, + bool bias, + const infinicore::DataType dtype, + const infinicore::Device device) { + + INFINICORE_NN_PARAMETER_INIT(weight, ({out_channels, 1, kernel_size}, dtype, device)); +} + +Qwen3NextGatedDeltaNet::Qwen3NextGatedDeltaNet(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device) { + layer_idx_ = layer_idx; + const auto &dtype{model_config->get_dtype()}; + size_t hidden_size = model_config->get("hidden_size"); + size_t linear_num_value_heads = model_config->get("linear_num_value_heads"); + size_t linear_num_key_heads = model_config->get("linear_num_key_heads"); + size_t linear_key_head_dim = model_config->get("linear_key_head_dim"); + size_t linear_value_head_dim = model_config->get("linear_value_head_dim"); + + size_t key_dim = linear_key_head_dim * linear_num_key_heads; + size_t value_dim = linear_value_head_dim * linear_num_value_heads; + + size_t linear_conv_kernel_dim = model_config->get("linear_conv_kernel_dim"); + + double rms_norm_eps = model_config->get("rms_norm_eps"); + + size_t conv_dim = key_dim * 2 + value_dim; + INFINICORE_NN_MODULE_INIT(conv1d, conv_dim, conv_dim, linear_conv_kernel_dim, 1, linear_conv_kernel_dim - 1, 1, 1, false, dtype, device); + + size_t projection_size_qkvz = key_dim * 2 + value_dim * 2; + size_t projection_size_ba = linear_num_value_heads * 2; + + INFINICORE_NN_MODULE_INIT(in_proj_qkvz, hidden_size, projection_size_qkvz, false, dtype, device); + INFINICORE_NN_MODULE_INIT(in_proj_ba, hidden_size, projection_size_ba, false, dtype, device); + + INFINICORE_NN_PARAMETER_INIT(dt_bias, ({linear_num_value_heads}, dtype, device)); + INFINICORE_NN_PARAMETER_INIT(A_log, ({linear_num_value_heads}, dtype, device)); + + INFINICORE_NN_MODULE_INIT(norm, linear_value_head_dim, rms_norm_eps, dtype, device); + INFINICORE_NN_MODULE_INIT(out_proj, value_dim, hidden_size, false, dtype, device); +} + +infinicore::Tensor Qwen3NextGatedDeltaNet::forward(const infinicore::Tensor &hidden_states) const { + spdlog::error("Qwen3NextGatedDeltaNet: forward not implemented"); + return hidden_states; +} + +} // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp new file mode 100644 index 00000000..e8c36d4e --- /dev/null +++ b/csrc/models/qwen3_next/qwen3_next_gated_deltanet.hpp @@ -0,0 +1,46 @@ +#pragma once + +#include "../../layers/common_modules.hpp" + +namespace infinilm::models::qwen3_next { +using Qwen3Next_Fake_RMSNormGated = infinicore::nn::RMSNorm; + +class FakeConv1d : public infinicore::nn::Module { +public: + FakeConv1d(size_t in_channels, + size_t out_channels, + size_t kernel_size, + size_t stride, + size_t padding, + size_t dilation, + size_t groups, + bool bias, + const infinicore::DataType dtype, + const infinicore::Device device); + +private: + size_t layer_idx_; + INFINICORE_NN_PARAMETER(weight); +}; + +class Qwen3NextGatedDeltaNet : public infinicore::nn::Module { +public: + Qwen3NextGatedDeltaNet(std::shared_ptr model_config, + size_t layer_idx, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinicore::Tensor &hidden_states) const; + +private: + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, in_proj_qkvz); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, in_proj_ba); + INFINICORE_NN_MODULE(FakeConv1d, conv1d); + INFINICORE_NN_PARAMETER(dt_bias); + INFINICORE_NN_PARAMETER(A_log); + INFINICORE_NN_MODULE(Qwen3Next_Fake_RMSNormGated, norm); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, out_proj); + + size_t layer_idx_; +}; + +} // namespace infinilm::models::qwen3_next diff --git a/csrc/models/qwen3_vl/qwen3_vl_for_conditional_generation.cpp b/csrc/models/qwen3_vl/qwen3_vl_for_conditional_generation.cpp new file mode 100644 index 00000000..512fe307 --- /dev/null +++ b/csrc/models/qwen3_vl/qwen3_vl_for_conditional_generation.cpp @@ -0,0 +1,83 @@ +#include "qwen3_vl_for_conditional_generation.hpp" +#include "../../global_state/global_state.hpp" +#include "../models_registry.hpp" +#include +#include + +namespace infinilm::models::qwen3_vl { + +Qwen3VLModel::Qwen3VLModel(std::shared_ptr model_config, + const infinicore::Device &device) { + nlohmann::json &config_json = model_config->get_config_json(); + nlohmann::json &text_config_json = config_json["text_config"]; + std::shared_ptr text_config = std::make_shared(text_config_json); + + INFINICORE_NN_MODULE_INIT(language_model, text_config, device); +} + +infinicore::Tensor Qwen3VLModel::forward(const infinilm::InfinilmModel::Input &input) const { + auto hidden_states = language_model_->forward(input); + return hidden_states; +} + +Qwen3VLForConditionalGeneration::Qwen3VLForConditionalGeneration(std::shared_ptr model_config, + const infinicore::Device &device) { + model_config_ = model_config; + const nlohmann::json &config_json = model_config->get_config_json(); + const nlohmann::json &text_config_json = config_json["text_config"]; + auto text_config = std::make_shared(text_config_json); + + size_t hidden_size = text_config->get("hidden_size"); + size_t vocab_size = text_config->get("vocab_size"); + const auto &dtype{model_config->get_dtype()}; + + INFINICORE_NN_MODULE_INIT(model, model_config, device); + INFINICORE_NN_MODULE_INIT(lm_head, hidden_size, vocab_size, false, dtype, device); +} + +infinilm::InfinilmModel::Output Qwen3VLForConditionalGeneration::forward(const infinilm::InfinilmModel::Input &input) const { + auto hidden_states = model_->forward(input); + auto logits = lm_head_->forward(hidden_states); + return {logits}; +} + +void Qwen3VLForConditionalGeneration::reset_cache(const cache::CacheConfig *cache_config) { + if (nullptr == cache_config) { + InfinilmModel::reset_cache(nullptr); + return; + } + cache_config_ = cache_config->unique_copy(); + + const nlohmann::json &config_json = model_config_->get_config_json(); + const nlohmann::json &text_config_json = config_json["text_config"]; + auto text_model_config = std::make_shared(text_config_json); + + auto &kv_cache_vec = infinilm::global_state::get_forward_context().kv_cache_vec; + kv_cache_vec.clear(); + const backends::AttentionBackend attention_backend = infinilm::global_state::get_infinilm_config().attention_backend; + kv_cache_vec = std::move(default_allocate_kv_cache_tensors(cache_config, text_model_config, attention_backend)); +} + +std::shared_ptr create_qwen3_vl_model_config(std::shared_ptr model_config) { + const std::string &model_type = model_config->get("model_type"); + if ("qwen3_vl" != model_type) { + throw std::runtime_error("infinilm::models::qwen3_vl::create_qwen3_vl_model_config: model_type is not qwen3_vl"); + } + + nlohmann::json &config_json = model_config->get_config_json(); + nlohmann::json &text_config_json = config_json["text_config"]; + if (!config_json.contains("torch_dtype")) { + std::string dtype = text_config_json["dtype"]; + config_json["torch_dtype"] = dtype; + } + return model_config; +} + +} // namespace infinilm::models::qwen3_vl + +namespace { +INFINILM_REGISTER_CAUSAL_LM_MODEL( + qwen3_vl, + infinilm::models::qwen3_vl::Qwen3VLForConditionalGeneration, + infinilm::models::qwen3_vl::create_qwen3_vl_model_config); +} // namespace diff --git a/csrc/models/qwen3_vl/qwen3_vl_for_conditional_generation.hpp b/csrc/models/qwen3_vl/qwen3_vl_for_conditional_generation.hpp new file mode 100644 index 00000000..dae5bc39 --- /dev/null +++ b/csrc/models/qwen3_vl/qwen3_vl_for_conditional_generation.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include "../../models/qwen3/qwen3_for_causal_lm.hpp" + +namespace infinilm::models::qwen3_vl { + +using Qwen3VLTextModel = infinilm::models::qwen3::Qwen3Model; + +class Qwen3VLModel : public infinicore::nn::Module { +public: + Qwen3VLModel(std::shared_ptr model_config, + const infinicore::Device &device); + + infinicore::Tensor forward(const infinilm::InfinilmModel::Input &input) const; + +protected: + INFINICORE_NN_MODULE(Qwen3VLTextModel, language_model); +}; + +class Qwen3VLForConditionalGeneration : public InfinilmModel { +public: + Qwen3VLForConditionalGeneration(std::shared_ptr model_config, + const infinicore::Device &device); + + Output forward(const Input &input) const override; + + void reset_cache(const cache::CacheConfig *cache_config) override; + +protected: + INFINICORE_NN_MODULE(Qwen3VLModel, model); + INFINICORE_NN_MODULE(infinilm::layers::linear::ReplicatedLinear, lm_head); +}; + +std::shared_ptr create_qwen3_vl_model_config(std::shared_ptr model_config); + +} // namespace infinilm::models::qwen3_vl diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 2bf7658d..a784d69c 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -66,8 +66,10 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def( + "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def( + "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { auto cfg = self.get_cache_config(); return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) @@ -112,8 +114,10 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def( + "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def( + "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) { auto cfg = self.get_cache_config(); return std::shared_ptr(std::move(cfg->unique_copy())); }) diff --git a/examples/bench.py b/examples/bench.py index 2c44ab19..9316ca26 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -441,7 +441,8 @@ def run( output_len = args.output_len enable_paged_attn = args.enable_paged_attn enable_graph = args.enable_graph - + attn_backend = args.attn + if isinstance(batch_size, int): batch_size = [batch_size] @@ -471,6 +472,9 @@ def run( else: cache_config = None + if enable_paged_attn and attn_backend == "default": + attn_backend = "paged-attn" + test = TestModel( model_path, infini_device=infini_device, @@ -478,7 +482,7 @@ def run( skip_load=skip_load, cache_config=cache_config, enable_graph=enable_graph, - attn_backend=args.attn, + attn_backend=attn_backend, ) # ---------------------------------------------------------------------------- # diff --git a/examples/jiuge.py b/examples/jiuge.py index 87caf662..7b5bf3c2 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -177,6 +177,9 @@ def test( # ---------------------------------------------------------------------------- # # Create Model # ---------------------------------------------------------------------------- # + if enable_paged_attn and attn_backend == "default": + attn_backend = "paged-attn" + model = InferEngine( model_path, device=infini_device, @@ -267,7 +270,7 @@ def test( ) model.reset_cache(cache_config) - + # ---------------------------------------------------------------------------- # # Generate # ---------------------------------------------------------------------------- # diff --git a/python/infinilm/auto_config.py b/python/infinilm/auto_config.py index ec3a896f..7e2d4afd 100644 --- a/python/infinilm/auto_config.py +++ b/python/infinilm/auto_config.py @@ -31,5 +31,7 @@ def from_pretrained(model_path): return LlamaConfig(**config_dict) elif config_dict["model_type"] == "fm9g7b": return LlamaConfig(**config_dict) + elif config_dict["model_type"] in ["qwen3_next" , "minicpm_sala" , "qwen3_vl" , "qwen3_moe"]: + return LlamaConfig(**config_dict) raise ValueError(f"Unsupported model type `{config_dict['model_type']}`.") diff --git a/xmake.lua b/xmake.lua index c29875aa..2b1b51d3 100644 --- a/xmake.lua +++ b/xmake.lua @@ -18,6 +18,16 @@ if has_config("use-kv-caching") then add_defines("ENABLE_KV_CACHING") end +option("use-classic-llama") + set_default(false) + set_showmenu(true) + set_description("Whether to using the classic LlamaForCausalLM") +option_end() + +if has_config("use-classic-llama") then + add_defines("USE_CLASSIC_LLAMA") +end + target("infinicore_infer") set_kind("shared")