1717 #define f16vec4 vec4
1818#endif
1919
20+ #ifndef C_TYPE
21+ #define C_TYPE float16_t
22+ #define C_CONVERT 0
23+ #else
24+ #define C_CONVERT 1
25+ #endif
26+
2027layout (push_constant) uniform UniformParams {
2128 Params g_params;
2229};
@@ -232,6 +239,11 @@ shared float16_t g_mat_staging1[16 * 16];
232239shared float16_t g_mat_staging2[16 * 16 ];
233240shared float16_t g_mat_staging3[16 * 16 ];
234241
242+ #ifdef COOP_M
243+ shared C_TYPE g_mat_staging_C0[COOP_M * 16 ];
244+ shared C_TYPE g_mat_staging_C1[COOP_M * 16 ];
245+ #endif
246+
235247void main() {
236248 ivec3 tile_id = ivec3 (gl_WorkGroupID), li = ivec3 (gl_LocalInvocationID);
237249
@@ -248,22 +260,22 @@ void main() {
248260 return ;
249261 }
250262
251- coopmat< float16_t , gl_ScopeSubgroup, COOP_M, COOP_N, gl_MatrixUseAccumulator> C0[C_ROWS][C_COLS], C1[C_ROWS][C_COLS];
263+ coopmat< C_TYPE , gl_ScopeSubgroup, COOP_M, COOP_N, gl_MatrixUseAccumulator> C0[C_ROWS][C_COLS], C1[C_ROWS][C_COLS];
252264 for (int i = 0 ; i < C_COLS; ++ i) {
253265 const int ii = int (gl_LocalInvocationIndex);
254266 for (int jj = 0 ; jj < COOP_M && ii < COOP_N; ++ jj) {
255- g_mat_staging0 [jj * 16 + ii] = float16_t (0.0 );
267+ g_mat_staging_C0 [jj * 16 + ii] = C_TYPE (0.0 );
256268 if (ii < OUT_CHANNELS) {
257- g_mat_staging0 [jj * 16 + ii] = g_biases[c + i * COOP_N + ii];
269+ g_mat_staging_C0 [jj * 16 + ii] = C_TYPE( g_biases[c + i * COOP_N + ii]) ;
258270 }
259271 // zero out shared memory to avoid NANs later
260- g_mat_staging1[jj * 16 + ii] = g_mat_staging2[jj * 16 + ii] = g_mat_staging3[jj * 16 + ii] = float16_t(0.0 );
272+ g_mat_staging0[jj * 16 + ii] = g_mat_staging1[jj * 16 + ii] = g_mat_staging2[jj * 16 + ii] = g_mat_staging3[jj * 16 + ii] = float16_t(0.0 );
261273 }
262274 groupMemoryBarrier(); barrier();
263275
264276 for (int j = 0 ; j < C_ROWS; ++ j) {
265- coopMatLoad(C0[j][i], g_mat_staging0 , 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
266- coopMatLoad(C1[j][i], g_mat_staging0 , 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
277+ coopMatLoad(C0[j][i], g_mat_staging_C0 , 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
278+ coopMatLoad(C1[j][i], g_mat_staging_C0 , 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
267279 }
268280 }
269281
@@ -501,34 +513,34 @@ void main() {
501513 for (int j = 0 ; j < C_ROWS; ++ j) {
502514 for (int i = 0 ; i < cols_count; ++ i) {
503515 for (int k = 0 ; k < C0[j][i].length (); ++ k) {
504- C0[j][i][k] = max (max (C0[j][i][k], C1[j][i][k]), float16_t (0.0 ));
516+ C0[j][i][k] = max (max (C0[j][i][k], C1[j][i][k]), C_TYPE (0.0 ));
505517 }
506518
507- coopMatStore(C0[j][i], g_mat_staging0 , 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
519+ coopMatStore(C0[j][i], g_mat_staging_C0 , 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
508520 groupMemoryBarrier(); barrier();
509521
510522 for (int jj = 0 ; jj < COOP_M; jj += 2 ) {
511523 for (int ii = 0 ; ii < COOP_N; ++ ii) {
512- const float16_t out_val = max (g_mat_staging0 [(jj + 0 ) * 16 + ii], g_mat_staging0 [(jj + 1 ) * 16 + ii]);
513- g_out_buf[OUT_CHANNELS * ((y / 2 + 1 ) * g_params.output_stride + x / 2 + j * COOP_M / 2 + jj / 2 + 1 ) + c + i * COOP_N + ii] = max (out_val, float16_t(0.0 ) );
524+ const C_TYPE out_val = max (max (g_mat_staging_C0 [(jj + 0 ) * 16 + ii], g_mat_staging_C0 [(jj + 1 ) * 16 + ii]), C_TYPE( 0.0 ) );
525+ g_out_buf[OUT_CHANNELS * ((y / 2 + 1 ) * g_params.output_stride + x / 2 + j * COOP_M / 2 + jj / 2 + 1 ) + c + i * COOP_N + ii] = float16_t(out_val );
514526 }
515527 }
516528 }
517529 }
518530#elif OUT_IMG
519531 for (int j = 0 ; j < rows_count && c == 0 ; ++ j) {
520532 for (int k = 0 ; k < C0[j][0 ].length (); ++ k) {
521- C0[j][0 ][k] = max (C0[j][0 ][k], float16_t (0.0 ));
522- C1[j][0 ][k] = max (C1[j][0 ][k], float16_t (0.0 ));
533+ C0[j][0 ][k] = max (C0[j][0 ][k], C_TYPE (0.0 ));
534+ C1[j][0 ][k] = max (C1[j][0 ][k], C_TYPE (0.0 ));
523535 }
524536
525- coopMatStore(C0[j][0 ], g_mat_staging0 , 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
526- coopMatStore(C1[j][0 ], g_mat_staging1 , 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
537+ coopMatStore(C0[j][0 ], g_mat_staging_C0 , 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
538+ coopMatStore(C1[j][0 ], g_mat_staging_C1 , 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
527539 groupMemoryBarrier(); barrier();
528540
529541 for (int jj = 0 ; jj < COOP_M; ++ jj) {
530- vec4 val0 = vec4 (g_mat_staging0 [jj * 16 + 0 ], g_mat_staging0 [jj * 16 + 1 ], g_mat_staging0 [jj * 16 + 2 ], 1.0 ),
531- val1 = vec4 (g_mat_staging1 [jj * 16 + 0 ], g_mat_staging1 [jj * 16 + 1 ], g_mat_staging1 [jj * 16 + 2 ], 1.0 );
542+ vec4 val0 = vec4 (g_mat_staging_C0 [jj * 16 + 0 ], g_mat_staging_C0 [jj * 16 + 1 ], g_mat_staging_C0 [jj * 16 + 2 ], 1.0 ),
543+ val1 = vec4 (g_mat_staging_C1 [jj * 16 + 0 ], g_mat_staging_C1 [jj * 16 + 1 ], g_mat_staging_C1 [jj * 16 + 2 ], 1.0 );
532544 val0.xyz = transfer_output(val0.xyz);
533545 val1.xyz = transfer_output(val1.xyz);
534546 imageStore(g_out_img, ivec2 (x + j * COOP_M + jj, y), val0);
@@ -554,14 +566,29 @@ void main() {
554566 for (int j = 0 ; j < rows_count; ++ j) {
555567 for (int i = 0 ; i < cols_count; ++ i) {
556568 for (int k = 0 ; k < C0[j][i].length (); ++ k) {
557- C0[j][i][k] = max (C0[j][i][k], float16_t (0.0 ));
558- C1[j][i][k] = max (C1[j][i][k], float16_t (0.0 ));
569+ C0[j][i][k] = max (C0[j][i][k], C_TYPE (0.0 ));
570+ C1[j][i][k] = max (C1[j][i][k], C_TYPE (0.0 ));
559571 }
560572
573+ #if C_CONVERT
574+ coopMatStore(C0[j][i], g_mat_staging_C0, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
575+ coopMatStore(C1[j][i], g_mat_staging_C1, 0u, 16u, gl_CooperativeMatrixLayoutRowMajor);
576+ groupMemoryBarrier(); barrier();
577+
578+ for (int jj = 0 ; jj < COOP_M; ++ jj) {
579+ for (int ii = 0 ; ii < COOP_N; ++ ii) {
580+ g_out_buf[OUT_CHANNELS * ((y + 0 + 1 ) * g_params.output_stride + x + j * COOP_M + jj + 1 ) + c + i * COOP_N + ii] = float16_t(g_mat_staging_C0[jj * 16 + ii]);
581+ if (y + 1 < int (g_params.out_dims[1 ])) {
582+ g_out_buf[OUT_CHANNELS * ((y + 1 + 1 ) * g_params.output_stride + x + j * COOP_M + jj + 1 ) + c + i * COOP_N + ii] = float16_t(g_mat_staging_C1[jj * 16 + ii]);
583+ }
584+ }
585+ }
586+ #else // C_CONVERT
561587 coopMatStore(C0[j][i], g_out_buf, OUT_CHANNELS * ((y + 0 + 1 ) * g_params.output_stride + x + j * COOP_M + 1 ) + c + i * COOP_N, OUT_CHANNELS, gl_CooperativeMatrixLayoutRowMajor);
562588 if (y + 1 < int (g_params.out_dims[1 ])) {
563589 coopMatStore(C1[j][i], g_out_buf, OUT_CHANNELS * ((y + 1 + 1 ) * g_params.output_stride + x + j * COOP_M + 1 ) + c + i * COOP_N, OUT_CHANNELS, gl_CooperativeMatrixLayoutRowMajor);
564590 }
591+ #endif // C_CONVERT
565592 }
566593 }
567594#endif // OUT_IMG
0 commit comments