-
Notifications
You must be signed in to change notification settings - Fork 27
Expand file tree
/
Copy pathparameter_scale.py
More file actions
189 lines (169 loc) · 6.74 KB
/
parameter_scale.py
File metadata and controls
189 lines (169 loc) · 6.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import copy
import os
import typing
from typing import Any, Iterable
from collections import OrderedDict
from policyengine_core import commons, parameters, tools
from policyengine_core.errors import ParameterParsingError
from policyengine_core.parameters import AtInstantLike, config, helpers
from policyengine_core.periods.instant_ import Instant
from policyengine_core.taxscales import (
LinearAverageRateTaxScale,
MarginalAmountTaxScale,
MarginalRateTaxScale,
SingleAmountTaxScale,
TaxScaleLike,
)
class ParameterScale(AtInstantLike):
"""
A parameter scale (for instance a marginal scale).
"""
# 'unit' and 'reference' are only listed here for backward compatibility
_allowed_keys = config.COMMON_KEYS.union({"brackets"})
_exclusion_list = ["parent", "_at_instant_cache"]
"""The keys to be excluded from the node when output to a yaml file."""
def __init__(self, name: str, data: dict, file_path: str):
"""
:param name: name of the scale, eg "taxes.some_scale"
:param data: Data loaded from a YAML file. In case of a reform, the data can also be created dynamically.
:param file_path: File the parameter was loaded from.
"""
self.name: str = name
self.file_path: str = file_path
helpers._validate_parameter(
self, data, data_type=dict, allowed_keys=self._allowed_keys
)
self.description: str = data.get("description")
self.metadata: typing.Dict = {}
self.metadata.update(data.get("metadata", {}))
if not isinstance(data.get("brackets", []), list):
raise ParameterParsingError(
"Property 'brackets' of scale '{}' must be of type array.".format(
self.name
),
self.file_path,
)
brackets = []
for i, bracket_data in enumerate(data.get("brackets", [])):
bracket_name = helpers._compose_name(name, item_name=i)
bracket = parameters.ParameterScaleBracket(
name=bracket_name, data=bracket_data, file_path=file_path
)
brackets.append(bracket)
self.brackets: typing.List[parameters.ParameterScaleBracket] = brackets
self.propagate_uprating()
self.propagate_units()
def __getitem__(self, key: str) -> Any:
if isinstance(key, int) and key < len(self.brackets):
return self.brackets[key]
else:
raise KeyError(key)
def __repr__(self) -> str:
return os.linesep.join(
["brackets:"]
+ [
tools.indent("-" + tools.indent(repr(bracket))[1:])
for bracket in self.brackets
]
)
def propagate_units(self) -> None:
unit_keys = filter(
lambda k: k in self.metadata,
parameters.ParameterScaleBracket.allowed_unit_keys(),
)
for unit_key in unit_keys:
child_key = unit_key[:-5]
for bracket in self.brackets:
if (
child_key in bracket.children
and "unit" not in bracket.children[child_key].metadata
):
bracket.children[child_key].metadata["unit"] = (
self.metadata[unit_key]
)
def propagate_uprating(self) -> None:
for bracket in self.brackets:
bracket.propagate_uprating(
self.metadata.get("uprating"),
threshold=self.metadata.get("uprate_thresholds", False),
)
def get_descendants(self) -> Iterable:
for bracket in self.brackets:
yield bracket
yield from bracket.get_descendants()
def clone(self) -> "ParameterScale":
clone = commons.empty_clone(self)
clone.__dict__ = self.__dict__.copy()
clone.brackets = [bracket.clone() for bracket in self.brackets]
clone.metadata = copy.deepcopy(self.metadata)
return clone
def _get_at_instant(self, instant: Instant) -> TaxScaleLike:
brackets = [
bracket.get_at_instant(instant) for bracket in self.brackets
]
if self.metadata.get("type") == "single_amount":
scale = SingleAmountTaxScale()
for bracket in brackets:
if (
"amount" in bracket._children
and "threshold" in bracket._children
):
amount = bracket.amount
threshold = bracket.threshold
scale.add_bracket(threshold, amount)
return scale
elif any("amount" in bracket._children for bracket in brackets):
scale = MarginalAmountTaxScale()
for bracket in brackets:
if (
"amount" in bracket._children
and "threshold" in bracket._children
):
amount = bracket.amount
threshold = bracket.threshold
scale.add_bracket(threshold, amount)
return scale
elif any("average_rate" in bracket._children for bracket in brackets):
scale = LinearAverageRateTaxScale()
for bracket in brackets:
if "base" in bracket._children:
base = bracket.base
else:
base = 1.0
if (
"average_rate" in bracket._children
and "threshold" in bracket._children
):
average_rate = bracket.average_rate
threshold = bracket.threshold
scale.add_bracket(threshold, average_rate * base)
return scale
else:
scale = MarginalRateTaxScale()
for bracket in brackets:
if "base" in bracket._children:
base = bracket.base
else:
base = 1.0
if (
"rate" in bracket._children
and "threshold" in bracket._children
):
rate = bracket.rate
threshold = bracket.threshold
scale.add_bracket(threshold, rate * base)
return scale
def get_attr_dict(self) -> dict:
data = OrderedDict(self.__dict__.copy())
for attr in self._exclusion_list:
if attr in data.keys():
del data[attr]
if "brackets" in data.keys():
node_list = data["brackets"]
i = 0
for node in node_list:
node_list[i] = node.get_attr_dict()
i += 1
data["brackets"] = node_list
data.move_to_end("brackets")
return dict(data)