Skip to content

Commit 4aa8c3e

Browse files
committed
Issue/253: (1) Refactor attention KV cache quantization to layers/kv_quant.cpp; (2)update kv_cache_dtype handling; (3)Update Python test scripts
1 parent a2a2dac commit 4aa8c3e

20 files changed

Lines changed: 164 additions & 80 deletions

csrc/cache/kv_cache.cpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "../utils.hpp"
44
#include "infinicore/ops.hpp"
5+
#include <iostream>
56
#include <stdexcept>
67

78
namespace infinilm::cache {
@@ -22,11 +23,8 @@ StaticKVCacheConfig::StaticKVCacheConfig(
2223
std::string kv_cache_dtype)
2324
: max_batch_size_(_max_batch_size),
2425
max_cache_len_(_max_cache_len) {
25-
if (kv_cache_dtype.empty()) {
26-
kv_cache_dtype_set_ = false;
27-
} else {
28-
this->kv_cache_dtype_ = parse_dtype(kv_cache_dtype);
29-
kv_cache_dtype_set_ = true;
26+
if (!kv_cache_dtype.empty()) {
27+
this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype));
3028
}
3129
}
3230

@@ -130,11 +128,14 @@ StaticKVCache::update(size_t layer_idx,
130128

131129
infinicore::DataType
132130
StaticKVCacheConfig::kv_cache_dtype() const {
133-
return kv_cache_dtype_;
131+
return kv_cache_dtype_.value();
134132
}
135-
136-
void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) const {
137-
kv_cache_dtype_ = dtype;
133+
void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) {
134+
if (!this->kv_cache_dtype_.has_value()) {
135+
this->kv_cache_dtype_ = std::make_optional(dtype);
136+
} else {
137+
return;
138+
}
138139
}
139140

