Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions custom_ops/xpu_ops/src/ops/block_attn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,17 @@ std::vector<paddle::Tensor> BlockAttnKernel(
rope_head_dim = rotary_embs.dims()[4];
}
std::string pos_emb_type;
if (use_neox_rotary_style == true) {
if (use_neox_rotary_style) {
pos_emb_type = "NEOX";
} else if (rope_head_dim == head_dim / 2) {
pos_emb_type = "HALF_HEAD_DIM";
} else {
pos_emb_type = "NORMAL";
}
float partial_rotary_factor = 1.0;
if (use_neox_rotary_style) {
partial_rotary_factor = static_cast<float>(rope_head_dim) / head_dim;
}

auto block_attn_out =
paddle::empty({token_num, hidden_dim}, qkv.type(), qkv.place());
Expand Down Expand Up @@ -362,7 +366,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
nullptr, // intx_v_pc_scale
nullptr, // intx_k_pc_zero
nullptr, // intx_v_pc_zero
rope_3d);
rope_3d,
nullptr,
nullptr,
partial_rotary_factor);
PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
} else {
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
Expand Down Expand Up @@ -620,7 +627,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
nullptr, // intx_v_pc_scale
nullptr, // intx_k_pc_zero
nullptr, // intx_v_pc_zero
rope_3d);
rope_3d,
nullptr,
nullptr,
partial_rotary_factor);
PD_CHECK(ret == api::SUCCESS, "split_neox_cache_kv_encoder failed.");
} else {
ret = infer_ops::split_rope_cache_kv_encoder<XPU_XType,
Expand Down Expand Up @@ -818,7 +828,10 @@ std::vector<paddle::Tensor> BlockAttnKernel(
reinterpret_cast<D_Scale*>(quant_v_scale), // v_cache_scale_inv
reinterpret_cast<D_Scale*>(quant_k_zp), // k_cache_zp
reinterpret_cast<D_Scale*>(quant_v_zp), // v_cache_zp
rope_3d);
rope_3d,
nullptr,
nullptr,
partial_rotary_factor);
PD_CHECK(ret == api::SUCCESS, "split_rope_cache_kv_decoder failed.");
} else {
ret = infer_ops::split_rope_cache_kv_decoder<XPU_XType,
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
)

if self.model_format == "torch" and "output_dim" in extra_weight_attrs:
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]
if extra_weight_attrs["output_dim"] is not None:
extra_weight_attrs["output_dim"] = not extra_weight_attrs["output_dim"]

set_weight_attrs(
layer.weight,
Expand Down
13 changes: 9 additions & 4 deletions fastdeploy/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __call__(self, position_ids):
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
partial_rotary_position_ids = position_ids / self.partial_rotary_factor
freqs = paddle.einsum("ij,k->ijk", partial_rotary_position_ids.cast("float32"), inv_freq)
if paddle.is_compiled_with_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
if current_platform.is_xpu() or paddle.is_compiled_with_custom_device("iluvatar_gpu"):
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.stack([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
Expand Down Expand Up @@ -89,9 +89,14 @@ def __call__(self, position_ids):
bsz, max_seq_len = position_ids.shape[:2]
inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
if current_platform.is_xpu():
# shape: [B, S, D]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim), dtype="float32")
emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim))
else:
# shape: [B, S, D/2]
rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
# shape: [B, S, 1, D]
emb = paddle.unsqueeze(emb, 2)
rot_emb[0] = paddle.cos(emb)
Expand Down
2 changes: 1 addition & 1 deletion fastdeploy/model_executor/models/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(
fd_config=fd_config,
prefix=f"{prefix}.up_gate_proj",
input_size=fd_config.model_config.hidden_size,
output_size=[intermediate_size, intermediate_size],
output_sizes=[intermediate_size, intermediate_size],
with_bias=False,
)

Expand Down
1 change: 1 addition & 0 deletions fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,7 @@ def _init_share_inputs(self, max_num_seqs: int):
position_ids=tmp_position_ids,
base=self.model_config.rope_theta,
model_config=self.model_config,
partial_rotary_factor=self.model_config.partial_rotary_factor,
)

# Set block tables
Expand Down
Loading