Skip to content

Commit 1449203

Browse files
committed
feat: Add RMSNorm op in cambricon backend.
1 parent 71fc388 commit 1449203

19 files changed

Lines changed: 600 additions & 23 deletions

File tree

src/CMakeLists.txt

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,48 @@ if(WITH_METAX)
104104
endif()
105105

106106
if(WITH_CAMBRICON)
107-
target_compile_definitions(infiniops PUBLIC WITH_CAMBRICON=1)
107+
file(GLOB_RECURSE CAMBRICON_MLU_SOURCES CONFIGURE_DEPENDS "cambricon/*/*.mlu")
108+
find_program(CNCC_COMPILER cncc HINTS "${NEUWARE_HOME}/bin" "$ENV{NEUWARE_HOME}/bin" /usr/local/neuware/bin)
109+
if(CNCC_COMPILER)
110+
message(STATUS "Found cncc: ${CNCC_COMPILER}")
111+
set(MLU_COMPILE_OPTS
112+
-c --bang-mlu-arch=mtp_592 -O3 -fPIC -Wall -Werror -std=c++17 -pthread
113+
-I${CMAKE_CURRENT_SOURCE_DIR} -I${NEUWARE_HOME}/include
114+
-idirafter /usr/local/neuware/lib/clang/11.1.0/include
115+
)
116+
function(compile_mlu_file src_file)
117+
get_filename_component(name ${src_file} NAME_WE)
118+
get_filename_component(path ${src_file} DIRECTORY)
119+
set(out_file "${CMAKE_CURRENT_BINARY_DIR}/${path}/${name}.o")
120+
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${path}")
121+
add_custom_command(OUTPUT ${out_file}
122+
COMMAND ${CNCC_COMPILER} ${MLU_COMPILE_OPTS} -c ${src_file} -o ${out_file}
123+
DEPENDS ${src_file}
124+
COMMENT "Building MLU kernel: ${src_file}"
125+
)
126+
set_property(DIRECTORY APPEND PROPERTY CAMBRICON_OBJECTS ${out_file})
127+
endfunction()
128+
foreach(src ${CAMBRICON_MLU_SOURCES})
129+
compile_mlu_file(${src})
130+
endforeach()
131+
get_directory_property(CAMBRICON_OBJECT_FILES CAMBRICON_OBJECTS)
132+
if(CAMBRICON_OBJECT_FILES)
133+
target_sources(infiniops PRIVATE ${CAMBRICON_OBJECT_FILES})
134+
endif()
135+
else()
136+
message(WARNING "cncc compiler not found. MLU kernels will not be compiled.")
137+
endif()
138+
target_compile_definitions(infiniops PRIVATE WITH_CAMBRICON=1)
108139

109140
target_include_directories(infiniops PUBLIC "${NEUWARE_HOME}/include")
110141
target_link_libraries(infiniops PUBLIC ${CAMBRICON_RUNTIME_LIB} ${CAMBRICON_CNNL_LIB} ${CAMBRICON_CNNL_EXTRA_LIB} ${CAMBRICON_PAPI_LIB})
111142

143+
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
144+
target_compile_options(infiniops PUBLIC
145+
"$<$<COMPILE_LANGUAGE:CXX>:SHELL:-idirafter /usr/local/neuware/lib/clang/11.1.0/include>"
146+
)
147+
endif()
148+
112149
list(APPEND DEVICE_LIST "cambricon")
113150
endif()
114151

