-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathkv_quant.cpp
More file actions
52 lines (40 loc) · 1.66 KB
/
kv_quant.cpp
File metadata and controls
52 lines (40 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include "kv_quant.hpp"
#include "infinicore/ops/per_tensor_dequant_i8.hpp"
#include "infinicore/ops/per_tensor_quant_i8.hpp"
namespace infinilm {
void KVQuantUtils::quantize(
infinicore::Tensor &k,
infinicore::Tensor &v,
infinicore::quantization::KVQuantAlgo algo,
const infinicore::Tensor &k_scale,
const infinicore::Tensor &v_scale) {
if (algo == infinicore::quantization::KVQuantAlgo::NONE) {
return;
}
auto device = k->device();
auto dtype = k->dtype();
auto zero_point = infinicore::Tensor::zeros({1}, dtype, device);
k = infinicore::op::per_tensor_quant_i8(k, k_scale, zero_point, true);
v = infinicore::op::per_tensor_quant_i8(v, v_scale, zero_point, true);
}
void KVQuantUtils::dequantize(
infinicore::Tensor &k,
infinicore::Tensor &v,
infinicore::quantization::KVQuantAlgo algo,
const infinicore::Tensor &k_scale,
const infinicore::Tensor &v_scale,
const infinicore::Tensor &reference) {
if (algo == infinicore::quantization::KVQuantAlgo::NONE) {
return; // 无需反量化
}
auto zero_point = infinicore::Tensor::zeros({1}, reference->dtype(), reference->device());
auto k_dequant = infinicore::Tensor::strided_empty(
k->shape(), k->strides(), reference->dtype(), reference->device());
auto v_dequant = infinicore::Tensor::strided_empty(
v->shape(), v->strides(), reference->dtype(), reference->device());
infinicore::op::per_tensor_dequant_i8_(k_dequant, k, k_scale, zero_point);
infinicore::op::per_tensor_dequant_i8_(v_dequant, v, v_scale, zero_point);
k = std::move(k_dequant);
v = std::move(v_dequant);
}
} // namespace infinilm