Skip to content

Commit a2a2dac

Browse files
committed
Issue/253: Support offline int8 inference with calibrated models
1 parent 1fc301f commit a2a2dac

15 files changed

Lines changed: 169 additions & 30 deletions

File tree

csrc/cache/kv_cache.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,20 @@ StaticKVCacheConfig::StaticKVCacheConfig(
1616
max_cache_len_(_max_cache_len) {
1717
}
1818

19+
StaticKVCacheConfig::StaticKVCacheConfig(
20+
infinicore::Size _max_batch_size,
21+
infinicore::Size _max_cache_len,
22+
std::string kv_cache_dtype)
23+
: max_batch_size_(_max_batch_size),
24+
max_cache_len_(_max_cache_len) {
25+
if (kv_cache_dtype.empty()) {
26+
kv_cache_dtype_set_ = false;
27+
} else {
28+
this->kv_cache_dtype_ = parse_dtype(kv_cache_dtype);
29+
kv_cache_dtype_set_ = true;
30+
}
31+
}
32+
1933
std::unique_ptr<CacheConfig>
2034
StaticKVCacheConfig::unique_copy() const {
2135
return std::make_unique<StaticKVCacheConfig>(*this);
@@ -42,7 +56,6 @@ StaticKVCache::StaticKVCache(
4256
infinicore::Size num_v_heads,
4357
infinicore::Size num_layers,
4458
infinicore::Size max_positional_embedding,
45-
infinicore::DataType dtype,
4659
const StaticKVCacheConfig &config,
4760
const engine::distributed::RankInfo &rank_info)
4861
: Cache(),
@@ -53,7 +66,7 @@ StaticKVCache::StaticKVCache(
5366
rank_batch_size_(config.max_batch_size()),
5467
cache_len_(config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()),
5568
rank_num_layers_(num_layers),
56-
dtype_(dtype) {
69+
dtype_(config.kv_cache_dtype()) {
5770

5871
// Allocate K cache
5972
k_caches_ = infinicore::Tensor::empty(
@@ -115,6 +128,15 @@ StaticKVCache::update(size_t layer_idx,
115128
return {k_cache_layer, v_cache_layer};
116129
}
117130

131+
infinicore::DataType
132+
StaticKVCacheConfig::kv_cache_dtype() const {
133+
return kv_cache_dtype_;
134+
}
135+
136+
void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) const {
137+
kv_cache_dtype_ = dtype;
138+
}
139+
118140
// ==========================
119141
// PagedKVCacheConfig
120142
// ==========================

csrc/cache/kv_cache.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,25 @@ class StaticKVCacheConfig final : public CacheConfig {
2323
infinicore::Size _max_batch_size = 1,
2424
infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max());
2525

26+
StaticKVCacheConfig(
27+
infinicore::Size _max_batch_size,
28+
infinicore::Size _max_cache_len,
29+
std::string kv_cache_dtype);
30+
2631
std::unique_ptr<CacheConfig> unique_copy() const override;
2732
infinicore::Size max_batch_size() const;
2833
infinicore::Size max_cache_len() const;
2934

35+
infinicore::DataType kv_cache_dtype() const;
36+
void set_kv_cache_dtype(infinicore::DataType dtype) const;
37+
bool kv_cache_dtype_is_set() const { return kv_cache_dtype_set_; }
38+
3039
private:
3140
infinicore::Size max_batch_size_;
3241
infinicore::Size max_cache_len_;
42+
43+
bool kv_cache_dtype_set_ = false;
44+
mutable infinicore::DataType kv_cache_dtype_;
3345
};
3446

3547
class StaticKVCache final : public Cache {
@@ -42,7 +54,6 @@ class StaticKVCache final : public Cache {
4254
infinicore::Size num_v_heads,
4355
infinicore::Size num_layers,
4456
infinicore::Size max_positional_embedding,
45-
infinicore::DataType dtype,
4657
const StaticKVCacheConfig &config,
4758
const engine::distributed::RankInfo &rank_info);
4859

csrc/config/model_config.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ class ModelConfig {
6464
infinicore::DataType get_dtype() const;
6565
infinicore::quantization::QuantScheme get_quant_scheme() const;
6666
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> get_rope_scaling() const;
67+
void set_kv_quant_scheme(std::string kv_cache_dtype) {
68+
if (kv_cache_dtype == "int8") {
69+
this->quant_config.set_kv_quant_scheme(kv_cache_dtype);
70+
}
71+
}
72+
infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const {
73+
return quant_config.get_kv_quant_scheme();
74+
}
6775

6876
private:
6977
nlohmann::json config_json;

csrc/config/quant_config.hpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#pragma once
2-
// #include "../quantization/quantization.hpp"
2+
#include "../utils.hpp"
33
#include "infinicore/quantization.hpp"
44
#include "nlohmann/json.hpp"
55

@@ -22,9 +22,27 @@ class QuantConfig {
2222
}
2323
}
2424

25+
void set_kv_quant_scheme(std::string kv_cache_dtype) {
26+
switch (parse_dtype(kv_cache_dtype)) {
27+
case infinicore::DataType::I8: {
28+
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8;
29+
break;
30+
}
31+
default: {
32+
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
33+
break;
34+
}
35+
}
36+
}
37+
38+
infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const {
39+
return kv_quant_scheme;
40+
}
41+
2542
private:
2643
nlohmann::json quantization_config;
2744
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization_method;
45+
infinicore::quantization::KVQuantAlgo kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
2846
};
2947

3048
} // namespace infinilm::config

csrc/engine/infer_engine.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,17 @@ InferEngine::InferEngine(
5555
infinicore::Device::Type device_type,
5656
const cache::CacheConfig *cache_config,
5757
bool enable_graph_compiling,
58-
backends::AttentionBackend attention_backend) // Changed parameter
58+
backends::AttentionBackend attention_backend,
59+
const std::string &kv_cache_dtype) // Changed parameter
5960
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
6061
if (cache_config != nullptr) {
6162
cache_config_ = cache_config->unique_copy();
6263
}
6364

6465
// Load model config if model_path is provided, model_path must be valid, and config.json exists
6566
this->model_config_ = std::make_shared<infinilm::config::ModelConfig>(model_path + "/config.json");
67+
// Only support offline int8 kv cache quantization in this version
68+
this->model_config_->set_kv_quant_scheme(kv_cache_dtype);
6669
// Create one RankWorker per rank
6770
int world_size = communication_group_.get_world_size();
6871
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);

csrc/engine/infer_engine.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class InferEngine {
4646
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
4747
const cache::CacheConfig *cache_config = nullptr,
4848
bool enable_graph_compiling = false,
49-
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
49+
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
50+
const std::string &kv_cache_dtype = "");
5051

5152
// Load a parameter to all workers (each can extract its shard inside RankWorker)
5253
void load_param(const std::string &name, const infinicore::Tensor &param);

csrc/models/llama/llama_attention.cpp

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
#include "infinicore/ops/mha_kvcache.hpp"
88
#include "infinicore/ops/mha_varlen.hpp"
99
#include "infinicore/ops/mul.hpp"
10+
#include "infinicore/ops/per_tensor_dequant_i8.hpp"
11+
#include "infinicore/ops/per_tensor_quant_i8.hpp"
1012

1113
#include <algorithm>
1214
#include <cmath>
1315
#include <cstring>
16+
#include <iostream>
1417
#include <optional>
1518
#include <spdlog/spdlog.h>
1619
#include <stdexcept>
@@ -137,6 +140,17 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo
137140
INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
138141
INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
139142
}
143+
144+
switch (this->model_config_->get_kv_quant_scheme()) {
145+
case (infinicore::quantization::KVQuantAlgo::INT8): {
146+
INFINICORE_NN_PARAMETER_INIT(kv_cache_k_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1));
147+
INFINICORE_NN_PARAMETER_INIT(kv_cache_v_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1));
148+
break;
149+
}
150+
default: {
151+
break;
152+
}
153+
}
140154
}
141155

