Skip to content

Commit 1502b6f

Browse files
authored
add instantiations for decoder rope enfore_fmul_rn=true (PaddlePaddle#7009)
1 parent 482f951 commit 1502b6f

2 files changed

Lines changed: 232 additions & 0 deletions

File tree

custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,3 +1163,117 @@ template void DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
11631163
const paddle::optional<paddle::Tensor>& q_norm_weight,
11641164
const paddle::optional<paddle::Tensor>& k_norm_weight,
11651165
const float rms_norm_eps);
1166+
1167+
template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int, true>(
1168+
const AppendAttnMetaData& meta_data,
1169+
const paddle::Tensor&
1170+
qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
1171+
// kv_num_heads, head_dim] if GQA)
1172+
const paddle::Tensor& seq_lens,
1173+
const paddle::Tensor& seq_lens_encoder,
1174+
const paddle::Tensor& cu_seqlens_q,
1175+
const paddle::Tensor& block_tables,
1176+
const paddle::optional<paddle::Tensor>& rotary_embs,
1177+
const paddle::optional<paddle::Tensor>& qkv_out_scales,
1178+
const paddle::optional<paddle::Tensor>& qkv_biases,
1179+
const paddle::optional<paddle::Tensor>& cache_k_scale,
1180+
const paddle::optional<paddle::Tensor>& cache_v_scale,
1181+
const paddle::optional<paddle::Tensor>& cache_k_zp,
1182+
const paddle::optional<paddle::Tensor>& cache_v_zp,
1183+
const std::string& cache_quant_type_str,
1184+
const bool use_neox_rotary_style,
1185+
const bool rope_3d,
1186+
const int max_seq_len,
1187+
cudaStream_t& stream,
1188+
paddle::Tensor* qkv_out,
1189+
paddle::Tensor* key_cache_out,
1190+
paddle::Tensor* value_cache_out,
1191+
const paddle::optional<paddle::Tensor>& q_norm_weight,
1192+
const paddle::optional<paddle::Tensor>& k_norm_weight,
1193+
const float rms_norm_eps);
1194+
1195+
template void
1196+
DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16, true>(
1197+
const AppendAttnMetaData& meta_data,
1198+
const paddle::Tensor&
1199+
qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
1200+
// kv_num_heads, head_dim] if GQA)
1201+
const paddle::Tensor& seq_lens,
1202+
const paddle::Tensor& seq_lens_encoder,
1203+
const paddle::Tensor& cu_seqlens_q,
1204+
const paddle::Tensor& block_tables,
1205+
const paddle::optional<paddle::Tensor>& rotary_embs,
1206+
const paddle::optional<paddle::Tensor>& qkv_out_scales,
1207+
const paddle::optional<paddle::Tensor>& qkv_biases,
1208+
const paddle::optional<paddle::Tensor>& cache_k_scale,
1209+
const paddle::optional<paddle::Tensor>& cache_v_scale,
1210+
const paddle::optional<paddle::Tensor>& cache_k_zp,
1211+
const paddle::optional<paddle::Tensor>& cache_v_zp,
1212+
const std::string& cache_quant_type_str,
1213+
const bool use_neox_rotary_style,
1214+
const bool rope_3d,
1215+
const int max_seq_len,
1216+
cudaStream_t& stream,
1217+
paddle::Tensor* qkv_out,
1218+
paddle::Tensor* key_cache_out,
1219+
paddle::Tensor* value_cache_out,
1220+
const paddle::optional<paddle::Tensor>& q_norm_weight,
1221+
const paddle::optional<paddle::Tensor>& k_norm_weight,
1222+
const float rms_norm_eps);
1223+
1224+
template void DecoderWriteCacheWithRoPEKernel<paddle::float16, int, true>(
1225+
const AppendAttnMetaData& meta_data,
1226+
const paddle::Tensor&
1227+
qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
1228+
// kv_num_heads, head_dim] if GQA)
1229+
const paddle::Tensor& seq_lens,
1230+
const paddle::Tensor& seq_lens_encoder,
1231+
const paddle::Tensor& cu_seqlens_q,
1232+
const paddle::Tensor& block_tables,
1233+
const paddle::optional<paddle::Tensor>& rotary_embs,
1234+
const paddle::optional<paddle::Tensor>& qkv_out_scales,
1235+
const paddle::optional<paddle::Tensor>& qkv_biases,
1236+
const paddle::optional<paddle::Tensor>& cache_k_scale,
1237+
const paddle::optional<paddle::Tensor>& cache_v_scale,
1238+
const paddle::optional<paddle::Tensor>& cache_k_zp,
1239+
const paddle::optional<paddle::Tensor>& cache_v_zp,
1240+
const std::string& cache_quant_type_str,
1241+
const bool use_neox_rotary_style,
1242+
const bool rope_3d,
1243+
const int max_seq_len,
1244+
cudaStream_t& stream,
1245+
paddle::Tensor* qkv_out,
1246+
paddle::Tensor* key_cache_out,
1247+
paddle::Tensor* value_cache_out,
1248+
const paddle::optional<paddle::Tensor>& q_norm_weight,
1249+
const paddle::optional<paddle::Tensor>& k_norm_weight,
1250+
const float rms_norm_eps);
1251+
1252+
template void
1253+
DecoderWriteCacheWithRoPEKernel<paddle::float16, paddle::float16, true>(
1254+
const AppendAttnMetaData& meta_data,
1255+
const paddle::Tensor&
1256+
qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
1257+
// kv_num_heads, head_dim] if GQA)
1258+
const paddle::Tensor& seq_lens,
1259+
const paddle::Tensor& seq_lens_encoder,
1260+
const paddle::Tensor& cu_seqlens_q,
1261+
const paddle::Tensor& block_tables,
1262+
const paddle::optional<paddle::Tensor>& rotary_embs,
1263+
const paddle::optional<paddle::Tensor>& qkv_out_scales,
1264+
const paddle::optional<paddle::Tensor>& qkv_biases,
1265+
const paddle::optional<paddle::Tensor>& cache_k_scale,
1266+
const paddle::optional<paddle::Tensor>& cache_v_scale,
1267+
const paddle::optional<paddle::Tensor>& cache_k_zp,
1268+
const paddle::optional<paddle::Tensor>& cache_v_zp,
1269+
const std::string& cache_quant_type_str,
1270+
const bool use_neox_rotary_style,
1271+
const bool rope_3d,
1272+
const int max_seq_len,
1273+
cudaStream_t& stream,
1274+
paddle::Tensor* qkv_out,
1275+
paddle::Tensor* key_cache_out,
1276+
paddle::Tensor* value_cache_out,
1277+
const paddle::optional<paddle::Tensor>& q_norm_weight,
1278+
const paddle::optional<paddle::Tensor>& k_norm_weight,
1279+
const float rms_norm_eps);

