Skip to content

Commit b11abc0

Browse files
Optimise vectorial parameter lookups for US simulation
Replace O(N×K) numpy.select with O(N) index-based selection in VectorialParameterNodeAtInstant.__getitem__. For enum/EnumArray keys, build a lookup table mapping integer codes directly to child indices, avoiding the intermediate string conversion entirely. For string keys, use numpy.unique to reduce N×K string comparisons to U dict lookups (where U = unique keys, typically ≪ N). Also cache build_from_node results on ParameterNodeAtInstant to avoid rebuilding the recarray on every vectorial access. US household_net_income compute: 12.8s → 9.0s (-30%). Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 047a845 commit b11abc0

3 files changed

Lines changed: 239 additions & 67 deletions

File tree

changelog_entry.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
changes:
33
fixed:
44
- Optimisation improvements for loading tax-benefit systems (caching).
5+
- Replaced O(N×K) numpy.select in vectorial parameter lookups with O(N) index-based selection, and cached build_from_node results. US simulation compute -30%, from 12.8s to 9.0s.

policyengine_core/parameters/parameter_node_at_instant.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,19 @@ def __getitem__(
5555
) -> Union["ParameterNodeAtInstant", VectorialParameterNodeAtInstant]:
5656
# If fancy indexing is used, cast to a vectorial node
5757
if isinstance(key, numpy.ndarray):
58-
return parameters.VectorialParameterNodeAtInstant.build_from_node(
59-
self
60-
)[key]
58+
# Cache the vectorial node to avoid rebuilding the recarray on
59+
# every call — build_from_node is expensive (walks the full
60+
# parameter subtree each time).
61+
try:
62+
vectorial = self._vectorial_node
63+
except AttributeError:
64+
vectorial = (
65+
parameters.VectorialParameterNodeAtInstant.build_from_node(
66+
self
67+
)
68+
)
69+
self._vectorial_node = vectorial
70+
return vectorial[key]
6171
return self._children[key]
6272

6373
def __iter__(self) -> Iterable:

policyengine_core/parameters/vectorial_parameter_node_at_instant.py

Lines changed: 225 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -197,87 +197,248 @@ def __getitem__(self, key: str) -> Any:
197197
return self.__getattr__(key)
198198
# If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1']
199199
elif isinstance(key, numpy.ndarray):
200-
if not numpy.issubdtype(key.dtype, numpy.str_):
201-
# In case the key is not a string vector, stringify it
202-
if key.dtype == object and issubclass(type(key[0]), Enum):
203-
enum = type(key[0])
204-
key = numpy.select(
205-
[key == item for item in enum],
206-
[str(item.name) for item in enum],
207-
default="unknown",
208-
)
209-
elif isinstance(key, EnumArray):
210-
enum = key.possible_values
211-
key = numpy.select(
212-
[key == item.index for item in enum],
213-
[item.name for item in enum],
214-
default="unknown",
215-
)
216-
else:
200+
names = self.dtype.names
201+
# Build name→child-index mapping (cached on instance)
202+
if not hasattr(self, "_name_to_child_idx"):
203+
self._name_to_child_idx = {
204+
name: i for i, name in enumerate(names)
205+
}
206+
207+
name_to_child_idx = self._name_to_child_idx
208+
n = len(key)
209+
SENTINEL = len(names)
210+
211+
# Convert key to integer indices directly, avoiding
212+
# expensive intermediate string arrays where possible.
213+
if isinstance(key, EnumArray):
214+
# EnumArray: map enum int codes → child indices via
215+
# a pre-built lookup table (O(N), no string comparison).
216+
enum = key.possible_values
217+
cache_key = id(enum)
218+
if not hasattr(self, "_enum_lut_cache"):
219+
self._enum_lut_cache = {}
220+
lut = self._enum_lut_cache.get(cache_key)
221+
if lut is None:
222+
enum_items = list(enum)
223+
max_code = max(item.index for item in enum_items) + 1
224+
lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp)
225+
for item in enum_items:
226+
child_idx = name_to_child_idx.get(item.name)
227+
if child_idx is not None:
228+
lut[item.index] = child_idx
229+
self._enum_lut_cache[cache_key] = lut
230+
idx = lut[numpy.asarray(key)]
231+
elif key.dtype == object and len(key) > 0 and issubclass(type(key[0]), Enum):
232+
# Object array of Enum instances
233+
enum = type(key[0])
234+
cache_key = id(enum)
235+
if not hasattr(self, "_enum_lut_cache"):
236+
self._enum_lut_cache = {}
237+
lut = self._enum_lut_cache.get(cache_key)
238+
if lut is None:
239+
enum_items = list(enum)
240+
max_code = max(item.index for item in enum_items) + 1
241+
lut = numpy.full(max_code, SENTINEL, dtype=numpy.intp)
242+
for item in enum_items:
243+
child_idx = name_to_child_idx.get(str(item.name))
244+
if child_idx is not None:
245+
lut[item.index] = child_idx
246+
self._enum_lut_cache[cache_key] = lut
247+
codes = numpy.array(
248+
[v.index for v in key], dtype=numpy.intp
249+
)
250+
idx = lut[codes]
251+
else:
252+
# String keys: map via dict lookup
253+
if not numpy.issubdtype(key.dtype, numpy.str_):
217254
key = key.astype("str")
218-
names = list(
219-
self.dtype.names
220-
) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2']
221-
conditions = [key == name for name in names]
255+
# Vectorised dict lookup using numpy unique + scatter
256+
uniq, inverse = numpy.unique(key, return_inverse=True)
257+
uniq_idx = numpy.array(
258+
[name_to_child_idx.get(u, SENTINEL) for u in uniq],
259+
dtype=numpy.intp,
260+
)
261+
idx = uniq_idx[inverse]
262+
263+
# Gather values by child index using take on a stacked array.
222264
values = [self.vector[name] for name in names]
223265

