diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index d356450101..50b7170d28 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include "mlx/backend/metal/kernels/bf16.h" #include "mlx/backend/metal/kernels/bf16_math.h" @@ -147,13 +148,13 @@ METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) { // Multiple Arrays with generic dims template -METAL_FUNC vec elem_to_loc_2_nd( +METAL_FUNC metal::vec elem_to_loc_2_nd( uint3 elem, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, int ndim) { - vec loc = { + metal::vec loc = { IdxT( elem.x * IdxT(a_strides[ndim - 1]) + IdxT(elem.y) * IdxT(a_strides[ndim - 2])), @@ -170,14 +171,14 @@ METAL_FUNC vec elem_to_loc_2_nd( } template -METAL_FUNC vec elem_to_loc_3_nd( +METAL_FUNC metal::vec elem_to_loc_3_nd( uint3 elem, constant const int* shape, constant const int64_t* a_strides, constant const int64_t* b_strides, constant const int64_t* c_strides, int ndim) { - vec loc = { + metal::vec loc = { IdxT(elem.x * IdxT(a_strides[ndim - 1])) + IdxT(elem.y * IdxT(a_strides[ndim - 2])), IdxT(elem.x * IdxT(b_strides[ndim - 1])) +