Skip to content

Commit 788532e

Browse files
issue/253 refine static kv cache init
1 parent 7796a76 commit 788532e

6 files changed

Lines changed: 69 additions & 60 deletions

File tree

csrc/cache/kv_cache.cpp

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,13 @@ namespace infinilm::cache {
1010
// StaticKVCacheConfig
1111
// ==========================
1212

13-
StaticKVCacheConfig::StaticKVCacheConfig(
14-
infinicore::Size _max_batch_size,
15-
infinicore::Size _max_cache_len)
16-
: max_batch_size_(_max_batch_size),
17-
max_cache_len_(_max_cache_len) {
18-
}
19-
2013
StaticKVCacheConfig::StaticKVCacheConfig(
2114
infinicore::Size _max_batch_size,
2215
infinicore::Size _max_cache_len,
23-
std::string kv_cache_dtype)
16+
std::optional<infinicore::DataType> kv_cache_dtype)
2417
: max_batch_size_(_max_batch_size),
25-
max_cache_len_(_max_cache_len) {
26-
if (!kv_cache_dtype.empty()) {
27-
this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype));
28-
}
18+
max_cache_len_(_max_cache_len),
19+
kv_cache_dtype_(kv_cache_dtype) {
2920
}
3021

3122
std::unique_ptr<CacheConfig>
@@ -143,20 +134,11 @@ void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) {
143134
// ==========================
144135
PagedKVCacheConfig::PagedKVCacheConfig(
145136
size_t num_blocks,
146-
std::string kv_cache_dtype,
147-
size_t block_size)
148-
: num_blocks_(num_blocks),
149-
block_size_(block_size) {
150-
if (!kv_cache_dtype.empty()) {
151-
this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype));
152-
}
153-
}
154-
155-
PagedKVCacheConfig::PagedKVCacheConfig(
156-
size_t num_blocks,
157-
size_t block_size)
137+
size_t block_size,
138+
std::optional<infinicore::DataType> kv_cache_dtype)
158139
: num_blocks_(num_blocks),
159-
block_size_(block_size) {
140+
block_size_(block_size),
141+
kv_cache_dtype_(kv_cache_dtype) {
160142
}
161143

162144
std::unique_ptr<CacheConfig>

csrc/cache/kv_cache.hpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,8 @@ class StaticKVCacheConfig final : public CacheConfig {
2222
public:
2323
StaticKVCacheConfig(
2424
infinicore::Size _max_batch_size = 1,
25-
infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max());
26-
27-
StaticKVCacheConfig(
28-
infinicore::Size _max_batch_size,
29-
infinicore::Size _max_cache_len,
30-
std::string kv_cache_dtype);
25+
infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max(),
26+
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
3127

3228
std::unique_ptr<CacheConfig> unique_copy() const override;
3329
infinicore::Size max_batch_size() const;
@@ -40,7 +36,7 @@ class StaticKVCacheConfig final : public CacheConfig {
4036
infinicore::Size max_batch_size_;
4137
infinicore::Size max_cache_len_;
4238

43-
std::optional<infinicore::DataType> kv_cache_dtype_ = std::nullopt;
39+
std::optional<infinicore::DataType> kv_cache_dtype_;
4440
};
4541

4642
class StaticKVCache final : public Cache {
@@ -97,12 +93,8 @@ class PagedKVCacheConfig final : public CacheConfig {
9793
public:
9894
PagedKVCacheConfig(
9995
size_t num_blocks,
100-
size_t block_size = 256);
101-
102-
PagedKVCacheConfig(
103-
size_t num_blocks,
104-
std::string kv_cache_dtype,
105-
size_t block_size = 16);
96+
size_t block_size = 256,
97+
std::optional<infinicore::DataType> kv_cache_dtype = std::nullopt);
10698

10799
std::unique_ptr<CacheConfig> unique_copy() const override;
108100
size_t num_blocks() const;

csrc/pybind11/cache/cache.hpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,10 @@ inline void bind_cache(py::module &m) {
1919
infinilm::cache::CacheConfig,
2020
std::shared_ptr<infinilm::cache::StaticKVCacheConfig>>(m, "StaticKVCacheConfig")
2121
.def(
22-
py::init<infinicore::Size, infinicore::Size>(),
23-
py::arg("max_batch_size") = 1,
24-
py::arg("max_cache_len") = std::numeric_limits<infinicore::Size>::max())
25-
.def(
26-
py::init<infinicore::Size, infinicore::Size, std::string>(),
22+
py::init<infinicore::Size, infinicore::Size, std::optional<infinicore::DataType>>(),
2723
py::arg("max_batch_size") = 1,
2824
py::arg("max_cache_len") = std::numeric_limits<infinicore::Size>::max(),
29-
py::arg("kv_cache_dtype"))
25+
py::arg("kv_cache_dtype") = std::nullopt)
3026
.def(
3127
"max_batch_size",
3228
&infinilm::cache::StaticKVCacheConfig::max_batch_size)
@@ -43,14 +39,10 @@ inline void bind_cache(py::module &m) {
4339
infinilm::cache::CacheConfig,
4440
std::shared_ptr<infinilm::cache::PagedKVCacheConfig>>(m, "PagedKVCacheConfig")
4541
.def(
46-
py::init<size_t, size_t>(),
47-
py::arg("num_blocks"),
48-
py::arg("block_size") = 256)
49-
.def(
50-
py::init<size_t, std::string, size_t>(),
42+
py::init<size_t, size_t, std::optional<infinicore::DataType>>(),
5143
py::arg("num_blocks"),
52-
py::arg("kv_cache_dtype"),
53-
py::arg("block_size") = 16)
44+
py::arg("block_size") = 256,
45+
py::arg("kv_cache_dtype") = std::nullopt)
5446
.def(
5547
"num_blocks",
5648
&infinilm::cache::PagedKVCacheConfig::num_blocks)

python/infinilm/cache/cache.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from infinilm.lib import _infinilm
2+
import infinicore
3+
from ..modeling_utils import parse_dtype
24

35

46
class CacheConfig(_infinilm.CacheConfig):
@@ -9,22 +11,45 @@ def __init__(self):
911

1012

1113
class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig):
12-
def __init__(self, max_batch_size: int = 1, max_cache_len: int = 0, kv_cache_dtype: str | None = None):
13-
if kv_cache_dtype is None:
14-
_infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len)
14+
def __init__(
15+
self,
16+
max_batch_size: int = 1,
17+
max_cache_len: int = 0,
18+
kv_cache_dtype=None,
19+
):
20+
if isinstance(kv_cache_dtype, str):
21+
_infinilm.StaticKVCacheConfig.__init__(
22+
self,
23+
max_batch_size,
24+
max_cache_len,
25+
parse_dtype(kv_cache_dtype)._underlying,
26+
)
27+
elif isinstance(kv_cache_dtype, infinicore.dtype.dtype):
28+
_infinilm.StaticKVCacheConfig.__init__(
29+
self, max_batch_size, max_cache_len, kv_cache_dtype._underlying
30+
)
1531
else:
16-
_infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len, kv_cache_dtype)
32+
_infinilm.StaticKVCacheConfig.__init__(
33+
self, max_batch_size, max_cache_len, kv_cache_dtype
34+
)
35+
1736

