@@ -979,3 +979,121 @@ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16>(
979979 const paddle::optional<paddle::Tensor>& q_norm_weight,
980980 const paddle::optional<paddle::Tensor>& k_norm_weight,
981981 const float rms_norm_eps);
982+
983+ template void SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, int , true >(
984+ const AppendAttnMetaData& meta_data,
985+ const paddle::Tensor&
986+ qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
987+ // gqa_group_size, head_dim] if GQA)
988+ const paddle::Tensor& seq_lens,
989+ const paddle::Tensor& seq_lens_encoder,
990+ const paddle::Tensor& batch_id_per_token,
991+ const paddle::Tensor& cu_seqlens_q,
992+ const paddle::Tensor& block_tables,
993+ const paddle::optional<paddle::Tensor>& rotary_embs,
994+ const paddle::optional<paddle::Tensor>& qkv_out_scales,
995+ const paddle::optional<paddle::Tensor>& qkv_biases,
996+ const paddle::optional<paddle::Tensor>& cache_k_scale,
997+ const paddle::optional<paddle::Tensor>& cache_v_scale,
998+ const paddle::optional<paddle::Tensor>& cache_k_zp,
999+ const paddle::optional<paddle::Tensor>& cache_v_zp,
1000+ const std::string& cache_quant_type_str,
1001+ const bool use_neox_rotary_style,
1002+ const bool rope_3d,
1003+ const int max_seq_len,
1004+ cudaStream_t& stream,
1005+ paddle::Tensor* qkv_out,
1006+ paddle::Tensor* key_cache_out,
1007+ paddle::Tensor* value_cache_out,
1008+ const paddle::optional<paddle::Tensor>& q_norm_weight,
1009+ const paddle::optional<paddle::Tensor>& k_norm_weight,
1010+ const float rms_norm_eps);
1011+
1012+ template void
1013+ SpeculateWriteCacheWithRoPEKernel<paddle::bfloat16, paddle::bfloat16, true >(
1014+ const AppendAttnMetaData& meta_data,
1015+ const paddle::Tensor&
1016+ qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
1017+ // gqa_group_size, head_dim] if GQA)
1018+ const paddle::Tensor& seq_lens,
1019+ const paddle::Tensor& seq_lens_encoder,
1020+ const paddle::Tensor& batch_id_per_token,
1021+ const paddle::Tensor& cu_seqlens_q,
1022+ const paddle::Tensor& block_tables,
1023+ const paddle::optional<paddle::Tensor>& rotary_embs,
1024+ const paddle::optional<paddle::Tensor>& qkv_out_scales,
1025+ const paddle::optional<paddle::Tensor>& qkv_biases,
1026+ const paddle::optional<paddle::Tensor>& cache_k_scale,
1027+ const paddle::optional<paddle::Tensor>& cache_v_scale,
1028+ const paddle::optional<paddle::Tensor>& cache_k_zp,
1029+ const paddle::optional<paddle::Tensor>& cache_v_zp,
1030+ const std::string& cache_quant_type_str,
1031+ const bool use_neox_rotary_style,
1032+ const bool rope_3d,
1033+ const int max_seq_len,
1034+ cudaStream_t& stream,
1035+ paddle::Tensor* qkv_out,
1036+ paddle::Tensor* key_cache_out,
1037+ paddle::Tensor* value_cache_out,
1038+ const paddle::optional<paddle::Tensor>& q_norm_weight,
1039+ const paddle::optional<paddle::Tensor>& k_norm_weight,
1040+ const float rms_norm_eps);
1041+
1042+ template void SpeculateWriteCacheWithRoPEKernel<paddle::float16, int , true >(
1043+ const AppendAttnMetaData& meta_data,
1044+ const paddle::Tensor&
1045+ qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
1046+ // gqa_group_size, head_dim] if GQA)
1047+ const paddle::Tensor& seq_lens,
1048+ const paddle::Tensor& seq_lens_encoder,
1049+ const paddle::Tensor& batch_id_per_token,
1050+ const paddle::Tensor& cu_seqlens_q,
1051+ const paddle::Tensor& block_tables,
1052+ const paddle::optional<paddle::Tensor>& rotary_embs,
1053+ const paddle::optional<paddle::Tensor>& qkv_out_scales,
1054+ const paddle::optional<paddle::Tensor>& qkv_biases,
1055+ const paddle::optional<paddle::Tensor>& cache_k_scale,
1056+ const paddle::optional<paddle::Tensor>& cache_v_scale,
1057+ const paddle::optional<paddle::Tensor>& cache_k_zp,
1058+ const paddle::optional<paddle::Tensor>& cache_v_zp,
1059+ const std::string& cache_quant_type_str,
1060+ const bool use_neox_rotary_style,
1061+ const bool rope_3d,
1062+ const int max_seq_len,
1063+ cudaStream_t& stream,
1064+ paddle::Tensor* qkv_out,
1065+ paddle::Tensor* key_cache_out,
1066+ paddle::Tensor* value_cache_out,
1067+ const paddle::optional<paddle::Tensor>& q_norm_weight,
1068+ const paddle::optional<paddle::Tensor>& k_norm_weight,
1069+ const float rms_norm_eps);
1070+
1071+ template void
1072+ SpeculateWriteCacheWithRoPEKernel<paddle::float16, paddle::float16, true >(
1073+ const AppendAttnMetaData& meta_data,
1074+ const paddle::Tensor&
1075+ qkv, // [token_num, 3, num_head, head_dim] ([token_num, num_head + 2 *
1076+ // gqa_group_size, head_dim] if GQA)
1077+ const paddle::Tensor& seq_lens,
1078+ const paddle::Tensor& seq_lens_encoder,
1079+ const paddle::Tensor& batch_id_per_token,
1080+ const paddle::Tensor& cu_seqlens_q,
1081+ const paddle::Tensor& block_tables,
1082+ const paddle::optional<paddle::Tensor>& rotary_embs,
1083+ const paddle::optional<paddle::Tensor>& qkv_out_scales,
1084+ const paddle::optional<paddle::Tensor>& qkv_biases,
1085+ const paddle::optional<paddle::Tensor>& cache_k_scale,
1086+ const paddle::optional<paddle::Tensor>& cache_v_scale,
1087+ const paddle::optional<paddle::Tensor>& cache_k_zp,
1088+ const paddle::optional<paddle::Tensor>& cache_v_zp,
1089+ const std::string& cache_quant_type_str,
1090+ const bool use_neox_rotary_style,
1091+ const bool rope_3d,
1092+ const int max_seq_len,
1093+ cudaStream_t& stream,
1094+ paddle::Tensor* qkv_out,
1095+ paddle::Tensor* key_cache_out,
1096+ paddle::Tensor* value_cache_out,
1097+ const paddle::optional<paddle::Tensor>& q_norm_weight,
1098+ const paddle::optional<paddle::Tensor>& k_norm_weight,
1099+ const float rms_norm_eps);
0 commit comments