@@ -198,16 +198,11 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
198198 rotary_emb_->forward (q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim]
199199 rotary_emb_->forward (k_reshaped, pos_ids_for_rope, true ); // [bs, seq_len, n_kv_head, head_dim]
200200
201- switch (this ->model_config_ ->get_kv_quant_scheme ()) {
202- case (infinicore::quantization::KVQuantAlgo::INT8): {
203- k_reshaped = infinicore::op::per_tensor_quant_i8 (k_reshaped, this ->kv_cache_k_scale (), infinicore::Tensor::zeros ({1 }, k_reshaped->dtype (), k_reshaped->device ()), true );
204- v_reshaped = infinicore::op::per_tensor_quant_i8 (v_reshaped, this ->kv_cache_v_scale (), infinicore::Tensor::zeros ({1 }, k_reshaped->dtype (), k_reshaped->device ()), true );
205- break ;
206- }
207- default : {
208- break ;
209- }
210- }
201+ infinilm::KVQuantUtils::quantize (
202+ k_reshaped, v_reshaped,
203+ this ->model_config_ ->get_kv_quant_scheme (),
204+ this ->kv_cache_k_scale (),
205+ this ->kv_cache_v_scale ());
211206
212207 // 5. Prepare KV caches
213208 // Convert to [batch, n_head, seq_len, head_dim] for cache
@@ -238,20 +233,13 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
238233 } else {
239234 size_t total_seq_len = reinterpret_cast <int32_t *>(total_sequence_lengths.value ()->to (infinicore::Device::cpu ())->data ())[0 ];
240235
241- switch (this ->model_config_ ->get_kv_quant_scheme ()) {
242- case (infinicore::quantization::KVQuantAlgo::INT8): {
243- auto k_total_dequant = infinicore::Tensor::strided_empty (k_total->shape (), k_total->strides (), q_reshaped->dtype (), q_reshaped->device ());
244- auto v_total_dequant = infinicore::Tensor::strided_empty (v_total->shape (), v_total->strides (), q_reshaped->dtype (), q_reshaped->device ());
245- infinicore::op::per_tensor_dequant_i8_ (k_total_dequant, k_total, this ->kv_cache_k_scale (), infinicore::Tensor::zeros ({1 }, k_reshaped->dtype (), k_reshaped->device ()));
246- infinicore::op::per_tensor_dequant_i8_ (v_total_dequant, v_total, this ->kv_cache_v_scale (), infinicore::Tensor::zeros ({1 }, k_reshaped->dtype (), k_reshaped->device ()));
247- k_total = k_total_dequant;
248- v_total = v_total_dequant;
249- break ;
250- }
251- default : {
252- break ;
253- }
254- }
236+ infinilm::KVQuantUtils::dequantize (
237+ k_total, v_total,
238+ this ->model_config_ ->get_kv_quant_scheme (),
239+ this ->kv_cache_k_scale (),
240+ this ->kv_cache_v_scale (),
241+ q_reshaped);
242+
255243 k_total = k_total->narrow ({{2 , 0 , total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
256244 v_total = v_total->narrow ({{2 , 0 , total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
257245
0 commit comments