Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions xrspatial/terrain.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

# std lib
import math
import threading
from functools import partial
from typing import Dict, List, Optional, Tuple, Union

# 3rd-party
import numpy as np
import xarray as xr
from numba import njit, prange

try:
import cupy
Expand Down Expand Up @@ -99,13 +102,171 @@ def _scale(value, old_range, new_range):
# numpy backend
# ---------------------------------------------------------------------------

# Fused, multithreaded reimplementation of the perlin octave loop. The plain
# numpy version in _gen_terrain calls _perlin once per octave, and each _perlin
# call allocates ~20 full-size temporaries (floor, gather, fade, lerp ...). At
# 16 octaves that is hundreds of array allocations. The kernel below collapses
# warp + fbm/ridged + cube into one per-pixel scalar computation parallelised
# over rows with prange, so there are no per-octave temporaries. Output matches
# the numpy path to float32 epsilon. Worley blending is not handled here (it
# needs a global min/max pass); _gen_terrain keeps the numpy path for that.

# Numba parallel=True kernels must not be launched concurrently from multiple
# Python threads: the default 'workqueue' threading layer is not threadsafe and
# aborts the process (SIGABRT on macOS) when two host threads enter a parallel
# region at once. _terrain_dask_numpy calls _gen_terrain_fast per chunk under
# dask's threaded scheduler, so the kernel launch is serialized behind this
# lock. Same hazard and fix as the reproject kernels (#3141).
_PARALLEL_KERNEL_LOCK = threading.Lock()


@njit(nogil=True, inline='always')
def _fade_scalar(t):
t3 = t * t * t
t4 = t3 * t
t5 = t4 * t
return 6 * t5 - 15 * t4 + 10 * t3


@njit(nogil=True, inline='always')
def _grad_scalar(h, x, y):
hv = h & 3
sel = (hv >> 1) & 1 # 0 -> y axis, 1 -> x axis
u = x * sel + y * (1 - sel)
return u * (1 - 2 * (hv & 1))


@njit(nogil=True, inline='always')
def _perlin_point(p, x, y):
xfl = np.float32(math.floor(x))
yfl = np.float32(math.floor(y))
xi = np.int32(xfl) & 255
yi = np.int32(yfl) & 255
xf = x - xfl
yf = y - yfl
u = _fade_scalar(xf)
v = _fade_scalar(yf)
one = np.float32(1.0)
n00 = _grad_scalar(p[p[xi] + yi], xf, yf)
n01 = _grad_scalar(p[p[xi] + yi + 1], xf, yf - one)
n11 = _grad_scalar(p[p[xi + 1] + yi + 1], xf - one, yf - one)
n10 = _grad_scalar(p[p[xi + 1] + yi], xf - one, yf)
x1 = n00 + u * (n10 - n00)
x2 = n01 + u * (n11 - n01)
return x1 + v * (x2 - x1)


@njit(parallel=True, nogil=True, cache=True)
def _fused_terrain_numpy(out, linx, liny, perms, octaves,
persistence, lacunarity, norm, is_ridged,
do_warp, warp_px, warp_py, warp_octaves,
warp_norm, warp_strength):
height = out.shape[0]
width = out.shape[1]
one = np.float32(1.0)
zero = np.float32(0.0)
for i in prange(height):
yb0 = liny[i]
for j in range(width):
xb = linx[j]
yb = yb0

# --- domain warping ---
if do_warp:
wx = zero
wy = zero
amp = one
freq = one
for wi in range(warp_octaves):
wx += _perlin_point(warp_px[wi], xb * freq, yb * freq) * amp
wy += _perlin_point(warp_py[wi], xb * freq, yb * freq) * amp
amp *= persistence
freq *= lacunarity
xb = xb + (wx / warp_norm) * warp_strength
yb = yb + (wy / warp_norm) * warp_strength

# --- octave noise loop ---
acc = zero
amp = one
freq = one
if is_ridged:
weight = one
for o in range(octaves):
nval = _perlin_point(perms[o], xb * freq, yb * freq)
nval = one - abs(nval)
nval = nval * nval
nval = nval * weight
weight = min(max(nval, zero), one)
acc += nval * amp
amp *= persistence
freq *= lacunarity
else: # fbm
for o in range(octaves):
acc += _perlin_point(perms[o], xb * freq, yb * freq) * amp
amp *= persistence
freq *= lacunarity

val = acc / norm
out[i, j] = val * val * val


def _gen_terrain_fast(height, width, seed, x_range, y_range, octaves,
lacunarity, persistence, noise_mode,
warp_strength, warp_octaves):
"""Fast path for _gen_terrain when worley blending is off (the default)."""
linx = np.linspace(
x_range[0], x_range[1], width, endpoint=False, dtype=np.float32
)
liny = np.linspace(
y_range[0], y_range[1], height, endpoint=False, dtype=np.float32
)

perms = np.stack([_make_perm_table(seed + i) for i in range(octaves)])
norm = float(sum(persistence ** i for i in range(octaves)))

