From 0277857f4e6669484948d1c14d54add7ba8b7a13 Mon Sep 17 00:00:00 2001 From: Sheldon Date: Sun, 7 Jun 2026 12:01:53 +0800 Subject: [PATCH] Fix tenant sequence validation --- test/collection/test_tenant_validation.py | 47 +++++++++++++++++++++++ weaviate/validator.py | 10 +++-- 2 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 test/collection/test_tenant_validation.py diff --git a/test/collection/test_tenant_validation.py b/test/collection/test_tenant_validation.py new file mode 100644 index 000000000..509e7f58e --- /dev/null +++ b/test/collection/test_tenant_validation.py @@ -0,0 +1,47 @@ +from typing import Sequence, Union + +import pytest + +from weaviate.collections.classes.tenants import Tenant, TenantCreate, TenantUpdate +from weaviate.exceptions import WeaviateInvalidInputError +from weaviate.validator import _validate_input, _ValidateArgument + + +def test_tenant_create_sequence_validation_rejects_mixed_invalid_values() -> None: + with pytest.raises(WeaviateInvalidInputError): + _validate_input( + _ValidateArgument( + expected=[Tenant, TenantCreate, Sequence[Union[str, Tenant, TenantCreate]]], + name="tenants", + value=[Tenant(name="tenant-a"), object()], + ) + ) + + +def test_tenant_update_sequence_validation_rejects_mixed_invalid_values() -> None: + with pytest.raises(WeaviateInvalidInputError): + _validate_input( + _ValidateArgument( + expected=[Tenant, TenantUpdate, Sequence[Union[Tenant, TenantUpdate]]], + name="tenants", + value=[Tenant(name="tenant-a"), object()], + ) + ) + + +def test_tenant_sequence_validation_accepts_valid_values() -> None: + _validate_input( + _ValidateArgument( + expected=[Tenant, TenantCreate, Sequence[Union[str, Tenant, TenantCreate]]], + name="tenants", + value=["tenant-a", Tenant(name="tenant-b"), TenantCreate(name="tenant-c")], + ) + ) + + _validate_input( + _ValidateArgument( + expected=[Tenant, TenantUpdate, Sequence[Union[Tenant, TenantUpdate]]], + name="tenants", + value=[Tenant(name="tenant-a"), TenantUpdate(name="tenant-b")], + ) + ) diff --git a/weaviate/validator.py b/weaviate/validator.py index 7fe11945c..148275225 100644 --- a/weaviate/validator.py +++ b/weaviate/validator.py @@ -35,6 +35,8 @@ def _validate_input(inputs: Union[List[_ValidateArgument], _ValidateArgument]) - def _is_valid(expected: Any, value: Any) -> bool: + if expected is Any: + return True if expected is None: return value is None @@ -46,7 +48,7 @@ def _is_valid(expected: Any, value: Any) -> bool: expected_origin = get_origin(expected) if expected_origin is Union: args = get_args(expected) - return any(isinstance(value, arg) for arg in args) + return any(_is_valid(arg, value) for arg in args) if expected_origin is not None and ( issubclass(expected_origin, Sequence) or expected_origin is list ): @@ -56,7 +58,9 @@ def _is_valid(expected: Any, value: Any) -> bool: if len(args) == 1: if get_origin(args[0]) is Union: union_args = get_args(args[0]) - return any(isinstance(val, union_arg) for val in value for union_arg in union_args) + return len(value) > 0 and all( + any(_is_valid(union_arg, val) for union_arg in union_args) for val in value + ) else: - return all(isinstance(val, args[0]) for val in value) + return all(_is_valid(args[0], val) for val in value) return isinstance(value, expected)