diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 5a4ccbb4952..167ceb7da83 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -66,7 +66,11 @@ jobs: echo "::endgroup::" echo "::group::Build test runners" - ${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 )) + ${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner mlx_mutable_state_test -j$(( $(sysctl -n hw.ncpu) - 1 )) + echo "::endgroup::" + + echo "::group::Run mutable-state (multi-session) unit test" + ./cmake-out/backends/mlx/test/mlx_mutable_state_test echo "::endgroup::" echo "::group::Run op unit tests" diff --git a/backends/mlx/CMakeLists.txt b/backends/mlx/CMakeLists.txt index 43968d09b5d..acb96fb1ed9 100644 --- a/backends/mlx/CMakeLists.txt +++ b/backends/mlx/CMakeLists.txt @@ -255,8 +255,10 @@ option(ET_MLX_ALLOW_CUSTOM_KERNEL_EXECUTION ON ) -set(_mlx_backend__srcs ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp +set(_mlx_backend__srcs + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXLoader.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/MLXBackend.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/runtime/mlx_mutable_state.cpp ) add_library(mlxdelegate ${_mlx_backend__srcs}) diff --git a/backends/mlx/runtime/MLXBackend.cpp b/backends/mlx/runtime/MLXBackend.cpp index 5bd3bf263d1..0dbdec22436 100644 --- a/backends/mlx/runtime/MLXBackend.cpp +++ b/backends/mlx/runtime/MLXBackend.cpp @@ -9,6 +9,7 @@ #include "MLXExecutor.h" #include "MLXInterpreter.h" #include "MLXLoader.h" +#include "mlx_mutable_state.h" #include #include @@ -277,6 +278,12 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { eval(handle->constants.tensors); } + // Register the handle with the per-session mutable-state manager. This is + // a no-op unless a multi-session owner is active for this load (see + // mlx_mutable_state.h); single-session execution is unaffected. + mutable_state_note_handle( + handle, &handle->program, &handle->mutable_buffers); + } catch (const std::exception& e) { ET_LOG(Error, "Failed to load MLX program: %s", e.what()); handle->~MLXHandle(); @@ -366,6 +373,14 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { } } + // Select the active session's mutable buffers (KV cache, recurrent/conv + // state) before running. No-op for single-session handles; weights stay + // shared via ExecutionState::constants. + if (Error rebind_err = mutable_state_rebind_for_execute(h, h->state); + rebind_err != Error::Ok) { + return rebind_err; + } + // Run the MLX program (builds lazy computation graph) h->interpreter.run(program, h->state, h->stream); @@ -431,6 +446,7 @@ class MLXBackend final : public ::executorch::runtime::BackendInterface { void destroy(DelegateHandle* handle) const override { std::lock_guard lock(mlx_global_mutex()); if (handle != nullptr) { + mutable_state_forget_handle(handle); auto* mlx_handle = static_cast(handle); mlx_handle->~MLXHandle(); } diff --git a/backends/mlx/runtime/mlx_mutable_state.cpp b/backends/mlx/runtime/mlx_mutable_state.cpp new file mode 100644 index 00000000000..429f3fea5da --- /dev/null +++ b/backends/mlx/runtime/mlx_mutable_state.cpp @@ -0,0 +1,268 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "mlx_mutable_state.h" + +#include "MLXExecutor.h" +#include "MLXLoader.h" + +#include + +#include +#include + +namespace executorch { +namespace backends { +namespace mlx { + +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; + +namespace { + +struct HandleInfo { + const MLXProgram* program{nullptr}; + MutableBufferData* default_buffers{nullptr}; +}; + +struct Context { + // Delegate handles associated with this loaded program (one per loaded + // method). Keyed by opaque MLXHandle pointer. + std::unordered_map handles; + // Per-session mutable buffers: token -> (handle -> buffers). Allocated lazily + // on first execute for a given (session, handle). + std::unordered_map> + sessions; + int next_token{0}; +}; + +// Process-global registry. MLX serializes execution via its own global mutex and +// the engine serializes per session, but the registry itself is guarded here so +// context/session lifecycle calls from other threads are safe. +std::mutex& registry_mutex() { + static std::mutex m; + return m; +} + +std::unordered_map& contexts() { + static std::unordered_map c; + return c; +} + +std::unordered_map& handle_ctx() { + static std::unordered_map m; + return m; +} + +MutableStateContext g_next_ctx = 1; // 0 is reserved as invalid. + +// Thread-local load scope and active (ctx, session) selection. +thread_local MutableStateContext tl_loading_ctx = kInvalidMutableContext; +thread_local MutableStateContext tl_active_ctx = kInvalidMutableContext; +thread_local int tl_active_token = kNoMutableSession; + +} // namespace + +namespace detail { + +MutableStateContext mutable_state_create_context() { + std::lock_guard g(registry_mutex()); + MutableStateContext ctx = g_next_ctx++; + if (ctx == kInvalidMutableContext) { + ctx = g_next_ctx++; + } + contexts()[ctx]; + return ctx; +} + +void mutable_state_destroy_context(MutableStateContext ctx) { + std::lock_guard g(registry_mutex()); + auto it = contexts().find(ctx); + if (it == contexts().end()) { + return; + } + for (const auto& kv : it->second.handles) { + handle_ctx().erase(kv.first); + } + contexts().erase(it); +} + +void mutable_state_begin_load(MutableStateContext ctx) { + tl_loading_ctx = ctx; +} + +void mutable_state_end_load() { + tl_loading_ctx = kInvalidMutableContext; +} + +bool mutable_state_available(MutableStateContext ctx) { + if (ctx == kInvalidMutableContext) { + return false; + } + std::lock_guard g(registry_mutex()); + return contexts().count(ctx) != 0; +} + +int64_t mutable_state_bytes_per_session(MutableStateContext ctx) { + std::lock_guard g(registry_mutex()); + auto it = contexts().find(ctx); + if (it == contexts().end()) { + return 0; + } + int64_t total = 0; + for (const auto& kv : it->second.handles) { + const MutableBufferData* bufs = kv.second.default_buffers; + if (bufs == nullptr) { + continue; + } + for (const auto& t : bufs->tensors) { + if (t.has_value()) { + total += static_cast(t->nbytes()); + } + } + } + return total; +} + +Error mutable_state_validate_coverage(MutableStateContext ctx) { + // MLX clones all mutable buffers by tid; there is no FQN coverage to verify. + (void)ctx; + return Error::Ok; +} + +Result mutable_state_create_session(MutableStateContext ctx) { + std::lock_guard g(registry_mutex()); + auto it = contexts().find(ctx); + if (it == contexts().end()) { + ET_LOG(Error, "mutable_state_create_session: unknown context %d", ctx); + return Error::InvalidState; + } + int token = it->second.next_token++; + // Per-handle buffers are allocated lazily on first execute. + it->second.sessions[token]; + return token; +} + +void mutable_state_destroy_session(MutableStateContext ctx, int token) { + std::lock_guard g(registry_mutex()); + auto it = contexts().find(ctx); + if (it == contexts().end()) { + return; + } + it->second.sessions.erase(token); +} + +void mutable_state_set_active(MutableStateContext ctx, int token) { + tl_active_ctx = ctx; + tl_active_token = token; +} + +} // namespace detail + +void mutable_state_note_handle( + const void* handle, + const MLXProgram* program, + MutableBufferData* default_buffers) { + if (tl_loading_ctx == kInvalidMutableContext) { + return; // No multi-session owner active during this load: single-session. + } + std::lock_guard g(registry_mutex()); + auto it = contexts().find(tl_loading_ctx); + if (it == contexts().end()) { + return; + } + it->second.handles[handle] = HandleInfo{program, default_buffers}; + handle_ctx()[handle] = tl_loading_ctx; +} + +void mutable_state_forget_handle(const void* handle) { + std::lock_guard g(registry_mutex()); + auto hit = handle_ctx().find(handle); + if (hit == handle_ctx().end()) { + return; + } + auto cit = contexts().find(hit->second); + if (cit != contexts().end()) { + cit->second.handles.erase(handle); + for (auto& session : cit->second.sessions) { + session.second.erase(handle); + } + } + handle_ctx().erase(hit); +} + +Error mutable_state_rebind_for_execute( + const void* handle, + ExecutionState& state) { + std::lock_guard g(registry_mutex()); + auto hit = handle_ctx().find(handle); + if (hit == handle_ctx().end()) { + // Handle was not loaded under a multi-session owner: keep default buffers. + return Error::Ok; + } + auto cit = contexts().find(hit->second); + if (cit == contexts().end()) { + return Error::Ok; + } + Context& ctx = cit->second; + HandleInfo& info = ctx.handles[handle]; + + const bool active_for_this_ctx = + tl_active_token != kNoMutableSession && tl_active_ctx == hit->second; + + if (!active_for_this_ctx) { + // No session selected. Refuse if sessions exist (running against the default + // buffers here would not isolate state from created sessions). + if (!ctx.sessions.empty()) { + ET_LOG( + Error, + "mutable_state_rebind_for_execute: no active session selected but " + "sessions exist for this program"); + return Error::InvalidState; + } + state.mutable_buffers = info.default_buffers; + return Error::Ok; + } + + auto sit = ctx.sessions.find(tl_active_token); + if (sit == ctx.sessions.end()) { + ET_LOG( + Error, + "mutable_state_rebind_for_execute: unknown session token %d", + tl_active_token); + return Error::InvalidState; + } + + auto& per_handle = sit->second; + auto bit = per_handle.find(handle); + if (bit == per_handle.end()) { + // First execute for this (session, handle): allocate fresh zeroed buffers. + // Constants/weights stay shared (ExecutionState::constants is untouched); + // only the mutable buffers are per-session. + MutableBufferData buffers; + try { + load_mutable_buffers(*info.program, buffers); + } catch (const std::exception& e) { + ET_LOG( + Error, + "mutable_state_rebind_for_execute: failed to allocate session " + "buffers: %s", + e.what()); + return Error::MemoryAllocationFailed; + } + bit = per_handle.emplace(handle, std::move(buffers)).first; + } + // unordered_map keeps element pointers stable across rehash, so this remains + // valid for the duration of the execute. + state.mutable_buffers = &bit->second; + return Error::Ok; +} + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/runtime/mlx_mutable_state.h b/backends/mlx/runtime/mlx_mutable_state.h new file mode 100644 index 00000000000..250b70c5a6b --- /dev/null +++ b/backends/mlx/runtime/mlx_mutable_state.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include +#include +#include + +// MLX-private support for running one loaded MLX program with multiple isolated +// instances of its mutable buffers (KV cache, conv/recurrent state). Callers +// create sessions and execute with one active session selected. +// +// Unlike the CUDA backend, the MLX runtime owns mutable buffers directly in a +// swappable container (ExecutionState::mutable_buffers is a MutableBufferData*), +// so per-session isolation is a pointer swap to a freshly zero-allocated +// MutableBufferData — no FQN registration / constant-repoint hook is needed. + +namespace executorch { +namespace backends { +namespace mlx { + +// Forward declarations (defined in MLXLoader.h / MLXExecutor.h). +struct MLXProgram; +struct MutableBufferData; +struct ExecutionState; + +// Opaque per-loaded-program context id (0 = invalid). +using MutableStateContext = int; +constexpr MutableStateContext kInvalidMutableContext = 0; + +// Sentinel for execution without per-session rebinding. +constexpr int kNoMutableSession = -1; + +// Implementation entry points. Callers should use MutableStateContextOwner. +namespace detail { + +MutableStateContext mutable_state_create_context(); +void mutable_state_destroy_context(MutableStateContext ctx); +void mutable_state_begin_load(MutableStateContext ctx); +void mutable_state_end_load(); +bool mutable_state_available(MutableStateContext ctx); +int64_t mutable_state_bytes_per_session(MutableStateContext ctx); +::executorch::runtime::Error mutable_state_validate_coverage( + MutableStateContext ctx); +::executorch::runtime::Result mutable_state_create_session( + MutableStateContext ctx); +void mutable_state_destroy_session(MutableStateContext ctx, int token); +void mutable_state_set_active(MutableStateContext ctx, int token); + +} // namespace detail + +// Caller-facing owner for one mutable-state context. Mirrors the CUDA backend's +// MutableStateContextOwner so the example engine can use a symmetric API. +class ET_EXPERIMENTAL MutableStateContextOwner final { + class LoadScope final { + public: + explicit LoadScope(MutableStateContext ctx) { + detail::mutable_state_begin_load(ctx); + } + + ~LoadScope() { + detail::mutable_state_end_load(); + } + + LoadScope(const LoadScope&) = delete; + LoadScope& operator=(const LoadScope&) = delete; + }; + + class ActiveSessionScope final { + public: + ActiveSessionScope(MutableStateContext ctx, int token) { + detail::mutable_state_set_active(ctx, token); + } + + ~ActiveSessionScope() { + detail::mutable_state_set_active( + kInvalidMutableContext, kNoMutableSession); + } + + ActiveSessionScope(const ActiveSessionScope&) = delete; + ActiveSessionScope& operator=(const ActiveSessionScope&) = delete; + }; + + public: + MutableStateContextOwner() : ctx_(detail::mutable_state_create_context()) {} + + ~MutableStateContextOwner() { + destroy(); + } + + MutableStateContextOwner(const MutableStateContextOwner&) = delete; + MutableStateContextOwner& operator=(const MutableStateContextOwner&) = delete; + + MutableStateContextOwner(MutableStateContextOwner&& other) noexcept + : ctx_(std::exchange(other.ctx_, kInvalidMutableContext)) {} + + MutableStateContextOwner& operator=( + MutableStateContextOwner&& other) noexcept { + if (this != &other) { + destroy(); + ctx_ = std::exchange(other.ctx_, kInvalidMutableContext); + } + return *this; + } + + MutableStateContext get() const { + return ctx_; + } + + explicit operator bool() const { + return ctx_ != kInvalidMutableContext; + } + + // Associates delegate handles created by `fn` with this context. + template + auto with_load_scope(Fn&& fn) const -> decltype(std::forward(fn)()) { + LoadScope scope(ctx_); + return std::forward(fn)(); + } + + // Selects this context/session while `fn` executes. The caller is responsible + // for serializing execution that touches the same loaded program. + template + auto with_active_session(int token, Fn&& fn) const + -> decltype(std::forward(fn)()) { + ActiveSessionScope scope(ctx_, token); + return std::forward(fn)(); + } + + bool available() const { + return detail::mutable_state_available(ctx_); + } + + int64_t bytes_per_session() const { + return detail::mutable_state_bytes_per_session(ctx_); + } + + ::executorch::runtime::Error validate_coverage() const { + return detail::mutable_state_validate_coverage(ctx_); + } + + ::executorch::runtime::Result create_session() const { + return detail::mutable_state_create_session(ctx_); + } + + void destroy_session(int token) const { + detail::mutable_state_destroy_session(ctx_, token); + } + + private: + void destroy() { + if (ctx_ != kInvalidMutableContext) { + detail::mutable_state_destroy_context(ctx_); + ctx_ = kInvalidMutableContext; + } + } + + MutableStateContext ctx_ = kInvalidMutableContext; +}; + +// --- MLXBackend hooks -------------------------------------------------------- +// +// Called from MLXBackend init/execute/destroy. `handle` is an opaque key (the +// MLXHandle pointer). `program` and `default_buffers` are the handle's own +// program and (init-time) mutable buffers; the manager swaps in per-session +// buffers (or restores the default) by re-pointing `state.mutable_buffers`. + +void mutable_state_note_handle( + const void* handle, + const MLXProgram* program, + MutableBufferData* default_buffers); + +void mutable_state_forget_handle(const void* handle); + +::executorch::runtime::Error mutable_state_rebind_for_execute( + const void* handle, + ExecutionState& state); + +} // namespace mlx +} // namespace backends +} // namespace executorch diff --git a/backends/mlx/test/CMakeLists.txt b/backends/mlx/test/CMakeLists.txt index 39024639d1d..c518b2a232d 100644 --- a/backends/mlx/test/CMakeLists.txt +++ b/backends/mlx/test/CMakeLists.txt @@ -69,3 +69,22 @@ if(EXECUTORCH_MLX_ENABLE_SANITIZERS) multi_thread_test_runner PRIVATE ${_mlx_sanitizer_link_options} ) endif() + +# Per-session mutable-state manager unit test (no model/tokenizer needed). +add_executable(mlx_mutable_state_test mlx_mutable_state_test.cpp) +target_include_directories( + mlx_mutable_state_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../runtime +) +target_link_libraries( + mlx_mutable_state_test PRIVATE mlxdelegate mlx executorch_core +) +if(EXECUTORCH_MLX_ENABLE_SANITIZERS) + target_compile_options( + mlx_mutable_state_test PRIVATE -fsanitize=address,undefined + -fno-omit-frame-pointer + ) + target_link_options( + mlx_mutable_state_test PRIVATE ${_mlx_sanitizer_link_options} + ) +endif() +add_test(NAME mlx_mutable_state COMMAND mlx_mutable_state_test) diff --git a/backends/mlx/test/mlx_mutable_state_test.cpp b/backends/mlx/test/mlx_mutable_state_test.cpp new file mode 100644 index 00000000000..ef34962c998 --- /dev/null +++ b/backends/mlx/test/mlx_mutable_state_test.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Unit test for the MLX per-session mutable-state manager +// (backends/mlx/runtime/mlx_mutable_state.{h,cpp}). +// +// Verifies that two sessions created on one loaded program get independent +// mutable buffers: writing into session A's buffer does not leak into session +// B's, and A's value persists across a rebind to B and back. This is the MLX +// analogue of the CUDA "no-bleed" guarantee, exercised directly on the manager +// (no model or tokenizer needed). + +#include "MLXExecutor.h" +#include "MLXLoader.h" +#include "mlx_mutable_state.h" + +#include + +#include + +using namespace ::executorch::backends::mlx; + +namespace { + +int g_failures = 0; + +#define CHECK(cond) \ + do { \ + if (!(cond)) { \ + std::printf("FAIL: %s (line %d)\n", #cond, __LINE__); \ + ++g_failures; \ + } \ + } while (0) + +// Build a minimal program with a single 1-element float mutable buffer at tid 0. +MLXProgram make_program() { + MLXProgram program; + program.num_mutable_buffer_tensors = 1; + program.mutable_buffer_map.push_back(SlotVariant{0, SlotType::TensorSlot}); + TensorMeta meta; + meta.shape.push_back(ShapeDim{/*value=*/1}); + meta.scalar_type = ScalarType::Float; + program.tensor_meta.resize(1); + program.tensor_meta[0] = meta; + return program; +} + +float read0(const MutableBufferData& bufs) { + auto arr = bufs.get(Tid{0}); + ::mlx::core::eval(arr); + return arr.item(); +} + +} // namespace + +int main() { + MLXProgram program = make_program(); + + // Handle's default (init-time) mutable buffers. + MutableBufferData default_bufs; + load_mutable_buffers(program, default_bufs); + + int dummy = 0; + const void* handle = &dummy; + + MutableStateContextOwner owner; + CHECK(static_cast(owner)); + + // Associate the handle with the context (as MLXBackend::init would). + owner.with_load_scope( + [&]() { mutable_state_note_handle(handle, &program, &default_bufs); }); + + CHECK(owner.available()); + CHECK(owner.bytes_per_session() == static_cast(sizeof(float))); + + auto tokA = owner.create_session(); + auto tokB = owner.create_session(); + CHECK(tokA.ok()); + CHECK(tokB.ok()); + CHECK(tokA.get() != tokB.get()); + + ExecutionState state; + + // Session A: rebind, then write a marker (7.0) into its buffer. + owner.with_active_session(tokA.get(), [&]() { + auto err = mutable_state_rebind_for_execute(handle, state); + CHECK(err == ::executorch::runtime::Error::Ok); + state.mutable_buffers->set( + Tid{0}, ::mlx::core::full({1}, 7.0f, ::mlx::core::float32)); + return err; + }); + + // Session B: a fresh rebind must see zeros, not A's marker. + owner.with_active_session(tokB.get(), [&]() { + auto err = mutable_state_rebind_for_execute(handle, state); + CHECK(err == ::executorch::runtime::Error::Ok); + CHECK(read0(*state.mutable_buffers) == 0.0f); + return err; + }); + + // Back to session A: the marker must persist (isolation, no bleed). + owner.with_active_session(tokA.get(), [&]() { + auto err = mutable_state_rebind_for_execute(handle, state); + CHECK(err == ::executorch::runtime::Error::Ok); + CHECK(read0(*state.mutable_buffers) == 7.0f); + return err; + }); + + // With sessions present, executing without an active session is refused + // (prevents running against unmanaged/shared state). + { + auto err = mutable_state_rebind_for_execute(handle, state); + CHECK(err == ::executorch::runtime::Error::InvalidState); + } + + owner.destroy_session(tokA.get()); + owner.destroy_session(tokB.get()); + mutable_state_forget_handle(handle); + + if (g_failures == 0) { + std::printf("OK: mlx_mutable_state isolation test passed\n"); + return 0; + } + std::printf("FAILED: %d checks\n", g_failures); + return 1; +} diff --git a/examples/models/qwen3_5_moe/CMakeLists.txt b/examples/models/qwen3_5_moe/CMakeLists.txt index 726657a3779..aeb97f76ab7 100644 --- a/examples/models/qwen3_5_moe/CMakeLists.txt +++ b/examples/models/qwen3_5_moe/CMakeLists.txt @@ -89,6 +89,7 @@ endif() if(TARGET mlxdelegate) executorch_target_copy_mlx_metallib(qwen3_5_moe_runner) + executorch_target_copy_mlx_metallib(qwen3_5_moe_worker) endif() if(EXECUTORCH_BUILD_CUDA) diff --git a/examples/models/qwen3_5_moe/CMakePresets.json b/examples/models/qwen3_5_moe/CMakePresets.json index 276c2116148..6adcb8aa9cb 100644 --- a/examples/models/qwen3_5_moe/CMakePresets.json +++ b/examples/models/qwen3_5_moe/CMakePresets.json @@ -70,9 +70,9 @@ }, { "name": "qwen3-5-moe-mlx", - "displayName": "Build Qwen3.5 MoE runner (MLX)", + "displayName": "Build Qwen3.5 MoE runner and worker (MLX)", "configurePreset": "qwen3-5-moe-mlx", - "targets": ["qwen3_5_moe_runner"] + "targets": ["qwen3_5_moe_runner", "qwen3_5_moe_worker"] } ], "workflowPresets": [ diff --git a/examples/models/qwen3_5_moe/README.md b/examples/models/qwen3_5_moe/README.md index c275641bfd7..65e3d3c38f1 100644 --- a/examples/models/qwen3_5_moe/README.md +++ b/examples/models/qwen3_5_moe/README.md @@ -302,6 +302,63 @@ python -m executorch.examples.models.qwen3_5_moe.run \ --max-new-tokens 50 ``` +### Serving (MLX, multi-session) + +The MLX worker hosts multiple isolated sessions on **one** weight load, so an +OpenAI-compatible server can serve concurrent conversations without duplicating +the ~weights. `make qwen3_5_moe-mlx` builds both `qwen3_5_moe_runner` and +`qwen3_5_moe_worker` (each with `mlx.metallib` copied alongside). + +Start the server (it auto-locates the worker binary): + +```bash +# tokenizer.json the C++ worker opens (resolve from the HF cache) +TOKENIZER_JSON=$(ls "${HF_HOME:-$HOME/.cache/huggingface}"/hub/models--Qwen--Qwen3.5-35B-A3B/snapshots/*/tokenizer.json | head -n1) + +python -m executorch.examples.models.qwen3_5_moe.serve \ + --model-path ./qwen35_moe_mlx/model.pte \ + --tokenizer-path "$TOKENIZER_JSON" \ + --hf-tokenizer Qwen/Qwen3.5-35B-A3B \ + --max-sessions 4 \ + --host 127.0.0.1 \ + --port 8000 +``` + +- `--tokenizer-path` is the raw `tokenizer.json` **file** the worker loads; + `--hf-tokenizer` (HF id or local dir) supplies the chat template on the Python + side. No `--data-path` (the MLX `.pte` is self-contained). +- `--max-sessions N` caps physical sessions on the single weight load. One slot + is reserved for anonymous requests (requests sent without a session id), so + `N` allows `N-1` concurrently named sessions. + +Query it (OpenAI-compatible) from another terminal. Route each conversation to a +session with the `session_id` header: + +```bash +curl http://127.0.0.1:8000/v1/chat/completions \ + -H "Content-Type: application/json" -H "session_id: alice" \ + -d '{"model":"qwen3.5-moe", + "messages":[{"role":"user","content":"What is the capital of France?"}], + "max_tokens":50,"chat_template_kwargs":{"enable_thinking":false}}' +``` + +Endpoints: `GET /health`, `GET /v1/models`, `POST /v1/chat/completions`, +`DELETE /v1/sessions/{id}` (free a session + its slot), `POST /v1/sessions/{id}/reset`. + +Session/memory semantics on MLX: +- This server uses the standard **stateless** OpenAI contract — send the full + `messages` history each request. `session_id` + warm-resume is a KV-cache reuse + optimization for the shared prefix, not server-side memory. +- Each session adds **one** set of mutable buffers (KV + recurrent/conv state) on + top of the shared weights; per-session cost scales with `max_seq_len`. Weights + are never duplicated. +- KV persists across requests for a live session and is **released on close** + (`DELETE`/reset). Named sessions are not auto-closed — close them to free slots. + MLX's Metal allocator pools freed buffers (so RSS may not shrink immediately), + but they are reused by later sessions, keeping memory bounded. +- Sessions interleave rather than run in parallel (MLX serializes GPU dispatch via + a global mutex). + ### Tiny Model Test For CI or quick pipeline validation (no model download needed): diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp index 713f6211330..6a6f03918b1 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.cpp @@ -183,9 +183,9 @@ class Qwen35MoESession : public LLMSession { ::tokenizers::Tokenizer* tokenizer, std::unordered_map metadata, std::unordered_set eos_ids -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , - ::executorch::backends::cuda::MutableStateContextOwner* mutable_state, + MutableStateContextOwner* mutable_state, int session_token #endif ) @@ -195,7 +195,7 @@ class Qwen35MoESession : public LLMSession { tokenizer_(tokenizer), metadata_(std::move(metadata)), eos_ids_(std::move(eos_ids)) -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , mutable_state_(mutable_state), session_token_(session_token) @@ -212,9 +212,8 @@ class Qwen35MoESession : public LLMSession { } ~Qwen35MoESession() override { -#ifdef EXECUTORCH_BUILD_CUDA - if (mutable_state_ != nullptr && - session_token_ != ::executorch::backends::cuda::kNoMutableSession) { +#ifdef QWEN_HAS_MUTABLE_STATE + if (mutable_state_ != nullptr && session_token_ != kNoMutableSession) { mutable_state_->destroy_session(session_token_); } #endif @@ -425,8 +424,8 @@ class Qwen35MoESession : public LLMSession { float temperature, bool sync_after) { std::lock_guard guard(*exec_mutex_); -#ifdef EXECUTORCH_BUILD_CUDA - Result> res = mutable_state_ != nullptr +#ifdef QWEN_HAS_MUTABLE_STATE + auto res = mutable_state_ != nullptr ? mutable_state_->with_active_session( session_token_, [&]() { return module_->execute(method, inputs); }) @@ -465,10 +464,11 @@ class Qwen35MoESession : public LLMSession { int64_t decode_pos_data_[1] = {0}; TensorPtr decode_tokens_; TensorPtr decode_pos_; +#ifdef QWEN_HAS_MUTABLE_STATE + MutableStateContextOwner* mutable_state_ = nullptr; + int session_token_ = kNoMutableSession; +#endif #ifdef EXECUTORCH_BUILD_CUDA - ::executorch::backends::cuda::MutableStateContextOwner* mutable_state_ = - nullptr; - int session_token_ = ::executorch::backends::cuda::kNoMutableSession; float temp_val_ = 1e-6f; TensorPtr temp_tensor_; #endif @@ -529,17 +529,17 @@ Result> Qwen35MoEEngine::create( "not stop at end of turn"); } +#ifdef QWEN_HAS_MUTABLE_STATE + std::unique_ptr mutable_state; +#endif #ifdef EXECUTORCH_BUILD_CUDA - std::unique_ptr<::executorch::backends::cuda::MutableStateContextOwner> - mutable_state; if (config.enable_cuda_graph) { ET_LOG( Info, "Qwen35MoEEngine: CUDA graph requested; per-session rebinding disabled " "and serving capacity clamped to 1 session."); } else { - auto candidate = std::make_unique< - ::executorch::backends::cuda::MutableStateContextOwner>(); + auto candidate = std::make_unique(); if (Error e = register_mutable_fqns(meta_module.get(), *candidate); e == Error::Ok) { mutable_state = std::move(candidate); @@ -550,9 +550,13 @@ Result> Qwen35MoEEngine::create( "serving capacity clamped to 1 session."); } } +#elif defined(EXECUTORCH_BUILD_MLX) + // MLX owns mutable buffers directly and clones them per session; no FQN + // registration or coverage check is required. + mutable_state = std::make_unique(); #endif -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE auto module_res = mutable_state != nullptr ? mutable_state->with_load_scope( [&]() { return build_qwen_module(config); }) @@ -566,16 +570,14 @@ Result> Qwen35MoEEngine::create( std::unique_ptr shared_module = std::move(module_res.get()); bool rebind_available = false; -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE rebind_available = mutable_state != nullptr && mutable_state->available(); - if (rebind_available) { - if (mutable_state->validate_coverage() != Error::Ok) { - ET_LOG( - Error, - "Qwen35MoEEngine: mutable-buffer coverage check failed; disabling " - "multi-session (capacity clamped to 1)."); - rebind_available = false; - } + if (rebind_available && mutable_state->validate_coverage() != Error::Ok) { + ET_LOG( + Error, + "Qwen35MoEEngine: mutable-buffer coverage check failed; disabling " + "multi-session (capacity clamped to 1)."); + rebind_available = false; } if (!rebind_available) { ET_LOG( @@ -592,7 +594,7 @@ Result> Qwen35MoEEngine::create( std::move(eos_ids), std::move(shared_module), rebind_available -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , std::move(mutable_state) #endif @@ -621,7 +623,7 @@ Result> Qwen35MoEEngine::create_session() { } int token = -1; // kNoMutableSession: single-session / no rebind -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE if (rebind_available_) { auto t = mutable_state_->create_session(); if (t.error() != Error::Ok) { @@ -638,7 +640,7 @@ Result> Qwen35MoEEngine::create_session() { tokenizer_.get(), metadata_, eos_ids_ -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , mutable_state_.get(), token @@ -648,7 +650,7 @@ Result> Qwen35MoEEngine::create_session() { LLMServingCapacity Qwen35MoEEngine::serving_capacity() const { LLMServingCapacity cap; // default: 1 session, 0 bytes (unknown) -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE if (rebind_available_) { cap.max_physical_sessions_without_weight_duplication = config_.max_sessions > 1 ? config_.max_sessions : 1; diff --git a/examples/models/qwen3_5_moe/qwen35_moe_engine.h b/examples/models/qwen3_5_moe/qwen35_moe_engine.h index c7ea53115b8..683e797ea68 100644 --- a/examples/models/qwen3_5_moe/qwen35_moe_engine.h +++ b/examples/models/qwen3_5_moe/qwen35_moe_engine.h @@ -28,10 +28,28 @@ #ifdef EXECUTORCH_BUILD_CUDA #include +#elif defined(EXECUTORCH_BUILD_MLX) +#include +#endif + +#if defined(EXECUTORCH_BUILD_CUDA) || defined(EXECUTORCH_BUILD_MLX) +#define QWEN_HAS_MUTABLE_STATE 1 #endif namespace executorch::extension::llm { +#if defined(EXECUTORCH_BUILD_CUDA) +using MutableStateContextOwner = + ::executorch::backends::cuda::MutableStateContextOwner; +constexpr int kNoMutableSession = + ::executorch::backends::cuda::kNoMutableSession; +#elif defined(EXECUTORCH_BUILD_MLX) +using MutableStateContextOwner = + ::executorch::backends::mlx::MutableStateContextOwner; +constexpr int kNoMutableSession = + ::executorch::backends::mlx::kNoMutableSession; +#endif + /// Immutable configuration for a Qwen3.5 MoE engine. struct Qwen35MoEConfig { std::string model_path; // .pte @@ -77,10 +95,9 @@ class ET_EXPERIMENTAL Qwen35MoEEngine : public LLMEngine { std::unordered_set eos_ids, std::unique_ptr shared_module, bool rebind_available -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , - std::unique_ptr<::executorch::backends::cuda::MutableStateContextOwner> - mutable_state + std::unique_ptr mutable_state #endif ) : config_(std::move(config)), @@ -89,7 +106,7 @@ class ET_EXPERIMENTAL Qwen35MoEEngine : public LLMEngine { eos_ids_(std::move(eos_ids)), shared_module_(std::move(shared_module)), rebind_available_(rebind_available) -#ifdef EXECUTORCH_BUILD_CUDA +#ifdef QWEN_HAS_MUTABLE_STATE , mutable_state_(std::move(mutable_state)) #endif @@ -104,9 +121,8 @@ class ET_EXPERIMENTAL Qwen35MoEEngine : public LLMEngine { std::unique_ptr shared_module_; std::mutex exec_mutex_; bool rebind_available_ = false; -#ifdef EXECUTORCH_BUILD_CUDA - std::unique_ptr<::executorch::backends::cuda::MutableStateContextOwner> - mutable_state_; +#ifdef QWEN_HAS_MUTABLE_STATE + std::unique_ptr mutable_state_; #endif std::atomic live_sessions_{0}; };