diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index 8495101b..59b9ac37 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -68,6 +68,12 @@ _GPU_DEPRECATED_SENTINEL = object() _ON_GPU_FAILURE_SENTINEL = object() +# Names of dims that ``to_geotiff`` / ``write_geotiff_gpu`` treat as the +# non-spatial band axis. Used both to remap ``(band, y, x)`` inputs to +# ``(y, x, band)`` before writing and to skip the band axis when inferring +# a GeoTransform from coords (see ``_coords_to_transform`` and issue #1643). +_BAND_DIM_NAMES = ('band', 'bands', 'channel') + def _wkt_to_epsg(wkt_or_proj: str) -> int | None: """Try to extract an EPSG code from a WKT or PROJ string. @@ -191,9 +197,34 @@ def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None: on raster_type: - PixelIsArea (default): origin = center - half_pixel (edge of pixel 0) - PixelIsPoint: origin = center (center of pixel 0) + + For 3D arrays the spatial dims are the two non-band dims. The helper + filters out any dim named ``band`` / ``bands`` / ``channel`` (see + ``_BAND_DIM_NAMES``) regardless of position, so a ``(y, x, band)``, + ``(band, y, x)``, or ``(y, band, x)`` DataArray returns the y/x + transform rather than picking up the band axis spacing as a pixel + size. ``to_geotiff`` itself remaps ``(band, y, x)`` arrays to + ``(y, x, band)`` before writing pixel bytes, but it calls + :func:`_coords_to_transform` against the original DataArray, so the + helper must handle both layouts to keep the geo-transform consistent + with the file's coord arrays. See issue #1643. """ - ydim = da.dims[-2] - xdim = da.dims[-1] + if da.ndim == 3: + # Drop the band-like dim and keep the two spatial dims in their + # original (y, x) order. Position-based fallback covers the case + # where none of the dims are named like a band axis. + spatial = tuple(d for d in da.dims if d not in _BAND_DIM_NAMES) + if len(spatial) == 2: + ydim, xdim = spatial[0], spatial[1] + else: + # No identifiable band dim; fall back to dims[-2:] so the + # original 2-D-style behaviour applies. This branch only + # triggers for unusual 3D layouts callers built by hand. + ydim = da.dims[-2] + xdim = da.dims[-1] + else: + ydim = da.dims[-2] + xdim = da.dims[-1] if xdim not in da.coords or ydim not in da.coords: return None @@ -1166,7 +1197,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path, *, if hasattr(raw, 'dask') and not cog and not _path_is_file_like: dask_arr = raw # Handle band-first dimension order (band, y, x) -> (y, x, band) - if raw.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'): + if raw.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES: import dask.array as da dask_arr = da.moveaxis(raw, 0, -1) if dask_arr.ndim not in (2, 3): @@ -1215,7 +1246,7 @@ def to_geotiff(data: xr.DataArray | np.ndarray, path, *, else: arr = np.asarray(raw) # Handle band-first dimension order (band, y, x) -> (y, x, band) - if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'): + if arr.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES: arr = np.moveaxis(arr, 0, -1) else: if hasattr(data, 'get'): @@ -2830,7 +2861,7 @@ def write_geotiff_gpu(data: xr.DataArray | cupy.ndarray | np.ndarray, # this remap the writer treats arr.shape[2] as the band axis and # produces a transposed file (issue #1580). The CPU writer does # the same remap at the matching step in to_geotiff(). - if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'): + if arr.ndim == 3 and data.dims[0] in _BAND_DIM_NAMES: arr = cupy.ascontiguousarray(cupy.moveaxis(arr, 0, -1)) # Prefer attrs['transform'] over the coord-derived transform: it diff --git a/xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py b/xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py new file mode 100644 index 00000000..c7cbb721 --- /dev/null +++ b/xrspatial/geotiff/tests/test_coords_to_transform_3d_1643.py @@ -0,0 +1,213 @@ +"""Regression test for issue #1643. + +``_coords_to_transform`` previously used ``dims[-2]`` and ``dims[-1]`` to +look up y/x coords. On a 3D ``(y, x, band)`` DataArray that picked +``x`` and ``band``, so ``to_geotiff`` silently wrote a wrong +GeoTransform when ``attrs['transform']`` was absent. The helper now +detects the band-like trailing/leading dim and uses the two spatial +dims regardless of position. +""" +from __future__ import annotations + +import importlib.util + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import _coords_to_transform, open_geotiff, to_geotiff + + +def _gpu_available() -> bool: + if importlib.util.find_spec("cupy") is None: + return False + try: + import cupy + return bool(cupy.cuda.is_available()) + except Exception: + return False + + +_HAS_GPU = _gpu_available() + + +def _make_geo_da_3d(dims): + """3D DataArray with georeferenced y/x coords and a band axis.""" + shape = [] + for d in dims: + if d in ('y',): + shape.append(10) + elif d in ('x',): + shape.append(20) + else: + shape.append(3) + arr = np.arange(int(np.prod(shape)), dtype=np.uint8).reshape(shape) + coords = { + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + 'band': np.arange(3), + } + return xr.DataArray(arr, dims=list(dims), coords=coords) + + +def test_coords_to_transform_yxband_returns_yx_spacing(): + """3D (y, x, band) picks y/x spacing rather than (x, band) spacing.""" + da = _make_geo_da_3d(('y', 'x', 'band')) + gt = _coords_to_transform(da) + # y spacing = (200 - 100) / 9, x spacing = (700 - 500) / 19 + assert gt is not None + np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19) + np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9) + + +def test_coords_to_transform_bandyx_returns_yx_spacing(): + """3D (band, y, x) also returns the y/x transform.""" + da = _make_geo_da_3d(('band', 'y', 'x')) + gt = _coords_to_transform(da) + assert gt is not None + np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19) + np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9) + + +@pytest.mark.parametrize('band_name', ['band', 'bands', 'channel']) +def test_coords_to_transform_3d_band_name_variants(band_name): + """All recognized band-dim names (band, bands, channel) are filtered + out when picking the y/x spatial dims.""" + arr = np.zeros((10, 20, 3), dtype=np.uint8) + da = xr.DataArray( + arr, + dims=['y', 'x', band_name], + coords={ + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + band_name: np.arange(3), + }, + ) + gt = _coords_to_transform(da) + assert gt is not None + np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19) + np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9) + + +def test_coords_to_transform_2d_unchanged(): + """2D (y, x) keeps its original behaviour.""" + da = xr.DataArray( + np.zeros((10, 20), dtype=np.uint8), + dims=['y', 'x'], + coords={ + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + }, + ) + gt = _coords_to_transform(da) + assert gt is not None + np.testing.assert_allclose(gt.pixel_width, (700.0 - 500.0) / 19) + np.testing.assert_allclose(gt.pixel_height, (200.0 - 100.0) / 9) + + +def test_to_geotiff_roundtrip_3d_yxband_no_transform_attr(tmp_path): + """to_geotiff -> open_geotiff round-trip on 3D arrays preserves coords. + + Before the fix the on-disk transform was derived from (x, band) + spacing, so the round-tripped y/x coords had wrong pixel size and + origin. After the fix the 3D output matches the 2D output. + """ + da_3d = _make_geo_da_3d(('y', 'x', 'band')) + da_2d = xr.DataArray( + np.zeros((10, 20), dtype=np.uint8), + dims=['y', 'x'], + coords={ + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + }, + ) + + p2 = str(tmp_path / 'roundtrip_1643_2d.tif') + p3 = str(tmp_path / 'roundtrip_1643_3d.tif') + to_geotiff(da_2d, p2) + to_geotiff(da_3d, p3) + + rt2 = open_geotiff(p2) + rt3 = open_geotiff(p3) + np.testing.assert_allclose(rt3.y.values, rt2.y.values) + np.testing.assert_allclose(rt3.x.values, rt2.x.values) + assert rt3.attrs.get('transform') == rt2.attrs.get('transform') + + +def test_to_geotiff_roundtrip_3d_bandyx_no_transform_attr(tmp_path): + """(band, y, x) input round-trips with the correct transform. + + ``to_geotiff`` remaps a (band, y, x) input to (y, x, band) before + writing, but ``_coords_to_transform`` runs against the original + dim order. The fix handles both 3D layouts. + """ + da_3d = _make_geo_da_3d(('band', 'y', 'x')) + da_2d = xr.DataArray( + np.zeros((10, 20), dtype=np.uint8), + dims=['y', 'x'], + coords={ + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + }, + ) + + p2 = str(tmp_path / 'roundtrip_1643_2d_b.tif') + p3 = str(tmp_path / 'roundtrip_1643_3d_bandfirst.tif') + to_geotiff(da_2d, p2) + to_geotiff(da_3d, p3) + + rt2 = open_geotiff(p2) + rt3 = open_geotiff(p3) + np.testing.assert_allclose(rt3.y.values, rt2.y.values) + np.testing.assert_allclose(rt3.x.values, rt2.x.values) + + +def test_to_geotiff_3d_without_transform_attr_does_not_invent_unit_pixels( + tmp_path): + """Regression sanity: the bad transform was pixel_width=1.0 (band + axis spacing). Assert the round-tripped pixel_width is finite, + non-unit, and matches the source x spacing. + """ + da = _make_geo_da_3d(('y', 'x', 'band')) + p = str(tmp_path / 'roundtrip_1643_3d_not_unit.tif') + to_geotiff(da, p) + rt = open_geotiff(p) + pw = abs(float(rt.x.values[1] - rt.x.values[0])) + # Source x spacing is (700-500)/19 = ~10.526. The buggy path would + # have produced pw=1.0 (the band axis spacing). + assert pw > 1.5, ( + f"round-tripped pixel_width={pw} suggests the band-axis spacing " + f"leaked into the GeoTransform; expected ~10.526") + + +@pytest.mark.skipif(not _HAS_GPU, reason="cupy + CUDA required") +def test_write_geotiff_gpu_roundtrip_3d_no_transform_attr(tmp_path): + """GPU writer shares ``_coords_to_transform`` with the CPU writer. + + Same regression on the GPU path: a 3D ``(y, x, band)`` cupy + DataArray without ``attrs['transform']`` would previously round-trip + through a unit pixel-width transform. + """ + import cupy as cp + + from xrspatial.geotiff import write_geotiff_gpu + + np_arr = np.arange(10 * 20 * 3, dtype=np.uint8).reshape(10, 20, 3) + da = xr.DataArray( + cp.asarray(np_arr), + dims=['y', 'x', 'band'], + coords={ + 'y': np.linspace(100.0, 200.0, 10), + 'x': np.linspace(500.0, 700.0, 20), + 'band': np.arange(3), + }, + ) + p = str(tmp_path / 'roundtrip_1643_3d_gpu.tif') + write_geotiff_gpu(da, p) + rt = open_geotiff(p) + pw = abs(float(rt.x.values[1] - rt.x.values[0])) + assert pw > 1.5, ( + f"GPU writer round-tripped pixel_width={pw}; expected ~10.526") + ph = abs(float(rt.y.values[1] - rt.y.values[0])) + assert ph > 1.5, ( + f"GPU writer round-tripped pixel_height={ph}; expected ~11.111")