Skip to content

Commit d6dc85f

Browse files
committed
Improve KI tests
1 parent 04492ac commit d6dc85f

1 file changed

Lines changed: 33 additions & 30 deletions

File tree

test/intrinsics.jl

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
import KernelAbstractions.KernelIntrinsics as KI
22

3+
struct KernelData
4+
global_size::Int
5+
global_id::Int
6+
local_size::Int
7+
local_id::Int
8+
num_groups::Int
9+
group_id::Int
10+
end
311
function test_intrinsics_kernel(results)
4-
# Test all intrinsics return NamedTuples with x, y, z fields
5-
global_size = KI.get_global_size()
6-
global_id = KI.get_global_id()
7-
local_size = KI.get_local_size()
8-
local_id = KI.get_local_id()
9-
num_groups = KI.get_num_groups()
10-
group_id = KI.get_group_id()
11-
12-
if UInt32(global_id.x) <= UInt32(global_size.x)
13-
results[1, global_id.x] = global_id.x
14-
results[2, global_id.x] = local_id.x
15-
results[3, global_id.x] = group_id.x
16-
results[4, global_id.x] = global_size.x
17-
results[5, global_id.x] = local_size.x
18-
results[6, global_id.x] = num_groups.x
12+
i = KI.get_global_id().x
13+
14+
if i <= length(results)
15+
@inbounds results[i] = KernelData(
16+
KI.get_global_size().x,
17+
KI.get_global_id().x,
18+
KI.get_local_size().x,
19+
KI.get_local_id().x,
20+
KI.get_num_groups().x,
21+
KI.get_group_id().x
22+
)
1923
end
2024
return
2125
end
@@ -82,41 +86,40 @@ function intrinsics_testsuite(backend, AT)
8286
@test KI.multiprocessor_count(backend()) isa Int
8387

8488
# Test with small kernel
85-
N = 16
86-
results = AT(zeros(Int, 6, N))
89+
workgroupsize = 4
90+
numworkgroups = 4
91+
N = workgroupsize * numworkgroups
92+
results = AT(Vector{KernelData}(undef, N))
8793
kernel = KI.@kernel backend() launch = false test_intrinsics_kernel(results)
8894

8995
@test KI.kernel_max_work_group_size(kernel) isa Int
9096
@test KI.kernel_max_work_group_size(kernel; max_work_items = 1) == 1
9197

92-
kernel(results; workgroupsize = 4, numworkgroups = 4)
98+
kernel(results; workgroupsize, numworkgroups)
9399
KernelAbstractions.synchronize(backend())
94100

95101
host_results = Array(results)
96102

97103
# Verify results make sense
98-
for i in 1:N
99-
global_id_x, local_id_x, group_id_x, global_size_x, local_size_x, num_groups_x = host_results[:, i]
104+
for (i, k_data) in enumerate(host_results)
100105

101106
# Global IDs should be 1-based and sequential
102-
@test global_id_x == i
107+
@test k_data.global_id == i
103108

104109
# Global size should match our ndrange
105-
@test global_size_x == N
110+
@test k_data.global_size == N
106111

107-
# Local size should be 4 (our workgroupsize)
108-
@test local_size_x == 4
112+
@test k_data.local_size == workgroupsize
109113

110-
# Number of groups should be ceil(N/4) = 4
111-
@test num_groups_x == 4
114+
@test k_data.num_groups == numworkgroups
112115

113116
# Group ID should be 1-based
114-
expected_group = div(i - 1, 4) + 1
115-
@test group_id_x == expected_group
117+
expected_group = div(i - 1, numworkgroups) + 1
118+
@test k_data.group_id == expected_group
116119

117120
# Local ID should be 1-based within group
118-
expected_local = ((i - 1) % 4) + 1
119-
@test local_id_x == expected_local
121+
expected_local = ((i - 1) % workgroupsize) + 1
122+
@test k_data.local_id == expected_local
120123
end
121124
end
122125
end

0 commit comments

Comments
 (0)