@@ -197,87 +197,248 @@ def __getitem__(self, key: str) -> Any:
197197 return self .__getattr__ (key )
198198 # If the key is a vector, e.g. ['zone_1', 'zone_2', 'zone_1']
199199 elif isinstance (key , numpy .ndarray ):
200- if not numpy .issubdtype (key .dtype , numpy .str_ ):
201- # In case the key is not a string vector, stringify it
202- if key .dtype == object and issubclass (type (key [0 ]), Enum ):
203- enum = type (key [0 ])
204- key = numpy .select (
205- [key == item for item in enum ],
206- [str (item .name ) for item in enum ],
207- default = "unknown" ,
208- )
209- elif isinstance (key , EnumArray ):
210- enum = key .possible_values
211- key = numpy .select (
212- [key == item .index for item in enum ],
213- [item .name for item in enum ],
214- default = "unknown" ,
215- )
216- else :
200+ names = self .dtype .names
201+ # Build name→child-index mapping (cached on instance)
202+ if not hasattr (self , "_name_to_child_idx" ):
203+ self ._name_to_child_idx = {
204+ name : i for i , name in enumerate (names )
205+ }
206+
207+ name_to_child_idx = self ._name_to_child_idx
208+ n = len (key )
209+ SENTINEL = len (names )
210+
211+ # Convert key to integer indices directly, avoiding
212+ # expensive intermediate string arrays where possible.
213+ if isinstance (key , EnumArray ):
214+ # EnumArray: map enum int codes → child indices via
215+ # a pre-built lookup table (O(N), no string comparison).
216+ enum = key .possible_values
217+ cache_key = id (enum )
218+ if not hasattr (self , "_enum_lut_cache" ):
219+ self ._enum_lut_cache = {}
220+ lut = self ._enum_lut_cache .get (cache_key )
221+ if lut is None :
222+ enum_items = list (enum )
223+ max_code = max (item .index for item in enum_items ) + 1
224+ lut = numpy .full (max_code , SENTINEL , dtype = numpy .intp )
225+ for item in enum_items :
226+ child_idx = name_to_child_idx .get (item .name )
227+ if child_idx is not None :
228+ lut [item .index ] = child_idx
229+ self ._enum_lut_cache [cache_key ] = lut
230+ idx = lut [numpy .asarray (key )]
231+ elif key .dtype == object and len (key ) > 0 and issubclass (type (key [0 ]), Enum ):
232+ # Object array of Enum instances
233+ enum = type (key [0 ])
234+ cache_key = id (enum )
235+ if not hasattr (self , "_enum_lut_cache" ):
236+ self ._enum_lut_cache = {}
237+ lut = self ._enum_lut_cache .get (cache_key )
238+ if lut is None :
239+ enum_items = list (enum )
240+ max_code = max (item .index for item in enum_items ) + 1
241+ lut = numpy .full (max_code , SENTINEL , dtype = numpy .intp )
242+ for item in enum_items :
243+ child_idx = name_to_child_idx .get (str (item .name ))
244+ if child_idx is not None :
245+ lut [item .index ] = child_idx
246+ self ._enum_lut_cache [cache_key ] = lut
247+ codes = numpy .array (
248+ [v .index for v in key ], dtype = numpy .intp
249+ )
250+ idx = lut [codes ]
251+ else :
252+ # String keys: map via dict lookup
253+ if not numpy .issubdtype (key .dtype , numpy .str_ ):
217254 key = key .astype ("str" )
218- names = list (
219- self .dtype .names
220- ) # Get all the names of the subnodes, e.g. ['zone_1', 'zone_2']
221- conditions = [key == name for name in names ]
255+ # Vectorised dict lookup using numpy unique + scatter
256+ uniq , inverse = numpy .unique (key , return_inverse = True )
257+ uniq_idx = numpy .array (
258+ [name_to_child_idx .get (u , SENTINEL ) for u in uniq ],
259+ dtype = numpy .intp ,
260+ )
261+ idx = uniq_idx [inverse ]
262+
263+ # Gather values by child index using take on a stacked array.
222264 values = [self .vector [name ] for name in names ]
223265
224- # NumPy 2.x requires all arrays in numpy.select to have identical dtypes
225- # For structured arrays with different field sets, we need to normalize them
226- if (
266+ is_structured = (
227267 len (values ) > 0
228268 and hasattr (values [0 ].dtype , "names" )
229269 and values [0 ].dtype .names
230- ):
231- # Check if all values have the same dtype
270+ )
271+
272+ if is_structured :
232273 dtypes_match = all (
233274 val .dtype == values [0 ].dtype for val in values
234275 )
276+ v0_len = len (values [0 ])
235277
236- if not dtypes_match :
237- # Find the union of all field names across all values, preserving first seen order
238- all_fields = []
239- seen = set ()
240- for val in values :
241- for field in val .dtype .names :
242- if field not in seen :
243- all_fields .append (field )
244- seen .add (field )
245-
246- # Create unified dtype with all fields
247- unified_dtype = numpy .dtype (
248- [(f , "<f8" ) for f in all_fields ]
249- )
278+ if v0_len <= 1 :
279+ # 1-element structured arrays: simple concat + index
280+ if not dtypes_match :
281+ all_fields = []
282+ seen = set ()
283+ for val in values :
284+ for field in val .dtype .names :
285+ if field not in seen :
286+ all_fields .append (field )
287+ seen .add (field )
250288
251- # Cast all values to unified dtype
252- values_cast = []
253- for val in values :
254- casted = numpy .zeros (len (val ), dtype = unified_dtype )
255- for field in val .dtype .names :
256- casted [field ] = val [field ]
257- values_cast .append (casted )
289+ unified_dtype = numpy .dtype (
290+ [(f , "<f8" ) for f in all_fields ]
291+ )
258292
259- default = numpy .zeros (
260- len (values_cast [0 ]), dtype = unified_dtype
261- )
262- # Fill with NaN
263- for field in unified_dtype .names :
264- default [field ] = numpy .nan
293+ values_cast = []
294+ for val in values :
295+ casted = numpy .zeros (
296+ len (val ), dtype = unified_dtype
297+ )
298+ for field in val .dtype .names :
299+ casted [field ] = val [field ]
300+ values_cast .append (casted )
265301
266- result = numpy .select (conditions , values_cast , default )
302+ default = numpy .zeros (1 , dtype = unified_dtype )
303+ for field in unified_dtype .names :
304+ default [field ] = numpy .nan
305+ stacked = numpy .concatenate (
306+ values_cast + [default ]
307+ )
308+ result = stacked [idx ]
309+ else :
310+ default = numpy .full (
311+ 1 , numpy .nan , dtype = values [0 ].dtype
312+ )
313+ stacked = numpy .concatenate (values + [default ])
314+ result = stacked [idx ]
267315 else :
268- # All dtypes match, use original logic
269- default = numpy .full_like (values [0 ], numpy .nan )
270- result = numpy .select (conditions , values , default )
316+ # N-element structured arrays: check if fields are
317+ # simple scalars (fast path) or nested records
318+ # (fall back to numpy.select).
319+ first_field = values [0 ].dtype .names [0 ]
320+ field_dtype = values [0 ][first_field ].dtype
321+ is_nested = (
322+ hasattr (field_dtype , "names" )
323+ and field_dtype .names is not None
324+ )
325+
326+ if is_nested :
327+ # Nested structured: fall back to numpy.select
328+ conditions = [
329+ idx == i for i in range (len (values ))
330+ ]
331+ if not dtypes_match :
332+ all_fields = []
333+ seen = set ()
334+ for val in values :
335+ for field in val .dtype .names :
336+ if field not in seen :
337+ all_fields .append (field )
338+ seen .add (field )
339+ unified_dtype = numpy .dtype (
340+ [(f , "<f8" ) for f in all_fields ]
341+ )
342+ values_cast = []
343+ for val in values :
344+ casted = numpy .zeros (
345+ len (val ), dtype = unified_dtype
346+ )
347+ for field in val .dtype .names :
348+ casted [field ] = val [field ]
349+ values_cast .append (casted )
350+ default = numpy .zeros (
351+ v0_len , dtype = unified_dtype
352+ )
353+ for field in unified_dtype .names :
354+ default [field ] = numpy .nan
355+ result = numpy .select (
356+ conditions , values_cast , default
357+ )
358+ else :
359+ default = numpy .full_like (
360+ values [0 ], numpy .nan
361+ )
362+ result = numpy .select (
363+ conditions , values , default
364+ )
365+ else :
366+ # Flat structured: fast per-field indexing
367+ if not dtypes_match :
368+ all_fields = []
369+ seen = set ()
370+ for val in values :
371+ for field in val .dtype .names :
372+ if field not in seen :
373+ all_fields .append (field )
374+ seen .add (field )
375+ unified_dtype = numpy .dtype (
376+ [(f , "<f8" ) for f in all_fields ]
377+ )
378+ values_unified = []
379+ for val in values :
380+ casted = numpy .zeros (
381+ len (val ), dtype = unified_dtype
382+ )
383+ for field in val .dtype .names :
384+ casted [field ] = val [field ]
385+ values_unified .append (casted )
386+ field_names = all_fields
387+ result_dtype = unified_dtype
388+ else :
389+ values_unified = values
390+ field_names = values [0 ].dtype .names
391+ result_dtype = values [0 ].dtype
392+
393+ result = numpy .empty (n , dtype = result_dtype )
394+ arange_n = numpy .arange (v0_len )
395+ for field in field_names :
396+ field_stack = numpy .empty (
397+ (len (values_unified ) + 1 , v0_len ),
398+ dtype = numpy .float64 ,
399+ )
400+ for i , v in enumerate (values_unified ):
401+ field_stack [i ] = v [field ]
402+ field_stack [- 1 ] = numpy .nan
403+ result [field ] = field_stack [
404+ idx , arange_n
405+ ]
271406 else :
272- # Non-structured array case
273- default = numpy .full_like (
274- values [0 ] if values else self .vector [key [0 ]], numpy .nan
275- )
276- result = numpy .select (conditions , values , default )
407+ # Non-structured: values are either scalars (1-elem arrays)
408+ # or N-element vectors (after prior vectorial indexing).
409+ if values :
410+ v0 = numpy .asarray (values [0 ])
411+ if v0 .ndim == 0 or v0 .shape [0 ] <= 1 :
412+ # Scalar per child: 1D lookup
413+ scalar_vals = numpy .empty (
414+ len (values ) + 1 , dtype = numpy .float64
415+ )
416+ for i , v in enumerate (values ):
417+ scalar_vals [i ] = float (v )
418+ scalar_vals [- 1 ] = numpy .nan
419+ result = scalar_vals [idx ]
420+ else :
421+ # N-element vectors: stack into (K+1, N) matrix
422+ m = v0 .shape [0 ]
423+ stacked = numpy .empty (
424+ (len (values ) + 1 , m ), dtype = numpy .float64
425+ )
426+ for i , v in enumerate (values ):
427+ stacked [i ] = v
428+ stacked [- 1 ] = numpy .nan
429+ result = stacked [idx , numpy .arange (m )]
430+ else :
431+ result = numpy .full (n , numpy .nan )
277432
278- # Check for unexpected keys (NaN results from missing keys)
433+ # Check for unexpected keys
279434 if helpers .contains_nan (result ):
280- unexpected_keys = set (key ).difference (self .vector .dtype .names )
435+ unexpected_keys = set (
436+ numpy .asarray (key , dtype = str )
437+ if not numpy .issubdtype (
438+ numpy .asarray (key ).dtype , numpy .str_
439+ )
440+ else key
441+ ).difference (self .vector .dtype .names )
281442 if unexpected_keys :
282443 unexpected_key = unexpected_keys .pop ()
283444 raise ParameterNotFoundError (
0 commit comments