diff --git a/csrc/cache/kv_cache.hpp b/csrc/cache/kv_cache.hpp index a594008a..164751e8 100644 --- a/csrc/cache/kv_cache.hpp +++ b/csrc/cache/kv_cache.hpp @@ -5,6 +5,7 @@ #include "infinicore/context/context.hpp" #include "infinicore/device.hpp" #include "infinicore/tensor.hpp" +#include #include #include @@ -34,7 +35,6 @@ class StaticKVCacheConfig final : public CacheConfig { class StaticKVCache final : public Cache { public: StaticKVCache( - infinicore::Size k_dim, infinicore::Size v_dim, infinicore::Size num_k_heads, @@ -100,7 +100,6 @@ class PagedKVCacheConfig final : public CacheConfig { class PagedKVCache final : public Cache { public: PagedKVCache( - infinicore::Size k_dim, infinicore::Size v_dim, infinicore::Size num_k_heads, diff --git a/csrc/config/model_config.cpp b/csrc/config/model_config.cpp index 70b41ff0..5b3c6bd6 100644 --- a/csrc/config/model_config.cpp +++ b/csrc/config/model_config.cpp @@ -66,23 +66,8 @@ ModelConfig::get_rope_scaling() const { } } -infinicore::DataType -ModelConfig::get_dtype() const { - try { - std::string dtype_str = this->get("torch_dtype"); - if (dtype_str == "float32") { - return infinicore::DataType::F32; - } else if (dtype_str == "float16") { - return infinicore::DataType::F16; - } else if (dtype_str == "bfloat16") { - return infinicore::DataType::BF16; - } else if (dtype_str == "int8") { - return infinicore::DataType::I8; - } else { - throw std::runtime_error("Unsupported dtype string: " + dtype_str); - } - } catch (const std::exception &e) { - throw std::runtime_error("Error getting dtype from config: " + std::string(e.what())); - } +infinicore::DataType ModelConfig::get_dtype() const { + std::string dtype_str = this->get("torch_dtype"); + return parse_dtype(dtype_str); } } // namespace infinilm::config diff --git a/csrc/config/model_config.hpp b/csrc/config/model_config.hpp index a4600304..97049d6a 100644 --- a/csrc/config/model_config.hpp +++ b/csrc/config/model_config.hpp @@ -1,5 +1,6 @@ #pragma once +#include "../utils.hpp" #include "infinicore/nn/rope.hpp" #include "infinicore/ops.hpp" #include "quant_config.hpp" @@ -63,6 +64,19 @@ class ModelConfig { infinicore::DataType get_dtype() const; infinicore::quantization::QuantScheme get_quant_scheme() const; std::shared_ptr get_rope_scaling() const; + void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) { + this->quant_config.set_kv_quant_scheme(kv_cache_dtype); + } + infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const { + return quant_config.get_kv_quant_scheme(); + } + infinicore::DataType get_kv_cache_dtype() const { + if (this->quant_config.get_kv_cache_dtype().has_value()) { + return this->quant_config.get_kv_cache_dtype().value(); + } else { + return this->get_dtype(); + } + } private: nlohmann::json config_json; diff --git a/csrc/config/quant_config.hpp b/csrc/config/quant_config.hpp index 480df067..d0b8cb24 100644 --- a/csrc/config/quant_config.hpp +++ b/csrc/config/quant_config.hpp @@ -1,7 +1,9 @@ #pragma once -// #include "../quantization/quantization.hpp" +#include "../utils.hpp" #include "infinicore/quantization.hpp" #include "nlohmann/json.hpp" +#include +#include namespace infinilm::config { @@ -22,9 +24,43 @@ class QuantConfig { } } + void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) { + try { + this->kv_cache_dtype_ = std::make_optional(kv_cache_dtype); + switch (kv_cache_dtype) { + case infinicore::DataType::I8: { + this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8; + break; + } + default: { + spdlog::warn("Unsupported kv_cache_dtype: '{}', fallback to NONE", infinicore::toString(kv_cache_dtype)); + this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; + break; + } + } + } catch (const std::exception &e) { + spdlog::error("Failed to parse kv_cache_dtype '{}': {}", infinicore::toString(kv_cache_dtype), e.what()); + this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; + } + } + + infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const { + return kv_quant_scheme; + } + + std::optional get_kv_cache_dtype() const { + if (this->kv_cache_dtype_.has_value()) { + return this->kv_cache_dtype_; + } + return std::nullopt; + } + private: nlohmann::json quantization_config; std::shared_ptr quantization_method; + + infinicore::quantization::KVQuantAlgo kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE; + std::optional kv_cache_dtype_ = std::nullopt; }; } // namespace infinilm::config diff --git a/csrc/engine/infer_engine.cpp b/csrc/engine/infer_engine.cpp index 7dd76eb8..80aa1005 100644 --- a/csrc/engine/infer_engine.cpp +++ b/csrc/engine/infer_engine.cpp @@ -55,7 +55,8 @@ InferEngine::InferEngine( infinicore::Device::Type device_type, const cache::CacheConfig *cache_config, bool enable_graph_compiling, - backends::AttentionBackend attention_backend) // Changed parameter + backends::AttentionBackend attention_backend, + std::optional kv_cache_dtype) // Changed parameter : communication_group_(distributed_config, device_type), attention_backend_(attention_backend) { if (cache_config != nullptr) { cache_config_ = cache_config->unique_copy(); @@ -63,6 +64,10 @@ InferEngine::InferEngine( // Load model config if model_path is provided, model_path must be valid, and config.json exists this->model_config_ = std::make_shared(model_path + "/config.json"); + // Only support offline int8 kv cache quantization in this version + if (kv_cache_dtype.has_value()) { + this->model_config_->set_kv_quant_scheme(kv_cache_dtype.value()); + } // Create one RankWorker per rank int world_size = communication_group_.get_world_size(); barrier_ = std::make_unique((size_t)world_size); diff --git a/csrc/engine/infer_engine.hpp b/csrc/engine/infer_engine.hpp index d191cfa1..5e8030a2 100644 --- a/csrc/engine/infer_engine.hpp +++ b/csrc/engine/infer_engine.hpp @@ -46,7 +46,8 @@ class InferEngine { infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), const cache::CacheConfig *cache_config = nullptr, bool enable_graph_compiling = false, - backends::AttentionBackend attention_backend = backends::AttentionBackend::Default); + backends::AttentionBackend attention_backend = backends::AttentionBackend::Default, + std::optional kv_cache_dtype = std::nullopt); // Load a parameter to all workers (each can extract its shard inside RankWorker) void load_param(const std::string &name, const infinicore::Tensor ¶m); diff --git a/csrc/layers/kv_quant.cpp b/csrc/layers/kv_quant.cpp new file mode 100644 index 00000000..5994f644 --- /dev/null +++ b/csrc/layers/kv_quant.cpp @@ -0,0 +1,52 @@ +#include "kv_quant.hpp" +#include "infinicore/ops/per_tensor_dequant_i8.hpp" +#include "infinicore/ops/per_tensor_quant_i8.hpp" + +namespace infinilm { + +void KVQuantUtils::quantize( + infinicore::Tensor &k, + infinicore::Tensor &v, + infinicore::quantization::KVQuantAlgo algo, + const infinicore::Tensor &k_scale, + const infinicore::Tensor &v_scale) { + + if (algo == infinicore::quantization::KVQuantAlgo::NONE) { + return; + } + + auto device = k->device(); + auto dtype = k->dtype(); + auto zero_point = infinicore::Tensor::zeros({1}, dtype, device); + + k = infinicore::op::per_tensor_quant_i8(k, k_scale, zero_point, true); + v = infinicore::op::per_tensor_quant_i8(v, v_scale, zero_point, true); +} + +void KVQuantUtils::dequantize( + infinicore::Tensor &k, + infinicore::Tensor &v, + infinicore::quantization::KVQuantAlgo algo, + const infinicore::Tensor &k_scale, + const infinicore::Tensor &v_scale, + const infinicore::Tensor &reference) { + + if (algo == infinicore::quantization::KVQuantAlgo::NONE) { + return; // 无需反量化 + } + + auto zero_point = infinicore::Tensor::zeros({1}, reference->dtype(), reference->device()); + + auto k_dequant = infinicore::Tensor::strided_empty( + k->shape(), k->strides(), reference->dtype(), reference->device()); + auto v_dequant = infinicore::Tensor::strided_empty( + v->shape(), v->strides(), reference->dtype(), reference->device()); + + infinicore::op::per_tensor_dequant_i8_(k_dequant, k, k_scale, zero_point); + infinicore::op::per_tensor_dequant_i8_(v_dequant, v, v_scale, zero_point); + + k = std::move(k_dequant); + v = std::move(v_dequant); +} + +} // namespace infinilm diff --git a/csrc/layers/kv_quant.hpp b/csrc/layers/kv_quant.hpp new file mode 100644 index 00000000..146c8f65 --- /dev/null +++ b/csrc/layers/kv_quant.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include "infinicore/quantization.hpp" +#include "infinicore/tensor.hpp" +#include + +namespace infinilm { + +class KVQuantUtils { +public: + /** + * @brief 量化 K/V(写入缓存前)- 原地修改 k 和 v + * @param k 原始 K 张量 + * @param v 原始 V 张量 + * @param algo 量化算法 + * @param k_scale K 的 scale + * @param v_scale V 的 scale + */ + static void quantize( + infinicore::Tensor &k, + infinicore::Tensor &v, + infinicore::quantization::KVQuantAlgo algo, + const infinicore::Tensor &k_scale, + const infinicore::Tensor &v_scale); + + /** + * @brief 反量化 K/V(读取缓存后)- 原地修改 k 和 v + * @param k 量化后的 K 张量 + * @param v 量化后的 V 张量 + * @param algo 量化算法 + * @param k_scale K 的 scale + * @param v_scale V 的 scale + * @param reference 参考张量(用于获取 dtype/device) + */ + static void dequantize( + infinicore::Tensor &k, + infinicore::Tensor &v, + infinicore::quantization::KVQuantAlgo algo, + const infinicore::Tensor &k_scale, + const infinicore::Tensor &v_scale, + const infinicore::Tensor &reference); +}; + +} // namespace infinilm diff --git a/csrc/models/llama/llama_attention.cpp b/csrc/models/llama/llama_attention.cpp index 57cad6ec..b4ce1f4d 100644 --- a/csrc/models/llama/llama_attention.cpp +++ b/csrc/models/llama/llama_attention.cpp @@ -7,6 +7,8 @@ #include "infinicore/ops/mha_kvcache.hpp" #include "infinicore/ops/mha_varlen.hpp" #include "infinicore/ops/mul.hpp" +#include "infinicore/ops/per_tensor_dequant_i8.hpp" +#include "infinicore/ops/per_tensor_quant_i8.hpp" #include #include @@ -137,6 +139,17 @@ LlamaAttention::LlamaAttention(std::shared_ptr mo INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, model_config_->get("rms_norm_eps"), dtype, device); INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, model_config_->get("rms_norm_eps"), dtype, device); } + + switch (this->model_config_->get_kv_quant_scheme()) { + case (infinicore::quantization::KVQuantAlgo::INT8): { + INFINICORE_NN_PARAMETER_INIT(kv_cache_k_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); + INFINICORE_NN_PARAMETER_INIT(kv_cache_v_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1)); + break; + } + default: { + break; + } + } } infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states, @@ -184,6 +197,12 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim] rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim] + infinilm::KVQuantUtils::quantize( + k_reshaped, v_reshaped, + this->model_config_->get_kv_quant_scheme(), + this->kv_cache_k_scale_, + this->kv_cache_v_scale_); + // 5. Prepare KV caches // Convert to [batch, n_head, seq_len, head_dim] for cache // Ensure contiguous after permute for F16 compatibility with cache operations @@ -212,6 +231,14 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta ->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] } else { size_t total_seq_len = reinterpret_cast(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0]; + + infinilm::KVQuantUtils::dequantize( + k_total, v_total, + this->model_config_->get_kv_quant_scheme(), + this->kv_cache_k_scale_, + this->kv_cache_v_scale_, + q_reshaped); + k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] @@ -342,10 +369,10 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd auto q_for_fa = q_reshaped->view({seq_len, 1, num_attention_heads_, head_dim_}); auto attn_out_4d = infinicore::op::mha_kvcache( q_for_fa, - k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim] + k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim] v_total->permute({0, 2, 1, 3}), - total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence) - block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32 + total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence) + block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32 std::nullopt, scaling_); attn_output = attn_out_4d->view({seq_len, num_attention_heads_, head_dim_}); @@ -361,7 +388,6 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd scaling_); } } - // 7. Project output attn_output diff --git a/csrc/models/llama/llama_attention.hpp b/csrc/models/llama/llama_attention.hpp index 4fe369d6..39493e76 100644 --- a/csrc/models/llama/llama_attention.hpp +++ b/csrc/models/llama/llama_attention.hpp @@ -5,6 +5,7 @@ #include "../../config/model_config.hpp" #include "../../engine/distributed/distributed.hpp" #include "../../layers/fused_linear.hpp" +#include "../../layers/kv_quant.hpp" #include "llama_config.hpp" #include "infinicore/nn/linear.hpp" @@ -123,6 +124,10 @@ class LlamaAttention : public infinicore::nn::Module { // Shared Rotary Position Embeddings (RoPE) std::shared_ptr rotary_emb_; + // For off-line kv cache quantization + INFINICORE_NN_PARAMETER(kv_cache_k_scale); + INFINICORE_NN_PARAMETER(kv_cache_v_scale); + private: std::shared_ptr model_config_ = std::make_shared(); size_t layer_idx_; // Layer index for cache access diff --git a/csrc/models/llama/llama_model.cpp b/csrc/models/llama/llama_model.cpp index 81e8fd04..e30b803c 100644 --- a/csrc/models/llama/llama_model.cpp +++ b/csrc/models/llama/llama_model.cpp @@ -136,7 +136,7 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { config_.num_key_value_heads, config_.num_hidden_layers, config_.max_position_embeddings, - config_.dtype, + model_config_->get_kv_cache_dtype(), *kv_cache_config, rank_info_); } else if (auto paged_kv_cache_config = dynamic_cast(cache_config); @@ -147,7 +147,7 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { config_.num_key_value_heads, config_.num_key_value_heads, config_.num_hidden_layers, - config_.dtype, + model_config_->get_kv_cache_dtype(), *paged_kv_cache_config, rank_info_); } else if (auto kv_cache_config = dynamic_cast(cache_config)) { @@ -158,7 +158,7 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { model_config_->get("num_key_value_heads"), model_config_->get("num_hidden_layers"), model_config_->get("max_position_embeddings"), - model_config_->get_dtype(), + model_config_->get_kv_cache_dtype(), *kv_cache_config, rank_info_); } else if (auto paged_kv_cache_config = dynamic_cast(cache_config)) { @@ -168,7 +168,7 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) { model_config_->get("num_key_value_heads"), model_config_->get("num_key_value_heads"), model_config_->get("num_hidden_layers"), - model_config_->get_dtype(), + model_config_->get_kv_cache_dtype(), *paged_kv_cache_config, rank_info_); } else { diff --git a/csrc/pybind11/cache/cache.hpp b/csrc/pybind11/cache/cache.hpp index 46f97ea4..492f6c30 100644 --- a/csrc/pybind11/cache/cache.hpp +++ b/csrc/pybind11/cache/cache.hpp @@ -6,7 +6,6 @@ namespace py = pybind11; namespace infinilm::cache { - inline void bind_cache(py::module &m) { py::class_>(m, "CacheConfig") diff --git a/csrc/pybind11/engine/engine.hpp b/csrc/pybind11/engine/engine.hpp index d84d223f..2bf7658d 100644 --- a/csrc/pybind11/engine/engine.hpp +++ b/csrc/pybind11/engine/engine.hpp @@ -66,18 +66,11 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def( - "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { - py::gil_scoped_release release; - return self.forward(input); - }, - "Run inference on all ranks with arbitrary arguments") - .def( - "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr { auto cfg = self.get_cache_config(); - return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; - }) + return cfg ? std::shared_ptr(cfg->unique_copy()) : nullptr; }) .def("__repr__", [](const InferEngine &self) { return ""; }); infer_engine @@ -87,21 +80,24 @@ inline void bind_infer_engine(py::module &m) { infinicore::Device::Type dev, std::shared_ptr cache_cfg, bool enable_graph_compiling, - const std::string &attention_backend) { + const std::string &attention_backend, + std::optional kv_cache_dtype) { return std::make_shared( model_path, dist, dev, cache_cfg ? cache_cfg.get() : nullptr, enable_graph_compiling, - infinilm::backends::parse_attention_backend(attention_backend)); + infinilm::backends::parse_attention_backend(attention_backend), + kv_cache_dtype); }), py::arg("model_path") = "", py::arg("distributed_config") = distributed::DistConfig(), py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("cache_config") = py::none(), py::arg("enable_graph_compiling") = false, - py::arg("attention_backend") = "default") + py::arg("attention_backend") = "default", + py::arg("kv_cache_dtype") = py::none()) .def("load_param", &InferEngine::load_param, py::arg("name"), py::arg("param"), "Load a parameter tensor into all workers (each worker picks its shard)") @@ -116,14 +112,8 @@ inline void bind_infer_engine(py::module &m) { } return state_dict_tp_all; }) - .def( - "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { - py::gil_scoped_release release; - return self.forward(input); - }, - "Run inference on all ranks with arbitrary arguments") - .def( - "reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) + .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") + .def("reset_cache", [](InferEngine &self, std::shared_ptr cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none()) .def("get_cache_config", [](const InferEngine &self) { auto cfg = self.get_cache_config(); return std::shared_ptr(std::move(cfg->unique_copy())); }) diff --git a/csrc/utils.hpp b/csrc/utils.hpp index 8af9759f..805e1254 100644 --- a/csrc/utils.hpp +++ b/csrc/utils.hpp @@ -1,4 +1,5 @@ #pragma once +#include #include #include @@ -123,3 +124,22 @@ inline uint16_t f32_to_bf16(float val) { inline void hash_combine(size_t &seed, size_t value) { seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); } + +inline infinicore::DataType parse_dtype(const std::string &dtype_str) { + static const std::unordered_map dtype_map = { + {"float32", infinicore::DataType::F32}, + {"float16", infinicore::DataType::F16}, + {"bfloat16", infinicore::DataType::BF16}, + {"int8", infinicore::DataType::I8}, + // 可根据需要扩展 + {"int32", infinicore::DataType::I32}, + {"int64", infinicore::DataType::I64}, + }; + + auto it = dtype_map.find(dtype_str); + if (it != dtype_map.end()) { + return it->second; + } + + throw std::runtime_error("Unsupported dtype string: " + dtype_str); +} diff --git a/examples/bench.py b/examples/bench.py index 9ac2b11e..2c44ab19 100644 --- a/examples/bench.py +++ b/examples/bench.py @@ -237,7 +237,7 @@ def get_args(): help="use paged cache", ) parser.add_argument( - "--paged_kv_block_size", + "--paged-kv-block-size", type=int, default=256, help="num tokens each kv block can hold", @@ -259,6 +259,13 @@ def get_args(): choices=["default", "flash-attn"], help="attention backend to use: 'default' or 'flash-attn'", ) + parser.add_argument( + "--kv-cache-dtype", + type=str, + default=None, + choices=["int8"], + ) + return parser.parse_args() @@ -298,6 +305,7 @@ def __init__( cache_config=cache_config, enable_graph_compiling=enable_graph, attention_backend=attn_backend, + kv_cache_dtype=args.kv_cache_dtype, ) # ---------------------------------------------------------------------------- # @@ -536,7 +544,8 @@ def run( initial_capacity = input_len + output_len test.model.reset_cache( StaticKVCacheConfig( - max_batch_size=batch_size, max_cache_len=initial_capacity + max_batch_size=batch_size, + max_cache_len=initial_capacity ) ) diff --git a/examples/jiuge.py b/examples/jiuge.py index 2cb99d66..0ac55e89 100644 --- a/examples/jiuge.py +++ b/examples/jiuge.py @@ -67,13 +67,13 @@ def get_args(): help="Run hygon test", ) parser.add_argument( - "--model_path", + "--model-path", type=str, required=True, help="model_path", ) parser.add_argument( - "--max_new_tokens", + "--max-new-tokens", type=int, default=100, help="max_new_tokens", @@ -109,7 +109,7 @@ def get_args(): ) parser.add_argument( - "--paged_kv_block_size", + "--paged-kv-block-size", type=int, default=256, help="num tokens each kv block can hold", @@ -149,6 +149,13 @@ def get_args(): choices=["default", "flash-attn"], help="attention backend to use: 'default' or 'flash-attn'", ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + default="", + choices=["", "int8"], + ) return parser.parse_args() @@ -176,6 +183,7 @@ def test( distributed_config=DistConfig(tp), enable_graph_compiling=enable_graph, attention_backend=attn_backend, + kv_cache_dtype=args.kv_cache_dtype, ) # ---------------------------------------------------------------------------- # # Load Weights diff --git a/python/infinilm/cache/cache.py b/python/infinilm/cache/cache.py index 354309e1..dc36ab00 100644 --- a/python/infinilm/cache/cache.py +++ b/python/infinilm/cache/cache.py @@ -1,6 +1,5 @@ from infinilm.lib import _infinilm - class CacheConfig(_infinilm.CacheConfig): def __init__(self): raise NotImplementedError( @@ -9,18 +8,22 @@ def __init__(self): class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig): - def __init__(self, max_batch_size: int = 1, max_cache_len: int = 0): - _infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len) + def __init__( + self, + max_batch_size: int = 1, + max_cache_len: int = 0, + ): + _infinilm.StaticKVCacheConfig.__init__( + self, max_batch_size, max_cache_len + ) class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig): def __init__( self, num_blocks: int, - block_size: int = 256, + block_size: int = 256 ): _infinilm.PagedKVCacheConfig.__init__( - self, - num_blocks, - block_size, + self, num_blocks, block_size ) diff --git a/python/infinilm/infer_engine.py b/python/infinilm/infer_engine.py index 8b614980..3ca1ea06 100644 --- a/python/infinilm/infer_engine.py +++ b/python/infinilm/infer_engine.py @@ -8,6 +8,8 @@ from infinilm.distributed import DistConfig from infinilm.lib import _infinilm +from .modeling_utils import parse_dtype + @dataclass class GenerationConfig: @@ -30,6 +32,7 @@ def __init__( cache_config=None, enable_graph_compiling=False, attention_backend="default", + kv_cache_dtype=None, ): self.config = AutoConfig.from_pretrained(model_path) @@ -43,6 +46,9 @@ def __init__( cache_config, enable_graph_compiling, attention_backend, + parse_dtype(kv_cache_dtype)._underlying + if kv_cache_dtype is not None + else None, ) self.use_cache = False diff --git a/python/infinilm/modeling_utils.py b/python/infinilm/modeling_utils.py index d1b26dd9..1d21f2d9 100644 --- a/python/infinilm/modeling_utils.py +++ b/python/infinilm/modeling_utils.py @@ -7,6 +7,24 @@ from tqdm import tqdm import infinicore + +def parse_dtype(dtype_str: str): + if dtype_str == "float32": + return infinicore.float32 + elif dtype_str == "float16": + return infinicore.float16 + elif dtype_str == "bfloat16": + return infinicore.bfloat16 + elif dtype_str == "int8": + return infinicore.int8 + elif dtype_str == "int32": + return infinicore.int32 + elif dtype_str == "int64": + return infinicore.int64 + else: + raise ValueError(f"Unknown dtype string: {dtype_str}") + + str_to_torch_dtype = { "BOOL": torch.bool, "U8": torch.uint8,