diff --git a/mlx/backend/cuda/CMakeLists.txt b/mlx/backend/cuda/CMakeLists.txt index fec5294865..fac31193bd 100644 --- a/mlx/backend/cuda/CMakeLists.txt +++ b/mlx/backend/cuda/CMakeLists.txt @@ -286,3 +286,9 @@ FetchContent_Declare( FetchContent_MakeAvailable(cutlass) target_include_directories( mlx SYSTEM PRIVATE $) + +# Install CUTLASS headers for JIT. +install(DIRECTORY ${cutlass_SOURCE_DIR}/include/cute + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) +install(DIRECTORY ${cutlass_SOURCE_DIR}/include/cutlass + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) diff --git a/mlx/backend/cuda/compiled.cpp b/mlx/backend/cuda/compiled.cpp index 402c8ce8b7..1cf6bf3182 100644 --- a/mlx/backend/cuda/compiled.cpp +++ b/mlx/backend/cuda/compiled.cpp @@ -239,7 +239,9 @@ void Compiled::eval_gpu( } int work_per_thread = 16 / max_size; - cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() { + auto& encoder = cu::get_command_encoder(s); + + cu::JitModule& mod = cu::get_jit_module(encoder.device(), lib_name(), [&]() { // Build source code. cu::FusedKernelBuilder builder{ g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_}; @@ -305,8 +307,6 @@ void Compiled::eval_gpu( } } - auto& encoder = cu::get_command_encoder(s); - // Put outputs. compiled_allocate_outputs( inputs, outputs, is_constant_, contiguous, [&](auto n) { diff --git a/mlx/backend/cuda/custom_kernel.cpp b/mlx/backend/cuda/custom_kernel.cpp index 9a6837acbb..0f102057c1 100644 --- a/mlx/backend/cuda/custom_kernel.cpp +++ b/mlx/backend/cuda/custom_kernel.cpp @@ -313,7 +313,7 @@ void CustomKernel::eval_gpu( std::string kernel_name = (is_precompiled_) ? name_ : "mlx::core::cu::" + name_; cu::JitModule& mod = cu::get_jit_module( - s.device, + encoder.device(), name_, [&]() { return std::make_tuple( diff --git a/mlx/backend/cuda/device.h b/mlx/backend/cuda/device.h index 99830e3cf1..06a728bf7d 100644 --- a/mlx/backend/cuda/device.h +++ b/mlx/backend/cuda/device.h @@ -57,7 +57,7 @@ class CommandEncoder { template void add_kernel_node_ex( - F* func, + F func, dim3 grid_dim, dim3 block_dim, dim3 cluster_dim, @@ -69,13 +69,18 @@ class CommandEncoder { ([&](auto&& p) { ptrs[i++] = static_cast(&p); }( std::forward(params)), ...); - add_kernel_node_raw( - reinterpret_cast(func), - grid_dim, - block_dim, - cluster_dim, - smem_bytes, - ptrs); + if constexpr (std::is_same_v) { + add_kernel_node_raw( + func, grid_dim, block_dim, cluster_dim, smem_bytes, ptrs); + } else { + add_kernel_node_raw( + reinterpret_cast(func), + grid_dim, + block_dim, + cluster_dim, + smem_bytes, + ptrs); + } } void add_kernel_node_raw( diff --git a/mlx/backend/cuda/quantized/qmm/cute_dequant.cuh b/mlx/backend/cuda/device/cute_dequant.cuh similarity index 78% rename from mlx/backend/cuda/quantized/qmm/cute_dequant.cuh rename to mlx/backend/cuda/device/cute_dequant.cuh index e507b6a9ac..6416c5b87a 100644 --- a/mlx/backend/cuda/quantized/qmm/cute_dequant.cuh +++ b/mlx/backend/cuda/device/cute_dequant.cuh @@ -110,13 +110,13 @@ namespace cute { // Required by tiled copy for 3/5/6-bit weights. struct uint24_t { - cuda::std::array bytes; + cuda::std::array bytes; }; struct uint40_t { - cuda::std::array bytes; + cuda::std::array bytes; }; struct uint48_t { - cuda::std::array bytes; + cuda::std::array bytes; }; template <> @@ -134,15 +134,22 @@ struct uint_bit<48> { } // namespace cute -namespace cutlass_gemm { +namespace mlx::core::cu { + +using namespace cute; // Whether the quant type is affine quantization. template constexpr bool quant_has_bias_v = !cutlass::has_negative_zero_v; // Dequantize CuTe tensors with out = w * s + z. -__device__ __forceinline__ void -cute_vectorized_dequant(auto w, auto s, auto z, auto out) { +template < + typename TensorW, + typename TensorS, + typename TensorZ, + typename TensorO> +CUTE_DEVICE void +cute_vectorized_dequant(TensorW w, TensorS s, TensorZ z, TensorO out) { using namespace cute; using Element = typename decltype(out)::value_type; using Quant = typename decltype(w)::value_type; @@ -166,4 +173,36 @@ cute_vectorized_dequant(auto w, auto s, auto z, auto out) { copy(make_tensor(make_rmem_ptr(&w_dq), out.layout()), out); } -} // namespace cutlass_gemm +template < + typename TensorW, + typename TensorS, + typename TensorZ, + typename TensorO> +CUTE_DEVICE void +cute_naive_dequant(TensorW w, TensorS s, TensorZ z, TensorO out) { + using Element = typename decltype(out)::value_type; + using Quant = typename decltype(w)::value_type; + using Scale = typename decltype(s)::value_type; + transform(w, out, [](Quant q) { return Element(q); }); + transform(out, s, out, [](Element e, Scale s) { return e * Element(s); }); + if constexpr (quant_has_bias_v) { + transform(out, z, out, plus{}); + } +} + +template < + typename TensorW, + typename TensorS, + typename TensorZ, + typename TensorO> +CUTE_DEVICE void cute_dequant(TensorW w, TensorS s, TensorZ z, TensorO out) { + if constexpr ( + stride(coalesce(w.layout())) == Int<1>{} && + is_static_v) { + cute_vectorized_dequant(w, s, z, out); + } else { + cute_naive_dequant(w, s, z, out); + } +} + +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/quantized/qmm/qmm_naive.cuh b/mlx/backend/cuda/device/qmm_naive.cuh similarity index 64% rename from mlx/backend/cuda/quantized/qmm/qmm_naive.cuh rename to mlx/backend/cuda/device/qmm_naive.cuh index dde43e0aa0..f9ee549962 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_naive.cuh +++ b/mlx/backend/cuda/device/qmm_naive.cuh @@ -1,12 +1,12 @@ // Copyright © 2026 Apple Inc. -#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh" -#include "mlx/dtype_utils.h" +#include "mlx/backend/cuda/device/cute_dequant.cuh" + +#include // clang-format off -// We can't put kernel code in mlx::core due to name conflicts of "Shape". -namespace cutlass_gemm { +namespace mlx::core::cu { using namespace cute; @@ -16,8 +16,8 @@ struct SharedStorage { ArrayEngine> B; }; -template -inline constexpr auto make_smem_layout(auto bM, auto bK) { +template +inline constexpr auto make_smem_layout(TileM bM, TileN bK) { // TODO: Calculate swizzle based on tile shape. if constexpr (KMajor) { auto swizzle = composition(Swizzle<3,3,3>{}, @@ -31,16 +31,44 @@ inline constexpr auto make_smem_layout(auto bM, auto bK) { } } -template -inline constexpr auto make_smem_layouts(auto cta_tiler) { +template +inline constexpr auto make_smem_layouts(CtaTiler cta_tiler) { + // Note: Kernel launcher assumes cosize being same for all KMajor. auto [bM, bN, bK] = cta_tiler; auto sA_layout = make_smem_layout(bM, bK); auto sB_layout = make_smem_layout(bN, bK); - return std::make_tuple(sA_layout, sB_layout); + return cute::make_tuple(sA_layout, sB_layout); } -template -inline constexpr auto make_tiled_copy(auto num_threads, auto bM, auto bK) { +template +inline constexpr auto make_tiled_mma(CtaTiler cta_tiler) { + // Note: Kernel launcher assumes num_threads being same for all parameters. + using Atom = cuda::std::conditional_t< + SM80, + cuda::std::conditional_t< + cuda::std::is_same_v, + SM80_16x8x16_F32F16F16F32_TN, + cuda::std::conditional_t< + cuda::std::is_same_v, + SM80_16x8x16_F32BF16BF16F32_TN, + UniversalFMA + > + >, + UniversalFMA>; + if constexpr (!SM80 || cuda::std::is_same_v) { + return make_tiled_mma(Atom{}, Layout>{}); + } else { + if constexpr (size<0>(cta_tiler) >= 32) { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); + } else { + return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); + } + } +} + +template +inline constexpr auto make_tiled_copy(NumThreads num_threads, TileM bM, TileN bK) { // TODO: Only do 1-element read for the tile of residue. auto n_read = Int{}; auto atom = Copy_Atom>>, T>{}; @@ -59,37 +87,13 @@ inline constexpr auto make_tiled_copy(auto num_threads, auto bM, auto bK) { } } - -__device__ __forceinline__ void -cute_naive_dequant(auto w, auto s, auto z, auto out) { - using Element = typename decltype(out)::value_type; - using Quant = typename decltype(w)::value_type; - using Scale = typename decltype(s)::value_type; - transform(w, out, [](Quant q) { return Element(q); } ); - transform(out, s, out, [](Element e, Scale s) { return e * Element(s); }); - if constexpr (quant_has_bias_v) { - transform(out, z, out, plus{}); - } -} - -__device__ __forceinline__ void -cute_dequant(auto w, auto s, auto z, auto out) { - if constexpr (stride(coalesce(w.layout())) == Int<1>{} && - is_static_v) { - cute_vectorized_dequant(w, s, z, out); - } else { - cute_naive_dequant(w, s, z, out); - } -} - template + typename TensorC> CUTE_DEVICE void qmm_naive_mainloop( CtaTiler cta_tiler, TensorA gA, @@ -97,20 +101,20 @@ CUTE_DEVICE void qmm_naive_mainloop( TensorS gS, TensorZ gZ, TensorC gC, - TiledMma mma, int m_max_coord, int n_max_coord, int k_residue, int thread_idx) { // Get the types of operands. - using Element = decltype(gA)::value_type; - using Quant = decltype(gB)::value_type; + using Element = typename decltype(gA)::value_type; + using Quant = typename decltype(gB)::value_type; // Shift tensor so we handle residue of K in the 0th tile. gA = domain_offset(make_coord(0, k_residue, 0), gA); if constexpr (sizeof_bits_v % 8 == 0) { gB = domain_offset(make_coord(0, k_residue, 0), gB); } else { + // TODO: Figure out why domain_offset is not returning wrong offset. gB.data() = recast_ptr(raw_pointer_cast(gB.data()) + gB.layout()(0, k_residue, 0) * cuda::std::min(8, sizeof_bits_v) / 8); } gS = domain_offset(make_coord(0, k_residue, 0), gS); @@ -119,7 +123,7 @@ CUTE_DEVICE void qmm_naive_mainloop( } // Define smem layouts. - auto [sA_layout, sB_layout] = make_smem_layouts(cta_tiler); + auto [sA_layout, sB_layout] = make_smem_layouts(cta_tiler); // Shared memory buffer. extern __shared__ char smem_buf[]; @@ -128,8 +132,11 @@ CUTE_DEVICE void qmm_naive_mainloop( Tensor sA = make_tensor(make_smem_ptr(smem.A.begin()), sA_layout); // (BLK_M,BLK_K) Tensor sB = make_tensor(make_smem_ptr(smem.B.begin()), sB_layout); // (BLK_N,BLK_K) - // Define copy atoms. + // Define MMA. + auto mma = make_tiled_mma(CtaTiler{}); auto num_threads = size(mma); + + // Define copy atoms. auto [bM, bN, bK] = cta_tiler; TiledCopy copy_a = make_tiled_copy(num_threads, bM, bK); TiledCopy copy_b = make_tiled_copy(num_threads, bN, bK); @@ -271,16 +278,17 @@ CUTE_DEVICE void qmm_naive_mainloop( } template -inline constexpr auto make_matrix_stride(auto m, auto k) { +inline constexpr auto make_matrix_stride(int m, int k) { if constexpr (KMajor) { - return cute::make_stride(k, cute::Int<1>{}, m * k); + return make_stride(k, Int<1>{}, m * k); } else { - return cute::make_stride(cute::Int<1>{}, m, m * k); + return make_stride(Int<1>{}, m, m * k); } } -template -inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size) { +template +inline constexpr auto make_scales_layout(int n, int k, int l) { + auto group_size = Int{}; if constexpr (KMajor) { return make_layout( make_shape(n, make_shape(group_size, k / group_size), l), @@ -292,106 +300,88 @@ inline constexpr auto make_scales_layout(auto n, auto k, auto l, auto group_size } } -template -inline constexpr auto make_cta_tiler(auto group_size) { - auto bM = Int{}; - auto bN = Int<(!SM80 && group_size > 64) ? 64 : 128>{}; - auto bK = Int{}; - return make_shape(bM, bN, bK); -} +template +__global__ +__launch_bounds__(decltype(size(make_tiled_mma(CtaTiler{})))::value) +void qmm_naive_kernel( + const Element* A, + const Quant* B, + const Scale* S, + const Element* Z, + const uint32_t* lhs_indices, + const uint32_t* rhs_indices, + Element* C, + int m, int n, int k, int l, + bool broadcast_b) { + int thread_idx = int(threadIdx.x); + int m_coord = int(blockIdx.x); + int n_coord = int(blockIdx.y); + int l_coord = int(blockIdx.z); + + // Define layouts (mixed). + auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL) + auto dB = make_matrix_stride(n, k); // (dN,dK,dL) + auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) + auto S_layout = make_scales_layout(n, k, l); + + // Handle broadcasting. + if (broadcast_b) { + get<2>(dB) = 0; + get<2>(stride(S_layout)) = 0; + } -template -inline constexpr auto make_tiled_mma(auto cta_tiler) { - using Atom = std::conditional_t< - SM80, - std::conditional_t< - std::is_same_v, - SM80_16x8x16_F32F16F16F32_TN, - std::conditional_t< - std::is_same_v, - SM80_16x8x16_F32BF16BF16F32_TN, - UniversalFMA - > - >, - UniversalFMA>; - if constexpr (!SM80 || std::is_same_v) { - return make_tiled_mma(Atom{}, Layout>{}); - } else { - if constexpr (size<0>(cta_tiler) >= 32) { - return make_tiled_mma(Atom{}, Layout>{}, Tile<_32,_32,_16>{}); + // Represent the full tensors. + Tensor mA_mkl = make_tensor(make_gmem_ptr(A), make_shape(m, k, l), dA); // (M,K,L) + Tensor mB_nkl = make_tensor(make_gmem_ptr(B), make_shape(n, k, l), dB); // (N,K,L) + Tensor mC_mnl = make_tensor(make_gmem_ptr(C), make_shape(m, n, l), dC); // (M,N,L) + + Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L) + + // For gather, use index lookup for input batch slicing. + uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord; + uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord; + + // Get batch slice. + Tensor mA = mA_mkl(_,_,a_batch); // (M,K) + Tensor mB = mB_nkl(_,_,b_batch); // (N,K) + Tensor mC = mC_mnl(_,_,l_coord); // (M,N) + + Tensor mS = mS_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) + + // Get the appropriate blocks for this thread block. + auto cta_tiler = CtaTiler{}; + auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + auto gZ = [&]() { + if constexpr (quant_has_bias_v) { + Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) + Tensor mZ = mZ_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) + return local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) } else { - return make_tiled_mma(Atom{}, Layout>{}, Tile<_16,_32,_16>{}); + // Dummy tensor; no-bias paths never offset or load gZ. + return gS; } - } -} - -} // namespace cutlass_gemm - -// clang-format on - -namespace mlx::core { - -template -inline void dispatch_element_types(Dtype dtype, const char* tag, F&& f) { - if (dtype == float32) { - f.template operator()(); - } else if (dtype == float16) { - f.template operator()(); - } else if (dtype == bfloat16) { - f.template operator()(); - } else { - throw std::invalid_argument( - fmt::format("{} Unsupported dtype: {}.", tag, dtype_to_string(dtype))); - } -} - -template -inline void dispatch_groups(int group_size, const char* tag, F&& f) { - if (group_size == 32) { - f.template operator()<32>(); - } else if (group_size == 64) { - f.template operator()<64>(); - } else if (group_size == 128) { - f.template operator()<128>(); - } else { - throw std::invalid_argument( - fmt::format("{} Group size {} is not supported.", tag, group_size)); - } -} - -template -inline void dispatch_quant_types( - int bits, - int group_size, - QuantizationMode mode, - const char* tag, - F&& f) { - if (mode == QuantizationMode::Mxfp4) { - f.template operator()(); - } else if (mode == QuantizationMode::Mxfp8) { - f.template operator()(); - } else if (mode == QuantizationMode::Nvfp4) { - f.template operator()(); - } else { - dispatch_groups(group_size, tag, [&]() { - if (bits == 2) { - f.template operator()(); - } else if (bits == 3) { - f.template operator()(); - } else if (bits == 4) { - f.template operator()(); - } else if (bits == 5) { - f.template operator()(); - } else if (bits == 6) { - f.template operator()(); - } else if (bits == 8) { - f.template operator()(); - } else { - throw std::invalid_argument( - fmt::format("{} {}-bit quantization is not supported.", tag, bits)); - } - }); - } + }(); + + // Compute tile residues for predication. + int m_max_coord = m - size<0>(cta_tiler) * m_coord; // M - BLK_M * m_coord + int n_max_coord = n - size<1>(cta_tiler) * n_coord; // N - BLK_N * n_coord + int k_residue = k - size<1>(gA) * size<2>(gA); + + qmm_naive_mainloop( + cta_tiler, + gA, + gB, + gS, + gZ, + gC, + m_max_coord, n_max_coord, k_residue, + thread_idx); } -} // namespace mlx::core +} // namespace mlx::core::cu diff --git a/mlx/backend/cuda/hadamard.cu b/mlx/backend/cuda/hadamard.cu index d7def2831a..d6778c4966 100644 --- a/mlx/backend/cuda/hadamard.cu +++ b/mlx/backend/cuda/hadamard.cu @@ -122,7 +122,9 @@ void hadamard_mn_contiguous( read_width_n2, read_width_m); - cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + auto& encoder = cu::get_command_encoder(s); + + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { std::vector kernel_names = {n2_kernel_name}; if (n1 > 1) { kernel_names.push_back(n1_kernel_name); @@ -142,8 +144,6 @@ void hadamard_mn_contiguous( return std::make_tuple(false, std::move(source), std::move(kernel_names)); }); - auto& encoder = cu::get_command_encoder(s); - if (n1 > 1) { const int64_t num_transforms = x.size() / n1; const uint32_t num_blocks = diff --git a/mlx/backend/cuda/indexing.cpp b/mlx/backend/cuda/indexing.cpp index 0bec840463..68f5ea1393 100644 --- a/mlx/backend/cuda/indexing.cpp +++ b/mlx/backend/cuda/indexing.cpp @@ -87,7 +87,7 @@ void Gather::eval_gpu(const std::vector& inputs, array& out) { dtype_to_string(idx_dtype), nidx); - cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { for (int large = 0; large <= 1; ++large) { @@ -182,7 +182,9 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { nidx); auto& s = stream(); - cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + auto& encoder = cu::get_command_encoder(s); + + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { for (int large = 0; large <= 1; ++large) { @@ -231,7 +233,6 @@ void Scatter::eval_gpu(const std::vector& inputs, array& out) { idx_ndim, large ? "int64_t" : "int32_t"); - auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { encoder.set_input_array(in); } @@ -262,7 +263,7 @@ void GatherAxis::eval_gpu(const std::vector& inputs, array& out) { dtype_to_string(out.dtype()), dtype_to_string(idx.dtype())); - cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { for (int contiguous = 0; contiguous < 4; ++contiguous) { @@ -366,7 +367,9 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { op); auto& s = stream(); - cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + auto& encoder = cu::get_command_encoder(s); + + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { std::vector kernel_names; for (int ndim = 0; ndim <= MAX_NDIM; ++ndim) { for (int contiguous = 0; contiguous < 4; ++contiguous) { @@ -429,7 +432,6 @@ void ScatterAxis::eval_gpu(const std::vector& inputs, array& out) { idx.flags().row_contiguous, large ? "int64_t" : "int32_t"); - auto& encoder = cu::get_command_encoder(s); for (const auto& in : inputs) { encoder.set_input_array(in); } @@ -489,7 +491,7 @@ void MaskedScatter::eval_gpu(const std::vector& inputs, array& out) { std::string module_name = fmt::format("masked_scatter_{}", dtype_to_string(out.dtype())); - cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { std::vector kernel_names; for (int src_contiguous = 0; src_contiguous <= 1; ++src_contiguous) { for (int dst_contiguous = 0; dst_contiguous <= 1; ++dst_contiguous) { @@ -626,7 +628,7 @@ void SliceUpdate::eval_gpu(const std::vector& inputs, array& out) { auto& s = stream(); auto& encoder = cu::get_command_encoder(s); - cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { std::vector kernel_names; for (int out_c = 0; out_c <= 1; ++out_c) { for (int upd_c = 0; upd_c <= 1; ++upd_c) { diff --git a/mlx/backend/cuda/jit_module.cpp b/mlx/backend/cuda/jit_module.cpp index 433712543a..0246eff4ca 100644 --- a/mlx/backend/cuda/jit_module.cpp +++ b/mlx/backend/cuda/jit_module.cpp @@ -50,12 +50,17 @@ const std::filesystem::path& default_cuda_toolkit_path() { const std::vector& include_path_args() { static std::vector cached_args = []() { std::vector args; - // Add path to bundled CCCL headers. + // Add path to bundled headers. auto root_dir = current_binary_dir(); #if !defined(_WIN32) root_dir = root_dir.parent_path(); #endif - auto path = root_dir / "include" / "cccl"; + auto path = root_dir / "include"; + if (std::filesystem::exists(path)) { + args.push_back(fmt::format("--include-path={}", path.string())); + } + // Add path to CCCL headers. + path = path / "cccl"; #if defined(MLX_CCCL_DIR) if (!std::filesystem::exists(path)) { path = MLX_CCCL_DIR; @@ -246,9 +251,11 @@ constexpr const char* g_include_names[] = { INCLUDE_PREFIX "cast_op.cuh", INCLUDE_PREFIX "config.h", INCLUDE_PREFIX "complex.cuh", + INCLUDE_PREFIX "cute_dequant.cuh", INCLUDE_PREFIX "fp16_math.cuh", INCLUDE_PREFIX "hadamard.cuh", INCLUDE_PREFIX "indexing.cuh", + INCLUDE_PREFIX "qmm_naive.cuh", INCLUDE_PREFIX "scatter_ops.cuh", INCLUDE_PREFIX "unary_ops.cuh", INCLUDE_PREFIX "ternary_ops.cuh", @@ -263,9 +270,11 @@ constexpr const char* g_headers[] = { jit_source_cast_op, jit_source_config, jit_source_complex, + jit_source_cute_dequant, jit_source_fp16_math, jit_source_hadamard, jit_source_indexing, + jit_source_qmm_naive, jit_source_scatter_ops, jit_source_unary_ops, jit_source_ternary_ops, @@ -295,8 +304,11 @@ void compile( CHECK_NVRTC_ERROR(nvrtcAddNameExpression(prog, name.c_str())); } - // Compile program. + // Required for compiling CUTLASS code. std::vector args; + args.push_back("--device-as-default-execution-space"); + + // Target current device. bool use_sass = compiler_supports_device_sass(device); auto cc = device.compute_capability_major(); std::string arch_tag = (cc >= 9) ? "a" : ""; @@ -310,6 +322,8 @@ void compile( for (const auto& include : include_path_args()) { args.push_back(include.c_str()); } + + // Compile program. nvrtcResult compile_result = nvrtcCompileProgram(prog, args.size(), args.data()); if (compile_result != NVRTC_SUCCESS) { @@ -441,7 +455,7 @@ CUfunction JitModule::get_kernel( } JitModule& get_jit_module( - const mlx::core::Device& device, + Device& device, const std::string& name, const KernelBuilder& builder, bool use_disk_cache) { @@ -460,8 +474,7 @@ JitModule& get_jit_module( std::unique_lock wlock(*mtx); auto it = cache->find(name); if (it == cache->end()) { - auto& d = cu::device(device); - it = cache->try_emplace(name, d, name, builder, use_disk_cache).first; + it = cache->try_emplace(name, device, name, builder, use_disk_cache).first; } return it->second; } diff --git a/mlx/backend/cuda/jit_module.h b/mlx/backend/cuda/jit_module.h index 4a779cc3a8..f6796625d1 100644 --- a/mlx/backend/cuda/jit_module.h +++ b/mlx/backend/cuda/jit_module.h @@ -111,7 +111,7 @@ class JitModule { }; JitModule& get_jit_module( - const mlx::core::Device& device, + Device& device, const std::string& name, const KernelBuilder& builder, bool use_disk_cache = true); diff --git a/mlx/backend/cuda/quantized/qmm/CMakeLists.txt b/mlx/backend/cuda/quantized/qmm/CMakeLists.txt index 0d682eade3..2fc2ece0b3 100644 --- a/mlx/backend/cuda/quantized/qmm/CMakeLists.txt +++ b/mlx/backend/cuda/quantized/qmm/CMakeLists.txt @@ -1,6 +1,8 @@ target_sources( mlx - PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cu ${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/qmm.cu + ${CMAKE_CURRENT_SOURCE_DIR}/qmm_naive.cu + ${CMAKE_CURRENT_SOURCE_DIR}/qmv.cu ${CMAKE_CURRENT_SOURCE_DIR}/fp_qmv.cu) foreach(TileN 16 32 64 128 256) @@ -16,20 +18,3 @@ foreach(TileM 16 32 64) "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}" @ONLY) target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}) endforeach() - -foreach(TileM 16 32 64) - foreach(KMajor true false) - foreach(HasKResidue true false) - foreach(SM80 true false) - if(${KMajor} AND ${HasKResidue}) - continue() - endif() - set(OUTPUT_FILE - "qmm_naive_impl_m${TileM}_${KMajor}_${HasKResidue}_${SM80}.cu") - configure_file("${CMAKE_CURRENT_SOURCE_DIR}/qmm_naive.cu" - "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}" @ONLY) - target_sources(mlx PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE}) - endforeach() - endforeach() - endforeach() -endforeach() diff --git a/mlx/backend/cuda/quantized/qmm/qmm.cu b/mlx/backend/cuda/quantized/qmm/qmm.cu index 41e802d6ac..403f9f189d 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm.cu @@ -194,21 +194,6 @@ void qmm_sm80( } } -// Defined in qmm_naive.cu. -template -void qmm_naive_impl( - const array& x, - const array& w, - const array& scales, - const std::optional& biases, - const std::optional& lhs_indices, - const std::optional& rhs_indices, - array& out, - int bits, - int group_size, - QuantizationMode mode, - cu::CommandEncoder& encoder); - bool supports_qmm_naive( const array& x, const array& w, @@ -234,68 +219,6 @@ bool supports_qmm_naive( return true; } -void qmm_naive( - const array& x, - const array& w, - const array& scales, - const std::optional& biases, - const std::optional& lhs_indices, - const std::optional& rhs_indices, - array& out, - bool transpose, - int bits, - int group_size, - QuantizationMode mode, - cu::CommandEncoder& encoder) { - auto dispatch = [&]() { - qmm_naive_impl( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - out, - bits, - group_size, - mode, - encoder); - }; - auto dispatch_k = [&](auto k_major, bool has_k_residue, auto&& f) { - if constexpr (k_major.value) { - if (has_k_residue) { - throw std::invalid_argument( - "[quantized_matmul] K must be multiples of max(64, group_size)."); - } - f.template operator()(); - } else { - dispatch_bool(has_k_residue, [&](auto has_k_residue) { - f.template operator()(); - }); - } - }; - int m = out.ndim() > 1 ? out.shape(-2) : 1; - int k = x.shape(-1); - int tile_k = std::max(64, group_size); - bool has_k_residue = k % tile_k != 0; - bool sm80 = encoder.device().compute_capability_major() >= 8; - dispatch_bool(transpose, [&](auto k_major) { - dispatch_k(k_major, has_k_residue, [&]() { - dispatch_bool(sm80, [&](auto sm80) { - constexpr bool KMajor = k_major.value; - constexpr bool SM80 = sm80.value; - if (m <= 16) { - dispatch.template operator()<16, KMajor, HasKResidue, SM80>(); - } else if (m <= 32) { - dispatch.template operator()<32, KMajor, HasKResidue, SM80>(); - } else { - dispatch.template operator()<64, KMajor, HasKResidue, SM80>(); - } - }); - }); - }); -} - bool supports_fp_qmv( const array& x, const array& w, diff --git a/mlx/backend/cuda/quantized/qmm/qmm_naive.cu b/mlx/backend/cuda/quantized/qmm/qmm_naive.cu index d13fd5a150..c01fd2a639 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_naive.cu +++ b/mlx/backend/cuda/quantized/qmm/qmm_naive.cu @@ -1,174 +1,82 @@ // Copyright © 2026 Apple Inc. +#include "mlx/backend/cuda/device/qmm_naive.cuh" +#include "mlx/backend/cuda/jit_module.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" -#include "mlx/backend/cuda/quantized/qmm/qmm_naive.cuh" - -// clang-format off - -// We can't put kernel code in mlx::core due to name conflicts of "Shape". -namespace cutlass_gemm { - -using namespace cute; - -template -__global__ -__launch_bounds__(decltype(size(TiledMma{}))::value) -void qmm_naive_kernel( - ProblemShape shape_MNKL, - CtaTiler cta_tiler, - const Element* A, StrideA dA, - const Quant* B, StrideB dB, - const Scale* S, const Element* Z, LayoutS S_layout, - const uint32_t* lhs_indices, const uint32_t* rhs_indices, - Element* C, StrideC dC, - TiledMma mma) { - CUTE_STATIC_ASSERT_V(congruent(select<0,2,3>(shape_MNKL), dA)); - CUTE_STATIC_ASSERT_V(congruent(select<1,2,3>(shape_MNKL), dB)); - CUTE_STATIC_ASSERT_V(congruent(select<0,1,3>(shape_MNKL), dC)); - - int thread_idx = int(threadIdx.x); - int m_coord = int(blockIdx.x); - int n_coord = int(blockIdx.y); - int l_coord = int(blockIdx.z); - - // Represent the full tensors. - Tensor mA_mkl = make_tensor(make_gmem_ptr(A), select<0,2,3>(shape_MNKL), dA); // (M,K,L) - Tensor mB_nkl = make_tensor(make_gmem_ptr(B), select<1,2,3>(shape_MNKL), dB); // (N,K,L) - Tensor mC_mnl = make_tensor(make_gmem_ptr(C), select<0,1,3>(shape_MNKL), dC); // (M,N,L) - - Tensor mS_nkl = make_tensor(make_gmem_ptr(S), S_layout); // (N,(group_size,K/group_size),L) - - // For gather, use index lookup for input batch slicing. - uint32_t a_batch = lhs_indices ? lhs_indices[l_coord] : l_coord; - uint32_t b_batch = rhs_indices ? rhs_indices[l_coord] : l_coord; - - // Get batch slice. - Tensor mA = mA_mkl(_,_,a_batch); // (M,K) - Tensor mB = mB_nkl(_,_,b_batch); // (N,K) - Tensor mC = mC_mnl(_,_,l_coord); // (M,N) - - Tensor mS = mS_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) - - // Get the appropriate blocks for this thread block. - auto cta_coord = make_coord(m_coord, n_coord, _); // (m,n,k) - Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) - Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) - - Tensor gS = local_tile(mS, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - auto gZ = [&]() { - if constexpr (quant_has_bias_v) { - Tensor mZ_nkl = make_tensor(make_gmem_ptr(Z), S_layout); // (N,(group_size,K/group_size),L) - Tensor mZ = mZ_nkl(_,_,b_batch); // (N,(group_size,K/group_size)) - return local_tile(mZ, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) - } else { - // Dummy tensor; no-bias paths never offset or load gZ. - return gS; - } - }(); - - // Compute tile residues for predication. - int m_max_coord = size<0>(shape_MNKL) - size<0>(cta_tiler) * m_coord; // M - BLK_M * m_coord - int n_max_coord = size<1>(shape_MNKL) - size<1>(cta_tiler) * n_coord; // N - BLK_N * n_coord - int k_residue = size<2>(shape_MNKL) - size<1>(gA) * size<2>(gA); - - qmm_naive_mainloop( - cta_tiler, - gA, - gB, - gS, - gZ, - gC, - mma, - m_max_coord, n_max_coord, k_residue, - thread_idx); -} +#include "mlx/dtype_utils.h" -template -void qmm_naive( - const Element* A, - const Quant* B, - const Scale* S, - const Element* Z, - const uint32_t* lhs_indices, - const uint32_t* rhs_indices, - Element* C, - int m, int n, int k, int l, - bool broadcast_b, - auto group_size, - auto&& launch_kernel) { - // Define shapes (dynamic). - auto shape_MNKL = make_shape(m, n, k, l); // (M,N,K,L) - - // Define layouts (mixed). - auto dA = make_stride(k, Int<1>{}, m * k); // (dM,dK,dL) - auto dB = make_matrix_stride(n, k); // (dN,dK,dL) - auto dC = make_stride(n, Int<1>{}, m * n); // (dM,dN,dL) - auto S_layout = make_scales_layout(n, k, l, group_size); - - // Handle broadcasting. - if (broadcast_b) { - get<2>(dB) = 0; - get<2>(stride(S_layout)) = 0; - } +#include "cuda_jit_sources.h" - // Define CTA tile size (static). - auto cta_tiler = make_cta_tiler(group_size); - - // Define MMA. - auto mma = make_tiled_mma(cta_tiler); - auto num_threads = size(mma); - - // Shared memory size. - auto [sA_layout, sB_layout] = make_smem_layouts(cta_tiler); - size_t smem_bytes = sizeof(SharedStorage); - - auto* kernel = &qmm_naive_kernel< - KMajor, HasKResidue, SM80, - Element, Quant, Scale, - decltype(shape_MNKL), - decltype(cta_tiler), - decltype(dA), - decltype(dB), - decltype(S_layout), - decltype(dC), - decltype(mma)>; - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); - - dim3 num_blocks{uint32_t(ceil_div(m, size<0>(cta_tiler))), - uint32_t(ceil_div(n, size<1>(cta_tiler))), - uint32_t(l)}; - dim3 block_dims{uint32_t(num_threads)}; - void* args[] = { - &shape_MNKL, - &cta_tiler, - &A, &dA, - &B, &dB, - &S, &Z, &S_layout, - &lhs_indices, &rhs_indices, - &C, &dC, - &mma}; - launch_kernel(reinterpret_cast(kernel), num_blocks, block_dims, smem_bytes, args); +namespace mlx::core { + +namespace { + +inline auto make_cta_tiler(int itemsize, int m, int group_size, bool sm80) { + bool enough_smem = sm80 && itemsize <= 2 && group_size <= 64; + int tile_m = std::max(16, std::min(64, next_power_of_2(m))); + int tile_n = enough_smem ? 128 : 64; + int tile_k = std::max(64, group_size); + return cute::make_shape(tile_m, tile_n, tile_k); } -} // namespace cutlass_gemm +inline auto cta_tiler_to_string(auto cta_tiler) { + return fmt::format( + "cute::Shape, cute::Int<{}>, cute::Int<{}>>", + cute::size<0>(cta_tiler), + cute::size<1>(cta_tiler), + cute::size<2>(cta_tiler)); +} -// clang-format on +const char* get_weight_cutlass_type(const Dtype& dtype) { + switch (dtype) { + case float16: + return "cutlass::half_t"; + case bfloat16: + return "cutlass::bfloat16_t"; + case float32: + return "float"; + default: + throw std::invalid_argument( + fmt::format( + "[quantized_matmul] Unsupported dtype: {}.", + dtype_to_string(dtype))); + } +} -namespace mlx::core { +inline std::tuple +get_quant_cutlass_types(const char* ctype_x, int bits, QuantizationMode mode) { + if (mode == QuantizationMode::Mxfp4) { + return {"cutlass::float_e2m1_t", "cutlass::float_ue8m0_t"}; + } else if (mode == QuantizationMode::Mxfp8) { + return {"cutlass::float_e4m3_t", "cutlass::float_ue8m0_t"}; + } else if (mode == QuantizationMode::Nvfp4) { + return {"cutlass::float_e2m1_t", "cutlass::float_e4m3_t"}; + } else { + if (bits == 2) { + return {"cutlass::uint2b_t", ctype_x}; + } else if (bits == 3) { + return {"cutlass::uint3b_t", ctype_x}; + } else if (bits == 4) { + return {"cutlass::uint4b_t", ctype_x}; + } else if (bits == 5) { + return {"cutlass::uint5b_t", ctype_x}; + } else if (bits == 6) { + return {"cutlass::uint6b_t", ctype_x}; + } else if (bits == 8) { + return {"uint8_t", ctype_x}; + } else { + throw std::invalid_argument( + fmt::format( + "[quantized_matmul] {}-bit quantization is not supported.", + bits)); + } + } +} + +} // namespace -template -void qmm_naive_impl( +void qmm_naive( const array& x, const array& w, const array& scales, @@ -176,76 +84,92 @@ void qmm_naive_impl( const std::optional& lhs_indices, const std::optional& rhs_indices, array& out, + bool transpose, int bits, int group_size, QuantizationMode mode, cu::CommandEncoder& encoder) { - const char* tag = "[quantized_matmul]"; int m = out.ndim() > 1 ? out.shape(-2) : 1; int n = out.shape(-1); int k = x.shape(-1); int l = out.size() / (m * n); bool broadcast_b = (w.ndim() <= 2) || (w.size() != w.data_size()); - dispatch_element_types(out.dtype(), tag, [&]() { - dispatch_quant_types( - bits, - group_size, - mode, - tag, - [&]() { - encoder.set_input_array(x); - encoder.set_input_array(w); - encoder.set_input_array(scales); - if (biases) { - encoder.set_input_array(*biases); - } - if (lhs_indices) { - encoder.set_input_array(*lhs_indices); - } - if (rhs_indices) { - encoder.set_input_array(*rhs_indices); - } - encoder.set_output_array(out); - cutlass_gemm::qmm_naive( - gpu_ptr(x), - gpu_ptr(w), - gpu_ptr(scales), - biases ? gpu_ptr(*biases) : nullptr, - lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, - rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, - gpu_ptr(out), - m, - n, - k, - l, - broadcast_b, - cute::Int{}, - [&](auto* kernel, - dim3 num_blocks, - dim3 block_dims, - size_t smem_bytes, - void** args) { - encoder.add_kernel_node_raw( - kernel, num_blocks, block_dims, {}, smem_bytes, args); - }); - }); + bool sm80 = encoder.device().compute_capability_major() >= 8; + auto cta_tiler = make_cta_tiler(x.itemsize(), m, group_size, sm80); + bool has_k_residue = (k % cute::size<2>(cta_tiler)) != 0; + + std::string module_name = fmt::format( + "qmm_naive_{}_{}_{}_m{}_b{}_g{}_{}", + dtype_to_string(x.dtype()), + transpose ? "k" : "n", + has_k_residue ? "residue" : "aligned", + cute::size<0>(cta_tiler), + bits, + group_size, + quantization_mode_to_string(mode)); + + auto ctype_x = get_weight_cutlass_type(x.dtype()); + auto [ctype_q, ctype_s] = get_quant_cutlass_types(ctype_x, bits, mode); + + std::string kernel_name = fmt::format( + "mlx::core::cu::qmm_naive_kernel<{}, {}, {}, {}, {}, {}, {}, {}>", + group_size, + transpose, + has_k_residue, + sm80, + ctype_x, + ctype_q, + ctype_s, + cta_tiler_to_string(cta_tiler)); + + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { + return std::make_tuple( + false, jit_source_qmm_naive, std::vector{kernel_name}); }); -} -// clang-format off -template void qmm_naive_impl<@TileM@, @KMajor@, @HasKResidue@, @SM80@>( - const array& x, - const array& w, - const array& scales, - const std::optional& biases, - const std::optional& lhs_indices, - const std::optional& rhs_indices, - array& out, - int bits, - int group_size, - QuantizationMode mode, - cu::CommandEncoder& encoder); -// clang-format on + encoder.set_input_array(x); + encoder.set_input_array(w); + encoder.set_input_array(scales); + if (biases) { + encoder.set_input_array(*biases); + } + if (lhs_indices) { + encoder.set_input_array(*lhs_indices); + } + if (rhs_indices) { + encoder.set_input_array(*rhs_indices); + } + encoder.set_output_array(out); + + dim3 num_blocks{ + uint32_t(cute::ceil_div(m, cute::size<0>(cta_tiler))), + uint32_t(cute::ceil_div(n, cute::size<1>(cta_tiler))), + uint32_t(l)}; + dim3 block_dims{uint32_t(cute::size(cu::make_tiled_mma(cta_tiler)))}; + + auto [sA_layout, sB_layout] = cu::make_smem_layouts(cta_tiler); + size_t smem_bytes = + x.itemsize() * (cute::cosize(sA_layout) + cute::cosize(sB_layout)); + + encoder.add_kernel_node_ex( + mod.get_kernel(kernel_name), + num_blocks, + block_dims, + {}, + smem_bytes, + gpu_ptr(x), + gpu_ptr(w), + gpu_ptr(scales), + biases ? gpu_ptr(*biases) : nullptr, + lhs_indices ? gpu_ptr(*lhs_indices) : nullptr, + rhs_indices ? gpu_ptr(*rhs_indices) : nullptr, + gpu_ptr(out), + m, + n, + k, + l, + broadcast_b); +} } // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh b/mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh index fdeceaab5f..d07909747b 100644 --- a/mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh +++ b/mlx/backend/cuda/quantized/qmm/qmm_sm80.cuh @@ -1,6 +1,6 @@ // Copyright © 2026 Apple Inc. -#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh" +#include "mlx/backend/cuda/device/cute_dequant.cuh" #include "mlx/dtype_utils.h" // clang-format off @@ -181,7 +181,7 @@ CUTE_DEVICE void qmm_sm80_mainloop( // Copy S/Z: GMEM => RMEM. auto fetch_scales = [&](int tile) { copy(g2r_copy_s, g2r_tCgS(_,_,_,tile), g2r_tCrS); - if constexpr (quant_has_bias_v) { + if constexpr (mlx::core::cu::quant_has_bias_v) { copy(g2r_copy_s, g2r_tCgZ(_,_,_,tile), g2r_tCrZ); } }; @@ -191,7 +191,7 @@ CUTE_DEVICE void qmm_sm80_mainloop( copy(s2r_atom_b, s2r_tCsB(_,_,block,smem_pipe_read), s2r_tCrB(_,_,block)); CUTE_UNROLL for (int n = 0; n < size<1>(tCrB); ++n) { - cute_vectorized_dequant( + mlx::core::cu::cute_vectorized_dequant( tCrB(_,n,block), tCrS(_,n,block), tCrZ(_,n,block), diff --git a/mlx/backend/cuda/quantized/qmm/qmv.cu b/mlx/backend/cuda/quantized/qmm/qmv.cu index 540e83cd2e..7a293d2ce8 100644 --- a/mlx/backend/cuda/quantized/qmm/qmv.cu +++ b/mlx/backend/cuda/quantized/qmm/qmv.cu @@ -1,7 +1,7 @@ // Copyright © 2026 Apple Inc. +#include "mlx/backend/cuda/device/cute_dequant.cuh" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/quantized/qmm/cute_dequant.cuh" #include "mlx/backend/cuda/quantized/qmm/qmm.h" #include "mlx/dtype_utils.h" diff --git a/mlx/backend/cuda/slicing.cpp b/mlx/backend/cuda/slicing.cpp index c130e5b13e..9581f7ccbd 100644 --- a/mlx/backend/cuda/slicing.cpp +++ b/mlx/backend/cuda/slicing.cpp @@ -56,7 +56,9 @@ array compute_dynamic_offset( dtype_to_cuda_type(dtype), nidx); - cu::JitModule& mod = cu::get_jit_module(s.device, module_name, [&]() { + auto& encoder = cu::get_command_encoder(s); + + cu::JitModule& mod = cu::get_jit_module(encoder.device(), module_name, [&]() { std::string source = R"( #include "mlx/backend/cuda/device/utils.cuh" @@ -81,7 +83,6 @@ array compute_dynamic_offset( return std::make_tuple(false, std::move(source), std::vector{kernel_name}); }); - auto& encoder = cu::get_command_encoder(s); // Prepare output. array offset({1}, int64, nullptr, {}); bool donate = indices.is_donatable() &&