Skip to content

Commit 554a5b1

Browse files
committed
Naive matrix multiplication
1 parent cba69fc commit 554a5b1

2 files changed

Lines changed: 289 additions & 0 deletions

File tree

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
### Dumb matrix multiplication
2+
### Simulate the CPU-style matrix multiplication with 1 GPU thread per row
3+
4+
from gpu.host import DeviceContext, HostBuffer
5+
from gpu import thread_idx, block_idx, block_dim
6+
import random
7+
from layout import Layout, LayoutTensor
8+
from memory import UnsafePointer, memcpy
9+
from python import Python, PythonObject
10+
from testing import assert_true
11+
12+
alias ROWS_A = 64
13+
alias COLS_A = 16
14+
alias ROWS_B = 16
15+
alias COLS_B = 8
16+
alias ROWS_C = ROWS_A
17+
alias COLS_C = COLS_B
18+
19+
alias MATRIX_MIN_ELEM = -5.0
20+
alias MATRIX_MAX_ELEM = 5.0
21+
22+
alias dtype = DType.float32
23+
# Num threads per block
24+
alias THREADS = ROWS_C
25+
# Total numbers blocks in the grid
26+
alias BLOCKS = 1
27+
28+
alias layout_a = Layout.row_major(ROWS_A, COLS_A)
29+
alias layout_b = Layout.row_major(ROWS_B, COLS_B)
30+
alias layout_c = Layout.row_major(ROWS_C, COLS_C)
31+
32+
33+
alias MatrixA = LayoutTensor[dtype, layout_a, MutableAnyOrigin]
34+
alias MatrixB = LayoutTensor[dtype, layout_b, MutableAnyOrigin]
35+
alias MatrixC = LayoutTensor[dtype, layout_c, MutableAnyOrigin]
36+
37+
38+
fn naive_matmul_one_thread_per_row[
39+
a: Layout, b: Layout, c: Layout
40+
](A: MatrixA, B: MatrixB, C: MatrixC,):
41+
var tid = block_idx.x * block_dim.x + thread_idx.x
42+
43+
if tid < ROWS_A: # Each thread id `tid` is a row of A or C
44+
for j in range(COLS_B):
45+
for k in range(COLS_A):
46+
C[tid, j] += A[tid, k] * B[k, j]
47+
48+
49+
# Initialize the matrix buffer with values in the range 0 to 100
50+
fn fill_buffer(buffer: HostBuffer[dtype]):
51+
# Randomize
52+
random.seed()
53+
for i in range(len(buffer)):
54+
buffer[i] = random.random_float64(
55+
MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
56+
).cast[dtype]()[0]
57+
58+
59+
fn main():
60+
try:
61+
ctx = DeviceContext()
62+
63+
buffer_a = ctx.enqueue_create_buffer[dtype](
64+
ROWS_A * COLS_A
65+
).enqueue_fill(0.0)
66+
buffer_b = ctx.enqueue_create_buffer[dtype](
67+
ROWS_B * COLS_B
68+
).enqueue_fill(0.0)
69+
buffer_c = ctx.enqueue_create_buffer[dtype](
70+
ROWS_C * COLS_C
71+
).enqueue_fill(0.0)
72+
73+
with buffer_a.map_to_host() as h_buffer_a:
74+
fill_buffer(h_buffer_a)
75+
76+
with buffer_b.map_to_host() as h_buffer_b:
77+
fill_buffer(h_buffer_b)
78+
79+
matrix_a = MatrixA(buffer_a)
80+
matrix_b = MatrixB(buffer_b)
81+
matrix_c = MatrixC(buffer_c)
82+
83+
ctx.enqueue_function[
84+
naive_matmul_one_thread_per_row[layout_a, layout_b, layout_c]
85+
](
86+
matrix_a,
87+
matrix_b,
88+
matrix_c,
89+
grid_dim=BLOCKS,
90+
block_dim=THREADS,
91+
)
92+
93+
ctx.synchronize()
94+
95+
with buffer_a.map_to_host() as h_buffer_a:
96+
with buffer_b.map_to_host() as h_buffer_b:
97+
with buffer_c.map_to_host() as h_buffer_c:
98+
assert_allclose(
99+
(ROWS_A, COLS_A, h_buffer_a),
100+
(ROWS_B, COLS_B, h_buffer_b),
101+
(ROWS_C, COLS_C, h_buffer_c),
102+
)
103+
104+
except e:
105+
print("Prininting here: ", e)
106+
107+
108+
fn assert_allclose(
109+
buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
110+
buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
111+
buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
112+
) raises:
113+
a_rows, a_cols, a_buff = buff_a_with_dims
114+
matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)
115+
116+
b_rows, b_cols, b_buff = buff_b_with_dims
117+
matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)
118+
119+
c_rows, c_cols, c_buff = buff_c_with_dims
120+
matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
121+
np = Python.import_module("numpy")
122+
assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
123+
print("Assertion was successful")
124+
125+
126+
fn to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
127+
np = Python.import_module("numpy")
128+
ndarray = np.zeros(len(buffer), dtype=np.float32)
129+
ndarray_ptr = ndarray_ptr[dtype](ndarray)
130+
buffer_ptr = buffer.unsafe_ptr()
131+
memcpy(ndarray_ptr, buffer_ptr, len(buffer))
132+
return ndarray
133+
134+
135+
fn reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
136+
return ndarray.reshape(rows, cols)
137+
138+
139+
fn ndarray_ptr[
140+
dtype: DType
141+
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
142+
return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()
143+
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
### Dumb matrix multiplication
2+
### Simulate the CPU-style triple for-loop truly dumb matrix multiplication
3+
4+
from gpu.host import DeviceContext, HostBuffer
5+
from gpu import thread_idx, block_idx, block_dim
6+
import random
7+
from layout import Layout, LayoutTensor
8+
from memory import UnsafePointer, memcpy
9+
from python import Python, PythonObject
10+
from testing import assert_true
11+
12+
alias ROWS_A = 64
13+
alias COLS_A = 16
14+
alias ROWS_B = 16
15+
alias COLS_B = 8
16+
alias ROWS_C = ROWS_A
17+
alias COLS_C = COLS_B
18+
19+
alias MATRIX_MIN_ELEM = -5.0
20+
alias MATRIX_MAX_ELEM = 5.0
21+
22+
alias dtype = DType.float32
23+
# Num threads per block
24+
alias THREADS = 1
25+
# Total numbers blocks in the grid
26+
alias BLOCKS = 1
27+
28+
alias layout_a = Layout.row_major(ROWS_A, COLS_A)
29+
alias layout_b = Layout.row_major(ROWS_B, COLS_B)
30+
alias layout_c = Layout.row_major(ROWS_C, COLS_C)
31+
32+
33+
alias MatrixA = LayoutTensor[dtype, layout_a, MutableAnyOrigin]
34+
alias MatrixB = LayoutTensor[dtype, layout_b, MutableAnyOrigin]
35+
alias MatrixC = LayoutTensor[dtype, layout_c, MutableAnyOrigin]
36+
37+
38+
fn naive_matmul_single_thread_layout_tensor[
39+
a: Layout, b: Layout, c: Layout
40+
](A: MatrixA, B: MatrixB, C: MatrixC,):
41+
var tid = block_idx.x * block_dim.x + thread_idx.x
42+
43+
if tid == 0:
44+
for i in range(ROWS_A):
45+
for j in range(COLS_B):
46+
for k in range(COLS_A):
47+
C[i, j] += A[i, k] * B[k, j]
48+
49+
50+
# Initialize the matrix buffer with values in the range 0 to 100
51+
fn fill_buffer(buffer: HostBuffer[dtype]):
52+
# Randomize
53+
random.seed()
54+
for i in range(len(buffer)):
55+
buffer[i] = random.random_float64(
56+
MATRIX_MIN_ELEM, MATRIX_MAX_ELEM
57+
).cast[dtype]()[0]
58+
59+
60+
fn main():
61+
try:
62+
ctx = DeviceContext()
63+
64+
buffer_a = ctx.enqueue_create_buffer[dtype](
65+
ROWS_A * COLS_A
66+
).enqueue_fill(0.0)
67+
buffer_b = ctx.enqueue_create_buffer[dtype](
68+
ROWS_B * COLS_B
69+
).enqueue_fill(0.0)
70+
buffer_c = ctx.enqueue_create_buffer[dtype](
71+
ROWS_C * COLS_C
72+
).enqueue_fill(0.0)
73+
74+
with buffer_a.map_to_host() as h_buffer_a:
75+
fill_buffer(h_buffer_a)
76+
77+
with buffer_b.map_to_host() as h_buffer_b:
78+
fill_buffer(h_buffer_b)
79+
80+
matrix_a = MatrixA(buffer_a)
81+
matrix_b = MatrixB(buffer_b)
82+
matrix_c = MatrixC(buffer_c)
83+
84+
ctx.enqueue_function[
85+
naive_matmul_single_thread_layout_tensor[
86+
layout_a, layout_b, layout_c
87+
]
88+
](
89+
matrix_a,
90+
matrix_b,
91+
matrix_c,
92+
grid_dim=BLOCKS,
93+
block_dim=THREADS,
94+
)
95+
96+
ctx.synchronize()
97+
98+
with buffer_a.map_to_host() as h_buffer_a:
99+
with buffer_b.map_to_host() as h_buffer_b:
100+
with buffer_c.map_to_host() as h_buffer_c:
101+
assert_allclose(
102+
(ROWS_A, COLS_A, h_buffer_a),
103+
(ROWS_B, COLS_B, h_buffer_b),
104+
(ROWS_C, COLS_C, h_buffer_c),
105+
)
106+
107+
except e:
108+
print("Prininting here: ", e)
109+
110+
111+
fn assert_allclose(
112+
buff_a_with_dims: (Int, Int, HostBuffer[dtype]),
113+
buff_b_with_dims: (Int, Int, HostBuffer[dtype]),
114+
buff_c_with_dims: (Int, Int, HostBuffer[dtype]),
115+
) raises:
116+
a_rows, a_cols, a_buff = buff_a_with_dims
117+
matrix_a = reshape(to_ndarray(a_buff), a_rows, a_cols)
118+
119+
b_rows, b_cols, b_buff = buff_b_with_dims
120+
matrix_b = reshape(to_ndarray(b_buff), b_rows, b_cols)
121+
122+
c_rows, c_cols, c_buff = buff_c_with_dims
123+
matrix_c = reshape(to_ndarray(c_buff), c_rows, c_cols)
124+
np = Python.import_module("numpy")
125+
assert_true(np.allclose(np.matmul(matrix_a, matrix_b), matrix_c))
126+
print("Assertion was successful")
127+
128+
129+
fn to_ndarray(buffer: HostBuffer[dtype]) raises -> PythonObject:
130+
np = Python.import_module("numpy")
131+
ndarray = np.zeros(len(buffer), dtype=np.float32)
132+
ndarray_ptr = ndarray_ptr[dtype](ndarray)
133+
buffer_ptr = buffer.unsafe_ptr()
134+
memcpy(ndarray_ptr, buffer_ptr, len(buffer))
135+
return ndarray
136+
137+
138+
fn reshape(ndarray: PythonObject, rows: Int, cols: Int) raises -> PythonObject:
139+
return ndarray.reshape(rows, cols)
140+
141+
142+
fn ndarray_ptr[
143+
dtype: DType
144+
](ndarray: PythonObject) raises -> UnsafePointer[Scalar[dtype]]:
145+
return ndarray.__array_interface__["data"][0].unsafe_get_as_pointer[dtype]()
146+

0 commit comments

Comments
 (0)