From 7418697ff07c94d289b446049f4002e128adeef5 Mon Sep 17 00:00:00 2001 From: LindseyMei <648816901@qq.com> Date: Tue, 30 Jun 2026 10:25:29 +0000 Subject: [PATCH] perf(metax): vectorize elementwise kernel for contiguous aligned tensors Add a 16-byte vector fast path to the shared MetaX elementwise template. When output and all inputs are contiguous, aligned, non-broadcasted, and share the same floating-point dtype, load/store packed values and apply the existing scalar Op functor per component. Falls back to the original scalar kernel for all other cases. Supported dtypes: float (float4), half (Pack), cuda_bfloat16 (Pack), double (double2). Integer/bool/fp8 and mixed-dtype ops continue to use the scalar path. Benchmark (silu, MetaX C500): - F32 16384x16384: ~59 Gelem/s -> ~177 Gelem/s (~3x) - F16 16384x16384: ~64 Gelem/s -> ~244 Gelem/s (~3.8x) - BF16 16384x16384: ~64 Gelem/s -> ~215 Gelem/s (~3.4x) Regression tests passed: silu, add, mul, reciprocal, gelu, swiglu, clip. Signed-off-by: LindseyMei <648816901@qq.com> --- .../elementwise/metax/elementwise_metax.h | 168 ++++++++++++++++++ 1 file changed, 168 insertions(+) diff --git a/src/infiniop/elementwise/metax/elementwise_metax.h b/src/infiniop/elementwise/metax/elementwise_metax.h index 084677ea7..52e4ddf54 100644 --- a/src/infiniop/elementwise/metax/elementwise_metax.h +++ b/src/infiniop/elementwise/metax/elementwise_metax.h @@ -5,6 +5,7 @@ #include "../../devices/metax/metax_common.h" #include "../../devices/metax/metax_kernel_common.h" #include "elementwise_metax_api.h" +#include namespace op::elementwise::metax { template @@ -12,11 +13,92 @@ __device__ __forceinline__ const T *typedInputPtr(const void *ptr) { return reinterpret_cast(ptr); } +// Generic aligned N-element pack used for vectorized load/store. +template +struct alignas(sizeof(T) * N) Pack { + T val[N]; +}; + +// Per-dtype vectorization info. Only floating-point types are enabled. +// Integer/bool/fp8 types keep using the scalar fallback automatically. +template +struct VecInfo { + static constexpr bool enabled = false; +}; + +template <> +struct VecInfo { + static constexpr bool enabled = true; + static constexpr int pack_size = 4; + using Type = Pack; +}; + +template <> +struct VecInfo { + static constexpr bool enabled = true; + static constexpr int pack_size = 8; + using Type = Pack; +}; + +template <> +struct VecInfo { + static constexpr bool enabled = true; + static constexpr int pack_size = 8; + using Type = Pack; +}; + +template <> +struct VecInfo { + static constexpr bool enabled = true; + static constexpr int pack_size = 2; + using Type = Pack; +}; + +template +__device__ __forceinline__ void loadInputVectors(VecT *in_vecs, + const Tdata *const *typed_inputs, + size_t base) { +#pragma unroll + for (size_t i = 0; i < N; ++i) { + in_vecs[i] = *reinterpret_cast(typed_inputs[i] + base); + } +} + __device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim, const size_t *shape, const ptrdiff_t *strides) { return is_contiguous ? idx : device::metax::indexToOffset(idx, ndim, shape, strides); } +template +bool canUseVecPath(const op::elementwise::ElementwiseInfo &info, + const void *output, + const std::vector &inputs) { + if constexpr (!VecInfo::enabled) { + return false; + } + if (!info.isOutputContiguous()) { + return false; + } + const bool *contiguous = info.getInputContiguous(); + const bool *broadcasted = info.getInputBroadcasted(); + for (size_t i = 0; i < N; ++i) { + if (!contiguous[i] || broadcasted[i]) { + return false; + } + } + // Require 16-byte alignment of output and every input pointer. + constexpr std::uintptr_t mask = 0xF; + if ((reinterpret_cast(output) & mask) != 0) { + return false; + } + for (const void *p : inputs) { + if ((reinterpret_cast(p) & mask) != 0) { + return false; + } + } + return true; +} + struct InputIndexer { size_t idx; size_t ndim; @@ -38,6 +120,47 @@ __device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence< f(std::integral_constant{}...); } +template +INFINIOP_METAX_KERNEL elementwiseVecKernel( + size_t output_size, + Tdata *__restrict__ output, + const void *const *__restrict__ inputs, + Args... args) { + + const Tdata *const *typed_inputs = reinterpret_cast(inputs); + const size_t num_packs = output_size / V; + const size_t tail_start = num_packs * V; + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + // Vectorized main loop over 16-byte packs + for (size_t pack_idx = tid; pack_idx < num_packs; pack_idx += stride) { + const size_t base = pack_idx * V; + VecT in_vecs[N]; + loadInputVectors(in_vecs, typed_inputs, base); + + VecT out_vec; +#pragma unroll + for (int k = 0; k < V; ++k) { + unpackInputsAndApply( + [&](auto... Is) { + out_vec.val[k] = Op{}(in_vecs[Is.value].val[k]..., std::forward(args)...); + }, + std::make_index_sequence{}); + } + *reinterpret_cast(output + base) = out_vec; + } + + // Scalar tail for remaining elements + for (size_t idx = tail_start + tid; idx < output_size; idx += stride) { + unpackInputsAndApply( + [&](auto... Is) { + output[idx] = Op{}(typed_inputs[Is.value][idx]..., std::forward(args)...); + }, + std::make_index_sequence{}); + } +} + template INFINIOP_METAX_KERNEL elementwiseKernel( size_t output_size, @@ -112,6 +235,16 @@ struct DeviceImpl::Opaque { const std::vector &inputs, hcStream_t stream, Args &&...args) { + if constexpr (VecInfo::enabled) { + if (canUseVecPath(info, output, inputs)) { + return launchElementwiseVecKernel::Type, + VecInfo::pack_size>( + info, workspace, + reinterpret_cast(output), inputs, stream, + std::forward(args)...); + } + } return launchElementwiseKernel( info, workspace, reinterpret_cast(output), inputs, @@ -136,6 +269,41 @@ struct DeviceImpl::Opaque { } private: + template + infiniStatus_t launchElementwiseVecKernel( + const op::elementwise::ElementwiseInfo &info, + void *workspace, + Tdata *output, + const std::vector &inputs, + hcStream_t stream, + Args &&...args) { + + const auto output_size = info.getOutputSize(); + if (output_size == 0) { + return INFINI_STATUS_SUCCESS; + } + + CHECK_METAX(hcMemcpyAsync(workspace, inputs.data(), N * sizeof(*inputs.data()), + hcMemcpyHostToDevice, stream)); + const void **d_inputs_arr = reinterpret_cast(workspace); + + dim3 blockDims(std::min(BLOCK_SIZE, static_cast(internal->maxThreadsPerBlock()))); + const size_t num_packs = output_size / V; + const size_t tail_size = output_size - num_packs * V; + const size_t grid_work = num_packs > 0 ? num_packs : tail_size; + dim3 gridDims(std::min(static_cast(CEIL_DIV(grid_work, blockDims.x)), + static_cast(internal->gridSizeX()))); + if (gridDims.x == 0) { + gridDims.x = 1; + } + + elementwiseVecKernel + <<>>( + output_size, output, d_inputs_arr, std::forward(args)...); + + return INFINI_STATUS_SUCCESS; + } + template infiniStatus_t infoToDevice( const op::elementwise::ElementwiseInfo &info,