Skip to content

Commit 5638cac

Browse files
committed
Fix oversized Dask graph
By passing bh.Histogram as an argument to blowckwise functions, I was embedding an array in the Dask task graph. This led to large graphs when having many partitions or/and high-dimensionality histograms. I now pass the axes (and storage) to avoid serializing bh.Histogram array
1 parent 8af59e5 commit 5638cac

2 files changed

Lines changed: 21 additions & 19 deletions

File tree

changelog.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11

2+
### 0.2.2
3+
4+
- Fix oversized Dask graph
5+
26
### 0.2.1
37

48
- Fix: under/over flow attributes are int instead of bool to conform with NetCDF

src/xarray_histogram/core.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,9 @@ def histogramdd(
307307
input_core_dims=[list(dims) for _ in data],
308308
output_core_dims=[bins_names],
309309
vectorize=True,
310-
kwargs=dict(weight=weights is not None, flow=flow, histref=histref),
310+
kwargs=dict(
311+
weight=weights is not None, flow=flow, histref=(histref.axes, storage)
312+
),
311313
).rename(VAR_HIST)
312314

313315
hist = hist.assign_coords(coords)
@@ -320,33 +322,27 @@ def histogramdd(
320322
return hist
321323

322324

323-
def get_shape(histref: bh.Histogram, flow: bool) -> tuple[int, ...]:
325+
def get_shape(axes: bh.axis.AxesTuple, flow: bool) -> tuple[int, ...]:
324326
"""Return shape of histogram."""
325327
if flow:
326-
return histref.axes.extent
327-
return histref.shape
328-
329-
330-
def clone(histref: bh.Histogram) -> bh.Histogram:
331-
"""Clone reference histogram."""
332-
return bh.Histogram(*histref.axes, storage=histref.storage_type())
328+
return axes.extent
329+
return axes.size
333330

334331

335332
def _blocked_dd(
336333
*data: NDArray,
337334
weight: bool,
338335
flow: bool,
339-
histref: bh.Histogram,
336+
histref: tuple[bh.axis.AxesTuple, bh.storage.Storage],
340337
keepdims: bool = False,
341338
) -> NDArray:
342339
"""Compute histogram on whole arrays.
343340
344341
Arrays are already broadcasted.
345342
"""
346-
thehist = clone(histref)
343+
thehist = bh.Histogram(*histref[0], storage=histref[1])
347344
flattened = (np.reshape(x, (-1,)) for x in data)
348345

349-
thehist = clone(histref)
350346
if weight:
351347
*args, weights = flattened
352348
thehist.fill(*args, weight=weights)
@@ -367,7 +363,7 @@ def _blocked_dd_loop(
367363
axis_loop: abc.Sequence[int],
368364
weight: bool,
369365
flow: bool,
370-
histref: bh.Histogram,
366+
histref: tuple[bh.axis.AxesTuple, bh.storage.Storage],
371367
keepdims: bool = False,
372368
) -> NDArray:
373369
"""Compute multiple histograms on looping axis.
@@ -388,9 +384,9 @@ def _blocked_dd_loop(
388384
n_loop = reduce(operator.mul, shape_loop)
389385
flattened = [np.reshape(a, (n_loop, n_agg)) for a in ordered]
390386

391-
counts = np.zeros((n_loop, *get_shape(histref, flow)))
387+
counts = np.zeros((n_loop, *get_shape(histref[0], flow)))
392388
for i in range(n_loop):
393-
thehist = clone(histref)
389+
thehist = bh.Histogram(*histref[0], storage=histref[1])
394390
if weight:
395391
*args, weights = flattened
396392
thehist.fill(*[a[i] for a in args], weight=weights[i])
@@ -399,7 +395,7 @@ def _blocked_dd_loop(
399395

400396
counts[i] = thehist.values(flow)
401397

402-
counts = np.reshape(counts, (*shape_loop, *get_shape(histref, flow)))
398+
counts = np.reshape(counts, (*shape_loop, *get_shape(histref[0], flow)))
403399
if keepdims:
404400
counts = np.expand_dims(
405401
counts, tuple(_range(len(axis_loop), len(axis_loop) + len(axis_agg)))
@@ -417,17 +413,18 @@ def _histogram_dask(
417413
histref: bh.Histogram,
418414
) -> da.Array:
419415
"""Compute histogram for dask data."""
416+
histref_tuple = (histref.axes, histref.storage_type())
420417
if axis_loop:
421418
func = partial(
422419
_blocked_dd_loop,
423420
axis_loop=axis_loop,
424421
axis_agg=axis_agg,
425422
weight=weight,
426423
flow=flow,
427-
histref=histref,
424+
histref=histref_tuple,
428425
)
429426
else:
430-
func = partial(_blocked_dd, weight=weight, flow=flow, histref=histref)
427+
func = partial(_blocked_dd, weight=weight, flow=flow, histref=histref_tuple)
431428

432429
dtype = (
433430
int
@@ -445,13 +442,14 @@ def _histogram_dask(
445442
chunks=(
446443
*[data[0].chunks[i][0] for i in axis_loop],
447444
*[1 for _ in axis_agg],
448-
*get_shape(histref, flow),
445+
*get_shape(histref.axes, flow),
449446
),
450447
new_axis=[data[0].ndim + i for i in range(histref.ndim)],
451448
enforce_ndim=True,
452449
name="hist-on-block",
453450
meta=np.array((), dtype=dtype),
454451
)
452+
455453
reduc = da.reductions._tree_reduce(
456454
blocked,
457455
da.sum,

0 commit comments

Comments
 (0)