-
Notifications
You must be signed in to change notification settings - Fork 152
Expand file tree
/
Copy pathreference.py
More file actions
194 lines (165 loc) · 7.45 KB
/
reference.py
File metadata and controls
194 lines (165 loc) · 7.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import torch
from task import input_t, output_t
from utils import make_match_reference
# Scaling factor vector size
sf_vec_size = 16
# Helper function for ceiling division
def ceil_div(a, b):
return (a + b - 1) // b
# Helper function to convert scale factor tensor to blocked format
def to_blocked(input_matrix):
rows, cols = input_matrix.shape
# Please ensure rows and cols are multiples of 128 and 4 respectively
n_row_blocks = ceil_div(rows, 128)
n_col_blocks = ceil_div(cols, 4)
padded = input_matrix
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)
rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16)
return rearranged.flatten()
def ref_kernel(
data: input_t,
) -> output_t:
"""
PyTorch reference implementation of NVFP4 block-scaled dual GEMM with silu activation,
C = silu(A @ B1) * (A @ B2).
"""
a_ref, b1_ref, b2_ref, sfa_ref_cpu, sfb1_ref_cpu, sfb2_ref_cpu, _, _, _, c_ref = data
# Get dimensions from MxNxL layout
m, n, l = c_ref.shape
# Call torch._scaled_mm to compute the GEMV result
ref1 = torch.empty(
(l, m, n),
dtype=torch.float32,
device="cuda",
).permute(1, 2, 0)
ref2 = torch.empty(
(l, m, n),
dtype=torch.float32,
device="cuda",
).permute(1, 2, 0)
for l_idx in range(l):
# Convert the scale factor tensor to blocked format
scale_a = to_blocked(sfa_ref_cpu[:, :, l_idx])
scale_b1 = to_blocked(sfb1_ref_cpu[:, :, l_idx])
scale_b2 = to_blocked(sfb2_ref_cpu[:, :, l_idx])
# (m, k) @ (n, k).T -> (m, n)
res1 = torch._scaled_mm(
a_ref[:, :, l_idx],
b1_ref[:, :, l_idx].transpose(0, 1),
scale_a.cuda(),
scale_b1.cuda(),
bias=None,
out_dtype=torch.float32,
)
ref1[:, :, l_idx] = res1
res2 = torch._scaled_mm(
a_ref[:, :, l_idx],
b2_ref[:, :, l_idx].transpose(0, 1),
scale_a.cuda(),
scale_b2.cuda(),
bias=None,
out_dtype=torch.float32,
)
ref2[:, :, l_idx] = res2
# Do silu on the first GEMM result and multiply with the second GEMM result
c_ref = (torch.nn.functional.silu(ref1) * ref2).to(torch.float16)
return c_ref
def generate_input(
m: int,
n: int,
k: int,
l: int,
seed: int,
):
"""
Generate input tensors for NVFP4 block-scaled dual GEMM with silu activation,
C = silu(A @ B1) * (A @ B2).
Args:
m: Number of rows in matrix A
n: Number of columns in matrix B1 and B2
k: Number of columns in A and rows of B1 and B2
l: Batch size
seed: Random seed for reproducibility
Returns:
Tuple of (a, b, scale_a, scale_b, c) where:
a: [m, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
b1: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
b2: [n, k, l] - Input matrix in torch.float4e2m1fn_x2 data type
scale_a: [m, k, l] - Input scale factors in torch.float8e4m3fn data type
scale_b1: [n, k, l] - Input scale factors in torch.float8e4m3fn data type
scale_b2: [n, k, l] - Input scale factors in torch.float8e4m3fn data type
scale_a_permuted: [32, 4, rest_m, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
scale_b1_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
scale_b2_permuted: [32, 4, rest_n, 4, rest_k, l] - Input scale factors in torch.float8e4m3fn data type
c: [m, n, l] - Output matrix in torch.float16 data type
"""
torch.manual_seed(seed)
# Generate uint8 tensor, then convert to float4e2m1fn_x2 data type
a_ref = torch.randint(
0, 4, (l, m, k // 2), dtype=torch.int8, device="cuda"
).permute(1, 2, 0)
b1_ref = torch.randint(
0, 4, (l, n, k // 2), dtype=torch.int8, device="cuda"
).permute(1, 2, 0)
b2_ref = torch.randint(
0, 4, (l, n, k // 2), dtype=torch.int8, device="cuda"
).permute(1, 2, 0)
a_ref = a_ref.view(torch.float4_e2m1fn_x2)
b1_ref = b1_ref.view(torch.float4_e2m1fn_x2)
b2_ref = b2_ref.view(torch.float4_e2m1fn_x2)
# Create float16 output tensor
c_ref = torch.randn((l, m, n), dtype=torch.float16, device="cuda").permute(
1, 2, 0
)
# Helper function to prepare the scale factor tensors for both reference
# kernel and customize kernel. The customized data layout can be found in:
# https://docs.nvidia.com/cuda/cublas/index.html?highlight=fp4#d-block-scaling-factors-layout
def create_scale_factor_tensors(l, mn, sf_k):
# Create the reference scale factor tensor (mn, sf_k, l) on CPU.
ref_shape = (l, mn, sf_k)
ref_permute_order = (1, 2, 0)
# Init with uint8 tensor, then convert to float8_e4m3fn
ref_f8_random_int = torch.randint(0, 3, ref_shape, dtype=torch.int8, device='cuda')
ref_f8_torch_tensor = ref_f8_random_int.to(dtype=torch.float8_e4m3fn)
# permute to match ref_permute_order
ref_f8_torch_tensor_permuted = ref_f8_torch_tensor.permute(*ref_permute_order)
atom_m = (32, 4)
atom_k = 4
mma_shape = (
l, # batch size
ceil_div(mn, atom_m[0] * atom_m[1]),
ceil_div(sf_k, atom_k),
atom_m[0],
atom_m[1],
atom_k,
)
# Reorder scale factor tensor to (32, 4, rest_m, 4, rest_k, l) layout
# Which is needed by the CuTe customized kernel
mma_permute_order = (3, 4, 1, 5, 2, 0)
# Generate a random int8 tensor, then convert to float8_e4m3fn
rand_int_tensor = torch.empty(mma_shape, dtype=torch.int8, device='cuda')
reordered_f8_torch_tensor = rand_int_tensor.to(dtype=torch.float8_e4m3fn)
# Permute according to mma_permute_order
reordered_f8_torch_tensor = reordered_f8_torch_tensor.permute(*mma_permute_order)
# GPU-side vectorized reordering (replaces slow CPU nested loops)
# Create index grids for all dimensions
i_idx = torch.arange(mn, device='cuda')
j_idx = torch.arange(sf_k, device='cuda')
b_idx = torch.arange(l, device='cuda')
# Create meshgrid for all combinations of (i, j, b)
i_grid, j_grid, b_grid = torch.meshgrid(i_idx, j_idx, b_idx, indexing='ij')
# Calculate target indices in vectorized manner
mm = i_grid // (atom_m[0] * atom_m[1])
mm32 = i_grid % atom_m[0]
mm4 = (i_grid % 128) // atom_m[0]
kk = j_grid // atom_k
kk4 = j_grid % atom_k
# Perform the reordering with advanced indexing (all on GPU)
reordered_f8_torch_tensor[mm32, mm4, mm, kk4, kk, b_grid] = ref_f8_torch_tensor_permuted[i_grid, j_grid, b_grid]
return ref_f8_torch_tensor_permuted.cpu(), reordered_f8_torch_tensor
sf_k = ceil_div(k, sf_vec_size)
sfa_ref_cpu, sfa_ref_permuted = create_scale_factor_tensors(l, m, sf_k)
sfb1_ref_cpu, sfb1_ref_permuted = create_scale_factor_tensors(l, n, sf_k)
sfb2_ref_cpu, sfb2_ref_permuted = create_scale_factor_tensors(l, n, sf_k)
return (a_ref, b1_ref, b2_ref, sfa_ref_cpu.to("cuda"), sfb1_ref_cpu.to("cuda"), sfb2_ref_cpu.to("cuda"), sfa_ref_permuted, sfb1_ref_permuted, sfb2_ref_permuted, c_ref)
check_implementation = make_match_reference(ref_kernel, rtol=1e-03, atol=1e-03)