-
Notifications
You must be signed in to change notification settings - Fork 719
Expand file tree
/
Copy pathpybind.cpp
More file actions
201 lines (173 loc) · 10.1 KB
/
pybind.cpp
File metadata and controls
201 lines (173 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "cgemm_helper.h"
#include "common/util/cuda_runtime.h"
namespace transformer_engine {
namespace jax {
template <typename T>
pybind11::capsule EncapsulateFFI(T *fn) {
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
"Encapsulated function must be an XLA FFI handler");
return pybind11::capsule(reinterpret_cast<void *>(fn), "xla._CUSTOM_CALL_TARGET");
}
pybind11::dict Registrations() {
pybind11::dict dict;
// Activation
dict["te_act_lu_ffi"] =
pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(ActLuInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(ActLuHandler));
dict["te_dact_dbias_quantize_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(DActLuDBiasQuantizeInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler));
// Quantization
dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler);
dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);
// Softmax
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler);
dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler);
dict["te_scaled_masked_softmax_backward_ffi"] =
EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler);
dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler);
dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Normalization
dict["te_norm_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(NormForwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler));
dict["te_norm_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("initialize") = EncapsulateFFI(NormBackwardInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler));
// Attention
dict["te_fused_attn_forward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnForwardHandler));
dict["te_fused_attn_backward_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(FusedAttnBackwardHandler));
// GEMM
dict["te_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GemmHandler));
// Grouped GEMM
dict["te_grouped_gemm_d2h_group_sizes_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler));
dict["te_grouped_gemm_ffi"] =
pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler),
pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler));
// Amax
dict["te_rht_amax_ffi"] = pybind11::dict(
pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler),
pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler));
dict["te_inspect_ffi"] =
pybind11::dict(pybind11::arg("execute") = EncapsulateFFI(InspectHandler));
return dict;
}
PYBIND11_MODULE(transformer_engine_jax, m) {
m.def("registrations", &Registrations);
m.def("get_fused_attn_backend", &GetFusedAttnBackend);
m.def("get_cuda_version", &GetCudaRuntimeVersion);
m.def("get_cudnn_version", &GetCudnnRuntimeVersion);
m.def("get_device_compute_capability", &GetDeviceComputeCapability);
m.def("get_num_compute_streams", &nvte_get_num_compute_streams);
m.def("get_cublasLt_version", &cublasLtGetVersion);
m.def("get_dact_dbias_quantize_workspace_sizes", &GetDActDBiasQuantizeWorkspaceSizes);
m.def("get_dbias_quantize_workspace_sizes", &GetDBiasQuantizeWorkspaceSizes);
m.def("get_norm_fwd_workspace_sizes", &GetNormForwardWorkspaceSizes);
m.def("get_norm_bwd_workspace_sizes", &GetNormBackwardWorkspaceSizes);
m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes);
m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes);
m.def("nvte_get_qkv_format", &nvte_get_qkv_format);
m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported);
m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator);
m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams);
pybind11::enum_<DType>(m, "DType", pybind11::module_local())
.value("kByte", DType::kByte)
.value("kInt32", DType::kInt32)
.value("kInt64", DType::kInt64)
.value("kFloat32", DType::kFloat32)
.value("kFloat16", DType::kFloat16)
.value("kBFloat16", DType::kBFloat16)
.value("kFloat8E4M3", DType::kFloat8E4M3)
.value("kFloat8E5M2", DType::kFloat8E5M2)
.value("kFloat8E8M0", DType::kFloat8E8M0)
.value("kFloat4E2M1", DType::kFloat4E2M1);
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local())
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS)
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS)
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local())
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK)
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK)
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK)
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK",
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK);
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local())
.value("NVTE_BS3HD", NVTE_QKV_Layout::NVTE_BS3HD)
.value("NVTE_BSHD_BS2HD", NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
.value("NVTE_BSHD_BSHD_BSHD", NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)
.value("NVTE_T3HD", NVTE_QKV_Layout::NVTE_T3HD)
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD)
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD);
pybind11::enum_<NVTE_QKV_Format>(m, "NVTE_QKV_Format", pybind11::module_local())
.value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD)
.value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD)
.value("NVTE_THD", NVTE_QKV_Format::NVTE_THD);
pybind11::enum_<NVTE_Softmax_Type>(m, "NVTE_Softmax_Type", pybind11::module_local())
.value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)
.value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX)
.value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX);
pybind11::enum_<NVTE_Activation_Type>(m, "NVTE_Activation_Type", pybind11::module_local())
.value("GELU", NVTE_Activation_Type::GELU)
.value("GEGLU", NVTE_Activation_Type::GEGLU)
.value("SILU", NVTE_Activation_Type::SILU)
.value("SWIGLU", NVTE_Activation_Type::SWIGLU)
.value("RELU", NVTE_Activation_Type::RELU)
.value("REGLU", NVTE_Activation_Type::REGLU)
.value("QGELU", NVTE_Activation_Type::QGELU)
.value("QGEGLU", NVTE_Activation_Type::QGEGLU)
.value("SRELU", NVTE_Activation_Type::SRELU)
.value("SREGLU", NVTE_Activation_Type::SREGLU)
.value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU)
.export_values();
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", pybind11::module_local())
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend)
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8);
pybind11::enum_<NVTE_Norm_Type>(m, "NVTE_Norm_Type", pybind11::module_local())
.value("LayerNorm", NVTE_Norm_Type::LayerNorm)
.value("RMSNorm", NVTE_Norm_Type::RMSNorm)
.export_values();
pybind11::enum_<JAXX_Scaling_Mode>(m, "JAXX_Scaling_Mode", pybind11::module_local())
.value("NO_SCALING", JAXX_Scaling_Mode::NO_SCALING)
.value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING)
.value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING)
.value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING)
.value("NVFP4_1D_SCALING", JAXX_Scaling_Mode::NVFP4_1D_SCALING)
.value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING)
.export_values();
pybind11::enum_<JAXX_Quantize_Layout>(m, "JAXX_Quantize_Layout", pybind11::module_local())
.value("ROWWISE", JAXX_Quantize_Layout::ROWWISE)
.value("COLWISE", JAXX_Quantize_Layout::COLWISE)
.value("ROWWISE_COLWISE", JAXX_Quantize_Layout::ROWWISE_COLWISE)
.export_values();
pybind11::enum_<JAXX_Collective_Op>(m, "JAXX_Collective_Op", pybind11::module_local())
.value("NONE", JAXX_Collective_Op::NONE)
.value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER)
.value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER)
.export_values();
}
} // namespace jax
} // namespace transformer_engine