From 993e8011dc4882a6d04ccb1ad3bd11a338169cda Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Tue, 17 Mar 2026 17:20:43 -0700 Subject: [PATCH] Speed up cost_distance iterative tile Dijkstra 2-4x Batch-compute all dask tiles in a single scheduler pass and cache them for reuse across iterations, replacing per-tile .compute() calls that re-executed the dask graph each time. Store friction boundaries as float64 to skip repeated dtype conversion. Assemble the final result eagerly from cached tiles instead of through da.map_blocks. Pass precomputed f_min from dask+cupy fallback to avoid a redundant da.nanmin().compute(). Benchmarked improvement on the iterative (unbounded max_cost) path: 200x100: 0.206s -> 0.050s (4.1x), 1.90MB -> 0.80MB (-58%) 300x150: 0.229s -> 0.075s (3.0x), 2.91MB -> 1.38MB (-53%) 400x200: 0.263s -> 0.114s (2.3x), 4.02MB -> 2.19MB (-46%) numpy and dask-bounded (map_overlap) paths are unchanged. --- xrspatial/cost_distance.py | 146 +++++++++++++++++++++---------------- 1 file changed, 84 insertions(+), 62 deletions(-) diff --git a/xrspatial/cost_distance.py b/xrspatial/cost_distance.py index f3e08927..68c3b57a 100644 --- a/xrspatial/cost_distance.py +++ b/xrspatial/cost_distance.py @@ -374,20 +374,21 @@ def _cost_distance_dask_cupy(source_da, friction_da, height, width = source_da.shape use_map_overlap = False + f_min_cached = None if np.isfinite(max_cost): positive_friction = da.where(friction_da > 0, friction_da, np.inf) - f_min = float(da.nanmin(positive_friction).compute()) - if np.isfinite(f_min) and f_min > 0: + f_min_cached = float(da.nanmin(positive_friction).compute()) + if np.isfinite(f_min_cached) and f_min_cached > 0: min_cellsize = min(cellsize_x, cellsize_y) - max_radius = max_cost / (f_min * min_cellsize) + max_radius = max_cost / (f_min_cached * min_cellsize) pad = int(max_radius + 1) chunks_y, chunks_x = source_da.chunks if pad < max(chunks_y) and pad < max(chunks_x): use_map_overlap = True if use_map_overlap: - pad_y = int(max_cost / (f_min * cellsize_y) + 1) - pad_x = int(max_cost / (f_min * cellsize_x) + 1) + pad_y = int(max_cost / (f_min_cached * cellsize_y) + 1) + pad_x = int(max_cost / (f_min_cached * cellsize_x) + 1) # Closure captures the scalar parameters tv = target_values @@ -423,6 +424,7 @@ def _chunk_func(source_block, friction_block): source_np, friction_np, cellsize_x, cellsize_y, max_cost, target_values, dy, dx, dd, + _f_min=f_min_cached, ) # Convert back to dask+cupy return result.map_blocks( @@ -608,10 +610,18 @@ def _dist_to_float32(dist, height, width, max_cost): def _preprocess_tiles(source_da, friction_da, chunks_y, chunks_x, target_values): - """Extract friction boundary strips and identify source tiles. + """Extract friction boundary strips, identify source tiles, cache data. - Streams one tile at a time so only one chunk is in RAM. + Batch-computes all tiles in a single scheduler pass and caches them + for reuse during the iterative phase (avoids repeated .compute()). + Friction boundaries are stored as float64 to avoid repeated conversion + in _compute_seeds. + + Returns (friction_bdry, has_source, tile_cache) where tile_cache is + a dict mapping (iy, ix) -> (source_np, friction_np). """ + import dask + n_tile_y = len(chunks_y) n_tile_x = len(chunks_x) n_values = len(target_values) @@ -621,27 +631,40 @@ def _preprocess_tiles(source_da, friction_da, chunks_y, chunks_x, for side in ('top', 'bottom', 'left', 'right') } has_source = [[False] * n_tile_x for _ in range(n_tile_y)] + tile_cache = {} + # Batch-compute all tiles in one scheduler pass + blocks = [] + indices = [] for iy in range(n_tile_y): for ix in range(n_tile_x): - fchunk = friction_da.blocks[iy, ix].compute() - friction_bdry['top'][iy][ix] = fchunk[0, :].astype(np.float32) - friction_bdry['bottom'][iy][ix] = fchunk[-1, :].astype(np.float32) - friction_bdry['left'][iy][ix] = fchunk[:, 0].astype(np.float32) - friction_bdry['right'][iy][ix] = fchunk[:, -1].astype(np.float32) + blocks.append(source_da.blocks[iy, ix]) + blocks.append(friction_da.blocks[iy, ix]) + indices.append((iy, ix)) - schunk = source_da.blocks[iy, ix].compute() - if n_values == 0: - has_source[iy][ix] = bool( - np.any((schunk != 0) & np.isfinite(schunk)) - ) - else: - for tv in target_values: - if np.any(schunk == tv): - has_source[iy][ix] = True - break + computed = dask.compute(*blocks) + + for i, (iy, ix) in enumerate(indices): + schunk = computed[i * 2] + fchunk = computed[i * 2 + 1] + tile_cache[(iy, ix)] = (schunk, fchunk) + + friction_bdry['top'][iy][ix] = fchunk[0, :].astype(np.float64) + friction_bdry['bottom'][iy][ix] = fchunk[-1, :].astype(np.float64) + friction_bdry['left'][iy][ix] = fchunk[:, 0].astype(np.float64) + friction_bdry['right'][iy][ix] = fchunk[:, -1].astype(np.float64) - return friction_bdry, has_source + if n_values == 0: + has_source[iy][ix] = bool( + np.any((schunk != 0) & np.isfinite(schunk)) + ) + else: + for tv in target_values: + if np.any(schunk == tv): + has_source[iy][ix] = True + break + + return friction_bdry, has_source, tile_cache def _init_boundaries(chunks_y, chunks_x): @@ -826,7 +849,7 @@ def _can_skip(iy, ix, has_source, boundaries, return True -def _process_tile(iy, ix, source_da, friction_da, +def _process_tile(iy, ix, tile_cache, boundaries, friction_bdry, cellsize_x, cellsize_y, max_cost, target_values, dy, dx, dd, chunks_y, chunks_x, @@ -835,8 +858,7 @@ def _process_tile(iy, ix, source_da, friction_da, Returns the maximum absolute boundary change (float). """ - source_chunk = source_da.blocks[iy, ix].compute() - friction_chunk = friction_da.blocks[iy, ix].compute() + source_chunk, friction_chunk = tile_cache[(iy, ix)] h, w = source_chunk.shape seeds = _compute_seeds( @@ -894,8 +916,8 @@ def _cost_distance_dask_iterative(source_da, friction_da, n_tile_y = len(chunks_y) n_tile_x = len(chunks_x) - # Phase 0: pre-extract friction boundaries & source tile flags - friction_bdry, has_source = _preprocess_tiles( + # Phase 0: batch-compute all tiles, extract boundaries & source flags + friction_bdry, has_source, tile_cache = _preprocess_tiles( source_da, friction_da, chunks_y, chunks_x, target_values, ) @@ -904,7 +926,7 @@ def _cost_distance_dask_iterative(source_da, friction_da, # Phase 2: iterative forward/backward sweeps max_iterations = max(n_tile_y, n_tile_x) + 10 - args = (source_da, friction_da, boundaries, friction_bdry, + args = (tile_cache, boundaries, friction_bdry, cellsize_x, cellsize_y, max_cost, target_values, dy, dx, dd, chunks_y, chunks_x, n_tile_y, n_tile_x, connectivity) @@ -935,44 +957,40 @@ def _cost_distance_dask_iterative(source_da, friction_da, if max_change == 0.0: break - # Phase 3: lazy final assembly via da.map_blocks + # Phase 3: eager assembly from cached tiles with converged seeds return _assemble_result( - source_da, friction_da, boundaries, friction_bdry, + tile_cache, boundaries, friction_bdry, cellsize_x, cellsize_y, max_cost, target_values, dy, dx, dd, chunks_y, chunks_x, n_tile_y, n_tile_x, connectivity, ) -def _assemble_result(source_da, friction_da, boundaries, friction_bdry, +def _assemble_result(tile_cache, boundaries, friction_bdry, cellsize_x, cellsize_y, max_cost, target_values, dy, dx, dd, chunks_y, chunks_x, n_tile_y, n_tile_x, connectivity): - """Build a lazy dask array by re-running each tile with converged seeds.""" - - def _tile_fn(source_block, friction_block, block_info=None): - if block_info is None or 0 not in block_info: - return np.full(source_block.shape, np.nan, dtype=np.float32) - iy, ix = block_info[0]['chunk-location'] - h, w = source_block.shape - seeds = _compute_seeds( - iy, ix, boundaries, friction_bdry, - cellsize_x, cellsize_y, chunks_y, chunks_x, - n_tile_y, n_tile_x, connectivity, - ) - dist = _cost_distance_tile_kernel( - source_block, friction_block, h, w, - cellsize_x, cellsize_y, max_cost, target_values, - dy, dx, dd, *seeds, - ) - return _dist_to_float32(dist, h, w, max_cost) - - return da.map_blocks( - _tile_fn, - source_da, friction_da, - dtype=np.float32, - meta=np.array((), dtype=np.float32), - ) + """Build result array from cached tiles and converged boundary seeds.""" + rows = [] + for iy in range(n_tile_y): + row_blocks = [] + for ix in range(n_tile_x): + src, fric = tile_cache[(iy, ix)] + h, w = src.shape + seeds = _compute_seeds( + iy, ix, boundaries, friction_bdry, + cellsize_x, cellsize_y, chunks_y, chunks_x, + n_tile_y, n_tile_x, connectivity, + ) + dist = _cost_distance_tile_kernel( + src, fric, h, w, + cellsize_x, cellsize_y, max_cost, target_values, + dy, dx, dd, *seeds, + ) + row_blocks.append(_dist_to_float32(dist, h, w, max_cost)) + rows.append(np.concatenate(row_blocks, axis=1)) + full = np.concatenate(rows, axis=0) + return da.from_array(full, chunks=(chunks_y, chunks_x)) # --------------------------------------------------------------------------- @@ -995,14 +1013,18 @@ def _chunk(source_block, friction_block): def _cost_distance_dask(source_da, friction_da, cellsize_x, cellsize_y, - max_cost, target_values, dy, dx, dd): + max_cost, target_values, dy, dx, dd, + _f_min=None): """Dask path: use map_overlap with depth derived from max_cost.""" # We need the global minimum friction to compute max pixel radius. - # This is a tiny scalar .compute(). - # Use da.where to avoid boolean indexing (which creates unknown chunks). - positive_friction = da.where(friction_da > 0, friction_da, np.inf) - f_min = da.nanmin(positive_friction).compute() + # This is a tiny scalar .compute(); skip if caller already computed it. + if _f_min is not None: + f_min = _f_min + else: + # Use da.where to avoid boolean indexing (which creates unknown chunks). + positive_friction = da.where(friction_da > 0, friction_da, np.inf) + f_min = da.nanmin(positive_friction).compute() if not np.isfinite(f_min) or f_min <= 0: # All friction is non-positive or NaN — nothing reachable return da.full(source_da.shape, np.nan, dtype=np.float32,