77#include " infinicore/ops/mha_kvcache.hpp"
88#include " infinicore/ops/mha_varlen.hpp"
99#include " infinicore/ops/mul.hpp"
10+ #include " infinicore/ops/per_tensor_dequant_i8.hpp"
11+ #include " infinicore/ops/per_tensor_quant_i8.hpp"
1012
1113#include < algorithm>
1214#include < cmath>
1315#include < cstring>
16+ #include < iostream>
1417#include < optional>
1518#include < spdlog/spdlog.h>
1619#include < stdexcept>
@@ -137,6 +140,17 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo
137140 INFINICORE_NN_MODULE_INIT (q_norm, head_dim_, model_config_->get <double >(" rms_norm_eps" ), dtype, device);
138141 INFINICORE_NN_MODULE_INIT (k_norm, head_dim_, model_config_->get <double >(" rms_norm_eps" ), dtype, device);
139142 }
143+
144+ switch (this ->model_config_ ->get_kv_quant_scheme ()) {
145+ case (infinicore::quantization::KVQuantAlgo::INT8): {
146+ INFINICORE_NN_PARAMETER_INIT (kv_cache_k_scale, ({1 }, infinicore::DataType::F32, device, 0 , 0 , 1 ));
147+ INFINICORE_NN_PARAMETER_INIT (kv_cache_v_scale, ({1 }, infinicore::DataType::F32, device, 0 , 0 , 1 ));
148+ break ;
149+ }
150+ default : {
151+ break ;
152+ }
153+ }
140154}
141155
142156infinicore::Tensor LlamaAttention::forward_ (const infinicore::Tensor &hidden_states,
@@ -184,6 +198,17 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
184198 rotary_emb_->forward (q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim]
185199 rotary_emb_->forward (k_reshaped, pos_ids_for_rope, true ); // [bs, seq_len, n_kv_head, head_dim]
186200
201+ switch (this ->model_config_ ->get_kv_quant_scheme ()) {
202+ case (infinicore::quantization::KVQuantAlgo::INT8): {
203+ k_reshaped = infinicore::op::per_tensor_quant_i8 (k_reshaped, this ->kv_cache_k_scale (), infinicore::Tensor::zeros ({1 }, k_reshaped->dtype (), k_reshaped->device ()), true );
204+ v_reshaped = infinicore::op::per_tensor_quant_i8 (v_reshaped, this ->kv_cache_v_scale (), infinicore::Tensor::zeros ({1 }, k_reshaped->dtype (), k_reshaped->device ()), true );
205+ break ;
206+ }
207+ default : {
208+ break ;
209+ }
210+ }
211+
187212 // 5. Prepare KV caches
188213 // Convert to [batch, n_head, seq_len, head_dim] for cache
189214 // Ensure contiguous after permute for F16 compatibility with cache operations
@@ -212,6 +237,21 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
212237 ->view ({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
213238 } else {
214239 size_t total_seq_len = reinterpret_cast <int32_t *>(total_sequence_lengths.value ()->to (infinicore::Device::cpu ())->data ())[0 ];
240+
241+ switch (this ->model_config_ ->get_kv_quant_scheme ()) {
242+ case (infinicore::quantization::KVQuantAlgo::INT8): {
243+ auto k_total_dequant = infinicore::Tensor::strided_empty (k_total->shape (), k_total->strides (), q_reshaped->dtype (), q_reshaped->device ());
244+ auto v_total_dequant = infinicore::Tensor::strided_empty (v_total->shape (), v_total->strides (), q_reshaped->dtype (), q_reshaped->device ());
245+ infinicore::op::per_tensor_dequant_i8_ (k_total_dequant, k_total, this ->kv_cache_k_scale (), infinicore::Tensor::zeros ({1 }, k_reshaped->dtype (), k_reshaped->device ()));
246+ infinicore::op::per_tensor_dequant_i8_ (v_total_dequant, v_total, this ->kv_cache_v_scale (), infinicore::Tensor::zeros ({1 }, k_reshaped->dtype (), k_reshaped->device ()));
247+ k_total = k_total_dequant;
248+ v_total = v_total_dequant;
249+ break ;
250+ }
251+ default : {
252+ break ;
253+ }
254+ }
215255 k_total = k_total->narrow ({{2 , 0 , total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
216256 v_total = v_total->narrow ({{2 , 0 , total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
217257
@@ -342,10 +382,10 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
342382 auto q_for_fa = q_reshaped->view ({seq_len, 1 , num_attention_heads_, head_dim_});
343383 auto attn_out_4d = infinicore::op::mha_kvcache (
344384 q_for_fa,
345- k_total->permute ({0 , 2 , 1 , 3 }), // [num_blocks, block_size, num_kv_heads, head_dim]
385+ k_total->permute ({0 , 2 , 1 , 3 }), // [num_blocks, block_size, num_kv_heads, head_dim]
346386 v_total->permute ({0 , 2 , 1 , 3 }),
347- total_sequence_lengths.value (), // [seq_len] int32 (one entry per sequence)
348- block_tables.value (), // [seq_len, max_num_blocks_per_seq] int32
387+ total_sequence_lengths.value (), // [seq_len] int32 (one entry per sequence)
388+ block_tables.value (), // [seq_len, max_num_blocks_per_seq] int32
349389 std::nullopt ,
350390 scaling_);
351391 attn_output = attn_out_4d->view ({seq_len, num_attention_heads_, head_dim_});
@@ -361,7 +401,6 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
361401 scaling_);
362402 }
363403 }
364-
365404
366405 // 7. Project output
367406 attn_output
0 commit comments