Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Generated files
build/
build-*/
cmake-build-*/
generated/

# Prerequisites
Expand Down
74 changes: 71 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ set(PYBIND11_ENABLE_EXTRAS ON)
option(WITH_CPU "Enable CPU backend" OFF)
option(WITH_NVIDIA "Enable CUDA backend" OFF)
option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF)
option(WITH_HYGON "Enable Hygon GPU backend" OFF)
option(WITH_METAX "Enable MetaX backend" OFF)
option(WITH_CAMBRICON "Enable Cambricon backend" OFF)
option(WITH_MOORE "Enable Moore backend" OFF)

option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)

set(_DEFAULT_HYGON_DTK_ROOT "/opt/dtk")

if(AUTO_DETECT_DEVICES)
message(STATUS "Auto-detecting available devices...")

Expand All @@ -37,6 +40,13 @@ if(AUTO_DETECT_DEVICES)
message(STATUS "Auto-detected Iluvatar environment.")
endif()

if(DEFINED ENV{DTK_ROOT} OR
EXISTS "${_DEFAULT_HYGON_DTK_ROOT}/cuda/bin/nvcc" OR
EXISTS "${_DEFAULT_HYGON_DTK_ROOT}/cuda/cuda/bin/nvcc")
set(WITH_HYGON ON)
message(STATUS "Auto-detected Hygon environment.")
endif()

if(DEFINED ENV{MACA_PATH})
set(WITH_METAX ON)
message(STATUS "Auto-detected MetaX environment from MACA_PATH")
Expand Down Expand Up @@ -77,14 +87,14 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)

# Only one CUDA-like GPU backend can be enabled at a time.
set(_gpu_backend_count 0)
foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_METAX WITH_MOORE)
foreach(_gpu_backend WITH_NVIDIA WITH_ILUVATAR WITH_HYGON WITH_METAX WITH_MOORE)
if(${_gpu_backend})
math(EXPR _gpu_backend_count "${_gpu_backend_count} + 1")
endif()
endforeach()

if(_gpu_backend_count GREATER 1)
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.")
message(FATAL_ERROR "`WITH_NVIDIA`, `WITH_ILUVATAR`, `WITH_HYGON`, `WITH_METAX`, and `WITH_MOORE` are mutually exclusive. Build one GPU backend at a time.")
endif()

if(WITH_NVIDIA)
Expand All @@ -111,6 +121,64 @@ if(WITH_ILUVATAR)
find_package(CUDAToolkit REQUIRED)
endif()

if(WITH_HYGON)
add_compile_definitions(WITH_HYGON=1)
set(DTK_ROOT $ENV{DTK_ROOT})
if(NOT DTK_ROOT)
set(DTK_ROOT "${_DEFAULT_HYGON_DTK_ROOT}")
endif()
if(NOT EXISTS "${DTK_ROOT}")
message(FATAL_ERROR "`WITH_HYGON` is `ON` but `DTK_ROOT` (`${DTK_ROOT}`) does not exist.")
endif()

set(_HYGON_ARCH_DEFAULT "gfx906")
if(DEFINED ENV{HYGON_ARCH} AND NOT "$ENV{HYGON_ARCH}" STREQUAL "")
set(_HYGON_ARCH_DEFAULT "$ENV{HYGON_ARCH}")
else()
find_program(HYGON_ROCMINFO_EXECUTABLE NAMES rocminfo HINTS "${DTK_ROOT}/bin")
if(HYGON_ROCMINFO_EXECUTABLE)
execute_process(
COMMAND ${HYGON_ROCMINFO_EXECUTABLE}
OUTPUT_VARIABLE _HYGON_ROCMINFO_OUTPUT
ERROR_QUIET
OUTPUT_STRIP_TRAILING_WHITESPACE
)
string(REGEX MATCH "gfx[0-9]+" _HYGON_ARCH_AUTO "${_HYGON_ROCMINFO_OUTPUT}")
if(_HYGON_ARCH_AUTO)
set(_HYGON_ARCH_DEFAULT "${_HYGON_ARCH_AUTO}")
endif()
endif()
endif()

set(HYGON_ARCH "${_HYGON_ARCH_DEFAULT}" CACHE STRING "Hygon GPU architecture")
set(HYGON_CUDA_ROOT "${DTK_ROOT}/cuda")
if(EXISTS "${DTK_ROOT}/cuda/cuda/bin/nvcc")
set(HYGON_CUDA_ROOT "${DTK_ROOT}/cuda/cuda")
endif()

if(NOT EXISTS "${HYGON_CUDA_ROOT}/bin/nvcc")
message(FATAL_ERROR "`WITH_HYGON` is `ON` but `${HYGON_CUDA_ROOT}/bin/nvcc` was not found. Checked `${DTK_ROOT}/cuda/bin/nvcc` and `${DTK_ROOT}/cuda/cuda/bin/nvcc`.")
endif()

