Skip to content

Commit 672623f

Browse files
committed
[gpu] Support ahead of time offloading of tensors
Add support for ahead of time copy of tensors on the GPU target, relying on the Mlir runner runtime functions for GPUs.
1 parent c419ac4 commit 672623f

3 files changed

Lines changed: 250 additions & 13 deletions

File tree

src/xtc/runtimes/accelerator/gpu/GPUDevice.py

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,16 @@ def __init_once__(self):
5151
f"{get_mlir_prefix()}/lib/{cuda_runtime_lib}"
5252
)
5353
self.loaded_kernels: dict[Module, LibLoader] = {}
54+
create_stream_func_name = "mgpuStreamCreate"
55+
create_stream_func = getattr(
56+
self._mlir_runtime_lib.lib, create_stream_func_name
57+
)
58+
assert create_stream_func is not None, (
59+
f"Cannot find symbol {create_stream_func_name} in lib {self._mlir_runtime_lib.lib}"
60+
)
61+
create_stream_func.argtypes = []
62+
create_stream_func.restype = ctypes.c_voidp
63+
self._custream = create_stream_func()
5464

5565
def __get_runtime_func(self, name: str) -> Callable:
5666
if name in runtime_funcs:
@@ -120,33 +130,89 @@ def unload_module(self, module: Module) -> None:
120130

121131
@override
122132
def memory_allocate(self, size_bytes: int) -> Any:
123-
raise NotImplementedError("memory_allocate is not implemented for GPU device")
133+
func_name = "mgpuMemAlloc"
134+
func = getattr(self._mlir_runtime_lib.lib, func_name)
135+
assert func is not None, (
136+
f"Cannot find symbol {func_name} in lib {self._mlir_runtime_lib.lib}"
137+
)
138+
func.argtypes = [ctypes.c_uint64, ctypes.c_voidp, ctypes.c_bool]
139+
func.restype = ctypes.c_voidp
140+
return func(size_bytes, self._custream, True)
124141

125142
@override
126143
def memory_free(self, handle: Any) -> None:
127-
raise NotImplementedError("memory_free is not implemented for GPU device")
144+
func_name = "mgpuMemFree"
145+
func = getattr(self._mlir_runtime_lib.lib, func_name)
146+
assert func is not None, (
147+
f"Cannot find symbol {func_name} in lib {self._mlir_runtime_lib.lib}"
148+
)
149+
func.argtypes = [ctypes.c_voidp, ctypes.c_voidp]
150+
func.restype = None
151+
func(handle, self._custream)
128152

129153
@override
130154
def memory_copy_to(
131155
self, acc_handle: Any, src: ctypes.c_void_p, size_bytes: int
132156
) -> None:
133-
raise NotImplementedError("memory_copy_to is not implemented for GPU device")
157+
# Copy memory to accelerator device
158+
func_name = "mgpuMemcpy"
159+
func = getattr(self._mlir_runtime_lib.lib, func_name)
160+
assert func is not None, (
161+
f"Cannot find symbol {func_name} in lib {self._mlir_runtime_lib.lib}"
162+
)
163+
func.argtypes = [
164+
ctypes.c_voidp,
165+
ctypes.c_voidp,
166+
ctypes.c_uint64,
167+
ctypes.c_voidp,
168+
]
169+
func.restype = None
170+
func(acc_handle, src, size_bytes, self._custream)
171+
# Synchronize stream
172+
sync_stream_func_name = "mgpuStreamSynchronize"
173+
sync_stream_func = getattr(self._mlir_runtime_lib.lib, sync_stream_func_name)
174+
assert sync_stream_func is not None, (
175+
f"Cannot find symbol {sync_stream_func_name} in lib {self._mlir_runtime_lib.lib}"
176+
)
177+
sync_stream_func.argtypes = [ctypes.c_voidp]
178+
sync_stream_func.restype = None
179+
sync_stream_func(self._custream)
134180