142156
infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
@@ -184,6 +198,17 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
184198
rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim]
185199
rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim]
186200

201+
switch (this->model_config_->get_kv_quant_scheme()) {
202+
case (infinicore::quantization::KVQuantAlgo::INT8): {
203+
k_reshaped = infinicore::op::per_tensor_quant_i8(k_reshaped, this->kv_cache_k_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()), true);
204+
v_reshaped = infinicore::op::per_tensor_quant_i8(v_reshaped, this->kv_cache_v_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()), true);
205+
break;
206+
}
207+
default: {
208+
break;
209+
}
210+
}
211+
187212
// 5. Prepare KV caches
188213
// Convert to [batch, n_head, seq_len, head_dim] for cache
189214
// Ensure contiguous after permute for F16 compatibility with cache operations
@@ -212,6 +237,21 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
212237
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
213238
} else {
214239
size_t total_seq_len = reinterpret_cast<int32_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];
240+
241+
switch (this->model_config_->get_kv_quant_scheme()) {
242+
case (infinicore::quantization::KVQuantAlgo::INT8): {
243+
auto k_total_dequant = infinicore::Tensor::strided_empty(k_total->shape(), k_total->strides(), q_reshaped->dtype(), q_reshaped->device());
244+
auto v_total_dequant = infinicore::Tensor::strided_empty(v_total->shape(), v_total->strides(), q_reshaped->dtype(), q_reshaped->device());
245+
infinicore::op::per_tensor_dequant_i8_(k_total_dequant, k_total, this->kv_cache_k_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()));
246+
infinicore::op::per_tensor_dequant_i8_(v_total_dequant, v_total, this->kv_cache_v_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()));
247+
k_total = k_total_dequant;
248+
v_total = v_total_dequant;
249+
break;
250+
}
251+
default: {
252+
break;
253+
}
254+
}
215255
k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
216256
v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
217257

