Skip to content

Commit 8eed1ea

Browse files
committed
Add support for 8x8x16 cooperative matrices (Intel Arc)
1 parent 8732232 commit 8eed1ea

65 files changed

Lines changed: 358 additions & 269 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

internal/RendererVK.cpp

Lines changed: 88 additions & 34 deletions
Large diffs are not rendered by default.

internal/Vk/ContextVK.cpp

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,9 @@ bool Ray::Vk::Context::Init(ILog *log, const VulkanDevice &vk_device, const Vulk
213213
return false;
214214
}
215215

216-
CheckVkPhysicalDeviceFeatures(api_, physical_device_, device_properties_, mem_properties_, graphics_family_index_,
217-
raytracing_supported_, ray_query_supported_, fp16_supported_, int64_supported_,
218-
int64_atomics_supported_, coop_matrix_size_, pageable_memory_supported_);
216+
CheckVkPhysicalDeviceFeatures(api_, physical_device_, device_properties_, mem_properties_, coop_mat_properties_,
217+
graphics_family_index_, raytracing_supported_, ray_query_supported_, fp16_supported_,
218+
int64_supported_, int64_atomics_supported_, pageable_memory_supported_);
219219

220220
// mask out unsupported stages
221221
if (!raytracing_supported_) {
@@ -227,7 +227,7 @@ bool Ray::Vk::Context::Init(ILog *log, const VulkanDevice &vk_device, const Vulk
227227

228228
if (!external_ && !InitVkDevice(api_, device_, physical_device_, graphics_family_index_, raytracing_supported_,
229229
ray_query_supported_, fp16_supported_, int64_supported_, int64_atomics_supported_,
230-
coop_matrix_size_[0] != -1, pageable_memory_supported_, log)) {
230+
coop_mat_properties_.MSize != 0, pageable_memory_supported_, log)) {
231231
return false;
232232
}
233233

@@ -559,14 +559,12 @@ bool Ray::Vk::Context::ChooseVkPhysicalDevice(const Api &api, VkPhysicalDevice &
559559
return true;
560560
}
561561

562-
void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalDevice &physical_device,
563-
VkPhysicalDeviceProperties &out_device_properties,
564-
VkPhysicalDeviceMemoryProperties &out_mem_properties,
565-
uint32_t &out_graphics_family_index,
566-
bool &out_raytracing_supported, bool &out_ray_query_supported,
567-
bool &out_shader_fp16_supported, bool &out_shader_int64_supported,
568-
bool &out_int64_atomics_supported, int out_coop_matrix_size[3],
569-
bool &out_pageable_memory_supported) {
562+
void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(
563+
const Api &api, VkPhysicalDevice &physical_device, VkPhysicalDeviceProperties &out_device_properties,
564+
VkPhysicalDeviceMemoryProperties &out_mem_properties, VkCooperativeMatrixPropertiesKHR &out_coop_mat_properties,
565+
uint32_t &out_graphics_family_index, bool &out_raytracing_supported, bool &out_ray_query_supported,
566+
bool &out_shader_fp16_supported, bool &out_shader_int64_supported, bool &out_int64_atomics_supported,
567+
bool &out_pageable_memory_supported) {
570568
api.vkGetPhysicalDeviceProperties(physical_device, &out_device_properties);
571569
api.vkGetPhysicalDeviceMemoryProperties(physical_device, &out_mem_properties);
572570

@@ -594,7 +592,7 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
594592
shader_buf_int64_atomics_supported = false, memory_priority_supported = false,
595593
pageable_memory_supported = false;
596594

597-
int coop_matrix_size[3] = {-1, -1, -1};
595+
VkCooperativeMatrixPropertiesKHR coop_mat_properties = {};
598596

599597
{ // check for features support
600598
uint32_t extension_count;
@@ -671,18 +669,28 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
671669
coop_matrix_props.data());
672670

673671
bool found = false;
672+
// We try to find 16x16x16 size and F16 accumulator first (NV and AMD)
674673
for (const VkCooperativeMatrixPropertiesKHR &p : coop_matrix_props) {
675674
if (p.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && p.BType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
676675
p.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && p.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
677676
p.MSize == 16 && p.NSize == 16 && p.KSize == 16 && p.scope == VK_SCOPE_SUBGROUP_KHR) {
678-
coop_matrix_size[0] = 16;
679-
coop_matrix_size[1] = 16;
680-
coop_matrix_size[2] = 16;
677+
coop_mat_properties = p;
681678
found = true;
682679
break;
683680
}
684681
}
685-
coop_matrix_supported &= found;
682+
if (!found) {
683+
// Try to find 8x8x16 size and F32 accumulator (Intel)
684+
for (const VkCooperativeMatrixPropertiesKHR &p : coop_matrix_props) {
685+
if (p.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && p.BType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
686+
p.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && p.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
687+
p.MSize == 8 && p.NSize == 8 && p.KSize == 16 && p.scope == VK_SCOPE_SUBGROUP_KHR) {
688+
coop_mat_properties = p;
689+
found = true;
690+
break;
691+
}
692+
}
693+
}
686694
}
687695
}
688696