135181
@override
136182
def memory_copy_from(
137183
self, acc_handle: Any, dst: ctypes.c_void_p, size_bytes: int
138184
) -> None:
139-
raise NotImplementedError("memory_copy_from is not implemented for GPU device")
185+
# Copy memory from accelerator device to host
186+
func_name = "mgpuMemcpy"
187+
func = getattr(self._mlir_runtime_lib.lib, func_name)
188+
assert func is not None, (
189+
f"Cannot find symbol {func_name} in lib {self._mlir_runtime_lib.lib}"
190+
)
191+
func.argtypes = [
192+
ctypes.c_voidp,
193+
ctypes.c_voidp,
194+
ctypes.c_uint64,
195+
ctypes.c_voidp,
196+
]
197+
func.restype = None
198+
func(dst, acc_handle, size_bytes, self._custream)
199+
# Synchronize stream
200+
sync_stream_func_name = "mgpuStreamSynchronize"
201+
sync_stream_func = getattr(self._mlir_runtime_lib.lib, sync_stream_func_name)
202+
assert sync_stream_func is not None, (
203+
f"Cannot find symbol {sync_stream_func_name} in lib {self._mlir_runtime_lib.lib}"
204+
)
205+
sync_stream_func.argtypes = [ctypes.c_voidp]
206+
sync_stream_func.restype = None
207+
sync_stream_func(self._custream)
140208

141209
@override
142210
def memory_fill_zero(self, acc_handle: Any, size_bytes: int) -> None:
143211
raise NotImplementedError("memory_fill_zero is not implemented for GPU device")
144212

145213
@override
146214
def memory_data_pointer(self, acc_handle: Any) -> ctypes.c_void_p:
147-
raise NotImplementedError(
148-
"memory_data_pointer is not implemented for GPU device"
149-
)
215+
return ctypes.cast(acc_handle, ctypes.c_void_p)
150216

151217
@override
152218
def evaluate(

src/xtc/targets/accelerator/gpu/GPUEvaluator.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,16 @@ def evaluate(self) -> tuple[list[float], int, str]:
6666

6767
# Map the buffers
6868
# TODO Replace memory mapping of buffers by explicit transfers
69-
for buffer in parameters[0] + parameters[1]:
70-
self._device._register_buffer(
71-
buffer.data, buffer.size * buffer.dtype.itemsize
72-
)
69+
for i, buffer in enumerate(parameters[0]):
70+
if self._np_inputs_spec()[i]["device"] is None:
71+
self._device._register_buffer(
72+
buffer.data, buffer.size * buffer.dtype.itemsize
73+
)
74+
for i, buffer in enumerate(parameters[1]):
75+
if self._np_outputs_spec()[i]["device"] is None:
76+
self._device._register_buffer(
77+
buffer.data, buffer.size * buffer.dtype.itemsize
78+
)
7379

7480
# Check the correctness of the outputs
7581
if self._validate:
@@ -89,8 +95,12 @@ def evaluate(self) -> tuple[list[float], int, str]:
8995
)
9096

9197
# Unmap the buffers
92-
for buffer in parameters[0] + parameters[1]:
93-
self._device._unregister_buffer(buffer.data)
98+
for i, buffer in enumerate(parameters[0]):
99+
if self._np_inputs_spec()[i]["device"] is None:
100+
self._device._unregister_buffer(buffer.data)
101+
for i, buffer in enumerate(parameters[1]):
102+
if self._np_outputs_spec()[i]["device"] is None:
103+
self._device._unregister_buffer(buffer.data)
94104

