Skip to content

Commit 946523f

Browse files
chen2021673claude
andcommitted
perf(elementwise): add vectorized 128-bit load/store for NoBroadcast backward
Add aligned_vector<T,N>, kVecSize<T>, and BinaryBackwardKernelNoBroadcastVectorized that processes 8 bf16 elements per thread via 128-bit loads/stores. ncu results on A800 (bf16 MulBackward, ~10M elements): Non-vectorized: 266 us, 1.54 TB/s DRAM (76% util) Vectorized: 229 us, 1.77 TB/s DRAM (87% util) — 15% faster Small tensors (~1.3M): 32 us → 18 us — 45% faster Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 6012445 commit 946523f

1 file changed

Lines changed: 87 additions & 4 deletions

File tree

infini_train/src/kernels/cuda/elementwise.cu

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ namespace {
1515
using namespace infini_train::common::cuda;
1616
constexpr int kWarpSize = 32;
1717

18+
// Aligned vector type for vectorized loads/stores (128-bit).
19+
template <typename T, int N>
20+
struct __align__(sizeof(T) * N) aligned_vector {
21+
T val[N];
22+
};
23+
24+
// Elements per vectorized load/store: 128-bit / sizeof(T).
25+
// float → 4, bf16/half → 8, double → 2.
26+
template <typename T>
27+
constexpr int kVecSize = 16 / sizeof(T);
28+
1829
template <typename T, typename Func>
1930
__global__ void UnaryForwardKernel(T *output, Func fn, size_t num_elements, size_t offset, const T *input) {
2031
size_t idx = blockIdx.x * blockDim.x + threadIdx.x + offset;
@@ -88,6 +99,60 @@ __global__ void BinaryBackwardKernelNoBroadcastFast(T *__restrict__ outA, T *__r
8899
}
89100
}
90101

102+
// Vectorized fast path backward: no broadcast, contiguous.
103+
// Each thread processes VecSize elements using 128-bit loads/stores.
104+
template <typename T, int VecSize, typename FuncA, typename FuncB>
105+
__global__ void BinaryBackwardKernelNoBroadcastVectorized(T *__restrict__ outA, T *__restrict__ outB, FuncA fn_a,
106+
FuncB fn_b, size_t numel, const T *__restrict__ grad_out,
107+
const T *__restrict__ inA, const T *__restrict__ inB) {
108+
using VecT = aligned_vector<T, VecSize>;
109+
const size_t num_vecs = numel / VecSize;
110+
const size_t grid_stride = static_cast<size_t>(gridDim.x) * blockDim.x;
111+
112+
for (size_t vid = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x; vid < num_vecs;
113+
vid += grid_stride) {
114+
const size_t base = vid * VecSize;
115+
116+
// 128-bit vectorized loads
117+
VecT g_vec = *reinterpret_cast<const VecT *>(&grad_out[base]);
118+
VecT a_vec, b_vec;
119+
if (inA) {
120+
a_vec = *reinterpret_cast<const VecT *>(&inA[base]);
121+
} else {
122+
#pragma unroll
123+
for (int i = 0; i < VecSize; ++i) a_vec.val[i] = T(0);
124+
}
125+
if (inB) {
126+
b_vec = *reinterpret_cast<const VecT *>(&inB[base]);
127+
} else {
128+
#pragma unroll
129+
for (int i = 0; i < VecSize; ++i) b_vec.val[i] = T(0);
130+
}
131+
132+
// Element-wise computation
133+
VecT outA_vec, outB_vec;
134+
#pragma unroll
135+
for (int i = 0; i < VecSize; ++i) {
136+
outA_vec.val[i] = Mul<T>(g_vec.val[i], fn_a(a_vec.val[i], b_vec.val[i]));
137+
outB_vec.val[i] = Mul<T>(g_vec.val[i], fn_b(a_vec.val[i], b_vec.val[i]));
138+
}
139+
140+
// 128-bit vectorized stores
141+
*reinterpret_cast<VecT *>(&outA[base]) = outA_vec;
142+
*reinterpret_cast<VecT *>(&outB[base]) = outB_vec;
143+
}
144+
145+
// Handle tail elements (numel % VecSize != 0)
146+
const size_t tail_start = num_vecs * VecSize;
147+
for (size_t idx = tail_start + static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x; idx < numel;
148+
idx += grid_stride) {
149+
const T a = inA ? inA[idx] : T(0);
150+
const T b = inB ? inB[idx] : T(0);
151+
outA[idx] = Mul<T>(grad_out[idx], fn_a(a, b));
152+
outB[idx] = Mul<T>(grad_out[idx], fn_b(a, b));
153+
}
154+
}
155+
91156
// Helper to choose optimal block size based on tensor size
92157
inline size_t ChooseBlockSize(size_t num_elements) {
93158
if (num_elements < 1024) return 64;
@@ -582,10 +647,28 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr<Tensor> &out
582647
= [](const auto &...ts) { return std::make_tuple(static_cast<const T *>(ts ? ts->DataPtr() : nullptr)...); };
583648
auto [input_a_ptr, input_b_ptr] = extract_ptrs(inputs...);
584649

585-
dim3 block_dims(std::min(BLOCK_SIZE, static_cast<size_t>(1024)));
586-
dim3 grid_dims(std::min(CEIL_DIV(num_elements, block_dims.x), static_cast<size_t>(65535)));
587-
BinaryBackwardKernelNoBroadcastFast<<<grid_dims, block_dims, 0, stream>>>(
588-
output_a_ptr, output_b_ptr, fun_a, fun_b, num_elements, grad_output_ptr, input_a_ptr, input_b_ptr);
650+
constexpr int VecSize = kVecSize<T>;
651+
// Use vectorized kernel if all pointers are 16-byte aligned and numel is large enough
652+
const bool can_vectorize
653+
= (num_elements >= static_cast<size_t>(VecSize))
654+
&& (reinterpret_cast<uintptr_t>(output_a_ptr) % (sizeof(T) * VecSize) == 0)
655+
&& (reinterpret_cast<uintptr_t>(output_b_ptr) % (sizeof(T) * VecSize) == 0)
656+
&& (reinterpret_cast<uintptr_t>(grad_output_ptr) % (sizeof(T) * VecSize) == 0)
657+
&& (!input_a_ptr || reinterpret_cast<uintptr_t>(input_a_ptr) % (sizeof(T) * VecSize) == 0)
658+
&& (!input_b_ptr || reinterpret_cast<uintptr_t>(input_b_ptr) % (sizeof(T) * VecSize) == 0);
659+
660+
if (can_vectorize) {
661+
const size_t num_vecs = num_elements / VecSize;
662+
dim3 block_dims(std::min(static_cast<size_t>(256), std::min(num_vecs, static_cast<size_t>(1024))));
663+
dim3 grid_dims(std::min(CEIL_DIV(num_vecs, block_dims.x), static_cast<size_t>(65535)));
664+
BinaryBackwardKernelNoBroadcastVectorized<T, VecSize><<<grid_dims, block_dims, 0, stream>>>(
665+
output_a_ptr, output_b_ptr, fun_a, fun_b, num_elements, grad_output_ptr, input_a_ptr, input_b_ptr);
666+
} else {
667+
dim3 block_dims(std::min(BLOCK_SIZE, static_cast<size_t>(1024)));
668+
dim3 grid_dims(std::min(CEIL_DIV(num_elements, block_dims.x), static_cast<size_t>(65535)));
669+
BinaryBackwardKernelNoBroadcastFast<<<grid_dims, block_dims, 0, stream>>>(
670+
output_a_ptr, output_b_ptr, fun_a, fun_b, num_elements, grad_output_ptr, input_a_ptr, input_b_ptr);
671+
}
589672
return;
590673
}
591674

0 commit comments

Comments
 (0)