Skip to content

Commit 71be6a8

Browse files
authored
Merge pull request #146 from jdebacker/marshmallow4
Marshmallow 4 upgrades
2 parents 81d19b4 + 58e2ad9 commit 71be6a8

8 files changed

Lines changed: 43 additions & 63 deletions

File tree

conda.recipe/meta.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ package:
55
requirements:
66
build:
77
- python
8-
- "marshmallow>=3.0.0rc4"
8+
- "marshmallow>=4.0.0"
99
- "numpy>=1.13"
1010
- "python-dateutil>=2.8.0"
1111

1212
run:
1313
- python
14-
- "marshmallow>=3.0.0rc4"
14+
- "marshmallow>=4.0.0"
1515
- "numpy>=1.13"
1616
- "python-dateutil>=2.8.0"
1717

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: paramtools-dev
22
channels:
33
- conda-forge
44
dependencies:
5-
- "marshmallow>=3.22.0"
5+
- "marshmallow>=4.0.0"
66
- "numpy>=2.1.0"
77
- "python-dateutil>=2.8.0"
88
- "pytest>=6.0.0"

paramtools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454

5555
name = "paramtools"
56-
__version__ = "0.19.0"
56+
__version__ = "0.20.0"
5757

5858
__all__ = [
5959
"SchemaFactory",

paramtools/parameters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __init__(
101101
else:
102102
self._stateless_label_grid[name] = []
103103
self.label_grid = copy.deepcopy(self._stateless_label_grid)
104-
self._validator_schema.context["spec"] = self
104+
self._validator_schema.pt_context["spec"] = self
105105
self._warnings = {}
106106
self._errors = {}
107107
self._defer_validation = False
@@ -364,7 +364,7 @@ def _adjust(
364364
for param, value in parsed_params.items():
365365
self._update_param(param, value)
366366

367-
self._validator_schema.context["spec"] = self
367+
self._validator_schema.pt_context["spec"] = self
368368

369369
has_errors = bool(self._errors.get("messages"))
370370
has_warnings = bool(self._warnings.get("messages"))
@@ -525,7 +525,7 @@ def _delete(
525525
if self.label_to_extend is not None and extend_adj:
526526
self.extend()
527527

528-
self._validator_schema.context["spec"] = self
528+
self._validator_schema.pt_context["spec"] = self
529529

530530
has_errors = bool(self._errors.get("messages"))
531531
has_warnings = bool(self._warnings.get("messages"))
@@ -1414,4 +1414,4 @@ def get_defaults(self):
14141414
- `params`: String if URL or file path. Dict if this is the loaded params
14151415
dict.
14161416
"""
1417-
return utils.read_json(self.defaults)
1417+
return utils.read_json(self.defaults)

paramtools/schema.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
validates_schema,
88
ValidationError as MarshmallowValidationError,
99
decorators,
10+
RAISE as RAISEUNKNOWNOPTION,
1011
)
1112
from marshmallow.error_store import ErrorStore
1213

@@ -28,14 +29,14 @@ class RangeSchema(Schema):
2829
}
2930
"""
3031

31-
_min = fields.Field(attribute="min", data_key="min")
32-
_max = fields.Field(attribute="max", data_key="max")
33-
step = fields.Field()
32+
_min = fields.Raw(attribute="min", data_key="min")
33+
_max = fields.Raw(attribute="max", data_key="max")
34+
step = fields.Raw()
3435
level = fields.String(validate=[validate.OneOf(["warn", "error"])])
3536

3637

3738
class ChoiceSchema(Schema):
38-
choices = fields.List(fields.Field)
39+
choices = fields.List(fields.Raw)
3940
level = fields.String(validate=[validate.OneOf(["warn", "error"])])
4041

4142

@@ -53,9 +54,9 @@ class ValueValidatorSchema(Schema):
5354

5455

5556
class IsSchema(Schema):
56-
equal_to = fields.Field(required=False)
57-
greater_than = fields.Field(required=False)
58-
less_than = fields.Field(required=False)
57+
equal_to = fields.Raw(required=False)
58+
greater_than = fields.Raw(required=False)
59+
less_than = fields.Raw(required=False)
5960

6061
@validates_schema
6162
def just_one(self, data, **kwargs):
@@ -107,15 +108,12 @@ class BaseParamSchema(Schema):
107108
data_key="type",
108109
)
109110
number_dims = fields.Integer(required=False, load_default=0)
110-
value = fields.Field(required=True) # will be specified later
111+
value = fields.Raw(required=True) # will be specified later
111112
validators = fields.Nested(
112113
ValueValidatorSchema(), required=False, load_default={}
113114
)
114115
indexed = fields.Boolean(required=False)
115116

116-
class Meta:
117-
ordered = True
118-
119117

120118
class EmptySchema(Schema):
121119
"""
@@ -126,15 +124,6 @@ class EmptySchema(Schema):
126124
pass
127125

128126

129-
class OrderedSchema(Schema):
130-
"""
131-
Same as `EmptySchema`, but preserves the order of its fields.
132-
"""
133-
134-
class Meta:
135-
ordered = True
136-
137-
138127
class ValueObject(fields.Nested):
139128
"""
140129
Schema for value objects
@@ -182,16 +171,17 @@ class BaseValidatorSchema(Schema):
182171
class.
183172
"""
184173

185-
class Meta:
186-
ordered = True
187-
188174
WRAPPER_MAP = {
189175
"range": "_get_range_validator",
190176
"date_range": "_get_range_validator",
191177
"choice": "_get_choice_validator",
192178
"when": "_get_when_validator",
193179
}
194180

181+
def __init__(self, *args, **kwargs):
182+
self.pt_context = {}
183+
super().__init__(*args, **kwargs)
184+
195185
def validate_only(self, data):
196186
"""
197187
Bypass deserialization and just run field validators. This is taken
@@ -208,21 +198,23 @@ def validate_only(self, data):
208198
field_errors = bool(error_store.errors)
209199
self._invoke_schema_validators(
210200
error_store=error_store,
211-
pass_many=True,
201+
pass_collection=True,
212202
data=data,
213203
original_data=data,
214204
many=None,
215205
partial=None,
216206
field_errors=field_errors,
207+
unknown=RAISEUNKNOWNOPTION,
217208
)
218209
self._invoke_schema_validators(
219210
error_store=error_store,
220-
pass_many=False,
211+
pass_collection=False,
221212
data=data,
222213
original_data=data,
223214
many=None,
224215
partial=None,
225216
field_errors=field_errors,
217+
unknown=RAISEUNKNOWNOPTION,
226218
)
227219
errors = error_store.errors
228220
if errors:
@@ -271,7 +263,7 @@ def validate_param(self, param_name, param_spec, raw_data):
271263
Do range validation for a parameter.
272264
"""
273265
validate_schema = not getattr(
274-
self.context["spec"], "_defer_validation", False
266+
self.pt_context["spec"], "_defer_validation", False
275267
)
276268
validators = self.validators(
277269
param_name, param_spec, raw_data, validate_schema=validate_schema
@@ -290,15 +282,15 @@ def validate_param(self, param_name, param_spec, raw_data):
290282
return warnings, errors
291283

292284
def field_keyfunc(self, param_name):
293-
data = self.context["spec"]._data[param_name]
285+
data = self.pt_context["spec"]._data[param_name]
294286
field = get_type(data, self.validators(param_name))
295287
try:
296288
return field.cmp_funcs()["key"]
297289
except AttributeError:
298290
return None
299291

300292
def field(self, param_name):
301-
data = self.context["spec"]._data[param_name]
293+
data = self.pt_context["spec"]._data[param_name]
302294
return get_type(data, self.validators(param_name))
303295

304296
def validators(
@@ -309,7 +301,7 @@ def validators(
309301
if raw_data is None:
310302
raw_data = {}
311303

312-
param_info = self.context["spec"]._data[param_name]
304+
param_info = self.pt_context["spec"]._data[param_name]
313305
# sort keys to guarantee order.
314306
validator_spec = param_info.get("validators", {})
315307
validators = []
@@ -347,7 +339,7 @@ def _get_when_validator(
347339
when_param = when_dict["param"]
348340

349341
if (
350-
when_param not in self.context["spec"]._data.keys()
342+
when_param not in self.pt_context["spec"]._data.keys()
351343
and when_param != "default"
352344
):
353345
raise MarshmallowValidationError(
@@ -382,8 +374,8 @@ def _get_when_validator(
382374
)
383375
)
384376

385-
_type = self.context["spec"]._data[oth_param]["type"]
386-
number_dims = self.context["spec"]._data[oth_param]["number_dims"]
377+
_type = self.pt_context["spec"]._data[oth_param]["type"]
378+
number_dims = self.pt_context["spec"]._data[oth_param]["number_dims"]
387379

388380
error_then = (
389381
f"When {oth_param}{{when_labels}}{{ix}} is {{is_val}}, "
@@ -469,9 +461,9 @@ def _get_range_validator(
469461
)
470462

471463
def _sort_by_label_to_extend(self, vos):
472-
label_to_extend = self.context["spec"].label_to_extend
464+
label_to_extend = self.pt_context["spec"].label_to_extend
473465
if label_to_extend is not None:
474-
label_grid = self.context["spec"]._stateless_label_grid
466+
label_grid = self.pt_context["spec"]._stateless_label_grid
475467
extend_vals = label_grid[label_to_extend]
476468
return sorted(
477469
vos,
@@ -533,9 +525,9 @@ def _get_related_value(
533525
# If comparing against the "default" value then get the current
534526
# value of the parameter being updated.
535527
if oth_param_name == "default":
536-
oth_param = self.context["spec"]._data[param_name]
528+
oth_param = self.pt_context["spec"]._data[param_name]
537529
else:
538-
oth_param = self.context["spec"]._data[oth_param_name]
530+
oth_param = self.pt_context["spec"]._data[oth_param_name]
539531
vals = oth_param["value"]
540532
labs_to_check = {k for k in param_spec if k not in ("value", "_auto")}
541533
if labs_to_check:
@@ -560,11 +552,11 @@ def _check_ndim_restriction(
560552
if other_param is None:
561553
continue
562554
if other_param == "default":
563-
ndims = self.context["spec"]._data[param_name][
555+
ndims = self.pt_context["spec"]._data[param_name][
564556
"number_dims"
565557
]
566558
else:
567-
ndims = self.context["spec"]._data[other_param][
559+
ndims = self.pt_context["spec"]._data[other_param][
568560
"number_dims"
569561
]
570562
if ndims > 0:

paramtools/schema_factory.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from marshmallow import fields
1+
from marshmallow import fields, Schema
22

33
from paramtools.schema import (
4-
OrderedSchema,
54
BaseValidatorSchema,
65
ValueObject,
76
get_type,
@@ -67,17 +66,15 @@ def schemas(self):
6766
# if not isinstance(v["value"], list):
6867
# v["value"] = [{"value": v["value"]}]
6968

70-
validator_dict[k] = type(
71-
"ValidatorItem", (OrderedSchema,), classattrs
72-
)
69+
validator_dict[k] = type("ValidatorItem", (Schema,), classattrs)
7370

7471
classattrs = {"value": ValueObject(validator_dict[k], many=True)}
7572
param_dict[k] = type(
7673
"IndividualParamSchema", (self.BaseParamSchema,), classattrs
7774
)
7875

7976
classattrs = {k: fields.Nested(v) for k, v in param_dict.items()}
80-
DefaultsSchema = type("DefaultsSchema", (OrderedSchema,), classattrs)
77+
DefaultsSchema = type("DefaultsSchema", (Schema,), classattrs)
8178
defaults_schema = DefaultsSchema()
8279

8380
classattrs = {

paramtools/tests/test_parameters.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,6 @@ def get_defaults(self):
166166
assert params.hello_world == "hello world"
167167
assert params.label_grid == {"somelabel": [0, 1, 2, 3, 4, 5]}
168168

169-
170169
def test_schema_not_dropped(self, defaults_spec_path):
171170
with open(defaults_spec_path, "r") as f:
172171
defaults_ = json.loads(f.read())
@@ -379,14 +378,6 @@ def test_specification(self, TestParams, defaults_spec_path):
379378

380379
assert spec1["min_int_param"] == exp["min_int_param"]["value"]
381380

382-
def test_is_ordered(self, TestParams):
383-
params = TestParams()
384-
spec1 = params.specification()
385-
assert isinstance(spec1, OrderedDict)
386-
387-
spec2 = params.specification(meta_data=True, serializable=True)
388-
assert isinstance(spec2, OrderedDict)
389-
390381
def test_specification_query(self, TestParams):
391382
params = TestParams()
392383
spec1 = params.specification()

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
setuptools.setup(
88
name="paramtools",
9-
version=os.environ.get("VERSION", "0.19.0"),
9+
version=os.environ.get("VERSION", "0.20.0"),
1010
author="Hank Doupe",
1111
author_email="henrymdoupe@gmail.com",
1212
description=(
@@ -18,7 +18,7 @@
1818
url="https://github.com/hdoupe/ParamTools",
1919
packages=setuptools.find_packages(),
2020
install_requires=[
21-
"marshmallow>=3.0.0",
21+
"marshmallow>=4.0.0",
2222
"numpy",
2323
"python-dateutil>=2.8.0",
2424
"fsspec",

0 commit comments

Comments
 (0)