Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlx/backend/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,9 @@ FetchContent_Declare(
FetchContent_MakeAvailable(cutlass)
target_include_directories(
mlx SYSTEM PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)

# Install CUTLASS headers for JIT.
install(DIRECTORY ${cutlass_SOURCE_DIR}/include/cute
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
install(DIRECTORY ${cutlass_SOURCE_DIR}/include/cutlass
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
6 changes: 3 additions & 3 deletions mlx/backend/cuda/compiled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ void Compiled::eval_gpu(
}
int work_per_thread = 16 / max_size;

cu::JitModule& mod = cu::get_jit_module(s.device, lib_name(), [&]() {
auto& encoder = cu::get_command_encoder(s);

cu::JitModule& mod = cu::get_jit_module(encoder.device(), lib_name(), [&]() {
// Build source code.
cu::FusedKernelBuilder builder{
g_jit_includes, lib_name(), inputs_, outputs_, tape_, is_constant_};
Expand Down Expand Up @@ -305,8 +307,6 @@ void Compiled::eval_gpu(
}
}

auto& encoder = cu::get_command_encoder(s);

// Put outputs.
compiled_allocate_outputs(
inputs, outputs, is_constant_, contiguous, [&](auto n) {
Expand Down
2 changes: 1 addition & 1 deletion mlx/backend/cuda/custom_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ void CustomKernel::eval_gpu(
std::string kernel_name =
(is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
cu::JitModule& mod = cu::get_jit_module(
s.device,
encoder.device(),
name_,
[&]() {
return std::make_tuple(
Expand Down
21 changes: 13 additions & 8 deletions mlx/backend/cuda/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class CommandEncoder {

template <typename F, typename... Params>
void add_kernel_node_ex(
F* func,
F func,
dim3 grid_dim,
dim3 block_dim,
dim3 cluster_dim,
Expand All @@ -69,13 +69,18 @@ class CommandEncoder {
([&](auto&& p) { ptrs[i++] = static_cast<void*>(&p); }(
std::forward<Params>(params)),
...);
add_kernel_node_raw(
reinterpret_cast<void*>(func),
grid_dim,
block_dim,
cluster_dim,
smem_bytes,
ptrs);
if constexpr (std::is_same_v<F, CUfunction>) {
add_kernel_node_raw(
func, grid_dim, block_dim, cluster_dim, smem_bytes, ptrs);
} else {
add_kernel_node_raw(
reinterpret_cast<void*>(func),
grid_dim,
block_dim,
cluster_dim,
smem_bytes,
ptrs);
}
}

void add_kernel_node_raw(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ namespace cute {

// Required by tiled copy for 3/5/6-bit weights.
struct uint24_t {
cuda::std::array<std::uint8_t, 3> bytes;
cuda::std::array<uint8_t, 3> bytes;
};
struct uint40_t {
cuda::std::array<std::uint8_t, 5> bytes;
cuda::std::array<uint8_t, 5> bytes;
};
struct uint48_t {
cuda::std::array<std::uint8_t, 6> bytes;
cuda::std::array<uint8_t, 6> bytes;
};

template <>
Expand All @@ -134,15 +134,22 @@ struct uint_bit<48> {

} // namespace cute

namespace cutlass_gemm {
namespace mlx::core::cu {

using namespace cute;

// Whether the quant type is affine quantization.
template <typename Quant>
constexpr bool quant_has_bias_v = !cutlass::has_negative_zero_v<Quant>;

// Dequantize CuTe tensors with out = w * s + z.
__device__ __forceinline__ void
cute_vectorized_dequant(auto w, auto s, auto z, auto out) {
template <
typename TensorW,
typename TensorS,
typename TensorZ,
typename TensorO>
CUTE_DEVICE void
cute_vectorized_dequant(TensorW w, TensorS s, TensorZ z, TensorO out) {
using namespace cute;
using Element = typename decltype(out)::value_type;
using Quant = typename decltype(w)::value_type;
Expand All @@ -166,4 +173,36 @@ cute_vectorized_dequant(auto w, auto s, auto z, auto out) {
copy(make_tensor(make_rmem_ptr<Element>(&w_dq), out.layout()), out);
}

} // namespace cutlass_gemm
template <
typename TensorW,
typename TensorS,
typename TensorZ,
typename TensorO>
CUTE_DEVICE void
cute_naive_dequant(TensorW w, TensorS s, TensorZ z, TensorO out) {
using Element = typename decltype(out)::value_type;
using Quant = typename decltype(w)::value_type;
using Scale = typename decltype(s)::value_type;
transform(w, out, [](Quant q) { return Element(q); });
transform(out, s, out, [](Element e, Scale s) { return e * Element(s); });
if constexpr (quant_has_bias_v<Quant>) {
transform(out, z, out, plus{});
}
}

template <
typename TensorW,
typename TensorS,
typename TensorZ,
typename TensorO>
CUTE_DEVICE void cute_dequant(TensorW w, TensorS s, TensorZ z, TensorO out) {
if constexpr (
stride(coalesce(w.layout())) == Int<1>{} &&
is_static_v<decltype(s.layout())>) {
cute_vectorized_dequant(w, s, z, out);
} else {
cute_naive_dequant(w, s, z, out);
}
}

} // namespace mlx::core::cu
Loading
Loading