Skip to content

Commit 3dd8c85

Browse files
committed
Improve KI tests
1 parent 748f932 commit 3dd8c85

1 file changed

Lines changed: 31 additions & 30 deletions

File tree

test/intrinsics.jl

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
import KernelAbstractions.KernelIntrinsics as KI
22

3+
struct KernelData
4+
global_size::Int
5+
global_id::Int
6+
local_size::Int
7+
local_id::Int
8+
num_groups::Int
9+
group_id::Int
10+
end
311
function test_intrinsics_kernel(results)
4-
# Test all intrinsics return NamedTuples with x, y, z fields
5-
global_size = KI.get_global_size()
6-
global_id = KI.get_global_id()
7-
local_size = KI.get_local_size()
8-
local_id = KI.get_local_id()
9-
num_groups = KI.get_num_groups()
10-
group_id = KI.get_group_id()
11-
12-
if UInt32(global_id.x) <= UInt32(global_size.x)
13-
results[1, global_id.x] = global_id.x
14-
results[2, global_id.x] = local_id.x
15-
results[3, global_id.x] = group_id.x
16-
results[4, global_id.x] = global_size.x
17-
results[5, global_id.x] = local_size.x
18-
results[6, global_id.x] = num_groups.x
12+
i = KI.get_global_id().x
13+
14+
if i <= length(results)
15+
@inbounds results[i] = KernelData(KI.get_global_size().x,
16+
KI.get_global_id().x,
17+
KI.get_local_size().x,
18+
KI.get_local_id().x,
19+
KI.get_num_groups().x,
20+
KI.get_group_id().x)
1921
end
2022
return
2123
end
@@ -82,41 +84,40 @@ function intrinsics_testsuite(backend, AT)
8284
@test KI.multiprocessor_count(backend()) isa Int
8385

8486
# Test with small kernel
85-
N = 16
86-
results = AT(zeros(Int, 6, N))
87+
workgroupsize = 4
88+
numworkgroups = 4
89+
N = workgroupsize * numworkgroups
90+
results = AT(Vector{KernelData}(undef, N))
8791
kernel = KI.@kernel backend() launch = false test_intrinsics_kernel(results)
8892

8993
@test KI.kernel_max_work_group_size(kernel) isa Int
9094
@test KI.kernel_max_work_group_size(kernel; max_work_items = 1) == 1
9195

92-
kernel(results; workgroupsize = 4, numworkgroups = 4)
96+
kernel(results; workgroupsize, numworkgroups)
9397
KernelAbstractions.synchronize(backend())
9498

9599
host_results = Array(results)
96100

97101
# Verify results make sense
98-
for i in 1:N
99-
global_id_x, local_id_x, group_id_x, global_size_x, local_size_x, num_groups_x = host_results[:, i]
102+
for (i, k_data) in enumerate(host_results)
100103

101104
# Global IDs should be 1-based and sequential
102-
@test global_id_x == i
105+
@test k_data.global_id == i
103106

104107
# Global size should match our ndrange
105-
@test global_size_x == N
108+
@test k_data.global_size == N
106109

107-
# Local size should be 4 (our workgroupsize)
108-
@test local_size_x == 4
110+
@test k_data.local_size == workgroupsize
109111

110-
# Number of groups should be ceil(N/4) = 4
111-
@test num_groups_x == 4
112+
@test k_data.num_groups == numworkgroups
112113

113114
# Group ID should be 1-based
114-
expected_group = div(i - 1, 4) + 1
115-
@test group_id_x == expected_group
115+
expected_group = div(i - 1, numworkgroups) + 1
116+
@test k_data.group_id == expected_group
116117

117118
# Local ID should be 1-based within group
118-
expected_local = ((i - 1) % 4) + 1
119-
@test local_id_x == expected_local
119+
expected_local = ((i - 1) % workgroupsize) + 1
120+
@test k_data.local_id == expected_local
120121
end
121122
end
122123
end

0 commit comments

Comments
 (0)