From d1851fbae86d8483ad7d728a2db4a706fe0f90b7 Mon Sep 17 00:00:00 2001 From: gongchensu Date: Mon, 29 Jun 2026 06:37:02 +0000 Subject: [PATCH] fix: support static cache graph replay Update static cache writes with device-side KV caching so graph capture does not depend on host reads. Replay static decode through paged attention using one logical block per batch while keeping the static KV layout. Reset runtime state around capture and replay to keep graph inputs stable. --- .../compiler/static_batching_compiler.cpp | 29 +++++++++-- .../layers/attention/backends/static_attn.cpp | 52 +++++++++---------- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/csrc/engine/compiler/static_batching_compiler.cpp b/csrc/engine/compiler/static_batching_compiler.cpp index dcd7f7143..28827a185 100644 --- a/csrc/engine/compiler/static_batching_compiler.cpp +++ b/csrc/engine/compiler/static_batching_compiler.cpp @@ -1,6 +1,7 @@ #include "static_batching_compiler.hpp" #include "../../cache/cache.hpp" #include "../../global_state/global_state.hpp" +#include "../../utils.hpp" namespace infinilm::engine { StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr &model, RankBarrier *barrier) @@ -13,10 +14,22 @@ void StaticBatchingCompiler::compile() { InfinilmModel::Input input; input.input_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice()); input.position_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice()); - input.past_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); - input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice()); - std::vector total_sequence_lengths_vec(b, 1); - infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false); + input.past_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice()); + input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I32, infinicore::context::getDevice()); + input.block_tables = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I32, infinicore::context::getDevice()); + + set_zeros(input.input_ids.value()); + set_zeros(input.position_ids.value()); + set_zeros(input.past_sequence_lengths.value()); + + std::vector total_sequence_lengths_vec(b, 1); + infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int32_t), false); + + std::vector block_tables_vec(b); + for (size_t i = 0; i < b; ++i) { + block_tables_vec[i] = static_cast(i); + } + infinicore::context::memcpyH2D(input.block_tables.value()->data(), block_tables_vec.data(), b * sizeof(int32_t), false); // Attention reads attn_metadata from thread-local forward context. infinilm::global_state::get_forward_context().attn_metadata = { @@ -28,7 +41,14 @@ void StaticBatchingCompiler::compile() { input.slot_mapping, }; + model_->forward(input); + infinicore::context::syncStream(); + model_->reset_runtime_state(); + infinicore::context::syncStream(); + barrier_->wait(); + model_->reset_runtime_state(); + infinicore::context::syncStream(); infinicore::context::startGraphRecording(); auto output = model_->forward(input); auto graph = infinicore::context::stopGraphRecording(); @@ -54,6 +74,7 @@ StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled( graph_input.position_ids.value()->copy_from(input.position_ids.value()); graph_input.past_sequence_lengths.value()->copy_from(input.past_sequence_lengths.value()); graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value()); + model_->reset_runtime_state(); auto graph = std::get<0>(result->second.compiled); auto shared_output = std::shared_ptr(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()}); diff --git a/csrc/layers/attention/backends/static_attn.cpp b/csrc/layers/attention/backends/static_attn.cpp index b71f0b52e..86eca0e51 100644 --- a/csrc/layers/attention/backends/static_attn.cpp +++ b/csrc/layers/attention/backends/static_attn.cpp @@ -4,6 +4,7 @@ #include "infinicore/ops.hpp" #include "infinicore/ops/per_tensor_dequant_i8.hpp" #include "infinicore/ops/per_tensor_quant_i8.hpp" +#include namespace infinilm::layers::attention::backends { @@ -55,25 +56,32 @@ infinicore::Tensor StaticAttentionImpl::forward(const AttentionLayer &layer, // v_total : [bs, n_kv_head, max_seq_len, head_dim] auto [k_total, v_total] = do_kv_cache_update(layer, k_permuted, v_permuted, kv_cache, past_sequence_lengths.value()); + if (infinilm::quantization::KVQuantAlgo::NONE != this->kv_quant_scheme_) { + infinilm::KVQuantUtils::dequantize( + k_total, v_total, + this->kv_quant_scheme_, + k_scale, + v_scale, + q_reshaped); + } + infinicore::Tensor attn_output; - if (false) { - // experimental nineoothed flash attention - attn_output = infinicore::op::flash_attention(q_reshaped, k_total, v_total, total_sequence_lengths.value(), scale_, true); - attn_output = attn_output->permute({0, 2, 1, 3}) - ->contiguous() - ->view({batch_size, seq_len, num_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim] + if (attn_metadata.block_tables.has_value() && seq_len == 1) { + auto query = q_rope->contiguous()->view({batch_size * seq_len, num_heads_, head_dim_}); + auto out = infinicore::Tensor::empty({batch_size * seq_len, num_heads_, head_dim_}, query->dtype(), query->device()); + infinicore::op::paged_attention_( + out, + query, + k_total, + v_total, + attn_metadata.block_tables.value(), + total_sequence_lengths.value(), + std::nullopt, + scale_); + attn_output = out->view({batch_size, seq_len, num_heads_ * head_dim_}); } else { size_t total_seq_len = reinterpret_cast(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0]; - if (infinilm::quantization::KVQuantAlgo::NONE != this->kv_quant_scheme_) { - infinilm::KVQuantUtils::dequantize( - k_total, v_total, - this->kv_quant_scheme_, - k_scale, - v_scale, - q_reshaped); - } - k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim] @@ -107,25 +115,13 @@ std::tuple StaticAttentionImpl::do_kv_ca const infinicore::Tensor past_sequence_lengths) const { auto batch_size = key->size(0); - auto update_len = key->size(2); auto k_cache_layer = kv_cache->narrow({{0, 0, 1}})->squeeze(0); auto v_cache_layer = kv_cache->narrow({{0, 1, 1}})->squeeze(0); size_t max_batch_size = k_cache_layer->size(0); - size_t max_seq_len = k_cache_layer->size(2); - auto device = k_cache_layer->device(); - ASSERT_EQ(batch_size, max_batch_size); - size_t cache_pos = reinterpret_cast(past_sequence_lengths->to(infinicore::Device::cpu())->data())[0]; - auto result_len = cache_pos + update_len; - ASSERT(result_len <= max_seq_len); - - auto k_cache_update = k_cache_layer->narrow({{2, cache_pos, update_len}}); - auto v_cache_update = v_cache_layer->narrow({{2, cache_pos, update_len}}); - - k_cache_update->copy_from(key); - v_cache_update->copy_from(value); + infinicore::op::kv_caching_(k_cache_layer, v_cache_layer, key, value, past_sequence_lengths); return {k_cache_layer, v_cache_layer}; }