66import re
77import threading
88from collections .abc import Mapping
9- from functools import reduce
10- from typing import Any , Literal , Optional , Union
9+ from typing import Annotated , Any , Literal , Optional , Union , get_args , get_origin
1110
1211import msgspec
1312import voluptuous
@@ -70,11 +69,37 @@ def validate_schema(schema, obj, msg_prefix):
7069 raise Exception (f"{ msg_prefix } \n { str (exc )} \n { pprint .pformat (obj )} " )
7170
7271
73- def UnionTypes (* types ):
74- """Use `functools.reduce` to simulate `Union[*allowed_types]` on older
75- Python versions.
76- """
77- return reduce (lambda a , b : Union [a , b ], types )
72+ class OptionallyKeyedBy :
73+ """Metadata class for optionally_keyed_by fields in msgspec schemas."""
74+
75+ def __init__ (self , * fields , wrapped_type ):
76+ self .fields = {f"by-{ field } " for field in fields }
77+ self .wrapped_type = wrapped_type
78+
79+ def uses_keyed_by (self , obj ) -> bool :
80+ if not isinstance (obj , dict ) or len (obj ) != 1 :
81+ return False
82+
83+ key = list (obj )[0 ]
84+ if key not in self .fields :
85+ return False
86+
87+ return True
88+
89+ def validate (self , obj ) -> None :
90+ if not self .uses_keyed_by (obj ):
91+ # Not using keyed by, validate directly against wrapped type
92+ msgspec .convert (obj , self .wrapped_type )
93+ return
94+
95+ # First validate the outer keyed-by dict
96+ msgspec .convert (obj , dict [str , dict ])
97+
98+ # Next validate each inner value. We call self.validate recursively to
99+ # support nested `by-*` keys.
100+ keyed_by_dict = list (obj .values ())[0 ]
101+ for value in keyed_by_dict .values ():
102+ self .validate (value )
78103
79104
80105def optionally_keyed_by (* arguments , use_msgspec = False ):
@@ -86,13 +111,15 @@ def optionally_keyed_by(*arguments, use_msgspec=False):
86111 use_msgspec: If True, return msgspec type hints; if False, return voluptuous validator
87112 """
88113 if use_msgspec :
89- # msgspec implementation - return type hints
114+ # msgspec implementation - use Annotated[Any, OptionallyKeyedBy]
90115 _type = arguments [- 1 ]
91116 if _type is object :
92117 return object
93118 fields = arguments [:- 1 ]
94- bykeys = [Literal [f"by-{ field } " ] for field in fields ]
95- return Union [_type , dict [UnionTypes (* bykeys ), dict [str , Any ]]]
119+ wrapper = OptionallyKeyedBy (* fields , wrapped_type = _type )
120+ # Annotating Any allows msgspec to accept any value without validation.
121+ # The actual validation then happens in Schema.__post_init__
122+ return Annotated [Any , wrapper ]
96123 else :
97124 # voluptuous implementation - return validator function
98125 schema = arguments [- 1 ]
@@ -318,6 +345,31 @@ class MySchema(Schema, forbid_unknown_fields=False, kw_only=True):
318345 foo: str
319346 """
320347
348+ def __post_init__ (self ):
349+ if taskgraph .fast :
350+ return
351+
352+ # Validate fields that use optionally_keyed_by. We need to validate this
353+ # manually because msgspec doesn't support union types with multiple
354+ # dicts. Any fields that use `optionally_keyed_by("foo", dict)` would
355+ # otherwise raise an exception.
356+ for field_name , field_type in self .__class__ .__annotations__ .items ():
357+ origin = get_origin (field_type )
358+ args = get_args (field_type )
359+
360+ if (
361+ origin is not Annotated
362+ or len (args ) < 2
363+ or not isinstance (args [1 ], OptionallyKeyedBy )
364+ ):
365+ # Not using `optionally_keyed_by`
366+ continue
367+
368+ keyed_by = args [1 ]
369+ obj = getattr (self , field_name )
370+
371+ keyed_by .validate (obj )
372+
321373 @classmethod
322374 def validate (cls , data ):
323375 """Validate data against this schema."""
0 commit comments