Skip to content

Commit d3be4cc

Browse files
committed
Issue/253: feat: support custom KV cache dtype for quantization
1 parent 6ae4832 commit d3be4cc

8 files changed

Lines changed: 74 additions & 28 deletions

File tree

csrc/cache/kv_cache.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,16 @@ StaticKVCache::update(size_t layer_idx,
118118
// ==========================
119119
// PagedKVCacheConfig
120120
// ==========================
121+
PagedKVCacheConfig::PagedKVCacheConfig(
122+
size_t num_blocks,
123+
std::string kv_cache_dtype,
124+
size_t block_size)
125+
: num_blocks_(num_blocks),
126+
block_size_(block_size),
127+
kv_cache_dtype_(parse_dtype(kv_cache_dtype)) {
128+
kv_cache_dtype_set_ = true;
129+
}
130+
121131
PagedKVCacheConfig::PagedKVCacheConfig(
122132
size_t num_blocks,
123133
size_t block_size)
@@ -140,6 +150,15 @@ PagedKVCacheConfig::block_size() const {
140150
return block_size_;
141151
}
142152

153+
infinicore::DataType
154+
PagedKVCacheConfig::kv_cache_dtype() const {
155+
return kv_cache_dtype_;
156+
}
157+
158+
void PagedKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) const {
159+
kv_cache_dtype_ = dtype;
160+
}
161+
143162
// ==========================
144163
// PagedKVCache
145164
// ==========================
@@ -149,7 +168,6 @@ PagedKVCache::PagedKVCache(
149168
infinicore::Size num_k_heads,
150169
infinicore::Size num_v_heads,
151170
infinicore::Size num_layers,
152-
infinicore::DataType dtype,
153171
const PagedKVCacheConfig &config,
154172
const engine::distributed::RankInfo &rank_info)
155173
: Cache(),
@@ -158,7 +176,7 @@ PagedKVCache::PagedKVCache(
158176
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
159177
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
160178
rank_num_layers_(num_layers),
161-
dtype_(dtype),
179+
dtype_(config.kv_cache_dtype()),
162180
num_blocks_per_layer_(config.num_blocks()),
163181
block_size_(config.block_size()) {
164182
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]

csrc/cache/kv_cache.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "base_cache.hpp"
44

5+
#include "../utils.hpp"
56
#include "infinicore/context/context.hpp"
67
#include "infinicore/device.hpp"
78
#include "infinicore/tensor.hpp"
@@ -88,13 +89,24 @@ class PagedKVCacheConfig final : public CacheConfig {
8889
size_t num_blocks,
8990
size_t block_size = 16);
9091

92+
PagedKVCacheConfig(
93+
size_t num_blocks,
94+
std::string kv_cache_dtype,
95+
size_t block_size = 16);
96+
9197
std::unique_ptr<CacheConfig> unique_copy() const override;
9298
size_t num_blocks() const;
9399
size_t block_size() const;
100+
infinicore::DataType kv_cache_dtype() const;
101+
void set_kv_cache_dtype(infinicore::DataType dtype) const;
102+
bool kv_cache_dtype_set() const { return kv_cache_dtype_set_; }
94103

95104
private:
96105
size_t num_blocks_;
97106
size_t block_size_;
107+
108+
bool kv_cache_dtype_set_ = false;
109+
mutable infinicore::DataType kv_cache_dtype_;
98110
};
99111

100112
class PagedKVCache final : public Cache {
@@ -106,7 +118,6 @@ class PagedKVCache final : public Cache {
106118
infinicore::Size num_k_heads,
107119
infinicore::Size num_v_heads,
108120
infinicore::Size num_layers,
109-
infinicore::DataType dtype,
110121
const PagedKVCacheConfig &config,
111122
const engine::distributed::RankInfo &rank_info);
112123

csrc/config/model_config.cpp

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,23 +66,8 @@ ModelConfig::get_rope_scaling() const {
6666
}
6767
}
6868

69-
infinicore::DataType
70-
ModelConfig::get_dtype() const {
71-
try {
72-
std::string dtype_str = this->get<std::string>("torch_dtype");
73-
if (dtype_str == "float32") {
74-
return infinicore::DataType::F32;
75-
} else if (dtype_str == "float16") {
76-
return infinicore::DataType::F16;
77-
} else if (dtype_str == "bfloat16") {
78-
return infinicore::DataType::BF16;
79-
} else if (dtype_str == "int8") {
80-
return infinicore::DataType::I8;
81-
} else {
82-
throw std::runtime_error("Unsupported dtype string: " + dtype_str);
83-
}
84-
} catch (const std::exception &e) {
85-
throw std::runtime_error("Error getting dtype from config: " + std::string(e.what()));
86-
}
69+
infinicore::DataType ModelConfig::get_dtype() const {
70+
std::string dtype_str = this->get<std::string>("torch_dtype");
71+
return parse_dtype(dtype_str);
8772
}
8873
} // namespace infinilm::config