set(CMAKE_CUDA_COMPILER "${HYGON_CUDA_ROOT}/bin/nvcc" CACHE FILEPATH "Hygon CUDA compiler (DTK nvcc)")
set(CUDAToolkit_ROOT "${HYGON_CUDA_ROOT}" CACHE PATH "Hygon CUDA toolkit root")
set(CMAKE_CUDA_ARCHITECTURES OFF CACHE STRING "Disable default CUDA arch flags for Hygon" FORCE)
set(CMAKE_CUDA_FLAGS "-std=c++17 -fPIC -arch=${HYGON_ARCH} -Wno-return-type -Wno-error=unused-private-field" CACHE STRING "Hygon CUDA flags")
set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Hygon")

# DTK's nvcc wrapper may invoke `nvcc` by name during compiler checks.
set(ENV{PATH} "${HYGON_CUDA_ROOT}/bin:$ENV{PATH}")

include_directories("${DTK_ROOT}/include")
include_directories("${HYGON_CUDA_ROOT}/include")
link_directories("${DTK_ROOT}/lib")
link_directories("${HYGON_CUDA_ROOT}/lib64")

message(STATUS "Hygon: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${HYGON_ARCH}, DTK root ${DTK_ROOT}")
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
endif()

if(WITH_METAX)
add_compile_definitions(WITH_METAX=1)

Expand Down Expand Up @@ -179,7 +247,7 @@ if(WITH_CAMBRICON)
endif()

# If all other platforms are not enabled, CPU is enabled by default.
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE)
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_HYGON AND NOT WITH_METAX AND NOT WITH_MOORE)
add_compile_definitions(WITH_CPU=1)
endif()

Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,17 @@ For the `<OPTIONS>`:
|----------------------------------------|------------------------------------|:-:
| `-DWITH_CPU=[ON\|OFF]` | Compile the CPU implementation | n
| `-DWITH_NVIDIA=[ON\|OFF]` | Compile the NVIDIA implementation | n
| `-DWITH_ILUVATAR=[ON\|OFF]` | Compile the Iluvatar implementation| n
| `-DWITH_HYGON=[ON\|OFF]` | Compile the Hygon implementation | n
| `-DWITH_METAX=[ON\|OFF]` | Compile the MetaX implementation | n
| `-DGENERATE_PYTHON_BINDINGS=[ON\|OFF]` | Generate Python bindings | n

*Note: If no accelerator options are provided, `WITH_CPU` is enabled by default.*

For Hygon builds, set `DTK_ROOT` to the DTK installation root if it is not
installed at `/opt/dtk`. You can override the default DCU arch with
`-DHYGON_ARCH=<arch>` when configuring CMake.

## 🚀 Running Examples
After a successful build, the executables are located in the `build/examples` directory.

Expand Down
3 changes: 3 additions & 0 deletions examples/gemm/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#if WITH_ILUVATAR
#include "iluvatar/gemm/cublas.h"
#endif
#if WITH_HYGON
#include "hygon/gemm/cublas.h"
#endif
#if WITH_METAX
#include "metax/gemm/mcblas.h"
#endif
Expand Down
9 changes: 9 additions & 0 deletions examples/runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
#define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice
#define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost
#define DEFAULT_DEVICE_TYPE Device::Type::kIluvatar
#elif WITH_HYGON
#include <cuda_runtime.h>
#define DEVICE_MALLOC cudaMalloc
#define DEVICE_FREE cudaFree
#define DEVICE_MEMCPY cudaMemcpy
#define DEVICE_MEMSET cudaMemset
#define DEVICE_MEMCPY_HOST_TO_DEVICE cudaMemcpyHostToDevice
#define DEVICE_MEMCPY_DEVICE_TO_HOST cudaMemcpyDeviceToHost
#define DEFAULT_DEVICE_TYPE Device::Type::kHygon
#elif WITH_METAX
#include <mcr/mc_runtime.h>
#define DEVICE_MALLOC mcMalloc
Expand Down
30 changes: 29 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,34 @@ if(WITH_ILUVATAR)
list(APPEND DEVICE_LIST "iluvatar")
endif()

if(WITH_HYGON)
set(HYGON_PATTERNS
"cuda/*.cc"
"cuda/*.cpp"
"cuda/*.cu"
"hygon/*.cc"
"hygon/*.cpp"
"hygon/*.cu"
)

file(GLOB_RECURSE HYGON_SOURCES CONFIGURE_DEPENDS ${HYGON_PATTERNS})

enable_language(CUDA)

target_compile_definitions(infiniops PUBLIC WITH_HYGON=1)
target_sources(infiniops PRIVATE ${HYGON_SOURCES})

find_package(CUDAToolkit REQUIRED)
target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas)

set_target_properties(infiniops PROPERTIES
CUDA_STANDARD 17
CUDA_STANDARD_REQUIRED ON
)

list(APPEND DEVICE_LIST "hygon")
endif()

if(WITH_METAX)
set(METAX_PATTERNS
"cuda/*.cc"
Expand Down Expand Up @@ -154,7 +182,7 @@ if(GENERATE_PYTHON_BINDINGS)
set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc")

# TODO: There might be a better solution.
if(WITH_NVIDIA OR WITH_ILUVATAR)
if(WITH_NVIDIA OR WITH_ILUVATAR OR WITH_HYGON)
set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA)
endif()

