From 565395f51ed2845de115a4c23b5620711afb2b90 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 11 May 2026 17:36:37 -0700 Subject: [PATCH 1/2] Fix _coords_to_transform for 3D (y,x,band) DataArrays (#1643) _coords_to_transform read y/x coords via dims[-2:] which on a 3D (y, x, band) DataArray picked (x, band) instead of (y, x). to_geotiff and write_geotiff_gpu silently emitted a wrong GeoTransform on the fallback path when attrs['transform'] was absent (the round-tripped file used the band axis spacing as pixel_width). The helper now skips any trailing/leading dim named band/bands/channel and uses the two remaining spatial dims. 2D inputs and 3D (band, y, x) inputs are both handled. --- xrspatial/geotiff/__init__.py | 29 ++- .../tests/test_coords_to_transform_3d_1643.py | 187 ++++++++++++++++++ 2 files changed, 214 insertions(+), 2 deletions(-) create mode 100644 xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 8495101b..693ccc0a 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -191,9 +191,34 @@ def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None: on raster_type: - PixelIsArea (default): origin = center - half_pixel (edge of pixel 0) - PixelIsPoint: origin = center (center of pixel 0) + + For 3D arrays the spatial dims are the two non-band dims. The helper + skips any trailing/leading dim named ``band`` / ``bands`` / ``channel`` + so a ``(y, x, band)`` or ``(band, y, x)`` DataArray returns the y/x + transform rather than picking up the band axis spacing as a pixel + size. ``to_geotiff`` itself remaps ``(band, y, x)`` arrays to + ``(y, x, band)`` before writing pixel bytes, but it calls + :func:`_coords_to_transform` against the original DataArray, so the + helper must handle both layouts to keep the geo-transform consistent + with the file's coord arrays. See issue #1643. """ - ydim = da.dims[-2] - xdim = da.dims[-1] + _BAND_DIM_NAMES = ('band', 'bands', 'channel') + if da.ndim == 3: + # Drop the band-like dim and keep the two spatial dims in their + # original (y, x) order. Position-based fallback covers the case + # where none of the dims are named like a band axis. + spatial = tuple(d for d in da.dims if d not in _BAND_DIM_NAMES) + if len(spatial) == 2: + ydim, xdim = spatial[0], spatial[1] + else: + # No identifiable band dim; fall back to dims[-2:] so the + # original 2-D-style behaviour applies. This branch only + # triggers for unusual 3D layouts callers built by hand. + ydim = da.dims[-2] + xdim = da.dims[-1] + else: + ydim = da.dims[-2] + xdim = da.dims[-1] if xdim not in da.coords or ydim not in da.coords: return None diff --git a/xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py b/xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py new file mode 100644 index 00000000..4c2d951a --- /dev/null +++ b/xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py @@ -0,0 +1,187 @@ +"""Regression test for issue #1643. + +``_coords_to_transform`` previously used ``dims[-2]`` and ``dims[-1]`` to +look up y/x coords. On a 3D ``(y, x, band)`` DataArray that picked +``x`` and ``band``, so ``to_geotiff`` silently wrote a wrong +GeoTransform when ``attrs['transform']`` was absent. The helper now +detects the band-like trailing/leading dim and uses the two spatial +dims regardless of position. +""" +from __future__ import annotations + +import os +import tempfile + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import _coords_to_transform, open_geotiff, to_geotiff + +try: + import cupy # noqa: F401 + HAS_CUPY = True +except ImportError: + HAS_CUPY = False + + +def _make_geo_da_3d(dims): + """3D DataArray with georeferenced y/x coords and a band axis.""" + shape = [] + for d in dims: + if d in ('y',): + shape.append(10) + elif d in ('x',): + shape.append(20) + else: + shape.append(3) + arr = np.arange(int(np.prod(shape)), dtype=np.uint8).reshape(shape) + coords = { + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + 'band': np.arange(3), + } + return xr.DataArray(arr, dims=list(dims), coords=coords) + + +def test_coords_to_transform_yxband_returns_yx_spacing(): + """3D (y, x, band) picks y/x spacing rather than (x, band) spacing.""" + da = _make_geo_da_3d(('y', 'x', 'band')) + gt = _coords_to_transform(da) + # y spacing = (200 - 100) / 9, x spacing = (700 - 500) / 19 + assert gt is not None + np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19) + np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9) + + +def test_coords_to_transform_bandyx_returns_yx_spacing(): + """3D (band, y, x) also returns the y/x transform.""" + da = _make_geo_da_3d(('band', 'y', 'x')) + gt = _coords_to_transform(da) + assert gt is not None + np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19) + np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9) + + +def test_coords_to_transform_2d_unchanged(): + """2D (y, x) keeps its original behaviour.""" + da = xr.DataArray( + np.zeros((10, 20), dtype=np.uint8), + dims=['y', 'x'], + coords={ + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + }, + ) + gt = _coords_to_transform(da) + assert gt is not None + np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19) + np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9) + + +def test_to_geotiff_roundtrip_3d_yxband_no_transform_attr(tmp_path): + """to_geotiff -> open_geotiff round-trip on 3D arrays preserves coords. + + Before the fix the on-disk transform was derived from (x, band) + spacing, so the round-tripped y/x coords had wrong pixel size and + origin. After the fix the 3D output matches the 2D output. + """ + da_3d = _make_geo_da_3d(('y', 'x', 'band')) + da_2d = xr.DataArray( + np.zeros((10, 20), dtype=np.uint8), + dims=['y', 'x'], + coords={ + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + }, + ) + + p2 = str(tmp_path / 'roundtrip_1643_2d.tif') + p3 = str(tmp_path / 'roundtrip_1643_3d.tif') + to_geotiff(da_2d, p2) + to_geotiff(da_3d, p3) + + rt2 = open_geotiff(p2) + rt3 = open_geotiff(p3) + np.testing.assert_allclose(rt3.y.values, rt2.y.values) + np.testing.assert_allclose(rt3.x.values, rt2.x.values) + assert rt3.attrs.get('transform') == rt2.attrs.get('transform') + + +def test_to_geotiff_roundtrip_3d_bandyx_no_transform_attr(tmp_path): + """(band, y, x) input round-trips with the correct transform. + + ``to_geotiff`` remaps a (band, y, x) input to (y, x, band) before + writing, but ``_coords_to_transform`` runs against the original + dim order. The fix handles both 3D layouts. + """ + da_3d = _make_geo_da_3d(('band', 'y', 'x')) + da_2d = xr.DataArray( + np.zeros((10, 20), dtype=np.uint8), + dims=['y', 'x'], + coords={ + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + }, + ) + + p2 = str(tmp_path / 'roundtrip_1643_2d_b.tif') + p3 = str(tmp_path / 'roundtrip_1643_3d_bandfirst.tif') + to_geotiff(da_2d, p2) + to_geotiff(da_3d, p3) + + rt2 = open_geotiff(p2) + rt3 = open_geotiff(p3) + np.testing.assert_allclose(rt3.y.values, rt2.y.values) + np.testing.assert_allclose(rt3.x.values, rt2.x.values) + + +def test_to_geotiff_3d_without_transform_attr_does_not_invent_unit_pixels( + tmp_path): + """Regression sanity: the bad transform was pixel_width=1.0 (band + axis spacing). Assert the round-tripped pixel_width is finite, + non-unit, and matches the source x spacing. + """ + da = _make_geo_da_3d(('y', 'x', 'band')) + p = str(tmp_path / 'roundtrip_1643_3d_not_unit.tif') + to_geotiff(da, p) + rt = open_geotiff(p) + pw = abs(float(rt.x.values[1] - rt.x.values[0])) + # Source x spacing is (700-500)/19 = ~10.526. The buggy path would + # have produced pw=1.0 (the band axis spacing). + assert pw > 1.5, ( + f"round-tripped pixel_width={pw} suggests the band-axis spacing " + f"leaked into the GeoTransform; expected ~10.526") + + +@pytest.mark.skipif(not HAS_CUPY, reason="cupy not available") +def test_write_geotiff_gpu_roundtrip_3d_no_transform_attr(tmp_path): + """GPU writer shares ``_coords_to_transform`` with the CPU writer. + + Same regression on the GPU path: a 3D ``(y, x, band)`` cupy + DataArray without ``attrs['transform']`` would previously round-trip + through a unit pixel-width transform. + """ + import cupy as cp + + from xrspatial.geotiff import write_geotiff_gpu + + np_arr = np.arange(10 * 20 * 3, dtype=np.uint8).reshape(10, 20, 3) + da = xr.DataArray( + cp.asarray(np_arr), + dims=['y', 'x', 'band'], + coords={ + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + 'band': np.arange(3), + }, + ) + p = str(tmp_path / 'roundtrip_1643_3d_gpu.tif') + write_geotiff_gpu(da, p) + rt = open_geotiff(p) + pw = abs(float(rt.x.values[1] - rt.x.values[0])) + assert pw > 1.5, ( + f"GPU writer round-tripped pixel_width={pw}; expected ~10.526") + ph = abs(float(rt.y.values[1] - rt.y.values[0])) + assert ph > 1.5, ( + f"GPU writer round-tripped pixel_height={ph}; expected ~11.111") From bd84f70d187cb1ebed70530fdb65abd7de78e542 Mon Sep 17 00:00:00 2001 From: Brendan Collins Date: Mon, 11 May 2026 17:45:17 -0700 Subject: [PATCH 2/2] Address Copilot review feedback on #1648 - Lift _BAND_DIM_NAMES to module scope and reuse at the three (band,y,x) remap sites in __init__.py to avoid drift between _coords_to_transform and the writer paths. - Reword _coords_to_transform docstring: filter is position-independent, not trailing/leading. - Drop unused os/tempfile imports from the regression test. - Replace `import cupy` guard with the repo's standard _gpu_available() pattern that also checks `cupy.cuda.is_available()` and swallows non-ImportError import failures. - Add parametrized helper coverage for 'bands' and 'channel' dim names. --- xrspatial/geotiff/__init__.py | 18 +++++--- .../tests/test_coords_to_transform_3d_1643.py | 42 +++++++++++++++---- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 693ccc0a..59b9ac37 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -68,6 +68,12 @@ _GPU_DEPRECATED_SENTINEL = object() _ON_GPU_FAILURE_SENTINEL = object() +# Names of dims that ``to_geotiff`` / ``write_geotiff_gpu`` treat as the +# non-spatial band axis. Used both to remap ``(band, y, x)`` inputs to +# ``(y, x, band)`` before writing and to skip the band axis when inferring +# a GeoTransform from coords (see ``_coords_to_transform`` and issue #1643). +_BAND_DIM_NAMES = ('band', 'bands', 'channel') + def _wkt_to_epsg(wkt_or_proj: str) -> int | None: """Try to extract an EPSG code from a WKT or PROJ string. @@ -193,8 +199,9 @@ def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None: - PixelIsPoint: origin = center (center of pixel 0) For 3D arrays the spatial dims are the two non-band dims. The helper - skips any trailing/leading dim named ``band`` / ``bands`` / ``channel`` - so a ``(y, x, band)`` or ``(band, y, x)`` DataArray returns the y/x + filters out any dim named ``band`` / ``bands`` / ``channel`` (see + ``_BAND_DIM_NAMES``) regardless of position, so a ``(y, x, band)``, + ``(band, y, x)``, or ``(y, band, x)`` DataArray returns the y/x transform rather than picking up the band axis spacing as a pixel size. ``to_geotiff`` itself remaps ``(band, y, x)`` arrays to ``(y, x, band)`` before writing pixel bytes, but it calls @@ -202,7 +209,6 @@ def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None: helper must handle both layouts to keep the geo-transform consistent with the file's coord arrays. See issue #1643. """ - _BAND_DIM_NAMES = ('band', 'bands', 'channel') if da.ndim == 3: # Drop the band-like dim and keep the two spatial dims in their # original (y, x) order. Position-based fallback covers the case @@ -1191,7 +1197,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path, *, if hasattr(raw, 'dask') and not cog and not _path_is_file_like: dask_arr = raw # Handle band-first dimension order (band, y, x) -> (y, x, band) - if raw.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'): + if raw.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES: import dask.array as da dask_arr = da.moveaxis(raw, 0, -1) if dask_arr.ndim not in (2, 3): @@ -1240,7 +1246,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path, *, else: arr = np.asarray(raw) # Handle band-first dimension order (band, y, x) -> (y, x, band) - if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'): + if arr.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES: arr = np.moveaxis(arr, 0, -1) else: if hasattr(data, 'get'): @@ -2855,7 +2861,7 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray, # this remap the writer treats arr.shape[2] as the band axis and # produces a transposed file (issue #1580). The CPU writer does # the same remap at the matching step in to_geotiff(). - if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'): + if arr.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES: arr = cupy.ascontiguousarray(cupy.moveaxis(arr, 0, -1)) # Prefer attrs['transform'] over the coord-derived transform: it diff --git a/xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py b/xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py index 4c2d951a..c7cbb721 100644 --- a/xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py +++ b/xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py @@ -9,8 +9,7 @@ """ from __future__ import annotations -import os -import tempfile +import importlib.util import numpy as np import pytest @@ -18,11 +17,18 @@ from xrspatial.geotiff import _coords_to_transform, open_geotiff, to_geotiff -try: - import cupy # noqa: F401 - HAS_CUPY = True -except ImportError: - HAS_CUPY = False + +def _gpu_available() -> bool: + if importlib.util.find_spec("cupy") is None: + return False + try: + import cupy + return bool(cupy.cuda.is_available()) + except Exception: + return False + + +_HAS_GPU = _gpu_available() def _make_geo_da_3d(dims): @@ -63,6 +69,26 @@ def test_coords_to_transform_bandyx_returns_yx_spacing(): np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9) +@pytest.mark.parametrize('band_name', ['band', 'bands', 'channel']) +def test_coords_to_transform_3d_band_name_variants(band_name): + """All recognized band-dim names (band, bands, channel) are filtered + out when picking the y/x spatial dims.""" + arr = np.zeros((10, 20, 3), dtype=np.uint8) + da = xr.DataArray( + arr, + dims=['y', 'x', band_name], + coords={ + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + band_name: np.arange(3), + }, + ) + gt = _coords_to_transform(da) + assert gt is not None + np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19) + np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9) + + def test_coords_to_transform_2d_unchanged(): """2D (y, x) keeps its original behaviour.""" da = xr.DataArray( @@ -154,7 +180,7 @@ def test_to_geotiff_3d_without_transform_attr_does_not_invent_unit_pixels( f"leaked into the GeoTransform; expected ~10.526") -@pytest.mark.skipif(not HAS_CUPY, reason="cupy not available") +@pytest.mark.skipif(not _HAS_GPU, reason="cupy + CUDA required") def test_write_geotiff_gpu_roundtrip_3d_no_transform_attr(tmp_path): """GPU writer shares ``_coords_to_transform`` with the CPU writer.