@@ -15,6 +15,17 @@ namespace {
1515using namespace infini_train ::common::cuda;
1616constexpr int kWarpSize = 32 ;
1717
18+ // Aligned vector type for vectorized loads/stores (128-bit).
19+ template <typename T, int N>
20+ struct __align__ (sizeof (T) * N) aligned_vector {
21+ T val[N];
22+ };
23+
24+ // Elements per vectorized load/store: 128-bit / sizeof(T).
25+ // float → 4, bf16/half → 8, double → 2.
26+ template <typename T>
27+ constexpr int kVecSize = 16 / sizeof (T);
28+
1829template <typename T, typename Func>
1930__global__ void UnaryForwardKernel (T *output, Func fn, size_t num_elements, size_t offset, const T *input) {
2031 size_t idx = blockIdx .x * blockDim .x + threadIdx .x + offset;
@@ -88,6 +99,60 @@ __global__ void BinaryBackwardKernelNoBroadcastFast(T *__restrict__ outA, T *__r
8899 }
89100}
90101
102+ // Vectorized fast path backward: no broadcast, contiguous.
103+ // Each thread processes VecSize elements using 128-bit loads/stores.
104+ template <typename T, int VecSize, typename FuncA, typename FuncB>
105+ __global__ void BinaryBackwardKernelNoBroadcastVectorized (T *__restrict__ outA, T *__restrict__ outB, FuncA fn_a,
106+ FuncB fn_b, size_t numel, const T *__restrict__ grad_out,
107+ const T *__restrict__ inA, const T *__restrict__ inB) {
108+ using VecT = aligned_vector<T, VecSize>;
109+ const size_t num_vecs = numel / VecSize;
110+ const size_t grid_stride = static_cast <size_t >(gridDim .x ) * blockDim .x ;
111+
112+ for (size_t vid = static_cast <size_t >(blockIdx .x ) * blockDim .x + threadIdx .x ; vid < num_vecs;
113+ vid += grid_stride) {
114+ const size_t base = vid * VecSize;
115+
116+ // 128-bit vectorized loads
117+ VecT g_vec = *reinterpret_cast <const VecT *>(&grad_out[base]);
118+ VecT a_vec, b_vec;
119+ if (inA) {
120+ a_vec = *reinterpret_cast <const VecT *>(&inA[base]);
121+ } else {
122+ #pragma unroll
123+ for (int i = 0 ; i < VecSize; ++i) a_vec.val [i] = T (0 );
124+ }
125+ if (inB) {
126+ b_vec = *reinterpret_cast <const VecT *>(&inB[base]);
127+ } else {
128+ #pragma unroll
129+ for (int i = 0 ; i < VecSize; ++i) b_vec.val [i] = T (0 );
130+ }
131+
132+ // Element-wise computation
133+ VecT outA_vec, outB_vec;
134+ #pragma unroll
135+ for (int i = 0 ; i < VecSize; ++i) {
136+ outA_vec.val [i] = Mul<T>(g_vec.val [i], fn_a (a_vec.val [i], b_vec.val [i]));
137+ outB_vec.val [i] = Mul<T>(g_vec.val [i], fn_b (a_vec.val [i], b_vec.val [i]));
138+ }
139+
140+ // 128-bit vectorized stores
141+ *reinterpret_cast <VecT *>(&outA[base]) = outA_vec;
142+ *reinterpret_cast <VecT *>(&outB[base]) = outB_vec;
143+ }
144+
145+ // Handle tail elements (numel % VecSize != 0)
146+ const size_t tail_start = num_vecs * VecSize;
147+ for (size_t idx = tail_start + static_cast <size_t >(blockIdx .x ) * blockDim .x + threadIdx .x ; idx < numel;
148+ idx += grid_stride) {
149+ const T a = inA ? inA[idx] : T (0 );
150+ const T b = inB ? inB[idx] : T (0 );
151+ outA[idx] = Mul<T>(grad_out[idx], fn_a (a, b));
152+ outB[idx] = Mul<T>(grad_out[idx], fn_b (a, b));
153+ }
154+ }
155+
91156// Helper to choose optimal block size based on tensor size
92157inline size_t ChooseBlockSize (size_t num_elements) {
93158 if (num_elements < 1024 ) return 64 ;
@@ -582,10 +647,28 @@ void LaunchBackward(FuncA fun_a, FuncB fun_b, const std::shared_ptr<Tensor> &out
582647 = [](const auto &...ts ) { return std::make_tuple (static_cast <const T *>(ts ? ts->DataPtr () : nullptr )...); };
583648 auto [input_a_ptr, input_b_ptr] = extract_ptrs (inputs...);
584649
585- dim3 block_dims (std::min (BLOCK_SIZE, static_cast <size_t >(1024 )));
586- dim3 grid_dims (std::min (CEIL_DIV (num_elements, block_dims.x ), static_cast <size_t >(65535 )));
587- BinaryBackwardKernelNoBroadcastFast<<<grid_dims, block_dims, 0 , stream>>> (
588- output_a_ptr, output_b_ptr, fun_a, fun_b, num_elements, grad_output_ptr, input_a_ptr, input_b_ptr);
650+ constexpr int VecSize = kVecSize <T>;
651+ // Use vectorized kernel if all pointers are 16-byte aligned and numel is large enough
652+ const bool can_vectorize
653+ = (num_elements >= static_cast <size_t >(VecSize))
654+ && (reinterpret_cast <uintptr_t >(output_a_ptr) % (sizeof (T) * VecSize) == 0 )
655+ && (reinterpret_cast <uintptr_t >(output_b_ptr) % (sizeof (T) * VecSize) == 0 )
656+ && (reinterpret_cast <uintptr_t >(grad_output_ptr) % (sizeof (T) * VecSize) == 0 )
657+ && (!input_a_ptr || reinterpret_cast <uintptr_t >(input_a_ptr) % (sizeof (T) * VecSize) == 0 )
658+ && (!input_b_ptr || reinterpret_cast <uintptr_t >(input_b_ptr) % (sizeof (T) * VecSize) == 0 );
659+
660+ if (can_vectorize) {
661+ const size_t num_vecs = num_elements / VecSize;
662+ dim3 block_dims (std::min (static_cast <size_t >(256 ), std::min (num_vecs, static_cast <size_t >(1024 ))));
663+ dim3 grid_dims (std::min (CEIL_DIV (num_vecs, block_dims.x ), static_cast <size_t >(65535 )));
664+ BinaryBackwardKernelNoBroadcastVectorized<T, VecSize><<<grid_dims, block_dims, 0 , stream>>> (
665+ output_a_ptr, output_b_ptr, fun_a, fun_b, num_elements, grad_output_ptr, input_a_ptr, input_b_ptr);
666+ } else {
667+ dim3 block_dims (std::min (BLOCK_SIZE, static_cast <size_t >(1024 )));
668+ dim3 grid_dims (std::min (CEIL_DIV (num_elements, block_dims.x ), static_cast <size_t >(65535 )));
669+ BinaryBackwardKernelNoBroadcastFast<<<grid_dims, block_dims, 0 , stream>>> (
670+ output_a_ptr, output_b_ptr, fun_a, fun_b, num_elements, grad_output_ptr, input_a_ptr, input_b_ptr);
671+ }
589672 return ;
590673 }
591674
0 commit comments