Skip to content

Commit f6ca4bc

Browse files
committed
fix: support optionally_keyed_by with underlying dict
1 parent 109c700 commit f6ca4bc

File tree

2 files changed

+81
-14
lines changed

2 files changed

+81
-14
lines changed

src/taskgraph/util/schema.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import re
77
import threading
88
from collections.abc import Mapping
9-
from functools import reduce
10-
from typing import Any, Literal, Optional, Union
9+
from typing import Annotated, Any, Literal, Optional, Union, get_args, get_origin
1110

1211
import msgspec
1312
import voluptuous
@@ -70,11 +69,37 @@ def validate_schema(schema, obj, msg_prefix):
7069
raise Exception(f"{msg_prefix}\n{str(exc)}\n{pprint.pformat(obj)}")
7170

7271

73-
def UnionTypes(*types):
74-
"""Use `functools.reduce` to simulate `Union[*allowed_types]` on older
75-
Python versions.
76-
"""
77-
return reduce(lambda a, b: Union[a, b], types)
72+
class OptionallyKeyedBy:
73+
"""Metadata class for optionally_keyed_by fields in msgspec schemas."""
74+
75+
def __init__(self, *fields, wrapped_type):
76+
self.fields = {f"by-{field}" for field in fields}
77+
self.wrapped_type = wrapped_type
78+
79+
def uses_keyed_by(self, obj) -> bool:
80+
if not isinstance(obj, dict) or len(obj) != 1:
81+
return False
82+
83+
key = list(obj)[0]
84+
if key not in self.fields:
85+
return False
86+
87+
return True
88+
89+
def validate(self, obj) -> None:
90+
if not self.uses_keyed_by(obj):
91+
# Not using keyed by, validate directly against wrapped type
92+
msgspec.convert(obj, self.wrapped_type)
93+
return
94+
95+
# First validate the outer keyed-by dict
96+
msgspec.convert(obj, dict[str, dict])
97+
98+
# Next validate each inner value. We call self.validate recursively to
99+
# support nested `by-*` keys.
100+
keyed_by_dict = list(obj.values())[0]
101+
for value in keyed_by_dict.values():
102+
self.validate(value)
78103

79104

80105
def optionally_keyed_by(*arguments, use_msgspec=False):
@@ -86,13 +111,15 @@ def optionally_keyed_by(*arguments, use_msgspec=False):
86111
use_msgspec: If True, return msgspec type hints; if False, return voluptuous validator
87112
"""
88113
if use_msgspec:
89-
# msgspec implementation - return type hints
114+
# msgspec implementation - use Annotated[Any, OptionallyKeyedBy]
90115
_type = arguments[-1]
91116
if _type is object:
92117
return object
93118
fields = arguments[:-1]
94-
bykeys = [Literal[f"by-{field}"] for field in fields]
95-
return Union[_type, dict[UnionTypes(*bykeys), dict[str, Any]]]
119+
wrapper = OptionallyKeyedBy(*fields, wrapped_type=_type)
120+
# Annotating Any allows msgspec to accept any value without validation.
121+
# The actual validation then happens in Schema.__post_init__
122+
return Annotated[Any, wrapper]
96123
else:
97124
# voluptuous implementation - return validator function
98125
schema = arguments[-1]
@@ -318,6 +345,31 @@ class MySchema(Schema, forbid_unknown_fields=False, kw_only=True):
318345
foo: str
319346
"""
320347

348+
def __post_init__(self):
349+
if taskgraph.fast:
350+
return
351+
352+
# Validate fields that use optionally_keyed_by. We need to validate this
353+
# manually because msgspec doesn't support union types with multiple
354+
# dicts. Any fields that use `optionally_keyed_by("foo", dict)` would
355+
# otherwise raise an exception.
356+
for field_name, field_type in self.__class__.__annotations__.items():
357+
origin = get_origin(field_type)
358+
args = get_args(field_type)
359+
360+
if (
361+
origin is not Annotated
362+
or len(args) < 2
363+
or not isinstance(args[1], OptionallyKeyedBy)
364+
):
365+
# Not using `optionally_keyed_by`
366+
continue
367+
368+
keyed_by = args[1]
369+
obj = getattr(self, field_name)
370+
371+
keyed_by.validate(obj)
372+
321373
@classmethod
322374
def validate(cls, data):
323375
"""Validate data against this schema."""

test/test_util_schema.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,24 +290,37 @@ class TestSchema(Schema):
290290
TestSchema.validate({"field": "baz"})
291291
TestSchema.validate({"field": {"by-foo": {"a": "b", "c": "d"}}})
292292

293-
# Inner dict values are Any, so mixed types are accepted
294-
TestSchema.validate({"field": {"by-foo": {"a": 1, "c": "d"}}})
293+
with pytest.raises(msgspec.ValidationError):
294+
TestSchema.validate({"field": 1})
295+
296+
with pytest.raises(msgspec.ValidationError):
297+
TestSchema.validate({"field": {"by-bar": "a"}})
298+
299+
with pytest.raises(msgspec.ValidationError):
300+
TestSchema.validate({"field": {"by-bar": {1: "b"}}})
295301

296302
with pytest.raises(msgspec.ValidationError):
297303
TestSchema.validate({"field": {"by-bar": {"a": "b"}}})
298304

305+
with pytest.raises(msgspec.ValidationError):
306+
TestSchema.validate({"field": {"by-foo": {"a": 1, "c": "d"}}})
299307

300-
def test_optionally_keyed_by_mulitple_keys():
308+
309+
def test_optionally_keyed_by_multiple_keys():
301310
class TestSchema(Schema):
302311
field: optionally_keyed_by("foo", "bar", str, use_msgspec=True) # type: ignore
303312

304313
TestSchema.validate({"field": {"by-foo": {"a": "b"}}})
305314
TestSchema.validate({"field": {"by-bar": {"x": "y"}}})
315+
TestSchema.validate({"field": {"by-foo": {"a": {"by-bar": {"x": "y"}}}}})
306316

307317
# Test invalid keyed-by field
308318
with pytest.raises(msgspec.ValidationError):
309319
TestSchema.validate({"field": {"by-unknown": {"a": "b"}}})
310320

321+
with pytest.raises(msgspec.ValidationError):
322+
TestSchema.validate({"field": {"by-foo": {"a": {"by-bar": {"x": 1}}}}})
323+
311324

312325
def test_optionally_keyed_by_object_passthrough():
313326
"""When the type argument is `object`, optionally_keyed_by returns object directly."""
@@ -320,14 +333,16 @@ def test_optionally_keyed_by_object_passthrough():
320333
assert msgspec.convert({"arbitrary": "dict"}, typ) == {"arbitrary": "dict"}
321334

322335

323-
@pytest.mark.xfail
324336
def test_optionally_keyed_by_dict():
325337
class TestSchema(Schema):
326338
field: optionally_keyed_by("foo", dict[str, str], use_msgspec=True) # type: ignore
327339

328340
TestSchema.validate({"field": {"by-foo": {"a": {"x": "y"}}}})
329341
TestSchema.validate({"field": {"a": "b"}})
330342

343+
with pytest.raises(msgspec.ValidationError):
344+
TestSchema.validate({"field": {"a": 1}})
345+
331346
with pytest.raises(msgspec.ValidationError):
332347
TestSchema.validate({"field": {"by-foo": {"a": {"x": 1}}}})
333348

0 commit comments

Comments
 (0)