95105
# Unload the module
96106
self._device.unload_module(self._module)
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# RUN: python %s 2>&1 | filecheck %s
2+
# REQUIRES: mlir-target=nvgpu
3+
4+
import xtc.graphs.xtc.op as O
5+
from xtc.backends.mlir.MlirGraphBackend import MlirGraphBackend as Backend
6+
7+
from xtc.runtimes.accelerator.gpu import GPUDevice
8+
9+
# Create device
10+
gpu = GPUDevice()
11+
12+
I, J, K, dtype = 4, 32, 512, "float32"
13+
a = O.tensor((I, K), dtype, name="A") # A lives on the host
14+
b = O.tensor((K, J), dtype, name="B", device=gpu) # B lives on the accelerator
15+
16+
with O.graph(name="matmul") as gb:
17+
O.matmul(a, b, name="C", device=gpu) # C must live on the accelerator
18+
19+
graph = gb.graph
20+
print(graph)
21+
22+
impl = Backend(graph)
23+
24+
sch = impl.get_scheduler()
25+
sch.tile("i", {"i1": 2})
26+
sch.tile("j", {"j1": 16})
27+
sch.unroll({"i1": 2})
28+
sch.parallelize(["i"])
29+
sched = sch.schedule()
30+
31+
comp = impl.get_compiler(
32+
target=gpu,
33+
shared_lib=True,
34+
dump_file="gpu_matmul_mlir_offload_tensor",
35+
print_source_ir=True,
36+
print_transformed_ir=True,
37+
)
38+
module = comp.compile(sched)
39+
executor = module.get_executor(validate=True)
40+
res = executor.execute()
41+
print(f"CODE: {res}")
42+
# CHECK: // -----// IR Dump Before transform //----- //
43+
# CHECK-NEXT: module attributes {transform.with_named_sequence} {
44+
# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias, memref.on_device}, %arg2: memref<4x32xf32> {llvm.noalias, memref.on_device}) {
45+
# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
46+
# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%arg2 : memref<4x32xf32>)
47+
# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%arg0, %arg1 : memref<4x512xf32>, memref<512x32xf32>) outs(%arg2 : memref<4x32xf32>)
48+
# CHECK-NEXT: return
49+
# CHECK-NEXT: }
50+
# CHECK-NEXT: transform.named_sequence @_vecto(%arg0: !transform.any_op {transform.consumed}) {
51+
# CHECK-NEXT: transform.structured.vectorize %arg0 : !transform.any_op
52+
# CHECK-NEXT: transform.yield
53+
# CHECK-NEXT: }
54+
# CHECK-NEXT: transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
55+
# CHECK-NEXT: %0 = transform.structured.match attributes {__xtc_id_C_0_} in %arg0 : (!transform.any_op) -> !transform.any_op
56+
# CHECK-NEXT: %tiled_linalg_op, %loops = transform.structured.tile_using_for %0 tile_sizes [1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
57+
# CHECK-NEXT: transform.annotate %loops "./i" : !transform.any_op
58+
# CHECK-NEXT: %tiled_linalg_op_0, %loops_1 = transform.structured.tile_using_for %tiled_linalg_op tile_sizes [0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
59+
# CHECK-NEXT: transform.annotate %loops_1 "./j" : !transform.any_op
60+
# CHECK-NEXT: %1 = transform.structured.match attributes {__xtc_id_C_} in %arg0 : (!transform.any_op) -> !transform.any_op
61+
# CHECK-NEXT: %tiled_op, %forall_op = transform.structured.tile_using_forall %1 tile_sizes [2, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
62+
# CHECK-NEXT: transform.annotate %forall_op "./i" : !transform.any_op
63+
# CHECK-NEXT: %tiled_linalg_op_2, %loops_3 = transform.structured.tile_using_for %tiled_op tile_sizes [0, 16, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
64+
# CHECK-NEXT: transform.annotate %loops_3 "./j" : !transform.any_op
65+
# CHECK-NEXT: %tiled_linalg_op_4, %loops_5 = transform.structured.tile_using_for %tiled_linalg_op_2 tile_sizes [0, 0, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
66+
# CHECK-NEXT: transform.annotate %loops_5 "./k" : !transform.any_op
67+
# CHECK-NEXT: %tiled_linalg_op_6, %loops_7 = transform.structured.tile_using_for %tiled_linalg_op_4 tile_sizes [1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
68+
# CHECK-NEXT: transform.annotate %loops_7 "./i1" : !transform.any_op
69+
# CHECK-NEXT: %tiled_linalg_op_8, %loops_9 = transform.structured.tile_using_for %tiled_linalg_op_6 tile_sizes [0, 1, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
70+
# CHECK-NEXT: transform.annotate %loops_9 "./j1" : !transform.any_op
71+
# CHECK-NEXT: transform.loop.unroll %loops_7 {factor = 2 : i64} : !transform.any_op
72+
# CHECK-NEXT: transform.yield
73+
# CHECK-NEXT: }
74+
# CHECK-NEXT: }
75+
# CHECK-NEXT:
76+
# CHECK-NEXT: // -----// IR Dump After transform //----- //
77+
# CHECK-NEXT: #map = affine_map<(d0) -> (d0 * 2)>
78+
# CHECK-NEXT: module attributes {transform.with_named_sequence} {
79+
# CHECK-NEXT: func.func @matmul(%arg0: memref<4x512xf32> {llvm.noalias}, %arg1: memref<512x32xf32> {llvm.noalias, memref.on_device}, %arg2: memref<4x32xf32> {llvm.noalias, memref.on_device}) {
80+
# CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f32
81+
# CHECK-NEXT: %c0 = arith.constant 0 : index
82+
# CHECK-NEXT: %c4 = arith.constant 4 : index
83+
# CHECK-NEXT: %c1 = arith.constant 1 : index
84+
# CHECK-NEXT: scf.for %arg3 = %c0 to %c4 step %c1 {
85+
# CHECK-NEXT: %subview = memref.subview %arg2[%arg3, 0] [1, 32] [1, 1] : memref<4x32xf32> to memref<1x32xf32, strided<[32, 1], offset: ?>>
86+
# CHECK-NEXT: %c0_0 = arith.constant 0 : index
87+
# CHECK-NEXT: %c32 = arith.constant 32 : index
88+
# CHECK-NEXT: %c1_1 = arith.constant 1 : index
89+
# CHECK-NEXT: scf.for %arg4 = %c0_0 to %c32 step %c1_1 {
90+
# CHECK-NEXT: %subview_2 = memref.subview %subview[0, %arg4] [1, 1] [1, 1] : memref<1x32xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>>
91+
# CHECK-NEXT: linalg.fill {__xtc_id_C_0_} ins(%cst : f32) outs(%subview_2 : memref<1x1xf32, strided<[32, 1], offset: ?>>)
92+
# CHECK-NEXT: } {"./j"}
93+
# CHECK-NEXT: } {"./i"}
94+
# CHECK-NEXT: scf.forall (%arg3) in (2) {
95+
# CHECK-NEXT: %0 = affine.apply #map(%arg3)
96+
# CHECK-NEXT: %subview = memref.subview %arg0[%0, 0] [2, 512] [1, 1] : memref<4x512xf32> to memref<2x512xf32, strided<[512, 1], offset: ?>>
97+
# CHECK-NEXT: %subview_0 = memref.subview %arg1[0, 0] [512, 32] [1, 1] : memref<512x32xf32> to memref<512x32xf32, strided<[32, 1]>>
98+
# CHECK-NEXT: %subview_1 = memref.subview %arg2[%0, 0] [2, 32] [1, 1] : memref<4x32xf32> to memref<2x32xf32, strided<[32, 1], offset: ?>>
99+
# CHECK-NEXT: %c0_2 = arith.constant 0 : index
100+
# CHECK-NEXT: %c32 = arith.constant 32 : index
101+
# CHECK-NEXT: %c16 = arith.constant 16 : index
102+
# CHECK-NEXT: scf.for %arg4 = %c0_2 to %c32 step %c16 {
103+
# CHECK-NEXT: %subview_3 = memref.subview %subview[0, 0] [2, 512] [1, 1] : memref<2x512xf32, strided<[512, 1], offset: ?>> to memref<2x512xf32, strided<[512, 1], offset: ?>>
104+
# CHECK-NEXT: %subview_4 = memref.subview %subview_0[0, %arg4] [512, 16] [1, 1] : memref<512x32xf32, strided<[32, 1]>> to memref<512x16xf32, strided<[32, 1], offset: ?>>
105+
# CHECK-NEXT: %subview_5 = memref.subview %subview_1[0, %arg4] [2, 16] [1, 1] : memref<2x32xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>>
106+
# CHECK-NEXT: %c0_6 = arith.constant 0 : index
107+
# CHECK-NEXT: %c512 = arith.constant 512 : index
108+
# CHECK-NEXT: %c1_7 = arith.constant 1 : index
109+
# CHECK-NEXT: scf.for %arg5 = %c0_6 to %c512 step %c1_7 {
110+
# CHECK-NEXT: %subview_8 = memref.subview %subview_3[0, %arg5] [2, 1] [1, 1] : memref<2x512xf32, strided<[512, 1], offset: ?>> to memref<2x1xf32, strided<[512, 1], offset: ?>>
111+
# CHECK-NEXT: %subview_9 = memref.subview %subview_4[%arg5, 0] [1, 16] [1, 1] : memref<512x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>>
112+
# CHECK-NEXT: %subview_10 = memref.subview %subview_5[0, 0] [2, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<2x16xf32, strided<[32, 1], offset: ?>>
113+
# CHECK-NEXT: %c0_11 = arith.constant 0 : index
114+
# CHECK-NEXT: %c2 = arith.constant 2 : index
115+
# CHECK-NEXT: %c1_12 = arith.constant 1 : index
116+
# CHECK-NEXT: %c2_13 = arith.constant 2 : index
117+
# CHECK-NEXT: %subview_14 = memref.subview %subview_8[%c0_11, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>>
118+
# CHECK-NEXT: %subview_15 = memref.subview %subview_9[0, 0] [1, 16] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>>
119+
# CHECK-NEXT: %subview_16 = memref.subview %subview_10[%c0_11, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>>
120+
# CHECK-NEXT: %c0_17 = arith.constant 0 : index
121+
# CHECK-NEXT: %c16_18 = arith.constant 16 : index
122+
# CHECK-NEXT: %c1_19 = arith.constant 1 : index
123+
# CHECK-NEXT: scf.for %arg6 = %c0_17 to %c16_18 step %c1_19 {
124+
# CHECK-NEXT: %subview_27 = memref.subview %subview_14[0, 0] [1, 1] [1, 1] : memref<1x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>>
125+
# CHECK-NEXT: %subview_28 = memref.subview %subview_15[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>>
126+
# CHECK-NEXT: %subview_29 = memref.subview %subview_16[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>>
127+
# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_27, %subview_28 : memref<1x1xf32, strided<[512, 1], offset: ?>>, memref<1x1xf32, strided<[32, 1], offset: ?>>) outs(%subview_29 : memref<1x1xf32, strided<[32, 1], offset: ?>>)
128+
# CHECK-NEXT: } {"./j1"}
129+
# CHECK-NEXT: %c1_20 = arith.constant 1 : index
130+
# CHECK-NEXT: %1 = arith.muli %c1_12, %c1_20 : index
131+
# CHECK-NEXT: %2 = arith.addi %c0_11, %1 : index
132+
# CHECK-NEXT: %subview_21 = memref.subview %subview_8[%2, 0] [1, 1] [1, 1] : memref<2x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>>
133+
# CHECK-NEXT: %subview_22 = memref.subview %subview_9[0, 0] [1, 16] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>>
134+
# CHECK-NEXT: %subview_23 = memref.subview %subview_10[%2, 0] [1, 16] [1, 1] : memref<2x16xf32, strided<[32, 1], offset: ?>> to memref<1x16xf32, strided<[32, 1], offset: ?>>
135+
# CHECK-NEXT: %c0_24 = arith.constant 0 : index
136+
# CHECK-NEXT: %c16_25 = arith.constant 16 : index
137+
# CHECK-NEXT: %c1_26 = arith.constant 1 : index
138+
# CHECK-NEXT: scf.for %arg6 = %c0_24 to %c16_25 step %c1_26 {
139+
# CHECK-NEXT: %subview_27 = memref.subview %subview_21[0, 0] [1, 1] [1, 1] : memref<1x1xf32, strided<[512, 1], offset: ?>> to memref<1x1xf32, strided<[512, 1], offset: ?>>
140+
# CHECK-NEXT: %subview_28 = memref.subview %subview_22[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>>
141+
# CHECK-NEXT: %subview_29 = memref.subview %subview_23[0, %arg6] [1, 1] [1, 1] : memref<1x16xf32, strided<[32, 1], offset: ?>> to memref<1x1xf32, strided<[32, 1], offset: ?>>
142+
# CHECK-NEXT: linalg.matmul {__xtc_id_C_} ins(%subview_27, %subview_28 : memref<1x1xf32, strided<[512, 1], offset: ?>>, memref<1x1xf32, strided<[32, 1], offset: ?>>) outs(%subview_29 : memref<1x1xf32, strided<[32, 1], offset: ?>>)
143+
# CHECK-NEXT: } {"./j1"}
144+
# CHECK-NEXT: } {"./k"}
145+
# CHECK-NEXT: } {"./j"}
146+
# CHECK-NEXT: } {"./i"}
147+
# CHECK-NEXT: return
148+
# CHECK-NEXT: }
149+
# CHECK-NEXT: }
150+
# CHECK-NEXT:
151+
# CHECK-NEXT: graph:
152+
# CHECK-NEXT: name: matmul
153+
# CHECK-NEXT: inputs:
154+
# CHECK-NEXT: - %0 : 4x512xfloat32
155+
# CHECK-NEXT: - %1 : 512x32xfloat32
156+
# CHECK-NEXT: outputs:
157+
# CHECK-NEXT: - %2 : 4x32xfloat32
158+
# CHECK-NEXT: nodes:
159+
# CHECK-NEXT: - %2: matmul(%0, %1) {name = 'C'} : [4x512xfloat32, 512x32xfloat32] -> [4x32xfloat32]
160+
# CHECK-NEXT:
161+
# CHECK-NEXT: CODE: 0

0 commit comments

Comments
 (0)