224-
# NumPy 2.x requires all arrays in numpy.select to have identical dtypes
225-
# For structured arrays with different field sets, we need to normalize them
226-
if (
266+
is_structured = (
227267
len(values) > 0
228268
and hasattr(values[0].dtype, "names")
229269
and values[0].dtype.names
230-
):
231-
# Check if all values have the same dtype
270+
)
271+
272+
if is_structured:
232273
dtypes_match = all(
233274
val.dtype == values[0].dtype for val in values
234275
)
276+
v0_len = len(values[0])
235277

236-
if not dtypes_match:
237-
# Find the union of all field names across all values, preserving first seen order
238-
all_fields = []
239-
seen = set()
240-
for val in values:
241-
for field in val.dtype.names:
242-
if field not in seen:
243-
all_fields.append(field)
244-
seen.add(field)
245-
246-
# Create unified dtype with all fields
247-
unified_dtype = numpy.dtype(
248-
[(f, "<f8") for f in all_fields]
249-
)
278+
if v0_len <= 1:
279+
# 1-element structured arrays: simple concat + index
280+
if not dtypes_match:
281+
all_fields = []
282+
seen = set()
283+
for val in values:
284+
for field in val.dtype.names:
285+
if field not in seen:
286+
all_fields.append(field)
287+
seen.add(field)
250288

251-
# Cast all values to unified dtype
252-
values_cast = []
253-
for val in values:
254-
casted = numpy.zeros(len(val), dtype=unified_dtype)
255-
for field in val.dtype.names:
256-
casted[field] = val[field]
257-
values_cast.append(casted)
289+
unified_dtype = numpy.dtype(
290+
[(f, "<f8") for f in all_fields]
291+
)
258292

259-
default = numpy.zeros(
260-
len(values_cast[0]), dtype=unified_dtype
261-
)
262-
# Fill with NaN
263-
for field in unified_dtype.names:
264-
default[field] = numpy.nan
293+
values_cast = []
294+
for val in values:
295+
casted = numpy.zeros(
296+
len(val), dtype=unified_dtype
297+
)
298+
for field in val.dtype.names:
299+
casted[field] = val[field]
300+
values_cast.append(casted)
265301