csrc/config/model_config.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include "../utils.hpp"
34
#include "infinicore/nn/rope.hpp"
45
#include "infinicore/ops.hpp"
56
#include "quant_config.hpp"

csrc/models/llama/llama_model.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
143143
config_.num_key_value_heads,
144144
config_.num_key_value_heads,
145145
config_.num_hidden_layers,
146-
config_.dtype,
147146
*paged_kv_cache_config,
148147
rank_info_);
149148
} else if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) {
@@ -158,13 +157,15 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
158157
*kv_cache_config,
159158
rank_info_);
160159
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config)) {
160+
if (!paged_kv_cache_config->kv_cache_dtype_set()) {
161+
paged_kv_cache_config->set_kv_cache_dtype(model_config_->get_dtype());
162+
}
161163
kv_cache_ = std::make_shared<cache::PagedKVCache>(
162164
model_config_->get_head_dim(),
163165
model_config_->get_head_dim(),
164166
model_config_->get<size_t>("num_key_value_heads"),
165167
model_config_->get<size_t>("num_key_value_heads"),
166168
model_config_->get<size_t>("num_hidden_layers"),
167-
model_config_->get_dtype(),
168169
*paged_kv_cache_config,
169170
rank_info_);
170171
} else {

csrc/pybind11/cache/cache.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "../../cache/cache.hpp"
2+
#include "infinicore/dtype.hpp"
23
#include "infinicore/tensor.hpp"
34
#include <pybind11/pybind11.h>
45
#include <pybind11/stl.h>
@@ -38,12 +39,19 @@ inline void bind_cache(py::module &m) {
3839
py::init<size_t, size_t>(),
3940
py::arg("num_blocks"),
4041
py::arg("block_size") = 16)
42+
.def(
43+
py::init<size_t, std::string, size_t>(),
44+
py::arg("num_blocks"),
45+
py::arg("kv_cache_dtype"),
46+
py::arg("block_size") = 16)
4147
.def(
4248
"num_blocks",
4349
&infinilm::cache::PagedKVCacheConfig::num_blocks)
4450
.def(
4551
"block_size",
4652
&infinilm::cache::PagedKVCacheConfig::block_size)
53+
.def("kv_cache_dtype",
54+
&infinilm::cache::PagedKVCacheConfig::kv_cache_dtype)
4755
.def("__repr__", [](const infinilm::cache::PagedKVCacheConfig &) {
4856
return "<PagedKVCacheConfig>";
4957
});

csrc/utils.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#pragma once
2+
#include <infinicore/dtype.hpp>
23
#include <infinirt.h>
34

45
#include <cstring>
@@ -123,3 +124,22 @@ inline uint16_t f32_to_bf16(float val) {
123124
inline void hash_combine(size_t &seed, size_t value) {
124125
seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2);
125126
}
127+
128+
inline infinicore::DataType parse_dtype(const std::string &dtype_str) {
129+
static const std::unordered_map<std::string, infinicore::DataType> dtype_map = {
130+
{"float32", infinicore::DataType::F32},
131+
{"float16", infinicore::DataType::F16},
132+
{"bfloat16", infinicore::DataType::BF16},
133+
{"int8", infinicore::DataType::I8},
134+
// 可根据需要扩展
135+
{"int32", infinicore::DataType::I32},
136+
{"int64", infinicore::DataType::I64},
137+
};
138+
139+
auto it = dtype_map.find(dtype_str);
140+
if (it != dtype_map.end()) {
141+
return it->second;
142+
}
143+
144+
throw std::runtime_error("Unsupported dtype string: " + dtype_str);
145+
}

python/infinilm/cache/cache.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ def __init__(
1818
self,
1919
num_blocks: int,
2020
block_size: int = 16,
21+
kv_cache_dtype: str | None = None,
2122
):
22-
_infinilm.PagedKVCacheConfig.__init__(
23-
self,
24-
num_blocks,
25-
block_size,
26-
)
23+
if kv_cache_dtype is None:
24+
_infinilm.PagedKVCacheConfig.__init__(self, num_blocks, block_size)
25+
else:
26+
_infinilm.PagedKVCacheConfig.__init__(
27+
self, num_blocks, kv_cache_dtype, block_size
28+
)

0 commit comments

Comments
 (0)