diff --git a/backends/metax_gpu/cinn/compiler/compiler.cc b/backends/metax_gpu/cinn/compiler/compiler.cc index b65f73e6e4..ea128e09e6 100644 --- a/backends/metax_gpu/cinn/compiler/compiler.cc +++ b/backends/metax_gpu/cinn/compiler/compiler.cc @@ -59,6 +59,37 @@ typedef long long int64_t; // Compatible with __half references in CINN-generated code typedef __half float16; +// BFloat16 software type for CINN-generated code +struct bfloat16 { + unsigned short x; + __device__ bfloat16() {} + __device__ explicit bfloat16(float val) { + unsigned int ival = *(unsigned int*)&val; + // Round to nearest even + unsigned int lsb = (ival >> 16) & 1; + unsigned int rounding_bias = 0x7fff + lsb; + ival += rounding_bias; + x = (unsigned short)(ival >> 16); + } + __device__ explicit operator float() const { + unsigned int val = ((unsigned int)x) << 16; + return *(float*)&val; + } + __device__ bfloat16 operator+(const bfloat16& o) const { return bfloat16((float)*this + (float)o); } + __device__ bfloat16 operator-(const bfloat16& o) const { return bfloat16((float)*this - (float)o); } + __device__ bfloat16 operator*(const bfloat16& o) const { return bfloat16((float)*this * (float)o); } + __device__ bfloat16 operator/(const bfloat16& o) const { return bfloat16((float)*this / (float)o); } + __device__ bool operator<(const bfloat16& o) const { return (float)*this < (float)o; } + __device__ bool operator>(const bfloat16& o) const { return (float)*this > (float)o; } + __device__ bool operator==(const bfloat16& o) const { return x == o.x; } + __device__ bool operator!=(const bfloat16& o) const { return x != o.x; } + __device__ bool operator<=(const bfloat16& o) const { return (float)*this <= (float)o; } + __device__ bool operator>=(const bfloat16& o) const { return (float)*this >= (float)o; } +}; + +#define CINN_BF16_MIN bfloat16(-3.38953139e+38f) +#define CINN_BF16_MAX bfloat16(3.38953139e+38f) + #define CINN_UINT8_MIN 0 #define CINN_UINT8_MAX 255 #define CINN_INT16_MIN -32768 @@ -308,6 +339,51 @@ __device__ inline float16 FN_FP16(fma)(float16 a, float16 b, float16 c) { return __device__ inline float16 FN_FP16(max)(float16 a, float16 b) { return __hgt(a, b) ? a : b; } __device__ inline float16 FN_FP16(min)(float16 a, float16 b) { return __hlt(a, b) ? a : b; } +// =============================================================== +// BFloat16 Functions +// =============================================================== +#define FN_BF16(func) cinn_custom_device_##func##_bf16 + +__device__ inline bfloat16 FN_BF16(ceil)(bfloat16 x) { return bfloat16(ceilf((float)x)); } +__device__ inline bfloat16 FN_BF16(floor)(bfloat16 x) { return bfloat16(floorf((float)x)); } +__device__ inline bfloat16 FN_BF16(round)(bfloat16 x) { return bfloat16(roundf((float)x)); } +__device__ inline bfloat16 FN_BF16(trunc)(bfloat16 x) { return bfloat16(truncf((float)x)); } +__device__ inline bfloat16 FN_BF16(sin)(bfloat16 x) { return bfloat16(sinf((float)x)); } +__device__ inline bfloat16 FN_BF16(cos)(bfloat16 x) { return bfloat16(cosf((float)x)); } +__device__ inline bfloat16 FN_BF16(exp)(bfloat16 x) { return bfloat16(expf((float)x)); } +__device__ inline bfloat16 FN_BF16(log)(bfloat16 x) { return bfloat16(logf((float)x)); } +__device__ inline bfloat16 FN_BF16(log2)(bfloat16 x) { return bfloat16(log2f((float)x)); } +__device__ inline bfloat16 FN_BF16(log10)(bfloat16 x) { return bfloat16(log10f((float)x)); } +__device__ inline bfloat16 FN_BF16(sqrt)(bfloat16 x) { return bfloat16(sqrtf((float)x)); } +__device__ inline bfloat16 FN_BF16(rsqrt)(bfloat16 x) { return bfloat16(rsqrtf((float)x)); } +__device__ inline bfloat16 FN_BF16(cbrt)(bfloat16 x) { return bfloat16(cbrtf((float)x)); } +__device__ inline bfloat16 FN_BF16(abs)(bfloat16 x) { return bfloat16(fabsf((float)x)); } +__device__ inline bool FN_BF16(isnan)(bfloat16 x) { return isnan((float)x); } +__device__ inline bool FN_BF16(isinf)(bfloat16 x) { return isinf((float)x); } +__device__ inline bool FN_BF16(isfinite)(bfloat16 x) { return isfinite((float)x); } +__device__ inline bfloat16 FN_BF16(erf)(bfloat16 x) { return bfloat16(erff((float)x)); } +__device__ inline bfloat16 FN_BF16(tan)(bfloat16 x) { return bfloat16(tanf((float)x)); } +__device__ inline bfloat16 FN_BF16(sinh)(bfloat16 x) { return bfloat16(sinhf((float)x)); } +__device__ inline bfloat16 FN_BF16(cosh)(bfloat16 x) { return bfloat16(coshf((float)x)); } +__device__ inline bfloat16 FN_BF16(tanh)(bfloat16 x) { return bfloat16(tanhf((float)x)); } +__device__ inline bfloat16 FN_BF16(asin)(bfloat16 x) { return bfloat16(asinf((float)x)); } +__device__ inline bfloat16 FN_BF16(acos)(bfloat16 x) { return bfloat16(acosf((float)x)); } +__device__ inline bfloat16 FN_BF16(atan)(bfloat16 x) { return bfloat16(atanf((float)x)); } +__device__ inline bfloat16 FN_BF16(asinh)(bfloat16 x) { return bfloat16(asinhf((float)x)); } +__device__ inline bfloat16 FN_BF16(acosh)(bfloat16 x) { return bfloat16(acoshf((float)x)); } +__device__ inline bfloat16 FN_BF16(atanh)(bfloat16 x) { return bfloat16(atanhf((float)x)); } +__device__ inline bfloat16 FN_BF16(sigmoid)(bfloat16 x) { return bfloat16(1.0f / (1.0f + expf(-(float)x))); } +__device__ inline bfloat16 FN_BF16(mod)(bfloat16 a, bfloat16 b) { return bfloat16(fmodf((float)a, (float)b)); } +__device__ inline bfloat16 FN_BF16(pow)(bfloat16 a, bfloat16 b) { return bfloat16(powf((float)a, (float)b)); } +__device__ inline bfloat16 FN_BF16(add)(bfloat16 a, bfloat16 b) { return bfloat16((float)a + (float)b); } +__device__ inline bfloat16 FN_BF16(sub)(bfloat16 a, bfloat16 b) { return bfloat16((float)a - (float)b); } +__device__ inline bfloat16 FN_BF16(mul)(bfloat16 a, bfloat16 b) { return bfloat16((float)a * (float)b); } +__device__ inline bfloat16 FN_BF16(div)(bfloat16 a, bfloat16 b) { return bfloat16((float)a / (float)b); } +__device__ inline bfloat16 FN_BF16(neg)(bfloat16 a) { return bfloat16(-(float)a); } +__device__ inline bfloat16 FN_BF16(fma)(bfloat16 a, bfloat16 b, bfloat16 c) { return bfloat16(fmaf((float)a, (float)b, (float)c)); } +__device__ inline bfloat16 FN_BF16(max)(bfloat16 a, bfloat16 b) { return (float)a > (float)b ? a : b; } +__device__ inline bfloat16 FN_BF16(min)(bfloat16 a, bfloat16 b) { return (float)a < (float)b ? a : b; } + // =============================================================== // Warp Shuffle Functions (used by reduce operators) // =============================================================== @@ -348,6 +424,20 @@ __device__ inline __half FN_SHUFFLE(warp_shuffle_down_fp16)(__half v, int factor unsigned short res = (unsigned short)__shfl_down((int)val, factor); return __ushort_as_half(res); } + +// BFloat16 warp shuffle (bitcast through unsigned short) +__device__ inline bfloat16 FN_SHUFFLE(warp_shuffle_xor_bf16)(bfloat16 v, int factor) { + unsigned short res = (unsigned short)__shfl_xor((int)v.x, factor); + bfloat16 r; r.x = res; return r; +} +__device__ inline bfloat16 FN_SHUFFLE(warp_shuffle_up_bf16)(bfloat16 v, int factor) { + unsigned short res = (unsigned short)__shfl_up((int)v.x, factor); + bfloat16 r; r.x = res; return r; +} +__device__ inline bfloat16 FN_SHUFFLE(warp_shuffle_down_bf16)(bfloat16 v, int factor) { + unsigned short res = (unsigned short)__shfl_down((int)v.x, factor); + bfloat16 r; r.x = res; return r; +} } // extern "C" // =============================================================== @@ -459,11 +549,10 @@ __device__ inline float16 cinn_max_fp16(const float16 left, const float16 right) __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) { return __hlt(left, right) ? left : right; } // --- BF16 (BFloat16) --- -// [Note] If mxcc does not support __nv_bfloat16, this section should be commented out or produce an error -#if defined(__MACACC__) || defined(__CUDACC__) // Assuming support is available -// Placeholder: comment out the BF16 section if compilation errors occur -// __device__ inline __nv_bfloat16 cinn_sum_bf16(...) ... -#endif +__device__ inline bfloat16 cinn_sum_bf16(const bfloat16 left, const bfloat16 right) { return bfloat16((float)left + (float)right); } +__device__ inline bfloat16 cinn_prod_bf16(const bfloat16 left, const bfloat16 right) { return bfloat16((float)left * (float)right); } +__device__ inline bfloat16 cinn_max_bf16(const bfloat16 left, const bfloat16 right) { return (float)left > (float)right ? left : right; } +__device__ inline bfloat16 cinn_min_bf16(const bfloat16 left, const bfloat16 right) { return (float)left < (float)right ? left : right; } // =============================================================== // 3. Reduce Initialization Macros @@ -512,6 +601,13 @@ __device__ inline float16 cinn_min_fp16(const float16 left, const float16 right) MACRO(max_fp16, -65504.0, float16, ##__VA_ARGS__) \ MACRO(min_fp16, 65504.0, float16, ##__VA_ARGS__) +// BF16 initial values +#define EXPAND_REDUCE_BF16_MACRO(MACRO, ...) \ + MACRO(sum_bf16, 0.0, bfloat16, ##__VA_ARGS__) \ + MACRO(prod_bf16, 1.0, bfloat16, ##__VA_ARGS__) \ + MACRO(max_bf16, -3.38953139e+38f, bfloat16, ##__VA_ARGS__) \ + MACRO(min_bf16, 3.38953139e+38f, bfloat16, ##__VA_ARGS__) + // =============================================================== // 4. Warp Shuffle Wrappers (Using Legacy API & Full Down Strategy) @@ -559,6 +655,11 @@ __device__ inline float16 cinn_warp_shuffle_down_float16_wrapper(float16 v, int return __ushort_as_half((unsigned short)__shfl_down((int)val, factor)); } +__device__ inline bfloat16 cinn_warp_shuffle_down_bfloat16_wrapper(bfloat16 v, int factor) { + unsigned short res = (unsigned short)__shfl_down((int)v.x, factor); + bfloat16 r; r.x = res; return r; +} + __device__ inline welford_fp32 cinn_warp_shuffle_down_welford_fp32_wrapper(welford_fp32 v, int factor) { float m = __shfl_down(v.mean, factor); float m2 = __shfl_down(v.m2, factor); @@ -582,6 +683,11 @@ __device__ inline float16 cinn_warp_shuffle_idx_float16_wrapper(float16 v, int l return __ushort_as_half((unsigned short)__shfl((int)val, lane)); } +__device__ inline bfloat16 cinn_warp_shuffle_idx_bfloat16_wrapper(bfloat16 v, int lane) { + unsigned short res = (unsigned short)__shfl((int)v.x, lane); + bfloat16 r; r.x = res; return r; +} + __device__ inline double cinn_warp_shuffle_idx_double_wrapper(double v, int lane) { unsigned long long int val_u64 = *(unsigned long long int*)&v; int lo = __shfl((int)val_u64, lane); @@ -617,6 +723,7 @@ EXPAND_REDUCE_FP32_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) EXPAND_REDUCE_FP64_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) EXPAND_REDUCE_BOOL_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) EXPAND_REDUCE_FP16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) +EXPAND_REDUCE_BF16_MACRO(CINN_WARP_SHUFFLE_INTERNAL_IMPL) // =============================================================== // 5. Block Reduce & Discrete Reduce & Grid Reduce @@ -673,6 +780,7 @@ EXPAND_REDUCE_FP32_MACRO(CINN_BLOCK_REDUCE_MACRO) EXPAND_REDUCE_FP64_MACRO(CINN_BLOCK_REDUCE_MACRO) EXPAND_REDUCE_BOOL_MACRO(CINN_BLOCK_REDUCE_MACRO) EXPAND_REDUCE_FP16_MACRO(CINN_BLOCK_REDUCE_MACRO) +EXPAND_REDUCE_BF16_MACRO(CINN_BLOCK_REDUCE_MACRO) #define CINN_DISCRETE_REDUCE_IMPL(REDUCE_TYPE, value) \ int tid = threadIdx.y * blockDim.x + threadIdx.x; \ @@ -699,6 +807,7 @@ EXPAND_REDUCE_FP32_MACRO(CINN_DISCRETE_REDUCE_MACRO) EXPAND_REDUCE_FP64_MACRO(CINN_DISCRETE_REDUCE_MACRO) EXPAND_REDUCE_BOOL_MACRO(CINN_DISCRETE_REDUCE_MACRO) EXPAND_REDUCE_FP16_MACRO(CINN_DISCRETE_REDUCE_MACRO) +EXPAND_REDUCE_BF16_MACRO(CINN_DISCRETE_REDUCE_MACRO) // =============================================================== // ArgMin/ArgMax Support (ArgIdx Structures & Combine Functions) @@ -800,6 +909,7 @@ EXPAND_REDUCE_FP32_MACRO(CINN_GRID_REDUCE_MACRO) EXPAND_REDUCE_FP64_MACRO(CINN_GRID_REDUCE_MACRO) EXPAND_REDUCE_BOOL_MACRO(CINN_GRID_REDUCE_MACRO) EXPAND_REDUCE_FP16_MACRO(CINN_GRID_REDUCE_MACRO) +EXPAND_REDUCE_BF16_MACRO(CINN_GRID_REDUCE_MACRO) __device__ inline bool cinn_grid_reduce_update_semaphore(int *semaphores) { __shared__ bool done; @@ -888,6 +998,7 @@ CINN_CUSTOM_DEVICE_LT_NUM(int16, int16_t) CINN_CUSTOM_DEVICE_LT_NUM(int32, int) CINN_CUSTOM_DEVICE_LT_NUM(int64, int64_t) CINN_CUSTOM_DEVICE_LT_NUM(fp16, float16) +CINN_CUSTOM_DEVICE_LT_NUM(bf16, bfloat16) #undef CINN_CUSTOM_DEVICE_LT_NUM #define CINN_CUSTOM_DEVICE_GT_NUM(TYPE_SUFFIX, TYPE) \ @@ -910,6 +1021,7 @@ CINN_CUSTOM_DEVICE_GT_NUM(int16, int16_t) CINN_CUSTOM_DEVICE_GT_NUM(int32, int) CINN_CUSTOM_DEVICE_GT_NUM(int64, int64_t) CINN_CUSTOM_DEVICE_GT_NUM(fp16, float16) +CINN_CUSTOM_DEVICE_GT_NUM(bf16, bfloat16) #undef CINN_CUSTOM_DEVICE_GT_NUM #define CINN_CUSTOM_DEVICE_INDEX_ADD(TYPE_SUFFIX, TYPE) \ @@ -939,6 +1051,7 @@ CINN_CUSTOM_DEVICE_INDEX_ADD(int64, int64_t) CINN_CUSTOM_DEVICE_INDEX_ADD(fp32, float) CINN_CUSTOM_DEVICE_INDEX_ADD(fp64, double) CINN_CUSTOM_DEVICE_INDEX_ADD(fp16, float16) +CINN_CUSTOM_DEVICE_INDEX_ADD(bf16, bfloat16) #undef CINN_CUSTOM_DEVICE_INDEX_ADD __device__ int cinn_custom_device_resize_bilinear(const int *buf,