@@ -118,6 +118,16 @@ StaticKVCache::update(size_t layer_idx,
118118// ==========================
119119// PagedKVCacheConfig
120120// ==========================
121+ PagedKVCacheConfig::PagedKVCacheConfig (
122+ size_t num_blocks,
123+ std::string kv_cache_dtype,
124+ size_t block_size)
125+ : num_blocks_(num_blocks),
126+ block_size_(block_size),
127+ kv_cache_dtype_(parse_dtype(kv_cache_dtype)) {
128+ kv_cache_dtype_set_ = true ;
129+ }
130+
121131PagedKVCacheConfig::PagedKVCacheConfig (
122132 size_t num_blocks,
123133 size_t block_size)
@@ -140,6 +150,15 @@ PagedKVCacheConfig::block_size() const {
140150 return block_size_;
141151}
142152
153+ infinicore::DataType
154+ PagedKVCacheConfig::kv_cache_dtype () const {
155+ return kv_cache_dtype_;
156+ }
157+
158+ void PagedKVCacheConfig::set_kv_cache_dtype (infinicore::DataType dtype) const {
159+ kv_cache_dtype_ = dtype;
160+ }
161+
143162// ==========================
144163// PagedKVCache
145164// ==========================
@@ -149,7 +168,6 @@ PagedKVCache::PagedKVCache(
149168 infinicore::Size num_k_heads,
150169 infinicore::Size num_v_heads,
151170 infinicore::Size num_layers,
152- infinicore::DataType dtype,
153171 const PagedKVCacheConfig &config,
154172 const engine::distributed::RankInfo &rank_info)
155173 : Cache(),
@@ -158,7 +176,7 @@ PagedKVCache::PagedKVCache(
158176 num_rank_k_heads_(num_k_heads / rank_info.tp_size),
159177 num_rank_v_heads_(num_v_heads / rank_info.tp_size),
160178 rank_num_layers_(num_layers),
161- dtype_(dtype ),
179+ dtype_(config.kv_cache_dtype() ),
162180 num_blocks_per_layer_(config.num_blocks()),
163181 block_size_(config.block_size()) {
164182 // [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
0 commit comments