Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ jobs:
echo "::endgroup::"

echo "::group::Build test runners"
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner mlx_mutable_state_test -j$(( $(sysctl -n hw.ncpu) - 1 ))
echo "::endgroup::"

echo "::group::Run mutable-state (multi-session) unit test"
./cmake-out/backends/mlx/test/mlx_mutable_state_test
echo "::endgroup::"

echo "::group::Run op unit tests"
Expand Down
6 changes: 4 additions & 2 deletions backends/mlx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,10 @@ option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION
ON
)

set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
set(_mlx_backend__srcs
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp
${CMAKE_CURRENT_SOURCE_DIR}/runtime/mlx_mutable_state.cpp
)

add_library(mlxdelegate ${_mlx_backend__srcs})
Expand Down
16 changes: 16 additions & 0 deletions backends/mlx/runtime/MLXBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "MLXExecutor.h"
#include "MLXInterpreter.h"
#include "MLXLoader.h"
#include "mlx_mutable_state.h"

#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/core/error.h>
Expand Down Expand Up @@ -277,6 +278,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
eval(handle->constants.tensors);
}

// Register the handle with the per-session mutable-state manager. This is
// a no-op unless a multi-session owner is active for this load (see
// mlx_mutable_state.h); single-session execution is unaffected.
mutable_state_note_handle(
handle, &handle->program, &handle->mutable_buffers);

} catch (const std::exception& e) {
ET_LOG(Error, "Failed to load MLX program: %s", e.what());
handle->~MLXHandle();
Expand Down Expand Up @@ -366,6 +373,14 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
}
}

// Select the active session's mutable buffers (KV cache, recurrent/conv
// state) before running. No-op for single-session handles; weights stay
// shared via ExecutionState::constants.
if (Error rebind_err = mutable_state_rebind_for_execute(h, h->state);
rebind_err != Error::Ok) {
return rebind_err;
}

// Run the MLX program (builds lazy computation graph)
h->interpreter.run(program, h->state, h->stream);

