diff --git a/xrspatial/terrain.py b/xrspatial/terrain.py index c66b32f8a..ebd680378 100644 --- a/xrspatial/terrain.py +++ b/xrspatial/terrain.py @@ -1,12 +1,15 @@ from __future__ import annotations # std lib +import math +import threading from functools import partial from typing import Dict, List, Optional, Tuple, Union # 3rd-party import numpy as np import xarray as xr +from numba import njit, prange try: import cupy @@ -99,6 +102,154 @@ def _scale(value, old_range, new_range): # numpy backend # --------------------------------------------------------------------------- +# Fused, multithreaded reimplementation of the perlin octave loop. The plain +# numpy version in _gen_terrain calls _perlin once per octave, and each _perlin +# call allocates ~20 full-size temporaries (floor, gather, fade, lerp ...). At +# 16 octaves that is hundreds of array allocations. The kernel below collapses +# warp + fbm/ridged + cube into one per-pixel scalar computation parallelised +# over rows with prange, so there are no per-octave temporaries. Output matches +# the numpy path to float32 epsilon. Worley blending is not handled here (it +# needs a global min/max pass); _gen_terrain keeps the numpy path for that. + +# Numba parallel=True kernels must not be launched concurrently from multiple +# Python threads: the default 'workqueue' threading layer is not threadsafe and +# aborts the process (SIGABRT on macOS) when two host threads enter a parallel +# region at once. _terrain_dask_numpy calls _gen_terrain_fast per chunk under +# dask's threaded scheduler, so the kernel launch is serialized behind this +# lock. Same hazard and fix as the reproject kernels (#3141). +_PARALLEL_KERNEL_LOCK = threading.Lock() + + +@njit(nogil=True, inline='always') +def _fade_scalar(t): + t3 = t * t * t + t4 = t3 * t + t5 = t4 * t + return 6 * t5 - 15 * t4 + 10 * t3 + + +@njit(nogil=True, inline='always') +def _grad_scalar(h, x, y): + hv = h & 3 + sel = (hv >> 1) & 1 # 0 -> y axis, 1 -> x axis + u = x * sel + y * (1 - sel) + return u * (1 - 2 * (hv & 1)) + + +@njit(nogil=True, inline='always') +def _perlin_point(p, x, y): + xfl = np.float32(math.floor(x)) + yfl = np.float32(math.floor(y)) + xi = np.int32(xfl) & 255 + yi = np.int32(yfl) & 255 + xf = x - xfl + yf = y - yfl + u = _fade_scalar(xf) + v = _fade_scalar(yf) + one = np.float32(1.0) + n00 = _grad_scalar(p[p[xi] + yi], xf, yf) + n01 = _grad_scalar(p[p[xi] + yi + 1], xf, yf - one) + n11 = _grad_scalar(p[p[xi + 1] + yi + 1], xf - one, yf - one) + n10 = _grad_scalar(p[p[xi + 1] + yi], xf - one, yf) + x1 = n00 + u * (n10 - n00) + x2 = n01 + u * (n11 - n01) + return x1 + v * (x2 - x1) + + +@njit(parallel=True, nogil=True, cache=True) +def _fused_terrain_numpy(out, linx, liny, perms, octaves, + persistence, lacunarity, norm, is_ridged, + do_warp, warp_px, warp_py, warp_octaves, + warp_norm, warp_strength): + height = out.shape[0] + width = out.shape[1] + one = np.float32(1.0) + zero = np.float32(0.0) + for i in prange(height): + yb0 = liny[i] + for j in range(width): + xb = linx[j] + yb = yb0 + + # --- domain warping --- + if do_warp: + wx = zero + wy = zero + amp = one + freq = one + for wi in range(warp_octaves): + wx += _perlin_point(warp_px[wi], xb * freq, yb * freq) * amp + wy += _perlin_point(warp_py[wi], xb * freq, yb * freq) * amp + amp *= persistence + freq *= lacunarity + xb = xb + (wx / warp_norm) * warp_strength + yb = yb + (wy / warp_norm) * warp_strength + + # --- octave noise loop --- + acc = zero + amp = one + freq = one + if is_ridged: + weight = one + for o in range(octaves): + nval = _perlin_point(perms[o], xb * freq, yb * freq) + nval = one - abs(nval) + nval = nval * nval + nval = nval * weight + weight = min(max(nval, zero), one) + acc += nval * amp + amp *= persistence + freq *= lacunarity + else: # fbm + for o in range(octaves): + acc += _perlin_point(perms[o], xb * freq, yb * freq) * amp + amp *= persistence + freq *= lacunarity + + val = acc / norm + out[i, j] = val * val * val + + +def _gen_terrain_fast(height, width, seed, x_range, y_range, octaves, + lacunarity, persistence, noise_mode, + warp_strength, warp_octaves): + """Fast path for _gen_terrain when worley blending is off (the default).""" + linx = np.linspace( + x_range[0], x_range[1], width, endpoint=False, dtype=np.float32 + ) + liny = np.linspace( + y_range[0], y_range[1], height, endpoint=False, dtype=np.float32 + ) + + perms = np.stack([_make_perm_table(seed + i) for i in range(octaves)]) + norm = float(sum(persistence ** i for i in range(octaves))) + + do_warp = warp_strength > 0 + if do_warp: + warp_px = np.stack( + [_make_perm_table(seed + 100 + wi) for wi in range(warp_octaves)] + ) + warp_py = np.stack( + [_make_perm_table(seed + 200 + wi) for wi in range(warp_octaves)] + ) + warp_norm = float(sum(persistence ** i for i in range(warp_octaves))) + else: + # numba needs concrete array types even on the unused branch + warp_px = np.zeros((1, 512), dtype=np.int32) + warp_py = np.zeros((1, 512), dtype=np.int32) + warp_norm = 1.0 + + out = np.empty((height, width), dtype=np.float32) + with _PARALLEL_KERNEL_LOCK: + _fused_terrain_numpy( + out, linx, liny, perms, octaves, + np.float32(persistence), np.float32(lacunarity), np.float32(norm), + noise_mode == 'ridged', do_warp, warp_px, warp_py, + warp_octaves, np.float32(warp_norm), np.float32(warp_strength), + ) + return out + + def _gen_terrain(height_map, seed, x_range=(0, 1), y_range=(0, 1), octaves=16, lacunarity=2.0, persistence=0.5, noise_mode='fbm', warp_strength=0.0, warp_octaves=4, @@ -106,6 +257,16 @@ def _gen_terrain(height_map, seed, x_range=(0, 1), y_range=(0, 1), worley_norm_range=None): height, width = height_map.shape + # Fast path: when worley blending is off (the default), the whole warp + + # octave + cube pipeline runs in one fused, multithreaded kernel instead of + # the per-octave numpy loop below. Worley needs a global min/max pass, so + # it stays on the numpy path. + if worley_blend <= 0: + return _gen_terrain_fast( + height, width, seed, x_range, y_range, octaves, + lacunarity, persistence, noise_mode, warp_strength, warp_octaves, + ) + linx = np.linspace( x_range[0], x_range[1], width, endpoint=False, dtype=np.float32 ) diff --git a/xrspatial/tests/test_terrain.py b/xrspatial/tests/test_terrain.py index 797faa4d5..96f297dda 100644 --- a/xrspatial/tests/test_terrain.py +++ b/xrspatial/tests/test_terrain.py @@ -528,3 +528,76 @@ def test_terrain_all_nan_template_dask_cupy_matches_numpy(): assert np.isfinite(t_dc.data.get()).all() np.testing.assert_allclose(t_np.data, t_dc.data.get(), rtol=1e-4, atol=1e-4) + + +# --------------------------------------------------------------------------- +# Fused numpy fast path (worley off) -- independent correctness anchor +# --------------------------------------------------------------------------- + +def _reference_terrain(height, width, seed, noise_mode='fbm', octaves=16, + lacunarity=2.0, persistence=0.5, zfactor=4000): + """Recompute the worley-off terrain pipeline from the trusted, un-fused + _perlin building block. + + The fused numpy kernel and the dask+numpy path both go through + _gen_terrain, so dask-vs-numpy parity cannot catch a bug in the kernel + itself, and the GPU parity tests are skipped on CPU-only CI. This mirrors + _gen_terrain + _terrain_numpy with x_range/y_range scaled to (0, 1) (the + default full_extent) so it stays an independent reference. + """ + from xrspatial.perlin import _make_perm_table, _perlin + + linx = np.linspace(0, 1, width, endpoint=False, dtype=np.float32) + liny = np.linspace(0, 1, height, endpoint=False, dtype=np.float32) + x, y = np.meshgrid(linx, liny) + + hm = np.zeros((height, width), dtype=np.float32) + norm = sum(persistence ** i for i in range(octaves)) + if noise_mode == 'ridged': + weight = np.ones((height, width), dtype=np.float32) + for i in range(octaves): + amp = persistence ** i + freq = lacunarity ** i + noise = _perlin(_make_perm_table(seed + i), x * freq, y * freq) + noise = 1.0 - np.abs(noise) + noise = noise * noise + noise *= weight + weight = np.clip(noise, 0, 1) + hm += noise * amp + else: + for i in range(octaves): + amp = persistence ** i + freq = lacunarity ** i + hm += _perlin(_make_perm_table(seed + i), x * freq, y * freq) * amp + + hm /= norm + hm = hm ** 3 + hm = np.clip(hm, -1, 1) + hm = (hm + 1) / 2 + hm[hm < 0.3] = 0 + hm *= zfactor + return hm + + +# rtol is the real drift guard; the loose atol only absorbs the ridged +# feedback path's ~6e-3 absolute gap on zfactor-scaled values. Do not +# tighten atol or the ridged case flakes. +_REF_RTOL = 1e-4 +_REF_ATOL = 2e-2 + + +@pytest.mark.parametrize('noise_mode', ['fbm', 'ridged']) +def test_fused_numpy_matches_reference(noise_mode): + data = xr.DataArray(np.zeros((60, 80), dtype=np.float32), dims=['y', 'x']) + out = generate_terrain(data, seed=10, noise_mode=noise_mode) + ref = _reference_terrain(60, 80, seed=10, noise_mode=noise_mode) + np.testing.assert_allclose(out.data, ref, rtol=_REF_RTOL, atol=_REF_ATOL) + + +def test_fused_numpy_matches_reference_nondefault_params(): + data = xr.DataArray(np.zeros((60, 80), dtype=np.float32), dims=['y', 'x']) + out = generate_terrain(data, seed=7, octaves=10, + lacunarity=2.3, persistence=0.45) + ref = _reference_terrain(60, 80, seed=7, octaves=10, + lacunarity=2.3, persistence=0.45) + np.testing.assert_allclose(out.data, ref, rtol=_REF_RTOL, atol=_REF_ATOL)