Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 33 additions & 30 deletions ggml/src/ggml-openvino/ggml-decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,11 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
// TODO: The shape modification for stateful model below is not validated for all supported models yet. More generic solution might be needed
// to enable additional cases. Ideally, this could be removed from decoder and done as part of a transformation later.
auto stateless_kv_shape = get_graph_input_shape(node, src);
assert(stateless_kv_shape.size() == 4 && stateless_kv_shape[0] == 1 && stateless_kv_shape[1] == 1
&& stateless_kv_shape[2].is_dynamic() && stateless_kv_shape[3] == (m_model_params.n_heads_kv*m_model_params.head_size));
stateful_kv_shape = {stateless_kv_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv, m_model_params.head_size};
assert(stateless_kv_shape.size() == 4 && stateless_kv_shape[0] == 1 &&
stateless_kv_shape[1] == 1 && stateless_kv_shape[2].is_dynamic() &&
stateless_kv_shape[3] == (m_model_params.n_heads_kv * m_model_params.head_size));
stateful_kv_shape = {stateless_kv_shape[0], ov::Dimension::dynamic(),
m_model_params.n_heads_kv, m_model_params.head_size};
}
}
}
Expand All @@ -180,9 +182,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
}
m_inputs[src_name] = src;
assert(stateful_kv_shape.rank().is_static());
ov::PartialShape param_shape = (stateful_kv_shape.rank().get_length() != 0)
? stateful_kv_shape
: get_graph_input_shape(node, src);
ov::PartialShape param_shape =
(stateful_kv_shape.rank().get_length() != 0) ? stateful_kv_shape : get_graph_input_shape(node, src);
auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), param_shape);
param_node->set_friendly_name(src_name);
param_node->output(0).get_tensor().set_names({src_name});
Expand All @@ -197,7 +198,7 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
static std::set<std::string> debug_output_names = {};
// Workaround: the final tensor "result_output" does not have GGML_TENSOR_FLAG_OUTPUT flag set in cgraph
if (node->op == GGML_OP_SET_ROWS || node->flags & GGML_TENSOR_FLAG_OUTPUT ||
node_output_name.find("output") != std::string::npos || debug_output_names.count(node_output_name)) {
debug_output_names.count(node_output_name)) {
if (m_model_outputs.find(node_output_name) == m_model_outputs.end()) {
m_model_outputs[node_output_name] = node_output;
}
Expand Down Expand Up @@ -312,6 +313,11 @@ std::pair<ModelParams, ComputeParams> GgmlOvDecoder::compute_llm_params(ggml_cgr
auto * node = cgraph->nodes[i];
std::string name = std::string(node->name);
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
model_params.n_heads = node->src[0]->ne[2];
model_params.n_heads_kv = node->src[1]->ne[2];
model_params.head_size = node->src[0]->ne[0];
compute_params.input_len = node->src[0]->ne[1];

auto * cache_k_perm = node->src[1];
if (cache_k_perm->op == GGML_OP_CPY) {
cache_k_perm = cache_k_perm->src[0];
Expand All @@ -324,9 +330,8 @@ std::pair<ModelParams, ComputeParams> GgmlOvDecoder::compute_llm_params(ggml_cgr
int layer = extract_layer_from_name(cache_k->name);
auto * mask = node->src[3];
std::string mask_name(mask->name);
assert(mask_name.find("self_kq_mask") == 0);

if (std::string(node->src[3]->name).find("swa") != std::string::npos) {
if (mask_name.find("swa") != std::string::npos) {
model_params.swa_layers.push_back(layer);
model_params.ctx_per_seq_swa = cache_k->ne[1];
} else {
Expand All @@ -351,25 +356,18 @@ std::pair<ModelParams, ComputeParams> GgmlOvDecoder::compute_llm_params(ggml_cgr
compute_params.attention_size_swa = model_params.ctx_per_seq_swa;
compute_params.token_len_per_seq = 1;
}

} else if (node->op == GGML_OP_ROPE) {
if (name.find("Qcur-0") == 0 || std::string(node->src[0]->name).find("Qcur-0") == 0) {
model_params.head_size = node->ne[0];
model_params.n_heads = node->ne[1];
model_params.rope_params = node->op_params;
auto * inp_pos = node->src[1];
compute_params.input_len = inp_pos->ne[0];
} else if (name.find("Kcur-0") == 0 || std::string(node->src[0]->name).find("Kcur-0") == 0) {
model_params.n_heads_kv = node->ne[1];
}
} else if (node->op == GGML_OP_GET_ROWS && std::string(node->src[1]->name) == "inp_out_ids") {
// for static case, output_len is always 1 except for llama-perplexity
compute_params.output_len = node->src[1]->ne[0];
if (is_static && compute_params.output_len == 0) {
compute_params.output_len = 1;
}
break;
}
if (node->op == GGML_OP_ROPE) {
model_params.rope_params = node->op_params;
}
}
auto * output_tensor = cgraph->nodes[cgraph->n_nodes - 1];
compute_params.output_len = output_tensor->ne[1];
// for NPU, output_len is always 1 except for llama-perplexity
if (is_static && compute_params.output_len == 0) {
compute_params.output_len = 1;
}
model_params.ctx = model_params.ctx_per_seq * model_params.n_seq;
model_params.ctx_swa = model_params.ctx_per_seq_swa * model_params.n_seq;
return {model_params, compute_params};
Expand All @@ -385,14 +383,17 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co
auto name = std::string(input->name);
ov::PartialShape input_shape;

if (name == "inp_tokens" || name == "inp_pos") {
if ((op->op == GGML_OP_GET_ROWS && op->src[0]->op == GGML_OP_NONE) || op->op == GGML_OP_ROPE) {
// tokens or positions
int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1;
input_shape = ov::PartialShape{1, 1, 1, len};

} else if (name == "inp_out_ids") {
} else if (op->op == GGML_OP_GET_ROWS) {
// output index
input_shape = ov::PartialShape{1, 1, 1, m_is_static ? m_compute_params.output_len : -1};

} else if (name.find("self_kq_mask") == 0) {
} else if (op->op == GGML_OP_CPY || op->op == GGML_OP_FLASH_ATTN_EXT) {
// mask
if (m_is_static) {
input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx};
} else if (m_is_stateful) {
Expand All @@ -401,14 +402,16 @@ ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, co
input_shape = ov::PartialShape{-1, 1, -1, -1};
}

} else if (name.find("cache_") == 0) {
} else if (op && op->op == GGML_OP_SET_ROWS && op->src[2] == input) {
// kvcache
input_shape = ov::PartialShape{get_shape(input)};
if (!m_is_static) {
// do not fix ctx size to make llama-bench work
input_shape[2] = -1;
}

} else if (op && op->op == GGML_OP_SET_ROWS && op->src[1] == input) {
// kv update index
int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1;
input_shape = ov::PartialShape{1, 1, 1, len};

Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-openvino/ggml-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct ModelParams {
int ctx_swa = -1;
int ctx_per_seq = -1;
int ctx_per_seq_swa = -1;
int n_seq = -1;
int n_seq = 1;
int n_heads = -1;
int n_heads_kv = -1;
int head_size = -1;
Expand All @@ -37,14 +37,14 @@ struct ModelParams {
};

struct ComputeParams {
int n_seq_active = -1;
int seq_active_start = -1;
int n_seq_active = 1;
int seq_active_start = 0;
int attention_size = -1;
int attention_size_swa = -1;
int input_len = -1;
int token_len_per_seq = -1;
int past_kv_len = -1;
int output_len = -1;
int output_len = 1;
};

class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder {
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ OutputVector translate_glu_geglu(const NodeContext & context) {
src1 = context.get_input(1);
} else {
auto combined = context.get_input(0);
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1});
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
src0 = split->output(0);
src1 = split->output(1);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
src1 = context.get_input(1);
} else {
auto combined = context.get_input(0);
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1});
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
src0 = split->output(0);
src1 = split->output(1);
Expand Down
22 changes: 6 additions & 16 deletions ggml/src/ggml-openvino/openvino/op/rope.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,16 @@ OutputVector translate_rope(const NodeContext & context) {
constexpr int ROPE_TYPE_NORM = 0;

if (mode == ROPE_TYPE_NORM) {
auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
Output<Node> even_slice;
Output<Node> odd_slice;
int32_t unsqueeze_dim = 4;
if (context.is_stateful()) {
unsqueeze_dim = 3;
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, two);
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, two);
} else {
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
}
int32_t unsqueeze_dim = context.is_stateful() ? 3 : 4;
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, neg_one);
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, neg_one);

Output<Node> first_half =
std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),
Expand All @@ -105,7 +99,7 @@ OutputVector translate_rope(const NodeContext & context) {
res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false);
} else if (mode == ROPE_TYPE_NEOX) {
auto data_split = std::make_shared<ov::op::v1::Split>(
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}), 2);
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2);
Output<Node> slice_data_node_0 = data_split->outputs()[0];
Output<Node> slice_data_node_1 = data_split->outputs()[1];

