Skip to content

Commit aca7c53

Browse files
committed
Fix stateful shapes
1 parent e03a9b1 commit aca7c53

5 files changed

Lines changed: 11 additions & 19 deletions

File tree

ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ OutputVector translate_glu_geglu(const NodeContext & context) {
2626
src1 = context.get_input(1);
2727
} else {
2828
auto combined = context.get_input(0);
29-
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
29+
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1});
3030
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
3131
src0 = split->output(0);
3232
src1 = split->output(1);

ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
2626
src1 = context.get_input(1);
2727
} else {
2828
auto combined = context.get_input(0);
29-
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
29+
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {-1});
3030
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
3131
src0 = split->output(0);
3232
src1 = split->output(1);

ggml/src/ggml-openvino/openvino/op/rope.cpp

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -70,22 +70,16 @@ OutputVector translate_rope(const NodeContext & context) {
7070
constexpr int ROPE_TYPE_NORM = 0;
7171

7272
if (mode == ROPE_TYPE_NORM) {
73+
auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1});
7374
auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
7475
auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
7576
auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2});
7677
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]});
7778
Output<Node> even_slice;
7879
Output<Node> odd_slice;
79-
int32_t unsqueeze_dim = 4;
80-
if (context.is_stateful()) {
81-
unsqueeze_dim = 3;
82-
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, two);
83-
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, two);
84-
} else {
85-
auto three = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
86-
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
87-
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
88-
}
80+
int32_t unsqueeze_dim = context.is_stateful() ? 3 : 4;
81+
even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, neg_one);
82+
odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, neg_one);
8983

9084
Output<Node> first_half =
9185
std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),
@@ -105,7 +99,7 @@ OutputVector translate_rope(const NodeContext & context) {
10599
res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false);
106100
} else if (mode == ROPE_TYPE_NEOX) {
107101
auto data_split = std::make_shared<ov::op::v1::Split>(
108-
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}), 2);
102+
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2);
109103
Output<Node> slice_data_node_0 = data_split->outputs()[0];
110104
Output<Node> slice_data_node_1 = data_split->outputs()[1];
111105

@@ -117,11 +111,7 @@ OutputVector translate_rope(const NodeContext & context) {
117111
std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
118112
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
119113

120-
int32_t concat_dim = 3;
121-
if (context.is_stateful()) {
122-
concat_dim = 2;
123-
}
124-
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, concat_dim);
114+
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, -1);
125115
}
126116

127117
return rename_outputs_with_suffix({res}, context.get_name());

ggml/src/ggml-openvino/openvino/utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ ov::Output<ov::Node> process_view_input(const NodeContext & context, int input_i
216216
auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr});
217217
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end});
218218
auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
219-
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
219+
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {context.is_stateful() ? 2 : 3});
220220
auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes);
221221
return sliced;
222222
}

ggml/src/ggml-openvino/utils.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, cons
497497

498498
ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
499499
const std::string & param_name) {
500+
// NPU decoding stage
500501
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
501502
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
502503

@@ -540,6 +541,7 @@ ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml
540541
ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggml_decoder,
541542
const std::string & param_name,
542543
int chunk_index) {
544+
// NPU prompt processing stage
543545
const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name);
544546
const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor);
545547

0 commit comments

Comments
 (0)