diff --git a/.gitignore b/.gitignore index 99bca7d..48539b4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Generated files build/ +build-*/ generated/ # Prerequisites diff --git a/include/infini/rt.h b/include/infini/rt.h index 3adb6eb..e687976 100644 --- a/include/infini/rt.h +++ b/include/infini/rt.h @@ -1,6 +1,7 @@ #ifndef INFINI_RT_PUBLIC_H_ #define INFINI_RT_PUBLIC_H_ +#include #include #endif diff --git a/include/infini/rt/c_api.h b/include/infini/rt/c_api.h new file mode 100644 index 0000000..c5ad65f --- /dev/null +++ b/include/infini/rt/c_api.h @@ -0,0 +1,78 @@ +#ifndef INFINI_RT_C_API_H_ +#define INFINI_RT_C_API_H_ + +#if defined(_WIN32) +#define INFINI_RT_EXPORT __declspec(dllexport) +#elif defined(__GNUC__) && \ + ((__GNUC__ >= 4) || (__GNUC__ == 3 && __GNUC_MINOR__ >= 3)) +#define INFINI_RT_EXPORT __attribute__((visibility("default"))) +#else +#define INFINI_RT_EXPORT +#endif + +#ifdef __cplusplus +#define INFINI_RT_EXTERN_C extern "C" +#else +#define INFINI_RT_EXTERN_C +#endif + +typedef enum { + INFINI_RT_STATUS_SUCCESS = 0, + INFINI_RT_STATUS_INVALID_ARGUMENT = 1, + INFINI_RT_STATUS_UNSUPPORTED_DEVICE = 2, + INFINI_RT_STATUS_RUNTIME_ERROR = 3, +} infiniRtStatus_t; + +typedef enum { + INFINI_RT_DEVICE_CPU = 0, + INFINI_RT_DEVICE_NVIDIA = 1, + INFINI_RT_DEVICE_CAMBRICON = 2, + INFINI_RT_DEVICE_ASCEND = 3, + INFINI_RT_DEVICE_METAX = 4, + INFINI_RT_DEVICE_MOORE = 5, + INFINI_RT_DEVICE_ILUVATAR = 6, + INFINI_RT_DEVICE_KUNLUN = 7, + INFINI_RT_DEVICE_HYGON = 8, + INFINI_RT_DEVICE_QY = 9, +} infiniRtDeviceType_t; + +typedef struct { + infiniRtDeviceType_t type; + int index; +} infiniRtDevice_t; + +typedef enum { + INFINI_RT_STREAM_CAPTURE_MODE_GLOBAL = 0, + INFINI_RT_STREAM_CAPTURE_MODE_THREAD_LOCAL = 1, + INFINI_RT_STREAM_CAPTURE_MODE_RELAXED = 2, +} infiniRtStreamCaptureMode_t; + +typedef void* infiniRtStream_t; +typedef void* infiniRtGraph_t; +typedef void* infiniRtGraphExec_t; + +INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t infiniRtStreamWrap( + infiniRtDevice_t device, void* native_stream, infiniRtStream_t* stream); + +INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t +infiniRtStreamDestroy(infiniRtStream_t stream); + +INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t infiniRtStreamBeginCapture( + infiniRtStream_t stream, infiniRtStreamCaptureMode_t mode); + +INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t +infiniRtStreamEndCapture(infiniRtStream_t stream, infiniRtGraph_t* graph); + +INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t +infiniRtGraphDestroy(infiniRtGraph_t graph); + +INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t infiniRtGraphInstantiate( + infiniRtGraphExec_t* graph_exec, infiniRtGraph_t graph); + +INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t +infiniRtGraphExecDestroy(infiniRtGraphExec_t graph_exec); + +INFINI_RT_EXTERN_C INFINI_RT_EXPORT infiniRtStatus_t +infiniRtGraphLaunch(infiniRtGraphExec_t graph_exec, infiniRtStream_t stream); + +#endif diff --git a/scripts/generate_public_headers.py b/scripts/generate_public_headers.py index 93ceee3..2ab58cc 100644 --- a/scripts/generate_public_headers.py +++ b/scripts/generate_public_headers.py @@ -101,7 +101,7 @@ def _rewrite_detail_include(match): _DETAIL_INCLUDE_PATTERN = re.compile( - r'#include "((?:common|native)/[^"]+|data_type\.h|device\.h|dispatcher\.h|hash\.h|runtime\.h|tensor_view\.h)"' + r'#include "((?:common|native)/[^"]+|data_type\.h|device\.h|dispatcher\.h|graph\.h|hash\.h|runtime\.h|tensor_view\.h)"' ) @@ -133,6 +133,7 @@ def _write_detail_headers(include_root, source_root, devices): "data_type.h", "device.h", "dispatcher.h", + "graph.h", "hash.h", "runtime.h", "tensor_view.h", @@ -158,6 +159,7 @@ def _write_generated_header(include_root, devices): includes = [ f"#include {_detail_include('data_type.h')}", f"#include {_detail_include('device.h')}", + f"#include {_detail_include('graph.h')}", f"#include {_detail_include('hash.h')}", f"#include {_detail_include('runtime.h')}", f"#include {_detail_include('tensor_view.h')}", @@ -210,7 +212,11 @@ def _parse_runtime_functions(runtime_header): _Function( return_type, name, - tuple(_parse_param(param) for param in params.split(", ") if param), + tuple( + _parse_param(param) + for param in re.split(r",\s*", params.strip()) + if param + ), ) for return_type, name, params in re.findall( r"^(void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE @@ -239,6 +245,8 @@ def _selector(function): return f"{param.name}.type()" if param.type == "Device::Type": return param.name + if param.type in {"Stream", "Graph", "GraphExec"}: + return f"{param.name}.device_type()" return "current_device.type()" @@ -250,6 +258,13 @@ def _runtime_arg(param): return None if param.type == "MemcpyKind": return f"RuntimeMemcpyKind<__DEVICE_TYPE__>({param.name})" + if param.type == "StreamCaptureMode": + return f"RuntimeStreamCaptureMode<__DEVICE_TYPE__>({param.name})" + if param.type in {"Stream", "Graph", "GraphExec"}: + return ( + f"static_cast::{param.type}>" + f"({param.name}.raw())" + ) return param.name @@ -264,6 +279,9 @@ def _preconditions(function): required_pointer_names = { "GetDevice": {"device"}, "GetDeviceCount": {"count"}, + "StreamCreate": {"stream"}, + "StreamEndCapture": {"graph"}, + "GraphInstantiate": {"graph_exec"}, } checks = [] for param in function.params: @@ -290,6 +308,27 @@ def _runtime_call(function): return f"Runtime<__DEVICE_TYPE__>::{function.name}()" +def _write_stream_create(function, devices): + stream_param = function.params[0].name + cases = _dispatch_cases( + devices, + f""" typename Runtime<__DEVICE_TYPE__>::Stream raw_stream = {{}}; + CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::StreamCreate(&raw_stream); }}); + *{stream_param} = Stream{{__DEVICE_TYPE__, static_cast(raw_stream)}};""", + ) + + return f"""void StreamCreate(Stream* {stream_param}) {{ + assert({stream_param} != nullptr); + + switch (current_device.type()) {{ +{cases} + default: +{_abort_statement("runtime device is not enabled")} + }} +}} +""" + + def _write_get_device(function, devices): device_param = function.params[0].name cases = _dispatch_cases( @@ -312,9 +351,81 @@ def _write_get_device(function, devices): """ +def _write_stream_end_capture(function, devices): + stream_param = function.params[0].name + graph_param = function.params[1].name + cases = _dispatch_cases( + devices, + f""" typename Runtime<__DEVICE_TYPE__>::Graph raw_graph = {{}}; + CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::StreamEndCapture(static_cast::Stream>({stream_param}.raw()), &raw_graph); }}); + *{graph_param} = Graph{{__DEVICE_TYPE__, static_cast(raw_graph)}};""", + ) + + return f"""void StreamEndCapture(Stream {stream_param}, Graph* {graph_param}) {{ + assert({graph_param} != nullptr); + + switch ({stream_param}.device_type()) {{ +{cases} + default: +{_abort_statement("runtime device is not enabled")} + }} +}} +""" + + +def _write_graph_instantiate(function, devices): + graph_exec_param = function.params[0].name + graph_param = function.params[1].name + cases = _dispatch_cases( + devices, + f""" typename Runtime<__DEVICE_TYPE__>::GraphExec raw_graph_exec = {{}}; + CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GraphInstantiate(&raw_graph_exec, static_cast::Graph>({graph_param}.raw())); }}); + *{graph_exec_param} = GraphExec{{__DEVICE_TYPE__, static_cast(raw_graph_exec)}};""", + ) + + return f"""void GraphInstantiate(GraphExec* {graph_exec_param}, Graph {graph_param}) {{ + assert({graph_exec_param} != nullptr); + + switch ({graph_param}.device_type()) {{ +{cases} + default: +{_abort_statement("runtime device is not enabled")} + }} +}} +""" + + +def _write_graph_launch(function, devices): + graph_exec_param = function.params[0].name + stream_param = function.params[1].name + cases = _dispatch_cases( + devices, + f""" CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GraphLaunch(static_cast::GraphExec>({graph_exec_param}.raw()), static_cast::Stream>({stream_param}.raw())); }});""", + ) + + return f"""void GraphLaunch(GraphExec {graph_exec_param}, Stream {stream_param}) {{ + assert({graph_exec_param}.device_type() == {stream_param}.device_type()); + + switch ({graph_exec_param}.device_type()) {{ +{cases} + default: +{_abort_statement("runtime device is not enabled")} + }} +}} +""" + + def _write_dispatch_function(function, devices): if function.name == "GetDevice": return _write_get_device(function, devices) + if function.name == "StreamCreate": + return _write_stream_create(function, devices) + if function.name == "StreamEndCapture": + return _write_stream_end_capture(function, devices) + if function.name == "GraphInstantiate": + return _write_graph_instantiate(function, devices) + if function.name == "GraphLaunch": + return _write_graph_launch(function, devices) cases = _dispatch_cases( devices, @@ -390,6 +501,22 @@ def _write_runtime_dispatch(source_path, runtime_header, devices): return Runtime::MemcpyHostToHost; }} +template +auto RuntimeStreamCaptureMode(StreamCaptureMode mode) {{ + switch (mode) {{ + case StreamCaptureMode::kGlobal: + return Runtime::StreamCaptureModeGlobal; + case StreamCaptureMode::kThreadLocal: + return Runtime::StreamCaptureModeThreadLocal; + case StreamCaptureMode::kRelaxed: + return Runtime::StreamCaptureModeRelaxed; + }} + + assert(false && "unsupported stream capture mode"); + std::abort(); + return Runtime::StreamCaptureModeGlobal; +}} + }} // namespace {dispatch_functions} diff --git a/src/c_api.cc b/src/c_api.cc new file mode 100644 index 0000000..178514a --- /dev/null +++ b/src/c_api.cc @@ -0,0 +1,306 @@ +#include + +#include +#include +#include + +#include "runtime.h" + +#if defined(WITH_NVIDIA) +#include "native/cuda/nvidia/runtime_.h" +#endif + +namespace { + +using infini::rt::Device; +using infini::rt::Graph; +using infini::rt::GraphExec; +using infini::rt::Runtime; +using infini::rt::Stream; + +struct CStream { + Stream stream; +}; + +struct CGraph { + Graph graph; +}; + +struct CGraphExec { + GraphExec graph_exec; +}; + +template +infiniRtStatus_t Guard(Func&& func) { + try { + return std::forward(func)(); + } catch (const std::bad_alloc&) { + return INFINI_RT_STATUS_RUNTIME_ERROR; + } catch (...) { + return INFINI_RT_STATUS_RUNTIME_ERROR; + } +} + +template +infiniRtStatus_t CheckBackendCall(Func&& func) { + using ReturnType = decltype(std::forward(func)()); + if constexpr (std::is_void_v) { + std::forward(func)(); + return INFINI_RT_STATUS_SUCCESS; + } else { + return std::forward(func)() == ReturnType{} + ? INFINI_RT_STATUS_SUCCESS + : INFINI_RT_STATUS_RUNTIME_ERROR; + } +} + +Device::Type ToCppDeviceType(infiniRtDeviceType_t type) { + switch (type) { + case INFINI_RT_DEVICE_CPU: + return Device::Type::kCpu; + case INFINI_RT_DEVICE_NVIDIA: + return Device::Type::kNvidia; + case INFINI_RT_DEVICE_CAMBRICON: + return Device::Type::kCambricon; + case INFINI_RT_DEVICE_ASCEND: + return Device::Type::kAscend; + case INFINI_RT_DEVICE_METAX: + return Device::Type::kMetax; + case INFINI_RT_DEVICE_MOORE: + return Device::Type::kMoore; + case INFINI_RT_DEVICE_ILUVATAR: + return Device::Type::kIluvatar; + case INFINI_RT_DEVICE_KUNLUN: + return Device::Type::kKunlun; + case INFINI_RT_DEVICE_HYGON: + return Device::Type::kHygon; + case INFINI_RT_DEVICE_QY: + return Device::Type::kQy; + } + return Device::Type::kCount; +} + +#if defined(WITH_NVIDIA) +auto ToNvidiaCaptureMode(infiniRtStreamCaptureMode_t mode) { + switch (mode) { + case INFINI_RT_STREAM_CAPTURE_MODE_GLOBAL: + return Runtime::StreamCaptureModeGlobal; + case INFINI_RT_STREAM_CAPTURE_MODE_THREAD_LOCAL: + return Runtime::StreamCaptureModeThreadLocal; + case INFINI_RT_STREAM_CAPTURE_MODE_RELAXED: + return Runtime::StreamCaptureModeRelaxed; + } + return Runtime::StreamCaptureModeRelaxed; +} + +auto RawNvidiaStream(Stream stream) { + return static_cast::Stream>( + stream.raw()); +} + +auto RawNvidiaGraph(Graph graph) { + return static_cast::Graph>( + graph.raw()); +} + +auto RawNvidiaGraphExec(GraphExec graph_exec) { + return static_cast::GraphExec>( + graph_exec.raw()); +} +#endif + +CStream* AsStream(infiniRtStream_t stream) { + return static_cast(stream); +} + +CGraph* AsGraph(infiniRtGraph_t graph) { return static_cast(graph); } + +CGraphExec* AsGraphExec(infiniRtGraphExec_t graph_exec) { + return static_cast(graph_exec); +} + +infiniRtStatus_t Unsupported() { return INFINI_RT_STATUS_UNSUPPORTED_DEVICE; } + +} // namespace + +infiniRtStatus_t infiniRtStreamWrap(infiniRtDevice_t device, + void* native_stream, + infiniRtStream_t* stream) { + if (native_stream == nullptr || stream == nullptr) { + return INFINI_RT_STATUS_INVALID_ARGUMENT; + } + return Guard([&] { + const auto device_type = ToCppDeviceType(device.type); + if (device_type == Device::Type::kCount) { + return INFINI_RT_STATUS_UNSUPPORTED_DEVICE; + } + *stream = new CStream{Stream{device_type, native_stream}}; + return INFINI_RT_STATUS_SUCCESS; + }); +} + +infiniRtStatus_t infiniRtStreamDestroy(infiniRtStream_t stream) { + if (stream == nullptr) { + return INFINI_RT_STATUS_INVALID_ARGUMENT; + } + delete AsStream(stream); + return INFINI_RT_STATUS_SUCCESS; +} + +infiniRtStatus_t infiniRtStreamBeginCapture(infiniRtStream_t stream, + infiniRtStreamCaptureMode_t mode) { + if (stream == nullptr) { + return INFINI_RT_STATUS_INVALID_ARGUMENT; + } + return Guard([&] { + auto* wrapped = AsStream(stream); + switch (wrapped->stream.device_type()) { +#if defined(WITH_NVIDIA) + case Device::Type::kNvidia: + return CheckBackendCall([&] { + return Runtime::StreamBeginCapture( + RawNvidiaStream(wrapped->stream), ToNvidiaCaptureMode(mode)); + }); +#endif + default: + return Unsupported(); + } + }); +} + +infiniRtStatus_t infiniRtStreamEndCapture(infiniRtStream_t stream, + infiniRtGraph_t* graph) { + if (stream == nullptr || graph == nullptr) { + return INFINI_RT_STATUS_INVALID_ARGUMENT; + } + return Guard([&] { + auto* wrapped = AsStream(stream); + switch (wrapped->stream.device_type()) { +#if defined(WITH_NVIDIA) + case Device::Type::kNvidia: { + typename Runtime::Graph raw_graph = {}; + const auto status = CheckBackendCall([&] { + return Runtime::StreamEndCapture( + RawNvidiaStream(wrapped->stream), &raw_graph); + }); + if (status != INFINI_RT_STATUS_SUCCESS) { + return status; + } + *graph = new CGraph{ + Graph{Device::Type::kNvidia, static_cast(raw_graph)}}; + return INFINI_RT_STATUS_SUCCESS; + } +#endif + default: + return Unsupported(); + } + }); +} + +infiniRtStatus_t infiniRtGraphDestroy(infiniRtGraph_t graph) { + if (graph == nullptr) { + return INFINI_RT_STATUS_INVALID_ARGUMENT; + } + return Guard([&] { + auto* wrapped = AsGraph(graph); + switch (wrapped->graph.device_type()) { +#if defined(WITH_NVIDIA) + case Device::Type::kNvidia: { + const auto status = CheckBackendCall([&] { + return Runtime::GraphDestroy( + RawNvidiaGraph(wrapped->graph)); + }); + // The C wrapper owns only the wrapper object. The backend destroy call + // above owns the native graph handle. + delete wrapped; + return status; + } +#endif + default: + delete wrapped; + return Unsupported(); + } + }); +} + +infiniRtStatus_t infiniRtGraphInstantiate(infiniRtGraphExec_t* graph_exec, + infiniRtGraph_t graph) { + if (graph_exec == nullptr || graph == nullptr) { + return INFINI_RT_STATUS_INVALID_ARGUMENT; + } + return Guard([&] { + auto* wrapped = AsGraph(graph); + switch (wrapped->graph.device_type()) { +#if defined(WITH_NVIDIA) + case Device::Type::kNvidia: { + typename Runtime::GraphExec raw_exec = {}; + const auto status = CheckBackendCall([&] { + return Runtime::GraphInstantiate( + &raw_exec, RawNvidiaGraph(wrapped->graph)); + }); + if (status != INFINI_RT_STATUS_SUCCESS) { + return status; + } + *graph_exec = new CGraphExec{ + GraphExec{Device::Type::kNvidia, static_cast(raw_exec)}}; + return INFINI_RT_STATUS_SUCCESS; + } +#endif + default: + return Unsupported(); + } + }); +} + +infiniRtStatus_t infiniRtGraphExecDestroy(infiniRtGraphExec_t graph_exec) { + if (graph_exec == nullptr) { + return INFINI_RT_STATUS_INVALID_ARGUMENT; + } + return Guard([&] { + auto* wrapped = AsGraphExec(graph_exec); + switch (wrapped->graph_exec.device_type()) { +#if defined(WITH_NVIDIA) + case Device::Type::kNvidia: { + const auto status = CheckBackendCall([&] { + return Runtime::GraphExecDestroy( + RawNvidiaGraphExec(wrapped->graph_exec)); + }); + // The C wrapper owns only the wrapper object. The backend destroy call + // above owns the native executable graph handle. + delete wrapped; + return status; + } +#endif + default: + delete wrapped; + return Unsupported(); + } + }); +} + +infiniRtStatus_t infiniRtGraphLaunch(infiniRtGraphExec_t graph_exec, + infiniRtStream_t stream) { + if (graph_exec == nullptr || stream == nullptr) { + return INFINI_RT_STATUS_INVALID_ARGUMENT; + } + return Guard([&] { + auto* exec = AsGraphExec(graph_exec); + auto* wrapped_stream = AsStream(stream); + if (exec->graph_exec.device_type() != + wrapped_stream->stream.device_type()) { + return INFINI_RT_STATUS_INVALID_ARGUMENT; + } + switch (exec->graph_exec.device_type()) { +#if defined(WITH_NVIDIA) + case Device::Type::kNvidia: + return CheckBackendCall([&] { + return Runtime::GraphLaunch( + RawNvidiaGraphExec(exec->graph_exec), + RawNvidiaStream(wrapped_stream->stream)); + }); +#endif + default: + return Unsupported(); + } + }); +} diff --git a/src/graph.h b/src/graph.h new file mode 100644 index 0000000..4a972e3 --- /dev/null +++ b/src/graph.h @@ -0,0 +1,74 @@ +#ifndef INFINI_RT_GRAPH_H_ +#define INFINI_RT_GRAPH_H_ + +#include + +#include "device.h" + +namespace infini::rt { + +enum class StreamCaptureMode { + kGlobal = 0, + kThreadLocal = 1, + kRelaxed = 2, +}; + +// Public dispatch wrappers keep backend handles opaque while preserving +// enough device identity for cross-device graph dispatch. +class Stream { + public: + Stream() = default; + + Stream(Device::Type device_type, void* raw) + : device_type_{device_type}, raw_{raw} {} + + Device::Type device_type() const { return device_type_; } + + void* raw() const { return raw_; } + + explicit operator bool() const { return raw_ != nullptr; } + + private: + Device::Type device_type_{Device::Type::kCpu}; + void* raw_{nullptr}; +}; + +class Graph { + public: + Graph() = default; + + Graph(Device::Type device_type, void* raw) + : device_type_{device_type}, raw_{raw} {} + + Device::Type device_type() const { return device_type_; } + + void* raw() const { return raw_; } + + explicit operator bool() const { return raw_ != nullptr; } + + private: + Device::Type device_type_{Device::Type::kCpu}; + void* raw_{nullptr}; +}; + +class GraphExec { + public: + GraphExec() = default; + + GraphExec(Device::Type device_type, void* raw) + : device_type_{device_type}, raw_{raw} {} + + Device::Type device_type() const { return device_type_; } + + void* raw() const { return raw_; } + + explicit operator bool() const { return raw_ != nullptr; } + + private: + Device::Type device_type_{Device::Type::kCpu}; + void* raw_{nullptr}; +}; + +} // namespace infini::rt + +#endif diff --git a/src/native/ascend/runtime_.h b/src/native/ascend/runtime_.h index 8b33e54..c183714 100644 --- a/src/native/ascend/runtime_.h +++ b/src/native/ascend/runtime_.h @@ -18,6 +18,10 @@ struct Runtime : DeviceRuntime> { using Stream = aclrtStream; + using Graph = void*; + + using GraphExec = void*; + static constexpr Device::Type kDeviceType = Device::Type::kAscend; static constexpr auto SetDevice = aclrtSetDevice; @@ -45,6 +49,11 @@ struct Runtime return aclrtMemcpy(dst, count, src, count, kind); }; + static auto MemcpyAsync(void* dst, const void* src, size_t count, + aclrtMemcpyKind kind, Stream stream) { + return aclrtMemcpyAsync(dst, count, src, count, kind, stream); + } + static constexpr auto MemcpyHostToHost = ACL_MEMCPY_HOST_TO_HOST; static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; @@ -53,9 +62,35 @@ struct Runtime static constexpr auto MemcpyDeviceToDevice = ACL_MEMCPY_DEVICE_TO_DEVICE; - static constexpr auto Memset = [](void* ptr, int value, size_t count) { + static auto Memset(void* ptr, int value, size_t count) { return aclrtMemset(ptr, count, value, count); - }; + } + + static auto StreamCreate(Stream* stream) { + return aclrtCreateStreamWithConfig(stream, 0, ACL_STREAM_FAST_LAUNCH); + } + + static constexpr auto StreamDestroy = aclrtDestroyStream; + + static constexpr auto StreamSynchronize = aclrtSynchronizeStream; + + static constexpr int StreamCaptureModeGlobal = 0; + + static constexpr int StreamCaptureModeThreadLocal = 1; + + static constexpr int StreamCaptureModeRelaxed = 2; + + static int StreamBeginCapture(Stream, int) { return 1; } + + static int StreamEndCapture(Stream, Graph*) { return 1; } + + static int GraphDestroy(Graph) { return 1; } + + static int GraphInstantiate(GraphExec*, Graph) { return 1; } + + static int GraphExecDestroy(GraphExec) { return 1; } + + static int GraphLaunch(GraphExec, Stream) { return 1; } }; static_assert(Runtime::Validate()); diff --git a/src/native/cambricon/runtime_.h b/src/native/cambricon/runtime_.h index 4db4920..76e6a2d 100644 --- a/src/native/cambricon/runtime_.h +++ b/src/native/cambricon/runtime_.h @@ -16,6 +16,10 @@ struct Runtime : DeviceRuntime> { using Stream = cnrtQueue_t; + using Graph = void*; + + using GraphExec = void*; + static constexpr Device::Type kDeviceType = Device::Type::kCambricon; static constexpr auto SetDevice = cnrtSetDevice; @@ -41,6 +45,11 @@ struct Runtime return cnrtMemcpy(dst, const_cast(src), size, kind); }; + static auto MemcpyAsync(void* dst, const void* src, std::size_t size, + cnrtMemTransDir_t kind, Stream stream) { + return cnrtMemcpyAsync_V2(dst, const_cast(src), size, stream, kind); + } + static constexpr auto MemcpyHostToHost = cnrtMemcpyHostToHost; static constexpr auto MemcpyHostToDevice = cnrtMemcpyHostToDev; @@ -50,6 +59,31 @@ struct Runtime static constexpr auto MemcpyDeviceToDevice = cnrtMemcpyDevToDev; static constexpr auto Memset = cnrtMemset; + + static constexpr auto StreamCreate = cnrtQueueCreate; + + static constexpr auto StreamDestroy = cnrtQueueDestroy; + + static constexpr auto StreamSynchronize = cnrtQueueSync; + + static constexpr auto StreamCaptureModeGlobal = cnrtQueueCaptureModeGlobal; + + static constexpr auto StreamCaptureModeThreadLocal = + cnrtQueueCaptureModeThreadLocal; + + static constexpr auto StreamCaptureModeRelaxed = cnrtQueueCaptureModeRelaxed; + + static int StreamBeginCapture(Stream, cnrtQueueCaptureMode_t) { return 1; } + + static int StreamEndCapture(Stream, Graph*) { return 1; } + + static int GraphDestroy(Graph) { return 1; } + + static int GraphInstantiate(GraphExec*, Graph) { return 1; } + + static int GraphExecDestroy(GraphExec) { return 1; } + + static int GraphLaunch(GraphExec, Stream) { return 1; } }; static_assert(Runtime::Validate()); diff --git a/src/native/cpu/runtime_.h b/src/native/cpu/runtime_.h index bf5a81c..3946d95 100644 --- a/src/native/cpu/runtime_.h +++ b/src/native/cpu/runtime_.h @@ -11,6 +11,12 @@ namespace infini::rt { template <> struct Runtime : RuntimeBase> { + using Stream = void*; + + using Graph = void*; + + using GraphExec = void*; + static constexpr Device::Type kDeviceType = Device::Type::kCpu; static void SetDevice(int index) { @@ -40,8 +46,8 @@ struct Runtime : RuntimeBase> { std::memcpy(dst, src, size); } - static void Memset(void* ptr, int value, std::size_t count) { - std::memset(ptr, value, count); + static int MemcpyAsync(void*, const void*, std::size_t, int, Stream) { + return 1; } static constexpr int MemcpyHostToHost = 0; @@ -51,6 +57,34 @@ struct Runtime : RuntimeBase> { static constexpr int MemcpyDeviceToHost = 1; static constexpr int MemcpyDeviceToDevice = 0; + + static void Memset(void* ptr, int value, std::size_t count) { + std::memset(ptr, value, count); + } + + static int StreamCreate(Stream*) { return 1; } + + static int StreamDestroy(Stream) { return 1; } + + static int StreamSynchronize(Stream) { return 1; } + + static constexpr int StreamCaptureModeGlobal = 0; + + static constexpr int StreamCaptureModeThreadLocal = 1; + + static constexpr int StreamCaptureModeRelaxed = 2; + + static int StreamBeginCapture(Stream, int) { return 1; } + + static int StreamEndCapture(Stream, Graph*) { return 1; } + + static int GraphDestroy(Graph) { return 1; } + + static int GraphInstantiate(GraphExec*, Graph) { return 1; } + + static int GraphExecDestroy(GraphExec) { return 1; } + + static int GraphLaunch(GraphExec, Stream) { return 1; } }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/iluvatar/runtime_.h b/src/native/cuda/iluvatar/runtime_.h index 8a1b649..81a5cd2 100644 --- a/src/native/cuda/iluvatar/runtime_.h +++ b/src/native/cuda/iluvatar/runtime_.h @@ -17,6 +17,10 @@ struct Runtime : CudaRuntime> { using Stream = cudaStream_t; + using Graph = cudaGraph_t; + + using GraphExec = cudaGraphExec_t; + static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; static constexpr auto SetDevice = cudaSetDevice; @@ -44,6 +48,54 @@ struct Runtime static constexpr auto MemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; + + static constexpr auto StreamCreate = [](auto&&... args) { + return cudaStreamCreateWithFlags(std::forward(args)..., + cudaStreamNonBlocking); + }; + + static constexpr auto StreamDestroy = [](auto&&... args) { + return cudaStreamDestroy(std::forward(args)...); + }; + + static constexpr auto StreamSynchronize = [](auto&&... args) { + return cudaStreamSynchronize(std::forward(args)...); + }; + + static constexpr auto MemcpyAsync = [](auto&&... args) { + return cudaMemcpyAsync(std::forward(args)...); + }; + + static constexpr auto StreamCaptureModeGlobal = cudaStreamCaptureModeGlobal; + + static constexpr auto StreamCaptureModeThreadLocal = + cudaStreamCaptureModeThreadLocal; + + static constexpr auto StreamCaptureModeRelaxed = cudaStreamCaptureModeRelaxed; + + static constexpr auto StreamBeginCapture = [](auto&&... args) { + return cudaStreamBeginCapture(std::forward(args)...); + }; + + static constexpr auto StreamEndCapture = [](auto&&... args) { + return cudaStreamEndCapture(std::forward(args)...); + }; + + static constexpr auto GraphDestroy = [](auto&&... args) { + return cudaGraphDestroy(std::forward(args)...); + }; + + static constexpr auto GraphInstantiate = [](auto&&... args) { + return cudaGraphInstantiate(std::forward(args)...); + }; + + static constexpr auto GraphExecDestroy = [](auto&&... args) { + return cudaGraphExecDestroy(std::forward(args)...); + }; + + static constexpr auto GraphLaunch = [](auto&&... args) { + return cudaGraphLaunch(std::forward(args)...); + }; }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/metax/runtime_.h b/src/native/cuda/metax/runtime_.h index 5785a51..3e7fb5c 100644 --- a/src/native/cuda/metax/runtime_.h +++ b/src/native/cuda/metax/runtime_.h @@ -3,6 +3,7 @@ #include +#include #include #include "native/cuda/metax/device_.h" @@ -15,6 +16,10 @@ struct Runtime : CudaRuntime> { using Stream = mcStream_t; + using Graph = void*; + + using GraphExec = void*; + static constexpr Device::Type kDeviceType = Device::Type::kMetax; static constexpr auto SetDevice = mcSetDevice; @@ -46,6 +51,35 @@ struct Runtime static constexpr auto MemcpyDeviceToDevice = mcMemcpyDeviceToDevice; static constexpr auto Memset = mcMemset; + + static int StreamCreate(Stream*) { return 1; } + + static int StreamDestroy(Stream) { return 1; } + + static int StreamSynchronize(Stream) { return 1; } + + static int MemcpyAsync(void*, const void*, std::size_t, + decltype(MemcpyHostToDevice), Stream) { + return 1; + } + + static constexpr int StreamCaptureModeGlobal = 0; + + static constexpr int StreamCaptureModeThreadLocal = 1; + + static constexpr int StreamCaptureModeRelaxed = 2; + + static int StreamBeginCapture(Stream, int) { return 1; } + + static int StreamEndCapture(Stream, Graph*) { return 1; } + + static int GraphDestroy(Graph) { return 1; } + + static int GraphInstantiate(GraphExec*, Graph) { return 1; } + + static int GraphExecDestroy(GraphExec) { return 1; } + + static int GraphLaunch(GraphExec, Stream) { return 1; } }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/moore/runtime_.h b/src/native/cuda/moore/runtime_.h index 8ced2ed..88702ff 100644 --- a/src/native/cuda/moore/runtime_.h +++ b/src/native/cuda/moore/runtime_.h @@ -3,6 +3,7 @@ #include +#include #include #include "native/cuda/moore/device_.h" @@ -15,6 +16,10 @@ struct Runtime : CudaRuntime> { using Stream = musaStream_t; + using Graph = void*; + + using GraphExec = void*; + static constexpr Device::Type kDeviceType = Device::Type::kMoore; static constexpr auto SetDevice = musaSetDevice; @@ -52,6 +57,35 @@ struct Runtime static constexpr auto MemcpyDeviceToDevice = musaMemcpyDeviceToDevice; static constexpr auto Memset = musaMemset; + + static int StreamCreate(Stream*) { return 1; } + + static int StreamDestroy(Stream) { return 1; } + + static int StreamSynchronize(Stream) { return 1; } + + static int MemcpyAsync(void*, const void*, std::size_t, + decltype(MemcpyHostToDevice), Stream) { + return 1; + } + + static constexpr int StreamCaptureModeGlobal = 0; + + static constexpr int StreamCaptureModeThreadLocal = 1; + + static constexpr int StreamCaptureModeRelaxed = 2; + + static int StreamBeginCapture(Stream, int) { return 1; } + + static int StreamEndCapture(Stream, Graph*) { return 1; } + + static int GraphDestroy(Graph) { return 1; } + + static int GraphInstantiate(GraphExec*, Graph) { return 1; } + + static int GraphExecDestroy(GraphExec) { return 1; } + + static int GraphLaunch(GraphExec, Stream) { return 1; } }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/nvidia/runtime_.h b/src/native/cuda/nvidia/runtime_.h index f6a9f2d..68651ad 100644 --- a/src/native/cuda/nvidia/runtime_.h +++ b/src/native/cuda/nvidia/runtime_.h @@ -17,6 +17,10 @@ struct Runtime : CudaRuntime> { using Stream = cudaStream_t; + using Graph = cudaGraph_t; + + using GraphExec = cudaGraphExec_t; + static constexpr Device::Type kDeviceType = Device::Type::kNvidia; static constexpr auto SetDevice = cudaSetDevice; @@ -44,6 +48,94 @@ struct Runtime static constexpr auto MemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; static constexpr auto Memset = cudaMemset; + + static constexpr auto StreamCreate = [](auto&&... args) { + return cudaStreamCreateWithFlags(std::forward(args)..., + cudaStreamNonBlocking); + }; + + static constexpr auto StreamDestroy = [](auto&&... args) { + return cudaStreamDestroy(std::forward(args)...); + }; + + static constexpr auto StreamSynchronize = [](auto&&... args) { + return cudaStreamSynchronize(std::forward(args)...); + }; + + static constexpr auto MemcpyAsync = [](auto&&... args) { + return cudaMemcpyAsync(std::forward(args)...); + }; + + static constexpr auto StreamCaptureModeGlobal = cudaStreamCaptureModeGlobal; + + static constexpr auto StreamCaptureModeThreadLocal = + cudaStreamCaptureModeThreadLocal; + + static constexpr auto StreamCaptureModeRelaxed = cudaStreamCaptureModeRelaxed; + + static constexpr auto StreamBeginCapture = [](auto&&... args) { + return cudaStreamBeginCapture(std::forward(args)...); + }; + + static constexpr auto StreamEndCapture = [](auto&&... args) { + return cudaStreamEndCapture(std::forward(args)...); + }; + + static constexpr auto GraphDestroy = [](auto&&... args) { + return cudaGraphDestroy(std::forward(args)...); + }; + + static constexpr auto GraphInstantiate = [](auto&&... args) { + return cudaGraphInstantiate(std::forward(args)...); + }; + + static constexpr auto GraphExecDestroy = [](auto&&... args) { + return cudaGraphExecDestroy(std::forward(args)...); + }; + + static constexpr auto GraphLaunch = [](auto&&... args) { + return cudaGraphLaunch(std::forward(args)...); + }; + + static constexpr bool Validate() { + CudaRuntime>::Validate(); + static_assert(sizeof(Graph) > 0, + "`Runtime` must define a `Graph` type alias."); + static_assert(sizeof(GraphExec) > 0, + "`Runtime` must define a `GraphExec` type alias."); + static_assert(std::is_invocable_v, + "`Runtime::StreamCreate` must be callable with `(Stream*)`."); + static_assert(std::is_invocable_v, + "`Runtime::StreamDestroy` must be callable with `(Stream)`."); + static_assert( + std::is_invocable_v, + "`Runtime::StreamSynchronize` must be callable with `(Stream)`."); + static_assert(std::is_invocable_v, + "`Runtime::MemcpyAsync` must be callable with " + "`(void*, const void*, size_t, cudaMemcpyKind, Stream)`."); + static_assert(std::is_invocable_v, + "`Runtime::StreamBeginCapture` must be callable with " + "`(Stream, cudaStreamCaptureMode)`."); + static_assert( + std::is_invocable_v, + "`Runtime::StreamEndCapture` must be callable with " + "`(Stream, Graph*)`."); + static_assert(std::is_invocable_v, + "`Runtime::GraphDestroy` must be callable with `(Graph)`."); + static_assert( + std::is_invocable_v, + "`Runtime::GraphInstantiate` must be callable with " + "`(GraphExec*, Graph)`."); + static_assert( + std::is_invocable_v, + "`Runtime::GraphExecDestroy` must be callable with `(GraphExec)`."); + static_assert( + std::is_invocable_v, + "`Runtime::GraphLaunch` must be callable with `(GraphExec, Stream)`."); + return true; + } }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/runtime_.h b/src/native/cuda/runtime_.h index 8765a05..1634d0b 100644 --- a/src/native/cuda/runtime_.h +++ b/src/native/cuda/runtime_.h @@ -1,6 +1,7 @@ #ifndef INFINI_RT_CUDA_RUNTIME_H_ #define INFINI_RT_CUDA_RUNTIME_H_ +#include #include #include "runtime.h" @@ -17,7 +18,7 @@ struct CudaRuntime : DeviceRuntime { DeviceRuntime::Validate(); static_assert( std::is_invocable_v, + std::size_t, decltype(Derived::MemcpyHostToDevice)>, "`Runtime::Memcpy` must be callable with " "`(void*, const void*, size_t, MemcpyHostToDevice)`."); return true; diff --git a/src/runtime.h b/src/runtime.h index ebc2698..b81ca92 100644 --- a/src/runtime.h +++ b/src/runtime.h @@ -5,6 +5,7 @@ #include #include "device.h" +#include "graph.h" namespace infini::rt { @@ -70,9 +71,30 @@ void Malloc(void** ptr, std::size_t size); void Free(void* ptr); +void Memcpy(void* dst, const void* src, std::size_t count, MemcpyKind kind); + +void MemcpyAsync(void* dst, const void* src, std::size_t count, MemcpyKind kind, + Stream stream); + void Memset(void* ptr, int value, std::size_t count); -void Memcpy(void* dst, const void* src, std::size_t count, MemcpyKind kind); +void StreamCreate(Stream* stream); + +void StreamDestroy(Stream stream); + +void StreamSynchronize(Stream stream); + +void StreamBeginCapture(Stream stream, StreamCaptureMode mode); + +void StreamEndCapture(Stream stream, Graph* graph); + +void GraphDestroy(Graph graph); + +void GraphInstantiate(GraphExec* graph_exec, Graph graph); + +void GraphExecDestroy(GraphExec graph_exec); + +void GraphLaunch(GraphExec graph_exec, Stream stream); } // namespace infini::rt diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ab54530..0d765c6 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -17,6 +17,8 @@ endif() if(WITH_NVIDIA) add_infini_rt_test(test_nvidia_runtime test_nvidia_runtime.cc) + add_infini_rt_test(test_nvidia_graph test_nvidia_graph.cc) + add_infini_rt_test(test_nvidia_graph_c_api test_nvidia_graph_c_api.cc) endif() set(INFINI_RT_TEST_INSTALL_PREFIX diff --git a/tests/test_nvidia_graph.cc b/tests/test_nvidia_graph.cc new file mode 100644 index 0000000..47f131b --- /dev/null +++ b/tests/test_nvidia_graph.cc @@ -0,0 +1,96 @@ +#include + +#include +#include +#include + +#include "test_helper.h" + +namespace { + +void FillPattern(std::array* input, std::uint8_t salt) { + for (std::size_t i = 0; i < input->size(); ++i) { + (*input)[i] = static_cast(i * 13 + salt); + } +} + +bool CopyDeviceToHostAndValidate(infini::rt::test::TestContext* context, + void* device_ptr, + const std::array& expected, + std::string_view message) { + std::array output{}; + infini::rt::Memcpy(output.data(), device_ptr, output.size(), + infini::rt::MemcpyKind::kDeviceToHost); + return context->ExpectEqual(output, expected, message); +} + +} // namespace + +int main() { + infini::rt::test::TestContext context; + const infini::rt::Device device{infini::rt::Device::Type::kNvidia, 0}; + + infini::rt::SetDevice(device); + + void* src = nullptr; + void* dst = nullptr; + infini::rt::Stream stream; + infini::rt::Graph graph; + infini::rt::GraphExec graph_exec; + + std::array capture_input{}; + FillPattern(&capture_input, 7); + + infini::rt::Malloc(&src, capture_input.size()); + infini::rt::Malloc(&dst, capture_input.size()); + infini::rt::StreamCreate(&stream); + + context.Expect(stream.device_type() == infini::rt::Device::Type::kNvidia, + "Stream should remember its NVIDIA device type."); + + infini::rt::Memcpy(src, capture_input.data(), capture_input.size(), + infini::rt::MemcpyKind::kHostToDevice); + infini::rt::Memset(dst, 0, capture_input.size()); + + infini::rt::StreamBeginCapture(stream, + infini::rt::StreamCaptureMode::kRelaxed); + infini::rt::MemcpyAsync(dst, src, capture_input.size(), + infini::rt::MemcpyKind::kDeviceToDevice, stream); + infini::rt::StreamEndCapture(stream, &graph); + + context.Expect(graph.device_type() == infini::rt::Device::Type::kNvidia, + "Graph should remember its NVIDIA device type."); + + infini::rt::GraphInstantiate(&graph_exec, graph); + context.Expect(graph_exec.device_type() == infini::rt::Device::Type::kNvidia, + "GraphExec should remember its NVIDIA device type."); + + std::array replay_input_1{}; + std::array replay_input_2{}; + FillPattern(&replay_input_1, 31); + FillPattern(&replay_input_2, 53); + + infini::rt::Memcpy(src, replay_input_1.data(), replay_input_1.size(), + infini::rt::MemcpyKind::kHostToDevice); + infini::rt::Memset(dst, 0, replay_input_1.size()); + infini::rt::GraphLaunch(graph_exec, stream); + infini::rt::StreamSynchronize(stream); + CopyDeviceToHostAndValidate(&context, dst, replay_input_1, + "First graph replay should copy D2D data."); + + infini::rt::Memcpy(src, replay_input_2.data(), replay_input_2.size(), + infini::rt::MemcpyKind::kHostToDevice); + infini::rt::Memset(dst, 0, replay_input_2.size()); + infini::rt::GraphLaunch(graph_exec, stream); + infini::rt::StreamSynchronize(stream); + CopyDeviceToHostAndValidate(&context, dst, replay_input_2, + "Second graph replay should copy D2D data."); + + infini::rt::GraphExecDestroy(graph_exec); + infini::rt::GraphDestroy(graph); + infini::rt::StreamDestroy(stream); + infini::rt::Free(dst); + infini::rt::Free(src); + + return context.ExitCode(); +} diff --git a/tests/test_nvidia_graph_c_api.cc b/tests/test_nvidia_graph_c_api.cc new file mode 100644 index 0000000..1a1b2cd --- /dev/null +++ b/tests/test_nvidia_graph_c_api.cc @@ -0,0 +1,133 @@ +#include +#include + +#include +#include +#include + +#include "test_helper.h" + +namespace { + +bool ExpectCudaSuccess(infini::rt::test::TestContext* context, + cudaError_t status, std::string_view message) { + return context->Expect(status == cudaSuccess, message); +} + +bool ExpectRtSuccess(infini::rt::test::TestContext* context, + infiniRtStatus_t status, std::string_view message) { + return context->Expect(status == INFINI_RT_STATUS_SUCCESS, message); +} + +void FillPattern(std::array* input, std::uint8_t salt) { + for (std::size_t i = 0; i < input->size(); ++i) { + (*input)[i] = static_cast(i * 13 + salt); + } +} + +bool CopyDeviceToHostAndValidate(infini::rt::test::TestContext* context, + void* device_ptr, + const std::array& expected, + std::string_view message) { + std::array output{}; + if (!ExpectCudaSuccess(context, + cudaMemcpy(output.data(), device_ptr, output.size(), + cudaMemcpyDeviceToHost), + "Failed to copy device output to host.")) { + return false; + } + return context->ExpectEqual(output, expected, message); +} + +} // namespace + +int main() { + infini::rt::test::TestContext context; + + cudaStream_t native_stream = nullptr; + void* src = nullptr; + void* dst = nullptr; + infiniRtStream_t stream = nullptr; + infiniRtGraph_t graph = nullptr; + infiniRtGraphExec_t graph_exec = nullptr; + + const auto device = infiniRtDevice_t{INFINI_RT_DEVICE_NVIDIA, 0}; + + ExpectCudaSuccess(&context, cudaSetDevice(device.index), + "Failed to set CUDA device."); + ExpectCudaSuccess( + &context, + cudaStreamCreateWithFlags(&native_stream, cudaStreamNonBlocking), + "Failed to create CUDA stream."); + + std::array capture_input{}; + FillPattern(&capture_input, 7); + + ExpectCudaSuccess(&context, cudaMalloc(&src, capture_input.size()), + "Failed to allocate source buffer."); + ExpectCudaSuccess(&context, cudaMalloc(&dst, capture_input.size()), + "Failed to allocate destination buffer."); + ExpectCudaSuccess(&context, + cudaMemcpy(src, capture_input.data(), capture_input.size(), + cudaMemcpyHostToDevice), + "Failed to initialize source buffer."); + ExpectCudaSuccess(&context, cudaMemset(dst, 0, capture_input.size()), + "Failed to initialize destination buffer."); + + ExpectRtSuccess(&context, infiniRtStreamWrap(device, native_stream, &stream), + "Failed to wrap native CUDA stream."); + ExpectRtSuccess( + &context, + infiniRtStreamBeginCapture(stream, INFINI_RT_STREAM_CAPTURE_MODE_RELAXED), + "Failed to begin graph capture through C API."); + ExpectCudaSuccess(&context, + cudaMemcpyAsync(dst, src, capture_input.size(), + cudaMemcpyDeviceToDevice, native_stream), + "Failed to record device-to-device copy."); + ExpectRtSuccess(&context, infiniRtStreamEndCapture(stream, &graph), + "Failed to end graph capture through C API."); + ExpectRtSuccess(&context, infiniRtGraphInstantiate(&graph_exec, graph), + "Failed to instantiate graph through C API."); + + std::array replay_input{}; + FillPattern(&replay_input, 31); + + ExpectCudaSuccess(&context, + cudaMemcpy(src, replay_input.data(), replay_input.size(), + cudaMemcpyHostToDevice), + "Failed to refresh source buffer."); + ExpectCudaSuccess(&context, cudaMemset(dst, 0, replay_input.size()), + "Failed to clear destination buffer."); + ExpectRtSuccess(&context, infiniRtGraphLaunch(graph_exec, stream), + "Failed to launch graph through C API."); + ExpectCudaSuccess(&context, cudaStreamSynchronize(native_stream), + "Failed to synchronize CUDA stream."); + CopyDeviceToHostAndValidate(&context, dst, replay_input, + "C API graph replay should copy D2D data."); + + if (graph_exec != nullptr) { + ExpectRtSuccess(&context, infiniRtGraphExecDestroy(graph_exec), + "Failed to destroy graph exec through C API."); + } + if (graph != nullptr) { + ExpectRtSuccess(&context, infiniRtGraphDestroy(graph), + "Failed to destroy graph through C API."); + } + if (stream != nullptr) { + ExpectRtSuccess(&context, infiniRtStreamDestroy(stream), + "Failed to destroy wrapped stream through C API."); + } + if (dst != nullptr) { + ExpectCudaSuccess(&context, cudaFree(dst), + "Failed to free destination buffer."); + } + if (src != nullptr) { + ExpectCudaSuccess(&context, cudaFree(src), "Failed to free source buffer."); + } + if (native_stream != nullptr) { + ExpectCudaSuccess(&context, cudaStreamDestroy(native_stream), + "Failed to destroy CUDA stream."); + } + + return context.ExitCode(); +}