custom_ops/gpu_ops/append_attn/speculate_write_cache_with_rope_kernel.cu

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,3 +979,121 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
979979
const paddle::optional<paddle::Tensor>& q_norm_weight,
980980
const paddle::optional<paddle::Tensor>& k_norm_weight,
981981
const float rms_norm_eps);
982+
983+
template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int, true>(
984+
const AppendAttnMetaData& meta_data,
985+
const paddle::Tensor&
986+
qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
987+
// gqa_group_size, head_dim] if GQA)
988+
const paddle::Tensor& seq_lens,
989+
const paddle::Tensor& seq_lens_encoder,
990+
const paddle::Tensor& batch_id_per_token,
991+
const paddle::Tensor& cu_seqlens_q,
992+
const paddle::Tensor& block_tables,
993+
const paddle::optional<paddle::Tensor>& rotary_embs,
994+
const paddle::optional<paddle::Tensor>& qkv_out_scales,
995+
const paddle::optional<paddle::Tensor>& qkv_biases,
996+
const paddle::optional<paddle::Tensor>& cache_k_scale,
997+
const paddle::optional<paddle::Tensor>& cache_v_scale,
998+
const paddle::optional<paddle::Tensor>& cache_k_zp,
999+
const paddle::optional<paddle::Tensor>& cache_v_zp,
1000+
const std::string& cache_quant_type_str,
1001+
const bool use_neox_rotary_style,
1002+
const bool rope_3d,
1003+
const int max_seq_len,
1004+
cudaStream_t& stream,
1005+
paddle::Tensor* qkv_out,
1006+
paddle::Tensor* key_cache_out,
1007+
paddle::Tensor* value_cache_out,
1008+
const paddle::optional<paddle::Tensor>& q_norm_weight,
1009+
const paddle::optional<paddle::Tensor>& k_norm_weight,
1010+
const float rms_norm_eps);
1011+
1012+
template void
1013+
SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16, true>(
1014+
const AppendAttnMetaData& meta_data,
1015+
const paddle::Tensor&
1016+
qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
1017+
// gqa_group_size, head_dim] if GQA)
1018+
const paddle::Tensor& seq_lens,
1019+
const paddle::Tensor& seq_lens_encoder,
1020+
const paddle::Tensor& batch_id_per_token,
1021+
const paddle::Tensor& cu_seqlens_q,
1022+
const paddle::Tensor& block_tables,
1023+
const paddle::optional<paddle::Tensor>& rotary_embs,
1024+
const paddle::optional<paddle::Tensor>& qkv_out_scales,
1025+
const paddle::optional<paddle::Tensor>& qkv_biases,
1026+
const paddle::optional<paddle::Tensor>& cache_k_scale,
1027+
const paddle::optional<paddle::Tensor>& cache_v_scale,
1028+
const paddle::optional<paddle::Tensor>& cache_k_zp,
1029+
const paddle::optional<paddle::Tensor>& cache_v_zp,
1030+
const std::string& cache_quant_type_str,
1031+
const bool use_neox_rotary_style,
1032+
const bool rope_3d,
1033+
const int max_seq_len,
1034+
cudaStream_t& stream,
1035+
paddle::Tensor* qkv_out,
1036+
paddle::Tensor* key_cache_out,
1037+
paddle::Tensor* value_cache_out,
1038+
const paddle::optional<paddle::Tensor>& q_norm_weight,
1039+
const paddle::optional<paddle::Tensor>& k_norm_weight,
1040+
const float rms_norm_eps);
1041+
1042+
template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int, true>(
1043+
const AppendAttnMetaData& meta_data,
1044+
const paddle::Tensor&
1045+
qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
1046+
// gqa_group_size, head_dim] if GQA)
1047+
const paddle::Tensor& seq_lens,
1048+
const paddle::Tensor& seq_lens_encoder,
1049+
const paddle::Tensor& batch_id_per_token,
1050+
const paddle::Tensor& cu_seqlens_q,
1051+
const paddle::Tensor& block_tables,
1052+
const paddle::optional<paddle::Tensor>& rotary_embs,
1053+
const paddle::optional<paddle::Tensor>& qkv_out_scales,
1054+
const paddle::optional<paddle::Tensor>& qkv_biases,
1055+
const paddle::optional<paddle::Tensor>& cache_k_scale,
1056+
const paddle::optional<paddle::Tensor>& cache_v_scale,
1057+
const paddle::optional<paddle::Tensor>& cache_k_zp,
1058+
const paddle::optional<paddle::Tensor>& cache_v_zp,
1059+
const std::string& cache_quant_type_str,
1060+
const bool use_neox_rotary_style,
1061+
const bool rope_3d,
1062+
const int max_seq_len,
1063+
cudaStream_t& stream,
1064+
paddle::Tensor* qkv_out,
1065+
paddle::Tensor* key_cache_out,
1066+
paddle::Tensor* value_cache_out,
1067+
const paddle::optional<paddle::Tensor>& q_norm_weight,
1068+
const paddle::optional<paddle::Tensor>& k_norm_weight,
1069+
const float rms_norm_eps);
1070+
1071+
template void
1072+
SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16, true>(
1073+
const AppendAttnMetaData& meta_data,
1074+
const paddle::Tensor&
1075+
qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
1076+
// gqa_group_size, head_dim] if GQA)
1077+
const paddle::Tensor& seq_lens,
1078+
const paddle::Tensor& seq_lens_encoder,
1079+
const paddle::Tensor& batch_id_per_token,
1080+
const paddle::Tensor& cu_seqlens_q,
1081+
const paddle::Tensor& block_tables,
1082+
const paddle::optional<paddle::Tensor>& rotary_embs,
1083+
const paddle::optional<paddle::Tensor>& qkv_out_scales,
1084+
const paddle::optional<paddle::Tensor>& qkv_biases,
1085+
const paddle::optional<paddle::Tensor>& cache_k_scale,
1086+
const paddle::optional<paddle::Tensor>& cache_v_scale,
1087+
const paddle::optional<paddle::Tensor>& cache_k_zp,
1088+
const paddle::optional<paddle::Tensor>& cache_v_zp,
1089+
const std::string& cache_quant_type_str,
1090+
const bool use_neox_rotary_style,
1091+
const bool rope_3d,
1092+
const int max_seq_len,
1093+
cudaStream_t& stream,
1094+
paddle::Tensor* qkv_out,
1095+
paddle::Tensor* key_cache_out,
1096+
paddle::Tensor* value_cache_out,
1097+
const paddle::optional<paddle::Tensor>& q_norm_weight,
1098+
const paddle::optional<paddle::Tensor>& k_norm_weight,
1099+
const float rms_norm_eps);

0 commit comments

Comments
 (0)