@@ -168,10 +168,10 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
168168 if (is_stateful ()) {
169169 // TODO: The shape modification for stateful model below is not validated for all supported models yet. More generic solution might be needed
170170 // to enable additional cases. Ideally, this could be removed from decoder and done as part of a transformation later.
171- auto stateless_kv_shape = get_graph_input_shape (node, src);
172- assert (stateless_kv_shape.size () == 4 && stateless_kv_shape[0 ] == 1 && stateless_kv_shape[1 ] == 1
173- && stateless_kv_shape[2 ].is_dynamic () && stateless_kv_shape[3 ] == (m_model_params.n_heads_kv *m_model_params.head_size ));
174- stateful_kv_shape = {stateless_kv_shape[0 ], ov::Dimension::dynamic (), m_model_params.n_heads_kv , m_model_params.head_size };
171+ // auto stateless_kv_shape = get_graph_input_shape(node, src);
172+ // assert(stateless_kv_shape.size() == 4 && stateless_kv_shape[0] == 1 && stateless_kv_shape[1] == 1
173+ // && stateless_kv_shape[2].is_dynamic() && stateless_kv_shape[3] == (m_model_params.n_heads_kv*m_model_params.head_size));
174+ // stateful_kv_shape = {stateless_kv_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv, m_model_params.head_size};
175175 }
176176 }
177177 }
@@ -180,9 +180,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
180180 }
181181 m_inputs[src_name] = src;
182182 assert (stateful_kv_shape.rank ().is_static ());
183- ov::PartialShape param_shape = (stateful_kv_shape.rank ().get_length () != 0 )
184- ? stateful_kv_shape
185- : get_graph_input_shape (node, src);
183+ ov::PartialShape param_shape =
184+ (stateful_kv_shape.rank ().get_length () != 0 ) ? stateful_kv_shape : get_graph_input_shape (node, src);
186185 auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type (src), param_shape);
187186 param_node->set_friendly_name (src_name);
188187 param_node->output (0 ).get_tensor ().set_names ({src_name});
@@ -324,9 +323,8 @@ std::pair<ModelParams, ComputeParams> GgmlOvDecoder::compute_llm_params(ggml_cgr
324323 int layer = extract_layer_from_name (cache_k->name );
325324 auto * mask = node->src [3 ];
326325 std::string mask_name (mask->name );
327- assert (mask_name.find (" self_kq_mask" ) == 0 );
328326
329- if (std::string (node-> src [ 3 ]-> name ) .find (" swa" ) != std::string::npos) {
327+ if (mask_name .find (" swa" ) != std::string::npos) {
330328 model_params.swa_layers .push_back (layer);
331329 model_params.ctx_per_seq_swa = cache_k->ne [1 ];
332330 } else {
@@ -351,25 +349,20 @@ std::pair<ModelParams, ComputeParams> GgmlOvDecoder::compute_llm_params(ggml_cgr
351349 compute_params.attention_size_swa = model_params.ctx_per_seq_swa ;
352350 compute_params.token_len_per_seq = 1 ;
353351 }
354-
355- } else if (node->op == GGML_OP_ROPE) {
356- if (name.find (" Qcur-0" ) == 0 || std::string (node->src [0 ]->name ).find (" Qcur-0" ) == 0 ) {
357- model_params.head_size = node->ne [0 ];
358- model_params.n_heads = node->ne [1 ];
359- model_params.rope_params = node->op_params ;
360- auto * inp_pos = node->src [1 ];
361- compute_params.input_len = inp_pos->ne [0 ];
362- } else if (name.find (" Kcur-0" ) == 0 || std::string (node->src [0 ]->name ).find (" Kcur-0" ) == 0 ) {
363- model_params.n_heads_kv = node->ne [1 ];
364- }
365- } else if (node->op == GGML_OP_GET_ROWS && std::string (node->src [1 ]->name ) == " inp_out_ids" ) {
366- // for static case, output_len is always 1 except for llama-perplexity
367- compute_params.output_len = node->src [1 ]->ne [0 ];
368- if (is_static && compute_params.output_len == 0 ) {
369- compute_params.output_len = 1 ;
370- }
352+ break ;
353+ }
354+ if (node->op == GGML_OP_ROPE) {
355+ model_params.rope_params = node->op_params ;
356+ auto * inp_pos = node->src [1 ];
357+ compute_params.input_len = inp_pos->ne [0 ];
371358 }
372359 }
360+ auto * output_tensor = cgraph->nodes [cgraph->n_nodes - 1 ];
361+ compute_params.output_len = output_tensor->ne [1 ];
362+ // for NPU, output_len is always 1 except for llama-perplexity
363+ if (is_static && compute_params.output_len == 0 ) {
364+ compute_params.output_len = 1 ;
365+ }
373366 model_params.ctx = model_params.ctx_per_seq * model_params.n_seq ;
374367 model_params.ctx_swa = model_params.ctx_per_seq_swa * model_params.n_seq ;
375368 return {model_params, compute_params};
@@ -385,14 +378,17 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co
385378 auto name = std::string (input->name );
386379 ov::PartialShape input_shape;
387380
388- if (name == " inp_tokens" || name == " inp_pos" ) {
381+ if ((op->op == GGML_OP_GET_ROWS && op->src [0 ]->op == GGML_OP_NONE) || op->op == GGML_OP_ROPE) {
382+ // tokens or positions
389383 int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1 ) : -1 ;
390384 input_shape = ov::PartialShape{1 , 1 , 1 , len};
391385
392- } else if (name == " inp_out_ids" ) {
386+ } else if (op->op == GGML_OP_GET_ROWS) {
387+ // output index
393388 input_shape = ov::PartialShape{1 , 1 , 1 , m_is_static ? m_compute_params.output_len : -1 };
394389
395- } else if (name.find (" self_kq_mask" ) == 0 ) {
390+ } else if (op->op == GGML_OP_CPY || op->op == GGML_OP_FLASH_ATTN_EXT) {
391+ // mask
396392 if (m_is_static) {
397393 input_shape = ov::PartialShape{1 , 1 , m_is_prefill ? m_prefill_chunk_size : 1 , m_model_params.ctx };
398394 } else if (m_is_stateful) {
@@ -401,14 +397,16 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co
401397 input_shape = ov::PartialShape{-1 , 1 , -1 , -1 };
402398 }
403399
404- } else if (name.find (" cache_" ) == 0 ) {
400+ } else if (op && op->op == GGML_OP_SET_ROWS && op->src [2 ] == input) {
401+ // kvcache
405402 input_shape = ov::PartialShape{get_shape (input)};
406403 if (!m_is_static) {
407404 // do not fix ctx size to make llama-bench work
408405 input_shape[2 ] = -1 ;
409406 }
410407
411408 } else if (op && op->op == GGML_OP_SET_ROWS && op->src [1 ] == input) {
409+ // kv update index
412410 int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1 ) : -1 ;
413411 input_shape = ov::PartialShape{1 , 1 , 1 , len};
414412
0 commit comments