@@ -76,7 +76,7 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
7676 ComputeParams c_params;
7777 std::tie (m_params, c_params) = GgmlOvDecoder::compute_llm_params (cgraph, is_static);
7878
79- const auto key = compute_graph_key (cgraph);
79+ graph_key key (cgraph);
8080 bool cache_hit;
8181
8282 int64_t decoder_end_time;
@@ -90,19 +90,22 @@ enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, const std::strin
9090 auto it = decoder_cache.find (key);
9191
9292 cache_hit = it != decoder_cache.end ();
93+ ModelParams old_m_params;
9394 if (cache_hit) {
9495 ggml_decoder = it->second ;
95- cache_hit = ggml_decoder->get_model_params ().can_reuse_dynamically (m_params);
96+ old_m_params = ggml_decoder->get_model_params ();
97+ cache_hit = old_m_params.can_reuse_dynamically (m_params);
9698 }
9799
98100 if (cache_hit) {
99101 std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
100- ggml_decoder = decoder_cache[key];
101102 ggml_decoder->set_compute_params (c_params);
102103 ggml_decoder->set_model_params (m_params);
104+ if (old_m_params.kv_buffer_changed (m_params)) {
105+ ggml_decoder->update_io (cgraph);
106+ }
103107 ggml_decoder->add_extra_inputs ();
104- infer_request = infer_request_cache[key];
105-
108+ infer_request = infer_request_cache.at (key);
106109 if (stateful) {
107110 const auto * inp_pos = get_inp_pos_tensor (cgraph);
108111 int32_t * pos_data = (int32_t *) inp_pos->data ;
@@ -240,7 +243,7 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
240243
241244 const auto * inp_pos = get_inp_pos_tensor (cgraph);
242245 const auto is_prefill = get_is_prefill (inp_pos);
243- const auto key = compute_graph_key (cgraph);
246+ graph_key key (cgraph);
244247 bool cache_hit;
245248
246249 int64_t decoder_end_time;
@@ -254,19 +257,23 @@ enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph) {
254257 auto it = decoder_cache.find (key);
255258
256259 cache_hit = it != decoder_cache.end ();
260+ ModelParams old_m_params;
257261 if (cache_hit) {
258262 ggml_decoder = it->second ;
259- cache_hit = ggml_decoder->get_model_params ().can_reuse_statically (m_params);
263+ old_m_params = ggml_decoder->get_model_params ();
264+ cache_hit = old_m_params.can_reuse_statically (m_params);
260265 }
261266
262267 if (cache_hit) {
263268 std::map<std::string, std::shared_ptr<ov::Node>> model_weights;
264- ggml_decoder = decoder_cache[key];
265269 ggml_decoder->m_is_prefill = is_prefill;
266270 ggml_decoder->set_model_params (m_params);
267271 ggml_decoder->set_compute_params (c_params);
272+ if (old_m_params.kv_buffer_changed (m_params)) {
273+ ggml_decoder->update_io (cgraph);
274+ }
268275 ggml_decoder->add_extra_inputs ();
269- infer_request = is_prefill ? infer_request_cache_prefill[ key] : infer_request_cache[ key] ;
276+ infer_request = is_prefill ? infer_request_cache_prefill. at ( key) : infer_request_cache. at ( key) ;
270277
271278 decoder_end_time = ggml_time_us ();
272279 conversion_end_time = decoder_end_time;
@@ -761,17 +768,4 @@ bool get_is_prefill(const ggml_tensor * inp_pos) {
761768 return inp_pos->ne [0 ] > 1 ;
762769}
763770
764- graph_key compute_graph_key (ggml_cgraph * cgraph) {
765- graph_key key;
766- key.n_nodes = cgraph->n_nodes ;
767-
768- for (int i = 0 ; i < cgraph->n_nodes ; ++i) {
769- const auto * node = cgraph->nodes [i];
770- if (node->op == GGML_OP_SET_ROWS && strncmp (node->src [2 ]->name , " cache_k_l0" , 10 ) == 0 ) {
771- key.cache_k_l0 = node->src [2 ];
772- }
773- }
774- return key;
775- }
776-
777771#pragma GCC diagnostic pop
0 commit comments