From 1fc301ffb4f7a86173fc1583620c3271d1cb2ba6 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Wed, 4 Mar 2026 15:17:29 +0800 Subject: [PATCH 1/5] Issue/253: feat: support custom KV cache dtype for quantization --- csrc/cache/kv_cache.cpp | 22 ++++++++++++++++++++-- csrc/cache/kv_cache.hpp | 13 ++++++++++++- csrc/config/model_config.cpp | 21 +++------------------ csrc/config/model_config.hpp | 1 + csrc/models/llama/llama_model.cpp | 5 +++-- csrc/pybind11/cache/cache.hpp | 8 ++++++++ csrc/utils.hpp | 20 ++++++++++++++++++++ python/infinilm/cache/cache.py | 12 +++++++----- 8 files changed, 74 insertions(+), 28 deletions(-) diff --git a/csrc/cache/kv_cache.cpp b/csrc/cache/kv_cache.cpp index 5f0de647..8bb785ba 100644 --- a/csrc/cache/kv_cache.cpp +++ b/csrc/cache/kv_cache.cpp @@ -118,6 +118,16 @@ StaticKVCache::update(size_t layer_idx, // ========================== // PagedKVCacheConfig // ========================== +PagedKVCacheConfig::PagedKVCacheConfig( + size_t num_blocks, + std::string kv_cache_dtype, + size_t block_size) + : num_blocks_(num_blocks), + block_size_(block_size), + kv_cache_dtype_(parse_dtype(kv_cache_dtype)) { + kv_cache_dtype_set_ = true; +} + PagedKVCacheConfig::PagedKVCacheConfig( size_t num_blocks, size_t block_size) @@ -140,6 +150,15 @@ PagedKVCacheConfig::block_size() const { return block_size_; } +infinicore::DataType +PagedKVCacheConfig::kv_cache_dtype() const { + return kv_cache_dtype_; +} + +void PagedKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) const { + kv_cache_dtype_ = dtype; +} + // ========================== // PagedKVCache // ========================== @@ -149,7 +168,6 @@ PagedKVCache::PagedKVCache( infinicore::Size num_k_heads, infinicore::Size num_v_heads, infinicore::Size num_layers, - infinicore::DataType dtype, const PagedKVCacheConfig &config, const engine::distributed::RankInfo &rank_info) : Cache(), @@ -158,7 +176,7 @@ PagedKVCache::PagedKVCache( num_rank_k_heads_(num_k_heads / rank_info.tp_size), num_rank_v_heads_(num_v_heads / rank_info.tp_size), rank_num_layers_(num_layers), - dtype_(dtype), + dtype_(config.kv_cache_dtype()), num_blocks_per_layer_(config.num_blocks()), block_size_(config.block_size()) { // [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim] diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp index a594008a..ec639aeb 100644 --- a/csrc/cache/kv_cache.hpp +++ b/csrc/cache/kv_cache.hpp @@ -2,6 +2,7 @@ #include "base_cache.hpp" +#include "../utils.hpp" #include "infinicore/context/context.hpp" #include "infinicore/device.hpp" #include "infinicore/tensor.hpp" @@ -88,13 +89,24 @@ class PagedKVCacheConfig final : public CacheConfig { size_t num_blocks, size_t block_size = 256); + PagedKVCacheConfig( + size_t num_blocks, + std::string kv_cache_dtype, + size_t block_size = 16); + std::unique_ptr unique_copy() const override; size_t num_blocks() const; size_t block_size() const; + infinicore::DataType kv_cache_dtype() const; + void set_kv_cache_dtype(infinicore::DataType dtype) const; + bool kv_cache_dtype_set() const { return kv_cache_dtype_set_; } private: size_t num_blocks_; size_t block_size_; + + bool kv_cache_dtype_set_ = false; + mutable infinicore::DataType kv_cache_dtype_; }; class PagedKVCache final : public Cache { @@ -106,7 +118,6 @@ class PagedKVCache final : public Cache { infinicore::Size num_k_heads, infinicore::Size num_v_heads, infinicore::Size num_layers, - infinicore::DataType dtype, const PagedKVCacheConfig &config, const engine::distributed::RankInfo &rank_info); diff --git a/csrc/config/model_config.cpp b/csrc/config/model_config.cpp index 70b41ff0..5b3c6bd6 100644 --- a/csrc/config/model_config.cpp +++ b/csrc/config/model_config.cpp @@ -66,23 +66,8 @@ ModelConfig::get_rope_scaling() const { } } -infinicore::DataType -ModelConfig::get_dtype() const { - try { - std::string dtype_str = this->get("torch_dtype"); - if (dtype_str == "float32") { - return infinicore::DataType::F32; - } else if (dtype_str == "float16") { - return infinicore::DataType::F16; - } else if (dtype_str == "bfloat16") { - return infinicore::DataType::BF16; - } else if (dtype_str == "int8") { - return infinicore::DataType::I8; - } else { - throw std::runtime_error("Unsupported dtype string: " + dtype_str); - } - } catch (const std::exception &e) { - throw std::runtime_error("Error getting dtype from config: " + std::string(e.what())); - } +infinicore::DataType ModelConfig::get_dtype() const { + std::string dtype_str = this->get("torch_dtype"); + return parse_dtype(dtype_str); } } // namespace infinilm::config diff --git a/csrc/config/model_config.hpp b/csrc/config/model_config.hpp index a4600304..d376aeb7 100644 --- a/csrc/config/model_config.hpp +++ b/csrc/config/model_config.hpp @@ -1,5 +1,6 @@ #pragma once +#include "../utils.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/ops.hpp" #include "quant_config.hpp" diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp index 81e8fd04..27d9aff2 100644 --- a/csrc/models/llama/llama_model.cpp +++ b/csrc/models/llama/llama_model.cpp @@ -147,7 +147,6 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { config_.num_key_value_heads, config_.num_key_value_heads, config_.num_hidden_layers, - config_.dtype, *paged_kv_cache_config, rank_info_); } else if (auto kv_cache_config = dynamic_cast(cache_config)) { @@ -162,13 +161,15 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { *kv_cache_config, rank_info_); } else if (auto paged_kv_cache_config = dynamic_cast(cache_config)) { + if (!paged_kv_cache_config->kv_cache_dtype_set()) { + paged_kv_cache_config->set_kv_cache_dtype(model_config_->get_dtype()); + } kv_cache_ = std::make_shared( model_config_->get_head_dim(), model_config_->get_head_dim(), model_config_->get("num_key_value_heads"), model_config_->get("num_key_value_heads"), model_config_->get("num_hidden_layers"), - model_config_->get_dtype(), *paged_kv_cache_config, rank_info_); } else { diff --git a/csrc/pybind11/cache/cache.hpp b/csrc/pybind11/cache/cache.hpp index 46f97ea4..37244b34 100644 --- a/csrc/pybind11/cache/cache.hpp +++ b/csrc/pybind11/cache/cache.hpp @@ -1,4 +1,5 @@ #include "../../cache/cache.hpp" +#include "infinicore/dtype.hpp" #include "infinicore/tensor.hpp" #include #include @@ -38,12 +39,19 @@ inline void bind_cache(py::module &m) { py::init(), py::arg("num_blocks"), py::arg("block_size") = 256) + .def( + py::init(), + py::arg("num_blocks"), + py::arg("kv_cache_dtype"), + py::arg("block_size") = 16) .def( "num_blocks", &infinilm::cache::PagedKVCacheConfig::num_blocks) .def( "block_size", &infinilm::cache::PagedKVCacheConfig::block_size) + .def("kv_cache_dtype", + &infinilm::cache::PagedKVCacheConfig::kv_cache_dtype) .def("__repr__", [](const infinilm::cache::PagedKVCacheConfig &) { return ""; }); diff --git a/csrc/utils.hpp b/csrc/utils.hpp index 8af9759f..805e1254 100644 --- a/csrc/utils.hpp +++ b/csrc/utils.hpp @@ -1,4 +1,5 @@ #pragma once +#include #include #include @@ -123,3 +124,22 @@ inline uint16_t f32_to_bf16(float val) { inline void hash_combine(size_t &seed, size_t value) { seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); } + +inline infinicore::DataType parse_dtype(const std::string &dtype_str) { + static const std::unordered_map dtype_map = { + {"float32", infinicore::DataType::F32}, + {"float16", infinicore::DataType::F16}, + {"bfloat16", infinicore::DataType::BF16}, + {"int8", infinicore::DataType::I8}, + // 可根据需要扩展 + {"int32", infinicore::DataType::I32}, + {"int64", infinicore::DataType::I64}, + }; + + auto it = dtype_map.find(dtype_str); + if (it != dtype_map.end()) { + return it->second; + } + + throw std::runtime_error("Unsupported dtype string: " + dtype_str); +} diff --git a/python/infinilm/cache/cache.py b/python/infinilm/cache/cache.py index 354309e1..fcd4460e 100644 --- a/python/infinilm/cache/cache.py +++ b/python/infinilm/cache/cache.py @@ -18,9 +18,11 @@ def __init__( self, num_blocks: int, block_size: int = 256, + kv_cache_dtype: str | None = None, ): - _infinilm.PagedKVCacheConfig.__init__( - self, - num_blocks, - block_size, - ) + if kv_cache_dtype is None: + _infinilm.PagedKVCacheConfig.__init__(self, num_blocks, block_size) + else: + _infinilm.PagedKVCacheConfig.__init__( + self, num_blocks, kv_cache_dtype, block_size + ) From a2a2dac837f011fb0517d1344a4c07aed04a7183 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Wed, 18 Mar 2026 17:10:00 +0800 Subject: [PATCH 2/5] Issue/253: Support offline int8 inference with calibrated models --- csrc/cache/kv_cache.cpp | 26 +++++++++++++-- csrc/cache/kv_cache.hpp | 13 +++++++- csrc/config/model_config.hpp | 8 +++++ csrc/config/quant_config.hpp | 20 +++++++++++- csrc/engine/infer_engine.cpp | 5 ++- csrc/engine/infer_engine.hpp | 3 +- csrc/models/llama/llama_attention.cpp | 47 ++++++++++++++++++++++++--- csrc/models/llama/llama_attention.hpp | 11 +++++++ csrc/models/llama/llama_model.cpp | 5 +-- csrc/pybind11/cache/cache.hpp | 7 ++++ csrc/pybind11/engine/engine.hpp | 24 +++++++------- examples/bench.py | 10 +++++- examples/jiuge.py | 10 +++++- python/infinilm/cache/cache.py | 8 +++-- python/infinilm/infer_engine.py | 2 ++ 15 files changed, 169 insertions(+), 30 deletions(-) diff --git a/csrc/cache/kv_cache.cpp b/csrc/cache/kv_cache.cpp index 8bb785ba..5b4e8c7f 100644 --- a/csrc/cache/kv_cache.cpp +++ b/csrc/cache/kv_cache.cpp @@ -16,6 +16,20 @@ StaticKVCacheConfig::StaticKVCacheConfig( max_cache_len_(_max_cache_len) { } +StaticKVCacheConfig::StaticKVCacheConfig( + infinicore::Size _max_batch_size, + infinicore::Size _max_cache_len, + std::string kv_cache_dtype) + : max_batch_size_(_max_batch_size), + max_cache_len_(_max_cache_len) { + if (kv_cache_dtype.empty()) { + kv_cache_dtype_set_ = false; + } else { + this->kv_cache_dtype_ = parse_dtype(kv_cache_dtype); + kv_cache_dtype_set_ = true; + } +} + std::unique_ptr StaticKVCacheConfig::unique_copy() const { return std::make_unique(*this); @@ -42,7 +56,6 @@ StaticKVCache::StaticKVCache( infinicore::Size num_v_heads, infinicore::Size num_layers, infinicore::Size max_positional_embedding, - infinicore::DataType dtype, const StaticKVCacheConfig &config, const engine::distributed::RankInfo &rank_info) : Cache(), @@ -53,7 +66,7 @@ StaticKVCache::StaticKVCache( rank_batch_size_(config.max_batch_size()), cache_len_(config.max_cache_len() == std::numeric_limits::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()), rank_num_layers_(num_layers), - dtype_(dtype) { + dtype_(config.kv_cache_dtype()) { // Allocate K cache k_caches_ = infinicore::Tensor::empty( @@ -115,6 +128,15 @@ StaticKVCache::update(size_t layer_idx, return {k_cache_layer, v_cache_layer}; } +infinicore::DataType +StaticKVCacheConfig::kv_cache_dtype() const { + return kv_cache_dtype_; +} + +void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) const { + kv_cache_dtype_ = dtype; +} + // ========================== // PagedKVCacheConfig // ========================== diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp index ec639aeb..65d0270a 100644 --- a/csrc/cache/kv_cache.hpp +++ b/csrc/cache/kv_cache.hpp @@ -23,13 +23,25 @@ class StaticKVCacheConfig final : public CacheConfig { infinicore::Size _max_batch_size = 1, infinicore::Size _max_cache_len = std::numeric_limits::max()); + StaticKVCacheConfig( + infinicore::Size _max_batch_size, + infinicore::Size _max_cache_len, + std::string kv_cache_dtype); + std::unique_ptr unique_copy() const override; infinicore::Size max_batch_size() const; infinicore::Size max_cache_len() const; + infinicore::DataType kv_cache_dtype() const; + void set_kv_cache_dtype(infinicore::DataType dtype) const; + bool kv_cache_dtype_is_set() const { return kv_cache_dtype_set_; } + private: infinicore::Size max_batch_size_; infinicore::Size max_cache_len_; + + bool kv_cache_dtype_set_ = false; + mutable infinicore::DataType kv_cache_dtype_; }; class StaticKVCache final : public Cache { @@ -42,7 +54,6 @@ class StaticKVCache final : public Cache { infinicore::Size num_v_heads, infinicore::Size num_layers, infinicore::Size max_positional_embedding, - infinicore::DataType dtype, const StaticKVCacheConfig &config, const engine::distributed::RankInfo &rank_info); diff --git a/csrc/config/model_config.hpp b/csrc/config/model_config.hpp index d376aeb7..4db3c264 100644 --- a/csrc/config/model_config.hpp +++ b/csrc/config/model_config.hpp @@ -64,6 +64,14 @@ class ModelConfig { infinicore::DataType get_dtype() const; infinicore::quantization::QuantScheme get_quant_scheme() const; std::shared_ptr get_rope_scaling() const; + void set_kv_quant_scheme(std::string kv_cache_dtype) { + if (kv_cache_dtype == "int8") { + this->quant_config.set_kv_quant_scheme(kv_cache_dtype); + } + } + infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const { + return quant_config.get_kv_quant_scheme(); + } private: nlohmann::json config_json; diff --git a/csrc/config/quant_config.hpp b/csrc/config/quant_config.hpp index 480df067..4771b6f1 100644 --- a/csrc/config/quant_config.hpp +++ b/csrc/config/quant_config.hpp @@ -1,5 +1,5 @@ #pragma once -// #include "../quantization/quantization.hpp" +#include "../utils.hpp" #include "infinicore/quantization.hpp" #include "nlohmann/json.hpp" @@ -22,9 +22,27 @@ class QuantConfig { } } + void set_kv_quant_scheme(std::string kv_cache_dtype) { + switch (parse_dtype(kv_cache_dtype)) { + case infinicore::DataType::I8: { + this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8; + break; + } + default: { + this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; + break; + } + } + } + + infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const { + return kv_quant_scheme; + } + private: nlohmann::json quantization_config; std::shared_ptr quantization_method; + infinicore::quantization::KVQuantAlgo kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; }; } // namespace infinilm::config diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 7dd76eb8..7856fa90 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -55,7 +55,8 @@ InferEngine::InferEngine( infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, - backends::AttentionBackend attention_backend) // Changed parameter + backends::AttentionBackend attention_backend, + const std::string &kv_cache_dtype) // Changed parameter : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); @@ -63,6 +64,8 @@ InferEngine::InferEngine( // Load model config if model_path is provided, model_path must be valid, and config.json exists this->model_config_ = std::make_shared(model_path + "/config.json"); + // Only support offline int8 kv cache quantization in this version + this->model_config_->set_kv_quant_scheme(kv_cache_dtype); // Create one RankWorker per rank int world_size = communication_group_.get_world_size(); barrier_ = std::make_unique((size_t)world_size); diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index d191cfa1..a232ee4f 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -46,7 +46,8 @@ class InferEngine { infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, - backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, + const std::string &kv_cache_dtype = ""); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index 57cad6ec..fc4e51a3 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -7,10 +7,13 @@ #include "infinicore/ops/mha_kvcache.hpp" #include "infinicore/ops/mha_varlen.hpp" #include "infinicore/ops/mul.hpp" +#include "infinicore/ops/per_tensor_dequant_i8.hpp" +#include "infinicore/ops/per_tensor_quant_i8.hpp" #include #include #include +#include #include #include #include @@ -137,6 +140,17 @@ LlamaAttention::LlamaAttention(std::shared_ptr mo INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, model_config_->get("rms_norm_eps"), dtype, device); INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, model_config_->get("rms_norm_eps"), dtype, device); } + + switch (this->model_config_->get_kv_quant_scheme()) { + case (infinicore::quantization::KVQuantAlgo::INT8): { + INFINICORE_NN_PARAMETER_INIT(kv_cache_k_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); + INFINICORE_NN_PARAMETER_INIT(kv_cache_v_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); + break; + } + default: { + break; + } + } } infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states, @@ -184,6 +198,17 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim] rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim] + switch (this->model_config_->get_kv_quant_scheme()) { + case (infinicore::quantization::KVQuantAlgo::INT8): { + k_reshaped = infinicore::op::per_tensor_quant_i8(k_reshaped, this->kv_cache_k_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()), true); + v_reshaped = infinicore::op::per_tensor_quant_i8(v_reshaped, this->kv_cache_v_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()), true); + break; + } + default: { + break; + } + } + // 5. Prepare KV caches // Convert to [batch, n_head, seq_len, head_dim] for cache // Ensure contiguous after permute for F16 compatibility with cache operations @@ -212,6 +237,21 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] } else { size_t total_seq_len = reinterpret_cast(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0]; + + switch (this->model_config_->get_kv_quant_scheme()) { + case (infinicore::quantization::KVQuantAlgo::INT8): { + auto k_total_dequant = infinicore::Tensor::strided_empty(k_total->shape(), k_total->strides(), q_reshaped->dtype(), q_reshaped->device()); + auto v_total_dequant = infinicore::Tensor::strided_empty(v_total->shape(), v_total->strides(), q_reshaped->dtype(), q_reshaped->device()); + infinicore::op::per_tensor_dequant_i8_(k_total_dequant, k_total, this->kv_cache_k_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device())); + infinicore::op::per_tensor_dequant_i8_(v_total_dequant, v_total, this->kv_cache_v_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device())); + k_total = k_total_dequant; + v_total = v_total_dequant; + break; + } + default: { + break; + } + } k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] @@ -342,10 +382,10 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd auto q_for_fa = q_reshaped->view({seq_len, 1, num_attention_heads_, head_dim_}); auto attn_out_4d = infinicore::op::mha_kvcache( q_for_fa, - k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim] + k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim] v_total->permute({0, 2, 1, 3}), - total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence) - block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32 + total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence) + block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32 std::nullopt, scaling_); attn_output = attn_out_4d->view({seq_len, num_attention_heads_, head_dim_}); @@ -361,7 +401,6 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd scaling_); } } - // 7. Project output attn_output diff --git a/csrc/models/llama/llama_attention.hpp b/csrc/models/llama/llama_attention.hpp index 4fe369d6..9924f4b8 100644 --- a/csrc/models/llama/llama_attention.hpp +++ b/csrc/models/llama/llama_attention.hpp @@ -112,6 +112,13 @@ class LlamaAttention : public infinicore::nn::Module { std::optional block_tables, std::optional slot_mapping) const; + infinicore::Tensor kv_cache_k_scale() const { + return kv_cache_k_scale_; + } + infinicore::Tensor kv_cache_v_scale() const { + return kv_cache_v_scale_; + } + protected: // Projection layers INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj); @@ -123,6 +130,10 @@ class LlamaAttention : public infinicore::nn::Module { // Shared Rotary Position Embeddings (RoPE) std::shared_ptr rotary_emb_; + // For off-line kv cache quantization + INFINICORE_NN_PARAMETER(kv_cache_k_scale); + INFINICORE_NN_PARAMETER(kv_cache_v_scale); + private: std::shared_ptr model_config_ = std::make_shared(); size_t layer_idx_; // Layer index for cache access diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp index 27d9aff2..5a224470 100644 --- a/csrc/models/llama/llama_model.cpp +++ b/csrc/models/llama/llama_model.cpp @@ -136,7 +136,6 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { config_.num_key_value_heads, config_.num_hidden_layers, config_.max_position_embeddings, - config_.dtype, *kv_cache_config, rank_info_); } else if (auto paged_kv_cache_config = dynamic_cast(cache_config); @@ -150,6 +149,9 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { *paged_kv_cache_config, rank_info_); } else if (auto kv_cache_config = dynamic_cast(cache_config)) { + if (!kv_cache_config->kv_cache_dtype_is_set()) { + kv_cache_config->set_kv_cache_dtype(model_config_->get_dtype()); + } kv_cache_ = std::make_shared( model_config_->get_head_dim(), model_config_->get_head_dim(), @@ -157,7 +159,6 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { model_config_->get("num_key_value_heads"), model_config_->get("num_hidden_layers"), model_config_->get("max_position_embeddings"), - model_config_->get_dtype(), *kv_cache_config, rank_info_); } else if (auto paged_kv_cache_config = dynamic_cast(cache_config)) { diff --git a/csrc/pybind11/cache/cache.hpp b/csrc/pybind11/cache/cache.hpp index 37244b34..73fcf78f 100644 --- a/csrc/pybind11/cache/cache.hpp +++ b/csrc/pybind11/cache/cache.hpp @@ -22,12 +22,19 @@ inline void bind_cache(py::module &m) { py::init(), py::arg("max_batch_size") = 1, py::arg("max_cache_len") = std::numeric_limits::max()) + .def( + py::init(), + py::arg("max_batch_size") = 1, + py::arg("max_cache_len") = std::numeric_limits::max(), + py::arg("kv_cache_dtype")) .def( "max_batch_size", &infinilm::cache::StaticKVCacheConfig::max_batch_size) .def( "max_cache_len", &infinilm::cache::StaticKVCacheConfig::max_cache_len) + .def("kv_cache_dtype", + &infinilm::cache::StaticKVCacheConfig::kv_cache_dtype) .def("__repr__", [](const infinilm::cache::StaticKVCacheConfig &) { return ""; }); diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 4aeec8af..abb1b8d5 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -66,14 +66,11 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def( - "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def( - "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { auto cfg = self.get_cache_config(); - return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; - }) + return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) .def("__repr__", [](const InferEngine &self) { return ""; }); infer_engine @@ -83,21 +80,24 @@ inline void bind_infer_engine(py::module &m) { infinicore::Device::Type dev, std::shared_ptr cache_cfg, bool enable_graph_compiling, - const std::string &attention_backend) { + const std::string &attention_backend, + const std::string &kv_cache_dtype) { return std::make_shared( model_path, dist, dev, cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, - infinilm::backends::parse_attention_backend(attention_backend)); + infinilm::backends::parse_attention_backend(attention_backend), + kv_cache_dtype); }), py::arg("model_path") = "", py::arg("distributed_config") = distributed::DistConfig(), py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, - py::arg("attention_backend") = "default") + py::arg("attention_backend") = "default", + py::arg("kv_cache_dtype") = "") .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") @@ -112,10 +112,8 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def( - "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def( - "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) { auto cfg = self.get_cache_config(); return std::shared_ptr(std::move(cfg->unique_copy())); }) diff --git a/examples/bench.py b/examples/bench.py index 9ac2b11e..35c0c3f4 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -259,6 +259,13 @@ def get_args(): choices=["default", "flash-attn"], help="attention backend to use: 'default' or 'flash-attn'", ) + parser.add_argument( + "--kv_cache_dtype", + type=str, + default="", + choices=["", "int8"], + ) + return parser.parse_args() @@ -298,6 +305,7 @@ def __init__( cache_config=cache_config, enable_graph_compiling=enable_graph, attention_backend=attn_backend, + kv_cache_dtype=args.kv_cache_dtype ) # ---------------------------------------------------------------------------- # @@ -536,7 +544,7 @@ def run( initial_capacity = input_len + output_len test.model.reset_cache( StaticKVCacheConfig( - max_batch_size=batch_size, max_cache_len=initial_capacity + max_batch_size=batch_size, max_cache_len=initial_capacity, kv_cache_dtype=args.kv_cache_dtype ) ) diff --git a/examples/jiuge.py b/examples/jiuge.py index 2cb99d66..72515de5 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -149,6 +149,13 @@ def get_args(): choices=["default", "flash-attn"], help="attention backend to use: 'default' or 'flash-attn'", ) + + parser.add_argument( + "--kv_cache_dtype", + type=str, + default="", + choices=["", "int8"], + ) return parser.parse_args() @@ -176,6 +183,7 @@ def test( distributed_config=DistConfig(tp), enable_graph_compiling=enable_graph, attention_backend=attn_backend, + kv_cache_dtype=args.kv_cache_dtype, ) # ---------------------------------------------------------------------------- # # Load Weights @@ -255,7 +263,7 @@ def test( batch_size = 1 if prompts is str else len(prompts) initial_capacity = max_new_tokens + len(input_ids_list[0]) cache_config = StaticKVCacheConfig( - max_batch_size=batch_size, max_cache_len=initial_capacity + max_batch_size=batch_size, max_cache_len=initial_capacity, kv_cache_dtype=args.kv_cache_dtype ) model.reset_cache(cache_config) diff --git a/python/infinilm/cache/cache.py b/python/infinilm/cache/cache.py index fcd4460e..0ce984b7 100644 --- a/python/infinilm/cache/cache.py +++ b/python/infinilm/cache/cache.py @@ -9,9 +9,11 @@ def __init__(self): class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig): - def __init__(self, max_batch_size: int = 1, max_cache_len: int = 0): - _infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len) - + def __init__(self, max_batch_size: int = 1, max_cache_len: int = 0, kv_cache_dtype: str | None = None): + if kv_cache_dtype is None: + _infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len) + else: + _infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len, kv_cache_dtype) class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig): def __init__( diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 8b614980..a4b9a3ae 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -30,6 +30,7 @@ def __init__( cache_config=None, enable_graph_compiling=False, attention_backend="default", + kv_cache_dtype="", ): self.config = AutoConfig.from_pretrained(model_path) @@ -43,6 +44,7 @@ def __init__( cache_config, enable_graph_compiling, attention_backend, + kv_cache_dtype, ) self.use_cache = False From 7796a7674962575df062f224bf667e04476ff095 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Thu, 19 Mar 2026 16:31:19 +0800 Subject: [PATCH 3/5] Issue/253: (1) Refactor attention KV cache quantization to layers/kv_quant.cpp; (2)update kv_cache_dtype handling; (3)Update Python test scripts --- csrc/cache/kv_cache.cpp | 36 +++++++++------- csrc/cache/kv_cache.hpp | 13 +++--- csrc/config/model_config.hpp | 4 +- csrc/config/quant_config.hpp | 27 +++++++++--- csrc/engine/infer_engine.cpp | 2 +- csrc/engine/infer_engine.hpp | 2 +- csrc/engine/rank_worker.cpp | 2 +- csrc/engine/rank_worker.hpp | 2 +- csrc/layers/kv_quant.cpp | 52 +++++++++++++++++++++++ csrc/layers/kv_quant.hpp | 44 +++++++++++++++++++ csrc/models/infinilm_model.hpp | 2 +- csrc/models/llama/llama_attention.cpp | 36 ++++++---------- csrc/models/llama/llama_attention.hpp | 8 +--- csrc/models/llama/llama_for_causal_lm.cpp | 2 +- csrc/models/llama/llama_for_causal_lm.hpp | 2 +- csrc/models/llama/llama_model.cpp | 18 +++----- csrc/models/llama/llama_model.hpp | 2 +- csrc/models/model_factory.cpp | 4 +- csrc/models/model_factory.hpp | 4 +- csrc/pybind11/engine/engine.hpp | 8 ++-- examples/bench.py | 4 +- examples/jiuge.py | 8 ++-- 22 files changed, 185 insertions(+), 97 deletions(-) create mode 100644 csrc/layers/kv_quant.cpp create mode 100644 csrc/layers/kv_quant.hpp diff --git a/csrc/cache/kv_cache.cpp b/csrc/cache/kv_cache.cpp index 5b4e8c7f..b28156e6 100644 --- a/csrc/cache/kv_cache.cpp +++ b/csrc/cache/kv_cache.cpp @@ -2,6 +2,7 @@ #include "../utils.hpp" #include "infinicore/ops.hpp" +#include #include namespace infinilm::cache { @@ -22,11 +23,8 @@ StaticKVCacheConfig::StaticKVCacheConfig( std::string kv_cache_dtype) : max_batch_size_(_max_batch_size), max_cache_len_(_max_cache_len) { - if (kv_cache_dtype.empty()) { - kv_cache_dtype_set_ = false; - } else { - this->kv_cache_dtype_ = parse_dtype(kv_cache_dtype); - kv_cache_dtype_set_ = true; + if (!kv_cache_dtype.empty()) { + this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype)); } } @@ -130,11 +128,14 @@ StaticKVCache::update(size_t layer_idx, infinicore::DataType StaticKVCacheConfig::kv_cache_dtype() const { - return kv_cache_dtype_; + return kv_cache_dtype_.value(); } - -void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) const { - kv_cache_dtype_ = dtype; +void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) { + if (!this->kv_cache_dtype_.has_value()) { + this->kv_cache_dtype_ = std::make_optional(dtype); + } else { + return; + } } // ========================== @@ -145,9 +146,10 @@ PagedKVCacheConfig::PagedKVCacheConfig( std::string kv_cache_dtype, size_t block_size) : num_blocks_(num_blocks), - block_size_(block_size), - kv_cache_dtype_(parse_dtype(kv_cache_dtype)) { - kv_cache_dtype_set_ = true; + block_size_(block_size) { + if (!kv_cache_dtype.empty()) { + this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype)); + } } PagedKVCacheConfig::PagedKVCacheConfig( @@ -174,11 +176,15 @@ PagedKVCacheConfig::block_size() const { infinicore::DataType PagedKVCacheConfig::kv_cache_dtype() const { - return kv_cache_dtype_; + return kv_cache_dtype_.value(); } -void PagedKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) const { - kv_cache_dtype_ = dtype; +void PagedKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) { + if (!this->kv_cache_dtype_.has_value()) { + this->kv_cache_dtype_ = std::make_optional(dtype); + } else { + return; + } } // ========================== diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp index 65d0270a..8e3e6f1a 100644 --- a/csrc/cache/kv_cache.hpp +++ b/csrc/cache/kv_cache.hpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -33,15 +34,13 @@ class StaticKVCacheConfig final : public CacheConfig { infinicore::Size max_cache_len() const; infinicore::DataType kv_cache_dtype() const; - void set_kv_cache_dtype(infinicore::DataType dtype) const; - bool kv_cache_dtype_is_set() const { return kv_cache_dtype_set_; } + void set_kv_cache_dtype(infinicore::DataType dtype); private: infinicore::Size max_batch_size_; infinicore::Size max_cache_len_; - bool kv_cache_dtype_set_ = false; - mutable infinicore::DataType kv_cache_dtype_; + std::optional kv_cache_dtype_ = std::nullopt; }; class StaticKVCache final : public Cache { @@ -109,15 +108,13 @@ class PagedKVCacheConfig final : public CacheConfig { size_t num_blocks() const; size_t block_size() const; infinicore::DataType kv_cache_dtype() const; - void set_kv_cache_dtype(infinicore::DataType dtype) const; - bool kv_cache_dtype_set() const { return kv_cache_dtype_set_; } + void set_kv_cache_dtype(infinicore::DataType dtype); private: size_t num_blocks_; size_t block_size_; - bool kv_cache_dtype_set_ = false; - mutable infinicore::DataType kv_cache_dtype_; + std::optional kv_cache_dtype_ = std::nullopt; }; class PagedKVCache final : public Cache { diff --git a/csrc/config/model_config.hpp b/csrc/config/model_config.hpp index 4db3c264..4de50b31 100644 --- a/csrc/config/model_config.hpp +++ b/csrc/config/model_config.hpp @@ -65,9 +65,7 @@ class ModelConfig { infinicore::quantization::QuantScheme get_quant_scheme() const; std::shared_ptr get_rope_scaling() const; void set_kv_quant_scheme(std::string kv_cache_dtype) { - if (kv_cache_dtype == "int8") { - this->quant_config.set_kv_quant_scheme(kv_cache_dtype); - } + this->quant_config.set_kv_quant_scheme(kv_cache_dtype); } infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const { return quant_config.get_kv_quant_scheme(); diff --git a/csrc/config/quant_config.hpp b/csrc/config/quant_config.hpp index 4771b6f1..99ccff0f 100644 --- a/csrc/config/quant_config.hpp +++ b/csrc/config/quant_config.hpp @@ -2,6 +2,7 @@ #include "../utils.hpp" #include "infinicore/quantization.hpp" #include "nlohmann/json.hpp" +#include namespace infinilm::config { @@ -23,15 +24,27 @@ class QuantConfig { } void set_kv_quant_scheme(std::string kv_cache_dtype) { - switch (parse_dtype(kv_cache_dtype)) { - case infinicore::DataType::I8: { - this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8; - break; - } - default: { + if (kv_cache_dtype.empty()) { this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; - break; + // spdlog::debug("kv_cache_dtype is empty, using default NONE"); + return; } + + try { + switch (parse_dtype(kv_cache_dtype)) { + case infinicore::DataType::I8: { + this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8; + break; + } + default: { + spdlog::warn("Unsupported kv_cache_dtype: '{}', fallback to NONE", kv_cache_dtype); + this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; + break; + } + } + } catch (const std::exception &e) { + spdlog::error("Failed to parse kv_cache_dtype '{}': {}", kv_cache_dtype, e.what()); + this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; } } diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 7856fa90..2e804d07 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -171,7 +171,7 @@ const distributed::DistConfig &InferEngine::get_dist_config() const { //------------------------------------------------------ // reset_cache (overloaded with CacheConfig) //------------------------------------------------------ -void InferEngine::reset_cache(const cache::CacheConfig *new_config) { +void InferEngine::reset_cache(cache::CacheConfig *new_config) { for (auto &worker : workers_) { worker->reset_cache(new_config); } diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index a232ee4f..14a878db 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -60,7 +60,7 @@ class InferEngine { void compile(); - void reset_cache(const cache::CacheConfig *new_config); + void reset_cache(cache::CacheConfig *new_config); ~InferEngine(); diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index c9cdc97f..8a640fe9 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -186,7 +186,7 @@ void RankWorker::wait() { } } -void RankWorker::reset_cache(const cache::CacheConfig *new_config) { +void RankWorker::reset_cache(cache::CacheConfig *new_config) { std::lock_guard lock(mutex_); if (should_exit_) { throw std::runtime_error("RankWorker is closing; cannot reset_cache"); diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 719c1cb1..32cdd5ad 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -85,7 +85,7 @@ class RankWorker { void run(const Input &args); // Reset the internal cache with a new configuration - void reset_cache(const cache::CacheConfig *new_config); + void reset_cache(cache::CacheConfig *new_config); // Compile the model graph if enabled. void compile(); diff --git a/csrc/layers/kv_quant.cpp b/csrc/layers/kv_quant.cpp new file mode 100644 index 00000000..15114d18 --- /dev/null +++ b/csrc/layers/kv_quant.cpp @@ -0,0 +1,52 @@ +#include "kv_quant.hpp" +#include "infinicore/ops/per_tensor_dequant_i8.hpp" +#include "infinicore/ops/per_tensor_quant_i8.hpp" + +namespace infinilm { + +void KVQuantUtils::quantize( + infinicore::Tensor &k, + infinicore::Tensor &v, + infinicore::quantization::KVQuantAlgo algo, + const infinicore::Tensor &k_scale, + const infinicore::Tensor &v_scale) { + + if (algo == infinicore::quantization::KVQuantAlgo::NONE) { + return; + } + + auto device = k->device(); + auto dtype = k->dtype(); + auto zero_point = infinicore::Tensor::zeros({1}, dtype, device); + + k = infinicore::op::per_tensor_quant_i8(k, k_scale, zero_point, true); + v = infinicore::op::per_tensor_quant_i8(v, v_scale, zero_point, true); +} + +void KVQuantUtils::dequantize( + infinicore::Tensor &k, + infinicore::Tensor &v, + infinicore::quantization::KVQuantAlgo algo, + const infinicore::Tensor &k_scale, + const infinicore::Tensor &v_scale, + const infinicore::Tensor &reference) { + + if (algo == infinicore::quantization::KVQuantAlgo::NONE) { + return; // 无需反量化 + } + + auto zero_point = infinicore::Tensor::zeros({1}, reference->dtype(), reference->device()); + + auto k_dequant = infinicore::Tensor::strided_empty( + k->shape(), k->strides(), reference->dtype(), reference->device()); + auto v_dequant = infinicore::Tensor::strided_empty( + v->shape(), v->strides(), reference->dtype(), reference->device()); + + infinicore::op::per_tensor_dequant_i8_(k_dequant, k, k_scale, zero_point); + infinicore::op::per_tensor_dequant_i8_(v_dequant, v, v_scale, zero_point); + + k = std::move(k_dequant); + v = std::move(v_dequant); +} + +} // namespace infinilm \ No newline at end of file diff --git a/csrc/layers/kv_quant.hpp b/csrc/layers/kv_quant.hpp new file mode 100644 index 00000000..146c8f65 --- /dev/null +++ b/csrc/layers/kv_quant.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include "infinicore/quantization.hpp" +#include "infinicore/tensor.hpp" +#include + +namespace infinilm { + +class KVQuantUtils { +public: + /** + * @brief 量化 K/V(写入缓存前)- 原地修改 k 和 v + * @param k 原始 K 张量 + * @param v 原始 V 张量 + * @param algo 量化算法 + * @param k_scale K 的 scale + * @param v_scale V 的 scale + */ + static void quantize( + infinicore::Tensor &k, + infinicore::Tensor &v, + infinicore::quantization::KVQuantAlgo algo, + const infinicore::Tensor &k_scale, + const infinicore::Tensor &v_scale); + + /** + * @brief 反量化 K/V(读取缓存后)- 原地修改 k 和 v + * @param k 量化后的 K 张量 + * @param v 量化后的 V 张量 + * @param algo 量化算法 + * @param k_scale K 的 scale + * @param v_scale V 的 scale + * @param reference 参考张量(用于获取 dtype/device) + */ + static void dequantize( + infinicore::Tensor &k, + infinicore::Tensor &v, + infinicore::quantization::KVQuantAlgo algo, + const infinicore::Tensor &k_scale, + const infinicore::Tensor &v_scale, + const infinicore::Tensor &reference); +}; + +} // namespace infinilm diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp index 550bf1aa..c4c5296f 100644 --- a/csrc/models/infinilm_model.hpp +++ b/csrc/models/infinilm_model.hpp @@ -43,7 +43,7 @@ class InfinilmModel : public infinicore::nn::Module { virtual ~InfinilmModel() = default; virtual Output forward(const Input &input) const = 0; - virtual void reset_cache(const cache::CacheConfig *cache_config) = 0; + virtual void reset_cache(cache::CacheConfig *cache_config) = 0; virtual const cache::CacheConfig *get_cache_config() const = 0; }; } // namespace infinilm diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index fc4e51a3..5737b886 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -198,16 +198,11 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim] rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim] - switch (this->model_config_->get_kv_quant_scheme()) { - case (infinicore::quantization::KVQuantAlgo::INT8): { - k_reshaped = infinicore::op::per_tensor_quant_i8(k_reshaped, this->kv_cache_k_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()), true); - v_reshaped = infinicore::op::per_tensor_quant_i8(v_reshaped, this->kv_cache_v_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()), true); - break; - } - default: { - break; - } - } + infinilm::KVQuantUtils::quantize( + k_reshaped, v_reshaped, + this->model_config_->get_kv_quant_scheme(), + this->kv_cache_k_scale_, + this->kv_cache_v_scale_); // 5. Prepare KV caches // Convert to [batch, n_head, seq_len, head_dim] for cache @@ -238,20 +233,13 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta } else { size_t total_seq_len = reinterpret_cast(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0]; - switch (this->model_config_->get_kv_quant_scheme()) { - case (infinicore::quantization::KVQuantAlgo::INT8): { - auto k_total_dequant = infinicore::Tensor::strided_empty(k_total->shape(), k_total->strides(), q_reshaped->dtype(), q_reshaped->device()); - auto v_total_dequant = infinicore::Tensor::strided_empty(v_total->shape(), v_total->strides(), q_reshaped->dtype(), q_reshaped->device()); - infinicore::op::per_tensor_dequant_i8_(k_total_dequant, k_total, this->kv_cache_k_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device())); - infinicore::op::per_tensor_dequant_i8_(v_total_dequant, v_total, this->kv_cache_v_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device())); - k_total = k_total_dequant; - v_total = v_total_dequant; - break; - } - default: { - break; - } - } + infinilm::KVQuantUtils::dequantize( + k_total, v_total, + this->model_config_->get_kv_quant_scheme(), + this->kv_cache_k_scale_, + this->kv_cache_v_scale_, + q_reshaped); + k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] diff --git a/csrc/models/llama/llama_attention.hpp b/csrc/models/llama/llama_attention.hpp index 9924f4b8..39493e76 100644 --- a/csrc/models/llama/llama_attention.hpp +++ b/csrc/models/llama/llama_attention.hpp @@ -5,6 +5,7 @@ #include "../../config/model_config.hpp" #include "../../engine/distributed/distributed.hpp" #include "../../layers/fused_linear.hpp" +#include "../../layers/kv_quant.hpp" #include "llama_config.hpp" #include "infinicore/nn/linear.hpp" @@ -112,13 +113,6 @@ class LlamaAttention : public infinicore::nn::Module { std::optional block_tables, std::optional slot_mapping) const; - infinicore::Tensor kv_cache_k_scale() const { - return kv_cache_k_scale_; - } - infinicore::Tensor kv_cache_v_scale() const { - return kv_cache_v_scale_; - } - protected: // Projection layers INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj); diff --git a/csrc/models/llama/llama_for_causal_lm.cpp b/csrc/models/llama/llama_for_causal_lm.cpp index 9596e668..1fc0708f 100644 --- a/csrc/models/llama/llama_for_causal_lm.cpp +++ b/csrc/models/llama/llama_for_causal_lm.cpp @@ -71,7 +71,7 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { return {logits}; } -void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) { +void LlamaForCausalLM::reset_cache(cache::CacheConfig *cache_config) { cache_config_ = cache_config->unique_copy(); model_->reset_cache(cache_config_.get()); } diff --git a/csrc/models/llama/llama_for_causal_lm.hpp b/csrc/models/llama/llama_for_causal_lm.hpp index 5cc79dfe..5727aea4 100644 --- a/csrc/models/llama/llama_for_causal_lm.hpp +++ b/csrc/models/llama/llama_for_causal_lm.hpp @@ -58,7 +58,7 @@ class LlamaForCausalLM : public InfinilmModel { */ Output forward(const Input &input) const; - void reset_cache(const cache::CacheConfig *cache_config) override; + void reset_cache(cache::CacheConfig *cache_config) override; const cache::CacheConfig *get_cache_config() const override; diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp index 5a224470..233c7b04 100644 --- a/csrc/models/llama/llama_model.cpp +++ b/csrc/models/llama/llama_model.cpp @@ -122,12 +122,12 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, return hidden_states; } -void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { +void LlamaModel::reset_cache(cache::CacheConfig *cache_config) { if (cache_config == nullptr) { kv_cache_ = nullptr; return; } - if (auto kv_cache_config = dynamic_cast(cache_config); + if (auto kv_cache_config = dynamic_cast(cache_config); kv_cache_config && model_config_ == nullptr) { kv_cache_ = std::make_shared( config_.head_dim, @@ -138,7 +138,7 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { config_.max_position_embeddings, *kv_cache_config, rank_info_); - } else if (auto paged_kv_cache_config = dynamic_cast(cache_config); + } else if (auto paged_kv_cache_config = dynamic_cast(cache_config); paged_kv_cache_config && model_config_ == nullptr) { kv_cache_ = std::make_shared( config_.head_dim, @@ -148,10 +148,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { config_.num_hidden_layers, *paged_kv_cache_config, rank_info_); - } else if (auto kv_cache_config = dynamic_cast(cache_config)) { - if (!kv_cache_config->kv_cache_dtype_is_set()) { - kv_cache_config->set_kv_cache_dtype(model_config_->get_dtype()); - } + } else if (auto kv_cache_config = dynamic_cast(cache_config)) { + kv_cache_config->set_kv_cache_dtype(model_config_->get_dtype()); kv_cache_ = std::make_shared( model_config_->get_head_dim(), model_config_->get_head_dim(), @@ -161,10 +159,8 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { model_config_->get("max_position_embeddings"), *kv_cache_config, rank_info_); - } else if (auto paged_kv_cache_config = dynamic_cast(cache_config)) { - if (!paged_kv_cache_config->kv_cache_dtype_set()) { - paged_kv_cache_config->set_kv_cache_dtype(model_config_->get_dtype()); - } + } else if (auto paged_kv_cache_config = dynamic_cast(cache_config)) { + paged_kv_cache_config->set_kv_cache_dtype(model_config_->get_dtype()); kv_cache_ = std::make_shared( model_config_->get_head_dim(), model_config_->get_head_dim(), diff --git a/csrc/models/llama/llama_model.hpp b/csrc/models/llama/llama_model.hpp index 416e1a5c..7d504408 100644 --- a/csrc/models/llama/llama_model.hpp +++ b/csrc/models/llama/llama_model.hpp @@ -79,7 +79,7 @@ class LlamaModel : public infinicore::nn::Module { std::optional block_tables, std::optional slot_mapping) const; - void reset_cache(const cache::CacheConfig *cache_config); + void reset_cache(cache::CacheConfig *cache_config); // Module information size_t num_layers() const { return model_config_->get("num_hidden_layers"); } diff --git a/csrc/models/model_factory.cpp b/csrc/models/model_factory.cpp index 319a7baa..0d4caa69 100644 --- a/csrc/models/model_factory.cpp +++ b/csrc/models/model_factory.cpp @@ -17,7 +17,7 @@ namespace infinilm { std::shared_ptr InfinilmModelFactory::createModel( const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info, - const cache::CacheConfig *cache, + cache::CacheConfig *cache, backends::AttentionBackend attention_backend) { std::shared_ptr model; if (const auto llama_config_ptr = dynamic_cast(&config)) { @@ -38,7 +38,7 @@ std::shared_ptr InfinilmModelFactory::createModel( std::shared_ptr InfinilmModelFactory::createModel( std::shared_ptr model_config, engine::distributed::RankInfo rank_info, - const cache::CacheConfig *cache, + cache::CacheConfig *cache, backends::AttentionBackend attention_backend) { std::shared_ptr model; diff --git a/csrc/models/model_factory.hpp b/csrc/models/model_factory.hpp index 3c3c2e38..8af25d84 100644 --- a/csrc/models/model_factory.hpp +++ b/csrc/models/model_factory.hpp @@ -24,13 +24,13 @@ class InfinilmModelFactory { static std::shared_ptr createModel( const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), - const cache::CacheConfig *cache = nullptr, + cache::CacheConfig *cache = nullptr, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); static std::shared_ptr createModel( std::shared_ptr model_config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), - const cache::CacheConfig *cache = nullptr, + cache::CacheConfig *cache = nullptr, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); }; } // namespace infinilm diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index abb1b8d5..2b604a3c 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -35,7 +35,7 @@ inline void bind_infer_engine(py::module &m) { const InfinilmModel::Config &cfg, const distributed::DistConfig &dist, infinicore::Device::Type dev, - std::shared_ptr cache_cfg, + std::shared_ptr cache_cfg, bool enable_graph_compiling, const std::string &attention_backend) { return std::make_shared( @@ -67,7 +67,7 @@ inline void bind_infer_engine(py::module &m) { return state_dict_tp_all; }) .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { auto cfg = self.get_cache_config(); return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) @@ -78,7 +78,7 @@ inline void bind_infer_engine(py::module &m) { const std::string &model_path, const distributed::DistConfig &dist, infinicore::Device::Type dev, - std::shared_ptr cache_cfg, + std::shared_ptr cache_cfg, bool enable_graph_compiling, const std::string &attention_backend, const std::string &kv_cache_dtype) { @@ -113,7 +113,7 @@ inline void bind_infer_engine(py::module &m) { return state_dict_tp_all; }) .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) { auto cfg = self.get_cache_config(); return std::shared_ptr(std::move(cfg->unique_copy())); }) diff --git a/examples/bench.py b/examples/bench.py index 35c0c3f4..ee7ecd92 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -237,7 +237,7 @@ def get_args(): help="use paged cache", ) parser.add_argument( - "--paged_kv_block_size", + "--paged-kv-block-size", type=int, default=256, help="num tokens each kv block can hold", @@ -260,7 +260,7 @@ def get_args(): help="attention backend to use: 'default' or 'flash-attn'", ) parser.add_argument( - "--kv_cache_dtype", + "--kv-cache-dtype", type=str, default="", choices=["", "int8"], diff --git a/examples/jiuge.py b/examples/jiuge.py index 72515de5..f0e64c9f 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -67,13 +67,13 @@ def get_args(): help="Run hygon test", ) parser.add_argument( - "--model_path", + "--model-path", type=str, required=True, help="model_path", ) parser.add_argument( - "--max_new_tokens", + "--max-new-tokens", type=int, default=100, help="max_new_tokens", @@ -109,7 +109,7 @@ def get_args(): ) parser.add_argument( - "--paged_kv_block_size", + "--paged-kv-block-size", type=int, default=256, help="num tokens each kv block can hold", @@ -151,7 +151,7 @@ def get_args(): ) parser.add_argument( - "--kv_cache_dtype", + "--kv-cache-dtype", type=str, default="", choices=["", "int8"], From 51d0a81533c8346993ef1941d8d645badb08d811 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Mon, 23 Mar 2026 01:08:52 +0000 Subject: [PATCH 4/5] issue/253 refine static kv cache init --- csrc/cache/kv_cache.cpp | 32 ++++++------------------ csrc/cache/kv_cache.hpp | 18 ++++---------- csrc/config/model_config.hpp | 2 +- csrc/config/quant_config.hpp | 14 +++-------- csrc/engine/infer_engine.cpp | 6 +++-- csrc/engine/infer_engine.hpp | 2 +- csrc/pybind11/cache/cache.hpp | 18 ++++---------- csrc/pybind11/engine/engine.hpp | 16 +++++++----- examples/bench.py | 10 +++++--- python/infinilm/cache/cache.py | 41 +++++++++++++++++++++++++------ python/infinilm/infer_engine.py | 8 ++++-- python/infinilm/modeling_utils.py | 18 ++++++++++++++ 12 files changed, 100 insertions(+), 85 deletions(-) diff --git a/csrc/cache/kv_cache.cpp b/csrc/cache/kv_cache.cpp index b28156e6..76cadb6b 100644 --- a/csrc/cache/kv_cache.cpp +++ b/csrc/cache/kv_cache.cpp @@ -10,22 +10,13 @@ namespace infinilm::cache { // StaticKVCacheConfig // ========================== -StaticKVCacheConfig::StaticKVCacheConfig( - infinicore::Size _max_batch_size, - infinicore::Size _max_cache_len) - : max_batch_size_(_max_batch_size), - max_cache_len_(_max_cache_len) { -} - StaticKVCacheConfig::StaticKVCacheConfig( infinicore::Size _max_batch_size, infinicore::Size _max_cache_len, - std::string kv_cache_dtype) + std::optional kv_cache_dtype) : max_batch_size_(_max_batch_size), - max_cache_len_(_max_cache_len) { - if (!kv_cache_dtype.empty()) { - this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype)); - } + max_cache_len_(_max_cache_len), + kv_cache_dtype_(kv_cache_dtype) { } std::unique_ptr @@ -143,20 +134,11 @@ void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) { // ========================== PagedKVCacheConfig::PagedKVCacheConfig( size_t num_blocks, - std::string kv_cache_dtype, - size_t block_size) - : num_blocks_(num_blocks), - block_size_(block_size) { - if (!kv_cache_dtype.empty()) { - this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype)); - } -} - -PagedKVCacheConfig::PagedKVCacheConfig( - size_t num_blocks, - size_t block_size) + size_t block_size, + std::optional kv_cache_dtype) : num_blocks_(num_blocks), - block_size_(block_size) { + block_size_(block_size), + kv_cache_dtype_(kv_cache_dtype) { } std::unique_ptr diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp index 8e3e6f1a..89b73c75 100644 --- a/csrc/cache/kv_cache.hpp +++ b/csrc/cache/kv_cache.hpp @@ -22,12 +22,8 @@ class StaticKVCacheConfig final : public CacheConfig { public: StaticKVCacheConfig( infinicore::Size _max_batch_size = 1, - infinicore::Size _max_cache_len = std::numeric_limits::max()); - - StaticKVCacheConfig( - infinicore::Size _max_batch_size, - infinicore::Size _max_cache_len, - std::string kv_cache_dtype); + infinicore::Size _max_cache_len = std::numeric_limits::max(), + std::optional kv_cache_dtype = std::nullopt); std::unique_ptr unique_copy() const override; infinicore::Size max_batch_size() const; @@ -40,7 +36,7 @@ class StaticKVCacheConfig final : public CacheConfig { infinicore::Size max_batch_size_; infinicore::Size max_cache_len_; - std::optional kv_cache_dtype_ = std::nullopt; + std::optional kv_cache_dtype_; }; class StaticKVCache final : public Cache { @@ -97,12 +93,8 @@ class PagedKVCacheConfig final : public CacheConfig { public: PagedKVCacheConfig( size_t num_blocks, - size_t block_size = 256); - - PagedKVCacheConfig( - size_t num_blocks, - std::string kv_cache_dtype, - size_t block_size = 16); + size_t block_size = 256, + std::optional kv_cache_dtype = std::nullopt); std::unique_ptr unique_copy() const override; size_t num_blocks() const; diff --git a/csrc/config/model_config.hpp b/csrc/config/model_config.hpp index 4de50b31..db69579b 100644 --- a/csrc/config/model_config.hpp +++ b/csrc/config/model_config.hpp @@ -64,7 +64,7 @@ class ModelConfig { infinicore::DataType get_dtype() const; infinicore::quantization::QuantScheme get_quant_scheme() const; std::shared_ptr get_rope_scaling() const; - void set_kv_quant_scheme(std::string kv_cache_dtype) { + void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) { this->quant_config.set_kv_quant_scheme(kv_cache_dtype); } infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const { diff --git a/csrc/config/quant_config.hpp b/csrc/config/quant_config.hpp index 99ccff0f..2bfb0a02 100644 --- a/csrc/config/quant_config.hpp +++ b/csrc/config/quant_config.hpp @@ -23,27 +23,21 @@ class QuantConfig { } } - void set_kv_quant_scheme(std::string kv_cache_dtype) { - if (kv_cache_dtype.empty()) { - this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; - // spdlog::debug("kv_cache_dtype is empty, using default NONE"); - return; - } - + void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) { try { - switch (parse_dtype(kv_cache_dtype)) { + switch (kv_cache_dtype) { case infinicore::DataType::I8: { this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8; break; } default: { - spdlog::warn("Unsupported kv_cache_dtype: '{}', fallback to NONE", kv_cache_dtype); + spdlog::warn("Unsupported kv_cache_dtype: '{}', fallback to NONE", infinicore::toString(kv_cache_dtype)); this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; break; } } } catch (const std::exception &e) { - spdlog::error("Failed to parse kv_cache_dtype '{}': {}", kv_cache_dtype, e.what()); + spdlog::error("Failed to parse kv_cache_dtype '{}': {}", infinicore::toString(kv_cache_dtype), e.what()); this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; } } diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 2e804d07..4ef68faa 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -56,7 +56,7 @@ InferEngine::InferEngine( const cache::CacheConfig *cache_config, bool enable_graph_compiling, backends::AttentionBackend attention_backend, - const std::string &kv_cache_dtype) // Changed parameter + std::optional kv_cache_dtype) // Changed parameter : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); @@ -65,7 +65,9 @@ InferEngine::InferEngine( // Load model config if model_path is provided, model_path must be valid, and config.json exists this->model_config_ = std::make_shared(model_path + "/config.json"); // Only support offline int8 kv cache quantization in this version - this->model_config_->set_kv_quant_scheme(kv_cache_dtype); + if (kv_cache_dtype.has_value()) { + this->model_config_->set_kv_quant_scheme(kv_cache_dtype.value()); + } // Create one RankWorker per rank int world_size = communication_group_.get_world_size(); barrier_ = std::make_unique((size_t)world_size); diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 14a878db..00dfc32e 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -47,7 +47,7 @@ class InferEngine { const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, - const std::string &kv_cache_dtype = ""); + std::optional kv_cache_dtype = std::nullopt); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); diff --git a/csrc/pybind11/cache/cache.hpp b/csrc/pybind11/cache/cache.hpp index 73fcf78f..a5723a2f 100644 --- a/csrc/pybind11/cache/cache.hpp +++ b/csrc/pybind11/cache/cache.hpp @@ -19,14 +19,10 @@ inline void bind_cache(py::module &m) { infinilm::cache::CacheConfig, std::shared_ptr>(m, "StaticKVCacheConfig") .def( - py::init(), - py::arg("max_batch_size") = 1, - py::arg("max_cache_len") = std::numeric_limits::max()) - .def( - py::init(), + py::init>(), py::arg("max_batch_size") = 1, py::arg("max_cache_len") = std::numeric_limits::max(), - py::arg("kv_cache_dtype")) + py::arg("kv_cache_dtype") = std::nullopt) .def( "max_batch_size", &infinilm::cache::StaticKVCacheConfig::max_batch_size) @@ -43,14 +39,10 @@ inline void bind_cache(py::module &m) { infinilm::cache::CacheConfig, std::shared_ptr>(m, "PagedKVCacheConfig") .def( - py::init(), - py::arg("num_blocks"), - py::arg("block_size") = 256) - .def( - py::init(), + py::init>(), py::arg("num_blocks"), - py::arg("kv_cache_dtype"), - py::arg("block_size") = 16) + py::arg("block_size") = 256, + py::arg("kv_cache_dtype") = std::nullopt) .def( "num_blocks", &infinilm::cache::PagedKVCacheConfig::num_blocks) diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 2b604a3c..3e0c884c 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -66,8 +66,10 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def( + "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def( + "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { auto cfg = self.get_cache_config(); return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) @@ -81,7 +83,7 @@ inline void bind_infer_engine(py::module &m) { std::shared_ptr cache_cfg, bool enable_graph_compiling, const std::string &attention_backend, - const std::string &kv_cache_dtype) { + std::optional kv_cache_dtype) { return std::make_shared( model_path, dist, @@ -97,7 +99,7 @@ inline void bind_infer_engine(py::module &m) { py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, py::arg("attention_backend") = "default", - py::arg("kv_cache_dtype") = "") + py::arg("kv_cache_dtype") = py::none()) .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") @@ -112,8 +114,10 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def( + "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def( + "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) { auto cfg = self.get_cache_config(); return std::shared_ptr(std::move(cfg->unique_copy())); }) diff --git a/examples/bench.py b/examples/bench.py index ee7ecd92..29ed14a8 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -262,8 +262,8 @@ def get_args(): parser.add_argument( "--kv-cache-dtype", type=str, - default="", - choices=["", "int8"], + default=None, + choices=["int8"], ) return parser.parse_args() @@ -305,7 +305,7 @@ def __init__( cache_config=cache_config, enable_graph_compiling=enable_graph, attention_backend=attn_backend, - kv_cache_dtype=args.kv_cache_dtype + kv_cache_dtype=args.kv_cache_dtype, ) # ---------------------------------------------------------------------------- # @@ -544,7 +544,9 @@ def run( initial_capacity = input_len + output_len test.model.reset_cache( StaticKVCacheConfig( - max_batch_size=batch_size, max_cache_len=initial_capacity, kv_cache_dtype=args.kv_cache_dtype + max_batch_size=batch_size, + max_cache_len=initial_capacity, + kv_cache_dtype=args.kv_cache_dtype, ) ) diff --git a/python/infinilm/cache/cache.py b/python/infinilm/cache/cache.py index 0ce984b7..f35dcd72 100644 --- a/python/infinilm/cache/cache.py +++ b/python/infinilm/cache/cache.py @@ -1,4 +1,6 @@ from infinilm.lib import _infinilm +import infinicore +from ..modeling_utils import parse_dtype class CacheConfig(_infinilm.CacheConfig): @@ -9,22 +11,45 @@ def __init__(self): class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig): - def __init__(self, max_batch_size: int = 1, max_cache_len: int = 0, kv_cache_dtype: str | None = None): - if kv_cache_dtype is None: - _infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len) + def __init__( + self, + max_batch_size: int = 1, + max_cache_len: int = 0, + kv_cache_dtype=None, + ): + if isinstance(kv_cache_dtype, str): + _infinilm.StaticKVCacheConfig.__init__( + self, + max_batch_size, + max_cache_len, + parse_dtype(kv_cache_dtype)._underlying, + ) + elif isinstance(kv_cache_dtype, infinicore.dtype): + _infinilm.StaticKVCacheConfig.__init__( + self, max_batch_size, max_cache_len, kv_cache_dtype._underlying + ) else: - _infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len, kv_cache_dtype) + _infinilm.StaticKVCacheConfig.__init__( + self, max_batch_size, max_cache_len, kv_cache_dtype + ) + class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig): def __init__( self, num_blocks: int, block_size: int = 256, - kv_cache_dtype: str | None = None, + kv_cache_dtype=None, ): - if kv_cache_dtype is None: - _infinilm.PagedKVCacheConfig.__init__(self, num_blocks, block_size) + if isinstance(kv_cache_dtype, str): + _infinilm.PagedKVCacheConfig.__init__( + self, num_blocks, block_size, parse_dtype(kv_cache_dtype)._underlying + ) + elif isinstance(kv_cache_dtype, infinicore.dtype): + _infinilm.PagedKVCacheConfig.__init__( + self, num_blocks, block_size, kv_cache_dtype._underlying + ) else: _infinilm.PagedKVCacheConfig.__init__( - self, num_blocks, kv_cache_dtype, block_size + self, num_blocks, block_size, kv_cache_dtype ) diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index a4b9a3ae..3ca1ea06 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -8,6 +8,8 @@ from infinilm.distributed import DistConfig from infinilm.lib import _infinilm +from .modeling_utils import parse_dtype + @dataclass class GenerationConfig: @@ -30,7 +32,7 @@ def __init__( cache_config=None, enable_graph_compiling=False, attention_backend="default", - kv_cache_dtype="", + kv_cache_dtype=None, ): self.config = AutoConfig.from_pretrained(model_path) @@ -44,7 +46,9 @@ def __init__( cache_config, enable_graph_compiling, attention_backend, - kv_cache_dtype, + parse_dtype(kv_cache_dtype)._underlying + if kv_cache_dtype is not None + else None, ) self.use_cache = False diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index d1b26dd9..1d21f2d9 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -7,6 +7,24 @@ from tqdm import tqdm import infinicore + +def parse_dtype(dtype_str: str): + if dtype_str == "float32": + return infinicore.float32 + elif dtype_str == "float16": + return infinicore.float16 + elif dtype_str == "bfloat16": + return infinicore.bfloat16 + elif dtype_str == "int8": + return infinicore.int8 + elif dtype_str == "int32": + return infinicore.int32 + elif dtype_str == "int64": + return infinicore.int64 + else: + raise ValueError(f"Unknown dtype string: {dtype_str}") + + str_to_torch_dtype = { "BOOL": torch.bool, "U8": torch.uint8, From 88e3889953bb585adb3f47febf610b76b5ae7a99 Mon Sep 17 00:00:00 2001 From: qinyiqun Date: Wed, 25 Mar 2026 16:23:29 +0800 Subject: [PATCH 5/5] Issue/253: kv data type is managed by model_config now --- csrc/cache/kv_cache.cpp | 43 +++++------------------ csrc/cache/kv_cache.hpp | 22 +++--------- csrc/config/model_config.hpp | 7 ++++ csrc/config/quant_config.hpp | 11 ++++++ csrc/engine/infer_engine.cpp | 2 +- csrc/engine/infer_engine.hpp | 2 +- csrc/engine/rank_worker.cpp | 2 +- csrc/engine/rank_worker.hpp | 2 +- csrc/layers/kv_quant.cpp | 2 +- csrc/models/infinilm_model.hpp | 2 +- csrc/models/llama/llama_attention.cpp | 1 - csrc/models/llama/llama_for_causal_lm.cpp | 2 +- csrc/models/llama/llama_for_causal_lm.hpp | 2 +- csrc/models/llama/llama_model.cpp | 16 +++++---- csrc/models/llama/llama_model.hpp | 2 +- csrc/models/model_factory.cpp | 4 +-- csrc/models/model_factory.hpp | 4 +-- csrc/pybind11/cache/cache.hpp | 16 +++------ csrc/pybind11/engine/engine.hpp | 16 ++++----- examples/bench.py | 3 +- examples/jiuge.py | 2 +- python/infinilm/cache/cache.py | 40 ++++----------------- 22 files changed, 72 insertions(+), 131 deletions(-) diff --git a/csrc/cache/kv_cache.cpp b/csrc/cache/kv_cache.cpp index 76cadb6b..d82ff7a2 100644 --- a/csrc/cache/kv_cache.cpp +++ b/csrc/cache/kv_cache.cpp @@ -12,11 +12,9 @@ namespace infinilm::cache { StaticKVCacheConfig::StaticKVCacheConfig( infinicore::Size _max_batch_size, - infinicore::Size _max_cache_len, - std::optional kv_cache_dtype) + infinicore::Size _max_cache_len) : max_batch_size_(_max_batch_size), - max_cache_len_(_max_cache_len), - kv_cache_dtype_(kv_cache_dtype) { + max_cache_len_(_max_cache_len) { } std::unique_ptr @@ -45,6 +43,7 @@ StaticKVCache::StaticKVCache( infinicore::Size num_v_heads, infinicore::Size num_layers, infinicore::Size max_positional_embedding, + infinicore::DataType dtype, const StaticKVCacheConfig &config, const engine::distributed::RankInfo &rank_info) : Cache(), @@ -55,7 +54,7 @@ StaticKVCache::StaticKVCache( rank_batch_size_(config.max_batch_size()), cache_len_(config.max_cache_len() == std::numeric_limits::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()), rank_num_layers_(num_layers), - dtype_(config.kv_cache_dtype()) { + dtype_(dtype) { // Allocate K cache k_caches_ = infinicore::Tensor::empty( @@ -117,28 +116,14 @@ StaticKVCache::update(size_t layer_idx, return {k_cache_layer, v_cache_layer}; } -infinicore::DataType -StaticKVCacheConfig::kv_cache_dtype() const { - return kv_cache_dtype_.value(); -} -void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) { - if (!this->kv_cache_dtype_.has_value()) { - this->kv_cache_dtype_ = std::make_optional(dtype); - } else { - return; - } -} - // ========================== // PagedKVCacheConfig // ========================== PagedKVCacheConfig::PagedKVCacheConfig( size_t num_blocks, - size_t block_size, - std::optional kv_cache_dtype) + size_t block_size) : num_blocks_(num_blocks), - block_size_(block_size), - kv_cache_dtype_(kv_cache_dtype) { + block_size_(block_size) { } std::unique_ptr @@ -156,19 +141,6 @@ PagedKVCacheConfig::block_size() const { return block_size_; } -infinicore::DataType -PagedKVCacheConfig::kv_cache_dtype() const { - return kv_cache_dtype_.value(); -} - -void PagedKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) { - if (!this->kv_cache_dtype_.has_value()) { - this->kv_cache_dtype_ = std::make_optional(dtype); - } else { - return; - } -} - // ========================== // PagedKVCache // ========================== @@ -178,6 +150,7 @@ PagedKVCache::PagedKVCache( infinicore::Size num_k_heads, infinicore::Size num_v_heads, infinicore::Size num_layers, + infinicore::DataType dtype, const PagedKVCacheConfig &config, const engine::distributed::RankInfo &rank_info) : Cache(), @@ -186,7 +159,7 @@ PagedKVCache::PagedKVCache( num_rank_k_heads_(num_k_heads / rank_info.tp_size), num_rank_v_heads_(num_v_heads / rank_info.tp_size), rank_num_layers_(num_layers), - dtype_(config.kv_cache_dtype()), + dtype_(dtype), num_blocks_per_layer_(config.num_blocks()), block_size_(config.block_size()) { // [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim] diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp index 89b73c75..164751e8 100644 --- a/csrc/cache/kv_cache.hpp +++ b/csrc/cache/kv_cache.hpp @@ -2,16 +2,15 @@ #include "base_cache.hpp" -#include "../utils.hpp" #include "infinicore/context/context.hpp" #include "infinicore/device.hpp" #include "infinicore/tensor.hpp" +#include #include #include #include #include -#include #include #include @@ -22,33 +21,27 @@ class StaticKVCacheConfig final : public CacheConfig { public: StaticKVCacheConfig( infinicore::Size _max_batch_size = 1, - infinicore::Size _max_cache_len = std::numeric_limits::max(), - std::optional kv_cache_dtype = std::nullopt); + infinicore::Size _max_cache_len = std::numeric_limits::max()); std::unique_ptr unique_copy() const override; infinicore::Size max_batch_size() const; infinicore::Size max_cache_len() const; - infinicore::DataType kv_cache_dtype() const; - void set_kv_cache_dtype(infinicore::DataType dtype); - private: infinicore::Size max_batch_size_; infinicore::Size max_cache_len_; - - std::optional kv_cache_dtype_; }; class StaticKVCache final : public Cache { public: StaticKVCache( - infinicore::Size k_dim, infinicore::Size v_dim, infinicore::Size num_k_heads, infinicore::Size num_v_heads, infinicore::Size num_layers, infinicore::Size max_positional_embedding, + infinicore::DataType dtype, const StaticKVCacheConfig &config, const engine::distributed::RankInfo &rank_info); @@ -93,31 +86,26 @@ class PagedKVCacheConfig final : public CacheConfig { public: PagedKVCacheConfig( size_t num_blocks, - size_t block_size = 256, - std::optional kv_cache_dtype = std::nullopt); + size_t block_size = 256); std::unique_ptr unique_copy() const override; size_t num_blocks() const; size_t block_size() const; - infinicore::DataType kv_cache_dtype() const; - void set_kv_cache_dtype(infinicore::DataType dtype); private: size_t num_blocks_; size_t block_size_; - - std::optional kv_cache_dtype_ = std::nullopt; }; class PagedKVCache final : public Cache { public: PagedKVCache( - infinicore::Size k_dim, infinicore::Size v_dim, infinicore::Size num_k_heads, infinicore::Size num_v_heads, infinicore::Size num_layers, + infinicore::DataType dtype, const PagedKVCacheConfig &config, const engine::distributed::RankInfo &rank_info); diff --git a/csrc/config/model_config.hpp b/csrc/config/model_config.hpp index db69579b..97049d6a 100644 --- a/csrc/config/model_config.hpp +++ b/csrc/config/model_config.hpp @@ -70,6 +70,13 @@ class ModelConfig { infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const { return quant_config.get_kv_quant_scheme(); } + infinicore::DataType get_kv_cache_dtype() const { + if (this->quant_config.get_kv_cache_dtype().has_value()) { + return this->quant_config.get_kv_cache_dtype().value(); + } else { + return this->get_dtype(); + } + } private: nlohmann::json config_json; diff --git a/csrc/config/quant_config.hpp b/csrc/config/quant_config.hpp index 2bfb0a02..d0b8cb24 100644 --- a/csrc/config/quant_config.hpp +++ b/csrc/config/quant_config.hpp @@ -2,6 +2,7 @@ #include "../utils.hpp" #include "infinicore/quantization.hpp" #include "nlohmann/json.hpp" +#include #include namespace infinilm::config { @@ -25,6 +26,7 @@ class QuantConfig { void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) { try { + this->kv_cache_dtype_ = std::make_optional(kv_cache_dtype); switch (kv_cache_dtype) { case infinicore::DataType::I8: { this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8; @@ -46,10 +48,19 @@ class QuantConfig { return kv_quant_scheme; } + std::optional get_kv_cache_dtype() const { + if (this->kv_cache_dtype_.has_value()) { + return this->kv_cache_dtype_; + } + return std::nullopt; + } + private: nlohmann::json quantization_config; std::shared_ptr quantization_method; + infinicore::quantization::KVQuantAlgo kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; + std::optional kv_cache_dtype_ = std::nullopt; }; } // namespace infinilm::config diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 4ef68faa..80aa1005 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -173,7 +173,7 @@ const distributed::DistConfig &InferEngine::get_dist_config() const { //------------------------------------------------------ // reset_cache (overloaded with CacheConfig) //------------------------------------------------------ -void InferEngine::reset_cache(cache::CacheConfig *new_config) { +void InferEngine::reset_cache(const cache::CacheConfig *new_config) { for (auto &worker : workers_) { worker->reset_cache(new_config); } diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index 00dfc32e..5e8030a2 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -60,7 +60,7 @@ class InferEngine { void compile(); - void reset_cache(cache::CacheConfig *new_config); + void reset_cache(const cache::CacheConfig *new_config); ~InferEngine(); diff --git a/csrc/engine/rank_worker.cpp b/csrc/engine/rank_worker.cpp index 8a640fe9..c9cdc97f 100644 --- a/csrc/engine/rank_worker.cpp +++ b/csrc/engine/rank_worker.cpp @@ -186,7 +186,7 @@ void RankWorker::wait() { } } -void RankWorker::reset_cache(cache::CacheConfig *new_config) { +void RankWorker::reset_cache(const cache::CacheConfig *new_config) { std::lock_guard lock(mutex_); if (should_exit_) { throw std::runtime_error("RankWorker is closing; cannot reset_cache"); diff --git a/csrc/engine/rank_worker.hpp b/csrc/engine/rank_worker.hpp index 32cdd5ad..719c1cb1 100644 --- a/csrc/engine/rank_worker.hpp +++ b/csrc/engine/rank_worker.hpp @@ -85,7 +85,7 @@ class RankWorker { void run(const Input &args); // Reset the internal cache with a new configuration - void reset_cache(cache::CacheConfig *new_config); + void reset_cache(const cache::CacheConfig *new_config); // Compile the model graph if enabled. void compile(); diff --git a/csrc/layers/kv_quant.cpp b/csrc/layers/kv_quant.cpp index 15114d18..5994f644 100644 --- a/csrc/layers/kv_quant.cpp +++ b/csrc/layers/kv_quant.cpp @@ -49,4 +49,4 @@ void KVQuantUtils::dequantize( v = std::move(v_dequant); } -} // namespace infinilm \ No newline at end of file +} // namespace infinilm diff --git a/csrc/models/infinilm_model.hpp b/csrc/models/infinilm_model.hpp index c4c5296f..550bf1aa 100644 --- a/csrc/models/infinilm_model.hpp +++ b/csrc/models/infinilm_model.hpp @@ -43,7 +43,7 @@ class InfinilmModel : public infinicore::nn::Module { virtual ~InfinilmModel() = default; virtual Output forward(const Input &input) const = 0; - virtual void reset_cache(cache::CacheConfig *cache_config) = 0; + virtual void reset_cache(const cache::CacheConfig *cache_config) = 0; virtual const cache::CacheConfig *get_cache_config() const = 0; }; } // namespace infinilm diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index 5737b886..b4ce1f4d 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -13,7 +13,6 @@ #include #include #include -#include #include #include #include diff --git a/csrc/models/llama/llama_for_causal_lm.cpp b/csrc/models/llama/llama_for_causal_lm.cpp index 1fc0708f..9596e668 100644 --- a/csrc/models/llama/llama_for_causal_lm.cpp +++ b/csrc/models/llama/llama_for_causal_lm.cpp @@ -71,7 +71,7 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const { return {logits}; } -void LlamaForCausalLM::reset_cache(cache::CacheConfig *cache_config) { +void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) { cache_config_ = cache_config->unique_copy(); model_->reset_cache(cache_config_.get()); } diff --git a/csrc/models/llama/llama_for_causal_lm.hpp b/csrc/models/llama/llama_for_causal_lm.hpp index 5727aea4..5cc79dfe 100644 --- a/csrc/models/llama/llama_for_causal_lm.hpp +++ b/csrc/models/llama/llama_for_causal_lm.hpp @@ -58,7 +58,7 @@ class LlamaForCausalLM : public InfinilmModel { */ Output forward(const Input &input) const; - void reset_cache(cache::CacheConfig *cache_config) override; + void reset_cache(const cache::CacheConfig *cache_config) override; const cache::CacheConfig *get_cache_config() const override; diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp index 233c7b04..e30b803c 100644 --- a/csrc/models/llama/llama_model.cpp +++ b/csrc/models/llama/llama_model.cpp @@ -122,12 +122,12 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids, return hidden_states; } -void LlamaModel::reset_cache(cache::CacheConfig *cache_config) { +void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { if (cache_config == nullptr) { kv_cache_ = nullptr; return; } - if (auto kv_cache_config = dynamic_cast(cache_config); + if (auto kv_cache_config = dynamic_cast(cache_config); kv_cache_config && model_config_ == nullptr) { kv_cache_ = std::make_shared( config_.head_dim, @@ -136,9 +136,10 @@ void LlamaModel::reset_cache(cache::CacheConfig *cache_config) { config_.num_key_value_heads, config_.num_hidden_layers, config_.max_position_embeddings, + model_config_->get_kv_cache_dtype(), *kv_cache_config, rank_info_); - } else if (auto paged_kv_cache_config = dynamic_cast(cache_config); + } else if (auto paged_kv_cache_config = dynamic_cast(cache_config); paged_kv_cache_config && model_config_ == nullptr) { kv_cache_ = std::make_shared( config_.head_dim, @@ -146,10 +147,10 @@ void LlamaModel::reset_cache(cache::CacheConfig *cache_config) { config_.num_key_value_heads, config_.num_key_value_heads, config_.num_hidden_layers, + model_config_->get_kv_cache_dtype(), *paged_kv_cache_config, rank_info_); - } else if (auto kv_cache_config = dynamic_cast(cache_config)) { - kv_cache_config->set_kv_cache_dtype(model_config_->get_dtype()); + } else if (auto kv_cache_config = dynamic_cast(cache_config)) { kv_cache_ = std::make_shared( model_config_->get_head_dim(), model_config_->get_head_dim(), @@ -157,16 +158,17 @@ void LlamaModel::reset_cache(cache::CacheConfig *cache_config) { model_config_->get("num_key_value_heads"), model_config_->get("num_hidden_layers"), model_config_->get("max_position_embeddings"), + model_config_->get_kv_cache_dtype(), *kv_cache_config, rank_info_); - } else if (auto paged_kv_cache_config = dynamic_cast(cache_config)) { - paged_kv_cache_config->set_kv_cache_dtype(model_config_->get_dtype()); + } else if (auto paged_kv_cache_config = dynamic_cast(cache_config)) { kv_cache_ = std::make_shared( model_config_->get_head_dim(), model_config_->get_head_dim(), model_config_->get("num_key_value_heads"), model_config_->get("num_key_value_heads"), model_config_->get("num_hidden_layers"), + model_config_->get_kv_cache_dtype(), *paged_kv_cache_config, rank_info_); } else { diff --git a/csrc/models/llama/llama_model.hpp b/csrc/models/llama/llama_model.hpp index 7d504408..416e1a5c 100644 --- a/csrc/models/llama/llama_model.hpp +++ b/csrc/models/llama/llama_model.hpp @@ -79,7 +79,7 @@ class LlamaModel : public infinicore::nn::Module { std::optional block_tables, std::optional slot_mapping) const; - void reset_cache(cache::CacheConfig *cache_config); + void reset_cache(const cache::CacheConfig *cache_config); // Module information size_t num_layers() const { return model_config_->get("num_hidden_layers"); } diff --git a/csrc/models/model_factory.cpp b/csrc/models/model_factory.cpp index 0d4caa69..319a7baa 100644 --- a/csrc/models/model_factory.cpp +++ b/csrc/models/model_factory.cpp @@ -17,7 +17,7 @@ namespace infinilm { std::shared_ptr InfinilmModelFactory::createModel( const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info, - cache::CacheConfig *cache, + const cache::CacheConfig *cache, backends::AttentionBackend attention_backend) { std::shared_ptr model; if (const auto llama_config_ptr = dynamic_cast(&config)) { @@ -38,7 +38,7 @@ std::shared_ptr InfinilmModelFactory::createModel( std::shared_ptr InfinilmModelFactory::createModel( std::shared_ptr model_config, engine::distributed::RankInfo rank_info, - cache::CacheConfig *cache, + const cache::CacheConfig *cache, backends::AttentionBackend attention_backend) { std::shared_ptr model; diff --git a/csrc/models/model_factory.hpp b/csrc/models/model_factory.hpp index 8af25d84..3c3c2e38 100644 --- a/csrc/models/model_factory.hpp +++ b/csrc/models/model_factory.hpp @@ -24,13 +24,13 @@ class InfinilmModelFactory { static std::shared_ptr createModel( const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), - cache::CacheConfig *cache = nullptr, + const cache::CacheConfig *cache = nullptr, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); static std::shared_ptr createModel( std::shared_ptr model_config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), - cache::CacheConfig *cache = nullptr, + const cache::CacheConfig *cache = nullptr, backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); }; } // namespace infinilm diff --git a/csrc/pybind11/cache/cache.hpp b/csrc/pybind11/cache/cache.hpp index a5723a2f..492f6c30 100644 --- a/csrc/pybind11/cache/cache.hpp +++ b/csrc/pybind11/cache/cache.hpp @@ -1,5 +1,4 @@ #include "../../cache/cache.hpp" -#include "infinicore/dtype.hpp" #include "infinicore/tensor.hpp" #include #include @@ -7,7 +6,6 @@ namespace py = pybind11; namespace infinilm::cache { - inline void bind_cache(py::module &m) { py::class_>(m, "CacheConfig") @@ -19,18 +17,15 @@ inline void bind_cache(py::module &m) { infinilm::cache::CacheConfig, std::shared_ptr>(m, "StaticKVCacheConfig") .def( - py::init>(), + py::init(), py::arg("max_batch_size") = 1, - py::arg("max_cache_len") = std::numeric_limits::max(), - py::arg("kv_cache_dtype") = std::nullopt) + py::arg("max_cache_len") = std::numeric_limits::max()) .def( "max_batch_size", &infinilm::cache::StaticKVCacheConfig::max_batch_size) .def( "max_cache_len", &infinilm::cache::StaticKVCacheConfig::max_cache_len) - .def("kv_cache_dtype", - &infinilm::cache::StaticKVCacheConfig::kv_cache_dtype) .def("__repr__", [](const infinilm::cache::StaticKVCacheConfig &) { return ""; }); @@ -39,18 +34,15 @@ inline void bind_cache(py::module &m) { infinilm::cache::CacheConfig, std::shared_ptr>(m, "PagedKVCacheConfig") .def( - py::init>(), + py::init(), py::arg("num_blocks"), - py::arg("block_size") = 256, - py::arg("kv_cache_dtype") = std::nullopt) + py::arg("block_size") = 256) .def( "num_blocks", &infinilm::cache::PagedKVCacheConfig::num_blocks) .def( "block_size", &infinilm::cache::PagedKVCacheConfig::block_size) - .def("kv_cache_dtype", - &infinilm::cache::PagedKVCacheConfig::kv_cache_dtype) .def("__repr__", [](const infinilm::cache::PagedKVCacheConfig &) { return ""; }); diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index 3e0c884c..e69d0648 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -35,7 +35,7 @@ inline void bind_infer_engine(py::module &m) { const InfinilmModel::Config &cfg, const distributed::DistConfig &dist, infinicore::Device::Type dev, - std::shared_ptr cache_cfg, + std::shared_ptr cache_cfg, bool enable_graph_compiling, const std::string &attention_backend) { return std::make_shared( @@ -66,10 +66,8 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def( - "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def( - "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { auto cfg = self.get_cache_config(); return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) @@ -80,7 +78,7 @@ inline void bind_infer_engine(py::module &m) { const std::string &model_path, const distributed::DistConfig &dist, infinicore::Device::Type dev, - std::shared_ptr cache_cfg, + const std::shared_ptr cache_cfg, bool enable_graph_compiling, const std::string &attention_backend, std::optional kv_cache_dtype) { @@ -114,10 +112,8 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def( - "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") - .def( - "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) { auto cfg = self.get_cache_config(); return std::shared_ptr(std::move(cfg->unique_copy())); }) diff --git a/examples/bench.py b/examples/bench.py index 29ed14a8..2c44ab19 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -545,8 +545,7 @@ def run( test.model.reset_cache( StaticKVCacheConfig( max_batch_size=batch_size, - max_cache_len=initial_capacity, - kv_cache_dtype=args.kv_cache_dtype, + max_cache_len=initial_capacity ) ) diff --git a/examples/jiuge.py b/examples/jiuge.py index f0e64c9f..0ac55e89 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -263,7 +263,7 @@ def test( batch_size = 1 if prompts is str else len(prompts) initial_capacity = max_new_tokens + len(input_ids_list[0]) cache_config = StaticKVCacheConfig( - max_batch_size=batch_size, max_cache_len=initial_capacity, kv_cache_dtype=args.kv_cache_dtype + max_batch_size=batch_size, max_cache_len=initial_capacity ) model.reset_cache(cache_config) diff --git a/python/infinilm/cache/cache.py b/python/infinilm/cache/cache.py index f35dcd72..dc36ab00 100644 --- a/python/infinilm/cache/cache.py +++ b/python/infinilm/cache/cache.py @@ -1,7 +1,4 @@ from infinilm.lib import _infinilm -import infinicore -from ..modeling_utils import parse_dtype - class CacheConfig(_infinilm.CacheConfig): def __init__(self): @@ -15,41 +12,18 @@ def __init__( self, max_batch_size: int = 1, max_cache_len: int = 0, - kv_cache_dtype=None, ): - if isinstance(kv_cache_dtype, str): - _infinilm.StaticKVCacheConfig.__init__( - self, - max_batch_size, - max_cache_len, - parse_dtype(kv_cache_dtype)._underlying, - ) - elif isinstance(kv_cache_dtype, infinicore.dtype): - _infinilm.StaticKVCacheConfig.__init__( - self, max_batch_size, max_cache_len, kv_cache_dtype._underlying - ) - else: - _infinilm.StaticKVCacheConfig.__init__( - self, max_batch_size, max_cache_len, kv_cache_dtype - ) + _infinilm.StaticKVCacheConfig.__init__( + self, max_batch_size, max_cache_len + ) class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig): def __init__( self, num_blocks: int, - block_size: int = 256, - kv_cache_dtype=None, + block_size: int = 256 ): - if isinstance(kv_cache_dtype, str): - _infinilm.PagedKVCacheConfig.__init__( - self, num_blocks, block_size, parse_dtype(kv_cache_dtype)._underlying - ) - elif isinstance(kv_cache_dtype, infinicore.dtype): - _infinilm.PagedKVCacheConfig.__init__( - self, num_blocks, block_size, kv_cache_dtype._underlying - ) - else: - _infinilm.PagedKVCacheConfig.__init__( - self, num_blocks, block_size, kv_cache_dtype - ) + _infinilm.PagedKVCacheConfig.__init__( + self, num_blocks, block_size + )