Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 168 additions & 0 deletions src/infiniop/elementwise/metax/elementwise_metax.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,100 @@
#include "../../devices/metax/metax_common.h"
#include "../../devices/metax/metax_kernel_common.h"
#include "elementwise_metax_api.h"
#include <cstdint>

namespace op::elementwise::metax {
template <typename T>
__device__ __forceinline__ const T *typedInputPtr(const void *ptr) {
return reinterpret_cast<const T *>(ptr);
}

// Generic aligned N-element pack used for vectorized load/store.
template <typename T, int N>
struct alignas(sizeof(T) * N) Pack {
T val[N];
};

// Per-dtype vectorization info. Only floating-point types are enabled.
// Integer/bool/fp8 types keep using the scalar fallback automatically.
template <typename T>
struct VecInfo {
static constexpr bool enabled = false;
};

template <>
struct VecInfo<float> {
static constexpr bool enabled = true;
static constexpr int pack_size = 4;
using Type = Pack<float, pack_size>;
};

template <>
struct VecInfo<half> {
static constexpr bool enabled = true;
static constexpr int pack_size = 8;
using Type = Pack<half, pack_size>;
};

template <>
struct VecInfo<cuda_bfloat16> {
static constexpr bool enabled = true;
static constexpr int pack_size = 8;
using Type = Pack<cuda_bfloat16, pack_size>;
};

template <>
struct VecInfo<double> {
static constexpr bool enabled = true;
static constexpr int pack_size = 2;
using Type = Pack<double, pack_size>;
};

template <typename Tdata, typename VecT, size_t N>
__device__ __forceinline__ void loadInputVectors(VecT *in_vecs,
const Tdata *const *typed_inputs,
size_t base) {
#pragma unroll
for (size_t i = 0; i < N; ++i) {
in_vecs[i] = *reinterpret_cast<const VecT *>(typed_inputs[i] + base);
}
}

__device__ __forceinline__ size_t getOutputIndex(size_t idx, bool is_contiguous, size_t ndim,
const size_t *shape, const ptrdiff_t *strides) {
return is_contiguous ? idx : device::metax::indexToOffset(idx, ndim, shape, strides);
}

template <typename Tdata, size_t N>
bool canUseVecPath(const op::elementwise::ElementwiseInfo &info,
const void *output,
const std::vector<const void *> &inputs) {
if constexpr (!VecInfo<Tdata>::enabled) {
return false;
}
if (!info.isOutputContiguous()) {
return false;
}
const bool *contiguous = info.getInputContiguous();
const bool *broadcasted = info.getInputBroadcasted();
for (size_t i = 0; i < N; ++i) {
if (!contiguous[i] || broadcasted[i]) {
return false;
}
}
// Require 16-byte alignment of output and every input pointer.
constexpr std::uintptr_t mask = 0xF;
if ((reinterpret_cast<std::uintptr_t>(output) & mask) != 0) {
return false;
}
for (const void *p : inputs) {
if ((reinterpret_cast<std::uintptr_t>(p) & mask) != 0) {
return false;
}
}
return true;
}

struct InputIndexer {
size_t idx;
size_t ndim;
Expand All @@ -38,6 +120,47 @@ __device__ __forceinline__ void unpackInputsAndApply(F &&f, std::index_sequence<
f(std::integral_constant<size_t, Is>{}...);
}

template <size_t N, typename Op, typename Tdata, typename VecT, int V, typename... Args>
INFINIOP_METAX_KERNEL elementwiseVecKernel(
size_t output_size,
Tdata *__restrict__ output,
const void *const *__restrict__ inputs,
Args... args) {

const Tdata *const *typed_inputs = reinterpret_cast<const Tdata *const *>(inputs);
const size_t num_packs = output_size / V;
const size_t tail_start = num_packs * V;
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;

// Vectorized main loop over 16-byte packs
for (size_t pack_idx = tid; pack_idx < num_packs; pack_idx += stride) {
const size_t base = pack_idx * V;
VecT in_vecs[N];
loadInputVectors<Tdata, VecT, N>(in_vecs, typed_inputs, base);

VecT out_vec;
#pragma unroll
for (int k = 0; k < V; ++k) {
unpackInputsAndApply(
[&](auto... Is) {
out_vec.val[k] = Op{}(in_vecs[Is.value].val[k]..., std::forward<Args>(args)...);
},
std::make_index_sequence<N>{});
}
*reinterpret_cast<VecT *>(output + base) = out_vec;
}

// Scalar tail for remaining elements
for (size_t idx = tail_start + tid; idx < output_size; idx += stride) {
unpackInputsAndApply(
[&](auto... Is) {
output[idx] = Op{}(typed_inputs[Is.value][idx]..., std::forward<Args>(args)...);
},
std::make_index_sequence<N>{});
}
}

template <size_t N, typename Op, typename Tdata, typename... Args>
INFINIOP_METAX_KERNEL elementwiseKernel(
size_t output_size,
Expand Down Expand Up @@ -112,6 +235,16 @@ struct DeviceImpl::Opaque {
const std::vector<const void *> &inputs,
hcStream_t stream,
Args &&...args) {
if constexpr (VecInfo<Tdata>::enabled) {
if (canUseVecPath<Tdata, N>(info, output, inputs)) {
return launchElementwiseVecKernel<BLOCK_SIZE, N, Op, Tdata,
typename VecInfo<Tdata>::Type,
VecInfo<Tdata>::pack_size>(
info, workspace,
reinterpret_cast<Tdata *>(output), inputs, stream,
std::forward<Args>(args)...);
}
}
return launchElementwiseKernel<BLOCK_SIZE, N>(
info, workspace,
reinterpret_cast<Tdata *>(output), inputs,
Expand All @@ -136,6 +269,41 @@ struct DeviceImpl::Opaque {
}

private:
template <uint32_t BLOCK_SIZE, size_t N, typename Op, typename Tdata, typename VecT, int V, typename... Args>
infiniStatus_t launchElementwiseVecKernel(
const op::elementwise::ElementwiseInfo &info,
void *workspace,
Tdata *output,
const std::vector<const void *> &inputs,
hcStream_t stream,
Args &&...args) {

const auto output_size = info.getOutputSize();
if (output_size == 0) {
return INFINI_STATUS_SUCCESS;
}

CHECK_METAX(hcMemcpyAsync(workspace, inputs.data(), N * sizeof(*inputs.data()),
hcMemcpyHostToDevice, stream));
const void **d_inputs_arr = reinterpret_cast<const void **>(workspace);

dim3 blockDims(std::min(BLOCK_SIZE, static_cast<uint32_t>(internal->maxThreadsPerBlock())));
const size_t num_packs = output_size / V;
const size_t tail_size = output_size - num_packs * V;
const size_t grid_work = num_packs > 0 ? num_packs : tail_size;
dim3 gridDims(std::min(static_cast<uint32_t>(CEIL_DIV(grid_work, blockDims.x)),
static_cast<uint32_t>(internal->gridSizeX())));
if (gridDims.x == 0) {
gridDims.x = 1;
}

elementwiseVecKernel<N, Op, Tdata, VecT, V, Args...>
<<<gridDims, blockDims, 0, stream>>>(
output_size, output, d_inputs_arr, std::forward<Args>(args)...);

return INFINI_STATUS_SUCCESS;
}

template <size_t N>
infiniStatus_t infoToDevice(
const op::elementwise::ElementwiseInfo &info,
Expand Down