Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions test/collection/test_tenant_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Sequence, Union

import pytest

from weaviate.collections.classes.tenants import Tenant, TenantCreate, TenantUpdate
from weaviate.exceptions import WeaviateInvalidInputError
from weaviate.validator import _validate_input, _ValidateArgument


def test_tenant_create_sequence_validation_rejects_mixed_invalid_values() -> None:
with pytest.raises(WeaviateInvalidInputError):
_validate_input(
_ValidateArgument(
expected=[Tenant, TenantCreate, Sequence[Union[str, Tenant, TenantCreate]]],
name="tenants",
value=[Tenant(name="tenant-a"), object()],
)
)


def test_tenant_update_sequence_validation_rejects_mixed_invalid_values() -> None:
with pytest.raises(WeaviateInvalidInputError):
_validate_input(
_ValidateArgument(
expected=[Tenant, TenantUpdate, Sequence[Union[Tenant, TenantUpdate]]],
name="tenants",
value=[Tenant(name="tenant-a"), object()],
)
)


def test_tenant_sequence_validation_accepts_valid_values() -> None:
_validate_input(
_ValidateArgument(
expected=[Tenant, TenantCreate, Sequence[Union[str, Tenant, TenantCreate]]],
name="tenants",
value=["tenant-a", Tenant(name="tenant-b"), TenantCreate(name="tenant-c")],
)
)

_validate_input(
_ValidateArgument(
expected=[Tenant, TenantUpdate, Sequence[Union[Tenant, TenantUpdate]]],
name="tenants",
value=[Tenant(name="tenant-a"), TenantUpdate(name="tenant-b")],
)
)
10 changes: 7 additions & 3 deletions weaviate/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def _validate_input(inputs: Union[List[_ValidateArgument], _ValidateArgument]) -


def _is_valid(expected: Any, value: Any) -> bool:
if expected is Any:
return True
if expected is None:
return value is None

Expand All @@ -46,7 +48,7 @@ def _is_valid(expected: Any, value: Any) -> bool:
expected_origin = get_origin(expected)
if expected_origin is Union:
args = get_args(expected)
return any(isinstance(value, arg) for arg in args)
return any(_is_valid(arg, value) for arg in args)
if expected_origin is not None and (
issubclass(expected_origin, Sequence) or expected_origin is list
):
Expand All @@ -56,7 +58,9 @@ def _is_valid(expected: Any, value: Any) -> bool:
if len(args) == 1:
if get_origin(args[0]) is Union:
union_args = get_args(args[0])
return any(isinstance(val, union_arg) for val in value for union_arg in union_args)
return len(value) > 0 and all(
any(_is_valid(union_arg, val) for union_arg in union_args) for val in value
)
else:
return all(isinstance(val, args[0]) for val in value)
return all(_is_valid(args[0], val) for val in value)
return isinstance(value, expected)