1837
class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig):
1938
def __init__(
2039
self,
2140
num_blocks: int,
2241
block_size: int = 256,
23-
kv_cache_dtype: str | None = None,
42+
kv_cache_dtype=None,
2443
):
25-
if kv_cache_dtype is None:
26-
_infinilm.PagedKVCacheConfig.__init__(self, num_blocks, block_size)
44+
if isinstance(kv_cache_dtype, str):
45+
_infinilm.PagedKVCacheConfig.__init__(
46+
self, num_blocks, block_size, parse_dtype(kv_cache_dtype)._underlying
47+
)
48+
elif isinstance(kv_cache_dtype, infinicore.dtype.dtype):
49+
_infinilm.PagedKVCacheConfig.__init__(
50+
self, num_blocks, block_size, kv_cache_dtype._underlying
51+
)
2752
else:
2853
_infinilm.PagedKVCacheConfig.__init__(
29-
self, num_blocks, kv_cache_dtype, block_size
54+
self, num_blocks, block_size, kv_cache_dtype
3055
)

python/infinilm/infer_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
cache_config=None,
3131
enable_graph_compiling=False,
3232
attention_backend="default",
33-
kv_cache_dtype="",
33+
kv_cache_dtype=None,
3434
):
3535
self.config = AutoConfig.from_pretrained(model_path)
3636

python/infinilm/modeling_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,24 @@
77
from tqdm import tqdm
88
import infinicore
99

10+
11+
def parse_dtype(dtype_str: str):
12+
if dtype_str == "float32":
13+
return infinicore.float32
14+
elif dtype_str == "float16":
15+
return infinicore.float16
16+
elif dtype_str == "bfloat16":
17+
return infinicore.bfloat16
18+
elif dtype_str == "int8":
19+
return infinicore.int8
20+
elif dtype_str == "int32":
21+
return infinicore.int32
22+
elif dtype_str == "int64":
23+
return infinicore.int64
24+
else:
25+
raise ValueError(f"Unknown dtype string: {dtype_str}")
26+
27+
1028
str_to_torch_dtype = {
1129
"BOOL": torch.bool,
1230
"U8": torch.uint8,

0 commit comments

Comments
 (0)