diff --git a/docs/source/reference/templates.rst b/docs/source/reference/templates.rst index 1b8f8c57c..43bc55cbb 100644 --- a/docs/source/reference/templates.rst +++ b/docs/source/reference/templates.rst @@ -25,3 +25,28 @@ From Template xrspatial.templates.from_template xrspatial.templates.list_templates + +Putting your data on a template +=============================== + +A template is an empty canvas; ``coregister`` fills it with your own data. It +reprojects a raster :class:`~xarray.DataArray` onto the template's grid, or +rasterizes a ``GeoDataFrame`` onto it, so every layer lines up cell-for-cell. + +.. sourcecode:: python + + from xrspatial import from_template + + grid = from_template("conus", resolution=1000) # empty Albers grid + elevation = grid.xrs.coregister(my_dem) # raster -> grid + roads = grid.xrs.coregister(my_roads_gdf) # vectors -> grid + slope = elevation.xrs.slope() + +For an out-of-core run, ask for a dask backend. The template tiles into even, +square blocks tuned for the neighborhood ops that follow, so the downstream +task graph stays parallel and overlap-friendly: + +.. sourcecode:: python + + grid = from_template("conus", resolution=250, backend="dask") + slope = grid.xrs.coregister(my_dem).xrs.slope() # stays lazy diff --git a/xrspatial/accessor.py b/xrspatial/accessor.py index cff20a752..78b88702b 100644 --- a/xrspatial/accessor.py +++ b/xrspatial/accessor.py @@ -850,6 +850,79 @@ def _rasterize_on_accessor(obj, geometries, *, coregister, **kwargs): return rasterize(geometries, like=obj, **kwargs) +def _coregister_on_accessor(obj, data, *, resampling, **kwargs): + """Put *data* on caller raster *obj*'s grid (CRS + bounds + shape). + + Dispatches on the input type: + + - A GeoDataFrame is rasterized onto the grid, reprojected into the caller + CRS first -- the same path as ``rasterize(coregister=True)``. The + ``resampling`` argument does not apply to a burn and is ignored. + - A raster ``DataArray`` is reprojected and resampled onto the grid. + + Either way the result shares *obj*'s ``y``/``x`` coordinates exactly, so it + lines up cell-for-cell with the template and any other layer coregistered + the same way. + """ + from xrspatial.interpolate._vector import is_geodataframe + + if is_geodataframe(data): + return _rasterize_on_accessor(obj, data, coregister=True, **kwargs) + if isinstance(data, xr.DataArray): + return _reproject_onto_accessor(obj, data, resampling=resampling, + **kwargs) + raise TypeError( + "coregister expects a raster xarray.DataArray or a GeoDataFrame, " + f"got {type(data).__name__}" + ) + + +def _reproject_onto_accessor(target, source, *, resampling, **kwargs): + """Reproject raster *source* onto *target*'s exact grid.""" + import numpy as np + + from .reproject import reproject + from .rasterize import _like_crs + + target_crs = _like_crs(target) + if target_crs is None: + raise ValueError( + "coregister of a raster needs a CRS on the caller grid " + "(its detected raster CRS, e.g. attrs['crs'])" + ) + y = np.asarray(target['y'].values, dtype='float64') + x = np.asarray(target['x'].values, dtype='float64') + if y.size < 2 or x.size < 2: + raise ValueError( + "coregister needs a caller grid of at least 2x2 cells" + ) + res = target.attrs.get('res') + if res is not None: + res_x, res_y = abs(float(res[0])), abs(float(res[1])) + else: + res_x = abs(float(x[1] - x[0])) + res_y = abs(float(y[1] - y[0])) + bounds = (float(x.min()) - res_x / 2.0, float(y.min()) - res_y / 2.0, + float(x.max()) + res_x / 2.0, float(y.max()) + res_y / 2.0) + out = reproject(source, target_crs, bounds=bounds, + width=x.size, height=y.size, resampling=resampling, + **kwargs) + # reproject emits north-up (descending y, ascending x) regardless of the + # source order. Match the caller's axis directions before snapping so a + # grid stored in either order lines up by geography, not by row index. + if x[0] > x[-1]: + out = out.isel(x=slice(None, None, -1)) + if y[0] < y[-1]: + out = out.isel(y=slice(None, None, -1)) + # Snap to the caller's exact coordinates and carry its CRS attrs so the + # result is a cell-for-cell drop-in for the template grid. + out = out.assign_coords(y=target['y'], x=target['x']) + for k in ('crs', 'crs_wkt', 'grid_mapping_name'): + if k in target.attrs: + out.attrs[k] = target.attrs[k] + return out + + @xr.register_dataarray_accessor("xrs") class XrsSpatialDataArrayAccessor: """DataArray accessor exposing xarray-spatial operations.""" @@ -1441,6 +1514,49 @@ def rasterize(self, geometries, *, coregister=False, **kwargs): return _rasterize_on_accessor(self._obj, geometries, coregister=coregister, **kwargs) + def coregister(self, data, *, resampling='bilinear', **kwargs): + """Put *data* on this raster's grid, ready to analyze. + + One call to land another layer on this grid, whatever form it takes: + + - A raster ``xarray.DataArray`` is reprojected and resampled from its + own CRS onto this grid's CRS, bounds, and shape. + - A ``GeoDataFrame`` is burned onto the grid with + :meth:`rasterize`, reprojected into this CRS first (the + ``coregister=True`` path). + + The result shares this raster's ``y``/``x`` coordinates exactly, so it + stacks cell-for-cell with the template and with anything else + coregistered onto it. Pair it with :func:`xrspatial.from_template` to + go from a region name to an analysis-ready grid:: + + grid = from_template('conus') + elevation = grid.xrs.coregister(my_dem) # raster -> grid + roads = grid.xrs.coregister(my_roads_gdf) # vectors -> grid + slope = elevation.xrs.slope() + + Parameters + ---------- + data : xarray.DataArray or geopandas.GeoDataFrame + The layer to place on this grid. A raster DataArray needs a + detectable CRS (``attrs['crs']`` or ``crs_wkt``); a GeoDataFrame + needs ``.crs``. + resampling : str, default 'bilinear' + Resampling method for the raster path, forwarded to + :func:`xrspatial.reproject.reproject` (e.g. ``'nearest'`` for + categorical data). Ignored when *data* is a GeoDataFrame. + **kwargs + Forwarded to :func:`~xrspatial.reproject.reproject` (raster input) + or :func:`~xrspatial.rasterize.rasterize` (GeoDataFrame input). + + Returns + ------- + xarray.DataArray + *data* on this raster's grid. + """ + return _coregister_on_accessor(self._obj, data, resampling=resampling, + **kwargs) + # ---- GeoTIFF I/O ---- def to_geotiff(self, path, **kwargs): diff --git a/xrspatial/templates.py b/xrspatial/templates.py index bcd9f9b03..177dc4efa 100644 --- a/xrspatial/templates.py +++ b/xrspatial/templates.py @@ -36,6 +36,46 @@ # ~26k blocks, so 1e6 leaves wide headroom while catching the runaway case. _MAX_CHUNKS = 1_000_000 +# Target block edge (cells) for the default dask tiling. dask's byte-based +# 'auto' picks one chunk shape from the array's total bytes, which leaves most +# templates as a single giant block (no parallelism) or with thin ragged edge +# slivers. A fixed square block tiles evenly and keeps each chunk friendly to +# the neighborhood ops (slope, focal, ...) that run on the result via +# map_overlap. 2048 is ~16 MB at float32: small enough to parallelize a +# moderate grid, large enough that overlap halos stay cheap. +_DASK_BLOCK = 2048 + +# Ceiling on the block count for the default tiling. A 2048-cell block would +# explode the graph at a typo-level fine resolution, so for very large grids the +# block edge grows to keep the count near this many blocks. That keeps the +# default 'auto' path always under _MAX_CHUNKS (it never errors on its own), and +# leaves any grid up to ~8e11 cells on the plain 2048 block. +_DASK_MAX_BLOCKS = 200_000 + + +def _balanced_axis(size, block): + """Split ``size`` into near-equal blocks of about ``block`` cells. + + Returns a tuple of chunk lengths that sum to ``size`` and differ by at most + one, so there is no tiny trailing sliver. A dimension at or below ~1.5x the + block stays a single chunk. + """ + n = max(1, round(size / block)) + base, rem = divmod(size, n) + return tuple(base + 1 if i < rem else base for i in range(n)) + + +def _neighborhood_chunks(shape): + """Even, square-ish dask chunks for the default tiling. + + Uses a ~``_DASK_BLOCK`` block, growing it for very large grids so the block + count stays near ``_DASK_MAX_BLOCKS`` instead of exploding the task graph. + """ + import math + cells = shape[0] * shape[1] + block = max(_DASK_BLOCK, math.ceil(math.sqrt(cells / _DASK_MAX_BLOCKS))) + return tuple(_balanced_axis(size, block) for size in shape) + def _resolve(name): """Resolve a name to a spec dict. @@ -338,12 +378,18 @@ def from_template(name: str, Dask chunk specification. Supplying it returns a lazy, chunked grid: an eager backend is promoted to its dask variant (``'numpy'`` to ``'dask+numpy'``, ``'cupy'`` to ``'dask+cupy'``), and the cell cap no - longer applies. When omitted, the dask backends use ``'auto'``. The - data stays lazy, but a very fine resolution still builds one task per - chunk, so an extreme shape with small chunks can make a task graph - large enough to bog down the client. To prevent that, a grid that would - split into more than 1,000,000 chunks raises ``ValueError``; coarsen the - resolution or use larger chunks. + longer applies. When omitted (or ``'auto'``), the dask backends tile + the grid into even, square-ish blocks (~2048 cells per side) tuned for + the neighborhood ops -- ``slope``, ``hillshade``, ``focal`` -- that run + on the result through ``map_overlap``. That gives a parallel, + well-formed task graph instead of one giant block or thin ragged edges. + A grid small enough to not be worth splitting stays a single chunk. + Pass an explicit value to override the tiling. The data stays lazy, but + a very fine resolution still builds one task per chunk, so an extreme + shape with small explicit chunks can make a task graph large enough to + bog down the client; a grid that would split into more than 1,000,000 + chunks raises ``ValueError``. The default tiling grows its block for + such grids so it never trips that cap on its own. Returns ------- @@ -393,6 +439,29 @@ def from_template(name: str, >>> agg = from_template("new_england", resolution=10, chunks=512) >>> type(agg.data).__name__ 'Array' + + Recipes + ------- + Pick a grid, drop your own data onto it with + :meth:`DataArray.xrs.coregister `, + then run any tool. ``coregister`` reprojects a raster or rasterizes a + GeoDataFrame onto the template's exact grid, so layers line up + cell-for-cell. + + .. sourcecode:: python + + >>> grid = from_template("conus", resolution=1000) + >>> elevation = grid.xrs.coregister(my_dem) # raster -> grid + >>> roads = grid.xrs.coregister(my_roads_gdf) # vectors -> grid + >>> slope = elevation.xrs.slope() + + For an out-of-core workflow, ask for a dask backend; the default tiling + keeps the downstream graph parallel and overlap-friendly: + + .. sourcecode:: python + + >>> grid = from_template("conus", resolution=250, backend="dask") + >>> slope = grid.xrs.coregister(my_dem).xrs.slope() # stays lazy """ spec = _resolve(name) key = spec["key"] @@ -436,14 +505,25 @@ def from_template(name: str, f"Use a coarser resolution, or pass chunks=... for a lazy dask grid." ) - effective_chunks = "auto" if chunks is None else chunks + # 'auto' (the default) tiles into even square blocks tuned for the + # neighborhood ops that run on the result; an explicit chunks= is honored + # verbatim. Either way the block count is checked against _MAX_CHUNKS below. + if chunks is None or chunks == "auto": + effective_chunks = _neighborhood_chunks((height, width)) if is_dask \ + else "auto" + else: + effective_chunks = chunks if is_dask: n_chunks = _estimate_n_chunks((height, width), effective_chunks) if n_chunks > _MAX_CHUNKS: + # Report the request, not the expanded per-block tuple the default + # tiling produces, so the message stays readable. + chunks_display = chunks if chunks not in (None, "auto") \ + else f"the default ~{_DASK_BLOCK}-cell tiling" raise ValueError( f"resolution {(res_x, res_y)} produces a {height} x {width} " f"grid that splits into {n_chunks:,} chunks with " - f"chunks={effective_chunks!r}, exceeding the " + f"{chunks_display!r}, exceeding the " f"{_MAX_CHUNKS:,}-chunk limit. A task graph this large can bog " f"down the client even though no data is computed. Use a " f"coarser resolution or larger chunks." diff --git a/xrspatial/tests/test_accessor.py b/xrspatial/tests/test_accessor.py index 2901cef88..72e763b49 100644 --- a/xrspatial/tests/test_accessor.py +++ b/xrspatial/tests/test_accessor.py @@ -97,6 +97,7 @@ def test_dataarray_accessor_has_expected_methods(elevation): 'generate_terrain', 'perlin', 'ndvi', 'evi', 'arvi', 'savi', 'nbr', 'sipi', 'rasterize', + 'coregister', 'rechunk_no_shuffle', 'fused_overlap', 'multi_overlap', @@ -703,3 +704,114 @@ def test_catalog_repr_still_works_with_proxy(elevation): assert 'slope' in text # _repr_html_ on the accessor itself still renders the table. assert '' in elevation.xrs._repr_html_() + + +# --------------------------------------------------------------------------- +# 12. coregister — put your own data on a template grid (#3561) +# --------------------------------------------------------------------------- + +def _conus_grid(resolution=20000): + """Small CONUS template (EPSG:5070) to coregister onto.""" + from xrspatial import from_template + return from_template('conus', resolution=resolution) + + +def _latlon_constant(value=7.0, shape=(50, 70)): + """A constant-valued EPSG:4326 raster covering the lower 48.""" + ys = np.linspace(50, 24, shape[0]) + xs = np.linspace(-125, -66, shape[1]) + return xr.DataArray( + np.full(shape, value, dtype='float32'), + dims=['y', 'x'], coords={'y': ys, 'x': xs}, attrs={'crs': 4326}, + ) + + +def test_coregister_raster_lands_on_template_grid(): + pytest.importorskip('pyproj') + grid = _conus_grid() + out = grid.xrs.coregister(_latlon_constant(value=7.0)) + # exact grid match: same shape, same coordinates, template CRS carried over + assert out.shape == grid.shape + assert np.array_equal(out['x'].values, grid['x'].values) + assert np.array_equal(out['y'].values, grid['y'].values) + assert out.attrs['crs'] == grid.attrs['crs'] == 5070 + # a constant field stays constant through the bilinear resample + interior = out.values[np.isfinite(out.values)] + assert interior.size > 0 + np.testing.assert_allclose(interior, 7.0, atol=1e-4) + + +def test_coregister_vector_matches_rasterize_coregister(): + pytest.importorskip('pyproj') + gpd = pytest.importorskip('geopandas') + from shapely.geometry import box + grid = _conus_grid() + gdf = gpd.GeoDataFrame({'v': [1]}, geometry=[box(-110, 35, -95, 45)], crs=4326) + out = grid.xrs.coregister(gdf) + direct = grid.xrs.rasterize(gdf, coregister=True) + assert np.array_equal(np.nan_to_num(out.values), np.nan_to_num(direct.values)) + assert out.shape == grid.shape + + +def test_coregister_rejects_unsupported_type(): + grid = _conus_grid() + with pytest.raises(TypeError, match='coregister expects'): + grid.xrs.coregister(123) + + +def test_coregister_raster_requires_target_crs(): + pytest.importorskip('pyproj') + # a caller grid with no detectable CRS cannot anchor the reprojection + grid = xr.DataArray( + np.zeros((10, 10), dtype='float32'), dims=['y', 'x'], + coords={'y': np.arange(10, dtype=float), 'x': np.arange(10, dtype=float)}, + ) + with pytest.raises(ValueError, match='needs a CRS'): + grid.xrs.coregister(_latlon_constant(shape=(10, 10))) + + +def test_coregister_raster_nearest_resampling_runs(): + pytest.importorskip('pyproj') + grid = _conus_grid() + out = grid.xrs.coregister(_latlon_constant(value=3.0), resampling='nearest') + assert out.shape == grid.shape + interior = out.values[np.isfinite(out.values)] + # nearest keeps the exact source value (no interpolation blur) + assert set(np.unique(interior)).issubset({3.0}) + + +def test_coregister_raster_handles_ascending_y_target(): + pytest.importorskip('pyproj') + # A caller grid stored with ascending y must still line up by geography: + # the north (large-y) row should hold the northern source value, not get + # silently flipped by the coordinate snap. + ys = np.linspace(2.7e5, 3.2e6, 40) # ascending y (EPSG:5070 metres) + xs = np.linspace(-2.3e6, 2.3e6, 60) + grid = xr.DataArray( + np.zeros((40, 60), dtype='float32'), dims=['y', 'x'], + coords={'y': ys, 'x': xs}, attrs={'crs': 5070, 'res': (76667.0, 76250.0)}, + ) + # source value == latitude, so larger value is further north + sys = np.linspace(50, 24, 50) + sxs = np.linspace(-125, -66, 70) + src = xr.DataArray( + np.repeat(sys[:, None], 70, axis=1).astype('float32'), + dims=['y', 'x'], coords={'y': sys, 'x': sxs}, attrs={'crs': 4326}, + ) + out = grid.xrs.coregister(src) + assert np.array_equal(out['y'].values, ys) + col = out.values[:, out.shape[1] // 2] + fin = col[np.isfinite(col)] + # ascending y: last row is the northernmost, so it holds the larger latitude + assert fin[-1] > fin[0] + + +def test_coregister_raster_preserves_dask_backend(): + pytest.importorskip('pyproj') + da = pytest.importorskip('dask.array') + grid = _conus_grid() + src = _latlon_constant(value=2.0) + src.data = da.from_array(src.data, chunks=(25, 35)) + out = grid.xrs.coregister(src) + assert isinstance(out.data, da.Array) + assert out.shape == grid.shape diff --git a/xrspatial/tests/test_templates.py b/xrspatial/tests/test_templates.py index 110a87f14..7bac23ad0 100644 --- a/xrspatial/tests/test_templates.py +++ b/xrspatial/tests/test_templates.py @@ -466,6 +466,52 @@ def test_legit_large_dask_grid_passes(): assert agg.data.npartitions < 1_000_000 +@dask_array_available +def test_default_dask_chunks_are_balanced_square_blocks(): + import dask.array as da + from xrspatial.templates import _DASK_BLOCK + # The default 'auto' path tiles into even, square-ish blocks (no thin edge + # slivers) sized near _DASK_BLOCK, so downstream map_overlap ops parallelize + # cleanly. conus @ 1 km is bigger than one block, so it must split. + agg = from_template("conus", resolution=1000, backend="dask") + assert isinstance(agg.data, da.Array) + assert agg.data.npartitions > 1 + for axis in agg.data.chunks: + # balanced: every block within one cell of the others, none tiny + assert max(axis) - min(axis) <= 1 + assert min(axis) > _DASK_BLOCK // 2 + + +@dask_array_available +def test_small_dask_template_stays_one_chunk(): + # A grid at or below ~1.5x the block edge is not worth splitting; it comes + # back as a single chunk just like before, so tiny templates keep zero + # task-graph overhead. + agg = from_template("conus", resolution=5000, backend="dask") + assert agg.data.npartitions == 1 + + +@dask_array_available +def test_default_tiling_block_grows_for_huge_grids(): + # At a typo-level fine resolution a fixed 2048 block would explode the graph. + # The default block edge grows instead, keeping the count under the cap so + # 'auto' never trips the guard on its own (the #3557 contract). + from xrspatial.templates import _MAX_CHUNKS, _DASK_BLOCK + agg = from_template("conus", resolution=1, backend="dask") + assert agg.data.npartitions <= _MAX_CHUNKS + # the block grew well past the nominal edge + assert agg.data.chunks[0][0] > _DASK_BLOCK + + +@dask_array_available +def test_explicit_chunks_bypass_default_tiling(): + # An explicit chunks= is honored verbatim, not replaced by the default + # tiling, so callers keep full control. + agg = from_template("conus", resolution=1000, chunks=512) + assert agg.data.chunks[0][0] == 512 + assert agg.data.chunks[1][0] == 512 + + def test_single_pixel_grid(): # a resolution coarser than the whole study-area box clamps width and height # to the max(1, ...) floor, giving a 1x1 grid that still obeys the contract.