@@ -342,10 +382,10 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
342382
auto q_for_fa = q_reshaped->view({seq_len, 1, num_attention_heads_, head_dim_});
343383
auto attn_out_4d = infinicore::op::mha_kvcache(
344384
q_for_fa,
345-
k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim]
385+
k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim]
346386
v_total->permute({0, 2, 1, 3}),
347-
total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence)
348-
block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32
387+
total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence)
388+
block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32
349389
std::nullopt,
350390
scaling_);
351391
attn_output = attn_out_4d->view({seq_len, num_attention_heads_, head_dim_});
@@ -361,7 +401,6 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
361401
scaling_);
362402
}
363403
}
364-
365404

366405
// 7. Project output
367406
attn_output

csrc/models/llama/llama_attention.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,13 @@ class LlamaAttention : public infinicore::nn::Module {
112112
std::optional<infinicore::Tensor> block_tables,
113113
std::optional<infinicore::Tensor> slot_mapping) const;
114114

115+
infinicore::Tensor kv_cache_k_scale() const {
116+
return kv_cache_k_scale_;
117+
}
118+
infinicore::Tensor kv_cache_v_scale() const {
119+
return kv_cache_v_scale_;
120+
}
121+
115122
protected:
116123
// Projection layers
117124
INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
@@ -123,6 +130,10 @@ class LlamaAttention : public infinicore::nn::Module {
123130
// Shared Rotary Position Embeddings (RoPE)
124131
std::shared_ptr<infinicore::nn::RoPE> rotary_emb_;
125132

133+
// For off-line kv cache quantization
134+
INFINICORE_NN_PARAMETER(kv_cache_k_scale);
135+
INFINICORE_NN_PARAMETER(kv_cache_v_scale);
136+
126137
private:
127138
std::shared_ptr<infinilm::config::ModelConfig> model_config_ = std::make_shared<infinilm::config::ModelConfig>();
128139
size_t layer_idx_; // Layer index for cache access

csrc/models/llama/llama_model.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
136136
config_.num_key_value_heads,
137137
config_.num_hidden_layers,
138138
config_.max_position_embeddings,
139-
config_.dtype,
140139
*kv_cache_config,
141140
rank_info_);
142141
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config);
@@ -150,14 +149,16 @@ void LlamaModel::reset_cache(const cache::CacheConfig *cache_config) {
150149
*paged_kv_cache_config,
151150
rank_info_);
152151
} else if (auto kv_cache_config = dynamic_cast<const cache::StaticKVCacheConfig *>(cache_config)) {
152+
if (!kv_cache_config->kv_cache_dtype_is_set()) {
153+
kv_cache_config->set_kv_cache_dtype(model_config_->get_dtype());
154+
}
153155
kv_cache_ = std::make_shared<cache::StaticKVCache>(
154156
model_config_->get_head_dim(),
155157
model_config_->get_head_dim(),
156158
model_config_->get<size_t>("num_key_value_heads"),
157159
model_config_->get<size_t>("num_key_value_heads"),
158160
model_config_->get<size_t>("num_hidden_layers"),
159161
model_config_->get<size_t>("max_position_embeddings"),
160-
model_config_->get_dtype(),
161162
*kv_cache_config,
162163
rank_info_);
163164
} else if (auto paged_kv_cache_config = dynamic_cast<const cache::PagedKVCacheConfig *>(cache_config)) {

csrc/pybind11/cache/cache.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,19 @@ inline void bind_cache(py::module &m) {
2222
py::init<infinicore::Size, infinicore::Size>(),
2323
py::arg("max_batch_size") = 1,
2424
py::arg("max_cache_len") = std::numeric_limits<infinicore::Size>::max())
25+
.def(
26+
py::init<infinicore::Size, infinicore::Size, std::string>(),
27+
py::arg("max_batch_size") = 1,
28+
py::arg("max_cache_len") = std::numeric_limits<infinicore::Size>::max(),
29+
py::arg("kv_cache_dtype"))
2530
.def(
2631
"max_batch_size",
2732
&infinilm::cache::StaticKVCacheConfig::max_batch_size)
2833
.def(
2934
"max_cache_len",
3035
&infinilm::cache::StaticKVCacheConfig::max_cache_len)
36+
.def("kv_cache_dtype",
37+
&infinilm::cache::StaticKVCacheConfig::kv_cache_dtype)
3138
.def("__repr__", [](const infinilm::cache::StaticKVCacheConfig &) {
3239
return "<StaticKVCacheConfig>";
3340
});

0 commit comments

Comments
 (0)