From 962427e5c8b305c5ce6caf6fa6aadfaf297c7d8c Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 16:50:30 +0100 Subject: [PATCH 1/4] Implement upsert-records tool for RedisVL MCP --- redisvl/mcp/server.py | 6 +- redisvl/mcp/tools/__init__.py | 3 +- redisvl/mcp/tools/upsert.py | 272 +++++++++++++++ .../integration/test_mcp/test_upsert_tool.py | 326 ++++++++++++++++++ tests/unit/test_mcp/test_upsert_tool_unit.py | 326 ++++++++++++++++++ 5 files changed, 930 insertions(+), 3 deletions(-) create mode 100644 redisvl/mcp/tools/upsert.py create mode 100644 tests/integration/test_mcp/test_upsert_tool.py create mode 100644 tests/unit/test_mcp/test_upsert_tool_unit.py diff --git a/redisvl/mcp/server.py b/redisvl/mcp/server.py index 4e07512c..a2fb2bd6 100644 --- a/redisvl/mcp/server.py +++ b/redisvl/mcp/server.py @@ -8,6 +8,8 @@ from redisvl.index import AsyncSearchIndex from redisvl.mcp.config import MCPConfig, load_mcp_config from redisvl.mcp.settings import MCPSettings +from redisvl.mcp.tools.search import register_search_tool +from redisvl.mcp.tools.upsert import register_upsert_tool from redisvl.redis.connection import RedisConnectionFactory, is_version_gte from redisvl.schema import IndexSchema @@ -181,9 +183,9 @@ def _register_tools(self) -> None: if self._tools_registered or not hasattr(self, "tool"): return - from redisvl.mcp.tools.search import register_search_tool - register_search_tool(self) + if not self.mcp_settings.read_only: + register_upsert_tool(self) self._tools_registered = True @staticmethod diff --git a/redisvl/mcp/tools/__init__.py b/redisvl/mcp/tools/__init__.py index e47aef7c..40e0a59e 100644 --- a/redisvl/mcp/tools/__init__.py +++ b/redisvl/mcp/tools/__init__.py @@ -1,3 +1,4 @@ from redisvl.mcp.tools.search import search_records +from redisvl.mcp.tools.upsert import upsert_records -__all__ = ["search_records"] +__all__ = ["search_records", "upsert_records"] diff --git a/redisvl/mcp/tools/upsert.py b/redisvl/mcp/tools/upsert.py new file mode 100644 index 00000000..3ed379ea --- /dev/null +++ b/redisvl/mcp/tools/upsert.py @@ -0,0 +1,272 @@ +import asyncio +import inspect +from typing import Any, Dict, List, Optional + +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError, map_exception +from redisvl.redis.utils import array_to_buffer +from redisvl.schema.schema import StorageType +from redisvl.schema.validation import validate_object + +DEFAULT_UPSERT_DESCRIPTION = "Upsert records in the configured Redis index." + + +def _validate_request( + *, + server: Any, + records: List[Dict[str, Any]], + id_field: Optional[str], + skip_embedding_if_present: Optional[bool], +) -> bool: + """Validate the public upsert request contract and resolve defaults.""" + runtime = server.config.runtime + + if not isinstance(records, list) or not records: + raise RedisVLMCPError( + "records must be a non-empty list", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if len(records) > runtime.max_upsert_records: + raise RedisVLMCPError( + "records length must be less than or equal to " + f"{runtime.max_upsert_records}", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if id_field is not None and (not isinstance(id_field, str) or not id_field): + raise RedisVLMCPError( + "id_field must be a non-empty string when provided", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + + effective_skip_embedding = runtime.skip_embedding_if_present + if skip_embedding_if_present is not None: + if not isinstance(skip_embedding_if_present, bool): + raise RedisVLMCPError( + "skip_embedding_if_present must be a boolean when provided", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + effective_skip_embedding = skip_embedding_if_present + + for record in records: + if not isinstance(record, dict): + raise RedisVLMCPError( + "records must contain only objects", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if id_field is not None and id_field not in record: + raise RedisVLMCPError( + "id_field '{id_field}' must exist in every record".format( + id_field=id_field + ), + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + + return effective_skip_embedding + + +def _record_needs_embedding( + record: Dict[str, Any], + *, + vector_field_name: str, + skip_embedding_if_present: bool, +) -> bool: + """Determine whether a record requires server-side embedding.""" + return ( + not skip_embedding_if_present + or vector_field_name not in record + or record[vector_field_name] is None + ) + + +def _validate_embed_sources( + records: List[Dict[str, Any]], + *, + embed_text_field: str, + vector_field_name: str, + skip_embedding_if_present: bool, +) -> List[str]: + """Collect embed sources for records that require embedding.""" + contents = [] + for record in records: + if not _record_needs_embedding( + record, + vector_field_name=vector_field_name, + skip_embedding_if_present=skip_embedding_if_present, + ): + continue + + content = record.get(embed_text_field) + if not isinstance(content, str) or not content.strip(): + raise RedisVLMCPError( + "records requiring embedding must include a non-empty " + "'{field}' field".format(field=embed_text_field), + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + contents.append(content) + + return contents + + +async def _embed_one(vectorizer: Any, content: str) -> List[float]: + """Embed one record, falling back from async to sync implementations.""" + aembed = getattr(vectorizer, "aembed", None) + if callable(aembed): + try: + return await aembed(content) + except NotImplementedError: + pass + + embed = getattr(vectorizer, "embed", None) + if embed is None: + raise AttributeError("Configured vectorizer does not support embed()") + if inspect.iscoroutinefunction(embed): + return await embed(content) + return await asyncio.to_thread(embed, content) + + +async def _embed_many(vectorizer: Any, contents: List[str]) -> List[List[float]]: + """Embed multiple records with batch-first fallbacks.""" + if not contents: + return [] + + aembed_many = getattr(vectorizer, "aembed_many", None) + if callable(aembed_many): + try: + return await aembed_many(contents) + except NotImplementedError: + pass + + embed_many = getattr(vectorizer, "embed_many", None) + if callable(embed_many): + if inspect.iscoroutinefunction(embed_many): + return await embed_many(contents) + return await asyncio.to_thread(embed_many, contents) + + embeddings = [] + for content in contents: + embeddings.append(await _embed_one(vectorizer, content)) + return embeddings + + +def _vector_dtype(server: Any, index: Any) -> str: + """Resolve the configured vector field datatype as a lowercase string.""" + field = server.config.get_vector_field(index.schema) + datatype = getattr(field.attrs.datatype, "value", field.attrs.datatype) + return str(datatype).lower() + + +def _prepare_record_for_storage( + record: Dict[str, Any], + *, + server: Any, + index: Any, +) -> Dict[str, Any]: + """Serialize vector fields for storage and validate the prepared record.""" + prepared = dict(record) + vector_field_name = server.config.runtime.vector_field_name + vector_value = prepared.get(vector_field_name) + + if index.schema.index.storage_type == StorageType.HASH: + if isinstance(vector_value, list): + prepared[vector_field_name] = array_to_buffer( + vector_value, + _vector_dtype(server, index), + ) + validate_object(index.schema, prepared) + return prepared + + +async def upsert_records( + server: Any, + *, + records: List[Dict[str, Any]], + id_field: Optional[str] = None, + skip_embedding_if_present: Optional[bool] = None, +) -> Dict[str, Any]: + """Execute `upsert-records` against the configured Redis index.""" + try: + index = await server.get_index() + effective_skip_embedding = _validate_request( + server=server, + records=records, + id_field=id_field, + skip_embedding_if_present=skip_embedding_if_present, + ) + # Copy caller-provided records before enriching them with embeddings or + # storage-specific serialization so the MCP tool does not mutate inputs. + prepared_records = [record.copy() for record in records] + runtime = server.config.runtime + embed_contents = _validate_embed_sources( + prepared_records, + embed_text_field=runtime.default_embed_text_field, + vector_field_name=runtime.vector_field_name, + skip_embedding_if_present=effective_skip_embedding, + ) + + if embed_contents: + vectorizer = await server.get_vectorizer() + embeddings = await _embed_many(vectorizer, embed_contents) + # Tracks position in the compact embeddings list, which only contains + # vectors for records that still need server-side embedding. + embedding_index = 0 + for record in prepared_records: + if _record_needs_embedding( + record, + vector_field_name=runtime.vector_field_name, + skip_embedding_if_present=effective_skip_embedding, + ): + record[runtime.vector_field_name] = embeddings[embedding_index] + embedding_index += 1 + + loadable_records = [ + _prepare_record_for_storage(record, server=server, index=index) + for record in prepared_records + ] + + try: + keys = await server.run_guarded( + "upsert-records", + index.load(loadable_records, id_field=id_field), + ) + except Exception as exc: + mapped = map_exception(exc) + mapped.metadata["partial_write_possible"] = True + raise mapped + + return { + "status": "success", + "keys_upserted": len(keys), + "keys": keys, + } + except RedisVLMCPError: + raise + except Exception as exc: + raise map_exception(exc) + + +def register_upsert_tool(server: Any) -> None: + """Register the MCP upsert tool on a server-like object.""" + description = ( + server.mcp_settings.tool_upsert_description or DEFAULT_UPSERT_DESCRIPTION + ) + + async def upsert_records_tool( + records: List[Dict[str, Any]], + id_field: Optional[str] = None, + skip_embedding_if_present: Optional[bool] = None, + ): + """FastMCP wrapper for the `upsert-records` tool.""" + return await upsert_records( + server, + records=records, + id_field=id_field, + skip_embedding_if_present=skip_embedding_if_present, + ) + + server.tool(name="upsert-records", description=description)(upsert_records_tool) diff --git a/tests/integration/test_mcp/test_upsert_tool.py b/tests/integration/test_mcp/test_upsert_tool.py new file mode 100644 index 00000000..e238ac3a --- /dev/null +++ b/tests/integration/test_mcp/test_upsert_tool.py @@ -0,0 +1,326 @@ +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pytest +import yaml + +from redisvl.index import AsyncSearchIndex +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.server import RedisVLMCPServer +from redisvl.mcp.settings import MCPSettings +from redisvl.mcp.tools.upsert import upsert_records +from redisvl.schema import IndexSchema + + +class RecordingVectorizer: + def __init__(self, model: str, dims: int = 3, **kwargs: Any) -> None: + self.model = model + self.dims = dims + self.kwargs = kwargs + self.aembed_many_inputs: List[List[str]] = [] + self.embed_many_inputs: List[List[str]] = [] + self.aembed_inputs: List[str] = [] + self.embed_inputs: List[str] = [] + + @staticmethod + def _vector_for(text: str) -> List[float]: + base = float(len(text)) + return [base, base + 0.1, base + 0.2] + + async def aembed(self, content: str = "", **kwargs: Any) -> List[float]: + del kwargs + self.aembed_inputs.append(content) + return self._vector_for(content) + + def embed(self, content: str = "", **kwargs: Any) -> List[float]: + del kwargs + self.embed_inputs.append(content) + return self._vector_for(content) + + async def aembed_many( + self, + contents: Optional[List[str]] = None, + texts: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[List[float]]: + del kwargs + items = contents or texts or [] + self.aembed_many_inputs.append(list(items)) + return [self._vector_for(text) for text in items] + + def embed_many( + self, + contents: Optional[List[str]] = None, + texts: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[List[float]]: + del kwargs + items = contents or texts or [] + self.embed_many_inputs.append(list(items)) + return [self._vector_for(text) for text in items] + + +@pytest.fixture +async def upsertable_index(async_client, worker_id): + schema = IndexSchema.from_dict( + { + "index": { + "name": f"mcp-upsert-{worker_id}", + "prefix": f"mcp-upsert:{worker_id}", + "storage_type": "hash", + }, + "fields": [ + {"name": "content", "type": "text"}, + {"name": "category", "type": "tag"}, + {"name": "rating", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + index = AsyncSearchIndex(schema=schema, redis_client=async_client) + await index.create(overwrite=True, drop=True) + + yield index + + await index.delete(drop=True) + + +@pytest.fixture +def mcp_config_path(tmp_path: Path, redis_url: str): + def factory( + *, + redis_name: str, + read_only: bool = False, + runtime_overrides: Optional[Dict[str, Any]] = None, + ) -> str: + runtime = { + "text_field_name": "content", + "vector_field_name": "embedding", + "default_embed_text_field": "content", + "default_limit": 2, + "max_limit": 5, + "max_upsert_records": 64, + "skip_embedding_if_present": True, + } + if runtime_overrides: + runtime.update(runtime_overrides) + + config = { + "server": {"redis_url": redis_url}, + "indexes": { + "knowledge": { + "redis_name": redis_name, + "vectorizer": { + "class": "RecordingVectorizer", + "model": "fake-model", + "dims": 3, + }, + "search": {"type": "vector"}, + "runtime": runtime, + } + }, + } + config_path = tmp_path / ( + f"{redis_name}-{'readonly' if read_only else 'readwrite'}.yaml" + ) + config_path.write_text(yaml.safe_dump(config), encoding="utf-8") + return str(config_path) + + return factory + + +@pytest.fixture +async def started_server(monkeypatch, upsertable_index, mcp_config_path): + monkeypatch.setattr( + "redisvl.mcp.server.resolve_vectorizer_class", + lambda class_name: RecordingVectorizer, + ) + + servers: List[RedisVLMCPServer] = [] + + async def factory( + *, + read_only: bool = False, + runtime_overrides: Optional[Dict[str, Any]] = None, + ) -> RedisVLMCPServer: + server = RedisVLMCPServer( + MCPSettings( + config=mcp_config_path( + redis_name=upsertable_index.schema.index.name, + read_only=read_only, + runtime_overrides=runtime_overrides, + ) + ) + ) + await server.startup() + servers.append(server) + return server + + yield factory + + for server in servers: + await server.shutdown() + + +def _record_id_from_key(key: str) -> str: + return key.rsplit(":", 1)[-1] + + +@pytest.mark.asyncio +async def test_upsert_records_inserts_rows_into_hash_index( + started_server, upsertable_index +): + server = await started_server() + + records = [ + {"content": "first upserted document", "category": "science", "rating": 5}, + {"content": "second upserted document", "category": "health", "rating": 4}, + ] + + response = await upsert_records(server, records=records) + + assert response["status"] == "success" + assert response["keys_upserted"] == 2 + assert len(response["keys"]) == 2 + + vectorizer = await server.get_vectorizer() + assert vectorizer.aembed_many_inputs == [ + ["first upserted document", "second upserted document"] + ] + + stored = await upsertable_index.fetch(_record_id_from_key(response["keys"][0])) + assert stored is not None + assert stored["content"] == "first upserted document" + assert stored["category"] == "science" + + +@pytest.mark.asyncio +async def test_upsert_records_updates_existing_row_with_id_field( + started_server, upsertable_index +): + server = await started_server() + + first_response = await upsert_records( + server, + records=[ + { + "doc_id": "doc-1", + "content": "original content", + "category": "science", + "rating": 3, + } + ], + id_field="doc_id", + ) + + second_response = await upsert_records( + server, + records=[ + { + "doc_id": "doc-1", + "content": "updated content", + "category": "engineering", + "rating": 5, + } + ], + id_field="doc_id", + ) + + assert first_response["keys"] == second_response["keys"] + assert second_response["keys_upserted"] == 1 + + stored = await upsertable_index.fetch( + _record_id_from_key(second_response["keys"][0]) + ) + assert stored is not None + assert stored["content"] == "updated content" + assert stored["category"] == "engineering" + assert int(stored["rating"]) == 5 + + +@pytest.mark.asyncio +async def test_upsert_records_rejects_invalid_records_before_write( + monkeypatch, started_server +): + server = await started_server() + + called = False + + async def fail_load(*args: Any, **kwargs: Any) -> Any: + del args, kwargs + nonlocal called + called = True + raise AssertionError("load should not be called for invalid records") + + monkeypatch.setattr( + "redisvl.index.index.AsyncSearchIndex.load", + fail_load, + ) + + with pytest.raises(RedisVLMCPError) as exc_info: + await upsert_records( + server, + records=[{"category": "science"}], + ) + + assert exc_info.value.code == MCPErrorCode.INVALID_REQUEST + assert called is False + + +@pytest.mark.asyncio +async def test_read_only_mode_excludes_upsert_tool( + monkeypatch, upsertable_index, mcp_config_path +): + monkeypatch.setattr( + "redisvl.mcp.server.resolve_vectorizer_class", + lambda class_name: RecordingVectorizer, + ) + + called: List[bool] = [] + + def fake_register_upsert_tool(server: Any) -> None: + called.append(server.mcp_settings.read_only) + + monkeypatch.setattr( + "redisvl.mcp.server.register_upsert_tool", + fake_register_upsert_tool, + raising=False, + ) + + writeable_server = RedisVLMCPServer( + MCPSettings( + config=mcp_config_path( + redis_name=upsertable_index.schema.index.name, + ) + ) + ) + await writeable_server.startup() + try: + assert called == [False] + finally: + await writeable_server.shutdown() + + read_only_server = RedisVLMCPServer( + MCPSettings( + config=mcp_config_path( + redis_name=upsertable_index.schema.index.name, + read_only=True, + ), + read_only=True, + ) + ) + + await read_only_server.startup() + try: + assert called == [False] + finally: + await read_only_server.shutdown() diff --git a/tests/unit/test_mcp/test_upsert_tool_unit.py b/tests/unit/test_mcp/test_upsert_tool_unit.py new file mode 100644 index 00000000..999c2c1c --- /dev/null +++ b/tests/unit/test_mcp/test_upsert_tool_unit.py @@ -0,0 +1,326 @@ +from types import SimpleNamespace +from typing import Any, List, Optional + +import pytest +from redis.exceptions import RedisError + +from redisvl.mcp.config import MCPConfig +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.tools.upsert import register_upsert_tool, upsert_records +from redisvl.redis.utils import array_to_buffer +from redisvl.schema import IndexSchema + + +def _schema(storage_type: str = "hash") -> IndexSchema: + return IndexSchema.from_dict( + { + "index": { + "name": "docs-index", + "prefix": "doc", + "storage_type": storage_type, + }, + "fields": [ + {"name": "content", "type": "text"}, + {"name": "category", "type": "tag"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + + +def _config( + storage_type: str = "hash", + *, + max_upsert_records: int = 5, + skip_embedding_if_present: bool = True, +) -> MCPConfig: + return MCPConfig.model_validate( + { + "server": {"redis_url": "redis://localhost:6379"}, + "indexes": { + "knowledge": { + "redis_name": "docs-index", + "vectorizer": {"class": "FakeVectorizer", "model": "test-model"}, + "search": {"type": "vector"}, + "runtime": { + "text_field_name": "content", + "vector_field_name": "embedding", + "default_embed_text_field": "content", + "default_limit": 2, + "max_limit": 5, + "max_upsert_records": max_upsert_records, + "skip_embedding_if_present": skip_embedding_if_present, + }, + } + }, + } + ) + + +class FakeVectorizer: + def __init__(self): + self.aembed_many_calls = [] + self.embed_many_calls = [] + self.aembed_calls = [] + self.embed_calls = [] + + async def aembed_many(self, contents: List[str], **kwargs): + self.aembed_many_calls.append((contents, kwargs)) + return [ + [float(index), float(index), float(index)] + for index, _ in enumerate(contents, start=1) + ] + + def embed_many(self, contents: List[str], **kwargs): + self.embed_many_calls.append((contents, kwargs)) + return [[9.0, 9.0, 9.0] for _ in contents] + + async def aembed(self, content: str, **kwargs): + self.aembed_calls.append((content, kwargs)) + return [8.0, 8.0, 8.0] + + def embed(self, content: str, **kwargs): + self.embed_calls.append((content, kwargs)) + return [7.0, 7.0, 7.0] + + +class FallbackBatchVectorizer(FakeVectorizer): + async def aembed_many(self, contents: List[str], **kwargs): + raise NotImplementedError + + +class FakeIndex: + def __init__(self, storage_type: str = "hash"): + self.schema = _schema(storage_type) + self.load_calls = [] + self.keys_to_return = ["doc:1"] + self.load_exception = None + + async def load(self, data, id_field=None, **kwargs): + materialized = list(data) + self.load_calls.append( + { + "data": materialized, + "id_field": id_field, + "kwargs": kwargs, + } + ) + if self.load_exception is not None: + raise self.load_exception + return self.keys_to_return + + +class FakeServer: + def __init__( + self, + *, + storage_type: str = "hash", + max_upsert_records: int = 5, + skip_embedding_if_present: bool = True, + vectorizer: Optional[FakeVectorizer] = None, + ): + self.config = _config( + storage_type, + max_upsert_records=max_upsert_records, + skip_embedding_if_present=skip_embedding_if_present, + ) + self.mcp_settings = SimpleNamespace(tool_upsert_description=None) + self.index = FakeIndex(storage_type) + self.vectorizer = vectorizer or FakeVectorizer() + self.registered_tools = [] + + async def get_index(self): + return self.index + + async def get_vectorizer(self): + return self.vectorizer + + async def run_guarded(self, operation_name: str, awaitable: Any): + return await awaitable + + def tool(self, name=None, description=None, **kwargs): + def decorator(fn): + self.registered_tools.append( + { + "name": name, + "description": description, + "fn": fn, + } + ) + return fn + + return decorator + + +@pytest.mark.asyncio +async def test_upsert_records_generates_missing_vectors_and_serializes_hash_vectors(): + server = FakeServer(storage_type="hash") + server.index.keys_to_return = ["doc:alpha", "doc:beta"] + + response = await upsert_records( + server, + records=[ + {"id": "alpha", "content": "alpha doc", "category": "science"}, + {"id": "beta", "content": "beta doc", "category": "health"}, + ], + id_field="id", + ) + + assert response == { + "status": "success", + "keys_upserted": 2, + "keys": ["doc:alpha", "doc:beta"], + } + assert server.vectorizer.aembed_many_calls == [(["alpha doc", "beta doc"], {})] + assert len(server.index.load_calls) == 1 + loaded_records = server.index.load_calls[0]["data"] + assert loaded_records[0]["embedding"] == array_to_buffer([1.0, 1.0, 1.0], "float32") + assert loaded_records[1]["embedding"] == array_to_buffer([2.0, 2.0, 2.0], "float32") + assert server.index.load_calls[0]["id_field"] == "id" + + +@pytest.mark.asyncio +async def test_upsert_records_preserves_supplied_vectors_when_skip_embedding_if_present(): + server = FakeServer(storage_type="hash", skip_embedding_if_present=True) + + existing_vector = [0.1, 0.2, 0.3] + await upsert_records( + server, + records=[ + {"id": "alpha", "content": "alpha doc", "embedding": existing_vector}, + {"id": "beta", "content": "beta doc"}, + ], + id_field="id", + ) + + loaded_records = server.index.load_calls[0]["data"] + assert loaded_records[0]["embedding"] == array_to_buffer(existing_vector, "float32") + assert loaded_records[1]["embedding"] == array_to_buffer([1.0, 1.0, 1.0], "float32") + assert server.vectorizer.aembed_many_calls == [(["beta doc"], {})] + + +@pytest.mark.asyncio +async def test_upsert_records_overwrites_supplied_vectors_when_skip_embedding_if_present_false(): + server = FakeServer(storage_type="hash", skip_embedding_if_present=True) + + await upsert_records( + server, + records=[{"id": "alpha", "content": "alpha doc", "embedding": [0.1, 0.2, 0.3]}], + id_field="id", + skip_embedding_if_present=False, + ) + + loaded_record = server.index.load_calls[0]["data"][0] + assert loaded_record["embedding"] == array_to_buffer([1.0, 1.0, 1.0], "float32") + assert server.vectorizer.aembed_many_calls == [(["alpha doc"], {})] + + +@pytest.mark.asyncio +async def test_upsert_records_uses_batch_fallback_when_aembed_many_is_not_implemented(): + server = FakeServer(vectorizer=FallbackBatchVectorizer()) + + await upsert_records( + server, + records=[{"content": "alpha doc"}], + ) + + loaded_record = server.index.load_calls[0]["data"][0] + assert loaded_record["embedding"] == array_to_buffer([9.0, 9.0, 9.0], "float32") + assert server.vectorizer.embed_many_calls == [(["alpha doc"], {})] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("records", "id_field", "message"), + [ + ([], None, "records must be a non-empty list"), + ("bad", None, "records must be a non-empty list"), + ([1], None, "records must contain only objects"), + ([{"content": "alpha"}], "id", "id_field 'id' must exist"), + ], +) +async def test_upsert_records_rejects_invalid_request_shapes( + records, id_field, message +): + server = FakeServer() + + with pytest.raises(RedisVLMCPError, match=message) as exc_info: + await upsert_records(server, records=records, id_field=id_field) + + assert exc_info.value.code == MCPErrorCode.INVALID_REQUEST + + +@pytest.mark.asyncio +async def test_upsert_records_rejects_batches_above_runtime_limit(): + server = FakeServer(max_upsert_records=1) + + with pytest.raises( + RedisVLMCPError, match="must be less than or equal to 1" + ) as exc_info: + await upsert_records( + server, + records=[{"content": "alpha"}, {"content": "beta"}], + ) + + assert exc_info.value.code == MCPErrorCode.INVALID_REQUEST + + +@pytest.mark.asyncio +async def test_upsert_records_requires_configured_embed_source_when_embedding_needed(): + server = FakeServer() + + with pytest.raises(RedisVLMCPError, match="content") as exc_info: + await upsert_records( + server, + records=[{"category": "science"}], + ) + + assert exc_info.value.code == MCPErrorCode.INVALID_REQUEST + + +@pytest.mark.asyncio +async def test_upsert_records_surfaces_partial_write_possible_on_backend_failures(): + server = FakeServer() + server.index.load_exception = RedisError("boom") + + with pytest.raises(RedisVLMCPError) as exc_info: + await upsert_records(server, records=[{"content": "alpha doc"}]) + + assert exc_info.value.code == MCPErrorCode.BACKEND_UNAVAILABLE + assert exc_info.value.metadata["partial_write_possible"] is True + + +def test_register_upsert_tool_uses_default_and_override_descriptions(): + default_server = FakeServer() + register_upsert_tool(default_server) + + assert default_server.registered_tools[0]["name"] == "upsert-records" + assert "Upsert records" in default_server.registered_tools[0]["description"] + + custom_server = FakeServer() + custom_server.mcp_settings.tool_upsert_description = "Custom upsert description" + register_upsert_tool(custom_server) + + assert ( + custom_server.registered_tools[0]["description"] == "Custom upsert description" + ) + + +@pytest.mark.asyncio +async def test_registered_upsert_tool_rejects_deprecated_embed_text_field_argument(): + server = FakeServer() + register_upsert_tool(server) + + tool_fn = server.registered_tools[0]["fn"] + + with pytest.raises(TypeError): + await tool_fn(records=[{"content": "alpha doc"}], embed_text_field="content") From c748db2c6c65c8294aea707304fef4877b6f8e1e Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 17:28:57 +0100 Subject: [PATCH 2/4] Python 3.9 compat --- redisvl/mcp/filters.py | 20 +++++++++---------- .../integration/test_mcp/test_upsert_tool.py | 14 +++++++++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/redisvl/mcp/filters.py b/redisvl/mcp/filters.py index cc870439..af5eef3f 100644 --- a/redisvl/mcp/filters.py +++ b/redisvl/mcp/filters.py @@ -1,6 +1,4 @@ -from __future__ import annotations - -from typing import Any, Iterable, Optional +from typing import Any, Dict, Iterable, List, Optional, Union from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError from redisvl.query.filter import FilterExpression, Num, Tag, Text @@ -8,8 +6,8 @@ def parse_filter( - value: Optional[str | dict[str, Any]], schema: IndexSchema -) -> Optional[str | FilterExpression]: + value: Optional[Union[str, Dict[str, Any]]], schema: IndexSchema +) -> Optional[Union[str, FilterExpression]]: """Parse an MCP filter value into a RedisVL filter representation.""" if value is None: return None @@ -24,7 +22,7 @@ def parse_filter( return _parse_expression(value, schema) -def _parse_expression(value: dict[str, Any], schema: IndexSchema) -> FilterExpression: +def _parse_expression(value: Dict[str, Any], schema: IndexSchema) -> FilterExpression: logical_keys = [key for key in ("and", "or", "not") if key in value] if logical_keys: if len(logical_keys) != 1 or len(value) != 1: @@ -53,7 +51,7 @@ def _parse_expression(value: dict[str, Any], schema: IndexSchema) -> FilterExpre retryable=False, ) - expressions: list[FilterExpression] = [] + expressions: List[FilterExpression] = [] for child in children: if not isinstance(child, dict): raise RedisVLMCPError( @@ -205,7 +203,7 @@ def _require_string(value: Any, field_name: str, op: str) -> str: return value -def _require_string_list(value: Any, field_name: str, op: str) -> list[str]: +def _require_string_list(value: Any, field_name: str, op: str) -> List[str]: if not isinstance(value, list) or not value: raise RedisVLMCPError( f"filter value for field '{field_name}' and operator '{op}' must be a non-empty array", @@ -216,7 +214,7 @@ def _require_string_list(value: Any, field_name: str, op: str) -> list[str]: return strings -def _require_number(value: Any, field_name: str, op: str) -> int | float: +def _require_number(value: Any, field_name: str, op: str) -> Union[int, float]: if isinstance(value, bool) or not isinstance(value, (int, float)): raise RedisVLMCPError( f"filter value for field '{field_name}' and operator '{op}' must be numeric", @@ -226,7 +224,9 @@ def _require_number(value: Any, field_name: str, op: str) -> int | float: return value -def _require_number_list(value: Any, field_name: str, op: str) -> list[int | float]: +def _require_number_list( + value: Any, field_name: str, op: str +) -> List[Union[int, float]]: if not isinstance(value, list) or not value: raise RedisVLMCPError( f"filter value for field '{field_name}' and operator '{op}' must be a non-empty array", diff --git a/tests/integration/test_mcp/test_upsert_tool.py b/tests/integration/test_mcp/test_upsert_tool.py index e238ac3a..819a0584 100644 --- a/tests/integration/test_mcp/test_upsert_tool.py +++ b/tests/integration/test_mcp/test_upsert_tool.py @@ -284,6 +284,20 @@ async def test_read_only_mode_excludes_upsert_tool( "redisvl.mcp.server.resolve_vectorizer_class", lambda class_name: RecordingVectorizer, ) + monkeypatch.setattr( + "redisvl.mcp.server.register_search_tool", + lambda server: None, + ) + + def fake_tool(*args: Any, **kwargs: Any): + del args, kwargs + + def decorator(func: Any) -> Any: + return func + + return decorator + + monkeypatch.setattr(RedisVLMCPServer, "tool", fake_tool, raising=False) called: List[bool] = [] From f475a1fb406f3f7e8a655afc5ac734a1ed6b157f Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 18:34:44 +0100 Subject: [PATCH 3/4] fix(mcp): validate hash vectors before serialization --- redisvl/mcp/tools/upsert.py | 28 ++++++++++++++++++-- tests/unit/test_mcp/test_upsert_tool_unit.py | 18 +++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/redisvl/mcp/tools/upsert.py b/redisvl/mcp/tools/upsert.py index 3ed379ea..1154ea5e 100644 --- a/redisvl/mcp/tools/upsert.py +++ b/redisvl/mcp/tools/upsert.py @@ -161,15 +161,40 @@ def _vector_dtype(server: Any, index: Any) -> str: return str(datatype).lower() +def _validation_schema_for_record( + index: Any, + *, + vector_field_name: str, + record: Dict[str, Any], +) -> Any: + """Use a JSON-shaped schema when validating list vectors for HASH storage.""" + if index.schema.index.storage_type == StorageType.HASH and isinstance( + record.get(vector_field_name), list + ): + schema = index.schema.model_copy(deep=True) + schema.index.storage_type = StorageType.JSON + return schema + return index.schema + + def _prepare_record_for_storage( record: Dict[str, Any], *, server: Any, index: Any, ) -> Dict[str, Any]: - """Serialize vector fields for storage and validate the prepared record.""" + """Validate records before serializing HASH vectors for storage.""" prepared = dict(record) vector_field_name = server.config.runtime.vector_field_name + validate_object( + _validation_schema_for_record( + index, + vector_field_name=vector_field_name, + record=prepared, + ), + prepared, + ) + vector_value = prepared.get(vector_field_name) if index.schema.index.storage_type == StorageType.HASH: @@ -178,7 +203,6 @@ def _prepare_record_for_storage( vector_value, _vector_dtype(server, index), ) - validate_object(index.schema, prepared) return prepared diff --git a/tests/unit/test_mcp/test_upsert_tool_unit.py b/tests/unit/test_mcp/test_upsert_tool_unit.py index 999c2c1c..45e569a5 100644 --- a/tests/unit/test_mcp/test_upsert_tool_unit.py +++ b/tests/unit/test_mcp/test_upsert_tool_unit.py @@ -208,6 +208,24 @@ async def test_upsert_records_preserves_supplied_vectors_when_skip_embedding_if_ assert server.vectorizer.aembed_many_calls == [(["beta doc"], {})] +@pytest.mark.asyncio +async def test_upsert_records_rejects_invalid_hash_vector_dimensions_before_serializing(): + server = FakeServer(storage_type="hash", skip_embedding_if_present=True) + + with pytest.raises( + RedisVLMCPError, match="must have 3 dimensions, got 2" + ) as exc_info: + await upsert_records( + server, + records=[{"id": "alpha", "content": "alpha doc", "embedding": [0.1, 0.2]}], + id_field="id", + ) + + assert exc_info.value.code == MCPErrorCode.INVALID_REQUEST + assert server.index.load_calls == [] + assert server.vectorizer.aembed_many_calls == [] + + @pytest.mark.asyncio async def test_upsert_records_overwrites_supplied_vectors_when_skip_embedding_if_present_false(): server = FakeServer(storage_type="hash", skip_embedding_if_present=True) From 11f5d07ae5ea2cbbb33640e8ccabcf234a02ed4d Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Thu, 26 Mar 2026 10:54:52 +0100 Subject: [PATCH 4/4] fix(mcp): validate records before embedding --- redisvl/mcp/tools/upsert.py | 29 ++++++++++++++------ tests/unit/test_mcp/test_upsert_tool_unit.py | 15 ++++++++++ 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/redisvl/mcp/tools/upsert.py b/redisvl/mcp/tools/upsert.py index 1154ea5e..61a18883 100644 --- a/redisvl/mcp/tools/upsert.py +++ b/redisvl/mcp/tools/upsert.py @@ -177,6 +177,20 @@ def _validation_schema_for_record( return index.schema +def _validate_record( + record: Dict[str, Any], *, index: Any, vector_field_name: str +) -> None: + """Validate one record against the schema, allowing HASH list vectors.""" + validate_object( + _validation_schema_for_record( + index, + vector_field_name=vector_field_name, + record=record, + ), + record, + ) + + def _prepare_record_for_storage( record: Dict[str, Any], *, @@ -186,14 +200,7 @@ def _prepare_record_for_storage( """Validate records before serializing HASH vectors for storage.""" prepared = dict(record) vector_field_name = server.config.runtime.vector_field_name - validate_object( - _validation_schema_for_record( - index, - vector_field_name=vector_field_name, - record=prepared, - ), - prepared, - ) + _validate_record(prepared, index=index, vector_field_name=vector_field_name) vector_value = prepared.get(vector_field_name) @@ -226,6 +233,12 @@ async def upsert_records( # storage-specific serialization so the MCP tool does not mutate inputs. prepared_records = [record.copy() for record in records] runtime = server.config.runtime + for record in prepared_records: + _validate_record( + record, + index=index, + vector_field_name=runtime.vector_field_name, + ) embed_contents = _validate_embed_sources( prepared_records, embed_text_field=runtime.default_embed_text_field, diff --git a/tests/unit/test_mcp/test_upsert_tool_unit.py b/tests/unit/test_mcp/test_upsert_tool_unit.py index 45e569a5..8bb59ce0 100644 --- a/tests/unit/test_mcp/test_upsert_tool_unit.py +++ b/tests/unit/test_mcp/test_upsert_tool_unit.py @@ -305,6 +305,21 @@ async def test_upsert_records_requires_configured_embed_source_when_embedding_ne assert exc_info.value.code == MCPErrorCode.INVALID_REQUEST +@pytest.mark.asyncio +async def test_upsert_records_validates_non_vector_fields_before_embedding(): + server = FakeServer() + + with pytest.raises(RedisVLMCPError, match="category") as exc_info: + await upsert_records( + server, + records=[{"content": "alpha doc", "category": ["science"]}], + ) + + assert exc_info.value.code == MCPErrorCode.INVALID_REQUEST + assert server.vectorizer.aembed_many_calls == [] + assert server.index.load_calls == [] + + @pytest.mark.asyncio async def test_upsert_records_surfaces_partial_write_possible_on_backend_failures(): server = FakeServer()