diff --git a/custom_ops/xpu_ops/src/ops/block_attn.cc b/custom_ops/xpu_ops/src/ops/block_attn.cc index c055bfb873d..7bcc9e814e4 100644 --- a/custom_ops/xpu_ops/src/ops/block_attn.cc +++ b/custom_ops/xpu_ops/src/ops/block_attn.cc @@ -158,13 +158,17 @@ std::vector BlockAttnKernel( rope_head_dim = rotary_embs.dims()[4]; } std::string pos_emb_type; - if (use_neox_rotary_style == true) { + if (use_neox_rotary_style) { pos_emb_type = "NEOX"; } else if (rope_head_dim == head_dim / 2) { pos_emb_type = "HALF_HEAD_DIM"; } else { pos_emb_type = "NORMAL"; } + float partial_rotary_factor = 1.0; + if (use_neox_rotary_style) { + partial_rotary_factor = static_cast(rope_head_dim) / head_dim; + } auto block_attn_out = paddle::empty({token_num, hidden_dim}, qkv.type(), qkv.place()); @@ -362,7 +366,10 @@ std::vector BlockAttnKernel( nullptr, // intx_v_pc_scale nullptr, // intx_k_pc_zero nullptr, // intx_v_pc_zero - rope_3d); + rope_3d, + nullptr, + nullptr, + partial_rotary_factor); PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed."); } else { ret = infer_ops::split_rope_cache_kv_encoder BlockAttnKernel( nullptr, // intx_v_pc_scale nullptr, // intx_k_pc_zero nullptr, // intx_v_pc_zero - rope_3d); + rope_3d, + nullptr, + nullptr, + partial_rotary_factor); PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed."); } else { ret = infer_ops::split_rope_cache_kv_encoder BlockAttnKernel( reinterpret_cast(quant_v_scale), // v_cache_scale_inv reinterpret_cast(quant_k_zp), // k_cache_zp reinterpret_cast(quant_v_zp), // v_cache_zp - rope_3d); + rope_3d, + nullptr, + nullptr, + partial_rotary_factor); PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed."); } else { ret = infer_ops::split_rope_cache_kv_decoderijk", partial_rotary_position_ids.cast("float32"), inv_freq) - if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"): + if current_platform.is_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"): # shape: [B, S, D] rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32") emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim)) @@ -89,9 +89,14 @@ def __call__(self, position_ids): bsz, max_seq_len = position_ids.shape[:2] inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim) freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) - # shape: [B, S, D/2] - rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32") - emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2)) + if current_platform.is_xpu(): + # shape: [B, S, D] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32") + emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim)) + else: + # shape: [B, S, D/2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32") + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2)) # shape: [B, S, 1, D] emb = paddle.unsqueeze(emb, 2) rot_emb[0] = paddle.cos(emb) diff --git a/fastdeploy/model_executor/models/glm4_moe.py b/fastdeploy/model_executor/models/glm4_moe.py index 3f45e9df614..e49fe6d59f1 100644 --- a/fastdeploy/model_executor/models/glm4_moe.py +++ b/fastdeploy/model_executor/models/glm4_moe.py @@ -71,7 +71,7 @@ def __init__( fd_config=fd_config, prefix=f"{prefix}.up_gate_proj", input_size=fd_config.model_config.hidden_size, - output_size=[intermediate_size, intermediate_size], + output_sizes=[intermediate_size, intermediate_size], with_bias=False, ) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index a984e8788c4..fa04e2b7983 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -990,6 +990,7 @@ def _init_share_inputs(self, max_num_seqs: int): position_ids=tmp_position_ids, base=self.model_config.rope_theta, model_config=self.model_config, + partial_rotary_factor=self.model_config.partial_rotary_factor, ) # Set block tables