diff --git a/scripts/generate_public_headers.py b/scripts/generate_public_headers.py index 93ceee3..9ba5093 100644 --- a/scripts/generate_public_headers.py +++ b/scripts/generate_public_headers.py @@ -210,24 +210,26 @@ def _parse_runtime_functions(runtime_header): _Function( return_type, name, - tuple(_parse_param(param) for param in params.split(", ") if param), + tuple( + _parse_param(param) + for param in re.split(r",\s*", params.strip()) + if param + ), ) for return_type, name, params in re.findall( - r"^(void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE + r"^(int|void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE ) ) -def _abort_statement(message): - return f""" assert(false && "{message}"); - std::abort();""" +def _unsupported_statement(): + return " return 1;" def _dispatch_cases(devices, statements): return "\n".join( f""" case {_DEVICE_TYPES[device]}: {{ {statements.replace("__DEVICE_TYPE__", _DEVICE_TYPES[device])} - return; }}""" for device in devices ) @@ -243,19 +245,21 @@ def _selector(function): return "current_device.type()" -def _runtime_arg(param): +def _runtime_arg(function, param): if param.type == "Device": - return f"{param.name}.index()" + if function.name in {"SetDevice", "GetDeviceResourceSnapshot"}: + return f"{param.name}.index()" + return None if param.type == "Device::Type": return None - if param.type == "MemcpyKind": + if param.type == "infinirtMemcpyKind": return f"RuntimeMemcpyKind<__DEVICE_TYPE__>({param.name})" return param.name def _runtime_args(function): - args = (_runtime_arg(param) for param in function.params) + args = (_runtime_arg(function, param) for param in function.params) return ", ".join(arg for arg in args if arg is not None) @@ -270,14 +274,18 @@ def _preconditions(function): if param.type.endswith("**") or param.name in required_pointer_names.get( function.name, set() ): - checks.append(f" assert({param.name} != nullptr);") + checks.append(f" if ({param.name} == nullptr) {{") + checks.append(" return 1;") + checks.append(" }") return "\n".join(checks) def _post_dispatch(function): if function.name == "SetDevice": - return "\n current_device = device;" + return ( + "\n if (rt_status == 0) {\n current_device = Device{current_device.type(), device};\n }" + ) return "" @@ -294,19 +302,23 @@ def _write_get_device(function, devices): device_param = function.params[0].name cases = _dispatch_cases( devices, - f""" int index = current_device.index(); - CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice(&index); }}); - current_device = Device{{current_device.type(), index}}; - *{device_param} = current_device;""", + f""" int status = CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice({device_param}); }}); + if (status != 0) {{ + return status; + }} + current_device = Device{{current_device.type(), *{device_param}}}; + return 0;""", ) - return f"""void GetDevice(Device* {device_param}) {{ - assert({device_param} != nullptr); + return f"""{function.return_type} GetDevice(int* {device_param}) {{ + if ({device_param} == nullptr) {{ + return 1; + }} switch (current_device.type()) {{ {cases} default: -{_abort_statement("runtime device is not enabled")} +{_unsupported_statement()} }} }} """ @@ -318,7 +330,8 @@ def _write_dispatch_function(function, devices): cases = _dispatch_cases( devices, - f""" CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)}""", + f""" int rt_status = CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)} + return rt_status;""", ) preconditions = _preconditions(function) if preconditions: @@ -328,25 +341,42 @@ def _write_dispatch_function(function, devices): {preconditions} switch ({_selector(function)}) {{ {cases} default: -{_abort_statement("runtime device is not enabled")} +{_unsupported_statement()} }} }} """ +def _runtime_header_for_device(source_root, device): + return source_root / _RUNTIME_HEADERS[device] + + +def _devices_for_function(function, devices, source_root): + enabled = [] + pattern = re.compile(r"\b" + re.escape(function.name) + r"\b") + for device in devices: + runtime_header = _runtime_header_for_device(source_root, device) + if runtime_header.exists() and pattern.search(runtime_header.read_text()): + enabled.append(device) + return enabled + + def _write_runtime_dispatch(source_path, runtime_header, devices): first_device_type = _DEVICE_TYPES[devices[0]] includes = ['#include "runtime.h"'] includes.extend(f'#include "{_RUNTIME_HEADERS[device]}"' for device in devices) functions = _parse_runtime_functions(runtime_header) + source_root = pathlib.Path(runtime_header).parent dispatch_functions = "\n".join( - _write_dispatch_function(function, devices) for function in functions + _write_dispatch_function( + function, _devices_for_function(function, devices, source_root) + ) + for function in functions ) source_path.parent.mkdir(parents=True, exist_ok=True) source_path.write_text( f"""#include -#include #include #include @@ -358,35 +388,32 @@ def _write_runtime_dispatch(source_path, runtime_header, devices): thread_local Device current_device{{{first_device_type}, 0}}; template -void CheckCall(Func&& func) {{ +int CheckCall(Func&& func) {{ using ReturnType = decltype(std::forward(func)()); if constexpr (std::is_void_v) {{ std::forward(func)(); + return 0; }} else {{ ReturnType status = std::forward(func)(); - if (status != ReturnType{{}}) {{ - assert(false && "runtime call failed"); - std::abort(); - }} + return status == ReturnType{{}} ? 0 : 1; }} }} template -auto RuntimeMemcpyKind(MemcpyKind kind) {{ +auto RuntimeMemcpyKind(infinirtMemcpyKind kind) {{ switch (kind) {{ - case MemcpyKind::kHostToHost: + case infinirtMemcpyKind::kHostToHost: return Runtime::MemcpyHostToHost; - case MemcpyKind::kHostToDevice: + case infinirtMemcpyKind::kHostToDevice: return Runtime::MemcpyHostToDevice; - case MemcpyKind::kDeviceToHost: + case infinirtMemcpyKind::kDeviceToHost: return Runtime::MemcpyDeviceToHost; - case MemcpyKind::kDeviceToDevice: + case infinirtMemcpyKind::kDeviceToDevice: return Runtime::MemcpyDeviceToDevice; }} assert(false && "unsupported memcpy kind"); - std::abort(); return Runtime::MemcpyHostToHost; }} diff --git a/src/native/cpu/runtime_.h b/src/native/cpu/runtime_.h index bf5a81c..41fa7c7 100644 --- a/src/native/cpu/runtime_.h +++ b/src/native/cpu/runtime_.h @@ -1,9 +1,12 @@ #ifndef INFINI_RT_CPU_RUNTIME__H_ #define INFINI_RT_CPU_RUNTIME__H_ -#include +#include +#include +#include #include #include +#include #include "runtime.h" @@ -13,35 +16,161 @@ template <> struct Runtime : RuntimeBase> { static constexpr Device::Type kDeviceType = Device::Type::kCpu; - static void SetDevice(int index) { - if (index != 0) { - assert(false && "CPU device index must be 0"); - std::abort(); - } - } + static int SetDevice(int index) { return index == 0 ? 0 : 1; } - static void GetDevice(int* index) { - assert(index != nullptr); + static int GetDevice(int* index) { + if (index == nullptr) { + return 1; + } *index = 0; + return 0; } - static void GetDeviceCount(int* count) { - assert(count != nullptr); + static int GetDeviceCount(int* count) { + if (count == nullptr) { + return 1; + } *count = 1; + return 0; } - static void DeviceSynchronize() {} + static int DeviceSynchronize() { return 0; } - static void Malloc(void** ptr, std::size_t size) { *ptr = std::malloc(size); } + static int Malloc(void** ptr, std::size_t size) { + if (ptr == nullptr) { + return 1; + } + *ptr = std::malloc(size); + return size != 0 && *ptr == nullptr ? 1 : 0; + } - static void Free(void* ptr) { std::free(ptr); } + static int Free(void* ptr) { + std::free(ptr); + return 0; + } - static void Memcpy(void* dst, const void* src, std::size_t size, int) { + static int Memcpy(void* dst, const void* src, std::size_t size, int) { + if ((dst == nullptr || src == nullptr) && size != 0) { + return 1; + } std::memcpy(dst, src, size); + return 0; } - static void Memset(void* ptr, int value, std::size_t count) { + static int Memset(void* ptr, int value, std::size_t count) { + if (ptr == nullptr && count != 0) { + return 1; + } std::memset(ptr, value, count); + return 0; + } + + static int MemGetInfo(std::size_t* free_bytes, std::size_t* total_bytes) { + if (free_bytes == nullptr || total_bytes == nullptr) { + return 1; + } + *free_bytes = 0; + *total_bytes = 0; + +#ifndef _WIN32 + FILE* fp = std::fopen("/proc/meminfo", "r"); + if (fp == nullptr) { + return 1; + } + + char label[64]; + std::size_t value = 0; + while (std::fscanf(fp, "%63s %zu %*s", label, &value) == 2) { + if (std::strcmp(label, "MemTotal:") == 0) { + *total_bytes = value * 1024; + } else if (std::strcmp(label, "MemAvailable:") == 0) { + *free_bytes = value * 1024; + } + } + std::fclose(fp); +#endif + if (*total_bytes == 0) { + return 1; + } + return 0; + } + + static int StreamCreate(void** stream) { + if (stream == nullptr) { + return 1; + } + *stream = nullptr; + return 0; + } + + static int StreamDestroy(void*) { return 0; } + + static int StreamSynchronize(void*) { return 0; } + + static int StreamWaitEvent(void*, void*, std::uint32_t) { return 0; } + + using Event = std::chrono::steady_clock::time_point; + + static int EventCreate(void** event) { + if (event == nullptr) { + return 1; + } + *event = new (std::nothrow) Event(std::chrono::steady_clock::now()); + return *event == nullptr ? 1 : 0; + } + + static int EventCreateWithFlags(void** event, std::uint32_t) { + return EventCreate(event); + } + + static int EventRecord(void* event, void*) { + if (event == nullptr) { + return 1; + } + *static_cast(event) = std::chrono::steady_clock::now(); + return 0; + } + + static int EventQuery(void* event) { return event == nullptr ? 1 : 0; } + + static int EventSynchronize(void* event) { return event == nullptr ? 1 : 0; } + + static int EventDestroy(void* event) { + delete static_cast(event); + return 0; + } + + static int EventElapsedTime(float* ms, void* start, void* end) { + if (ms == nullptr || start == nullptr || end == nullptr) { + return 1; + } + const auto* start_time = static_cast(start); + const auto* end_time = static_cast(end); + const auto duration = std::chrono::duration_cast( + *end_time - *start_time); + *ms = static_cast(duration.count()) / 1000.0f; + return 0; + } + + static int MallocHost(void** ptr, std::size_t size) { + return Malloc(ptr, size); + } + + static int FreeHost(void* ptr) { return Free(ptr); } + + static int MemcpyAsync(void* dst, const void* src, std::size_t size, int kind, + void*) { + return Memcpy(dst, src, size, kind); + } + + static int MallocAsync(void** ptr, std::size_t size, void*) { + return Malloc(ptr, size); + } + + static int FreeAsync(void* ptr, void*) { return Free(ptr); } + + static int MemsetAsync(void* ptr, int value, std::size_t count, void*) { + return Memset(ptr, value, count); } static constexpr int MemcpyHostToHost = 0; diff --git a/src/runtime.h b/src/runtime.h index ebc2698..7c78bea 100644 --- a/src/runtime.h +++ b/src/runtime.h @@ -2,6 +2,7 @@ #define INFINI_RT_RUNTIME_H_ #include +#include #include #include "device.h" @@ -51,28 +52,65 @@ struct DeviceRuntime : RuntimeBase { } }; -enum class MemcpyKind { +enum class infinirtMemcpyKind { kHostToHost, kHostToDevice, kDeviceToHost, kDeviceToDevice, }; -void SetDevice(Device device); +int SetDevice(int device); -void GetDevice(Device* device); +int GetDevice(int* device); -void GetDeviceCount(int* count, Device::Type type); +int GetDeviceCount(int* count); -void DeviceSynchronize(); +int DeviceSynchronize(); -void Malloc(void** ptr, std::size_t size); +int Malloc(void** ptr, std::size_t size); -void Free(void* ptr); +int Free(void* ptr); -void Memset(void* ptr, int value, std::size_t count); +int Memset(void* ptr, int value, std::size_t count); -void Memcpy(void* dst, const void* src, std::size_t count, MemcpyKind kind); +int Memcpy(void* dst, const void* src, std::size_t count, infinirtMemcpyKind kind); + +int MallocHost(void** ptr, std::size_t size); + +int FreeHost(void* ptr); + +int MemcpyAsync(void* dst, const void* src, std::size_t count, infinirtMemcpyKind kind, + void* stream); + +int MallocAsync(void** ptr, std::size_t size, void* stream); + +int FreeAsync(void* ptr, void* stream); + +int MemsetAsync(void* ptr, int value, std::size_t count, void* stream); + +int MemGetInfo(std::size_t* free_bytes, std::size_t* total_bytes); + +int StreamCreate(void** stream); + +int StreamDestroy(void* stream); + +int StreamSynchronize(void* stream); + +int StreamWaitEvent(void* stream, void* event, std::uint32_t flags); + +int EventCreate(void** event); + +int EventCreateWithFlags(void** event, std::uint32_t flags); + +int EventRecord(void* event, void* stream); + +int EventQuery(void* event); + +int EventSynchronize(void* event); + +int EventDestroy(void* event); + +int EventElapsedTime(float* ms, void* start, void* end); } // namespace infini::rt