Expand Down
4 changes: 2 additions & 2 deletions src/common/cast.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#ifndef INFINI_OPS_COMMON_CAST_H_
#define INFINI_OPS_COMMON_CAST_H_

#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_METAX) || \
defined(WITH_MOORE)
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_HYGON) || \
defined(WITH_METAX) || defined(WITH_MOORE)
#include "common/cuda/cast.h"
#else
#include "common/cpu/cast.h"
Expand Down
4 changes: 1 addition & 3 deletions src/common/cuda/cast.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#ifndef INFINI_OPS_COMMON_CUDA_CAST_H_
#define INFINI_OPS_COMMON_CUDA_CAST_H_

#ifdef WITH_NVIDIA
#include <cuda_runtime.h>
#elif defined(WITH_ILUVATAR)
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_HYGON)
#include <cuda_runtime.h>
#elif defined(WITH_METAX)
#include <mcr/mc_runtime.h>
Expand Down
11 changes: 2 additions & 9 deletions src/common/cuda/kernel_commons.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
#ifndef INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_
#define INFINI_OPS_COMMON_CUDA_KERNEL_COMMONS_H_

#ifdef WITH_NVIDIA
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_HYGON)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
using cuda_bfloat16 = nv_bfloat16;
using cuda_bfloat162 = nv_bfloat162;
#elif defined(WITH_ILUVATAR)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
using cuda_bfloat16 = nv_bfloat16;
using cuda_bfloat162 = nv_bfloat162;
Expand Down Expand Up @@ -39,7 +32,7 @@ constexpr int CUDA_BLOCK_SIZE_512 = 512;
constexpr int CUDA_BLOCK_SIZE_1024 = 1024;
constexpr int CUDA_BLOCK_SIZE_2048 = 2048;

#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR)
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_HYGON)
// Cache `cudaDeviceProp` per device, initialized once at first access.
class DevicePropertyCache {
public:
Expand Down
16 changes: 10 additions & 6 deletions src/cuda/gemm/blas.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ class Blas : public Gemm {
// TODO: Check constraints.
}

~Blas() { Backend::blasDestroy(handle_); }
~Blas() {
if (handle_ != nullptr) {
Backend::blasDestroy(handle_);
}
}

Blas(const Tensor a, const Tensor b, std::optional<float> alpha,
std::optional<float> beta, Tensor c)
Expand Down Expand Up @@ -69,7 +73,6 @@ class Blas : public Gemm {
return &beta;
}

private:
auto GetOpA(int trans_a, int trans_b) const {
if (swap_a_and_b_) {
return (b_is_col_major_ == trans_b) ? Backend::BLAS_OP_T
Expand All @@ -88,13 +91,14 @@ class Blas : public Gemm {
: Backend::BLAS_OP_N;
}

bool a_is_col_major_{false};
bool swap_a_and_b_{false};

bool b_is_col_major_{false};
mutable typename Backend::blasHandle_t handle_{};

bool swap_a_and_b_{false};
private:
bool a_is_col_major_{false};

typename Backend::blasHandle_t handle_;
bool b_is_col_major_{false};
};

} // namespace infini::ops
Expand Down
7 changes: 2 additions & 5 deletions src/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
#include <cstring>
#include <string>

#ifdef WITH_NVIDIA
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#elif defined(WITH_ILUVATAR)
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_HYGON)
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#elif defined(WITH_METAX)
Expand Down Expand Up @@ -198,7 +195,7 @@ DEFINE_DATA_TYPE_MAPPING(kInt64, std::int64_t)
DEFINE_DATA_TYPE_MAPPING(kFloat32, float)
DEFINE_DATA_TYPE_MAPPING(kFloat64, double)

#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR)
#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_HYGON)
DEFINE_DATA_TYPE_MAPPING(kFloat16, half)
DEFINE_DATA_TYPE_MAPPING(kBFloat16, __nv_bfloat16)
#elif defined(WITH_METAX)
Expand Down
42 changes: 42 additions & 0 deletions src/hygon/add/kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#ifndef INFINI_OPS_HYGON_ADD_KERNEL_H_
#define INFINI_OPS_HYGON_ADD_KERNEL_H_

#include <utility>

// clang-format off
#include <cuda_runtime.h>
// clang-format on

#include "cuda/add/kernel.h"

namespace infini::ops {

namespace add {

struct HygonBackend {
using stream_t = cudaStream_t;

static constexpr auto malloc = [](auto&&... args) {
return cudaMalloc(std::forward<decltype(args)>(args)...);
};

static constexpr auto memcpy = cudaMemcpy;

static constexpr auto free = [](auto&&... args) {
return cudaFree(std::forward<decltype(args)>(args)...);
};

static constexpr auto memcpyH2D = cudaMemcpyHostToDevice;
};

} // namespace add

template <>
class Operator<Add, Device::Type::kHygon> : public CudaAdd<add::HygonBackend> {
public:
using CudaAdd<add::HygonBackend>::CudaAdd;
};

} // namespace infini::ops

#endif
Loading