src/base/rms_norm.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,17 @@ namespace infini::ops {
1212
class RmsNorm : public Operator<RmsNorm> {
1313
public:
1414
RmsNorm(const Tensor input, const Tensor weight, float eps, Tensor out)
15-
: eps_{eps},
15+
: input_shape_{input.shape()},
1616
out_shape_{out.shape()},
17-
input_shape_{input.shape()},
18-
out_strides_{out.strides()},
1917
input_strides_{input.strides()},
18+
out_strides_{out.strides()},
19+
eps_{eps},
2020
dim_{out.size(-1)},
2121
ndim_{out.ndim()},
2222
batch_size_{ndim_ == 2 ? out.size(-2) : out.size(-3)},
23-
nhead_{ndim_ == 2 ? 1 : out.size(-2)} {}
23+
nhead_{ndim_ == 2 ? 1 : out.size(-2)} {
24+
assert(input.dtype() == out.dtype());
25+
}
2426

2527
RmsNorm(const Tensor input, const Tensor weight, Tensor out)
2628
: RmsNorm{input, weight, 1e-6f, out} {}

src/cambricon/cast.h

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#ifndef INFINI_OPS_COMMON_CAMBRICON_CAST_H_
2+
#define INFINI_OPS_COMMON_CAMBRICON_CAST_H_
3+
4+
#include "bang_fp16.h"
5+
#include "bang_bf16.h"
6+
7+
#include "data_type.h"
8+
9+
namespace infini::ops {
10+
11+
namespace detail {
12+
13+
template <typename T>
14+
using PureType = std::remove_cv_t<std::remove_reference_t<T>>;
15+
16+
template <typename T>
17+
__host__ __device__ constexpr float ToFloatHelper(T&& x) {
18+
using PureSrc = PureType<T>;
19+
if constexpr (IsBFloat16<PureSrc>) {
20+
return __bfloat162float__(x);
21+
} else if constexpr (IsFP16<PureSrc>) {
22+
return __half2float(x);
23+
} else {
24+
return static_cast<float>(std::forward<T>(x));
25+
}
26+
}
27+
28+
template <typename Dst>
29+
__host__ __device__ constexpr Dst FromFloatHelper(float f) {
30+
using PureDst = PureType<Dst>;
31+
if constexpr (IsBFloat16<PureDst>) {
32+
return __float2bfloat16__(f);
33+
} else if constexpr (IsFP16<PureDst>) {
34+
return __float2half__(f);
35+
} else {
36+
return static_cast<Dst>(f);
37+
}
38+
}
39+
40+
// Priority tags for overload resolution.
41+
struct PriorityLow {};
42+
43+
struct PriorityHigh : PriorityLow {};
44+
45+
// Fallback: lowest priority. This always matches if nothing else does.
46+
template <typename Dst, typename Src>
47+
__host__ __device__ constexpr Dst HardwareCast(Src&& x, PriorityLow) {
48+
return FromFloatHelper<Dst>(ToFloatHelper(std::forward<Src>(x)));
49+
}
50+
51+
// Usage: `DEFINE_DIRECT_CAST(INTRINSIC, CONDITION)`.
52+
#define DEFINE_DIRECT_CAST(INTRINSIC, ...) \
53+
template <typename Dst, typename Src> \
54+
__host__ __device__ auto HardwareCast(Src x, PriorityHigh) \
55+
->std::enable_if_t<(__VA_ARGS__), \
56+
decltype(INTRINSIC(std::declval<Src>()))> { \
57+
return INTRINSIC(x); \
58+
}
59+
60+
DEFINE_DIRECT_CAST(
61+
__bfloat162int_rz__,
62+
std::is_same_v<PureType<Dst>, int>&& IsBFloat16<PureType<Src>>)
63+
DEFINE_DIRECT_CAST(
64+
__bfloat162short_rz__,
65+
std::is_same_v<PureType<Dst>, short>&& IsBFloat16<PureType<Src>>)
66+
DEFINE_DIRECT_CAST(
67+
__int2bfloat16_rn__,
68+
IsBFloat16<PureType<Dst>>&& std::is_same_v<PureType<Src>, int>)
69+
DEFINE_DIRECT_CAST(__int2half_rn__,
70+
IsFP16<PureType<Dst>>&& std::is_same_v<PureType<Src>, int>)
71+
DEFINE_DIRECT_CAST(
72+
__float2bfloat16__,
73+
IsBFloat16<PureType<Dst>>&& std::is_same_v<PureType<Src>, double>)
74+
DEFINE_DIRECT_CAST(
75+
__float2half__,
76+
IsFP16<PureType<Dst>>&& std::is_same_v<PureType<Src>, double>)
77+
DEFINE_DIRECT_CAST(__half, IsFP16<PureType<Dst>>&& IsBFloat16<PureType<Src>>)
78+
#undef DEFINE_DIRECT_CAST
79+
80+
} // namespace detail
81+
82+
template <typename Dst, typename Src>
83+
__host__ __device__ Dst Cast(Src&& x) {
84+
static_assert(!std::is_reference_v<Dst>,
85+
"`Cast` cannot return reference types");
86+
87+
using PureSrc = std::remove_cv_t<std::remove_reference_t<Src>>;
88+
using PureDst = std::remove_cv_t<std::remove_reference_t<Dst>>;
89+
90+
if constexpr (std::is_same_v<PureSrc, PureDst>) {
91+
return std::forward<Src>(x);
92+
} else {
93+
return detail::HardwareCast<PureDst>(std::forward<Src>(x),
94+
detail::PriorityHigh{});
95+
}
96+
}
97+
98+
} // namespace infini::ops
99+
100+
#endif

src/cambricon/common.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,78 @@
22
#define INFINI_OPS_CAMBRICON_COMMON_H_
33

44
#include <cnnl.h>
5+
#include <cnrt.h>
56

67
#include "data_type.h"
8+
#include "device.h"
9+
10+
#define NRAM_MAX_SIZE (1024 * 240)
11+
12+
#ifdef __BANG__
13+
14+
namespace infini::ops::reduce {
15+
16+
constexpr int batch_size = 128 / sizeof(float);
17+
18+
__mlu_func__ void SumInternal(float *dst, float *src, int max_batch) {
19+
const int width = max_batch / batch_size;
20+
21+
if (width >= 4) {
22+
__bang_sumpool(
23+
dst, src,
24+
batch_size, 1, width,
25+
1, width, 1, 1);
26+
__bang_reduce_sum(dst, dst, batch_size);
27+
} else {
28+
float sum = 0.0f;
29+
for (int i = 0; i < max_batch; ++i) {
30+
sum += src[i];
31+
}
32+
dst[0] = sum;
33+
}
34+
}
35+
36+
} // namespace infini::ops::reduce
37+
38+
#endif // __BANG__
739

840
namespace infini::ops::cnnl_utils {
941

1042
inline cnnlDataType_t GetDataType(DataType dtype) {
1143
switch (dtype) {
44+
case DataType::kInt8:
45+
return CNNL_DTYPE_INT8;
46+
case DataType::kUInt8:
47+
return CNNL_DTYPE_UINT8;
1248
case DataType::kInt32:
1349
return CNNL_DTYPE_INT32;
50+
case DataType::kInt64:
51+
return CNNL_DTYPE_INT64;
1452
case DataType::kFloat16:
1553
return CNNL_DTYPE_HALF;
1654
case DataType::kFloat32:
1755
return CNNL_DTYPE_FLOAT;
56+
case DataType::kBFloat16:
57+
return CNNL_DTYPE_BFLOAT16;
58+
case DataType::kFloat64:
59+
return CNNL_DTYPE_DOUBLE;
1860
default:
1961
return CNNL_DTYPE_INVALID;
2062
}
2163
}
2264

2365
} // namespace infini::ops::cnnl_utils
2466

67+
namespace infini::ops::cnrt_utils {
68+
69+
inline void GetLaunchConfig(const Device& device, int* core_per_cluster,
70+
int* cluster_count) {
71+
int device_id = device.index();
72+
cnrtDeviceGetAttribute(cluster_count, cnrtAttrClusterCount, device_id);
73+
cnrtDeviceGetAttribute(core_per_cluster, cnrtAttrMcorePerCluster,
74+
device_id);
75+
}
76+
77+
} // namespace infini::ops::cnrt_utils
78+
2579
#endif

0 commit comments

Comments
 (0)