@@ -70,22 +70,16 @@ OutputVector translate_rope(const NodeContext & context) {
7070 constexpr int ROPE_TYPE_NORM = 0 ;
7171
7272 if (mode == ROPE_TYPE_NORM) {
73+ auto neg_one = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {-1 });
7374 auto zero = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
7475 auto one = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {1 });
7576 auto two = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {2 });
7677 auto end = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {output_shape[3 ]});
7778 Output<Node> even_slice;
7879 Output<Node> odd_slice;
79- int32_t unsqueeze_dim = 4 ;
80- if (context.is_stateful ()) {
81- unsqueeze_dim = 3 ;
82- even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, two);
83- odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, two);
84- } else {
85- auto three = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {3 });
86- even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, three);
87- odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, three);
88- }
80+ int32_t unsqueeze_dim = context.is_stateful () ? 3 : 4 ;
81+ even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, neg_one);
82+ odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, neg_one);
8983
9084 Output<Node> first_half =
9185 std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node),
@@ -105,7 +99,7 @@ OutputVector translate_rope(const NodeContext & context) {
10599 res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false );
106100 } else if (mode == ROPE_TYPE_NEOX) {
107101 auto data_split = std::make_shared<ov::op::v1::Split>(
108- data_node, ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{}, {3 }), 2 );
102+ data_node, ov::op::v0::Constant::create (ov::element::i64 , ov::Shape{}, {- 1 }), 2 );
109103 Output<Node> slice_data_node_0 = data_split->outputs ()[0 ];
110104 Output<Node> slice_data_node_1 = data_split->outputs ()[1 ];
111105
@@ -117,11 +111,7 @@ OutputVector translate_rope(const NodeContext & context) {
117111 std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node),
118112 std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node));
119113
120- int32_t concat_dim = 3 ;
121- if (context.is_stateful ()) {
122- concat_dim = 2 ;
123- }
124- res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, concat_dim);
114+ res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, -1 );
125115 }
126116
127117 return rename_outputs_with_suffix ({res}, context.get_name ());
0 commit comments