diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py index d3485919..e76b31d1 100644 --- a/xrspatial/geotiff/__init__.py +++ b/xrspatial/geotiff/__init__.py @@ -3947,185 +3947,6 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp): _write_bytes(file_bytes, path) -def _vrt_effective_dtype(vrt, band): - """Return the dtype a VRT read is expected to materialize.""" - selected = [vrt.bands[band]] if band is not None else vrt.bands - if not selected: - raise ValueError( - "VRT has no elements; cannot determine " - "output dtype" - ) - effective = [] - for vrt_band in selected: - dt = vrt_band.dtype - for src in vrt_band.sources: - scaled = src.scale is not None and src.scale != 1.0 - offset = src.offset is not None and src.offset != 0.0 - if scaled or offset: - dt = np.dtype(np.float64) - break - if dt.kind in ('u', 'i') and vrt_band.nodata is not None: - try: - if isinstance(vrt_band.nodata, (int, np.integer)): - nd = int(vrt_band.nodata) - else: - nf = float(vrt_band.nodata) - nd = int(nf) if np.isfinite(nf) and nf.is_integer() else None - if nd is not None: - info = np.iinfo(dt) - if info.min <= nd <= info.max: - dt = np.dtype(np.float64) - except (TypeError, ValueError): - pass - effective.append(dt) - return np.result_type(*effective) - - -def _read_vrt_dask(source: str, *, dtype=None, window=None, band=None, - name=None, chunks=None, max_pixels=None, - missing_sources='warn'): - """Build a truly lazy dask-backed VRT DataArray from window tasks.""" - import os - import dask - import dask.array as da - from ._reader import _check_dimensions, MAX_PIXELS_DEFAULT - from ._vrt import parse_vrt - - with open(source, 'r') as f: - xml_str = f.read() - vrt_dir = os.path.dirname(os.path.abspath(source)) - vrt = parse_vrt(xml_str, vrt_dir) - - if band is not None: - if not isinstance(band, (int, np.integer)) or isinstance(band, bool): - raise ValueError(f"band must be a non-negative int, got {band!r}") - if band < 0 or band >= len(vrt.bands): - raise ValueError( - f"band index {band} out of range for VRT with " - f"{len(vrt.bands)} band(s)") - - if window is not None: - win_r0, win_c0, win_r1, win_c1 = window - if (win_r0 < 0 or win_c0 < 0 - or win_r1 > vrt.height or win_c1 > vrt.width - or win_r0 >= win_r1 or win_c0 >= win_c1): - raise ValueError( - f"window={window} is outside the VRT extent " - f"({vrt.height}x{vrt.width}) or has non-positive size.") - else: - win_r0, win_c0, win_r1, win_c1 = 0, 0, vrt.height, vrt.width - - height = win_r1 - win_r0 - width = win_c1 - win_c0 - n_bands = len([vrt.bands[band]] if band is not None else vrt.bands) - if max_pixels is None: - max_pixels = MAX_PIXELS_DEFAULT - _check_dimensions(width, height, n_bands, max_pixels) - - out_dtype = np.dtype(dtype) if dtype is not None else _vrt_effective_dtype(vrt, band) - if dtype is not None: - _validate_dtype_cast(_vrt_effective_dtype(vrt, band), out_dtype) - - if isinstance(chunks, int): - ch_h = ch_w = chunks - else: - ch_h, ch_w = chunks - - # Match read_geotiff_dask's graph-size guard. Each VRT chunk becomes a - # delayed task, so tiny chunks over very large VRT extents can OOM the - # driver during graph construction before any source read executes. - _MAX_DASK_CHUNKS = 50_000 - n_chunks = ((height + ch_h - 1) // ch_h) * ((width + ch_w - 1) // ch_w) - if n_chunks > _MAX_DASK_CHUNKS: - import math - scale = math.sqrt(n_chunks / _MAX_DASK_CHUNKS) - suggested_h = int(math.ceil(ch_h * scale)) - suggested_w = int(math.ceil(ch_w * scale)) - raise ValueError( - f"read_vrt: chunks=({ch_h}, {ch_w}) on a {height}x{width} " - f"VRT window would produce {n_chunks:,} dask tasks, exceeding " - f"the {_MAX_DASK_CHUNKS:,}-task cap. Pass a larger chunks=... " - f"value explicitly (e.g. chunks=({suggested_h}, " - f"{suggested_w}) keeps the task count under the cap)." - ) - - rows = list(range(0, height, ch_h)) - cols = list(range(0, width, ch_w)) - out_has_band_axis = band is None and n_bands > 1 - - @dask.delayed - def _read_chunk(chunk_window): - chunk_da = read_vrt( - source, dtype=dtype, window=chunk_window, band=band, - chunks=None, gpu=False, max_pixels=max_pixels, - missing_sources=missing_sources, - ) - arr = np.asarray(chunk_da.values) - if arr.dtype != out_dtype: - arr = arr.astype(out_dtype) - return arr - - dask_rows = [] - for r0 in rows: - r1 = min(r0 + ch_h, height) - dask_cols = [] - for c0 in cols: - c1 = min(c0 + ch_w, width) - chunk_window = (r0 + win_r0, c0 + win_c0, - r1 + win_r0, c1 + win_c0) - shape = ((r1 - r0, c1 - c0, n_bands) - if out_has_band_axis else (r1 - r0, c1 - c0)) - dask_cols.append(da.from_delayed( - _read_chunk(chunk_window), shape=shape, dtype=out_dtype)) - dask_rows.append(da.concatenate(dask_cols, axis=1)) - dask_arr = da.concatenate(dask_rows, axis=0) - - coords = {} - gt = vrt.geo_transform - if gt is not None: - origin_x, res_x, _, origin_y, _, res_y = gt - if vrt.raster_type == 'point': - x_shift = win_c0 * res_x - y_shift = win_r0 * res_y - else: - x_shift = (win_c0 + 0.5) * res_x - y_shift = (win_r0 + 0.5) * res_y - coords = { - 'x': np.arange(width, dtype=np.float64) * res_x + origin_x + x_shift, - 'y': np.arange(height, dtype=np.float64) * res_y + origin_y + y_shift, - } - - attrs = {} - if vrt.crs_wkt: - epsg = _wkt_to_epsg(vrt.crs_wkt) - if epsg is not None: - attrs['crs'] = epsg - attrs['crs_wkt'] = vrt.crs_wkt - if vrt.raster_type == 'point': - attrs['raster_type'] = 'point' - if vrt.bands: - band_idx_for_nodata = band if band is not None else 0 - nodata = vrt.bands[band_idx_for_nodata].nodata - if nodata is not None: - attrs['nodata'] = nodata - if gt is not None: - origin_x, res_x, _, origin_y, _, res_y = gt - attrs['transform'] = ( - float(res_x), 0.0, float(origin_x) + win_c0 * float(res_x), - 0.0, float(res_y), float(origin_y) + win_r0 * float(res_y), - ) - - if name is None: - name = os.path.splitext(os.path.basename(source))[0] - if out_has_band_axis: - dims = ['y', 'x', 'band'] - coords['band'] = np.arange(n_bands) - else: - dims = ['y', 'x'] - return xr.DataArray(dask_arr, dims=dims, coords=coords, - name=name, attrs=attrs) - - def read_vrt(source: str, *, dtype: str | np.dtype | None = None, window: tuple | None = None, @@ -4195,6 +4016,15 @@ def read_vrt(source: str, *, ``relativeToVRT='1'`` source that escapes the VRT directory (e.g. ``../../etc/passwd`` or a symlink to a file outside the directory) is rejected regardless of the allowlist. + + Lazy chunked reads (issue #1814): when ``chunks=`` is set, the + returned DataArray wraps a dask graph that decodes one chunk + window per task. Construction does not materialise any pixels; + only the VRT XML is parsed. The eager read populates + ``attrs['vrt_holes']`` from skipped sources; the chunked path does + not aggregate per-task hole records, so that attribute is not set + when ``chunks=`` is used. Each worker still emits + ``GeoTIFFFallbackWarning`` for missing sources. """ from ._reader import _coerce_path from ._vrt import read_vrt as _read_vrt_internal @@ -4213,10 +4043,23 @@ def read_vrt(source: str, *, f"missing_sources must be 'warn' or 'raise', got " f"{missing_sources!r}") - if chunks is not None and not gpu: - return _read_vrt_dask( - source, dtype=dtype, window=window, band=band, name=name, - chunks=chunks, max_pixels=max_pixels, + # Lazy chunked path (issue #1814). The eager call below materialises + # the full mosaic on host RAM and then wraps the array via + # ``.chunk()``, so chunks= gave no memory protection and gpu=True + + # chunks= still assembled the full mosaic on the CPU before moving to + # the device. When chunks= is set, dispatch to a delayed-per-window + # builder so each task decodes only the sources intersecting its + # destination window. + if chunks is not None: + return _read_vrt_chunked( + source, + window=window, + band=band, + name=name, + chunks=chunks, + gpu=gpu, + dtype=dtype, + max_pixels=max_pixels, missing_sources=missing_sources, ) @@ -4414,14 +4257,420 @@ def _sentinel_for_dtype(nodata_val, dtype): result = xr.DataArray(arr, dims=dims, coords=coords, name=name, attrs=attrs) - # Chunk for Dask (or Dask+CuPy if gpu=True) - if chunks is not None: - if isinstance(chunks, int): - chunk_dict = {'y': chunks, 'x': chunks} + # ``chunks is not None`` is handled by ``_read_vrt_chunked`` higher up + # in this function (issue #1814); reaching this point implies the + # eager path, so no post-decode chunking is needed. + return result + + +# Hard cap on the per-VRT chunk task count. Matches the +# ``_MAX_DASK_CHUNKS`` value used by ``read_geotiff_dask`` so the two +# entry points refuse the same scheduler-busting chunk grids. See +# issue #1814. +_MAX_VRT_DASK_CHUNKS = 50_000 + + +def _vrt_chunk_read(source, r0, c0, r1, c1, *, + band, max_pixels, missing_sources, + declared_dtype, gpu): + """Decode a single chunk window from a VRT. + + Called by ``dask.delayed`` from :func:`_read_vrt_chunked`. The + function reads only the destination window via the existing VRT + internal reader, applies the same integer-sentinel masking the + eager :func:`read_vrt` does post-decode, casts to the dtype the + dask graph declared up front, and optionally moves the block to + the GPU. + + Returning a ``numpy.ndarray`` (or ``cupy.ndarray`` when ``gpu`` is + set) whose shape and dtype match the ``shape=`` / ``dtype=`` kwargs + of the surrounding ``dask.array.from_delayed`` is the contract; a + mismatch would silently produce a wrong-shape dask array. + """ + from ._vrt import read_vrt as _read_vrt_internal + + # TODO(#1825): this re-parses the VRT XML and re-validates source + # paths once per chunk task. Plumb the parent's parsed VRT through + # the task graph to remove the N+1 parse cost. + arr, vrt = _read_vrt_internal( + source, window=(r0, c0, r1, c1), band=band, + max_pixels=max_pixels, missing_sources=missing_sources, + ) + + # Mirror the eager post-decode integer-sentinel masking in + # ``read_vrt``. The internal reader NaN-masks float source arrays + # inline but leaves integer sentinels untouched, so the eager path + # promotes to float64 when sentinels hit. Apply the same logic per + # chunk; the surrounding dask graph already declared float64 when + # any band has a representable integer sentinel, so any chunk that + # actually fires the mask returns a buffer whose dtype matches the + # declared one. + if arr.dtype.kind in ('u', 'i'): + if arr.ndim == 3 and band is None and vrt.bands: + int_arr = arr + int_dtype = int_arr.dtype + for i, vrt_band in enumerate(vrt.bands): + if i >= int_arr.shape[-1]: + break + sentinel = _vrt_sentinel_for_dtype(vrt_band.nodata, int_dtype) + if sentinel is None: + continue + mask = int_arr[..., i] == sentinel + if not mask.any(): + continue + if arr.dtype != np.float64: + arr = arr.astype(np.float64) + arr[..., i][mask] = np.nan else: - chunk_dict = {'y': chunks[0], 'x': chunks[1]} - result = result.chunk(chunk_dict) + band_idx = band if band is not None else 0 + nodata = None + if vrt.bands and 0 <= band_idx < len(vrt.bands): + nodata = vrt.bands[band_idx].nodata + if nodata is not None: + sentinel = _vrt_sentinel_for_dtype(nodata, arr.dtype) + if sentinel is not None: + mask = arr == sentinel + if mask.any(): + arr = arr.astype(np.float64) + arr[mask] = np.nan + + if declared_dtype is not None and arr.dtype != declared_dtype: + arr = arr.astype(declared_dtype) + + if gpu: + import cupy + arr = cupy.asarray(arr) + + return arr + + +def _vrt_sentinel_for_dtype(nodata_val, dtype): + """Return ``dtype``-cast sentinel for ``nodata_val`` or None. + + Module-level twin of the closure ``_sentinel_for_dtype`` defined + inside :func:`read_vrt`. Lifted to module scope so the per-chunk + helper :func:`_vrt_chunk_read` can call it without paying the cost + of re-binding the closure on every block. + """ + if nodata_val is None or dtype.kind not in ('u', 'i'): + return None + info = np.iinfo(dtype) + if isinstance(nodata_val, (int, np.integer)) and not isinstance( + nodata_val, bool): + nodata_int = int(nodata_val) + if info.min <= nodata_int <= info.max: + return dtype.type(nodata_int) + return None + try: + nodata_f = float(nodata_val) + except (TypeError, ValueError): + return None + if not (np.isfinite(nodata_f) and nodata_f.is_integer() + and info.min <= nodata_f <= info.max): + return None + return dtype.type(int(nodata_f)) + + +def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype, + max_pixels, missing_sources): + """Lazy ``read_vrt`` dispatch when ``chunks=`` is set (issue #1814). + + Parses the VRT XML once to recover the extent, CRS, GeoTransform, + and per-band metadata, then builds a dask graph with one task per + chunk window. Each task calls into the existing VRT internal reader + with its own ``window=`` so only the sources intersecting the + chunk's destination rectangle are decoded. + + ``attrs['vrt_holes']`` is populated from a parse-time + ``os.path.exists`` sweep over every source referenced by the parsed + VRT; this preserves the eager-path contract documented in #1734 so + callers switching from eager to chunked can still detect partial + mosaics by attribute lookup (rather than monitoring the + ``GeoTIFFFallbackWarning`` stream). The check is a static + approximation of the eager path's per-source decode-time exception + handling: it catches the dominant "missing file" case but does not + detect decode-time codec failures, which surface as per-task + ``GeoTIFFFallbackWarning`` from each worker. + """ + import os as _os + import dask + import dask.array as da + + from ._reader import MAX_PIXELS_DEFAULT + from ._vrt import parse_vrt + + # Parse the VRT XML up-front (cheap; no pixel decode). + with open(source, 'r') as f: + xml_str = f.read() + vrt_dir = _os.path.dirname(_os.path.abspath(source)) + vrt = parse_vrt(xml_str, vrt_dir) + + # Validate ``band`` against the parsed band count, matching the + # internal reader's contract so the failure mode is the same whether + # the user reads eagerly or chunked. + if band is not None: + if not isinstance(band, (int, np.integer)) or isinstance(band, bool): + raise ValueError( + f"band must be a non-negative int, got {band!r}") + if band < 0 or band >= len(vrt.bands): + raise ValueError( + f"band index {band} out of range for VRT with " + f"{len(vrt.bands)} band(s)") + + # Resolve the windowed extent against the VRT. + if window is not None: + r0, c0, r1, c1 = window + if (r0 < 0 or c0 < 0 + or r1 > vrt.height or c1 > vrt.width + or r0 >= r1 or c0 >= c1): + raise ValueError( + f"window={window} is outside the VRT extent " + f"({vrt.height}x{vrt.width}) or has non-positive size.") + win_r0, win_c0 = r0, c0 + full_h, full_w = r1 - r0, c1 - c0 + else: + win_r0, win_c0 = 0, 0 + full_h, full_w = vrt.height, vrt.width + + max_pixels_effective = ( + max_pixels if max_pixels is not None else MAX_PIXELS_DEFAULT + ) + + # Up-front pixel-count guard against the windowed extent. Mirrors + # the eager ``_vrt.read_vrt`` (which calls ``_check_dimensions`` on + # the full output shape) and ``read_geotiff_dask`` (which guards + # ``full_h * full_w * eff_bands`` before scheduling any task). Each + # chunk task additionally re-checks via ``max_pixels`` through the + # internal reader, but catching an oversized request up front saves + # the caller from a misleading per-chunk error. + eff_bands = 1 if band is not None else max(1, len(vrt.bands)) + if full_h * full_w * eff_bands > max_pixels_effective: + raise ValueError( + f"Requested region {full_h}x{full_w}x{eff_bands} exceeds " + f"max_pixels={max_pixels_effective:,}.") + + if isinstance(chunks, int): + ch_h = ch_w = chunks + else: + ch_h, ch_w = chunks + + # Refuse chunk grids that would build more tasks than the scheduler + # can hold without OOMing the driver. ``read_geotiff_dask`` uses the + # same cap with the same suggestion logic (see issue #1814 and the + # ``_MAX_DASK_CHUNKS`` guard upstream). + n_chunks = ((full_h + ch_h - 1) // ch_h) * ((full_w + ch_w - 1) // ch_w) + if n_chunks > _MAX_VRT_DASK_CHUNKS: + scale = math.sqrt(n_chunks / _MAX_VRT_DASK_CHUNKS) + suggested_h = int(math.ceil(ch_h * scale)) + suggested_w = int(math.ceil(ch_w * scale)) + raise ValueError( + f"read_vrt: chunks=({ch_h}, {ch_w}) on a " + f"{full_h}x{full_w} VRT region would produce {n_chunks:,} " + f"dask tasks, exceeding the {_MAX_VRT_DASK_CHUNKS:,}-task " + f"cap. Pass a larger chunks=... value explicitly (e.g. " + f"chunks=({suggested_h}, {suggested_w}) keeps the task " + f"count under the cap)." + ) + + # Select bands for shape/dtype declaration. + if band is not None: + selected_bands = [vrt.bands[band]] + else: + selected_bands = vrt.bands + + if not selected_bands: + raise ValueError( + "VRT has no elements; cannot determine " + "output dtype") + # Compute the declared dtype. Match the internal reader's + # ``np.result_type`` over per-band effective dtypes, then widen to + # float64 only when at least one selected band declares an integer + # nodata sentinel that round-trips through the band's dtype. + # + # The eager path (``read_vrt`` at lines ~4033-4064) defers the + # promotion to runtime: it scans every band's pixels and promotes + # only if at least one sentinel hits. The chunked path cannot + # afford that scan up front (it would require decoding the mosaic + # the dask graph was constructed to defer), so this is a + # parse-time approximation. The trade-off: + # * if a band declares nodata and no chunk contains the + # sentinel, the chunked output is float64 where the eager + # output would have stayed integer (acceptable: the user + # asked the source for nodata, so they expect NaN masking); + # * if a band does not declare nodata, both paths keep the + # source integer dtype (handled by the ``promotes is False`` + # fall-through below). + # See also Copilot review on PR #1822. + # TODO(#1825): share dtype + scale/offset/sentinel logic with + # the eager path instead of re-implementing it here. + effective_dtypes = [] + for vrt_band in selected_bands: + eff = vrt_band.dtype + for src in vrt_band.sources: + scaled = src.scale is not None and src.scale != 1.0 + offset = src.offset is not None and src.offset != 0.0 + if scaled or offset: + eff = np.dtype(np.float64) + break + effective_dtypes.append(eff) + declared_dtype = np.result_type(*effective_dtypes) + + if declared_dtype.kind in ('u', 'i'): + promotes = False + for vrt_band in selected_bands: + if _vrt_sentinel_for_dtype(vrt_band.nodata, + declared_dtype) is not None: + promotes = True + break + if promotes: + declared_dtype = np.dtype(np.float64) + + out_has_band_axis = band is None and len(vrt.bands) > 1 + n_out_bands = len(selected_bands) + + # Build the dask graph: one ``from_delayed`` per chunk window. The + # destination coordinate space is the VRT's full extent (or the + # windowed extent), so chunk windows are computed relative to that + # space and translated to absolute VRT coords before being passed + # into the per-chunk reader. + rows = list(range(0, full_h, ch_h)) + cols = list(range(0, full_w, ch_w)) + + delayed_read = dask.delayed(_vrt_chunk_read) + + if gpu: + import cupy + meta = cupy.empty((0,) * (3 if out_has_band_axis else 2), + dtype=declared_dtype) + else: + meta = np.empty((0,) * (3 if out_has_band_axis else 2), + dtype=declared_dtype) + + dask_rows = [] + for r0c in rows: + r1c = min(r0c + ch_h, full_h) + dask_cols = [] + for c0c in cols: + c1c = min(c0c + ch_w, full_w) + if out_has_band_axis: + block_shape = (r1c - r0c, c1c - c0c, n_out_bands) + else: + block_shape = (r1c - r0c, c1c - c0c) + d = delayed_read( + source, + r0c + win_r0, c0c + win_c0, + r1c + win_r0, c1c + win_c0, + band=band, + max_pixels=max_pixels_effective, + missing_sources=missing_sources, + declared_dtype=declared_dtype, + gpu=gpu, + ) + block = da.from_delayed(d, shape=block_shape, + dtype=declared_dtype, meta=meta) + dask_cols.append(block) + dask_rows.append(da.concatenate(dask_cols, axis=1)) + + dask_arr = da.concatenate(dask_rows, axis=0) + + # Optional user-requested dtype cast happens lazily on the dask + # array so the per-chunk decode dtype stays predictable. + if dtype is not None: + target = np.dtype(dtype) + _validate_dtype_cast(declared_dtype, target) + dask_arr = dask_arr.astype(target) + final_dtype = target + else: + final_dtype = declared_dtype + + # Coordinates: derive from the VRT GeoTransform and the windowed + # extent. Mirrors the eager branch in ``read_vrt`` so chunked and + # eager reads share the same x/y arrays. + gt = vrt.geo_transform + coords = {} + attrs = {} + if gt is not None: + origin_x, res_x, _, origin_y, _, res_y = gt + if vrt.raster_type == 'point': + x_shift = win_c0 * res_x + y_shift = win_r0 * res_y + else: + x_shift = (win_c0 + 0.5) * res_x + y_shift = (win_r0 + 0.5) * res_y + x = np.arange(full_w, dtype=np.float64) * res_x + origin_x + x_shift + y = np.arange(full_h, dtype=np.float64) * res_y + origin_y + y_shift + coords['y'] = y + coords['x'] = x + origin_x_out = float(origin_x) + win_c0 * float(res_x) + origin_y_out = float(origin_y) + win_r0 * float(res_y) + attrs['transform'] = ( + float(res_x), 0.0, origin_x_out, + 0.0, float(res_y), origin_y_out, + ) + + if vrt.crs_wkt: + epsg = _wkt_to_epsg(vrt.crs_wkt) + if epsg is not None: + attrs['crs'] = epsg + attrs['crs_wkt'] = vrt.crs_wkt + if vrt.raster_type == 'point': + attrs['raster_type'] = 'point' + + # Surface the nodata sentinel for the selected band. + nodata_meta = None + if vrt.bands: + band_idx_for_nodata = band if band is not None else 0 + nodata_meta = vrt.bands[band_idx_for_nodata].nodata + if nodata_meta is not None: + attrs['nodata'] = nodata_meta + + # Static hole detection: mirror the eager-path ``attrs['vrt_holes']`` + # contract (#1734) by scanning every source referenced in the parsed + # VRT and recording the ones whose backing file does not exist on + # disk. The eager path discovers holes at decode time (per-source + # OSError / codec error) and aggregates them onto ``vrt.holes``; + # under chunked dispatch each per-task decode catches its own + # missing source and warns, but those records cannot be reduced + # back onto the parent DataArray without an extra synchronisation + # pass. The parse-time existence sweep catches the dominant + # missing-file case before scheduling and lets callers branch on + # ``"vrt_holes" in da.attrs`` exactly as with the eager reader. + # Empty list is omitted so the attr only appears when a hole is + # actually present. Each entry mirrors the eager schema: + # ``{'source', 'band', 'dst_rect', 'error'}``. + # TODO(#1825): the per-task path independently re-parses and + # re-resolves source paths; refactor to share the parent's scan. + chunked_holes: list[dict] = [] + for vrt_band in vrt.bands: + for src in vrt_band.sources: + if not _os.path.exists(src.filename): + chunked_holes.append({ + 'source': src.filename, + 'band': vrt_band.band_num, + 'dst_rect': (src.dst_rect.x_off, src.dst_rect.y_off, + src.dst_rect.x_size, src.dst_rect.y_size), + 'error': 'FileNotFoundError: source file not found', + }) + if chunked_holes: + attrs['vrt_holes'] = chunked_holes + + if out_has_band_axis: + dims = ['y', 'x', 'band'] + coords['band'] = np.arange(n_out_bands) + else: + dims = ['y', 'x'] + + if name is None: + name = _os.path.splitext(_os.path.basename(source))[0] + + result = xr.DataArray( + dask_arr, dims=dims, coords=coords, name=name, attrs=attrs, + ) + # Sanity: the declared dtype on the dask array is what we return. + assert result.dtype == final_dtype, ( + f"internal: result dtype {result.dtype} != declared {final_dtype}" + ) return result diff --git a/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py b/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py new file mode 100644 index 00000000..c692bdcd --- /dev/null +++ b/xrspatial/geotiff/tests/test_vrt_lazy_chunks_1814.py @@ -0,0 +1,365 @@ +"""Lazy chunked read_vrt builds a real dask graph (issue #1814). + +The pre-fix ``read_vrt(chunks=...)`` materialised the full VRT mosaic +on host RAM, then wrapped the resulting numpy array via ``.chunk()``. +That defeated the purpose of ``chunks=`` for memory protection and +made ``gpu=True`` + ``chunks=`` even worse: the entire mosaic was +moved to the device before chunking. + +These tests cover the new lazy path: + +* construction does not decode any pixels; +* per-chunk decode happens at ``.compute()`` time; +* the resulting array is byte-identical to the eager read; +* the chunk task count is bounded so a typo in ``chunks=`` cannot + build a graph the scheduler refuses to dispatch. +""" +from __future__ import annotations + +import os +import tempfile + +import dask.array as da +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import read_vrt, to_geotiff +from xrspatial.geotiff._vrt import write_vrt as _write_vrt_internal + + +def _gpu_available() -> bool: + try: + import cupy # noqa: F401 + except ImportError: + return False + try: + import cupy + return bool(cupy.cuda.is_available()) + except Exception: + return False + + +_HAS_GPU = _gpu_available() + + +@pytest.fixture +def single_tile_vrt(): + """One 128x128 float32 tile wrapped in a VRT.""" + arr = np.arange(128 * 128, dtype=np.float32).reshape(128, 128) + y = np.linspace(41.0, 40.0, 128) + x = np.linspace(-106.0, -105.0, 128) + raster = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + td = tempfile.mkdtemp(prefix='tmp_1814_single_') + tile_path = os.path.join(td, 'tile.tif') + to_geotiff(raster, tile_path) + vrt_path = os.path.join(td, 'mosaic.vrt') + _write_vrt_internal(vrt_path, [tile_path]) + yield vrt_path, arr + + +@pytest.fixture +def two_by_two_vrt(): + """4-tile mosaic via the to_geotiff(.vrt, ...) dask path.""" + arr = np.arange(256 * 256, dtype=np.float32).reshape(256, 256) + y = np.linspace(41.0, 40.0, 256) + x = np.linspace(-106.0, -105.0, 256) + raster = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + td = tempfile.mkdtemp(prefix='tmp_1814_2x2_') + vrt_path = os.path.join(td, 'mosaic.vrt') + # ``tile_size=128`` produces a 2x2 mosaic of 128x128 tiles. + to_geotiff(raster, vrt_path, tile_size=128) + yield vrt_path, arr + + +@pytest.fixture +def multiband_vrt(): + """3-band single-tile VRT.""" + rng = np.random.default_rng(1814) + arr = rng.random((64, 64, 3), dtype=np.float32) + y = np.linspace(41.0, 40.0, 64) + x = np.linspace(-106.0, -105.0, 64) + raster = xr.DataArray( + arr, + dims=['y', 'x', 'band'], + coords={'y': y, 'x': x, 'band': np.arange(3)}, + attrs={'crs': 4326}, + ) + td = tempfile.mkdtemp(prefix='tmp_1814_mb_') + tile_path = os.path.join(td, 'tile.tif') + to_geotiff(raster, tile_path) + vrt_path = os.path.join(td, 'mosaic.vrt') + _write_vrt_internal(vrt_path, [tile_path]) + yield vrt_path, arr + + +# --------------------------------------------------------------------------- +# 1. Construction is lazy: no pixels are decoded before .compute(). +# --------------------------------------------------------------------------- + +def test_chunks_builds_dask_array_with_multiple_blocks(two_by_two_vrt): + """``read_vrt(chunks=(N,N))`` returns a dask-backed DataArray + whose underlying array has more than one chunk along each spatial + axis. Before the fix the array was numpy-backed under + ``result.chunk()``, so this asserts the new lazy graph is in + play. + """ + vrt_path, _ = two_by_two_vrt + result = read_vrt(vrt_path, chunks=(64, 64)) + assert isinstance(result.data, da.Array), ( + f"expected dask Array, got {type(result.data).__name__}" + ) + # 256 / 64 = 4 blocks per axis. + assert result.data.numblocks == (4, 4), ( + f"expected 4x4 blocks, got {result.data.numblocks}" + ) + + +def test_chunks_is_lazy_does_not_call_internal_reader(monkeypatch, + two_by_two_vrt): + """Construction-time call count of the internal VRT reader is zero; + after ``.compute()`` it equals the chunk count. + """ + vrt_path, _ = two_by_two_vrt + + from xrspatial.geotiff import _vrt as vrt_module + + counter = {'calls': 0} + real_read = vrt_module.read_vrt + + def counting_read(*args, **kwargs): + counter['calls'] += 1 + return real_read(*args, **kwargs) + + monkeypatch.setattr(vrt_module, 'read_vrt', counting_read) + + result = read_vrt(vrt_path, chunks=(64, 64)) + + assert counter['calls'] == 0, ( + f"_read_vrt_internal called {counter['calls']} times before " + f".compute(); the chunked path leaked an eager decode" + ) + + computed = result.compute() + # 4 row blocks * 4 col blocks = 16 expected decodes. + assert counter['calls'] == 16, ( + f"expected 16 per-chunk decodes after compute, got {counter['calls']}" + ) + assert computed.shape == (256, 256) + + +# --------------------------------------------------------------------------- +# 2. Byte-identical to the eager path. +# --------------------------------------------------------------------------- + +def test_chunked_compute_matches_eager(two_by_two_vrt): + vrt_path, _ = two_by_two_vrt + eager = read_vrt(vrt_path) + chunked = read_vrt(vrt_path, chunks=(64, 64)).compute() + assert eager.shape == chunked.shape + assert np.array_equal(eager.values, chunked.values), ( + "chunked compute diverged from eager read" + ) + # Coords and key attrs must match too. + np.testing.assert_array_equal(eager['x'].values, chunked['x'].values) + np.testing.assert_array_equal(eager['y'].values, chunked['y'].values) + assert eager.attrs.get('transform') == chunked.attrs.get('transform') + assert eager.attrs.get('crs') == chunked.attrs.get('crs') + + +def test_chunked_single_tile_matches_eager(single_tile_vrt): + """Single-tile VRT (one source) should still match eager when + chunked. Exercises the path where many chunk windows hit the + same single source. + """ + vrt_path, _ = single_tile_vrt + eager = read_vrt(vrt_path) + chunked = read_vrt(vrt_path, chunks=(32, 32)).compute() + assert np.array_equal(eager.values, chunked.values) + + +# --------------------------------------------------------------------------- +# 3. Task-count cap. +# --------------------------------------------------------------------------- + +def test_chunks_task_cap_raises(two_by_two_vrt): + """``chunks=(1, 1)`` on a 256x256 VRT would build 65,536 tasks, + blowing past the 50,000-task cap. The reader should refuse with + a ValueError that names ``chunks=`` and suggests a larger size. + """ + vrt_path, _ = two_by_two_vrt + with pytest.raises(ValueError, match=r"chunks=.*task"): + read_vrt(vrt_path, chunks=(1, 1)) + + +# --------------------------------------------------------------------------- +# 4. Window + chunks: chunks tile the window, not the full extent. +# --------------------------------------------------------------------------- + +def test_window_plus_chunks_matches_eager(two_by_two_vrt): + """When both ``window=`` and ``chunks=`` are passed, the dask + graph must tile the window (not the full VRT extent). The output + shape and pixel values match an eager windowed read. + """ + vrt_path, _ = two_by_two_vrt + window = (32, 48, 160, 192) # 128 high, 144 wide + + eager = read_vrt(vrt_path, window=window) + chunked = read_vrt(vrt_path, window=window, chunks=(64, 64)) + + assert isinstance(chunked.data, da.Array) + # The chunk grid is sized off the window extent (128, 144) with + # chunks=64 => (2, 3) numblocks. + assert chunked.data.numblocks == (2, 3), ( + f"expected (2, 3) numblocks over the window, got " + f"{chunked.data.numblocks}" + ) + + computed = chunked.compute() + assert computed.shape == eager.shape == (128, 144) + assert np.array_equal(eager.values, computed.values) + + +# --------------------------------------------------------------------------- +# 5. GPU + chunks: each block is a cupy array. +# --------------------------------------------------------------------------- + +@pytest.mark.skipif(not _HAS_GPU, reason="cupy + CUDA required") +def test_gpu_plus_chunks_returns_dask_on_cupy(two_by_two_vrt): + """``read_vrt(gpu=True, chunks=...)`` must build a dask graph whose + blocks are cupy-backed (not numpy that gets cupy-wrapped at + compute time on the host). + """ + import cupy + + vrt_path, _ = two_by_two_vrt + result = read_vrt(vrt_path, gpu=True, chunks=(64, 64)) + + assert isinstance(result.data, da.Array) + assert isinstance(result.data._meta, cupy.ndarray), ( + f"expected cupy _meta, got " + f"{type(result.data._meta).__module__}." + f"{type(result.data._meta).__name__}" + ) + computed = result.compute() + assert isinstance(computed.data, cupy.ndarray) + + +# --------------------------------------------------------------------------- +# 6. Multi-band VRT + chunks. +# --------------------------------------------------------------------------- + +def test_multiband_plus_chunks_preserves_band_dim(multiband_vrt): + """3-band VRT read with ``chunks=`` keeps the band dimension on + every block and the assembled DataArray. + """ + vrt_path, src = multiband_vrt + result = read_vrt(vrt_path, chunks=(32, 32)) + + assert isinstance(result.data, da.Array) + assert result.dims == ('y', 'x', 'band') + assert result.shape == (64, 64, 3) + # Per-block shape on the band axis is 3 (whole band axis in one + # chunk because we did not pass a band-chunk size). + assert result.data.chunks[2] == (3,) + + computed = result.compute() + np.testing.assert_allclose(computed.values, src, rtol=0, atol=0) + + +# --------------------------------------------------------------------------- +# 7. Copilot review: ``attrs['vrt_holes']`` must propagate to the chunked +# path so users switching from eager to chunked keep the #1734 contract. +# --------------------------------------------------------------------------- + +def test_chunked_propagates_vrt_holes_when_source_missing(two_by_two_vrt): + """When a source referenced by the VRT does not exist on disk the + chunked reader must populate ``attrs['vrt_holes']`` with the same + schema the eager reader uses, so callers can branch on + ``"vrt_holes" in da.attrs`` regardless of which code path produced + the DataArray. + """ + import warnings + from xrspatial.geotiff import GeoTIFFFallbackWarning + from xrspatial.geotiff._reader import _mmap_cache + + vrt_path, _ = two_by_two_vrt + vrt_dir = os.path.dirname(vrt_path) + # Remove one of the four source tiles. ``to_geotiff(.vrt, tile_size=128)`` + # writes tile files into a ``_tiles/`` subdirectory next to the + # .vrt; walk the tree for any .tif and unlink the first one. + tile_files = [] + for root, _dirs, files in os.walk(vrt_dir): + for f in files: + if f.endswith('.tif'): + tile_files.append(os.path.join(root, f)) + assert len(tile_files) >= 1 + # write_vrt() opens each tile via _FileSource to read its header; + # _FileSource.close() decrements the refcount but the mmap stays + # cached. On Windows an active mmap blocks os.unlink (WinError 32). + _mmap_cache.clear() + os.unlink(tile_files[0]) + + with warnings.catch_warnings(): + warnings.simplefilter('ignore', GeoTIFFFallbackWarning) + result = read_vrt(vrt_path, chunks=(64, 64)) + + assert 'vrt_holes' in result.attrs, ( + "chunked path dropped vrt_holes contract from #1734" + ) + holes = result.attrs['vrt_holes'] + assert isinstance(holes, list) and len(holes) >= 1 + entry = holes[0] + # Schema parity with the eager path (see read_vrt at ~line 3963). + assert set(entry.keys()) >= {'source', 'band', 'dst_rect', 'error'} + assert isinstance(entry['dst_rect'], tuple) + assert len(entry['dst_rect']) == 4 + + +def test_chunked_no_vrt_holes_attr_when_complete(two_by_two_vrt): + """When every source is on disk the chunked reader must not set + ``attrs['vrt_holes']`` (eager parity: empty hole list is omitted). + """ + vrt_path, _ = two_by_two_vrt + result = read_vrt(vrt_path, chunks=(64, 64)) + assert 'vrt_holes' not in result.attrs + + +# --------------------------------------------------------------------------- +# 8. Copilot review: integer source with no declared nodata must keep its +# integer dtype through the chunked path (no spurious float64 promotion). +# --------------------------------------------------------------------------- + +def test_chunked_integer_no_nodata_keeps_source_dtype(): + """A uint16 source with no declared must produce a + uint16 chunked DataArray, not float64. The eager path stays integer + in this case because its runtime ``mask.any()`` is False; the + chunked path approximates with a static "any band declares nodata?" + check, which yields the same answer here. + """ + arr = np.arange(128 * 128, dtype=np.uint16).reshape(128, 128) + y = np.linspace(41.0, 40.0, 128) + x = np.linspace(-106.0, -105.0, 128) + raster = xr.DataArray(arr, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + td = tempfile.mkdtemp(prefix='tmp_1814_uint16_nonodata_') + tile_path = os.path.join(td, 'tile.tif') + to_geotiff(raster, tile_path) + vrt_path = os.path.join(td, 'mosaic.vrt') + # No ``nodata=`` passed: the VRT will not declare for + # this band, exercising the no-promotion branch. + _write_vrt_internal(vrt_path, [tile_path]) + + result = read_vrt(vrt_path, chunks=(32, 32)) + assert result.dtype == np.uint16, ( + f"expected uint16 (source dtype), got {result.dtype}; " + f"chunked path promoted to float64 despite no declared nodata" + ) + computed = result.compute() + assert computed.dtype == np.uint16 + np.testing.assert_array_equal(computed.values, arr)