diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index a32681c6..c9a30c5f 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -3827,7 +3827,10 @@ def read_vrt(source: str, *, ``GeoTIFFFallbackWarning`` for missing sources. """ from ._reader import _coerce_path - from ._vrt import read_vrt as _read_vrt_internal + from ._vrt import ( + read_vrt as _read_vrt_internal, + _apply_integer_sentinel_mask as _vrt_apply_integer_sentinel_mask, + ) source = _coerce_path(source) @@ -3939,83 +3942,12 @@ def read_vrt(source: str, *, # downstream code follows (``attrs['nodata']`` is present iff the # array has already been NaN-masked). # - # For multi-band reads (``band is None`` and ``arr.ndim == 3``), each - # band can declare its own ````. The float-VRT path masks - # per-band inline in ``_vrt._read_data``; mirror that here by walking - # ``vrt.bands`` and masking each ``arr[..., i]`` slice against its own - # sentinel. Before this branch, only band 0's sentinel was applied and - # bands 1+ left their integer sentinels as literal finite values in - # the returned float64 array. See issue #1611. - def _sentinel_for_dtype(nodata_val, dtype): - """Return ``dtype``-cast sentinel for ``nodata_val`` or ``None`` - if the value can't be represented in ``dtype`` (non-integer - dtype, out-of-range, non-finite, or fractional). Mirrors the - gating PR #1583 added to other read paths via - ``_int_nodata_in_range``. - - A plain Python ``int`` ``nodata_val`` is handled without going - through ``float`` first, so 64-bit sentinels such as - ``2**64 - 1`` (``UInt64`` max) and ``-2**63`` (``Int64`` min) - round-trip without the float64 rounding that pushes them just - past the dtype's representable range. ``_parse_band_nodata`` - in ``_vrt.py`` parses integer-band ```` directly - as ``int`` to feed this path. See issue #1783 follow-up. - """ - if nodata_val is None or dtype.kind not in ('u', 'i'): - return None - info = np.iinfo(dtype) - # Fast/exact path: ``nodata_val`` is already an integer. Avoids - # the float64 round-trip that loses precision near the int64 / - # uint64 extremes. ``bool`` is a subclass of ``int`` -- treat - # True/False as a 1/0 sentinel rather than rejecting outright, - # matching the existing ``int(float(...))`` behaviour. - if isinstance(nodata_val, (int, np.integer)) and not isinstance( - nodata_val, bool): - nodata_int = int(nodata_val) - if info.min <= nodata_int <= info.max: - return dtype.type(nodata_int) - return None - try: - nodata_f = float(nodata_val) - except (TypeError, ValueError): - return None - if not (np.isfinite(nodata_f) and nodata_f.is_integer() - and info.min <= nodata_f <= info.max): - return None - return dtype.type(int(nodata_f)) - - if arr.dtype.kind in ('u', 'i'): - if arr.ndim == 3 and band is None and vrt.bands: - # Per-band masking: walk ``vrt.bands`` once and stream each - # band's mask. The first band with a sentinel hit promotes - # ``arr`` to float64 in place; ``int_arr`` keeps the original - # integer view alive so subsequent bands still compare against - # the exact sentinel dtype (the post-promotion float64 view - # works too, but staying on the integer dtype avoids any - # rounding edge case on extreme sentinels). Peak boolean-mask - # memory is O(H * W), not O(bands * H * W) like the earlier - # collect-then-apply implementation. - int_arr = arr - int_dtype = int_arr.dtype - for i, vrt_band in enumerate(vrt.bands): - if i >= int_arr.shape[-1]: - break - sentinel = _sentinel_for_dtype(vrt_band.nodata, int_dtype) - if sentinel is None: - continue - mask = int_arr[..., i] == sentinel - if not mask.any(): - continue - if arr.dtype != np.float64: - arr = arr.astype(np.float64) - arr[..., i][mask] = np.nan - elif nodata is not None: - sentinel = _sentinel_for_dtype(nodata, arr.dtype) - if sentinel is not None: - mask = arr == sentinel - if mask.any(): - arr = arr.astype(np.float64) - arr[mask] = np.nan + # The helper handles both per-band masking (multi-band reads where + # each band has its own ````) and single-band masking, + # promoting ``arr`` to float64 on the first sentinel hit and writing + # NaNs in place on the promoted view. Shared with the chunked path + # (issue #1825) so behaviour stays in lockstep. See issue #1611. + arr = _vrt_apply_integer_sentinel_mask(arr, vrt, band) # Surface the source GeoTransform in the same rasterio ordering used # by open_geotiff: (pixel_width, 0, origin_x, 0, pixel_height, origin_y). @@ -4063,7 +3995,7 @@ def _sentinel_for_dtype(nodata_val, dtype): def _vrt_chunk_read(source, r0, c0, r1, c1, *, band, max_pixels, missing_sources, - declared_dtype, gpu): + declared_dtype, gpu, parsed_vrt): """Decode a single chunk window from a VRT. Called by ``dask.delayed`` from :func:`_read_vrt_chunked`. The @@ -4073,57 +4005,35 @@ def _vrt_chunk_read(source, r0, c0, r1, c1, *, dask graph declared up front, and optionally moves the block to the GPU. + ``parsed_vrt`` is the parent dispatcher's already-parsed + :class:`VRTDataset`; the internal reader skips the XML parse and + source-path containment check when this is supplied, removing the + N+1 parse cost an earlier implementation had (issue #1825). + Returning a ``numpy.ndarray`` (or ``cupy.ndarray`` when ``gpu`` is set) whose shape and dtype match the ``shape=`` / ``dtype=`` kwargs of the surrounding ``dask.array.from_delayed`` is the contract; a mismatch would silently produce a wrong-shape dask array. """ - from ._vrt import read_vrt as _read_vrt_internal + from ._vrt import ( + read_vrt as _read_vrt_internal, + _apply_integer_sentinel_mask, + ) - # TODO(#1825): this re-parses the VRT XML and re-validates source - # paths once per chunk task. Plumb the parent's parsed VRT through - # the task graph to remove the N+1 parse cost. arr, vrt = _read_vrt_internal( source, window=(r0, c0, r1, c1), band=band, max_pixels=max_pixels, missing_sources=missing_sources, + parsed=parsed_vrt, ) - # Mirror the eager post-decode integer-sentinel masking in - # ``read_vrt``. The internal reader NaN-masks float source arrays + # Mirror the eager post-decode integer-sentinel masking via the + # shared helper. The internal reader NaN-masks float source arrays # inline but leaves integer sentinels untouched, so the eager path - # promotes to float64 when sentinels hit. Apply the same logic per - # chunk; the surrounding dask graph already declared float64 when - # any band has a representable integer sentinel, so any chunk that - # actually fires the mask returns a buffer whose dtype matches the - # declared one. - if arr.dtype.kind in ('u', 'i'): - if arr.ndim == 3 and band is None and vrt.bands: - int_arr = arr - int_dtype = int_arr.dtype - for i, vrt_band in enumerate(vrt.bands): - if i >= int_arr.shape[-1]: - break - sentinel = _vrt_sentinel_for_dtype(vrt_band.nodata, int_dtype) - if sentinel is None: - continue - mask = int_arr[..., i] == sentinel - if not mask.any(): - continue - if arr.dtype != np.float64: - arr = arr.astype(np.float64) - arr[..., i][mask] = np.nan - else: - band_idx = band if band is not None else 0 - nodata = None - if vrt.bands and 0 <= band_idx < len(vrt.bands): - nodata = vrt.bands[band_idx].nodata - if nodata is not None: - sentinel = _vrt_sentinel_for_dtype(nodata, arr.dtype) - if sentinel is not None: - mask = arr == sentinel - if mask.any(): - arr = arr.astype(np.float64) - arr[mask] = np.nan + # promotes to float64 when sentinels hit. The surrounding dask graph + # already declared float64 when any band has a representable integer + # sentinel, so any chunk that actually fires the mask returns a + # buffer whose dtype matches the declared one. + arr = _apply_integer_sentinel_mask(arr, vrt, band) if declared_dtype is not None and arr.dtype != declared_dtype: arr = arr.astype(declared_dtype) @@ -4135,33 +4045,6 @@ def _vrt_chunk_read(source, r0, c0, r1, c1, *, return arr -def _vrt_sentinel_for_dtype(nodata_val, dtype): - """Return ``dtype``-cast sentinel for ``nodata_val`` or None. - - Module-level twin of the closure ``_sentinel_for_dtype`` defined - inside :func:`read_vrt`. Lifted to module scope so the per-chunk - helper :func:`_vrt_chunk_read` can call it without paying the cost - of re-binding the closure on every block. - """ - if nodata_val is None or dtype.kind not in ('u', 'i'): - return None - info = np.iinfo(dtype) - if isinstance(nodata_val, (int, np.integer)) and not isinstance( - nodata_val, bool): - nodata_int = int(nodata_val) - if info.min <= nodata_int <= info.max: - return dtype.type(nodata_int) - return None - try: - nodata_f = float(nodata_val) - except (TypeError, ValueError): - return None - if not (np.isfinite(nodata_f) and nodata_f.is_integer() - and info.min <= nodata_f <= info.max): - return None - return dtype.type(int(nodata_f)) - - def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, max_pixels, missing_sources): """Lazy ``read_vrt`` dispatch when ``chunks=`` is set (issue #1814). @@ -4188,13 +4071,21 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, import dask.array as da from ._reader import MAX_PIXELS_DEFAULT - from ._vrt import _read_vrt_xml, parse_vrt + from ._vrt import ( + parse_vrt, + _read_vrt_xml, + _effective_dtype_for_bands, + _sentinel_for_dtype, + ) # Parse the VRT XML up-front (cheap; no pixel decode). Route through # ``_read_vrt_xml`` so the 64 MiB ``XRSPATIAL_VRT_MAX_XML_BYTES`` cap # added in #1818 applies to the chunked dispatcher too; a raw # ``open().read()`` here would let a multi-GB attacker-supplied VRT # exhaust memory before any parser-side guard fires (issue #1831). + # The parsed VRTDataset is plumbed into every per-chunk task so each + # task can skip the redundant XML parse and source-path allowlist + # validation the internal reader otherwise performs (issue #1825). xml_str = _read_vrt_xml(source) vrt_dir = _os.path.dirname(_os.path.abspath(source)) vrt = parse_vrt(xml_str, vrt_dir) @@ -4277,17 +4168,18 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, "VRT has no elements; cannot determine " "output dtype") - # Compute the declared dtype. Match the internal reader's - # ``np.result_type`` over per-band effective dtypes, then widen to - # float64 only when at least one selected band declares an integer - # nodata sentinel that round-trips through the band's dtype. + # Compute the declared dtype. Share the per-band effective-dtype + # rule (ComplexSource scale/offset promotes to float64) with the + # eager path via ``_effective_dtype_for_bands`` so both paths agree + # on the result_type (issue #1825). Then widen to float64 if any + # selected band declares an integer nodata sentinel that round-trips + # through the band's dtype. # - # The eager path (``read_vrt`` at lines ~4033-4064) defers the - # promotion to runtime: it scans every band's pixels and promotes - # only if at least one sentinel hits. The chunked path cannot - # afford that scan up front (it would require decoding the mosaic - # the dask graph was constructed to defer), so this is a - # parse-time approximation. The trade-off: + # The eager path defers the promotion to runtime: it scans every + # band's pixels and promotes only if at least one sentinel hits. + # The chunked path cannot afford that scan up front (it would + # require decoding the mosaic the dask graph was constructed to + # defer), so this is a parse-time approximation. The trade-off: # * if a band declares nodata and no chunk contains the # sentinel, the chunked output is float64 where the eager # output would have stayed integer (acceptable: the user @@ -4296,25 +4188,13 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, # source integer dtype (handled by the ``promotes is False`` # fall-through below). # See also Copilot review on PR #1822. - # TODO(#1825): share dtype + scale/offset/sentinel logic with - # the eager path instead of re-implementing it here. - effective_dtypes = [] - for vrt_band in selected_bands: - eff = vrt_band.dtype - for src in vrt_band.sources: - scaled = src.scale is not None and src.scale != 1.0 - offset = src.offset is not None and src.offset != 0.0 - if scaled or offset: - eff = np.dtype(np.float64) - break - effective_dtypes.append(eff) - declared_dtype = np.result_type(*effective_dtypes) + declared_dtype = _effective_dtype_for_bands(selected_bands) if declared_dtype.kind in ('u', 'i'): promotes = False for vrt_band in selected_bands: - if _vrt_sentinel_for_dtype(vrt_band.nodata, - declared_dtype) is not None: + if _sentinel_for_dtype(vrt_band.nodata, + declared_dtype) is not None: promotes = True break if promotes: @@ -4360,6 +4240,7 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, missing_sources=missing_sources, declared_dtype=declared_dtype, gpu=gpu, + parsed_vrt=vrt, ) block = da.from_delayed(d, shape=block_shape, dtype=declared_dtype, meta=meta) @@ -4427,8 +4308,6 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, # Empty list is omitted so the attr only appears when a hole is # actually present. Each entry mirrors the eager schema: # ``{'source', 'band', 'dst_rect', 'error'}``. - # TODO(#1825): the per-task path independently re-parses and - # re-resolves source paths; refactor to share the parent's scan. chunked_holes: list[dict] = [] for vrt_band in vrt.bands: for src in vrt_band.sources: diff --git a/xrspatial/geotiff/_vrt.py b/xrspatial/geotiff/_vrt.py index 553d939f..812b5bf4 100644 --- a/xrspatial/geotiff/_vrt.py +++ b/xrspatial/geotiff/_vrt.py @@ -9,7 +9,7 @@ import os import struct import zlib -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace as _dc_replace from typing import Union from xml.sax.saxutils import escape as _xml_escape, quoteattr as _xml_quoteattr @@ -721,10 +721,140 @@ def _resample_nearest_window(src_sub: np.ndarray, return src_sub[y_idx[:, None], x_idx[None, :]] +# --------------------------------------------------------------------------- +# Shared helpers used by both the eager VRT read path (this module) and the +# chunked dask path in ``xrspatial.geotiff.__init__._read_vrt_chunked``. +# Centralised here so both call sites agree on dtype promotion, sentinel +# masking, and effective-dtype computation. See issue #1825. +# --------------------------------------------------------------------------- + + +def _sentinel_for_dtype(nodata_val, dtype): + """Return ``dtype``-cast sentinel for ``nodata_val`` or ``None``. + + ``None`` is returned when the value cannot be represented in ``dtype`` + (non-integer dtype, out-of-range, non-finite, or fractional). Mirrors + the gating used by other read paths via ``_int_nodata_in_range``. + + A plain Python ``int`` ``nodata_val`` is handled without going through + ``float`` first, so 64-bit sentinels such as ``2**64 - 1`` (``UInt64`` + max) and ``-2**63`` (``Int64`` min) round-trip without the float64 + rounding that pushes them past the dtype's representable range. + ``_parse_band_nodata`` parses integer-band ```` directly + as ``int`` to feed this path. See issue #1783 follow-up. + + This is the single shared implementation used by both the eager path + in :func:`read_vrt` and the chunked path in + :func:`xrspatial.geotiff._read_vrt_chunked`; previously a closure in + the eager path and a module-level twin in the chunked path duplicated + the logic (issue #1825). + """ + if nodata_val is None or dtype.kind not in ('u', 'i'): + return None + info = np.iinfo(dtype) + # Fast/exact path: ``nodata_val`` is already an integer. Avoids the + # float64 round-trip that loses precision near the int64 / uint64 + # extremes. ``bool`` is a subclass of ``int`` -- treat True/False as + # a 1/0 sentinel rather than rejecting outright, matching the + # existing ``int(float(...))`` behaviour. + if isinstance(nodata_val, (int, np.integer)) and not isinstance( + nodata_val, bool): + nodata_int = int(nodata_val) + if info.min <= nodata_int <= info.max: + return dtype.type(nodata_int) + return None + try: + nodata_f = float(nodata_val) + except (TypeError, ValueError): + return None + if not (np.isfinite(nodata_f) and nodata_f.is_integer() + and info.min <= nodata_f <= info.max): + return None + return dtype.type(int(nodata_f)) + + +def _effective_dtype_for_bands(selected_bands) -> np.dtype: + """Return the output buffer dtype that holds every selected band losslessly. + + Computes ``np.result_type`` over each band's effective dtype, where + each band's effective dtype is widened to ``float64`` if any of its + ``ComplexSource`` declarations apply a non-identity ``ScaleRatio`` + (``scale``) or ``ScaleOffset`` (``offset``). Mirrors the historic + inline computation in :func:`read_vrt` and matches the parse-time + declared dtype the chunked path emits up front. Issue #1825. + """ + if not selected_bands: + raise ValueError( + "VRT has no elements; cannot determine " + "output dtype" + ) + effective_dtypes = [] + for vrt_band in selected_bands: + eff = vrt_band.dtype + for src in vrt_band.sources: + scaled = src.scale is not None and src.scale != 1.0 + offset = src.offset is not None and src.offset != 0.0 + if scaled or offset: + eff = np.dtype(np.float64) + break + effective_dtypes.append(eff) + return np.result_type(*effective_dtypes) + + +def _apply_integer_sentinel_mask(arr, vrt, band): + """NaN-mask integer sentinels in a freshly decoded VRT buffer. + + Mirrors the post-decode integer-promotion branch the eager path + applies after :func:`read_vrt` and the chunked path applies inside + each per-chunk task. Walks the relevant ``vrt.bands`` entry / entries, + promotes ``arr`` to ``float64`` on the first sentinel hit, and rewrites + matching pixels to ``NaN`` in place on the promoted view. + + Returns the (possibly promoted) ``arr``. The internal reader already + NaN-masks float source arrays inline; this helper only fires for + integer-dtype outputs paired with an integer ````. + Issue #1825. + """ + if arr.dtype.kind not in ('u', 'i'): + return arr + if arr.ndim == 3 and band is None and vrt.bands: + int_arr = arr + int_dtype = int_arr.dtype + for i, vrt_band in enumerate(vrt.bands): + if i >= int_arr.shape[-1]: + break + sentinel = _sentinel_for_dtype(vrt_band.nodata, int_dtype) + if sentinel is None: + continue + mask = int_arr[..., i] == sentinel + if not mask.any(): + continue + if arr.dtype != np.float64: + arr = arr.astype(np.float64) + arr[..., i][mask] = np.nan + return arr + band_idx = band if band is not None else 0 + nodata = None + if vrt.bands and 0 <= band_idx < len(vrt.bands): + nodata = vrt.bands[band_idx].nodata + if nodata is None: + return arr + sentinel = _sentinel_for_dtype(nodata, arr.dtype) + if sentinel is None: + return arr + mask = arr == sentinel + if mask.any(): + arr = arr.astype(np.float64) + arr[mask] = np.nan + return arr + + def read_vrt(vrt_path: str, *, window=None, band: int | None = None, max_pixels: int | None = None, - missing_sources: str = 'warn') -> tuple[np.ndarray, VRTDataset]: + missing_sources: str = 'warn', + parsed: VRTDataset | None = None, + ) -> tuple[np.ndarray, VRTDataset]: """Read a VRT file by assembling pixel data from its source files. Parameters @@ -743,6 +873,13 @@ def read_vrt(vrt_path: str, *, window=None, ``'warn'`` emits ``GeoTIFFFallbackWarning`` and records ``vrt.holes`` unless ``XRSPATIAL_GEOTIFF_STRICT=1`` is set. ``'raise'`` fails immediately. + parsed : VRTDataset or None + Pre-parsed VRT structure. When supplied, ``vrt_path`` is not + re-read or re-parsed and the source-path containment check is + skipped (the supplied ``VRTDataset`` is assumed to have been + produced by :func:`parse_vrt` already, which performs the check). + Used by the chunked dask path (issue #1825) so each per-chunk + task can skip the redundant XML parse and allowlist validation. Returns ------- @@ -750,10 +887,21 @@ def read_vrt(vrt_path: str, *, window=None, """ from ._reader import PixelSafetyLimitError, read_to_array - xml_str = _read_vrt_xml(vrt_path) - - vrt_dir = os.path.dirname(os.path.abspath(vrt_path)) - vrt = parse_vrt(xml_str, vrt_dir) + if parsed is not None: + # Shallow-copy with a fresh ``holes`` list. ``read_vrt`` appends + # to ``vrt.holes`` on missing/unreadable sources, and under + # chunked dispatch (issue #1825) the same ``parsed`` instance is + # threaded into every per-chunk task. Mutating the shared list + # would leak skipped-source records across tasks (racy growth + # under the threaded scheduler, and cumulative duplication + # across calls if a caller ever reused the parsed object). The + # dataclass replace is O(1) over a handful of fields; the bands + # / sources / dtypes references are intentionally shared. + vrt = _dc_replace(parsed, holes=[]) + else: + xml_str = _read_vrt_xml(vrt_path) + vrt_dir = os.path.dirname(os.path.abspath(vrt_path)) + vrt = parse_vrt(xml_str, vrt_dir) if missing_sources not in ('warn', 'raise'): raise ValueError( f"missing_sources must be 'warn' or 'raise', got " @@ -838,23 +986,9 @@ def read_vrt(vrt_path: str, *, window=None, # Guard against a malformed VRT with zero ```` # elements: ``np.result_type()`` with no args raises a generic # "at least one array or dtype is required" message that gives the - # caller no hint about the underlying cause. - if not selected_bands: - raise ValueError( - "VRT has no elements; cannot determine " - "output dtype" - ) - effective_dtypes = [] - for vrt_band in selected_bands: - eff = vrt_band.dtype - for src in vrt_band.sources: - scaled = src.scale is not None and src.scale != 1.0 - offset = src.offset is not None and src.offset != 0.0 - if scaled or offset: - eff = np.dtype(np.float64) - break - effective_dtypes.append(eff) - dtype = np.result_type(*effective_dtypes) + # caller no hint about the underlying cause. The helper raises + # ``ValueError`` for the empty case with that explicit message. + dtype = _effective_dtype_for_bands(selected_bands) fill = np.nan if dtype.kind in ('f', 'c') else 0 if len(selected_bands) == 1: result = np.full((out_h, out_w), fill, dtype=dtype) diff --git a/xrspatial/geotiff/tests/test_vrt_single_parse_1825.py b/xrspatial/geotiff/tests/test_vrt_single_parse_1825.py new file mode 100644 index 00000000..4c5e089c --- /dev/null +++ b/xrspatial/geotiff/tests/test_vrt_single_parse_1825.py @@ -0,0 +1,266 @@ +"""Chunked ``read_vrt`` parses the VRT XML once (issue #1825). + +Before the refactor each per-chunk task re-parsed the VRT XML and +re-validated source-path containment, so an N-chunk read paid an N+1 +parse cost. The dispatcher now parses once and threads the parsed +``VRTDataset`` into every task via the dask graph, removing the +per-task XML parse and allowlist validation. + +These tests pin the new behaviour: + +* the dispatcher calls ``parse_vrt`` exactly once during construction, + and ``.compute()`` does not parse the XML again per task; +* the parsed VRT object survives pickling, so the dask graph can ship + it to workers under any scheduler; +* numerical results match the eager path byte-for-byte (regression + guard for the helper extraction). +""" +from __future__ import annotations + +import os +import pickle +import tempfile + +import dask.array as da +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import read_vrt, to_geotiff +from xrspatial.geotiff._vrt import write_vrt as _write_vrt_internal + + +@pytest.fixture +def two_by_two_vrt_1825(): + """4-tile mosaic via the to_geotiff(.vrt, ...) dask path.""" + arr = np.arange(256 * 256, dtype=np.float32).reshape(256, 256) + y = np.linspace(41.0, 40.0, 256) + x = np.linspace(-106.0, -105.0, 256) + raster = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + td = tempfile.mkdtemp(prefix='tmp_1825_2x2_') + vrt_path = os.path.join(td, 'mosaic_1825.vrt') + to_geotiff(raster, vrt_path, tile_size=128) + yield vrt_path, arr + + +@pytest.fixture +def single_tile_vrt_1825(): + """One 64x64 float32 tile wrapped in a VRT.""" + arr = np.arange(64 * 64, dtype=np.float32).reshape(64, 64) + y = np.linspace(41.0, 40.0, 64) + x = np.linspace(-106.0, -105.0, 64) + raster = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + td = tempfile.mkdtemp(prefix='tmp_1825_single_') + tile_path = os.path.join(td, 'tile_1825.tif') + to_geotiff(raster, tile_path) + vrt_path = os.path.join(td, 'single_1825.vrt') + _write_vrt_internal(vrt_path, [tile_path]) + yield vrt_path, arr + + +def test_chunked_path_parses_xml_once(monkeypatch, two_by_two_vrt_1825): + """Construction parses once, and ``.compute()`` adds zero parses. + + The previous implementation re-parsed inside every per-chunk task, + so a 4x4 chunk grid produced 17 parses total. After #1825 the + dispatcher parses once and threads the already-parsed VRTDataset + through the task graph. + """ + vrt_path, _ = two_by_two_vrt_1825 + + from xrspatial.geotiff import _vrt as vrt_module + + counter = {'parses': 0} + real_parse = vrt_module.parse_vrt + + def counting_parse(*args, **kwargs): + counter['parses'] += 1 + return real_parse(*args, **kwargs) + + monkeypatch.setattr(vrt_module, 'parse_vrt', counting_parse) + + result = read_vrt(vrt_path, chunks=(64, 64)) + + # Construction parses exactly once. + assert counter['parses'] == 1, ( + f"expected 1 parse during construction, got {counter['parses']}" + ) + + computed = result.compute() + + # 4x4 chunk grid would re-parse 16 more times under the old code. + assert counter['parses'] == 1, ( + f"expected 1 parse total (construction only); got " + f"{counter['parses']} -- per-chunk tasks are still reparsing" + ) + + # Sanity: the computed array is the original data. + assert computed.shape == (256, 256) + assert computed.dtype == np.float32 + + +def test_chunked_path_reads_xml_file_once(monkeypatch, two_by_two_vrt_1825): + """The chunked dispatcher reads the VRT XML file exactly once. + + Pin the file-read side too: before #1825 every per-chunk task + re-opened the .vrt file via ``_read_vrt_xml``. After the refactor + only the dispatcher reads it. + """ + vrt_path, _ = two_by_two_vrt_1825 + + from xrspatial.geotiff import _vrt as vrt_module + + counter = {'reads': 0} + real_read_xml = vrt_module._read_vrt_xml + + def counting_read_xml(*args, **kwargs): + counter['reads'] += 1 + return real_read_xml(*args, **kwargs) + + monkeypatch.setattr(vrt_module, '_read_vrt_xml', counting_read_xml) + + result = read_vrt(vrt_path, chunks=(64, 64)) + assert counter['reads'] == 1, ( + f"expected 1 XML file read during construction, got " + f"{counter['reads']}" + ) + + result.compute() + assert counter['reads'] == 1, ( + f"expected 1 XML file read total; got {counter['reads']} -- " + f"per-chunk tasks are still re-opening the .vrt file" + ) + + +def test_parsed_vrt_is_picklable(single_tile_vrt_1825): + """The parsed VRTDataset round-trips through pickle. + + The chunked dispatcher embeds the parsed VRT into the dask graph, + so dask must be able to serialise it for the distributed and + process-pool schedulers. Pin picklability with the stdlib pickler + (cloudpickle is a strict superset). + """ + vrt_path, _ = single_tile_vrt_1825 + from xrspatial.geotiff._vrt import parse_vrt, _read_vrt_xml + + xml_str = _read_vrt_xml(vrt_path) + vrt_dir = os.path.dirname(os.path.abspath(vrt_path)) + vrt = parse_vrt(xml_str, vrt_dir) + + blob = pickle.dumps(vrt) + restored = pickle.loads(blob) + + assert restored.width == vrt.width + assert restored.height == vrt.height + assert len(restored.bands) == len(vrt.bands) + assert restored.bands[0].dtype == vrt.bands[0].dtype + assert [s.filename for s in restored.bands[0].sources] == [ + s.filename for s in vrt.bands[0].sources + ] + + +def test_chunked_matches_eager_after_refactor(two_by_two_vrt_1825): + """Byte-identical eager vs chunked results after the helper consolidation. + + The eager path uses ``_apply_integer_sentinel_mask`` / + ``_effective_dtype_for_bands`` / ``_sentinel_for_dtype`` from + ``_vrt`` directly; the chunked path imports the same helpers. A + regression in either call site would surface here. + """ + vrt_path, original = two_by_two_vrt_1825 + eager = read_vrt(vrt_path) + chunked = read_vrt(vrt_path, chunks=(64, 64)).compute() + assert eager.dtype == chunked.dtype + np.testing.assert_array_equal(eager.values, chunked.values) + np.testing.assert_array_equal(eager.values, original) + + +def test_no_path_containment_revalidation_per_chunk(monkeypatch, + two_by_two_vrt_1825): + """Per-chunk tasks skip the source-path containment check. + + ``parse_vrt`` is the only place that resolves and validates source + paths against the VRT directory / ``XRSPATIAL_VRT_ALLOWED_ROOTS``. + Because each task now receives the already-parsed VRT, ``parse_vrt`` + must not run during ``.compute()`` even when the graph is hydrated. + """ + vrt_path, _ = two_by_two_vrt_1825 + + from xrspatial.geotiff import _vrt as vrt_module + + parse_calls = {'n': 0} + real_parse = vrt_module.parse_vrt + + def counting_parse(*args, **kwargs): + parse_calls['n'] += 1 + return real_parse(*args, **kwargs) + + monkeypatch.setattr(vrt_module, 'parse_vrt', counting_parse) + + result = read_vrt(vrt_path, chunks=(64, 64)) + parses_after_construction = parse_calls['n'] + + # Compute one block via dask's sliced API; confirm parse count + # stays at the construction baseline (no extra parses fired). + da_arr = result.data + if isinstance(da_arr, da.Array): + _block = da_arr.blocks[0, 0].compute() + assert _block.shape[0] > 0 and _block.shape[1] > 0 + + assert parse_calls['n'] == parses_after_construction, ( + f"per-block compute triggered extra parses " + f"({parse_calls['n']} vs {parses_after_construction})" + ) + + +def test_parsed_kwarg_does_not_mutate_caller_holes(single_tile_vrt_1825): + """``read_vrt(parsed=...)`` must not mutate the caller's ``holes``. + + The chunked dispatcher threads a single parsed ``VRTDataset`` into + every per-chunk task. ``read_vrt`` appends skipped-source records to + ``vrt.holes`` when a backing file is missing; without a defensive + copy the appends would land on the dispatcher's shared object and + leak across tasks (racy under the threaded scheduler, and + cumulatively across calls if a caller ever reused the parsed + object). Pin that ``parsed.holes`` stays untouched. + """ + vrt_path, _ = single_tile_vrt_1825 + from xrspatial.geotiff._vrt import ( + _read_vrt_xml, + parse_vrt, + read_vrt as _read_vrt_internal, + ) + + xml_str = _read_vrt_xml(vrt_path) + vrt_dir = os.path.dirname(os.path.abspath(vrt_path)) + parsed = parse_vrt(xml_str, vrt_dir) + + # Point the only source at a path that does not exist so the + # lenient ``missing_sources='warn'`` branch fires and would append + # a record onto ``holes``. + parsed.bands[0].sources[0].filename = os.path.join(vrt_dir, 'gone.tif') + holes_id_before = id(parsed.holes) + + import warnings + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + arr, returned = _read_vrt_internal( + vrt_path, parsed=parsed, missing_sources='warn', + ) + + assert parsed.holes == [], ( + f"parsed.holes was mutated across the read; got {parsed.holes!r}" + ) + assert id(parsed.holes) == holes_id_before, ( + "parsed.holes list object was replaced -- the caller's reference " + "is now stale" + ) + # The returned VRTDataset is the per-call working copy and DID + # collect the skipped-source record. + assert len(returned.holes) == 1 + assert returned.holes[0]['source'].endswith('gone.tif') + assert arr.shape == (64, 64)