Skip to content

Commit 51d0a81

Browse files
issue/253 refine static kv cache init
1 parent 7796a76 commit 51d0a81

12 files changed

Lines changed: 100 additions & 85 deletions

File tree

csrc/cache/kv_cache.cpp

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,13 @@ namespace infinilm::cache {
1010
// StaticKVCacheConfig
1111
// ==========================
1212

13-
StaticKVCacheConfig::StaticKVCacheConfig(
14-
infinicore::Size _max_batch_size,
15-
infinicore::Size _max_cache_len)
16-
: max_batch_size_(_max_batch_size),
17-
max_cache_len_(_max_cache_len) {
18-
}
19-
2013
StaticKVCacheConfig::StaticKVCacheConfig(
2114
infinicore::Size _max_batch_size,
2215
infinicore::Size _max_cache_len,
23-
std::string kv_cache_dtype)
16+
std::optional<infinicore::DataType> kv_cache_dtype)
2417
: max_batch_size_(_max_batch_size),
25-
max_cache_len_(_max_cache_len) {
26-
if (!kv_cache_dtype.empty()) {
27-
this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype));
28-
}
18+
max_cache_len_(_max_cache_len),
19+
kv_cache_dtype_(kv_cache_dtype) {
2920
}
3021

3122
std::unique_ptr<CacheConfig>
@@ -143,20 +134,11 @@ void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) {
143134
// ==========================
144135
PagedKVCacheConfig::PagedKVCacheConfig(
145136
size_t num_blocks,
146-
std::string kv_cache_dtype,
147-
size_t block_size)
148-
: num_blocks_(num_blocks),
149-
block_size_(block_size) {
150-
if (!kv_cache_dtype.empty()) {
151-
this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype));
152-
}
153-
}
154-
155-
PagedKVCacheConfig::PagedKVCacheConfig(
156-
size_t num_blocks,
157-
size_t block_size)
137+
size_t block_size,
138+
std::optional<infinicore::DataType> kv_cache_dtype)
158139
: num_blocks_(num_blocks),
159-
block_size_(block_size) {
140+
block_size_(block_size),
141+
kv_cache_dtype_(kv_cache_dtype) {
160142
}
161143

162144
std::unique_ptr<CacheConfig>

csrc/cache/kv_cache.hpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,8 @@ class StaticKVCacheConfig final : public CacheConfig {
2222
public:
2323
StaticKVCacheConfig(
2424
infinicore::Size _max_batch_size = 1,
25-
infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max());
26-
27-
StaticKVCacheConfig(
28-
infinicore::Size _max_batch_size,
29-
infinicore::Size _max_cache_len,
30-
std::string kv_cache_dtype);
25+
infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max(),
26+
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
3127

3228
std::unique_ptr<CacheConfig> unique_copy() const override;
3329
infinicore::Size max_batch_size() const;
@@ -40,7 +36,7 @@ class StaticKVCacheConfig final : public CacheConfig {
4036
infinicore::Size max_batch_size_;
4137
infinicore::Size max_cache_len_;
4238

43-
std::optional<infinicore::DataType> kv_cache_dtype_ = std::nullopt;
39+
std::optional<infinicore::DataType> kv_cache_dtype_;
4440
};
4541

4642
class StaticKVCache final : public Cache {
@@ -97,12 +93,8 @@ class PagedKVCacheConfig final : public CacheConfig {
9793
public:
9894
PagedKVCacheConfig(
9995
size_t num_blocks,
100-
size_t block_size = 256);
101-
102-
PagedKVCacheConfig(
103-
size_t num_blocks,
104-
std::string kv_cache_dtype,
105-
size_t block_size = 16);
96+
size_t block_size = 256,
97+
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
10698

10799
std::unique_ptr<CacheConfig> unique_copy() const override;
108100
size_t num_blocks() const;

csrc/config/model_config.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class ModelConfig {
6464
infinicore::DataType get_dtype() const;
6565
infinicore::quantization::QuantScheme get_quant_scheme() const;
6666
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> get_rope_scaling() const;
67-
void set_kv_quant_scheme(std::string kv_cache_dtype) {
67+
void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) {
6868
this->quant_config.set_kv_quant_scheme(kv_cache_dtype);
6969
}
7070
infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const {

csrc/config/quant_config.hpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,21 @@ class QuantConfig {
2323
}
2424
}
2525

26-
void set_kv_quant_scheme(std::string kv_cache_dtype) {
27-
if (kv_cache_dtype.empty()) {
28-
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
29-
// spdlog::debug("kv_cache_dtype is empty, using default NONE");
30-
return;
31-
}
32-
26+
void set_kv_quant_scheme(infinicore::DataType kv_cache_dtype) {
3327
try {
34-
switch (parse_dtype(kv_cache_dtype)) {
28+
switch (kv_cache_dtype) {
3529
case infinicore::DataType::I8: {
3630
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8;
3731
break;
3832
}
3933
default: {
40-
spdlog::warn("Unsupported kv_cache_dtype: '{}', fallback to NONE", kv_cache_dtype);
34+
spdlog::warn("Unsupported kv_cache_dtype: '{}', fallback to NONE", infinicore::toString(kv_cache_dtype));
4135
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
4236
break;
4337
}
4438
}
4539
} catch (const std::exception &e) {
46-
spdlog::error("Failed to parse kv_cache_dtype '{}': {}", kv_cache_dtype, e.what());
40+
spdlog::error("Failed to parse kv_cache_dtype '{}': {}", infinicore::toString(kv_cache_dtype), e.what());
4741
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
4842
}
4943
}

csrc/engine/infer_engine.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ InferEngine::InferEngine(
5656
const cache::CacheConfig *cache_config,
5757
bool enable_graph_compiling,
5858
backends::AttentionBackend attention_backend,
59-
const std::string &kv_cache_dtype) // Changed parameter
59+
std::optional<infinicore::DataType> kv_cache_dtype) // Changed parameter
6060
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
6161
if (cache_config != nullptr) {
6262
cache_config_ = cache_config->unique_copy();
@@ -65,7 +65,9 @@ InferEngine::InferEngine(
6565
// Load model config if model_path is provided, model_path must be valid, and config.json exists
6666
this->model_config_ = std::make_shared<infinilm::config::ModelConfig>(model_path + "/config.json");
6767
// Only support offline int8 kv cache quantization in this version
68-
this->model_config_->set_kv_quant_scheme(kv_cache_dtype);
68+
if (kv_cache_dtype.has_value()) {
69+
this->model_config_->set_kv_quant_scheme(kv_cache_dtype.value());
70+
}
6971
// Create one RankWorker per rank
7072
int world_size = communication_group_.get_world_size();
7173
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);

csrc/engine/infer_engine.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class InferEngine {
4747
const cache::CacheConfig *cache_config = nullptr,
4848
bool enable_graph_compiling = false,
4949
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
50-
const std::string &kv_cache_dtype = "");
50+
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
5151

5252
// Load a parameter to all workers (each can extract its shard inside RankWorker)
5353
void load_param(const std::string &name, const infinicore::Tensor &param);

csrc/pybind11/cache/cache.hpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,10 @@ inline void bind_cache(py::module &m) {
1919
infinilm::cache::CacheConfig,
2020
std::shared_ptr<infinilm::cache::StaticKVCacheConfig>>(m, "StaticKVCacheConfig")
2121
.def(
22-
py::init<infinicore::Size, infinicore::Size>(),
23-
py::arg("max_batch_size") = 1,
24-
py::arg("max_cache_len") = std::numeric_limits<infinicore::Size>::max())
25-
.def(
26-
py::init<infinicore::Size, infinicore::Size, std::string>(),
22+
py::init<infinicore::Size, infinicore::Size, std::optional<infinicore::DataType>>(),
2723
py::arg("max_batch_size") = 1,
2824
py::arg("max_cache_len") = std::numeric_limits<infinicore::Size>::max(),
29-
py::arg("kv_cache_dtype"))
25+
py::arg("kv_cache_dtype") = std::nullopt)
3026
.def(
3127
"max_batch_size",
3228
&infinilm::cache::StaticKVCacheConfig::max_batch_size)
@@ -43,14 +39,10 @@ inline void bind_cache(py::module &m) {
4339
infinilm::cache::CacheConfig,
4440
std::shared_ptr<infinilm::cache::PagedKVCacheConfig>>(m, "PagedKVCacheConfig")
4541
.def(
46-
py::init<size_t, size_t>(),
47-
py::arg("num_blocks"),
48-
py::arg("block_size") = 256)
49-
.def(
50-
py::init<size_t, std::string, size_t>(),
42+
py::init<size_t, size_t, std::optional<infinicore::DataType>>(),
5143
py::arg("num_blocks"),
52-
py::arg("kv_cache_dtype"),
53-
py::arg("block_size") = 16)
44+
py::arg("block_size") = 256,
45+
py::arg("kv_cache_dtype") = std::nullopt)
5446
.def(
5547
"num_blocks",
5648
&infinilm::cache::PagedKVCacheConfig::num_blocks)

csrc/pybind11/engine/engine.hpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@ inline void bind_infer_engine(py::module &m) {
6666
}
6767
return state_dict_tp_all;
6868
})
69-
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
70-
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
69+
.def(
70+
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
71+
.def(
72+
"reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
7173
.def("get_cache_config", [](const InferEngine &self) -> std::shared_ptr<cache::CacheConfig> {
7274
auto cfg = self.get_cache_config();
7375
return cfg ? std::shared_ptr<cache::CacheConfig>(cfg->unique_copy()) : nullptr; })
@@ -81,7 +83,7 @@ inline void bind_infer_engine(py::module &m) {
8183
std::shared_ptr<infinilm::cache::CacheConfig> cache_cfg,
8284
bool enable_graph_compiling,
8385
const std::string &attention_backend,
84-
const std::string &kv_cache_dtype) {
86+
std::optional<infinicore::DataType> kv_cache_dtype) {
8587
return std::make_shared<InferEngine>(
8688
model_path,
8789
dist,
@@ -97,7 +99,7 @@ inline void bind_infer_engine(py::module &m) {
9799
py::arg("cache_config") = py::none(),
98100
py::arg("enable_graph_compiling") = false,
99101
py::arg("attention_backend") = "default",
100-
py::arg("kv_cache_dtype") = "")
102+
py::arg("kv_cache_dtype") = py::none())
101103
.def("load_param", &InferEngine::load_param,
102104
py::arg("name"), py::arg("param"),
103105
"Load a parameter tensor into all workers (each worker picks its shard)")
@@ -112,8 +114,10 @@ inline void bind_infer_engine(py::module &m) {
112114
}
113115
return state_dict_tp_all;
114116
})
115-
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
116-
.def("reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
117+
.def(
118+
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
119+
.def(
120+
"reset_cache", [](InferEngine &self, std::shared_ptr<cache::CacheConfig> cfg) { self.reset_cache(cfg ? cfg.get() : nullptr); }, py::arg("cache_config") = py::none())
117121
.def("get_cache_config", [](const InferEngine &self) {
118122
auto cfg = self.get_cache_config();
119123
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy())); })

examples/bench.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ def get_args():
262262
parser.add_argument(
263263
"--kv-cache-dtype",
264264
type=str,
265-
default="",
266-
choices=["", "int8"],
265+
default=None,
266+
choices=["int8"],
267267
)
268268

269269
return parser.parse_args()
@@ -305,7 +305,7 @@ def __init__(
305305
cache_config=cache_config,
306306
enable_graph_compiling=enable_graph,
307307
attention_backend=attn_backend,
308-
kv_cache_dtype=args.kv_cache_dtype
308+
kv_cache_dtype=args.kv_cache_dtype,
309309
)
310310

311311
# ---------------------------------------------------------------------------- #
@@ -544,7 +544,9 @@ def run(
544544
initial_capacity = input_len + output_len
545545
test.model.reset_cache(
546546
StaticKVCacheConfig(
547-
max_batch_size=batch_size, max_cache_len=initial_capacity, kv_cache_dtype=args.kv_cache_dtype
547+
max_batch_size=batch_size,
548+
max_cache_len=initial_capacity,
549+
kv_cache_dtype=args.kv_cache_dtype,
548550
)
549551
)
550552

python/infinilm/cache/cache.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from infinilm.lib import _infinilm
2+
import infinicore
3+
from ..modeling_utils import parse_dtype
24

35

46
class CacheConfig(_infinilm.CacheConfig):
@@ -9,22 +11,45 @@ def __init__(self):
911

1012

1113
class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig):
12-
def __init__(self, max_batch_size: int = 1, max_cache_len: int = 0, kv_cache_dtype: str | None = None):
13-
if kv_cache_dtype is None:
14-
_infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len)
14+
def __init__(
15+
self,
16+
max_batch_size: int = 1,
17+
max_cache_len: int = 0,
18+
kv_cache_dtype=None,
19+
):
20+
if isinstance(kv_cache_dtype, str):
21+
_infinilm.StaticKVCacheConfig.__init__(
22+
self,
23+
max_batch_size,
24+
max_cache_len,
25+
parse_dtype(kv_cache_dtype)._underlying,
26+
)
27+
elif isinstance(kv_cache_dtype, infinicore.dtype):
28+
_infinilm.StaticKVCacheConfig.__init__(
29+
self, max_batch_size, max_cache_len, kv_cache_dtype._underlying
30+
)
1531
else:
16-
_infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len, kv_cache_dtype)
32+
_infinilm.StaticKVCacheConfig.__init__(
33+
self, max_batch_size, max_cache_len, kv_cache_dtype
34+
)
35+
1736

1837
class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig):
1938
def __init__(
2039
self,
2140
num_blocks: int,
2241
block_size: int = 256,
23-
kv_cache_dtype: str | None = None,
42+
kv_cache_dtype=None,
2443
):
25-
if kv_cache_dtype is None:
26-
_infinilm.PagedKVCacheConfig.__init__(self, num_blocks, block_size)
44+
if isinstance(kv_cache_dtype, str):
45+
_infinilm.PagedKVCacheConfig.__init__(
46+
self, num_blocks, block_size, parse_dtype(kv_cache_dtype)._underlying
47+
)
48+
elif isinstance(kv_cache_dtype, infinicore.dtype):
49+
_infinilm.PagedKVCacheConfig.__init__(
50+
self, num_blocks, block_size, kv_cache_dtype._underlying
51+
)
2752
else:
2853
_infinilm.PagedKVCacheConfig.__init__(
29-
self, num_blocks, kv_cache_dtype, block_size
54+
self, num_blocks, block_size, kv_cache_dtype
3055
)

0 commit comments

Comments
 (0)