266-
result = numpy.select(conditions, values_cast, default)
302+
default = numpy.zeros(1, dtype=unified_dtype)
303+
for field in unified_dtype.names:
304+
default[field] = numpy.nan
305+
stacked = numpy.concatenate(
306+
values_cast + [default]
307+
)
308+
result = stacked[idx]
309+
else:
310+
default = numpy.full(
311+
1, numpy.nan, dtype=values[0].dtype
312+
)
313+
stacked = numpy.concatenate(values + [default])
314+
result = stacked[idx]
267315
else:
268-
# All dtypes match, use original logic
269-
default = numpy.full_like(values[0], numpy.nan)
270-
result = numpy.select(conditions, values, default)
316+
# N-element structured arrays: check if fields are
317+
# simple scalars (fast path) or nested records
318+
# (fall back to numpy.select).
319+
first_field = values[0].dtype.names[0]
320+
field_dtype = values[0][first_field].dtype
321+
is_nested = (
322+
hasattr(field_dtype, "names")
323+
and field_dtype.names is not None
324+
)
325+
326+
if is_nested:
327+
# Nested structured: fall back to numpy.select
328+
conditions = [
329+
idx == i for i in range(len(values))
330+
]
331+
if not dtypes_match:
332+
all_fields = []
333+
seen = set()
334+
for val in values:
335+
for field in val.dtype.names:
336+
if field not in seen:
337+
all_fields.append(field)
338+
seen.add(field)
339+
unified_dtype = numpy.dtype(
340+
[(f, "<f8") for f in all_fields]
341+
)
342+
values_cast = []
343+
for val in values:
344+
casted = numpy.zeros(
345+
len(val), dtype=unified_dtype
346+
)
347+
for field in val.dtype.names:
348+
casted[field] = val[field]
349+
values_cast.append(casted)
350+
default = numpy.zeros(
351+
v0_len, dtype=unified_dtype
352+
)
353+
for field in unified_dtype.names:
354+
default[field] = numpy.nan
355+
result = numpy.select(
356+
conditions, values_cast, default
357+
)
358+
else:
359+
default = numpy.full_like(
360+
values[0], numpy.nan
361+
)
362+
result = numpy.select(
363+
conditions, values, default
364+
)
365+
else:
366+
# Flat structured: fast per-field indexing
367+
if not dtypes_match:
368+
all_fields = []
369+
seen = set()
370+
for val in values:
371+
for field in val.dtype.names:
372+
if field not in seen:
373+
all_fields.append(field)
374+
seen.add(field)
375+
unified_dtype = numpy.dtype(
376+
[(f, "<f8") for f in all_fields]
377+
)
378+
values_unified = []
379+
for val in values:
380+
casted = numpy.zeros(
381+
len(val), dtype=unified_dtype
382+
)
383+
for field in val.dtype.names:
384+
casted[field] = val[field]
385+
values_unified.append(casted)
386+
field_names = all_fields
387+
result_dtype = unified_dtype
388+
else:
389+
values_unified = values
390+
field_names = values[0].dtype.names
391+
result_dtype = values[0].dtype
392+
393+
result = numpy.empty(n, dtype=result_dtype)
394+
arange_n = numpy.arange(v0_len)
395+
for field in field_names:
396+
field_stack = numpy.empty(
397+
(len(values_unified) + 1, v0_len),
398+
dtype=numpy.float64,
399+
)
400+
for i, v in enumerate(values_unified):
401+
field_stack[i] = v[field]
402+
field_stack[-1] = numpy.nan
403+
result[field] = field_stack[
404+
idx, arange_n
405+
]
271406
else:
272-
# Non-structured array case
273-
default = numpy.full_like(
274-
values[0] if values else self.vector[key[0]], numpy.nan
275-
)
276-
result = numpy.select(conditions, values, default)
407+
# Non-structured: values are either scalars (1-elem arrays)
408+
# or N-element vectors (after prior vectorial indexing).
409+
if values:
410+
v0 = numpy.asarray(values[0])
411+
if v0.ndim == 0 or v0.shape[0] <= 1:
412+
# Scalar per child: 1D lookup
413+
scalar_vals = numpy.empty(
414+
len(values) + 1, dtype=numpy.float64
415+
)
416+
for i, v in enumerate(values):
417+
scalar_vals[i] = float(v)
418+
scalar_vals[-1] = numpy.nan
419+
result = scalar_vals[idx]
420+
else:
421+
# N-element vectors: stack into (K+1, N) matrix
422+
m = v0.shape[0]
423+
stacked = numpy.empty(
424+
(len(values) + 1, m), dtype=numpy.float64
425+
)
426+
for i, v in enumerate(values):
427+
stacked[i] = v
428+
stacked[-1] = numpy.nan
429+
result = stacked[idx, numpy.arange(m)]
430+
else:
431+
result = numpy.full(n, numpy.nan)
277432

278-
# Check for unexpected keys (NaN results from missing keys)
433+
# Check for unexpected keys
279434
if helpers.contains_nan(result):
280-
unexpected_keys = set(key).difference(self.vector.dtype.names)
435+
unexpected_keys = set(
436+
numpy.asarray(key, dtype=str)
437+
if not numpy.issubdtype(
438+
numpy.asarray(key).dtype, numpy.str_
439+
)
440+
else key
441+
).difference(self.vector.dtype.names)
281442
if unexpected_keys:
282443
unexpected_key = unexpected_keys.pop()
283444
raise ParameterNotFoundError(

0 commit comments

Comments
 (0)