Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 50 additions & 25 deletions scripts/generate_public_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,21 +213,19 @@ def _parse_runtime_functions(runtime_header):
tuple(_parse_param(param) for param in params.split(", ") if param),
)
for return_type, name, params in re.findall(
r"^(void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE
r"^(int|void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE
)
)


def _abort_statement(message):
return f""" assert(false && "{message}");
std::abort();"""
def _unsupported_statement():
return " return 1;"


def _dispatch_cases(devices, statements):
return "\n".join(
f""" case {_DEVICE_TYPES[device]}: {{
{statements.replace("__DEVICE_TYPE__", _DEVICE_TYPES[device])}
return;
}}"""
for device in devices
)
Expand All @@ -243,9 +241,11 @@ def _selector(function):
return "current_device.type()"


def _runtime_arg(param):
def _runtime_arg(function, param):
if param.type == "Device":
return f"{param.name}.index()"
if function.name in {"SetDevice", "GetDeviceResourceSnapshot"}:
return f"{param.name}.index()"
return None
if param.type == "Device::Type":
return None
if param.type == "MemcpyKind":
Expand All @@ -255,7 +255,7 @@ def _runtime_arg(param):


def _runtime_args(function):
args = (_runtime_arg(param) for param in function.params)
args = (_runtime_arg(function, param) for param in function.params)

return ", ".join(arg for arg in args if arg is not None)

Expand All @@ -270,14 +270,18 @@ def _preconditions(function):
if param.type.endswith("**") or param.name in required_pointer_names.get(
function.name, set()
):
checks.append(f" assert({param.name} != nullptr);")
checks.append(f" if ({param.name} == nullptr) {{")
checks.append(" return 1;")
checks.append(" }")

return "\n".join(checks)


def _post_dispatch(function):
if function.name == "SetDevice":
return "\n current_device = device;"
return (
"\n if (rt_status == 0) {\n current_device = device;\n }"
)

return ""

Expand All @@ -295,18 +299,24 @@ def _write_get_device(function, devices):
cases = _dispatch_cases(
devices,
f""" int index = current_device.index();
CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice(&index); }});
int status = CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice(&index); }});
if (status != 0) {{
return status;
}}
current_device = Device{{current_device.type(), index}};
*{device_param} = current_device;""",
*{device_param} = current_device;
return 0;""",
)

return f"""void GetDevice(Device* {device_param}) {{
assert({device_param} != nullptr);
return f"""{function.return_type} GetDevice(Device* {device_param}) {{
if ({device_param} == nullptr) {{
return 1;
}}

switch (current_device.type()) {{
{cases}
default:
{_abort_statement("runtime device is not enabled")}
{_unsupported_statement()}
}}
}}
"""
Expand All @@ -318,7 +328,8 @@ def _write_dispatch_function(function, devices):

cases = _dispatch_cases(
devices,
f""" CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)}""",
f""" int rt_status = CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)}
return rt_status;""",
)
preconditions = _preconditions(function)
if preconditions:
Expand All @@ -328,25 +339,42 @@ def _write_dispatch_function(function, devices):
{preconditions} switch ({_selector(function)}) {{
{cases}
default:
{_abort_statement("runtime device is not enabled")}
{_unsupported_statement()}
}}
}}
"""


def _runtime_header_for_device(source_root, device):
return source_root / _RUNTIME_HEADERS[device]


def _devices_for_function(function, devices, source_root):
enabled = []
pattern = re.compile(r"\b" + re.escape(function.name) + r"\b")
for device in devices:
runtime_header = _runtime_header_for_device(source_root, device)
if runtime_header.exists() and pattern.search(runtime_header.read_text()):
enabled.append(device)
return enabled


def _write_runtime_dispatch(source_path, runtime_header, devices):
first_device_type = _DEVICE_TYPES[devices[0]]
includes = ['#include "runtime.h"']
includes.extend(f'#include "{_RUNTIME_HEADERS[device]}"' for device in devices)
functions = _parse_runtime_functions(runtime_header)
source_root = pathlib.Path(runtime_header).parent
dispatch_functions = "\n".join(
_write_dispatch_function(function, devices) for function in functions
_write_dispatch_function(
function, _devices_for_function(function, devices, source_root)
)
for function in functions
)

source_path.parent.mkdir(parents=True, exist_ok=True)
source_path.write_text(
f"""#include <cassert>
#include <cstdlib>
#include <type_traits>
#include <utility>

Expand All @@ -358,17 +386,15 @@ def _write_runtime_dispatch(source_path, runtime_header, devices):
thread_local Device current_device{{{first_device_type}, 0}};

template <typename Func>
void CheckCall(Func&& func) {{
int CheckCall(Func&& func) {{
using ReturnType = decltype(std::forward<Func>(func)());

if constexpr (std::is_void_v<ReturnType>) {{
std::forward<Func>(func)();
return 0;
}} else {{
ReturnType status = std::forward<Func>(func)();
if (status != ReturnType{{}}) {{
assert(false && "runtime call failed");
std::abort();
}}
return status == ReturnType{{}} ? 0 : 1;
}}
}}

Expand All @@ -386,7 +412,6 @@ def _write_runtime_dispatch(source_path, runtime_header, devices):
}}

assert(false && "unsupported memcpy kind");
std::abort();
return Runtime<kDev>::MemcpyHostToHost;
}}

