Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions backend/python/common/python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,31 @@
chat-template-compatible message list from proto Message objects.
"""
import json
from urllib.parse import unquote


def resolve_model_path(model, model_file=""):
"""Resolve a LocalAI model reference to something an HF/MLX loader accepts.

LocalAI hands backends either a plain HuggingFace repo id
(``namespace/name``), an already-local filesystem path, or a
``file://`` URI (its ``LocalPrefix``) for models imported from disk.
Loaders such as ``mlx_lm.load`` reject the ``file://`` form because the
scheme is neither a valid repo id nor an existing path, so we normalize
it here before loading.

Resolution order:
1. Prefer ``model_file`` when set and non-empty - that is the resolved
local path LocalAI computed for the model.
2. Strip a ``file://`` scheme and percent-decode it to a plain path.
3. Leave plain repo ids and already-local paths unchanged.
"""
candidate = model_file if model_file else model
if candidate is None:
return candidate
if candidate.startswith("file://"):
return unquote(candidate[len("file://"):])
return candidate


def parse_options(options_list):
Expand Down
16 changes: 10 additions & 6 deletions backend/python/mlx-distributed/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from python_utils import messages_to_dicts, parse_options as _shared_parse_options
from python_utils import messages_to_dicts, parse_options as _shared_parse_options, resolve_model_path
from mlx_utils import parse_tool_calls, split_reasoning


Expand Down Expand Up @@ -99,7 +99,11 @@ async def LoadModel(self, request, context):
from mlx_lm import load
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache

print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr)
# Normalize the model reference: strip LocalAI's file:// LocalPrefix
# and prefer the resolved ModelFile so mlx_lm.load() gets a plain
# repo id or filesystem path (it rejects file:// URIs).
model_path = resolve_model_path(request.Model, request.ModelFile)
print(f"[Rank 0] Loading model: {model_path}", file=sys.stderr)

self.options = parse_options(request.Options)
print(f"Options: {self.options}", file=sys.stderr)
Expand Down Expand Up @@ -128,7 +132,7 @@ async def LoadModel(self, request, context):
)
self.coordinator = DistributedCoordinator(self.group)
self.coordinator.broadcast_command(CMD_LOAD_MODEL)
self.coordinator.broadcast_model_name(request.Model)
self.coordinator.broadcast_model_name(model_path)
else:
print("[Rank 0] No hostfile configured, running single-node", file=sys.stderr)

Expand All @@ -144,9 +148,9 @@ async def LoadModel(self, request, context):

if tokenizer_config:
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
self.model, self.tokenizer = load(model_path, tokenizer_config=tokenizer_config)
else:
self.model, self.tokenizer = load(request.Model)
self.model, self.tokenizer = load(model_path)

if self.group is not None:
from sharding import pipeline_auto_parallel
Expand All @@ -157,7 +161,7 @@ async def LoadModel(self, request, context):
from mlx_cache import ThreadSafeLRUPromptCache
max_cache_entries = self.options.get("max_cache_entries", 10)
self.max_kv_size = self.options.get("max_kv_size", None)
self.model_key = request.Model
self.model_key = model_path
self.lru_cache = ThreadSafeLRUPromptCache(
max_size=max_cache_entries,
can_trim_fn=can_trim_prompt_cache,
Expand Down
12 changes: 8 additions & 4 deletions backend/python/mlx-vlm/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from python_utils import messages_to_dicts, parse_options
from python_utils import messages_to_dicts, parse_options, resolve_model_path
from mlx_utils import parse_tool_calls, split_reasoning

from mlx_vlm import load, stream_generate
Expand Down Expand Up @@ -67,7 +67,11 @@ async def LoadModel(self, request, context):
backend_pb2.Result: The load model result.
"""
try:
print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr)
# Normalize the model reference: strip LocalAI's file:// LocalPrefix
# and prefer the resolved ModelFile so mlx_vlm.load() gets a plain
# repo id or filesystem path (it rejects file:// URIs).
model_path = resolve_model_path(request.Model, request.ModelFile)
print(f"Loading MLX-VLM model: {model_path}", file=sys.stderr)
print(f"Request: {request}", file=sys.stderr)

# Parse Options[] key:value strings into a typed dict
Expand All @@ -76,10 +80,10 @@ async def LoadModel(self, request, context):

# Load model and processor using MLX-VLM
# mlx-vlm load function returns (model, processor) instead of (model, tokenizer)
self.model, self.processor = load(request.Model)
self.model, self.processor = load(model_path)