@@ -691,7 +699,7 @@ void Ray::Vk::Context::CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalD
691699
out_shader_fp16_supported = (shader_fp16_supported && storage_fp16_supported);
692700
out_shader_int64_supported = shader_int64_supported;
693701
out_int64_atomics_supported = shader_buf_int64_atomics_supported;
694-
memcpy(out_coop_matrix_size, coop_matrix_size, 3 * sizeof(int));
702+
out_coop_mat_properties = coop_mat_properties;
695703
out_pageable_memory_supported = (memory_priority_supported && pageable_memory_supported);
696704
}
697705

internal/Vk/ContextVK.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class Context {
2727
VkPhysicalDeviceLimits phys_device_limits_ = {};
2828
VkPhysicalDeviceProperties device_properties_ = {};
2929
VkPhysicalDeviceMemoryProperties mem_properties_ = {};
30+
VkCooperativeMatrixPropertiesKHR coop_mat_properties_ = {};
3031
uint32_t graphics_family_index_ = 0;
3132

3233
VkDevice device_ = {};
@@ -46,8 +47,6 @@ class Context {
4647

4748
bool subgroup_supported_ = false;
4849

49-
int coop_matrix_size_[3] = {-1, -1, -1};
50-
5150
bool pageable_memory_supported_ = false;
5251

5352
uint32_t supported_stages_mask_ = 0xffffffff;
@@ -94,14 +93,14 @@ class Context {
9493
bool int64_supported() const { return int64_supported_; }
9594
bool int64_atomics_supported() const { return int64_atomics_supported_; }
9695
bool subgroup_supported() const { return subgroup_supported_; }
97-
const int *coop_matrix_size() const { return coop_matrix_size_; }
9896

9997
uint32_t supported_stages_mask() const { return supported_stages_mask_; };
10098
bool image_blit_supported() const { return true; }
10199

102100
const VkPhysicalDeviceLimits &phys_device_limits() const { return phys_device_limits_; }
103101
const VkPhysicalDeviceProperties &device_properties() const { return device_properties_; }
104102
const VkPhysicalDeviceMemoryProperties &mem_properties() const { return mem_properties_; }
103+
const VkCooperativeMatrixPropertiesKHR &coop_mat_properties() const { return coop_mat_properties_; }
105104

106105
const VkPhysicalDeviceRayTracingPipelinePropertiesKHR &rt_props() const { return rt_props_; }
107106

@@ -153,10 +152,11 @@ class Context {
153152
static void CheckVkPhysicalDeviceFeatures(const Api &api, VkPhysicalDevice &physical_device,
154153
VkPhysicalDeviceProperties &device_properties,
155154
VkPhysicalDeviceMemoryProperties &mem_properties,
155+
VkCooperativeMatrixPropertiesKHR &coop_mat_properties,
156156
uint32_t &graphics_family_index, bool &out_raytracing_supported,
157157
bool &out_ray_query_supported, bool &out_shader_fp16_supported,
158158
bool &out_shader_int64_supported, bool &out_int64_atomics_supported,
159-
int out_coop_matrix_size[3], bool &out_pageable_memory_supported);
159+
bool &out_pageable_memory_supported);
160160
static bool InitVkDevice(const Api &api, VkDevice &device, VkPhysicalDevice physical_device,
161161
uint32_t graphics_family_index, bool enable_raytracing, bool enable_ray_query,
162162
bool enable_fp16, bool enable_int64, bool enable_int64_atomics, bool enable_coop_matrix,

internal/shaders/convolution.comp.glsl

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
#define f16vec4 vec4
1818
#endif
1919

20+
#ifndef C_TYPE
21+
#define C_TYPE float16_t
22+
#define C_CONVERT 0
23+
#else
24+
#define C_CONVERT 1
25+
#endif
26+
2027
layout(push_constant) uniform UniformParams {
2128
Params g_params;
2229
};
@@ -232,6 +239,11 @@ shared float16_t g_mat_staging1[16 * 16];
232239
shared float16_t g_mat_staging2[16 * 16];
233240
shared float16_t g_mat_staging3[16 * 16];
234241

242+
#ifdef COOP_M
243+
shared C_TYPE g_mat_staging_C0[COOP_M * 16];
244+
shared C_TYPE g_mat_staging_C1[COOP_M * 16];
245+
#endif
246+
235247
void main() {
236248
ivec3 tile_id = ivec3(gl_WorkGroupID), li = ivec3(gl_LocalInvocationID);
237249

@@ -248,22 +260,22 @@ void main() {
248260
return;
249261
}
250262

251-
coopmat<float16_t, gl_ScopeSubgroup, COOP_M, COOP_N, gl_MatrixUseAccumulator> C0[C_ROWS][C_COLS], C1[C_ROWS][C_COLS];
263+
coopmat<C_TYPE, gl_ScopeSubgroup, COOP_M, COOP_N, gl_MatrixUseAccumulator> C0[C_ROWS][C_COLS], C1[C_ROWS][C_COLS];
252264
for (int i = 0; i < C_COLS; ++i) {
253265
const int ii = int(gl_LocalInvocationIndex);
254266
for (int jj = 0; jj < COOP_M && ii < COOP_N; ++jj) {
255-
g_mat_staging0[jj * 16 + ii] = float16_t(0.0);
267+
g_mat_staging_C0[jj * 16 + ii] = C_TYPE(0.0);
256268
if (ii < OUT_CHANNELS) {
257-
g_mat_staging0[jj * 16 + ii] = g_biases[c + i * COOP_N + ii];
269+
g_mat_staging_C0[jj * 16 + ii] = C_TYPE(g_biases[c + i * COOP_N + ii]);
258270
}
259271
// zero out shared memory to avoid NANs later
260-
g_mat_staging1[jj * 16 + ii] = g_mat_staging2[jj * 16 + ii] = g_mat_staging3[jj * 16 + ii] = float16_t(0.0);
272+
g_mat_staging0[jj * 16 + ii] = g_mat_staging1[jj * 16 + ii] = g_mat_staging2[jj * 16 + ii] = g_mat_staging3[jj * 16 + ii] = float16_t(0.0);
261273
}
262274
groupMemoryBarrier(); barrier();
263275

264276
for (int j = 0; j < C_ROWS; ++j) {
265-
coopMatLoad(C0[j][i], g_mat_staging0, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
266-
coopMatLoad(C1[j][i], g_mat_staging0, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
277+
coopMatLoad(C0[j][i], g_mat_staging_C0, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
278+
coopMatLoad(C1[j][i], g_mat_staging_C0, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
267279
}
268280
}
269281

@@ -501,34 +513,34 @@ void main() {
501513
for (int j = 0; j < C_ROWS; ++j) {
502514
for (int i = 0; i < cols_count; ++i) {
503515
for (int k = 0; k < C0[j][i].length(); ++k) {
504-
C0[j][i][k] = max(max(C0[j][i][k], C1[j][i][k]), float16_t(0.0));
516+
C0[j][i][k] = max(max(C0[j][i][k], C1[j][i][k]), C_TYPE(0.0));
505517
}
506518

507-
coopMatStore(C0[j][i], g_mat_staging0, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
519+
coopMatStore(C0[j][i], g_mat_staging_C0, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
508520
groupMemoryBarrier(); barrier();
509521

510522
for (int jj = 0; jj < COOP_M; jj += 2) {
511523
for (int ii = 0; ii < COOP_N; ++ii) {
512-
const float16_t out_val = max(g_mat_staging0[(jj + 0) * 16 + ii], g_mat_staging0[(jj + 1) * 16 + ii]);
513-
g_out_buf[OUT_CHANNELS * ((y / 2 + 1) * g_params.output_stride + x / 2 + j * COOP_M / 2 + jj / 2 + 1) + c + i * COOP_N + ii] = max(out_val, float16_t(0.0));
524+
const C_TYPE out_val = max(max(g_mat_staging_C0[(jj + 0) * 16 + ii], g_mat_staging_C0[(jj + 1) * 16 + ii]), C_TYPE(0.0));
525+
g_out_buf[OUT_CHANNELS * ((y / 2 + 1) * g_params.output_stride + x / 2 + j * COOP_M / 2 + jj / 2 + 1) + c + i * COOP_N + ii] = float16_t(out_val);
514526
}
515527
}
516528
}
517529
}
518530
#elif OUT_IMG
519531
for (int j = 0; j < rows_count && c == 0; ++j) {
520532
for (int k = 0; k < C0[j][0].length(); ++k) {
521-
C0[j][0][k] = max(C0[j][0][k], float16_t(0.0));
522-
C1[j][0][k] = max(C1[j][0][k], float16_t(0.0));
533+
C0[j][0][k] = max(C0[j][0][k], C_TYPE(0.0));
534+
C1[j][0][k] = max(C1[j][0][k], C_TYPE(0.0));
523535
}
524536

525-
coopMatStore(C0[j][0], g_mat_staging0, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
526-
coopMatStore(C1[j][0], g_mat_staging1, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
537+
coopMatStore(C0[j][0], g_mat_staging_C0, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
538+
coopMatStore(C1[j][0], g_mat_staging_C1, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
527539
groupMemoryBarrier(); barrier();
528540

529541
for (int jj = 0; jj < COOP_M; ++jj) {
530-
vec4 val0 = vec4(g_mat_staging0[jj * 16 + 0], g_mat_staging0[jj * 16 + 1], g_mat_staging0[jj * 16 + 2], 1.0),
531-
val1 = vec4(g_mat_staging1[jj * 16 + 0], g_mat_staging1[jj * 16 + 1], g_mat_staging1[jj * 16 + 2], 1.0);
542+
vec4 val0 = vec4(g_mat_staging_C0[jj * 16 + 0], g_mat_staging_C0[jj * 16 + 1], g_mat_staging_C0[jj * 16 + 2], 1.0),
543+
val1 = vec4(g_mat_staging_C1[jj * 16 + 0], g_mat_staging_C1[jj * 16 + 1], g_mat_staging_C1[jj * 16 + 2], 1.0);
532544
val0.xyz = transfer_output(val0.xyz);
533545
val1.xyz = transfer_output(val1.xyz);
534546
imageStore(g_out_img, ivec2(x + j * COOP_M + jj, y), val0);
@@ -554,14 +566,29 @@ void main() {
554566
for (int j = 0; j < rows_count; ++j) {
555567
for (int i = 0; i < cols_count; ++i) {
556568
for (int k = 0; k < C0[j][i].length(); ++k) {
557-
C0[j][i][k] = max(C0[j][i][k], float16_t(0.0));
558-
C1[j][i][k] = max(C1[j][i][k], float16_t(0.0));
569+
C0[j][i][k] = max(C0[j][i][k], C_TYPE(0.0));
570+
C1[j][i][k] = max(C1[j][i][k], C_TYPE(0.0));
559571
}
560572

573+
#if C_CONVERT
574+
coopMatStore(C0[j][i], g_mat_staging_C0, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
575+
coopMatStore(C1[j][i], g_mat_staging_C1, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
576+
groupMemoryBarrier(); barrier();
577+
578+
for (int jj = 0; jj < COOP_M; ++jj) {
579+
for (int ii = 0; ii < COOP_N; ++ii) {
580+
g_out_buf[OUT_CHANNELS * ((y + 0 + 1) * g_params.output_stride + x + j * COOP_M + jj + 1) + c + i * COOP_N + ii] = float16_t(g_mat_staging_C0[jj * 16 + ii]);
581+
if (y + 1 < int(g_params.out_dims[1])) {
582+
g_out_buf[OUT_CHANNELS * ((y + 1 + 1) * g_params.output_stride + x + j * COOP_M + jj + 1) + c + i * COOP_N + ii] = float16_t(g_mat_staging_C1[jj * 16 + ii]);
583+
}
584+
}
585+
}
586+
#else // C_CONVERT
561587
coopMatStore(C0[j][i], g_out_buf, OUT_CHANNELS * ((y + 0 + 1) * g_params.output_stride + x + j * COOP_M + 1) + c + i * COOP_N, OUT_CHANNELS, gl_CooperativeMatrixLayoutRowMajor);
562588
if (y + 1 < int(g_params.out_dims[1])) {
563589
coopMatStore(C1[j][i], g_out_buf, OUT_CHANNELS * ((y + 1 + 1) * g_params.output_stride + x + j * COOP_M + 1) + c + i * COOP_N, OUT_CHANNELS, gl_CooperativeMatrixLayoutRowMajor);
564590
}
591+
#endif // C_CONVERT
565592
}
566593
}
567594
#endif // OUT_IMG

internal/shaders/output/convolution_112_112_coop_16x16x16.comp.spv.inl

Lines changed: 0 additions & 5 deletions
This file was deleted.

internal/shaders/output/convolution_112_112_coop_16x16x16_CF16.comp.spv.inl

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/shaders/output/convolution_112_112_coop_16x8x8.comp.spv.inl

Lines changed: 0 additions & 5 deletions
This file was deleted.

internal/shaders/output/convolution_112_112_coop_8x8x16_CF32.comp.spv.inl

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/shaders/output/convolution_32_32_Downsample_coop_16x16x16.comp.spv.inl

Lines changed: 0 additions & 5 deletions
This file was deleted.

internal/shaders/output/convolution_32_32_Downsample_coop_16x16x16_CF16.comp.spv.inl

Lines changed: 5 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)