do_warp = warp_strength > 0
if do_warp:
warp_px = np.stack(
[_make_perm_table(seed + 100 + wi) for wi in range(warp_octaves)]
)
warp_py = np.stack(
[_make_perm_table(seed + 200 + wi) for wi in range(warp_octaves)]
)
warp_norm = float(sum(persistence ** i for i in range(warp_octaves)))
else:
# numba needs concrete array types even on the unused branch
warp_px = np.zeros((1, 512), dtype=np.int32)
warp_py = np.zeros((1, 512), dtype=np.int32)
warp_norm = 1.0

out = np.empty((height, width), dtype=np.float32)
with _PARALLEL_KERNEL_LOCK:
_fused_terrain_numpy(
out, linx, liny, perms, octaves,
np.float32(persistence), np.float32(lacunarity), np.float32(norm),
noise_mode == 'ridged', do_warp, warp_px, warp_py,
warp_octaves, np.float32(warp_norm), np.float32(warp_strength),
)
return out


def _gen_terrain(height_map, seed, x_range=(0, 1), y_range=(0, 1),
octaves=16, lacunarity=2.0, persistence=0.5,
noise_mode='fbm', warp_strength=0.0, warp_octaves=4,
worley_blend=0.0, worley_seed=None,
worley_norm_range=None):
height, width = height_map.shape

# Fast path: when worley blending is off (the default), the whole warp +
# octave + cube pipeline runs in one fused, multithreaded kernel instead of
# the per-octave numpy loop below. Worley needs a global min/max pass, so
# it stays on the numpy path.
if worley_blend <= 0:
return _gen_terrain_fast(
height, width, seed, x_range, y_range, octaves,
lacunarity, persistence, noise_mode, warp_strength, warp_octaves,
)

linx = np.linspace(
x_range[0], x_range[1], width, endpoint=False, dtype=np.float32
)
Expand Down
73 changes: 73 additions & 0 deletions xrspatial/tests/test_terrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,76 @@ def test_terrain_all_nan_template_dask_cupy_matches_numpy():
assert np.isfinite(t_dc.data.get()).all()
np.testing.assert_allclose(t_np.data, t_dc.data.get(),
rtol=1e-4, atol=1e-4)


# ---------------------------------------------------------------------------
# Fused numpy fast path (worley off) -- independent correctness anchor
# ---------------------------------------------------------------------------

def _reference_terrain(height, width, seed, noise_mode='fbm', octaves=16,
lacunarity=2.0, persistence=0.5, zfactor=4000):
"""Recompute the worley-off terrain pipeline from the trusted, un-fused
_perlin building block.

The fused numpy kernel and the dask+numpy path both go through
_gen_terrain, so dask-vs-numpy parity cannot catch a bug in the kernel
itself, and the GPU parity tests are skipped on CPU-only CI. This mirrors
_gen_terrain + _terrain_numpy with x_range/y_range scaled to (0, 1) (the
default full_extent) so it stays an independent reference.
"""
from xrspatial.perlin import _make_perm_table, _perlin

linx = np.linspace(0, 1, width, endpoint=False, dtype=np.float32)
liny = np.linspace(0, 1, height, endpoint=False, dtype=np.float32)
x, y = np.meshgrid(linx, liny)

hm = np.zeros((height, width), dtype=np.float32)
norm = sum(persistence ** i for i in range(octaves))
if noise_mode == 'ridged':
weight = np.ones((height, width), dtype=np.float32)
for i in range(octaves):
amp = persistence ** i
freq = lacunarity ** i
noise = _perlin(_make_perm_table(seed + i), x * freq, y * freq)
noise = 1.0 - np.abs(noise)
noise = noise * noise
noise *= weight
weight = np.clip(noise, 0, 1)
hm += noise * amp
else:
for i in range(octaves):
amp = persistence ** i
freq = lacunarity ** i
hm += _perlin(_make_perm_table(seed + i), x * freq, y * freq) * amp

hm /= norm
hm = hm ** 3
hm = np.clip(hm, -1, 1)
hm = (hm + 1) / 2
hm[hm < 0.3] = 0
hm *= zfactor
return hm


# rtol is the real drift guard; the loose atol only absorbs the ridged
# feedback path's ~6e-3 absolute gap on zfactor-scaled values. Do not
# tighten atol or the ridged case flakes.
_REF_RTOL = 1e-4
_REF_ATOL = 2e-2


@pytest.mark.parametrize('noise_mode', ['fbm', 'ridged'])
def test_fused_numpy_matches_reference(noise_mode):
data = xr.DataArray(np.zeros((60, 80), dtype=np.float32), dims=['y', 'x'])
out = generate_terrain(data, seed=10, noise_mode=noise_mode)
ref = _reference_terrain(60, 80, seed=10, noise_mode=noise_mode)
np.testing.assert_allclose(out.data, ref, rtol=_REF_RTOL, atol=_REF_ATOL)


def test_fused_numpy_matches_reference_nondefault_params():
data = xr.DataArray(np.zeros((60, 80), dtype=np.float32), dims=['y', 'x'])
out = generate_terrain(data, seed=7, octaves=10,
lacunarity=2.3, persistence=0.45)
ref = _reference_terrain(60, 80, seed=7, octaves=10,
lacunarity=2.3, persistence=0.45)
np.testing.assert_allclose(out.data, ref, rtol=_REF_RTOL, atol=_REF_ATOL)
Loading