Expand Down
166 changes: 150 additions & 16 deletions src/native/cpu/runtime_.h

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

按照刚才讨论的内容,CPU 的先往后放一放,优先搞英伟达 GPU 的吧。

Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
#ifndef INFINI_RT_CPU_RUNTIME__H_
#define INFINI_RT_CPU_RUNTIME__H_

#include <cassert>
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <new>

#include "runtime.h"

Expand All @@ -13,35 +15,167 @@ template <>
struct Runtime<Device::Type::kCpu> : RuntimeBase<Runtime<Device::Type::kCpu>> {
static constexpr Device::Type kDeviceType = Device::Type::kCpu;

static void SetDevice(int index) {
if (index != 0) {
assert(false && "CPU device index must be 0");
std::abort();
}
}
static int SetDevice(int index) { return index == 0 ? 0 : 1; }

static void GetDevice(int* index) {
assert(index != nullptr);
static int GetDevice(int* index) {
if (index == nullptr) {
return 1;
}
*index = 0;
return 0;
}

static void GetDeviceCount(int* count) {
assert(count != nullptr);
static int GetDeviceCount(int* count) {
if (count == nullptr) {
return 1;
}
*count = 1;
return 0;
}

static void DeviceSynchronize() {}
static int DeviceSynchronize() { return 0; }

static void Malloc(void** ptr, std::size_t size) { *ptr = std::malloc(size); }
static int Malloc(void** ptr, std::size_t size) {
if (ptr == nullptr) {
return 1;
}
*ptr = std::malloc(size);
return size != 0 && *ptr == nullptr ? 1 : 0;
}

static void Free(void* ptr) { std::free(ptr); }
static int Free(void* ptr) {
std::free(ptr);
return 0;
}

static void Memcpy(void* dst, const void* src, std::size_t size, int) {
static int Memcpy(void* dst, const void* src, std::size_t size, int) {
if ((dst == nullptr || src == nullptr) && size != 0) {
return 1;
}
std::memcpy(dst, src, size);
return 0;
}

static void Memset(void* ptr, int value, std::size_t count) {
static int Memset(void* ptr, int value, std::size_t count) {
if (ptr == nullptr && count != 0) {
return 1;
}
std::memset(ptr, value, count);
return 0;
}

static int GetMemInfo(std::size_t* free_bytes, std::size_t* total_bytes) {
if (free_bytes == nullptr || total_bytes == nullptr) {
return 1;
}
*free_bytes = 0;
*total_bytes = 0;

#ifndef _WIN32
FILE* fp = std::fopen("/proc/meminfo", "r");
if (fp == nullptr) {
return 1;
}

char label[64];
std::size_t value = 0;
while (std::fscanf(fp, "%63s %zu %*s", label, &value) == 2) {
if (std::strcmp(label, "MemTotal:") == 0) {
*total_bytes = value * 1024;
} else if (std::strcmp(label, "MemAvailable:") == 0) {
*free_bytes = value * 1024;
}
}
std::fclose(fp);
#endif
if (*total_bytes == 0) {
return 1;
}
return 0;
}

static int StreamCreate(void** stream) {
if (stream == nullptr) {
return 1;
}
*stream = nullptr;
return 0;
}

static int StreamDestroy(void*) { return 0; }

static int StreamSynchronize(void*) { return 0; }

static int StreamWaitEvent(void*, void*) { return 0; }

using Event = std::chrono::steady_clock::time_point;

static int EventCreate(void** event) {
if (event == nullptr) {
return 1;
}
*event = new (std::nothrow) Event(std::chrono::steady_clock::now());
return *event == nullptr ? 1 : 0;
}

static int EventCreateWithFlags(void** event, uint32_t) {
return EventCreate(event);
}

static int EventRecord(void* event, void*) {
if (event == nullptr) {
return 1;
}
*static_cast<Event*>(event) = std::chrono::steady_clock::now();
return 0;
}

static int EventQuery(void* event, int* status) {
if (event == nullptr || status == nullptr) {
return 1;
}
*status = 0;
return 0;
}

static int EventSynchronize(void* event) { return event == nullptr ? 1 : 0; }

static int EventDestroy(void* event) {
delete static_cast<Event*>(event);
return 0;
}

static int EventElapsedTime(float* ms, void* start, void* end) {
if (ms == nullptr || start == nullptr || end == nullptr) {
return 1;
}
const auto* start_time = static_cast<const Event*>(start);
const auto* end_time = static_cast<const Event*>(end);
const auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
*end_time - *start_time);
*ms = static_cast<float>(duration.count()) / 1000.0f;
return 0;
}

static int MallocHost(void** ptr, std::size_t size) {
return Malloc(ptr, size);
}

static int FreeHost(void* ptr) { return Free(ptr); }

static int MemcpyAsync(void* dst, const void* src, std::size_t size, int kind,
void*) {
return Memcpy(dst, src, size, kind);
}

static int MallocAsync(void** ptr, std::size_t size, void*) {
return Malloc(ptr, size);
}

static int FreeAsync(void* ptr, void*) { return Free(ptr); }

static int MemsetAsync(void* ptr, int value, std::size_t count, void*) {
return Memset(ptr, value, count);
}

static constexpr int MemcpyHostToHost = 0;
Expand Down
Loading
Loading