@@ -307,7 +307,9 @@ def histogramdd(
307307 input_core_dims = [list (dims ) for _ in data ],
308308 output_core_dims = [bins_names ],
309309 vectorize = True ,
310- kwargs = dict (weight = weights is not None , flow = flow , histref = histref ),
310+ kwargs = dict (
311+ weight = weights is not None , flow = flow , histref = (histref .axes , storage )
312+ ),
311313 ).rename (VAR_HIST )
312314
313315 hist = hist .assign_coords (coords )
@@ -320,33 +322,27 @@ def histogramdd(
320322 return hist
321323
322324
323- def get_shape (histref : bh .Histogram , flow : bool ) -> tuple [int , ...]:
325+ def get_shape (axes : bh .axis . AxesTuple , flow : bool ) -> tuple [int , ...]:
324326 """Return shape of histogram."""
325327 if flow :
326- return histref .axes .extent
327- return histref .shape
328-
329-
330- def clone (histref : bh .Histogram ) -> bh .Histogram :
331- """Clone reference histogram."""
332- return bh .Histogram (* histref .axes , storage = histref .storage_type ())
328+ return axes .extent
329+ return axes .size
333330
334331
335332def _blocked_dd (
336333 * data : NDArray ,
337334 weight : bool ,
338335 flow : bool ,
339- histref : bh .Histogram ,
336+ histref : tuple [ bh .axis . AxesTuple , bh . storage . Storage ] ,
340337 keepdims : bool = False ,
341338) -> NDArray :
342339 """Compute histogram on whole arrays.
343340
344341 Arrays are already broadcasted.
345342 """
346- thehist = clone ( histref )
343+ thehist = bh . Histogram ( * histref [ 0 ], storage = histref [ 1 ] )
347344 flattened = (np .reshape (x , (- 1 ,)) for x in data )
348345
349- thehist = clone (histref )
350346 if weight :
351347 * args , weights = flattened
352348 thehist .fill (* args , weight = weights )
@@ -367,7 +363,7 @@ def _blocked_dd_loop(
367363 axis_loop : abc .Sequence [int ],
368364 weight : bool ,
369365 flow : bool ,
370- histref : bh .Histogram ,
366+ histref : tuple [ bh .axis . AxesTuple , bh . storage . Storage ] ,
371367 keepdims : bool = False ,
372368) -> NDArray :
373369 """Compute multiple histograms on looping axis.
@@ -388,9 +384,9 @@ def _blocked_dd_loop(
388384 n_loop = reduce (operator .mul , shape_loop )
389385 flattened = [np .reshape (a , (n_loop , n_agg )) for a in ordered ]
390386
391- counts = np .zeros ((n_loop , * get_shape (histref , flow )))
387+ counts = np .zeros ((n_loop , * get_shape (histref [ 0 ] , flow )))
392388 for i in range (n_loop ):
393- thehist = clone ( histref )
389+ thehist = bh . Histogram ( * histref [ 0 ], storage = histref [ 1 ] )
394390 if weight :
395391 * args , weights = flattened
396392 thehist .fill (* [a [i ] for a in args ], weight = weights [i ])
@@ -399,7 +395,7 @@ def _blocked_dd_loop(
399395
400396 counts [i ] = thehist .values (flow )
401397
402- counts = np .reshape (counts , (* shape_loop , * get_shape (histref , flow )))
398+ counts = np .reshape (counts , (* shape_loop , * get_shape (histref [ 0 ] , flow )))
403399 if keepdims :
404400 counts = np .expand_dims (
405401 counts , tuple (_range (len (axis_loop ), len (axis_loop ) + len (axis_agg )))
@@ -417,17 +413,18 @@ def _histogram_dask(
417413 histref : bh .Histogram ,
418414) -> da .Array :
419415 """Compute histogram for dask data."""
416+ histref_tuple = (histref .axes , histref .storage_type ())
420417 if axis_loop :
421418 func = partial (
422419 _blocked_dd_loop ,
423420 axis_loop = axis_loop ,
424421 axis_agg = axis_agg ,
425422 weight = weight ,
426423 flow = flow ,
427- histref = histref ,
424+ histref = histref_tuple ,
428425 )
429426 else :
430- func = partial (_blocked_dd , weight = weight , flow = flow , histref = histref )
427+ func = partial (_blocked_dd , weight = weight , flow = flow , histref = histref_tuple )
431428
432429 dtype = (
433430 int
@@ -445,13 +442,14 @@ def _histogram_dask(
445442 chunks = (
446443 * [data [0 ].chunks [i ][0 ] for i in axis_loop ],
447444 * [1 for _ in axis_agg ],
448- * get_shape (histref , flow ),
445+ * get_shape (histref . axes , flow ),
449446 ),
450447 new_axis = [data [0 ].ndim + i for i in range (histref .ndim )],
451448 enforce_ndim = True ,
452449 name = "hist-on-block" ,
453450 meta = np .array ((), dtype = dtype ),
454451 )
452+
455453 reduc = da .reductions ._tree_reduce (
456454 blocked ,
457455 da .sum ,
0 commit comments