140141
// ==========================
@@ -145,9 +146,10 @@ PagedKVCacheConfig::PagedKVCacheConfig(
145146
std::string kv_cache_dtype,
146147
size_t block_size)
147148
: num_blocks_(num_blocks),
148-
block_size_(block_size),
149-
kv_cache_dtype_(parse_dtype(kv_cache_dtype)) {
150-
kv_cache_dtype_set_ = true;
149+
block_size_(block_size) {
150+
if (!kv_cache_dtype.empty()) {
151+
this->kv_cache_dtype_ = std::make_optional(parse_dtype(kv_cache_dtype));
152+
}
151153
}
152154

153155
PagedKVCacheConfig::PagedKVCacheConfig(
@@ -174,11 +176,15 @@ PagedKVCacheConfig::block_size() const {
174176

175177
infinicore::DataType
176178
PagedKVCacheConfig::kv_cache_dtype() const {
177-
return kv_cache_dtype_;
179+
return kv_cache_dtype_.value();
178180
}
179181

180-
void PagedKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) const {
181-
kv_cache_dtype_ = dtype;
182+
void PagedKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) {
183+
if (!this->kv_cache_dtype_.has_value()) {
184+
this->kv_cache_dtype_ = std::make_optional(dtype);
185+
} else {
186+
return;
187+
}
182188
}
183189

184190
// ==========================

csrc/cache/kv_cache.hpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <limits>
1212
#include <memory>
1313
#include <numeric>
14+
#include <optional>
1415
#include <stdexcept>
1516
#include <utility>
1617

@@ -33,15 +34,13 @@ class StaticKVCacheConfig final : public CacheConfig {
3334
infinicore::Size max_cache_len() const;
3435

3536
infinicore::DataType kv_cache_dtype() const;
36-
void set_kv_cache_dtype(infinicore::DataType dtype) const;
37-
bool kv_cache_dtype_is_set() const { return kv_cache_dtype_set_; }
37+
void set_kv_cache_dtype(infinicore::DataType dtype);
3838

3939
private:
4040
infinicore::Size max_batch_size_;
4141
infinicore::Size max_cache_len_;
4242

43-
bool kv_cache_dtype_set_ = false;
44-
mutable infinicore::DataType kv_cache_dtype_;
43+
std::optional<infinicore::DataType> kv_cache_dtype_ = std::nullopt;
4544
};
4645

4746
class StaticKVCache final : public Cache {
@@ -109,15 +108,13 @@ class PagedKVCacheConfig final : public CacheConfig {
109108
size_t num_blocks() const;
110109
size_t block_size() const;
111110
infinicore::DataType kv_cache_dtype() const;
112-
void set_kv_cache_dtype(infinicore::DataType dtype) const;
113-
bool kv_cache_dtype_set() const { return kv_cache_dtype_set_; }
111+
void set_kv_cache_dtype(infinicore::DataType dtype);
114112

115113
private:
116114
size_t num_blocks_;
117115
size_t block_size_;
118116

119-
bool kv_cache_dtype_set_ = false;
120-
mutable infinicore::DataType kv_cache_dtype_;
117+
std::optional<infinicore::DataType> kv_cache_dtype_ = std::nullopt;
121118
};
122119

123120
class PagedKVCache final : public Cache {

csrc/engine/infer_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ const distributed::DistConfig &InferEngine::get_dist_config() const {
171171
//------------------------------------------------------
172172
// reset_cache (overloaded with CacheConfig)
173173
//------------------------------------------------------
174-
void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
174+
void InferEngine::reset_cache(cache::CacheConfig *new_config) {
175175
for (auto &worker : workers_) {
176176
worker->reset_cache(new_config);
177177
}

csrc/engine/infer_engine.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class InferEngine {
6060

6161
void compile();
6262

63-
void reset_cache(const cache::CacheConfig *new_config);
63+
void reset_cache( cache::CacheConfig *new_config);
6464

6565
~InferEngine();
6666

csrc/engine/rank_worker.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ void RankWorker::wait() {
186186
}
187187
}
188188

189-
void RankWorker::reset_cache(const cache::CacheConfig *new_config) {
189+
void RankWorker::reset_cache(cache::CacheConfig *new_config) {
190190
std::lock_guard<std::mutex> lock(mutex_);
191191
if (should_exit_) {
192192
throw std::runtime_error("RankWorker is closing; cannot reset_cache");

csrc/engine/rank_worker.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class RankWorker {
8585
void run(const Input &args);
8686

8787
// Reset the internal cache with a new configuration
88-
void reset_cache(const cache::CacheConfig *new_config);
88+
void reset_cache(cache::CacheConfig *new_config);
8989

9090
// Compile the model graph if enabled.
9191
void compile();

csrc/layers/kv_quant.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "kv_quant.hpp"
2+
#include "infinicore/ops/per_tensor_dequant_i8.hpp"
3+
#include "infinicore/ops/per_tensor_quant_i8.hpp"
4+
5+
namespace infinilm {
6+
7+
void KVQuantUtils::quantize(
8+
infinicore::Tensor &k,
9+
infinicore::Tensor &v,
10+
infinicore::quantization::KVQuantAlgo algo,
11+
const infinicore::Tensor &k_scale,
12+
const infinicore::Tensor &v_scale) {
13+
14+
if (algo == infinicore::quantization::KVQuantAlgo::NONE) {
15+
return;
16+
}
17+
18+
auto device = k->device();
19+
auto dtype = k->dtype();
20+
auto zero_point = infinicore::Tensor::zeros({1}, dtype, device);
21+
22+
k = infinicore::op::per_tensor_quant_i8(k, k_scale, zero_point, true);
23+
v = infinicore::op::per_tensor_quant_i8(v, v_scale, zero_point, true);
24+
}
25+
26+
void KVQuantUtils::dequantize(
27+
infinicore::Tensor &k,
28+
infinicore::Tensor &v,
29+
infinicore::quantization::KVQuantAlgo algo,
30+
const infinicore::Tensor &k_scale,
31+
const infinicore::Tensor &v_scale,
32+
const infinicore::Tensor &reference) {
33+
34+
if (algo == infinicore::quantization::KVQuantAlgo::NONE) {
35+
return; // 无需反量化
36+
}
37+
38+
auto zero_point = infinicore::Tensor::zeros({1}, reference->dtype(), reference->device());
39+
40+
auto k_dequant = infinicore::Tensor::strided_empty(
41+
k->shape(), k->strides(), reference->dtype(), reference->device());
42+
auto v_dequant = infinicore::Tensor::strided_empty(
43+
v->shape(), v->strides(), reference->dtype(), reference->device());
44+
45+
infinicore::op::per_tensor_dequant_i8_(k_dequant, k, k_scale, zero_point);
46+
infinicore::op::per_tensor_dequant_i8_(v_dequant, v, v_scale, zero_point);
47+
48+
k = std::move(k_dequant);
49+
v = std::move(v_dequant);
50+
}
51+
52+
} // namespace infinilm

csrc/layers/kv_quant.hpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
3+
#include "infinicore/quantization.hpp"
4+
#include "infinicore/tensor.hpp"
5+
#include <utility>
6+
7+
namespace infinilm {
8+
9+
class KVQuantUtils {
10+
public:
11+
/**
12+
* @brief 量化 K/V(写入缓存前)- 原地修改 k 和 v
13+
* @param k 原始 K 张量
14+
* @param v 原始 V 张量
15+
* @param algo 量化算法
16+
* @param k_scale K 的 scale
17+
* @param v_scale V 的 scale
18+
*/
19+
static void quantize(
20+
infinicore::Tensor &k,
21+
infinicore::Tensor &v,
22+
infinicore::quantization::KVQuantAlgo algo,
23+
const infinicore::Tensor &k_scale,
24+
const infinicore::Tensor &v_scale);
25+
26+
/**
27+
* @brief 反量化 K/V(读取缓存后)- 原地修改 k 和 v
28+
* @param k 量化后的 K 张量
29+
* @param v 量化后的 V 张量
30+
* @param algo 量化算法
31+
* @param k_scale K 的 scale
32+
* @param v_scale V 的 scale
33+
* @param reference 参考张量(用于获取 dtype/device)
34+
*/
35+
static void dequantize(
36+
infinicore::Tensor &k,
37+
infinicore::Tensor &v,
38+
infinicore::quantization::KVQuantAlgo algo,
39+
const infinicore::Tensor &k_scale,
40+
const infinicore::Tensor &v_scale,
41+
const infinicore::Tensor &reference);
42+
};
43+
44+
} // namespace infinilm

csrc/models/infinilm_model.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class InfinilmModel : public infinicore::nn::Module {
4343
virtual ~InfinilmModel() = default;
4444
virtual Output forward(const Input &input) const = 0;
4545

46-
virtual void reset_cache(const cache::CacheConfig *cache_config) = 0;
46+
virtual void reset_cache(cache::CacheConfig *cache_config) = 0;
4747
virtual const cache::CacheConfig *get_cache_config() const = 0;
4848
};
4949
} // namespace infinilm

csrc/models/llama/llama_attention.cpp

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -198,16 +198,11 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
198198
rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim]
199199
rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim]
200200

201-
switch (this->model_config_->get_kv_quant_scheme()) {
202-
case (infinicore::quantization::KVQuantAlgo::INT8): {
203-
k_reshaped = infinicore::op::per_tensor_quant_i8(k_reshaped, this->kv_cache_k_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()), true);
204-
v_reshaped = infinicore::op::per_tensor_quant_i8(v_reshaped, this->kv_cache_v_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()), true);
205-
break;
206-
}
207-
default: {
208-
break;
209-
}
210-
}
201+
infinilm::KVQuantUtils::quantize(
202+
k_reshaped, v_reshaped,
203+
this->model_config_->get_kv_quant_scheme(),
204+
this->kv_cache_k_scale(),
205+
this->kv_cache_v_scale());
211206

212207
// 5. Prepare KV caches
213208
// Convert to [batch, n_head, seq_len, head_dim] for cache
@@ -238,20 +233,13 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
238233
} else {
239234
size_t total_seq_len = reinterpret_cast<int32_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
240235

241-
switch (this->model_config_->get_kv_quant_scheme()) {
242-
case (infinicore::quantization::KVQuantAlgo::INT8): {
243-
auto k_total_dequant = infinicore::Tensor::strided_empty(k_total->shape(), k_total->strides(), q_reshaped->dtype(), q_reshaped->device());
244-
auto v_total_dequant = infinicore::Tensor::strided_empty(v_total->shape(), v_total->strides(), q_reshaped->dtype(), q_reshaped->device());
245-
infinicore::op::per_tensor_dequant_i8_(k_total_dequant, k_total, this->kv_cache_k_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()));
246-
infinicore::op::per_tensor_dequant_i8_(v_total_dequant, v_total, this->kv_cache_v_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()));
247-
k_total = k_total_dequant;
248-
v_total = v_total_dequant;
249-
break;
250-
}
251-
default: {
252-
break;
253-
}
254-
}
236+
infinilm::KVQuantUtils::dequantize(
237+
k_total, v_total,
238+
this->model_config_->get_kv_quant_scheme(),
239+
this->kv_cache_k_scale(),
240+
this->kv_cache_v_scale(),
241+
q_reshaped);
242+
255243
k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
256244
v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
257245

0 commit comments

Comments
 (0)