From cb3c543f79ea81e02740679994ff209e9198cb1a Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 10:43:16 +0100 Subject: [PATCH 1/9] Add filter type-hints; allow float values in `Num` --- redisvl/query/filter.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/redisvl/query/filter.py b/redisvl/query/filter.py index 0295568f..f30870f2 100644 --- a/redisvl/query/filter.py +++ b/redisvl/query/filter.py @@ -164,7 +164,7 @@ def __eq__(self, other: Union[List[str], str]) -> "FilterExpression": return FilterExpression(str(self)) @check_operator_misuse - def __ne__(self, other) -> "FilterExpression": + def __ne__(self, other: Union[List[str], str]) -> "FilterExpression": """Create a Tag inequality filter expression. Args: @@ -298,7 +298,7 @@ def __eq__(self, other) -> "FilterExpression": return FilterExpression(str(self)) @check_operator_misuse - def __ne__(self, other) -> "FilterExpression": + def __ne__(self, other: GeoRadius) -> "FilterExpression": """Create a geographic filter outside of a specified GeoRadius. Args: @@ -349,11 +349,11 @@ class Num(FilterField): SUPPORTED_VAL_TYPES = (int, float, tuple, type(None)) - def __eq__(self, other: int) -> "FilterExpression": + def __eq__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric equality filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -364,11 +364,11 @@ def __eq__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.EQ) return FilterExpression(str(self)) - def __ne__(self, other: int) -> "FilterExpression": + def __ne__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric inequality filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -380,11 +380,11 @@ def __ne__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.NE) return FilterExpression(str(self)) - def __gt__(self, other: int) -> "FilterExpression": + def __gt__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric greater than filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -396,11 +396,11 @@ def __gt__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.GT) return FilterExpression(str(self)) - def __lt__(self, other: int) -> "FilterExpression": + def __lt__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric less than filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -412,11 +412,11 @@ def __lt__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.LT) return FilterExpression(str(self)) - def __ge__(self, other: int) -> "FilterExpression": + def __ge__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric greater than or equal to filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -428,11 +428,11 @@ def __ge__(self, other: int) -> "FilterExpression": self._set_value(other, self.SUPPORTED_VAL_TYPES, FilterOperator.GE) return FilterExpression(str(self)) - def __le__(self, other: int) -> "FilterExpression": + def __le__(self, other: Union[int, float]) -> "FilterExpression": """Create a Numeric less than or equal to filter expression. Args: - other (int): The value to filter on. + other (Union[int, float]): The value to filter on. .. code-block:: python @@ -759,7 +759,9 @@ def _convert_to_timestamp(self, value, end_date=False): raise TypeError(f"Unsupported type for timestamp conversion: {type(value)}") - def __eq__(self, other) -> FilterExpression: + def __eq__( + self, other: Union[datetime.datetime, datetime.date, str, int, float] + ) -> FilterExpression: """ Filter for timestamps equal to the specified value. For date objects (without time), this matches the entire day. @@ -774,6 +776,7 @@ def __eq__(self, other) -> FilterExpression: # For date objects, match the entire day if isinstance(other, str): other = datetime.datetime.strptime(other, "%Y-%m-%d").date() + assert isinstance(other, datetime.date) # validate for mypy start = datetime.datetime.combine(other, datetime.time.min).astimezone( datetime.timezone.utc ) @@ -786,7 +789,9 @@ def __eq__(self, other) -> FilterExpression: self._set_value(timestamp, self.SUPPORTED_TYPES, FilterOperator.EQ) return FilterExpression(str(self)) - def __ne__(self, other) -> FilterExpression: + def __ne__( + self, other: Union[datetime.datetime, datetime.date, str, int, float] + ) -> FilterExpression: """ Filter for timestamps not equal to the specified value. For date objects (without time), this excludes the entire day. @@ -801,6 +806,7 @@ def __ne__(self, other) -> FilterExpression: # For date objects, exclude the entire day if isinstance(other, str): other = datetime.datetime.strptime(other, "%Y-%m-%d").date() + assert isinstance(other, datetime.date) # validate for mypy start = datetime.datetime.combine(other, datetime.time.min) end = datetime.datetime.combine(other, datetime.time.max) return self.between(start, end) From 7529d612e3fbcd99120c3f84310cdbc8ace4f207 Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 11:07:25 +0100 Subject: [PATCH 2/9] Implement MCP search-records tool --- redisvl/mcp/errors.py | 1 + redisvl/mcp/filters.py | 236 ++++++++++++ redisvl/mcp/server.py | 30 +- redisvl/mcp/tools/__init__.py | 3 + redisvl/mcp/tools/search.py | 333 +++++++++++++++++ .../integration/test_mcp/test_search_tool.py | 250 +++++++++++++ tests/unit/test_mcp/test_errors.py | 12 + tests/unit/test_mcp/test_filters.py | 136 +++++++ tests/unit/test_mcp/test_search_tool_unit.py | 345 ++++++++++++++++++ 9 files changed, 1345 insertions(+), 1 deletion(-) create mode 100644 redisvl/mcp/filters.py create mode 100644 redisvl/mcp/tools/__init__.py create mode 100644 redisvl/mcp/tools/search.py create mode 100644 tests/integration/test_mcp/test_search_tool.py create mode 100644 tests/unit/test_mcp/test_filters.py create mode 100644 tests/unit/test_mcp/test_search_tool_unit.py diff --git a/redisvl/mcp/errors.py b/redisvl/mcp/errors.py index 54fb59bc..6befad3b 100644 --- a/redisvl/mcp/errors.py +++ b/redisvl/mcp/errors.py @@ -12,6 +12,7 @@ class MCPErrorCode(str, Enum): """Stable internal error codes exposed by the MCP framework.""" INVALID_REQUEST = "invalid_request" + INVALID_FILTER = "invalid_filter" DEPENDENCY_MISSING = "dependency_missing" BACKEND_UNAVAILABLE = "backend_unavailable" INTERNAL_ERROR = "internal_error" diff --git a/redisvl/mcp/filters.py b/redisvl/mcp/filters.py new file mode 100644 index 00000000..cc870439 --- /dev/null +++ b/redisvl/mcp/filters.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from typing import Any, Iterable, Optional + +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.query.filter import FilterExpression, Num, Tag, Text +from redisvl.schema import IndexSchema + + +def parse_filter( + value: Optional[str | dict[str, Any]], schema: IndexSchema +) -> Optional[str | FilterExpression]: + """Parse an MCP filter value into a RedisVL filter representation.""" + if value is None: + return None + if isinstance(value, str): + return value + if not isinstance(value, dict): + raise RedisVLMCPError( + "filter must be a string or object", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + return _parse_expression(value, schema) + + +def _parse_expression(value: dict[str, Any], schema: IndexSchema) -> FilterExpression: + logical_keys = [key for key in ("and", "or", "not") if key in value] + if logical_keys: + if len(logical_keys) != 1 or len(value) != 1: + raise RedisVLMCPError( + "logical filter objects must contain exactly one operator", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + logical_key = logical_keys[0] + if logical_key == "not": + child = value["not"] + if not isinstance(child, dict): + raise RedisVLMCPError( + "not filter must wrap a single object expression", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + return FilterExpression(f"(-({str(_parse_expression(child, schema))}))") + + children = value[logical_key] + if not isinstance(children, list) or not children: + raise RedisVLMCPError( + f"{logical_key} filter must contain a non-empty array", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + expressions: list[FilterExpression] = [] + for child in children: + if not isinstance(child, dict): + raise RedisVLMCPError( + "logical filter children must be objects", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + expressions.append(_parse_expression(child, schema)) + + combined = expressions[0] + for child in expressions[1:]: + combined = combined & child if logical_key == "and" else combined | child + return combined + + field_name = value.get("field") + op = value.get("op") + if not isinstance(field_name, str) or not field_name.strip(): + raise RedisVLMCPError( + "filter.field must be a non-empty string", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + if not isinstance(op, str) or not op.strip(): + raise RedisVLMCPError( + "filter.op must be a non-empty string", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + field = schema.fields.get(field_name) + if field is None: + raise RedisVLMCPError( + f"Unknown filter field: {field_name}", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + normalized_op = op.lower() + if normalized_op == "exists": + return FilterExpression(f"(-ismissing(@{field_name}))") + + if "value" not in value: + raise RedisVLMCPError( + "filter.value is required for this operator", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + operand = value["value"] + if field.type == "tag": + return _parse_tag_expression(field_name, normalized_op, operand) + if field.type == "text": + return _parse_text_expression(field_name, normalized_op, operand) + if field.type == "numeric": + return _parse_numeric_expression(field_name, normalized_op, operand) + + raise RedisVLMCPError( + f"Unsupported filter field type for {field_name}: {field.type}", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + +def _parse_tag_expression(field_name: str, op: str, operand: Any) -> FilterExpression: + field = Tag(field_name) + if op == "eq": + return field == _require_string(operand, field_name, op) + if op == "ne": + return field != _require_string(operand, field_name, op) + if op == "in": + return field == _require_string_list(operand, field_name, op) + if op == "like": + return field % _require_string(operand, field_name, op) + raise RedisVLMCPError( + f"Unsupported operator '{op}' for tag field '{field_name}'", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + +def _parse_text_expression(field_name: str, op: str, operand: Any) -> FilterExpression: + field = Text(field_name) + if op == "eq": + return field == _require_string(operand, field_name, op) + if op == "ne": + return field != _require_string(operand, field_name, op) + if op == "like": + return field % _require_string(operand, field_name, op) + if op == "in": + return _combine_or( + [field == item for item in _require_string_list(operand, field_name, op)] + ) + raise RedisVLMCPError( + f"Unsupported operator '{op}' for text field '{field_name}'", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + +def _parse_numeric_expression( + field_name: str, op: str, operand: Any +) -> FilterExpression: + field = Num(field_name) + if op == "eq": + return field == _require_number(operand, field_name, op) + if op == "ne": + return field != _require_number(operand, field_name, op) + if op == "gt": + return field > _require_number(operand, field_name, op) + if op == "gte": + return field >= _require_number(operand, field_name, op) + if op == "lt": + return field < _require_number(operand, field_name, op) + if op == "lte": + return field <= _require_number(operand, field_name, op) + if op == "in": + return _combine_or( + [field == item for item in _require_number_list(operand, field_name, op)] + ) + raise RedisVLMCPError( + f"Unsupported operator '{op}' for numeric field '{field_name}'", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + +def _combine_or(expressions: Iterable[FilterExpression]) -> FilterExpression: + expression_list = list(expressions) + if not expression_list: + raise RedisVLMCPError( + "in operator requires a non-empty array", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + combined = expression_list[0] + for expression in expression_list[1:]: + combined = combined | expression + return combined + + +def _require_string(value: Any, field_name: str, op: str) -> str: + if not isinstance(value, str) or not value: + raise RedisVLMCPError( + f"filter value for field '{field_name}' and operator '{op}' must be a non-empty string", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + return value + + +def _require_string_list(value: Any, field_name: str, op: str) -> list[str]: + if not isinstance(value, list) or not value: + raise RedisVLMCPError( + f"filter value for field '{field_name}' and operator '{op}' must be a non-empty array", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + strings = [_require_string(item, field_name, op) for item in value] + return strings + + +def _require_number(value: Any, field_name: str, op: str) -> int | float: + if isinstance(value, bool) or not isinstance(value, (int, float)): + raise RedisVLMCPError( + f"filter value for field '{field_name}' and operator '{op}' must be numeric", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + return value + + +def _require_number_list(value: Any, field_name: str, op: str) -> list[int | float]: + if not isinstance(value, list) or not value: + raise RedisVLMCPError( + f"filter value for field '{field_name}' and operator '{op}' must be a non-empty array", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + return [_require_number(item, field_name, op) for item in value] diff --git a/redisvl/mcp/server.py b/redisvl/mcp/server.py index 12e1d6db..06c07f62 100644 --- a/redisvl/mcp/server.py +++ b/redisvl/mcp/server.py @@ -2,11 +2,13 @@ from importlib import import_module from typing import Any, Awaitable, Optional, Type +from redis import __version__ as redis_py_version + from redisvl.exceptions import RedisSearchError from redisvl.index import AsyncSearchIndex from redisvl.mcp.config import MCPConfig, load_mcp_config from redisvl.mcp.settings import MCPSettings -from redisvl.redis.connection import RedisConnectionFactory +from redisvl.redis.connection import RedisConnectionFactory, is_version_gte from redisvl.schema import IndexSchema try: @@ -41,6 +43,7 @@ def __init__(self, settings: MCPSettings): self._index: Optional[AsyncSearchIndex] = None self._vectorizer: Optional[Any] = None self._semaphore: Optional[asyncio.Semaphore] = None + self._tools_registered = False async def startup(self) -> None: """Load config, inspect the configured index, and initialize dependencies.""" @@ -82,6 +85,7 @@ async def startup(self) -> None: timeout=timeout, ) self._validate_vectorizer_dims(effective_schema) + self._register_tools() except Exception: if self._index is not None: await self.shutdown() @@ -155,6 +159,30 @@ def _validate_vectorizer_dims(self, schema: IndexSchema) -> None: f"Vectorizer dims {actual_dims} do not match configured vector field dims {configured_dims}" ) + async def supports_native_hybrid_search(self) -> bool: + """Return whether the current runtime supports Redis native hybrid search.""" + if self._index is None: + raise RuntimeError("MCP server has not been started") + if not is_version_gte(redis_py_version, "7.1.0"): + return False + + client = await self._index._get_client() + info = await client.info("server") + if not is_version_gte(info.get("redis_version", "0.0.0"), "8.4.0"): + return False + + return hasattr(client.ft(self._index.schema.index.name), "hybrid_search") + + def _register_tools(self) -> None: + """Register MCP tools once the server is ready.""" + if self._tools_registered or not hasattr(self, "tool"): + return + + from redisvl.mcp.tools.search import register_search_tool + + register_search_tool(self) + self._tools_registered = True + @staticmethod def _is_missing_index_error(exc: RedisSearchError) -> bool: """Detect the Redis search errors that mean the configured index is absent.""" diff --git a/redisvl/mcp/tools/__init__.py b/redisvl/mcp/tools/__init__.py new file mode 100644 index 00000000..e47aef7c --- /dev/null +++ b/redisvl/mcp/tools/__init__.py @@ -0,0 +1,3 @@ +from redisvl.mcp.tools.search import search_records + +__all__ = ["search_records"] diff --git a/redisvl/mcp/tools/search.py b/redisvl/mcp/tools/search.py new file mode 100644 index 00000000..39a789b8 --- /dev/null +++ b/redisvl/mcp/tools/search.py @@ -0,0 +1,333 @@ +import asyncio +import inspect +from typing import Any, Optional + +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError, map_exception +from redisvl.mcp.filters import parse_filter +from redisvl.query import AggregateHybridQuery, HybridQuery, TextQuery, VectorQuery + +DEFAULT_SEARCH_DESCRIPTION = "Search records in the configured Redis index." + + +def _validate_request( + *, + query: str, + search_type: str, + limit: Optional[int], + offset: int, + return_fields: Optional[list[str]], + server: Any, + index: Any, +) -> tuple[int, list[str]]: + """Validate the MCP search request and resolve effective request defaults. + + This function enforces the public MCP contract for `search-records` before + any RedisVL query objects are constructed. It also derives the default + return-field projection from the effective index schema. + """ + runtime = server.config.runtime + + if not isinstance(query, str) or not query.strip(): + raise RedisVLMCPError( + "query must be a non-empty string", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if search_type not in {"vector", "fulltext", "hybrid"}: + raise RedisVLMCPError( + "search_type must be one of: vector, fulltext, hybrid", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + + effective_limit = runtime.default_limit if limit is None else limit + if not isinstance(effective_limit, int) or effective_limit <= 0: + raise RedisVLMCPError( + "limit must be greater than 0", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if effective_limit > runtime.max_limit: + raise RedisVLMCPError( + f"limit must be less than or equal to {runtime.max_limit}", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if not isinstance(offset, int) or offset < 0: + raise RedisVLMCPError( + "offset must be greater than or equal to 0", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + + schema_fields = set(index.schema.field_names) + vector_field_name = runtime.vector_field_name + + if return_fields is None: + fields = [ + field_name + for field_name in index.schema.field_names + if field_name != vector_field_name + ] + else: + if not isinstance(return_fields, list): + raise RedisVLMCPError( + "return_fields must be a list of field names", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + fields = [] + for field_name in return_fields: + if not isinstance(field_name, str) or not field_name: + raise RedisVLMCPError( + "return_fields must contain non-empty strings", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if field_name not in schema_fields: + raise RedisVLMCPError( + f"Unknown return field '{field_name}'", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + if field_name == vector_field_name: + raise RedisVLMCPError( + f"Vector field '{vector_field_name}' cannot be returned", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + fields.append(field_name) + + return effective_limit, fields + + +def _normalize_record( + result: dict[str, Any], score_field: str, score_type: str +) -> dict[str, Any]: + """Convert one RedisVL search result into the stable MCP result shape. + + RedisVL and redis-py expose scores and document identifiers under slightly + different field names depending on the query type, so normalization happens + here before the MCP response is returned. + """ + score = result.get(score_field) + if score is None and score_field == "score": + score = result.get("__score") + if score is None: + raise RedisVLMCPError( + f"Search result missing expected score field '{score_field}'", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + + record = dict(result) + doc_id = record.pop("id", None) + if doc_id is None: + doc_id = record.pop("__key", None) + if doc_id is None: + doc_id = record.pop("key", None) + if doc_id is None: + raise RedisVLMCPError( + "Search result missing id", + code=MCPErrorCode.INVALID_REQUEST, + retryable=False, + ) + + for field_name in ( + "vector_distance", + "score", + "__score", + "text_score", + "vector_similarity", + "hybrid_score", + ): + record.pop(field_name, None) + + return { + "id": doc_id, + "score": float(score), + "score_type": score_type, + "record": record, + } + + +async def _embed_query(vectorizer: Any, query: str) -> Any: + """Embed the user query through either an async or sync vectorizer API.""" + if hasattr(vectorizer, "aembed"): + return await vectorizer.aembed(query) + embed = getattr(vectorizer, "embed") + if inspect.iscoroutinefunction(embed): + return await embed(query) + return await asyncio.to_thread(embed, query) + + +async def _build_query( + *, + server: Any, + index: Any, + query: str, + search_type: str, + limit: int, + offset: int, + filter_value: str | dict[str, Any] | None, + return_fields: list[str], +) -> tuple[Any, str, str]: + """Build the RedisVL query object and score metadata for one search mode. + + Returns the constructed query object along with the raw score field name and + the stable MCP `score_type` label that the response should expose. + """ + runtime = server.config.runtime + num_results = limit + offset + filter_expression = parse_filter(filter_value, index.schema) + + if search_type == "vector": + vectorizer = await server.get_vectorizer() + embedding = await _embed_query(vectorizer, query) + return ( + VectorQuery( + vector=embedding, + vector_field_name=runtime.vector_field_name, + filter_expression=filter_expression, + return_fields=return_fields, + num_results=num_results, + normalize_vector_distance=True, + ), + "vector_distance", + "vector_distance_normalized", + ) + + if search_type == "fulltext": + return ( + TextQuery( + text=query, + text_field_name=runtime.text_field_name, + filter_expression=filter_expression, + return_fields=return_fields, + num_results=num_results, + stopwords=None, + ), + "score", + "text_score", + ) + + vectorizer = await server.get_vectorizer() + embedding = await _embed_query(vectorizer, query) + if await server.supports_native_hybrid_search(): + native_query = HybridQuery( + text=query, + text_field_name=runtime.text_field_name, + vector=embedding, + vector_field_name=runtime.vector_field_name, + filter_expression=filter_expression, + return_fields=["__key", *return_fields], + num_results=num_results, + stopwords=None, + combination_method="LINEAR", + linear_alpha=0.7, + yield_text_score_as="text_score", + yield_vsim_score_as="vector_similarity", + yield_combined_score_as="hybrid_score", + ) + native_query.postprocessing_config.apply(__key="@__key") + return ( + native_query, + "hybrid_score", + "hybrid_score", + ) + + fallback_query = AggregateHybridQuery( + text=query, + text_field_name=runtime.text_field_name, + vector=embedding, + vector_field_name=runtime.vector_field_name, + filter_expression=filter_expression, + return_fields=["__key", *return_fields], + num_results=num_results, + stopwords=None, + ) + return ( + fallback_query, + "hybrid_score", + "hybrid_score", + ) + + +async def search_records( + server: Any, + *, + query: str, + search_type: str = "vector", + limit: Optional[int] = None, + offset: int = 0, + filter: str | dict[str, Any] | None = None, + return_fields: Optional[list[str]] = None, +) -> dict[str, Any]: + """Execute `search-records` against the server's configured Redis index.""" + try: + index = await server.get_index() + effective_limit, effective_return_fields = _validate_request( + query=query, + search_type=search_type, + limit=limit, + offset=offset, + return_fields=return_fields, + server=server, + index=index, + ) + built_query, score_field, score_type = await _build_query( + server=server, + index=index, + query=query.strip(), + search_type=search_type, + limit=effective_limit, + offset=offset, + filter_value=filter, + return_fields=effective_return_fields, + ) + raw_results = await server.run_guarded( + "search-records", + index.query(built_query), + ) + sliced_results = raw_results[offset : offset + effective_limit] + return { + "search_type": search_type, + "offset": offset, + "limit": effective_limit, + "results": [ + _normalize_record(result, score_field, score_type) + for result in sliced_results + ], + } + except RedisVLMCPError: + raise + except Exception as exc: + raise map_exception(exc) from exc + + +def register_search_tool(server: Any) -> None: + """Register the MCP search tool on a server-like object.""" + description = ( + server.mcp_settings.tool_search_description or DEFAULT_SEARCH_DESCRIPTION + ) + + async def search_records_tool( + query: str, + search_type: str = "vector", + limit: Optional[int] = None, + offset: int = 0, + filter: str | dict[str, Any] | None = None, + return_fields: Optional[list[str]] = None, + ): + """FastMCP wrapper for the `search-records` tool.""" + return await search_records( + server, + query=query, + search_type=search_type, + limit=limit, + offset=offset, + filter=filter, + return_fields=return_fields, + ) + + server.tool(name="search-records", description=description)(search_records_tool) diff --git a/tests/integration/test_mcp/test_search_tool.py b/tests/integration/test_mcp/test_search_tool.py new file mode 100644 index 00000000..4f25a6b6 --- /dev/null +++ b/tests/integration/test_mcp/test_search_tool.py @@ -0,0 +1,250 @@ +from pathlib import Path + +import pytest +import yaml + +from redisvl.index import AsyncSearchIndex +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.server import RedisVLMCPServer +from redisvl.mcp.settings import MCPSettings +from redisvl.mcp.tools.search import search_records +from redisvl.redis.connection import is_version_gte +from redisvl.redis.utils import array_to_buffer +from redisvl.schema import IndexSchema +from tests.conftest import get_redis_version_async, skip_if_redis_version_below_async + + +class FakeVectorizer: + def __init__(self, model: str, dims: int = 3, **kwargs): + self.model = model + self.dims = dims + self.kwargs = kwargs + + def embed(self, content: str = "", **kwargs): + del content, kwargs + return [0.1, 0.1, 0.5] + + +@pytest.fixture +async def searchable_index(async_client, worker_id): + schema = IndexSchema.from_dict( + { + "index": { + "name": f"mcp-search-{worker_id}", + "prefix": f"mcp-search:{worker_id}", + "storage_type": "hash", + }, + "fields": [ + {"name": "content", "type": "text"}, + {"name": "category", "type": "tag"}, + {"name": "rating", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + index = AsyncSearchIndex(schema=schema, redis_client=async_client) + await index.create(overwrite=True, drop=True) + + def preprocess(record: dict) -> dict: + return { + **record, + "embedding": array_to_buffer(record["embedding"], "float32"), + } + + await index.load( + [ + { + "id": f"doc:{worker_id}:1", + "content": "science article about planets", + "category": "science", + "rating": 5, + "embedding": [0.1, 0.1, 0.5], + }, + { + "id": f"doc:{worker_id}:2", + "content": "medical science and health", + "category": "health", + "rating": 4, + "embedding": [0.1, 0.1, 0.4], + }, + { + "id": f"doc:{worker_id}:3", + "content": "sports update and scores", + "category": "sports", + "rating": 3, + "embedding": [-0.2, 0.1, 0.0], + }, + ], + preprocess=preprocess, + ) + + yield index + + await index.delete(drop=True) + + +@pytest.fixture +def mcp_config_path(tmp_path: Path, redis_url: str): + def factory(redis_name: str) -> str: + config = { + "server": {"redis_url": redis_url}, + "indexes": { + "knowledge": { + "redis_name": redis_name, + "vectorizer": { + "class": "FakeVectorizer", + "model": "fake-model", + "dims": 3, + }, + "runtime": { + "text_field_name": "content", + "vector_field_name": "embedding", + "default_embed_text_field": "content", + "default_limit": 2, + "max_limit": 5, + }, + } + }, + } + config_path = tmp_path / f"{redis_name}.yaml" + config_path.write_text(yaml.safe_dump(config), encoding="utf-8") + return str(config_path) + + return factory + + +@pytest.fixture +async def started_server(monkeypatch, searchable_index, mcp_config_path): + monkeypatch.setattr( + "redisvl.mcp.server.resolve_vectorizer_class", + lambda class_name: FakeVectorizer, + ) + server = RedisVLMCPServer( + MCPSettings(config=mcp_config_path(searchable_index.schema.index.name)) + ) + await server.startup() + yield server + await server.shutdown() + + +@pytest.mark.asyncio +async def test_search_records_vector_success_with_pagination_and_projection( + started_server, +): + response = await search_records( + started_server, + query="science", + limit=1, + offset=1, + return_fields=["content", "category"], + ) + + assert response["search_type"] == "vector" + assert response["offset"] == 1 + assert response["limit"] == 1 + assert len(response["results"]) == 1 + assert response["results"][0]["score_type"] == "vector_distance_normalized" + assert set(response["results"][0]["record"]) == {"content", "category"} + + +@pytest.mark.asyncio +async def test_search_records_fulltext_success(started_server): + response = await search_records( + started_server, + query="science", + search_type="fulltext", + return_fields=["content", "category"], + ) + + assert response["search_type"] == "fulltext" + assert response["results"] + assert response["results"][0]["score_type"] == "text_score" + assert response["results"][0]["score"] is not None + assert "science" in response["results"][0]["record"]["content"] + + +@pytest.mark.asyncio +async def test_search_records_respects_raw_string_filter(started_server): + response = await search_records( + started_server, + query="science", + filter="@category:{science}", + return_fields=["content", "category"], + ) + + assert response["results"] + assert all( + result["record"]["category"] == "science" for result in response["results"] + ) + + +@pytest.mark.asyncio +async def test_search_records_respects_dsl_filter(started_server): + response = await search_records( + started_server, + query="science", + filter={"field": "rating", "op": "gte", "value": 4.5}, + return_fields=["content", "category", "rating"], + ) + + assert response["results"] + assert all( + float(result["record"]["rating"]) >= 4.5 for result in response["results"] + ) + + +@pytest.mark.asyncio +async def test_search_records_invalid_filter_returns_invalid_filter(started_server): + with pytest.raises(RedisVLMCPError) as exc_info: + await search_records( + started_server, + query="science", + filter={"field": "missing", "op": "eq", "value": "science"}, + ) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER + + +@pytest.mark.asyncio +async def test_search_records_native_hybrid_success(started_server, async_client): + await skip_if_redis_version_below_async(async_client, "8.4.0") + + response = await search_records( + started_server, + query="science", + search_type="hybrid", + return_fields=["content", "category"], + ) + + assert response["search_type"] == "hybrid" + assert response["results"] + assert response["results"][0]["score_type"] == "hybrid_score" + assert response["results"][0]["score"] is not None + + +@pytest.mark.asyncio +async def test_search_records_fallback_hybrid_success(started_server, async_client): + redis_version = await get_redis_version_async(async_client) + if is_version_gte(redis_version, "8.4.0"): + pytest.skip(f"Redis version {redis_version} uses native hybrid search") + + response = await search_records( + started_server, + query="science", + search_type="hybrid", + return_fields=["content", "category"], + ) + + assert response["search_type"] == "hybrid" + assert response["results"] + assert response["results"][0]["score_type"] == "hybrid_score" + assert response["results"][0]["score"] is not None diff --git a/tests/unit/test_mcp/test_errors.py b/tests/unit/test_mcp/test_errors.py index 066e3173..ddd28622 100644 --- a/tests/unit/test_mcp/test_errors.py +++ b/tests/unit/test_mcp/test_errors.py @@ -26,6 +26,18 @@ def test_import_error_maps_to_dependency_missing(): assert mapped.retryable is False +def test_filter_error_is_preserved(): + original = RedisVLMCPError( + "bad filter", + code=MCPErrorCode.INVALID_FILTER, + retryable=False, + ) + + mapped = map_exception(original) + + assert mapped is original + + def test_redis_errors_map_to_backend_unavailable(): mapped = map_exception(RedisSearchError("redis unavailable")) diff --git a/tests/unit/test_mcp/test_filters.py b/tests/unit/test_mcp/test_filters.py new file mode 100644 index 00000000..4fb43b6a --- /dev/null +++ b/tests/unit/test_mcp/test_filters.py @@ -0,0 +1,136 @@ +import pytest + +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.filters import parse_filter +from redisvl.query.filter import FilterExpression +from redisvl.schema import IndexSchema + + +def _schema() -> IndexSchema: + return IndexSchema.from_dict( + { + "index": { + "name": "docs-index", + "prefix": "doc", + "storage_type": "hash", + }, + "fields": [ + {"name": "content", "type": "text"}, + {"name": "category", "type": "tag"}, + {"name": "rating", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + + +def _render_filter(value): + if isinstance(value, FilterExpression): + return str(value) + return value + + +def test_parse_filter_passes_through_raw_string(): + raw = "@category:{science} @rating:[4 +inf]" + + parsed = parse_filter(raw, _schema()) + + assert parsed == raw + + +def test_parse_filter_builds_atomic_expression(): + parsed = parse_filter( + {"field": "category", "op": "eq", "value": "science"}, + _schema(), + ) + + assert isinstance(parsed, FilterExpression) + assert str(parsed) == "@category:{science}" + + +def test_parse_filter_builds_nested_logical_expression(): + parsed = parse_filter( + { + "and": [ + {"field": "category", "op": "eq", "value": "science"}, + { + "or": [ + {"field": "rating", "op": "gte", "value": 4.5}, + {"field": "content", "op": "like", "value": "quant*"}, + ] + }, + ] + }, + _schema(), + ) + + assert isinstance(parsed, FilterExpression) + assert ( + str(parsed) == "(@category:{science} (@rating:[4.5 +inf] | @content:(quant*)))" + ) + + +def test_parse_filter_builds_not_expression(): + parsed = parse_filter( + { + "not": {"field": "category", "op": "eq", "value": "science"}, + }, + _schema(), + ) + + assert _render_filter(parsed) == "(-(@category:{science}))" + + +def test_parse_filter_builds_exists_expression(): + parsed = parse_filter( + {"field": "content", "op": "exists"}, + _schema(), + ) + + assert _render_filter(parsed) == "(-ismissing(@content))" + + +def test_parse_filter_rejects_unknown_field(): + with pytest.raises(RedisVLMCPError) as exc_info: + parse_filter({"field": "missing", "op": "eq", "value": "science"}, _schema()) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER + + +def test_parse_filter_rejects_unknown_operator(): + with pytest.raises(RedisVLMCPError) as exc_info: + parse_filter( + {"field": "category", "op": "contains", "value": "science"}, _schema() + ) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER + + +def test_parse_filter_rejects_type_mismatch(): + with pytest.raises(RedisVLMCPError) as exc_info: + parse_filter({"field": "rating", "op": "gte", "value": "high"}, _schema()) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER + + +def test_parse_filter_rejects_empty_logical_array(): + with pytest.raises(RedisVLMCPError) as exc_info: + parse_filter({"and": []}, _schema()) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER + + +def test_parse_filter_rejects_malformed_payload(): + with pytest.raises(RedisVLMCPError) as exc_info: + parse_filter({"field": "category", "value": "science"}, _schema()) + + assert exc_info.value.code == MCPErrorCode.INVALID_FILTER diff --git a/tests/unit/test_mcp/test_search_tool_unit.py b/tests/unit/test_mcp/test_search_tool_unit.py new file mode 100644 index 00000000..185c850f --- /dev/null +++ b/tests/unit/test_mcp/test_search_tool_unit.py @@ -0,0 +1,345 @@ +from types import SimpleNamespace + +import pytest + +from redisvl.mcp.config import MCPConfig +from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError +from redisvl.mcp.tools.search import register_search_tool, search_records +from redisvl.schema import IndexSchema + + +def _schema() -> IndexSchema: + return IndexSchema.from_dict( + { + "index": { + "name": "docs-index", + "prefix": "doc", + "storage_type": "hash", + }, + "fields": [ + {"name": "content", "type": "text"}, + {"name": "category", "type": "tag"}, + {"name": "rating", "type": "numeric"}, + { + "name": "embedding", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": 3, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + } + ) + + +def _config() -> MCPConfig: + return MCPConfig.model_validate( + { + "server": {"redis_url": "redis://localhost:6379"}, + "indexes": { + "knowledge": { + "redis_name": "docs-index", + "vectorizer": {"class": "FakeVectorizer", "model": "test-model"}, + "runtime": { + "text_field_name": "content", + "vector_field_name": "embedding", + "default_embed_text_field": "content", + "default_limit": 2, + "max_limit": 5, + }, + } + }, + } + ) + + +class FakeVectorizer: + async def embed(self, text: str): + return [0.1, 0.2, 0.3] + + +class FakeIndex: + def __init__(self): + self.schema = _schema() + self.query_calls = [] + + async def query(self, query): + self.query_calls.append(query) + return [] + + +class FakeServer: + def __init__(self): + self.config = _config() + self.mcp_settings = SimpleNamespace(tool_search_description=None) + self.index = FakeIndex() + self.vectorizer = FakeVectorizer() + self.registered_tools = [] + self.native_hybrid_supported = False + + async def get_index(self): + return self.index + + async def get_vectorizer(self): + return self.vectorizer + + async def run_guarded(self, operation_name, awaitable): + return await awaitable + + async def supports_native_hybrid_search(self): + return self.native_hybrid_supported + + def tool(self, name=None, description=None, **kwargs): + def decorator(fn): + self.registered_tools.append( + { + "name": name, + "description": description, + "fn": fn, + } + ) + return fn + + return decorator + + +class FakeQuery: + def __init__(self, **kwargs): + self.kwargs = kwargs + + +@pytest.mark.asyncio +async def test_search_records_rejects_blank_query(): + server = FakeServer() + + with pytest.raises(RedisVLMCPError) as exc_info: + await search_records(server, query=" ") + + assert exc_info.value.code == MCPErrorCode.INVALID_REQUEST + + +@pytest.mark.asyncio +async def test_search_records_rejects_invalid_limit_and_offset(): + server = FakeServer() + + with pytest.raises(RedisVLMCPError) as limit_exc: + await search_records(server, query="science", limit=0) + + with pytest.raises(RedisVLMCPError) as offset_exc: + await search_records(server, query="science", offset=-1) + + assert limit_exc.value.code == MCPErrorCode.INVALID_REQUEST + assert offset_exc.value.code == MCPErrorCode.INVALID_REQUEST + + +@pytest.mark.asyncio +async def test_search_records_rejects_unknown_or_vector_return_fields(): + server = FakeServer() + + with pytest.raises(RedisVLMCPError) as unknown_exc: + await search_records(server, query="science", return_fields=["missing"]) + + with pytest.raises(RedisVLMCPError) as vector_exc: + await search_records(server, query="science", return_fields=["embedding"]) + + assert unknown_exc.value.code == MCPErrorCode.INVALID_REQUEST + assert vector_exc.value.code == MCPErrorCode.INVALID_REQUEST + + +@pytest.mark.asyncio +async def test_search_records_builds_vector_query_and_normalizes_results(monkeypatch): + server = FakeServer() + built_queries = [] + + class FakeVectorQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(kwargs) + super().__init__(**kwargs) + + async def fake_query(query): + server.index.query_calls.append(query) + return [ + { + "id": "doc:1", + "content": "science doc", + "category": "science", + "vector_distance": "0.93", + } + ] + + monkeypatch.setattr("redisvl.mcp.tools.search.VectorQuery", FakeVectorQuery) + server.index.query = fake_query + + response = await search_records(server, query="science") + + assert built_queries[0]["vector"] == [0.1, 0.2, 0.3] + assert built_queries[0]["vector_field_name"] == "embedding" + assert built_queries[0]["return_fields"] == ["content", "category", "rating"] + assert built_queries[0]["num_results"] == 2 + assert built_queries[0]["normalize_vector_distance"] is True + assert response == { + "search_type": "vector", + "offset": 0, + "limit": 2, + "results": [ + { + "id": "doc:1", + "score": 0.93, + "score_type": "vector_distance_normalized", + "record": { + "content": "science doc", + "category": "science", + }, + } + ], + } + + +@pytest.mark.asyncio +async def test_search_records_builds_fulltext_query(monkeypatch): + server = FakeServer() + built_queries = [] + + class FakeTextQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(kwargs) + super().__init__(**kwargs) + + async def fake_query(query): + server.index.query_calls.append(query) + return [ + { + "id": "doc:2", + "content": "medical science", + "category": "health", + "__score": "1.5", + } + ] + + monkeypatch.setattr("redisvl.mcp.tools.search.TextQuery", FakeTextQuery) + server.index.query = fake_query + + response = await search_records( + server, + query="medical science", + search_type="fulltext", + limit=1, + return_fields=["content", "category"], + ) + + assert built_queries[0]["text"] == "medical science" + assert built_queries[0]["text_field_name"] == "content" + assert built_queries[0]["num_results"] == 1 + assert response["results"][0]["score"] == 1.5 + assert response["results"][0]["score_type"] == "text_score" + + +@pytest.mark.asyncio +async def test_search_records_builds_hybrid_query_for_native_runtime(monkeypatch): + server = FakeServer() + server.native_hybrid_supported = True + built_queries = [] + + class FakePostProcessingConfig: + def __init__(self): + self.apply_calls = [] + + def apply(self, **kwargs): + self.apply_calls.append(kwargs) + + class FakeHybridQuery(FakeQuery): + def __init__(self, **kwargs): + self.postprocessing_config = FakePostProcessingConfig() + built_queries.append(("native", kwargs, self.postprocessing_config)) + super().__init__(**kwargs) + + class FakeAggregateHybridQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(("fallback", kwargs)) + super().__init__(**kwargs) + + async def fake_query(query): + server.index.query_calls.append(query) + return [ + { + "id": "doc:3", + "content": "hybrid doc", + "hybrid_score": "2.5", + } + ] + + monkeypatch.setattr("redisvl.mcp.tools.search.HybridQuery", FakeHybridQuery) + monkeypatch.setattr( + "redisvl.mcp.tools.search.AggregateHybridQuery", FakeAggregateHybridQuery + ) + server.index.query = fake_query + + response = await search_records(server, query="hybrid", search_type="hybrid") + + assert built_queries[0][0] == "native" + assert built_queries[0][1]["vector"] == [0.1, 0.2, 0.3] + assert built_queries[0][2].apply_calls == [{"__key": "@__key"}] + assert response["results"][0]["score_type"] == "hybrid_score" + assert response["results"][0]["score"] == 2.5 + + +@pytest.mark.asyncio +async def test_search_records_builds_hybrid_query_for_fallback_runtime(monkeypatch): + server = FakeServer() + built_queries = [] + + class FakeHybridQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(("native", kwargs)) + super().__init__(**kwargs) + + class FakeAggregateHybridQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(("fallback", kwargs)) + super().__init__(**kwargs) + + async def fake_query(query): + server.index.query_calls.append(query) + return [ + { + "id": "doc:4", + "content": "fallback hybrid", + "hybrid_score": "0.7", + } + ] + + monkeypatch.setattr("redisvl.mcp.tools.search.HybridQuery", FakeHybridQuery) + monkeypatch.setattr( + "redisvl.mcp.tools.search.AggregateHybridQuery", FakeAggregateHybridQuery + ) + server.index.query = fake_query + + response = await search_records(server, query="hybrid", search_type="hybrid") + + assert built_queries[0][0] == "fallback" + assert built_queries[0][1]["return_fields"] == [ + "__key", + "content", + "category", + "rating", + ] + assert response["results"][0]["score"] == 0.7 + + +def test_register_search_tool_uses_default_and_override_descriptions(): + default_server = FakeServer() + register_search_tool(default_server) + + assert default_server.registered_tools[0]["name"] == "search-records" + assert "Search records" in default_server.registered_tools[0]["description"] + + custom_server = FakeServer() + custom_server.mcp_settings.tool_search_description = "Custom search description" + register_search_tool(custom_server) + + assert ( + custom_server.registered_tools[0]["description"] == "Custom search description" + ) From 9bdbee390f0188309fb545ff442f292b98041dcf Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 11:11:00 +0100 Subject: [PATCH 3/9] Add Codex config files to .gitignore --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index cd800581..16c99a57 100644 --- a/.gitignore +++ b/.gitignore @@ -189,6 +189,9 @@ dmypy.json # Cython debug symbols cython_debug/ +# Codex +.codex/ + # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore From cb2d0942e3df0079be54b99db0c55e5feafdc678 Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 11:19:39 +0100 Subject: [PATCH 4/9] Define config-owned search behavior in MCP spec --- spec/MCP.md | 149 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 113 insertions(+), 36 deletions(-) diff --git a/spec/MCP.md b/spec/MCP.md index 5f09e723..c160db79 100644 --- a/spec/MCP.md +++ b/spec/MCP.md @@ -13,6 +13,8 @@ metadata: This specification defines a Model Context Protocol (MCP) server for RedisVL that allows MCP clients to search and upsert data in an existing Redis index. +Search behavior is owned by server configuration. MCP clients provide query text, filtering, pagination, and field projection, but do not choose the search mode or runtime tuning parameters. + The MCP design targets indexes hosted on open-source Redis Stack, Redis Cloud, or Redis Enterprise, provided the required Search capabilities are available for the configured tool behavior. The server is designed for stdio transport first and must be runnable via: @@ -25,7 +27,7 @@ For a production-oriented usage narrative and end-to-end example, see [MCP-produ ### Goals -1. Expose RedisVL search capabilities (`vector`, `fulltext`, `hybrid`) through stable MCP tools. +1. Expose configured RedisVL search capabilities (`vector`, `fulltext`, `hybrid`) through stable MCP tools without requiring MCP clients to configure retrieval strategy. 2. Support controlled write access via an upsert tool. 3. Automatically reconstruct the index schema from an existing Redis index instead of requiring a full manual schema definition. 4. Keep the vectorizer configuration explicit and user-defined. @@ -59,7 +61,7 @@ These are hard compatibility expectations for v1. Notes: - This spec standardizes on the standalone `fastmcp` package for server implementation. It does not assume the official `mcp` package is on a 2.x line. - Client SDK examples may still use whichever client-side MCP package their ecosystem requires. -- Native hybrid support is preferred when available because it aligns with current Redis runtime capabilities, but lack of native support is not a blocker for `search_type=\"hybrid\"`. +- Native hybrid support is preferred when available because it aligns with current Redis runtime capabilities, but lack of native support is not a blocker for `indexes..search.type=\"hybrid\"` when the configured search params remain compatible with the aggregate fallback. --- @@ -148,6 +150,16 @@ indexes: dims: 1536 datatype: float32 + search: + type: hybrid + params: + text_scorer: BM25STD + stopwords: english + vector_search_method: KNN + combination_method: LINEAR + linear_text_weight: 0.3 + knn_ef_runtime: 150 + runtime: # required explicit field mapping for tool behavior text_field_name: content @@ -170,6 +182,51 @@ indexes: max_concurrency: 16 ``` +### Search Configuration (Normative) + +`indexes..search` defines the retrieval strategy for the sole bound index in v1. Tool callers must not override this configuration. + +Required fields: + +- `type`: `vector` | `fulltext` | `hybrid` +- `params`: optional object whose allowed keys depend on `type` + +Allowed `params` by `type`: + +- `vector` + - `hybrid_policy` + - `batch_size` + - `ef_runtime` + - `epsilon` + - `search_window_size` + - `use_search_history` + - `search_buffer_capacity` + - `normalize_vector_distance` +- `fulltext` + - `text_scorer` + - `stopwords` + - `text_weights` +- `hybrid` + - `text_scorer` + - `stopwords` + - `text_weights` + - `vector_search_method` + - `knn_ef_runtime` + - `range_radius` + - `range_epsilon` + - `combination_method` + - `rrf_window` + - `rrf_constant` + - `linear_text_weight` + +Normalization rules: + +1. `linear_text_weight` is the MCP config's stable meaning for linear hybrid fusion and always represents the text-side weight. +2. When building native `HybridQuery`, the server must pass `linear_text_weight` through as `linear_alpha`. +3. When building `AggregateHybridQuery`, the server must translate `linear_text_weight` to `alpha = 1 - linear_text_weight` so the config meaning does not change across implementations. +4. `linear_text_weight` is only valid when `combination_method` is `LINEAR`. +5. Hybrid configs using FT.SEARCH-only runtime params (`knn_ef_runtime`) must fail startup if the environment only supports the aggregate fallback path. + ### Schema Discovery and Override Rules 1. `server.redis_url` is required. @@ -179,19 +236,22 @@ indexes: 5. The server must reconstruct the base schema from Redis metadata, preferably via existing RedisVL inspection primitives built on `FT.INFO`. 6. `indexes..vectorizer` remains fully manual and is never inferred from Redis index metadata in v1. 7. `indexes..schema_overrides` is optional and exists only to supplement incomplete inspection data. -8. Discovered index identity is authoritative: +8. `indexes..search.type` is required and is authoritative for query construction. +9. `indexes..search.params` is optional but, when present, may only contain keys valid for the configured `search.type`. +10. Tool requests implicitly target the sole configured index binding and its configured search behavior in v1. No `index`, `search_type`, or search-tuning request parameters are exposed. +11. Tool callers may control only query text, filtering, pagination, and returned fields for `search-records`. +12. Discovered index identity is authoritative: - `indexes..redis_name` - storage type - field identity (`name`, `type`, and `path` when applicable) -9. Overrides may: +13. Overrides may: - add missing attrs for a discovered field - replace discovered attrs for a discovered field when needed for compatibility -10. Overrides must not: +14. Overrides must not: - redefine index identity - add entirely new fields that do not exist in the inspected index - change a discovered field's `name`, `type`, or `path` -11. Override conflicts must fail startup with a config error. -12. Tool requests implicitly target the sole configured index binding in v1. No `index` request parameter is exposed yet. +15. Override conflicts must fail startup with a config error. ### Env Substitution Rules @@ -210,13 +270,17 @@ Server startup must fail fast if: 4. `indexes` missing, empty, or containing more than one entry. 5. The configured binding id is blank. 6. `indexes..redis_name` missing or blank. -7. The referenced Redis index does not exist. -8. Schema inspection fails and no valid `indexes..schema_overrides` resolve the issue. -9. `indexes..runtime.text_field_name` not in the effective schema. -10. `indexes..runtime.vector_field_name` not in the effective schema or not vector type. -11. `indexes..runtime.default_embed_text_field` not in the effective schema. -12. `default_limit <= 0` or `max_limit < default_limit`. -13. `max_upsert_records <= 0`. +7. `indexes..search.type` missing or not one of `vector`, `fulltext`, `hybrid`. +8. `indexes..search.params` contains keys that are incompatible with the configured `search.type`. +9. `indexes..search.params.linear_text_weight` is present without `combination_method: LINEAR`. +10. A hybrid config relies on FT.SEARCH-only runtime params and the environment only supports the aggregate fallback path. +11. The referenced Redis index does not exist. +12. Schema inspection fails and no valid `indexes..schema_overrides` resolve the issue. +13. `indexes..runtime.text_field_name` not in the effective schema. +14. `indexes..runtime.vector_field_name` not in the effective schema or not vector type. +15. `indexes..runtime.default_embed_text_field` not in the effective schema. +16. `default_limit <= 0` or `max_limit < default_limit`. +17. `max_upsert_records <= 0`. --- @@ -234,9 +298,10 @@ On server startup: 6. Convert the inspected index metadata into an `IndexSchema`. 7. Apply any validated `indexes..schema_overrides` to produce the effective schema. 8. Instantiate `AsyncSearchIndex` from the effective schema. -9. Instantiate the configured `indexes..vectorizer`. -10. Validate vectorizer dimensions against the effective vector field dims when available. -11. Register tools (omit upsert in read-only mode). +9. Validate `indexes..search` against the effective schema and current runtime capabilities. +10. Instantiate the configured `indexes..vectorizer`. +11. Validate vectorizer dimensions against the effective vector field dims when available. +12. Register tools (omit upsert in read-only mode). If vector field attributes cannot be reconstructed from Redis metadata on the target Redis version, startup must fail with an actionable error unless `indexes..schema_overrides` provides the missing attrs. @@ -299,14 +364,13 @@ Tool executions are bounded by an async semaphore (`runtime.max_concurrency`). R ## Tool: `search-records` -Search records using vector, full-text, or hybrid query. +Search records using the configured search behavior for the bound index. ### Request Contract | Parameter | Type | Required | Default | Constraints | |----------|------|----------|---------|-------------| | `query` | str | yes | - | non-empty | -| `search_type` | enum | no | `vector` | `vector` \| `fulltext` \| `hybrid` | | `limit` | int | no | `runtime.default_limit` | `1..runtime.max_limit` | | `offset` | int | no | `0` | `>=0` | | `filter` | string \\| object | no | `null` | Raw RedisVL filter string or DSL object | @@ -335,12 +399,14 @@ Search records using vector, full-text, or hybrid query. ### Search Semantics -- `vector`: embeds `query` with configured vectorizer, builds `VectorQuery`. -- `fulltext`: builds `TextQuery`. +- `search_type` in the response is informational metadata derived from `indexes..search.type`. +- `search-records` must reject deprecated client-side search-mode or tuning inputs with `invalid_request`. +- `vector`: embeds `query` with the configured vectorizer and builds `VectorQuery` using `indexes..search.params`. +- `fulltext`: builds `TextQuery` using `indexes..search.params`. - `hybrid`: embeds `query` and selects the query implementation by runtime capability: - use native `HybridQuery` when Redis `>=8.4.0` and redis-py `>=7.1.0` are available - otherwise fall back to `AggregateHybridQuery` -- The MCP request/response contract for `hybrid` is identical across both implementation paths. +- The MCP request/response contract for `hybrid` is identical across both implementation paths because config normalization hides class-specific fusion semantics from tool callers. - In v1, `filter` is applied uniformly to the hybrid query rather than allowing separate text-side and vector-side filters. This is intentional to keep the API simple; future versions may expose finer-grained hybrid filtering controls. ### Errors @@ -421,8 +487,10 @@ For the sole configured binding in v1, the server owns these validated values: - `text_field_name` - `vector_field_name` - `default_embed_text_field` +- `search.type` +- `search.params` -Schema discovery is automatic in v1. Field mapping is not. Runtime field mappings remain explicit so the server does not guess among multiple valid text or vector fields. +Schema discovery is automatic in v1. Field mapping is not. Search construction is configuration-owned. Runtime field mappings remain explicit so the server does not guess among multiple valid text or vector fields, and MCP callers do not choose retrieval mode or tuning. --- @@ -478,7 +546,7 @@ async def main(): ) as server: agent = Agent( name="search-agent", - instructions="Search and maintain Redis-backed knowledge.", + instructions="Search and maintain Redis-backed knowledge using the server-configured retrieval strategy.", mcp_servers=[server], ) ``` @@ -494,7 +562,7 @@ from mcp import StdioServerParameters root_agent = LlmAgent( model="gemini-2.0-flash", name="redis_search_agent", - instruction="Search and maintain Redis-backed knowledge using vector search.", + instruction="Search and maintain Redis-backed knowledge using the server-configured retrieval strategy.", tools=[ McpToolset( connection_params=StdioConnectionParams( @@ -570,6 +638,8 @@ Note: Full n8n MCP client support depends on n8n's MCP implementation. Refer to - env substitution success/failure - schema inspection merge and override validation - field mapping validation + - `indexes..search` validation by type + - normalized hybrid fusion validation - `test_filters.py` - DSL parsing, invalid operators, type mismatches - `test_errors.py` @@ -582,10 +652,14 @@ Note: Full n8n MCP client support depends on n8n's MCP implementation. Refer to - missing index failure - vector field inspection gap resolved by `indexes..schema_overrides` - conflicting override failure + - hybrid config with FT.SEARCH-only params rejected when only aggregate fallback is available - `test_search_tool.py` - - vector/fulltext/hybrid success paths + - configured `vector` / `fulltext` / `hybrid` success paths + - request without `search_type` succeeds + - deprecated client-side search-mode or tuning params rejected with `invalid_request` + - response reports configured `search_type` - native hybrid path on Redis `>=8.4.0` - - aggregate hybrid fallback path on older supported runtimes + - aggregate hybrid fallback path on older supported runtimes when config is compatible - pagination and field projection - filter behavior - `test_upsert_tool.py` @@ -622,12 +696,13 @@ DoD: Deliverables: 1. `search-records` request/response contract. 2. Filter parser (JSON DSL + raw string pass-through). -3. Hybrid query selection between native and aggregate implementations. +3. Config-owned search construction and hybrid query selection between native and aggregate implementations. DoD: 1. All search modes tested. 2. Invalid filter returns `invalid_filter`. -3. `hybrid` uses native execution when available and `AggregateHybridQuery` otherwise, without changing the MCP contract. +3. Deprecated client-side search-mode and tuning inputs return `invalid_request`. +4. `hybrid` uses native execution when available and `AggregateHybridQuery` otherwise, without changing the MCP contract or the meaning of `linear_text_weight`. ### Phase 3: Upsert Tool @@ -657,12 +732,11 @@ DoD: Deliverables: 1. Config reference and examples. 2. Client setup examples. -3. Companion production example document. -4. Troubleshooting guide with common errors and fixes. +3. Troubleshooting guide with common errors and fixes. DoD: 1. Docs reflect normative contracts in this spec. -2. Companion example is aligned with the config and lifecycle contract. +2. Client-facing examples do not imply MCP callers choose retrieval mode. --- @@ -670,17 +744,20 @@ DoD: 1. Runtime mismatch for hybrid search. - Native hybrid requires newer Redis and redis-py capabilities, while older supported environments may still need the aggregate fallback path. - - Mitigation: explicitly detect runtime capability and select native `HybridQuery` or `AggregateHybridQuery` deterministically. + - Mitigation: explicitly detect runtime capability, reject incompatible hybrid configs at startup, and otherwise select native `HybridQuery` or `AggregateHybridQuery` deterministically. 2. Dependency drift across provider vectorizers. - Mitigation: dependency matrix and startup validation. -3. Ambiguous filter behavior causing agent retries. - - Mitigation: explicit raw-string pass-through semantics and deterministic DSL parser errors. +3. Search behavior drift caused by client-selected tuning. + - Mitigation: keep search mode and query construction params in config, not in the MCP request surface. 4. Hidden partial writes during failures. - Mitigation: conservative `partial_write_possible` signaling. 5. Incomplete schema reconstruction on older Redis versions. - `FT.INFO` may not return enough vector metadata on some older Redis versions to fully reconstruct vector field attrs. - Mitigation: fail fast with an actionable error and support targeted `indexes..schema_overrides` for missing attrs. -6. Security and deployment limitations (v1 scope). +6. Hybrid fusion semantics differ between `HybridQuery` and `AggregateHybridQuery`. + - Native `HybridQuery` uses text-weight semantics while `AggregateHybridQuery` exposes vector-weight semantics. + - Mitigation: normalize on `linear_text_weight` in MCP config and translate internally per execution path. +7. Security and deployment limitations (v1 scope). - This implementation is stdio-first and not production-hardened by itself. It does not include: - Authentication/authorization mechanisms. - Remote transports (SSE/HTTP) that would enable multi-tenant or networked deployments. From 9db4c13ec37ea59b28215a48c5a370b3dab8d514 Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 11:48:21 +0100 Subject: [PATCH 5/9] Configure search in config, tool just takes query --- redisvl/mcp/config.py | 96 ++++++++- redisvl/mcp/server.py | 3 + redisvl/mcp/tools/search.py | 201 +++++++++++------- .../integration/test_mcp/test_search_tool.py | 93 ++++++-- .../test_mcp/test_server_startup.py | 36 ++++ tests/unit/test_mcp/test_config.py | 122 +++++++++-- tests/unit/test_mcp/test_search_tool_unit.py | 90 ++++++-- 7 files changed, 517 insertions(+), 124 deletions(-) diff --git a/redisvl/mcp/config.py b/redisvl/mcp/config.py index 939c7c6d..af7104f5 100644 --- a/redisvl/mcp/config.py +++ b/redisvl/mcp/config.py @@ -2,7 +2,7 @@ import re from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, Literal, Optional import yaml from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -71,6 +71,84 @@ class MCPServerConfig(BaseModel): redis_url: str = Field(..., min_length=1) +class MCPIndexSearchConfig(BaseModel): + """Configured search mode and query tuning for the bound index. + + The MCP request contract only exposes query text, filtering, pagination, and + field projection. Search mode and query-tuning behavior are owned entirely by + YAML config and validated here. + """ + + type: Literal["vector", "fulltext", "hybrid"] + params: Dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="after") + def _validate_params(self) -> "MCPIndexSearchConfig": + """Reject params that do not belong to the configured search mode.""" + allowed_params = { + "vector": { + "hybrid_policy", + "batch_size", + "ef_runtime", + "epsilon", + "search_window_size", + "use_search_history", + "search_buffer_capacity", + "normalize_vector_distance", + }, + "fulltext": { + "text_scorer", + "stopwords", + "text_weights", + }, + "hybrid": { + "text_scorer", + "stopwords", + "text_weights", + "vector_search_method", + "knn_ef_runtime", + "range_radius", + "range_epsilon", + "combination_method", + "rrf_window", + "rrf_constant", + "linear_text_weight", + }, + } + invalid_keys = sorted(set(self.params) - allowed_params[self.type]) + if invalid_keys: + raise ValueError( + "search.params contains keys incompatible with " + f"search.type '{self.type}': {', '.join(invalid_keys)}" + ) + + if ( + "linear_text_weight" in self.params + and self.params.get("combination_method") != "LINEAR" + ): + raise ValueError( + "search.params.linear_text_weight requires combination_method to be LINEAR" + ) + return self + + def to_query_params(self) -> Dict[str, Any]: + """Return normalized query kwargs exactly as configured.""" + return dict(self.params) + + def validate_runtime_capabilities( + self, *, supports_native_hybrid_search: bool + ) -> None: + """Fail startup when hybrid config depends on native-only FT.SEARCH params.""" + if ( + self.type == "hybrid" + and not supports_native_hybrid_search + and "knn_ef_runtime" in self.params + ): + raise ValueError( + "search.params.knn_ef_runtime requires native hybrid search support" + ) + + class MCPSchemaOverrideField(BaseModel): """Allowed schema override fragment for one already-discovered field.""" @@ -91,6 +169,7 @@ class MCPIndexBindingConfig(BaseModel): redis_name: str = Field(..., min_length=1) vectorizer: MCPVectorizerConfig + search: MCPIndexSearchConfig runtime: MCPRuntimeConfig schema_overrides: MCPSchemaOverrides = Field(default_factory=MCPSchemaOverrides) @@ -134,6 +213,11 @@ def vectorizer(self) -> MCPVectorizerConfig: """Expose the sole binding's vectorizer config for phase 1.""" return self.binding.vectorizer + @property + def search(self) -> MCPIndexSearchConfig: + """Expose the sole binding's configured search behavior.""" + return self.binding.search + @property def redis_name(self) -> str: """Return the existing Redis index name that must be inspected at startup.""" @@ -255,6 +339,16 @@ def get_vector_field_dims(self, schema: IndexSchema) -> Optional[int]: attrs = self.get_vector_field(schema).attrs return getattr(attrs, "dims", None) + def validate_search( + self, + *, + supports_native_hybrid_search: bool, + ) -> None: + """Validate configured search behavior against current runtime support.""" + self.search.validate_runtime_capabilities( + supports_native_hybrid_search=supports_native_hybrid_search + ) + def _substitute_env(value: Any) -> Any: """Recursively resolve `${VAR}` and `${VAR:-default}` placeholders.""" diff --git a/redisvl/mcp/server.py b/redisvl/mcp/server.py index 06c07f62..4e07512c 100644 --- a/redisvl/mcp/server.py +++ b/redisvl/mcp/server.py @@ -79,6 +79,9 @@ async def startup(self) -> None: # The server acquired this client explicitly during startup, so hand # ownership to the index for a single shutdown path. self._index._owns_redis_client = True + self.config.validate_search( + supports_native_hybrid_search=await self.supports_native_hybrid_search(), + ) self._vectorizer = await asyncio.wait_for( asyncio.to_thread(self._build_vectorizer), diff --git a/redisvl/mcp/tools/search.py b/redisvl/mcp/tools/search.py index 39a789b8..88ef26cb 100644 --- a/redisvl/mcp/tools/search.py +++ b/redisvl/mcp/tools/search.py @@ -8,23 +8,28 @@ DEFAULT_SEARCH_DESCRIPTION = "Search records in the configured Redis index." +_NATIVE_HYBRID_DEFAULTS = { + "combination_method": "LINEAR", + "linear_text_weight": 0.3, +} + def _validate_request( *, query: str, - search_type: str, limit: Optional[int], offset: int, return_fields: Optional[list[str]], server: Any, index: Any, ) -> tuple[int, list[str]]: - """Validate the MCP search request and resolve effective request defaults. + """Validate a `search-records` request and resolve default projection. - This function enforces the public MCP contract for `search-records` before - any RedisVL query objects are constructed. It also derives the default - return-field projection from the effective index schema. + The MCP caller can only supply query text, pagination, filters, and return + fields. Search mode and tuning are sourced from config, so this validation + step focuses only on the public request contract. """ + runtime = server.config.runtime if not isinstance(query, str) or not query.strip(): @@ -33,12 +38,6 @@ def _validate_request( code=MCPErrorCode.INVALID_REQUEST, retryable=False, ) - if search_type not in {"vector", "fulltext", "hybrid"}: - raise RedisVLMCPError( - "search_type must be one of: vector, fulltext, hybrid", - code=MCPErrorCode.INVALID_REQUEST, - retryable=False, - ) effective_limit = runtime.default_limit if limit is None else limit if not isinstance(effective_limit, int) or effective_limit <= 0: @@ -104,12 +103,7 @@ def _validate_request( def _normalize_record( result: dict[str, Any], score_field: str, score_type: str ) -> dict[str, Any]: - """Convert one RedisVL search result into the stable MCP result shape. - - RedisVL and redis-py expose scores and document identifiers under slightly - different field names depending on the query type, so normalization happens - here before the MCP response is returned. - """ + """Convert one RedisVL result into the stable MCP result shape.""" score = result.get(score_field) if score is None and score_field == "score": score = result.get("__score") @@ -152,49 +146,125 @@ def _normalize_record( async def _embed_query(vectorizer: Any, query: str) -> Any: - """Embed the user query through either an async or sync vectorizer API.""" - if hasattr(vectorizer, "aembed"): - return await vectorizer.aembed(query) + """Embed the query text, tolerating vectorizers without real async support.""" + aembed = getattr(vectorizer, "aembed", None) + if callable(aembed): + try: + return await aembed(query) + except NotImplementedError: + pass embed = getattr(vectorizer, "embed") if inspect.iscoroutinefunction(embed): return await embed(query) return await asyncio.to_thread(embed, query) +def _get_configured_search(server: Any) -> tuple[str, dict[str, Any]]: + """Return the configured search mode and normalized query params.""" + search_config = server.config.search + return search_config.type, search_config.to_query_params() + + +def _build_native_hybrid_kwargs( + *, + query: str, + embedding: Any, + runtime: Any, + filter_expression: Any, + return_fields: list[str], + num_results: int, + search_params: dict[str, Any], +) -> dict[str, Any]: + """Build native `HybridQuery` kwargs from MCP config-owned hybrid params.""" + params = {**_NATIVE_HYBRID_DEFAULTS, **search_params} + linear_text_weight = params.pop("linear_text_weight", None) + if linear_text_weight is not None: + params["linear_alpha"] = linear_text_weight + + return { + "text": query, + "text_field_name": runtime.text_field_name, + "vector": embedding, + "vector_field_name": runtime.vector_field_name, + "filter_expression": filter_expression, + "return_fields": ["__key", *return_fields], + "num_results": num_results, + "yield_text_score_as": "text_score", + "yield_vsim_score_as": "vector_similarity", + "yield_combined_score_as": "hybrid_score", + **params, + } + + +def _build_fallback_hybrid_kwargs( + *, + query: str, + embedding: Any, + runtime: Any, + filter_expression: Any, + return_fields: list[str], + num_results: int, + search_params: dict[str, Any], +) -> dict[str, Any]: + """Build aggregate fallback kwargs while preserving MCP fusion semantics.""" + params = { + key: value + for key, value in search_params.items() + if key in {"text_scorer", "stopwords", "text_weights"} + } + linear_text_weight = search_params.get("linear_text_weight", 0.3) + params["alpha"] = 1 - linear_text_weight + + return { + "text": query, + "text_field_name": runtime.text_field_name, + "vector": embedding, + "vector_field_name": runtime.vector_field_name, + "filter_expression": filter_expression, + "return_fields": ["__key", *return_fields], + "num_results": num_results, + **params, + } + + async def _build_query( *, server: Any, index: Any, query: str, - search_type: str, limit: int, offset: int, filter_value: str | dict[str, Any] | None, return_fields: list[str], -) -> tuple[Any, str, str]: - """Build the RedisVL query object and score metadata for one search mode. +) -> tuple[Any, str, str, str]: + """Build the RedisVL query object from configured search mode and params. - Returns the constructed query object along with the raw score field name and - the stable MCP `score_type` label that the response should expose. + Returns the query instance, the raw score field to read from RedisVL + results, the public MCP `score_type`, and the configured `search_type`. """ runtime = server.config.runtime + search_type, search_params = _get_configured_search(server) num_results = limit + offset filter_expression = parse_filter(filter_value, index.schema) if search_type == "vector": vectorizer = await server.get_vectorizer() embedding = await _embed_query(vectorizer, query) + vector_kwargs = { + "vector": embedding, + "vector_field_name": runtime.vector_field_name, + "filter_expression": filter_expression, + "return_fields": return_fields, + "num_results": num_results, + **search_params, + } + if "normalize_vector_distance" not in vector_kwargs: + vector_kwargs["normalize_vector_distance"] = True return ( - VectorQuery( - vector=embedding, - vector_field_name=runtime.vector_field_name, - filter_expression=filter_expression, - return_fields=return_fields, - num_results=num_results, - normalize_vector_distance=True, - ), + VectorQuery(**vector_kwargs), "vector_distance", "vector_distance_normalized", + search_type, ) if search_type == "fulltext": @@ -205,81 +275,68 @@ async def _build_query( filter_expression=filter_expression, return_fields=return_fields, num_results=num_results, - stopwords=None, + **search_params, ), "score", "text_score", + search_type, ) vectorizer = await server.get_vectorizer() embedding = await _embed_query(vectorizer, query) if await server.supports_native_hybrid_search(): native_query = HybridQuery( - text=query, - text_field_name=runtime.text_field_name, - vector=embedding, - vector_field_name=runtime.vector_field_name, - filter_expression=filter_expression, - return_fields=["__key", *return_fields], - num_results=num_results, - stopwords=None, - combination_method="LINEAR", - linear_alpha=0.7, - yield_text_score_as="text_score", - yield_vsim_score_as="vector_similarity", - yield_combined_score_as="hybrid_score", + **_build_native_hybrid_kwargs( + query=query, + embedding=embedding, + runtime=runtime, + filter_expression=filter_expression, + return_fields=return_fields, + num_results=num_results, + search_params=search_params, + ) ) native_query.postprocessing_config.apply(__key="@__key") - return ( - native_query, - "hybrid_score", - "hybrid_score", - ) + return native_query, "hybrid_score", "hybrid_score", search_type fallback_query = AggregateHybridQuery( - text=query, - text_field_name=runtime.text_field_name, - vector=embedding, - vector_field_name=runtime.vector_field_name, - filter_expression=filter_expression, - return_fields=["__key", *return_fields], - num_results=num_results, - stopwords=None, - ) - return ( - fallback_query, - "hybrid_score", - "hybrid_score", + **_build_fallback_hybrid_kwargs( + query=query, + embedding=embedding, + runtime=runtime, + filter_expression=filter_expression, + return_fields=return_fields, + num_results=num_results, + search_params=search_params, + ) ) + return fallback_query, "hybrid_score", "hybrid_score", search_type async def search_records( server: Any, *, query: str, - search_type: str = "vector", limit: Optional[int] = None, offset: int = 0, filter: str | dict[str, Any] | None = None, return_fields: Optional[list[str]] = None, ) -> dict[str, Any]: - """Execute `search-records` against the server's configured Redis index.""" + """Execute `search-records` against the configured Redis index binding.""" try: index = await server.get_index() effective_limit, effective_return_fields = _validate_request( query=query, - search_type=search_type, limit=limit, offset=offset, return_fields=return_fields, server=server, index=index, ) - built_query, score_field, score_type = await _build_query( + built_query, score_field, score_type, search_type = await _build_query( server=server, index=index, query=query.strip(), - search_type=search_type, limit=effective_limit, offset=offset, filter_value=filter, @@ -306,14 +363,13 @@ async def search_records( def register_search_tool(server: Any) -> None: - """Register the MCP search tool on a server-like object.""" + """Register the MCP `search-records` tool with its config-owned contract.""" description = ( server.mcp_settings.tool_search_description or DEFAULT_SEARCH_DESCRIPTION ) async def search_records_tool( query: str, - search_type: str = "vector", limit: Optional[int] = None, offset: int = 0, filter: str | dict[str, Any] | None = None, @@ -323,7 +379,6 @@ async def search_records_tool( return await search_records( server, query=query, - search_type=search_type, limit=limit, offset=offset, filter=filter, diff --git a/tests/integration/test_mcp/test_search_tool.py b/tests/integration/test_mcp/test_search_tool.py index 4f25a6b6..a5eaf8f3 100644 --- a/tests/integration/test_mcp/test_search_tool.py +++ b/tests/integration/test_mcp/test_search_tool.py @@ -94,7 +94,7 @@ def preprocess(record: dict) -> dict: @pytest.fixture def mcp_config_path(tmp_path: Path, redis_url: str): - def factory(redis_name: str) -> str: + def factory(redis_name: str, search: dict) -> str: config = { "server": {"redis_url": redis_url}, "indexes": { @@ -105,6 +105,7 @@ def factory(redis_name: str) -> str: "model": "fake-model", "dims": 3, }, + "search": search, "runtime": { "text_field_name": "content", "vector_field_name": "embedding", @@ -115,7 +116,7 @@ def factory(redis_name: str) -> str: } }, } - config_path = tmp_path / f"{redis_name}.yaml" + config_path = tmp_path / f"{redis_name}-{search['type']}.yaml" config_path.write_text(yaml.safe_dump(config), encoding="utf-8") return str(config_path) @@ -128,20 +129,42 @@ async def started_server(monkeypatch, searchable_index, mcp_config_path): "redisvl.mcp.server.resolve_vectorizer_class", lambda class_name: FakeVectorizer, ) - server = RedisVLMCPServer( - MCPSettings(config=mcp_config_path(searchable_index.schema.index.name)) - ) - await server.startup() - yield server - await server.shutdown() + + async def factory(search: dict) -> RedisVLMCPServer: + server = RedisVLMCPServer( + MCPSettings( + config=mcp_config_path(searchable_index.schema.index.name, search) + ) + ) + await server.startup() + return server + + servers = [] + + async def started(search: dict) -> RedisVLMCPServer: + server = await factory(search) + servers.append(server) + return server + + yield started + + for server in servers: + await server.shutdown() @pytest.mark.asyncio async def test_search_records_vector_success_with_pagination_and_projection( started_server, ): + server = await started_server( + { + "type": "vector", + "params": {"normalize_vector_distance": True}, + } + ) + response = await search_records( - started_server, + server, query="science", limit=1, offset=1, @@ -158,10 +181,19 @@ async def test_search_records_vector_success_with_pagination_and_projection( @pytest.mark.asyncio async def test_search_records_fulltext_success(started_server): + server = await started_server( + { + "type": "fulltext", + "params": { + "text_scorer": "BM25STD.NORM", + "stopwords": None, + }, + } + ) + response = await search_records( - started_server, + server, query="science", - search_type="fulltext", return_fields=["content", "category"], ) @@ -174,8 +206,10 @@ async def test_search_records_fulltext_success(started_server): @pytest.mark.asyncio async def test_search_records_respects_raw_string_filter(started_server): + server = await started_server({"type": "vector"}) + response = await search_records( - started_server, + server, query="science", filter="@category:{science}", return_fields=["content", "category"], @@ -189,8 +223,10 @@ async def test_search_records_respects_raw_string_filter(started_server): @pytest.mark.asyncio async def test_search_records_respects_dsl_filter(started_server): + server = await started_server({"type": "vector"}) + response = await search_records( - started_server, + server, query="science", filter={"field": "rating", "op": "gte", "value": 4.5}, return_fields=["content", "category", "rating"], @@ -204,9 +240,11 @@ async def test_search_records_respects_dsl_filter(started_server): @pytest.mark.asyncio async def test_search_records_invalid_filter_returns_invalid_filter(started_server): + server = await started_server({"type": "vector"}) + with pytest.raises(RedisVLMCPError) as exc_info: await search_records( - started_server, + server, query="science", filter={"field": "missing", "op": "eq", "value": "science"}, ) @@ -217,11 +255,20 @@ async def test_search_records_invalid_filter_returns_invalid_filter(started_serv @pytest.mark.asyncio async def test_search_records_native_hybrid_success(started_server, async_client): await skip_if_redis_version_below_async(async_client, "8.4.0") + server = await started_server( + { + "type": "hybrid", + "params": { + "combination_method": "LINEAR", + "linear_text_weight": 0.3, + "stopwords": None, + }, + } + ) response = await search_records( - started_server, + server, query="science", - search_type="hybrid", return_fields=["content", "category"], ) @@ -237,10 +284,20 @@ async def test_search_records_fallback_hybrid_success(started_server, async_clie if is_version_gte(redis_version, "8.4.0"): pytest.skip(f"Redis version {redis_version} uses native hybrid search") + server = await started_server( + { + "type": "hybrid", + "params": { + "combination_method": "LINEAR", + "linear_text_weight": 0.3, + "stopwords": None, + }, + } + ) + response = await search_records( - started_server, + server, query="science", - search_type="hybrid", return_fields=["content", "category"], ) diff --git a/tests/integration/test_mcp/test_server_startup.py b/tests/integration/test_mcp/test_server_startup.py index dd41ce91..8ed235b2 100644 --- a/tests/integration/test_mcp/test_server_startup.py +++ b/tests/integration/test_mcp/test_server_startup.py @@ -6,7 +6,9 @@ from redisvl.index import AsyncSearchIndex from redisvl.mcp.server import RedisVLMCPServer from redisvl.mcp.settings import MCPSettings +from redisvl.redis.connection import is_version_gte from redisvl.schema import IndexSchema +from tests.conftest import get_redis_version_async class FakeVectorizer: @@ -79,6 +81,7 @@ def factory( vector_dims: int = 3, schema_overrides: dict | None = None, runtime_overrides: dict | None = None, + search: dict | None = None, ) -> str: runtime = { "text_field_name": "content", @@ -98,6 +101,7 @@ def factory( "model": "fake-model", "dims": vector_dims, }, + "search": search or {"type": "vector"}, "runtime": runtime, } }, @@ -135,6 +139,38 @@ async def test_server_startup_success(monkeypatch, existing_index, mcp_config_pa await server.shutdown() +@pytest.mark.asyncio +async def test_server_fails_when_hybrid_config_requires_native_runtime( + monkeypatch, existing_index, mcp_config_path, async_client +): + redis_version = await get_redis_version_async(async_client) + if is_version_gte(redis_version, "8.4.0"): + pytest.skip(f"Redis version {redis_version} supports native hybrid search") + + index = await existing_index(index_name="mcp-native-required") + monkeypatch.setattr( + "redisvl.mcp.server.resolve_vectorizer_class", + lambda class_name: FakeVectorizer, + ) + server = RedisVLMCPServer( + MCPSettings( + config=mcp_config_path( + redis_name=index.name, + search={ + "type": "hybrid", + "params": { + "vector_search_method": "KNN", + "knn_ef_runtime": 150, + }, + }, + ) + ) + ) + + with pytest.raises(ValueError, match="knn_ef_runtime"): + await server.startup() + + @pytest.mark.asyncio async def test_server_fails_when_configured_index_is_missing( monkeypatch, mcp_config_path, worker_id diff --git a/tests/unit/test_mcp/test_config.py b/tests/unit/test_mcp/test_config.py index 4a0520f0..a524a52a 100644 --- a/tests/unit/test_mcp/test_config.py +++ b/tests/unit/test_mcp/test_config.py @@ -15,6 +15,7 @@ def _valid_config() -> dict: "knowledge": { "redis_name": "docs-index", "vectorizer": {"class": "FakeVectorizer", "model": "test-model"}, + "search": {"type": "vector"}, "runtime": { "text_field_name": "content", "vector_field_name": "embedding", @@ -68,17 +69,19 @@ def test_load_mcp_config_env_substitution(tmp_path: Path, monkeypatch): server: redis_url: ${REDIS_URL:-redis://localhost:6379} indexes: - knowledge: - redis_name: docs-index - vectorizer: - class: FakeVectorizer - model: ${VECTOR_MODEL:-test-model} - api_config: - api_key: ${OPENAI_API_KEY} - runtime: - text_field_name: content - vector_field_name: embedding - default_embed_text_field: content + knowledge: + redis_name: docs-index + vectorizer: + class: FakeVectorizer + model: ${VECTOR_MODEL:-test-model} + api_config: + api_key: ${OPENAI_API_KEY} + search: + type: vector + runtime: + text_field_name: content + vector_field_name: embedding + default_embed_text_field: content """.strip(), encoding="utf-8", ) @@ -101,15 +104,17 @@ def test_load_mcp_config_required_env_missing(tmp_path: Path, monkeypatch): server: redis_url: redis://localhost:6379 indexes: - knowledge: - redis_name: docs-index - vectorizer: - class: FakeVectorizer - model: ${VECTOR_MODEL} - runtime: - text_field_name: content - vector_field_name: embedding - default_embed_text_field: content + knowledge: + redis_name: docs-index + vectorizer: + class: FakeVectorizer + model: ${VECTOR_MODEL} + search: + type: vector + runtime: + text_field_name: content + vector_field_name: embedding + default_embed_text_field: content """.strip(), encoding="utf-8", ) @@ -166,6 +171,7 @@ def test_mcp_config_binding_helpers(): assert config.binding_id == "knowledge" assert config.binding.redis_name == "docs-index" + assert config.binding.search.type == "vector" assert config.runtime.default_embed_text_field == "content" assert config.vectorizer.class_name == "FakeVectorizer" assert config.redis_name == "docs-index" @@ -275,3 +281,79 @@ def test_load_mcp_config_requires_exactly_one_binding(tmp_path: Path): with pytest.raises(ValueError, match="exactly one configured index binding"): load_mcp_config(str(config_path)) + + +@pytest.mark.parametrize("search_type", ["vector", "fulltext", "hybrid"]) +def test_mcp_config_accepts_search_types(search_type): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = {"type": search_type} + + loaded = MCPConfig.model_validate(config) + + assert loaded.binding.search.type == search_type + assert loaded.binding.search.params == {} + + +def test_mcp_config_requires_search_type(): + config = _valid_config() + del config["indexes"]["knowledge"]["search"]["type"] + + with pytest.raises(ValueError, match="type"): + MCPConfig.model_validate(config) + + +def test_mcp_config_rejects_invalid_search_type(): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = {"type": "semantic"} + + with pytest.raises(ValueError, match="vector|fulltext|hybrid"): + MCPConfig.model_validate(config) + + +@pytest.mark.parametrize( + ("search_type", "params"), + [ + ("vector", {"text_scorer": "BM25STD"}), + ("fulltext", {"normalize_vector_distance": True}), + ("hybrid", {"normalize_vector_distance": True}), + ], +) +def test_mcp_config_rejects_invalid_search_params(search_type, params): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = { + "type": search_type, + "params": params, + } + + with pytest.raises(ValueError, match="search.params"): + MCPConfig.model_validate(config) + + +def test_mcp_config_rejects_linear_text_weight_without_linear_combination(): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = { + "type": "hybrid", + "params": { + "combination_method": "RRF", + "linear_text_weight": 0.3, + }, + } + + with pytest.raises(ValueError, match="linear_text_weight"): + MCPConfig.model_validate(config) + + +def test_mcp_config_normalizes_hybrid_linear_text_weight(): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = { + "type": "hybrid", + "params": { + "combination_method": "LINEAR", + "linear_text_weight": 0.3, + }, + } + + loaded = MCPConfig.model_validate(config) + + assert loaded.binding.search.type == "hybrid" + assert loaded.binding.search.params["linear_text_weight"] == 0.3 diff --git a/tests/unit/test_mcp/test_search_tool_unit.py b/tests/unit/test_mcp/test_search_tool_unit.py index 185c850f..fa8367a9 100644 --- a/tests/unit/test_mcp/test_search_tool_unit.py +++ b/tests/unit/test_mcp/test_search_tool_unit.py @@ -4,7 +4,7 @@ from redisvl.mcp.config import MCPConfig from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError -from redisvl.mcp.tools.search import register_search_tool, search_records +from redisvl.mcp.tools.search import _embed_query, register_search_tool, search_records from redisvl.schema import IndexSchema @@ -35,7 +35,7 @@ def _schema() -> IndexSchema: ) -def _config() -> MCPConfig: +def _config_with_search(search_type: str, params: dict | None = None) -> MCPConfig: return MCPConfig.model_validate( { "server": {"redis_url": "redis://localhost:6379"}, @@ -43,6 +43,7 @@ def _config() -> MCPConfig: "knowledge": { "redis_name": "docs-index", "vectorizer": {"class": "FakeVectorizer", "model": "test-model"}, + "search": {"type": search_type, "params": params or {}}, "runtime": { "text_field_name": "content", "vector_field_name": "embedding", @@ -72,8 +73,10 @@ async def query(self, query): class FakeServer: - def __init__(self): - self.config = _config() + def __init__( + self, *, search_type: str = "vector", search_params: dict | None = None + ): + self.config = _config_with_search(search_type, search_params) self.mcp_settings = SimpleNamespace(tool_search_description=None) self.index = FakeIndex() self.vectorizer = FakeVectorizer() @@ -111,6 +114,20 @@ def __init__(self, **kwargs): self.kwargs = kwargs +@pytest.mark.asyncio +async def test_embed_query_falls_back_to_sync_embed_when_aembed_is_not_implemented(): + class FallbackVectorizer: + async def aembed(self, text: str): + raise NotImplementedError + + def embed(self, text: str): + return [0.4, 0.5, 0.6] + + embedding = await _embed_query(FallbackVectorizer(), "science") + + assert embedding == [0.4, 0.5, 0.6] + + @pytest.mark.asyncio async def test_search_records_rejects_blank_query(): server = FakeServer() @@ -151,7 +168,10 @@ async def test_search_records_rejects_unknown_or_vector_return_fields(): @pytest.mark.asyncio async def test_search_records_builds_vector_query_and_normalizes_results(monkeypatch): - server = FakeServer() + server = FakeServer( + search_type="vector", + search_params={"normalize_vector_distance": False, "ef_runtime": 42}, + ) built_queries = [] class FakeVectorQuery(FakeQuery): @@ -179,7 +199,8 @@ async def fake_query(query): assert built_queries[0]["vector_field_name"] == "embedding" assert built_queries[0]["return_fields"] == ["content", "category", "rating"] assert built_queries[0]["num_results"] == 2 - assert built_queries[0]["normalize_vector_distance"] is True + assert built_queries[0]["normalize_vector_distance"] is False + assert built_queries[0]["ef_runtime"] == 42 assert response == { "search_type": "vector", "offset": 0, @@ -200,7 +221,14 @@ async def fake_query(query): @pytest.mark.asyncio async def test_search_records_builds_fulltext_query(monkeypatch): - server = FakeServer() + server = FakeServer( + search_type="fulltext", + search_params={ + "text_scorer": "BM25STD.NORM", + "stopwords": None, + "text_weights": {"medical": 2.5}, + }, + ) built_queries = [] class FakeTextQuery(FakeQuery): @@ -225,7 +253,6 @@ async def fake_query(query): response = await search_records( server, query="medical science", - search_type="fulltext", limit=1, return_fields=["content", "category"], ) @@ -233,13 +260,28 @@ async def fake_query(query): assert built_queries[0]["text"] == "medical science" assert built_queries[0]["text_field_name"] == "content" assert built_queries[0]["num_results"] == 1 + assert built_queries[0]["text_scorer"] == "BM25STD.NORM" + assert built_queries[0]["stopwords"] is None + assert built_queries[0]["text_weights"] == {"medical": 2.5} + assert response["search_type"] == "fulltext" assert response["results"][0]["score"] == 1.5 assert response["results"][0]["score_type"] == "text_score" @pytest.mark.asyncio async def test_search_records_builds_hybrid_query_for_native_runtime(monkeypatch): - server = FakeServer() + server = FakeServer( + search_type="hybrid", + search_params={ + "text_scorer": "TFIDF", + "stopwords": None, + "text_weights": {"hybrid": 2.0}, + "vector_search_method": "KNN", + "knn_ef_runtime": 77, + "combination_method": "LINEAR", + "linear_text_weight": 0.2, + }, + ) server.native_hybrid_supported = True built_queries = [] @@ -277,18 +319,35 @@ async def fake_query(query): ) server.index.query = fake_query - response = await search_records(server, query="hybrid", search_type="hybrid") + response = await search_records(server, query="hybrid") assert built_queries[0][0] == "native" assert built_queries[0][1]["vector"] == [0.1, 0.2, 0.3] + assert built_queries[0][1]["text_scorer"] == "TFIDF" + assert built_queries[0][1]["stopwords"] is None + assert built_queries[0][1]["text_weights"] == {"hybrid": 2.0} + assert built_queries[0][1]["vector_search_method"] == "KNN" + assert built_queries[0][1]["knn_ef_runtime"] == 77 + assert built_queries[0][1]["combination_method"] == "LINEAR" + assert built_queries[0][1]["linear_alpha"] == 0.2 assert built_queries[0][2].apply_calls == [{"__key": "@__key"}] + assert response["search_type"] == "hybrid" assert response["results"][0]["score_type"] == "hybrid_score" assert response["results"][0]["score"] == 2.5 @pytest.mark.asyncio async def test_search_records_builds_hybrid_query_for_fallback_runtime(monkeypatch): - server = FakeServer() + server = FakeServer( + search_type="hybrid", + search_params={ + "text_scorer": "TFIDF", + "stopwords": None, + "text_weights": {"hybrid": 2.0}, + "combination_method": "LINEAR", + "linear_text_weight": 0.2, + }, + ) built_queries = [] class FakeHybridQuery(FakeQuery): @@ -317,15 +376,20 @@ async def fake_query(query): ) server.index.query = fake_query - response = await search_records(server, query="hybrid", search_type="hybrid") + response = await search_records(server, query="hybrid") assert built_queries[0][0] == "fallback" + assert built_queries[0][1]["text_scorer"] == "TFIDF" + assert built_queries[0][1]["stopwords"] is None + assert built_queries[0][1]["text_weights"] == {"hybrid": 2.0} + assert built_queries[0][1]["alpha"] == pytest.approx(0.8) assert built_queries[0][1]["return_fields"] == [ "__key", "content", "category", "rating", ] + assert response["search_type"] == "hybrid" assert response["results"][0]["score"] == 0.7 @@ -335,6 +399,8 @@ def test_register_search_tool_uses_default_and_override_descriptions(): assert default_server.registered_tools[0]["name"] == "search-records" assert "Search records" in default_server.registered_tools[0]["description"] + assert "query" in default_server.registered_tools[0]["fn"].__annotations__ + assert "search_type" not in default_server.registered_tools[0]["fn"].__annotations__ custom_server = FakeServer() custom_server.mcp_settings.tool_search_description = "Custom search description" From aa93dd2a010de740d5ca421ec02aa1a5e4968727 Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 11:57:15 +0100 Subject: [PATCH 6/9] Python 3.9 compat --- redisvl/mcp/tools/search.py | 8 ++++---- tests/integration/test_mcp/test_server_startup.py | 9 +++++---- tests/unit/test_mcp/test_search_tool_unit.py | 8 ++++++-- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/redisvl/mcp/tools/search.py b/redisvl/mcp/tools/search.py index 88ef26cb..ae59c783 100644 --- a/redisvl/mcp/tools/search.py +++ b/redisvl/mcp/tools/search.py @@ -1,6 +1,6 @@ import asyncio import inspect -from typing import Any, Optional +from typing import Any, Optional, Union from redisvl.mcp.errors import MCPErrorCode, RedisVLMCPError, map_exception from redisvl.mcp.filters import parse_filter @@ -234,7 +234,7 @@ async def _build_query( query: str, limit: int, offset: int, - filter_value: str | dict[str, Any] | None, + filter_value: Optional[Union[str, dict[str, Any]]], return_fields: list[str], ) -> tuple[Any, str, str, str]: """Build the RedisVL query object from configured search mode and params. @@ -319,7 +319,7 @@ async def search_records( query: str, limit: Optional[int] = None, offset: int = 0, - filter: str | dict[str, Any] | None = None, + filter: Optional[Union[str, dict[str, Any]]] = None, return_fields: Optional[list[str]] = None, ) -> dict[str, Any]: """Execute `search-records` against the configured Redis index binding.""" @@ -372,7 +372,7 @@ async def search_records_tool( query: str, limit: Optional[int] = None, offset: int = 0, - filter: str | dict[str, Any] | None = None, + filter: Optional[Union[str, dict[str, Any]]] = None, return_fields: Optional[list[str]] = None, ): """FastMCP wrapper for the `search-records` tool.""" diff --git a/tests/integration/test_mcp/test_server_startup.py b/tests/integration/test_mcp/test_server_startup.py index 8ed235b2..953aa6df 100644 --- a/tests/integration/test_mcp/test_server_startup.py +++ b/tests/integration/test_mcp/test_server_startup.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Optional import pytest import yaml @@ -31,7 +32,7 @@ async def factory( *, index_name: str, storage_type: str = "hash", - vector_path: str | None = None, + vector_path: Optional[str] = None, ) -> AsyncSearchIndex: fields = [{"name": "content", "type": "text"}] vector_field = { @@ -79,9 +80,9 @@ def factory( *, redis_name: str, vector_dims: int = 3, - schema_overrides: dict | None = None, - runtime_overrides: dict | None = None, - search: dict | None = None, + schema_overrides: Optional[dict] = None, + runtime_overrides: Optional[dict] = None, + search: Optional[dict] = None, ) -> str: runtime = { "text_field_name": "content", diff --git a/tests/unit/test_mcp/test_search_tool_unit.py b/tests/unit/test_mcp/test_search_tool_unit.py index fa8367a9..0afc37fa 100644 --- a/tests/unit/test_mcp/test_search_tool_unit.py +++ b/tests/unit/test_mcp/test_search_tool_unit.py @@ -1,4 +1,5 @@ from types import SimpleNamespace +from typing import Optional import pytest @@ -35,7 +36,7 @@ def _schema() -> IndexSchema: ) -def _config_with_search(search_type: str, params: dict | None = None) -> MCPConfig: +def _config_with_search(search_type: str, params: Optional[dict] = None) -> MCPConfig: return MCPConfig.model_validate( { "server": {"redis_url": "redis://localhost:6379"}, @@ -74,7 +75,10 @@ async def query(self, query): class FakeServer: def __init__( - self, *, search_type: str = "vector", search_params: dict | None = None + self, + *, + search_type: str = "vector", + search_params: Optional[dict] = None, ): self.config = _config_with_search(search_type, search_params) self.mcp_settings = SimpleNamespace(tool_search_description=None) From ac3cc904e4a52f0fa9d55e8c3c9066d6fca363a7 Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 17:28:28 +0100 Subject: [PATCH 7/9] Cache hybrid support checks and validate fallback search params --- redisvl/mcp/config.py | 29 +++++++++++--- redisvl/mcp/server.py | 12 +++++- redisvl/mcp/tools/search.py | 7 +++- tests/unit/test_mcp/test_config.py | 37 +++++++++++++++++ tests/unit/test_mcp/test_search_tool_unit.py | 16 +++++++- tests/unit/test_mcp/test_server_unit.py | 42 ++++++++++++++++++++ 6 files changed, 134 insertions(+), 9 deletions(-) create mode 100644 tests/unit/test_mcp/test_server_unit.py diff --git a/redisvl/mcp/config.py b/redisvl/mcp/config.py index af7104f5..e226986b 100644 --- a/redisvl/mcp/config.py +++ b/redisvl/mcp/config.py @@ -139,13 +139,30 @@ def validate_runtime_capabilities( self, *, supports_native_hybrid_search: bool ) -> None: """Fail startup when hybrid config depends on native-only FT.SEARCH params.""" - if ( - self.type == "hybrid" - and not supports_native_hybrid_search - and "knn_ef_runtime" in self.params - ): + if self.type != "hybrid" or supports_native_hybrid_search: + return + + unsupported_params = set() + if self.params.get("combination_method") not in (None, "LINEAR"): + unsupported_params.add("combination_method") + unsupported_params.update( + key + for key in ( + "vector_search_method", + "knn_ef_runtime", + "range_radius", + "range_epsilon", + "rrf_window", + "rrf_constant", + ) + if key in self.params + ) + + if unsupported_params: + unsupported_list = ", ".join(sorted(unsupported_params)) raise ValueError( - "search.params.knn_ef_runtime requires native hybrid search support" + "search.params requires native hybrid search support for: " + f"{unsupported_list}" ) diff --git a/redisvl/mcp/server.py b/redisvl/mcp/server.py index 4e07512c..fb0df041 100644 --- a/redisvl/mcp/server.py +++ b/redisvl/mcp/server.py @@ -42,6 +42,7 @@ def __init__(self, settings: MCPSettings): self.config: Optional[MCPConfig] = None self._index: Optional[AsyncSearchIndex] = None self._vectorizer: Optional[Any] = None + self._supports_native_hybrid_search: Optional[bool] = None self._semaphore: Optional[asyncio.Semaphore] = None self._tools_registered = False @@ -49,6 +50,7 @@ async def startup(self) -> None: """Load config, inspect the configured index, and initialize dependencies.""" self.config = load_mcp_config(self.mcp_settings.config) self._semaphore = asyncio.Semaphore(self.config.runtime.max_concurrency) + self._supports_native_hybrid_search = None timeout = self.config.runtime.startup_timeout_seconds client = None @@ -109,6 +111,7 @@ async def shutdown(self) -> None: elif callable(close): close() finally: + self._supports_native_hybrid_search = None if self._index is not None: index = self._index self._index = None @@ -164,17 +167,24 @@ def _validate_vectorizer_dims(self, schema: IndexSchema) -> None: async def supports_native_hybrid_search(self) -> bool: """Return whether the current runtime supports Redis native hybrid search.""" + if self._supports_native_hybrid_search is not None: + return self._supports_native_hybrid_search if self._index is None: raise RuntimeError("MCP server has not been started") if not is_version_gte(redis_py_version, "7.1.0"): + self._supports_native_hybrid_search = False return False client = await self._index._get_client() info = await client.info("server") if not is_version_gte(info.get("redis_version", "0.0.0"), "8.4.0"): + self._supports_native_hybrid_search = False return False - return hasattr(client.ft(self._index.schema.index.name), "hybrid_search") + self._supports_native_hybrid_search = hasattr( + client.ft(self._index.schema.index.name), "hybrid_search" + ) + return self._supports_native_hybrid_search def _register_tools(self) -> None: """Register MCP tools once the server is ready.""" diff --git a/redisvl/mcp/tools/search.py b/redisvl/mcp/tools/search.py index ae59c783..91ea8d2b 100644 --- a/redisvl/mcp/tools/search.py +++ b/redisvl/mcp/tools/search.py @@ -260,10 +260,15 @@ async def _build_query( } if "normalize_vector_distance" not in vector_kwargs: vector_kwargs["normalize_vector_distance"] = True + normalize_vector_distance = vector_kwargs["normalize_vector_distance"] return ( VectorQuery(**vector_kwargs), "vector_distance", - "vector_distance_normalized", + ( + "vector_distance_normalized" + if normalize_vector_distance + else "vector_distance" + ), search_type, ) diff --git a/tests/unit/test_mcp/test_config.py b/tests/unit/test_mcp/test_config.py index a524a52a..bd20ed2d 100644 --- a/tests/unit/test_mcp/test_config.py +++ b/tests/unit/test_mcp/test_config.py @@ -357,3 +357,40 @@ def test_mcp_config_normalizes_hybrid_linear_text_weight(): assert loaded.binding.search.type == "hybrid" assert loaded.binding.search.params["linear_text_weight"] == 0.3 + + +@pytest.mark.parametrize( + "params", + [ + {"knn_ef_runtime": 42}, + {"vector_search_method": "RANGE", "range_radius": 0.4}, + {"combination_method": "RRF", "rrf_window": 50}, + ], +) +def test_mcp_config_rejects_native_only_hybrid_runtime_params(params): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = { + "type": "hybrid", + "params": params, + } + + loaded = MCPConfig.model_validate(config) + + with pytest.raises(ValueError, match="native hybrid search support"): + loaded.validate_search(supports_native_hybrid_search=False) + + +def test_mcp_config_allows_linear_hybrid_fallback_params(): + config = _valid_config() + config["indexes"]["knowledge"]["search"] = { + "type": "hybrid", + "params": { + "text_scorer": "TFIDF", + "combination_method": "LINEAR", + "linear_text_weight": 0.3, + }, + } + + loaded = MCPConfig.model_validate(config) + + loaded.validate_search(supports_native_hybrid_search=False) diff --git a/tests/unit/test_mcp/test_search_tool_unit.py b/tests/unit/test_mcp/test_search_tool_unit.py index 0afc37fa..4bc48358 100644 --- a/tests/unit/test_mcp/test_search_tool_unit.py +++ b/tests/unit/test_mcp/test_search_tool_unit.py @@ -213,7 +213,7 @@ async def fake_query(query): { "id": "doc:1", "score": 0.93, - "score_type": "vector_distance_normalized", + "score_type": "vector_distance", "record": { "content": "science doc", "category": "science", @@ -397,6 +397,20 @@ async def fake_query(query): assert response["results"][0]["score"] == 0.7 +@pytest.mark.asyncio +async def test_search_records_rejects_native_only_hybrid_runtime_params(monkeypatch): + server = FakeServer( + search_type="hybrid", + search_params={ + "combination_method": "RRF", + "rrf_window": 50, + }, + ) + + with pytest.raises(ValueError, match="native hybrid search support"): + server.config.validate_search(supports_native_hybrid_search=False) + + def test_register_search_tool_uses_default_and_override_descriptions(): default_server = FakeServer() register_search_tool(default_server) diff --git a/tests/unit/test_mcp/test_server_unit.py b/tests/unit/test_mcp/test_server_unit.py new file mode 100644 index 00000000..13ba23d5 --- /dev/null +++ b/tests/unit/test_mcp/test_server_unit.py @@ -0,0 +1,42 @@ +from types import SimpleNamespace + +import pytest + +from redisvl.mcp.server import RedisVLMCPServer + + +class FakeClient: + def __init__(self): + self.info_calls = 0 + + async def info(self, section: str): + self.info_calls += 1 + assert section == "server" + return {"redis_version": "8.4.0"} + + def ft(self, index_name: str): + assert index_name == "docs-index" + return SimpleNamespace(hybrid_search=object()) + + +class FakeIndex: + def __init__(self, client: FakeClient): + self.schema = SimpleNamespace(index=SimpleNamespace(name="docs-index")) + self._client = client + + async def _get_client(self): + return self._client + + +@pytest.mark.asyncio +async def test_supports_native_hybrid_search_caches_runtime_probe(monkeypatch): + client = FakeClient() + server = RedisVLMCPServer.__new__(RedisVLMCPServer) + server._index = FakeIndex(client) + server._supports_native_hybrid_search = None + + monkeypatch.setattr("redisvl.mcp.server.redis_py_version", "7.1.0") + + assert await server.supports_native_hybrid_search() is True + assert await server.supports_native_hybrid_search() is True + assert client.info_calls == 1 From 32f4a0758a644811215c6d5c4d92d3f624e8b991 Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Wed, 25 Mar 2026 17:53:24 +0100 Subject: [PATCH 8/9] Classify malformed search results as internal errors --- redisvl/mcp/tools/search.py | 4 +- tests/unit/test_mcp/test_search_tool_unit.py | 40 ++++++++++++++++++++ 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/redisvl/mcp/tools/search.py b/redisvl/mcp/tools/search.py index 91ea8d2b..0178a9c4 100644 --- a/redisvl/mcp/tools/search.py +++ b/redisvl/mcp/tools/search.py @@ -110,7 +110,7 @@ def _normalize_record( if score is None: raise RedisVLMCPError( f"Search result missing expected score field '{score_field}'", - code=MCPErrorCode.INVALID_REQUEST, + code=MCPErrorCode.INTERNAL_ERROR, retryable=False, ) @@ -123,7 +123,7 @@ def _normalize_record( if doc_id is None: raise RedisVLMCPError( "Search result missing id", - code=MCPErrorCode.INVALID_REQUEST, + code=MCPErrorCode.INTERNAL_ERROR, retryable=False, ) diff --git a/tests/unit/test_mcp/test_search_tool_unit.py b/tests/unit/test_mcp/test_search_tool_unit.py index 4bc48358..5d1b3ee2 100644 --- a/tests/unit/test_mcp/test_search_tool_unit.py +++ b/tests/unit/test_mcp/test_search_tool_unit.py @@ -170,6 +170,46 @@ async def test_search_records_rejects_unknown_or_vector_return_fields(): assert vector_exc.value.code == MCPErrorCode.INVALID_REQUEST +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("result", "message"), + [ + ( + { + "id": "doc:broken", + "content": "science doc", + "category": "science", + }, + "missing expected score field", + ), + ( + { + "content": "science doc", + "category": "science", + "vector_distance": "0.93", + }, + "missing id", + ), + ], +) +async def test_search_records_treats_malformed_backend_results_as_internal_errors( + result, message +): + server = FakeServer(search_type="vector") + + async def fake_query(query): + server.index.query_calls.append(query) + return [result] + + server.index.query = fake_query + + with pytest.raises(RedisVLMCPError, match=message) as exc_info: + await search_records(server, query="science") + + assert exc_info.value.code == MCPErrorCode.INTERNAL_ERROR + assert exc_info.value.retryable is False + + @pytest.mark.asyncio async def test_search_records_builds_vector_query_and_normalizes_results(monkeypatch): server = FakeServer( From c7b21542b5d22f6210c3933549c3e722beb1a862 Mon Sep 17 00:00:00 2001 From: Vishal Bala Date: Thu, 26 Mar 2026 11:00:58 +0100 Subject: [PATCH 9/9] Fix native hybrid linear defaults for RRF --- redisvl/mcp/tools/search.py | 15 ++++- tests/unit/test_mcp/test_search_tool_unit.py | 61 ++++++++++++++++++++ 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/redisvl/mcp/tools/search.py b/redisvl/mcp/tools/search.py index 0178a9c4..29da0496 100644 --- a/redisvl/mcp/tools/search.py +++ b/redisvl/mcp/tools/search.py @@ -176,10 +176,19 @@ def _build_native_hybrid_kwargs( search_params: dict[str, Any], ) -> dict[str, Any]: """Build native `HybridQuery` kwargs from MCP config-owned hybrid params.""" - params = {**_NATIVE_HYBRID_DEFAULTS, **search_params} - linear_text_weight = params.pop("linear_text_weight", None) - if linear_text_weight is not None: + params = dict(search_params) + combination_method = params.setdefault( + "combination_method", + _NATIVE_HYBRID_DEFAULTS["combination_method"], + ) + if combination_method == "LINEAR": + linear_text_weight = params.pop( + "linear_text_weight", + _NATIVE_HYBRID_DEFAULTS["linear_text_weight"], + ) params["linear_alpha"] = linear_text_weight + else: + params.pop("linear_text_weight", None) return { "text": query, diff --git a/tests/unit/test_mcp/test_search_tool_unit.py b/tests/unit/test_mcp/test_search_tool_unit.py index 5d1b3ee2..5bb49a0d 100644 --- a/tests/unit/test_mcp/test_search_tool_unit.py +++ b/tests/unit/test_mcp/test_search_tool_unit.py @@ -380,6 +380,67 @@ async def fake_query(query): assert response["results"][0]["score"] == 2.5 +@pytest.mark.asyncio +async def test_search_records_avoids_linear_defaults_for_rrf_native_hybrid_query( + monkeypatch, +): + server = FakeServer( + search_type="hybrid", + search_params={ + "combination_method": "RRF", + "rrf_window": 50, + }, + ) + server.native_hybrid_supported = True + built_queries = [] + + class FakePostProcessingConfig: + def __init__(self): + self.apply_calls = [] + + def apply(self, **kwargs): + self.apply_calls.append(kwargs) + + class FakeHybridQuery(FakeQuery): + def __init__(self, **kwargs): + self.postprocessing_config = FakePostProcessingConfig() + built_queries.append(("native", kwargs, self.postprocessing_config)) + super().__init__(**kwargs) + + class FakeAggregateHybridQuery(FakeQuery): + def __init__(self, **kwargs): + built_queries.append(("fallback", kwargs)) + super().__init__(**kwargs) + + async def fake_query(query): + server.index.query_calls.append(query) + return [ + { + "id": "doc:rrf", + "content": "hybrid doc", + "hybrid_score": "1.2", + } + ] + + monkeypatch.setattr("redisvl.mcp.tools.search.HybridQuery", FakeHybridQuery) + monkeypatch.setattr( + "redisvl.mcp.tools.search.AggregateHybridQuery", FakeAggregateHybridQuery + ) + server.index.query = fake_query + + response = await search_records(server, query="hybrid") + + assert built_queries[0][0] == "native" + assert built_queries[0][1]["combination_method"] == "RRF" + assert built_queries[0][1]["rrf_window"] == 50 + assert "linear_alpha" not in built_queries[0][1] + assert "linear_text_weight" not in built_queries[0][1] + assert built_queries[0][2].apply_calls == [{"__key": "@__key"}] + assert response["search_type"] == "hybrid" + assert response["results"][0]["score_type"] == "hybrid_score" + assert response["results"][0]["score"] == 1.2 + + @pytest.mark.asyncio async def test_search_records_builds_hybrid_query_for_fallback_runtime(monkeypatch): server = FakeServer(