# Load model config for chat template support
self.config = load_config(request.Model)
self.config = load_config(model_path)

# Auto-infer the tool parser from the chat template. mlx-vlm has
# its own _infer_tool_parser that falls back to mlx-lm parsers.
Expand Down
14 changes: 9 additions & 5 deletions backend/python/mlx/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'common'))
from grpc_auth import get_auth_interceptors
from python_utils import messages_to_dicts, parse_options
from python_utils import messages_to_dicts, parse_options, resolve_model_path
from mlx_utils import parse_tool_calls, split_reasoning

from mlx_lm import load, stream_generate
Expand Down Expand Up @@ -63,7 +63,11 @@ async def LoadModel(self, request, context):
backend_pb2.Result: The load model result.
"""
try:
print(f"Loading MLX model: {request.Model}", file=sys.stderr)
# Normalize the model reference: strip LocalAI's file:// LocalPrefix
# and prefer the resolved ModelFile so mlx_lm.load() gets a plain
# repo id or filesystem path (it rejects file:// URIs).
model_path = resolve_model_path(request.Model, request.ModelFile)
print(f"Loading MLX model: {model_path}", file=sys.stderr)
print(f"Request: {request}", file=sys.stderr)

# Parse Options[] key:value strings into a typed dict (shared helper)
Expand All @@ -89,9 +93,9 @@ async def LoadModel(self, request, context):
# Load model and tokenizer using MLX
if tokenizer_config:
print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr)
self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
self.model, self.tokenizer = load(model_path, tokenizer_config=tokenizer_config)
else:
self.model, self.tokenizer = load(request.Model)
self.model, self.tokenizer = load(model_path)

# mlx_lm.load() returns a TokenizerWrapper that detects tool
# calling and thinking markers from the chat template / vocab.
Expand All @@ -111,7 +115,7 @@ async def LoadModel(self, request, context):
# Initialize thread-safe LRU prompt cache for efficient generation
max_cache_entries = self.options.get("max_cache_entries", 10)
self.max_kv_size = self.options.get("max_kv_size", None)
self.model_key = request.Model
self.model_key = model_path
self.lru_cache = ThreadSafeLRUPromptCache(
max_size=max_cache_entries,
can_trim_fn=can_trim_prompt_cache,
Expand Down
38 changes: 37 additions & 1 deletion backend/python/mlx/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# Make the shared helpers importable so we can unit-test them without a
# running gRPC server.
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'common'))
from python_utils import messages_to_dicts, parse_options
from python_utils import messages_to_dicts, parse_options, resolve_model_path
from mlx_utils import parse_tool_calls, split_reasoning

class TestBackendServicer(unittest.TestCase):
Expand Down Expand Up @@ -322,6 +322,42 @@ def test_split_reasoning_no_marker(self):
self.assertEqual(r, "")
self.assertEqual(c, "just text")

def test_resolve_model_path_file_uri(self):
# file:// LocalPrefix (LocalAI import) is stripped to a plain path.
self.assertEqual(resolve_model_path("file:///a/b"), "/a/b")

def test_resolve_model_path_file_uri_percent_decoded(self):
# Percent-encoded characters (e.g. spaces) are decoded.
self.assertEqual(
resolve_model_path("file:///Users/me/My%20Models/Qwen3"),
"/Users/me/My Models/Qwen3",
)

def test_resolve_model_path_hf_repo_id_unchanged(self):
# Plain HuggingFace repo ids must pass through untouched.
self.assertEqual(
resolve_model_path("mlx-community/Qwen3-Coder-30B"),
"mlx-community/Qwen3-Coder-30B",
)

def test_resolve_model_path_local_path_unchanged(self):
# An already-local absolute path is left as-is.
self.assertEqual(resolve_model_path("/models/Qwen3"), "/models/Qwen3")

def test_resolve_model_path_prefers_model_file(self):
# The resolved ModelFile wins over Model when both are set.
self.assertEqual(
resolve_model_path("file:///ignored", "/resolved/local/path"),
"/resolved/local/path",
)

def test_resolve_model_path_model_file_file_uri(self):
# A ModelFile that is itself a file:// URI is also normalized.
self.assertEqual(
resolve_model_path("ignored", "file:///a/b"),
"/a/b",
)

def test_parse_tool_calls_with_shim(self):
tm = types.SimpleNamespace(
tool_call_start="<tool_call>",
Expand Down
Loading