Skip to content

Commit a6eafbc

Browse files
committed
Stateful fix for shape errors after rebase
1 parent ff9bb1a commit a6eafbc

4 files changed

Lines changed: 21 additions & 6 deletions

File tree

ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@ OutputVector translate_glu_swiglu(const NodeContext & context) {
2626
src1 = context.get_input(1);
2727
} else {
2828
auto combined = context.get_input(0);
29-
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {3});
29+
// TODO: Will it work if we set it to "-1" for all cases?
30+
int32_t split_dim = 3;
31+
if (context.is_stateful()) {
32+
split_dim = -1;
33+
}
34+
auto split_axis = ov::op::v0::Constant::create(ov::element::i64, {}, {split_dim});
3035
auto split = std::make_shared<ov::op::v1::Split>(combined, split_axis, 2);
3136
src0 = split->output(0);
3237
src1 = split->output(1);

ggml/src/ggml-openvino/openvino/op/rope.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,12 @@ OutputVector translate_rope(const NodeContext & context) {
104104
ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]});
105105
res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false);
106106
} else if (mode == ROPE_TYPE_NEOX) {
107+
int32_t split_dim = 3;
108+
if (context.is_stateful()) {
109+
split_dim = 2;
110+
}
107111
auto data_split = std::make_shared<ov::op::v1::Split>(
108-
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}), 2);
112+
data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {split_dim}), 2);
109113
Output<Node> slice_data_node_0 = data_split->outputs()[0];
110114
Output<Node> slice_data_node_1 = data_split->outputs()[1];
111115

@@ -117,9 +121,10 @@ OutputVector translate_rope(const NodeContext & context) {
117121
std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
118122
std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
119123

124+
// TODO: Will it work if we set it to "-1" for all cases?
120125
int32_t concat_dim = 3;
121126
if (context.is_stateful()) {
122-
concat_dim = 2;
127+
concat_dim = -1;
123128
}
124129
res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, concat_dim);
125130
}

ggml/src/ggml-openvino/openvino/utils.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,15 @@ ov::Output<ov::Node> process_view_input(const NodeContext & context, int input_i
213213
}
214214
int64_t slice_end = split_addr + slice_len;
215215

216+
int32_t axes_val = 3;
217+
if (context.is_stateful()) {
218+
axes_val = 2;
219+
}
220+
216221
auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr});
217222
auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end});
218223
auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
219-
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {3});
224+
auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {axes_val});
220225
auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes);
221226
return sliced;
222227
}

ggml/src/ggml-openvino/utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,10 +614,10 @@ ov::Tensor get_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, con
614614
auto output_type = ggml_decoder->get_ov_type(ggml_tensor);
615615
auto output_shape = ggml_decoder->get_shape(ggml_tensor);
616616

617-
if (ggml_decoder->is_static() && result_name == "result_output" && output_shape[2] == 0) {
617+
if (ggml_decoder->is_static() && (result_name == "result_output" || result_name == "result_norm") && output_shape[2] == 0) {
618618
output_shape[2] = 1;
619619
}
620-
if (ggml_decoder->is_stateful() && result_name == "result_output") {
620+
if (ggml_decoder->is_stateful() && (result_name == "result_output" || result_name == "result_norm")) {
621621
std::vector<long unsigned int> output_shape_3d;
622622
for (size_t i=1; i<output_shape.size(); i++) {
623623
output_shape_3d.push_back(output_shape[i]);

0 commit comments

Comments
 (0)