diff --git a/csrc/engine/compiler/static_batching_compiler.cpp b/csrc/engine/compiler/static_batching_compiler.cpp index dcd7f7143..28827a185 100644 --- a/csrc/engine/compiler/static_batching_compiler.cpp +++ b/csrc/engine/compiler/static_batching_compiler.cpp @@ -1,6 +1,7 @@ #include "static_batching_compiler.hpp" #include "../../cache/cache.hpp" #include "../../global_state/global_state.hpp" +#include "../../utils.hpp" namespace infinilm::engine { StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr &model, RankBarrier *barrier) @@ -13,10 +14,22 @@ void StaticBatchingCompiler::compile() { InfinilmModel::Input input; input.input_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice()); input.position_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice()); - input.past_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); - input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); - std::vector total_sequence_lengths_vec(b, 1); - infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false); + input.past_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice()); + input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice()); + input.block_tables = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + + set_zeros(input.input_ids.value()); + set_zeros(input.position_ids.value()); + set_zeros(input.past_sequence_lengths.value()); + + std::vector total_sequence_lengths_vec(b, 1); + infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int32_t), false); + + std::vector block_tables_vec(b); + for (size_t i = 0; i < b; ++i) { + block_tables_vec[i] = static_cast(i); + } + infinicore::context::memcpyH2D(input.block_tables.value()->data(), block_tables_vec.data(), b * sizeof(int32_t), false); // Attention reads attn_metadata from thread-local forward context. infinilm::global_state::get_forward_context().attn_metadata = { @@ -28,7 +41,14 @@ void StaticBatchingCompiler::compile() { input.slot_mapping, }; + model_->forward(input); + infinicore::context::syncStream(); + model_->reset_runtime_state(); + infinicore::context::syncStream(); + barrier_->wait(); + model_->reset_runtime_state(); + infinicore::context::syncStream(); infinicore::context::startGraphRecording(); auto output = model_->forward(input); auto graph = infinicore::context::stopGraphRecording(); @@ -54,6 +74,7 @@ StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled( graph_input.position_ids.value()->copy_from(input.position_ids.value()); graph_input.past_sequence_lengths.value()->copy_from(input.past_sequence_lengths.value()); graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); + model_->reset_runtime_state(); auto graph = std::get<0>(result->second.compiled); auto shared_output = std::shared_ptr(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); diff --git a/csrc/layers/attention/backends/static_attn.cpp b/csrc/layers/attention/backends/static_attn.cpp index b71f0b52e..86eca0e51 100644 --- a/csrc/layers/attention/backends/static_attn.cpp +++ b/csrc/layers/attention/backends/static_attn.cpp @@ -4,6 +4,7 @@ #include "infinicore/ops.hpp" #include "infinicore/ops/per_tensor_dequant_i8.hpp" #include "infinicore/ops/per_tensor_quant_i8.hpp" +#include namespace infinilm::layers::attention::backends { @@ -55,25 +56,32 @@ infinicore::Tensor StaticAttentionImpl::forward(const AttentionLayer &layer, // v_total : [bs, n_kv_head, max_seq_len, head_dim] auto [k_total, v_total] = do_kv_cache_update(layer, k_permuted, v_permuted, kv_cache, past_sequence_lengths.value()); + if (infinilm::quantization::KVQuantAlgo::NONE != this->kv_quant_scheme_) { + infinilm::KVQuantUtils::dequantize( + k_total, v_total, + this->kv_quant_scheme_, + k_scale, + v_scale, + q_reshaped); + } + infinicore::Tensor attn_output; - if (false) { - // experimental nineoothed flash attention - attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scale_, true); - attn_output = attn_output->permute({0, 2, 1, 3}) - ->contiguous() - ->view({batch_size, seq_len, num_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] + if (attn_metadata.block_tables.has_value() && seq_len == 1) { + auto query = q_rope->contiguous()->view({batch_size * seq_len, num_heads_, head_dim_}); + auto out = infinicore::Tensor::empty({batch_size * seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + infinicore::op::paged_attention_( + out, + query, + k_total, + v_total, + attn_metadata.block_tables.value(), + total_sequence_lengths.value(), + std::nullopt, + scale_); + attn_output = out->view({batch_size, seq_len, num_heads_ * head_dim_}); } else { size_t total_seq_len = reinterpret_cast(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0]; - if (infinilm::quantization::KVQuantAlgo::NONE != this->kv_quant_scheme_) { - infinilm::KVQuantUtils::dequantize( - k_total, v_total, - this->kv_quant_scheme_, - k_scale, - v_scale, - q_reshaped); - } - k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] @@ -107,25 +115,13 @@ std::tuple StaticAttentionImpl::do_kv_ca const infinicore::Tensor past_sequence_lengths) const { auto batch_size = key->size(0); - auto update_len = key->size(2); auto k_cache_layer = kv_cache->narrow({{0, 0, 1}})->squeeze(0); auto v_cache_layer = kv_cache->narrow({{0, 1, 1}})->squeeze(0); size_t max_batch_size = k_cache_layer->size(0); - size_t max_seq_len = k_cache_layer->size(2); - auto device = k_cache_layer->device(); - ASSERT_EQ(batch_size, max_batch_size); - size_t cache_pos = reinterpret_cast(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0]; - auto result_len = cache_pos + update_len; - ASSERT(result_len <= max_seq_len); - - auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}}); - auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}}); - - k_cache_update->copy_from(key); - v_cache_update->copy_from(value); + infinicore::op::kv_caching_(k_cache_layer, v_cache_layer, key, value, past_sequence_lengths); return {k_cache_layer, v_cache_layer}; }