Expand Down Expand Up @@ -431,6 +446,7 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface {
void destroy(DelegateHandle* handle) const override {
std::lock_guard<std::mutex> lock(mlx_global_mutex());
if (handle != nullptr) {
mutable_state_forget_handle(handle);
auto* mlx_handle = static_cast<MLXHandle*>(handle);
mlx_handle->~MLXHandle();
}
Expand Down
268 changes: 268 additions & 0 deletions backends/mlx/runtime/mlx_mutable_state.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "mlx_mutable_state.h"

#include "MLXExecutor.h"
#include "MLXLoader.h"

#include <executorch/runtime/platform/log.h>

#include <mutex>
#include <unordered_map>

namespace executorch {
namespace backends {
namespace mlx {

using ::executorch::runtime::Error;
using ::executorch::runtime::Result;

namespace {

struct HandleInfo {
const MLXProgram* program{nullptr};
MutableBufferData* default_buffers{nullptr};
};

struct Context {
// Delegate handles associated with this loaded program (one per loaded
// method). Keyed by opaque MLXHandle pointer.
std::unordered_map<const void*, HandleInfo> handles;
// Per-session mutable buffers: token -> (handle -> buffers). Allocated lazily
// on first execute for a given (session, handle).
std::unordered_map<int, std::unordered_map<const void*, MutableBufferData>>
sessions;
int next_token{0};
};

// Process-global registry. MLX serializes execution via its own global mutex and
// the engine serializes per session, but the registry itself is guarded here so
// context/session lifecycle calls from other threads are safe.
std::mutex& registry_mutex() {
static std::mutex m;
return m;
}

std::unordered_map<MutableStateContext, Context>& contexts() {
static std::unordered_map<MutableStateContext, Context> c;
return c;
}

std::unordered_map<const void*, MutableStateContext>& handle_ctx() {
static std::unordered_map<const void*, MutableStateContext> m;
return m;
}

MutableStateContext g_next_ctx = 1; // 0 is reserved as invalid.

// Thread-local load scope and active (ctx, session) selection.
thread_local MutableStateContext tl_loading_ctx = kInvalidMutableContext;
thread_local MutableStateContext tl_active_ctx = kInvalidMutableContext;
thread_local int tl_active_token = kNoMutableSession;

} // namespace

namespace detail {

MutableStateContext mutable_state_create_context() {
std::lock_guard<std::mutex> g(registry_mutex());
MutableStateContext ctx = g_next_ctx++;
if (ctx == kInvalidMutableContext) {
ctx = g_next_ctx++;
}
contexts()[ctx];
return ctx;
}

void mutable_state_destroy_context(MutableStateContext ctx) {
std::lock_guard<std::mutex> g(registry_mutex());
auto it = contexts().find(ctx);
if (it == contexts().end()) {
return;
}
for (const auto& kv : it->second.handles) {
handle_ctx().erase(kv.first);
}
contexts().erase(it);
}

void mutable_state_begin_load(MutableStateContext ctx) {
tl_loading_ctx = ctx;
}

void mutable_state_end_load() {
tl_loading_ctx = kInvalidMutableContext;
}

bool mutable_state_available(MutableStateContext ctx) {
if (ctx == kInvalidMutableContext) {
return false;
}
std::lock_guard<std::mutex> g(registry_mutex());
return contexts().count(ctx) != 0;
}

int64_t mutable_state_bytes_per_session(MutableStateContext ctx) {
std::lock_guard<std::mutex> g(registry_mutex());
auto it = contexts().find(ctx);
if (it == contexts().end()) {
return 0;
}
int64_t total = 0;
for (const auto& kv : it->second.handles) {
const MutableBufferData* bufs = kv.second.default_buffers;
if (bufs == nullptr) {
continue;
}
for (const auto& t : bufs->tensors) {
if (t.has_value()) {
total += static_cast<int64_t>(t->nbytes());
}
}
}
return total;
}

Error mutable_state_validate_coverage(MutableStateContext ctx) {
// MLX clones all mutable buffers by tid; there is no FQN coverage to verify.
(void)ctx;
return Error::Ok;
}

Result<int> mutable_state_create_session(MutableStateContext ctx) {
std::lock_guard<std::mutex> g(registry_mutex());
auto it = contexts().find(ctx);
if (it == contexts().end()) {
ET_LOG(Error, "mutable_state_create_session: unknown context %d", ctx);
return Error::InvalidState;
}
int token = it->second.next_token++;
// Per-handle buffers are allocated lazily on first execute.
it->second.sessions[token];
return token;
}

void mutable_state_destroy_session(MutableStateContext ctx, int token) {
std::lock_guard<std::mutex> g(registry_mutex());
auto it = contexts().find(ctx);
if (it == contexts().end()) {
return;
}
it->second.sessions.erase(token);
}

void mutable_state_set_active(MutableStateContext ctx, int token) {
tl_active_ctx = ctx;
tl_active_token = token;
}

} // namespace detail

void mutable_state_note_handle(
const void* handle,
const MLXProgram* program,
MutableBufferData* default_buffers) {
if (tl_loading_ctx == kInvalidMutableContext) {
return; // No multi-session owner active during this load: single-session.
}
std::lock_guard<std::mutex> g(registry_mutex());
auto it = contexts().find(tl_loading_ctx);
if (it == contexts().end()) {
return;
}
it->second.handles[handle] = HandleInfo{program, default_buffers};
handle_ctx()[handle] = tl_loading_ctx;
}

void mutable_state_forget_handle(const void* handle) {
std::lock_guard<std::mutex> g(registry_mutex());
auto hit = handle_ctx().find(handle);
if (hit == handle_ctx().end()) {
return;
}
auto cit = contexts().find(hit->second);
if (cit != contexts().end()) {
cit->second.handles.erase(handle);
for (auto& session : cit->second.sessions) {
session.second.erase(handle);
}
}
handle_ctx().erase(hit);
}

Error mutable_state_rebind_for_execute(
const void* handle,
ExecutionState& state) {
std::lock_guard<std::mutex> g(registry_mutex());
auto hit = handle_ctx().find(handle);
if (hit == handle_ctx().end()) {
// Handle was not loaded under a multi-session owner: keep default buffers.
return Error::Ok;
}
auto cit = contexts().find(hit->second);
if (cit == contexts().end()) {
return Error::Ok;
}
Context& ctx = cit->second;
HandleInfo& info = ctx.handles[handle];

const bool active_for_this_ctx =
tl_active_token != kNoMutableSession && tl_active_ctx == hit->second;

if (!active_for_this_ctx) {
// No session selected. Refuse if sessions exist (running against the default
// buffers here would not isolate state from created sessions).
if (!ctx.sessions.empty()) {
ET_LOG(
Error,
"mutable_state_rebind_for_execute: no active session selected but "
"sessions exist for this program");
return Error::InvalidState;
}
state.mutable_buffers = info.default_buffers;
return Error::Ok;
}

auto sit = ctx.sessions.find(tl_active_token);
if (sit == ctx.sessions.end()) {
ET_LOG(
Error,
"mutable_state_rebind_for_execute: unknown session token %d",
tl_active_token);
return Error::InvalidState;
}

auto& per_handle = sit->second;
auto bit = per_handle.find(handle);
if (bit == per_handle.end()) {
// First execute for this (session, handle): allocate fresh zeroed buffers.
// Constants/weights stay shared (ExecutionState::constants is untouched);
// only the mutable buffers are per-session.
MutableBufferData buffers;
try {
load_mutable_buffers(*info.program, buffers);
} catch (const std::exception& e) {
ET_LOG(
Error,
"mutable_state_rebind_for_execute: failed to allocate session "
"buffers: %s",
e.what());
return Error::MemoryAllocationFailed;
}
bit = per_handle.emplace(handle, std::move(buffers)).first;
}
// unordered_map keeps element pointers stable across rehash, so this remains
// valid for the duration of the execute.
state.mutable_buffers = &bit->second;
return Error::Ok;
}

} // namespace mlx
} // namespace backends
} // namespace executorch
Loading
Loading