Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 84 additions & 62 deletions xrspatial/cost_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,20 +374,21 @@ def _cost_distance_dask_cupy(source_da, friction_da,
height, width = source_da.shape

use_map_overlap = False
f_min_cached = None
if np.isfinite(max_cost):
positive_friction = da.where(friction_da > 0, friction_da, np.inf)
f_min = float(da.nanmin(positive_friction).compute())
if np.isfinite(f_min) and f_min > 0:
f_min_cached = float(da.nanmin(positive_friction).compute())
if np.isfinite(f_min_cached) and f_min_cached > 0:
min_cellsize = min(cellsize_x, cellsize_y)
max_radius = max_cost / (f_min * min_cellsize)
max_radius = max_cost / (f_min_cached * min_cellsize)
pad = int(max_radius + 1)
chunks_y, chunks_x = source_da.chunks
if pad < max(chunks_y) and pad < max(chunks_x):
use_map_overlap = True

if use_map_overlap:
pad_y = int(max_cost / (f_min * cellsize_y) + 1)
pad_x = int(max_cost / (f_min * cellsize_x) + 1)
pad_y = int(max_cost / (f_min_cached * cellsize_y) + 1)
pad_x = int(max_cost / (f_min_cached * cellsize_x) + 1)

# Closure captures the scalar parameters
tv = target_values
Expand Down Expand Up @@ -423,6 +424,7 @@ def _chunk_func(source_block, friction_block):
source_np, friction_np,
cellsize_x, cellsize_y, max_cost,
target_values, dy, dx, dd,
_f_min=f_min_cached,
)
# Convert back to dask+cupy
return result.map_blocks(
Expand Down Expand Up @@ -608,10 +610,18 @@ def _dist_to_float32(dist, height, width, max_cost):

def _preprocess_tiles(source_da, friction_da, chunks_y, chunks_x,
target_values):
"""Extract friction boundary strips and identify source tiles.
"""Extract friction boundary strips, identify source tiles, cache data.

Streams one tile at a time so only one chunk is in RAM.
Batch-computes all tiles in a single scheduler pass and caches them
for reuse during the iterative phase (avoids repeated .compute()).
Friction boundaries are stored as float64 to avoid repeated conversion
in _compute_seeds.

Returns (friction_bdry, has_source, tile_cache) where tile_cache is
a dict mapping (iy, ix) -> (source_np, friction_np).
"""
import dask

n_tile_y = len(chunks_y)
n_tile_x = len(chunks_x)
n_values = len(target_values)
Expand All @@ -621,27 +631,40 @@ def _preprocess_tiles(source_da, friction_da, chunks_y, chunks_x,
for side in ('top', 'bottom', 'left', 'right')
}
has_source = [[False] * n_tile_x for _ in range(n_tile_y)]
tile_cache = {}

# Batch-compute all tiles in one scheduler pass
blocks = []
indices = []
for iy in range(n_tile_y):
for ix in range(n_tile_x):
fchunk = friction_da.blocks[iy, ix].compute()
friction_bdry['top'][iy][ix] = fchunk[0, :].astype(np.float32)
friction_bdry['bottom'][iy][ix] = fchunk[-1, :].astype(np.float32)
friction_bdry['left'][iy][ix] = fchunk[:, 0].astype(np.float32)
friction_bdry['right'][iy][ix] = fchunk[:, -1].astype(np.float32)
blocks.append(source_da.blocks[iy, ix])
blocks.append(friction_da.blocks[iy, ix])
indices.append((iy, ix))

schunk = source_da.blocks[iy, ix].compute()
if n_values == 0:
has_source[iy][ix] = bool(
np.any((schunk != 0) & np.isfinite(schunk))
)
else:
for tv in target_values:
if np.any(schunk == tv):
has_source[iy][ix] = True
break
computed = dask.compute(*blocks)

for i, (iy, ix) in enumerate(indices):
schunk = computed[i * 2]
fchunk = computed[i * 2 + 1]
tile_cache[(iy, ix)] = (schunk, fchunk)

friction_bdry['top'][iy][ix] = fchunk[0, :].astype(np.float64)
friction_bdry['bottom'][iy][ix] = fchunk[-1, :].astype(np.float64)
friction_bdry['left'][iy][ix] = fchunk[:, 0].astype(np.float64)
friction_bdry['right'][iy][ix] = fchunk[:, -1].astype(np.float64)

return friction_bdry, has_source
if n_values == 0:
has_source[iy][ix] = bool(
np.any((schunk != 0) & np.isfinite(schunk))
)
else:
for tv in target_values:
if np.any(schunk == tv):
has_source[iy][ix] = True
break

return friction_bdry, has_source, tile_cache


def _init_boundaries(chunks_y, chunks_x):
Expand Down Expand Up @@ -826,7 +849,7 @@ def _can_skip(iy, ix, has_source, boundaries,
return True


def _process_tile(iy, ix, source_da, friction_da,
def _process_tile(iy, ix, tile_cache,
boundaries, friction_bdry,
cellsize_x, cellsize_y, max_cost, target_values,
dy, dx, dd, chunks_y, chunks_x,
Expand All @@ -835,8 +858,7 @@ def _process_tile(iy, ix, source_da, friction_da,

Returns the maximum absolute boundary change (float).
"""
source_chunk = source_da.blocks[iy, ix].compute()
friction_chunk = friction_da.blocks[iy, ix].compute()
source_chunk, friction_chunk = tile_cache[(iy, ix)]
h, w = source_chunk.shape

seeds = _compute_seeds(
Expand Down Expand Up @@ -894,8 +916,8 @@ def _cost_distance_dask_iterative(source_da, friction_da,
n_tile_y = len(chunks_y)
n_tile_x = len(chunks_x)

# Phase 0: pre-extract friction boundaries & source tile flags
friction_bdry, has_source = _preprocess_tiles(
# Phase 0: batch-compute all tiles, extract boundaries & source flags
friction_bdry, has_source, tile_cache = _preprocess_tiles(
source_da, friction_da, chunks_y, chunks_x, target_values,
)

Expand All @@ -904,7 +926,7 @@ def _cost_distance_dask_iterative(source_da, friction_da,

# Phase 2: iterative forward/backward sweeps
max_iterations = max(n_tile_y, n_tile_x) + 10
args = (source_da, friction_da, boundaries, friction_bdry,
args = (tile_cache, boundaries, friction_bdry,
cellsize_x, cellsize_y, max_cost, target_values,
dy, dx, dd, chunks_y, chunks_x,
n_tile_y, n_tile_x, connectivity)
Expand Down Expand Up @@ -935,44 +957,40 @@ def _cost_distance_dask_iterative(source_da, friction_da,
if max_change == 0.0:
break

# Phase 3: lazy final assembly via da.map_blocks
# Phase 3: eager assembly from cached tiles with converged seeds
return _assemble_result(
source_da, friction_da, boundaries, friction_bdry,
tile_cache, boundaries, friction_bdry,
cellsize_x, cellsize_y, max_cost, target_values,
dy, dx, dd, chunks_y, chunks_x,
n_tile_y, n_tile_x, connectivity,
)


def _assemble_result(source_da, friction_da, boundaries, friction_bdry,
def _assemble_result(tile_cache, boundaries, friction_bdry,
cellsize_x, cellsize_y, max_cost, target_values,
dy, dx, dd, chunks_y, chunks_x,
n_tile_y, n_tile_x, connectivity):
"""Build a lazy dask array by re-running each tile with converged seeds."""

def _tile_fn(source_block, friction_block, block_info=None):
if block_info is None or 0 not in block_info:
return np.full(source_block.shape, np.nan, dtype=np.float32)
iy, ix = block_info[0]['chunk-location']
h, w = source_block.shape
seeds = _compute_seeds(
iy, ix, boundaries, friction_bdry,
cellsize_x, cellsize_y, chunks_y, chunks_x,
n_tile_y, n_tile_x, connectivity,
)
dist = _cost_distance_tile_kernel(
source_block, friction_block, h, w,
cellsize_x, cellsize_y, max_cost, target_values,
dy, dx, dd, *seeds,
)
return _dist_to_float32(dist, h, w, max_cost)

return da.map_blocks(
_tile_fn,
source_da, friction_da,
dtype=np.float32,
meta=np.array((), dtype=np.float32),
)
"""Build result array from cached tiles and converged boundary seeds."""
rows = []
for iy in range(n_tile_y):
row_blocks = []
for ix in range(n_tile_x):
src, fric = tile_cache[(iy, ix)]
h, w = src.shape
seeds = _compute_seeds(
iy, ix, boundaries, friction_bdry,
cellsize_x, cellsize_y, chunks_y, chunks_x,
n_tile_y, n_tile_x, connectivity,
)
dist = _cost_distance_tile_kernel(
src, fric, h, w,
cellsize_x, cellsize_y, max_cost, target_values,
dy, dx, dd, *seeds,
)
row_blocks.append(_dist_to_float32(dist, h, w, max_cost))
rows.append(np.concatenate(row_blocks, axis=1))
full = np.concatenate(rows, axis=0)
return da.from_array(full, chunks=(chunks_y, chunks_x))


# ---------------------------------------------------------------------------
Expand All @@ -995,14 +1013,18 @@ def _chunk(source_block, friction_block):


def _cost_distance_dask(source_da, friction_da, cellsize_x, cellsize_y,
max_cost, target_values, dy, dx, dd):
max_cost, target_values, dy, dx, dd,
_f_min=None):
"""Dask path: use map_overlap with depth derived from max_cost."""

# We need the global minimum friction to compute max pixel radius.
# This is a tiny scalar .compute().
# Use da.where to avoid boolean indexing (which creates unknown chunks).
positive_friction = da.where(friction_da > 0, friction_da, np.inf)
f_min = da.nanmin(positive_friction).compute()
# This is a tiny scalar .compute(); skip if caller already computed it.
if _f_min is not None:
f_min = _f_min
else:
# Use da.where to avoid boolean indexing (which creates unknown chunks).
positive_friction = da.where(friction_da > 0, friction_da, np.inf)
f_min = da.nanmin(positive_friction).compute()
if not np.isfinite(f_min) or f_min <= 0:
# All friction is non-positive or NaN — nothing reachable
return da.full(source_da.shape, np.nan, dtype=np.float32,
Expand Down
Loading