Expand All @@ -117,11 +111,7 @@ OutputVector translate_rope(const NodeContext & context) {
std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));

int32_t concat_dim = 3;
if (context.is_stateful()) {
concat_dim = 2;
}
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, concat_dim);
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, -1);
}

return rename_outputs_with_suffix({res}, context.get_name());
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-openvino/openvino/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ ov::Output<ov::Node> process_view_input(const NodeContext & context, int input_i
auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr});
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end});
auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {context.is_stateful() ? 2 : 3});
auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes);
return sliced;
}
Expand Down
6 changes: 4 additions & 2 deletions ggml/src/ggml-openvino/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons

ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
const std::string & param_name) {
// NPU decoding stage
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);

Expand Down Expand Up @@ -540,6 +541,7 @@ ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml
ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
const std::string & param_name,
int chunk_index) {
// NPU prompt processing stage
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);

Expand Down Expand Up @@ -614,10 +616,10 @@ ov::Tensor get_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, con
auto output_type = ggml_decoder->get_ov_type(ggml_tensor);
auto output_shape = ggml_decoder->get_shape(ggml_tensor);

if (ggml_decoder->is_static() && result_name == "result_output" && output_shape[2] == 0) {
if (ggml_decoder->is_static() && output_shape[2] == 0) {
output_shape[2] = 1;
}
if (ggml_decoder->is_stateful() && result_name == "result_output") {
if (ggml_decoder->is_stateful() && ggml_tensor->flags & GGML_TENSOR_FLAG_OUTPUT) {
std::vector<long unsigned int> output_shape_3d;
for (size_t i=1; i<output_shape.size(); i++) {
output_shape_3d.push_back(output_shape[i]);
Expand Down