Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 52 additions & 173 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3827,7 +3827,10 @@ def read_vrt(source: str, *,
``GeoTIFFFallbackWarning`` for missing sources.
"""
from ._reader import _coerce_path
from ._vrt import read_vrt as _read_vrt_internal
from ._vrt import (
read_vrt as _read_vrt_internal,
_apply_integer_sentinel_mask as _vrt_apply_integer_sentinel_mask,
)

source = _coerce_path(source)

Expand Down Expand Up @@ -3939,83 +3942,12 @@ def read_vrt(source: str, *,
# downstream code follows (``attrs['nodata']`` is present iff the
# array has already been NaN-masked).
#
# For multi-band reads (``band is None`` and ``arr.ndim == 3``), each
# band can declare its own ``<NoDataValue>``. The float-VRT path masks
# per-band inline in ``_vrt._read_data``; mirror that here by walking
# ``vrt.bands`` and masking each ``arr[..., i]`` slice against its own
# sentinel. Before this branch, only band 0's sentinel was applied and
# bands 1+ left their integer sentinels as literal finite values in
# the returned float64 array. See issue #1611.
def _sentinel_for_dtype(nodata_val, dtype):
"""Return ``dtype``-cast sentinel for ``nodata_val`` or ``None``
if the value can't be represented in ``dtype`` (non-integer
dtype, out-of-range, non-finite, or fractional). Mirrors the
gating PR #1583 added to other read paths via
``_int_nodata_in_range``.

A plain Python ``int`` ``nodata_val`` is handled without going
through ``float`` first, so 64-bit sentinels such as
``2**64 - 1`` (``UInt64`` max) and ``-2**63`` (``Int64`` min)
round-trip without the float64 rounding that pushes them just
past the dtype's representable range. ``_parse_band_nodata``
in ``_vrt.py`` parses integer-band ``<NoDataValue>`` directly
as ``int`` to feed this path. See issue #1783 follow-up.
"""
if nodata_val is None or dtype.kind not in ('u', 'i'):
return None
info = np.iinfo(dtype)
# Fast/exact path: ``nodata_val`` is already an integer. Avoids
# the float64 round-trip that loses precision near the int64 /
# uint64 extremes. ``bool`` is a subclass of ``int`` -- treat
# True/False as a 1/0 sentinel rather than rejecting outright,
# matching the existing ``int(float(...))`` behaviour.
if isinstance(nodata_val, (int, np.integer)) and not isinstance(
nodata_val, bool):
nodata_int = int(nodata_val)
if info.min <= nodata_int <= info.max:
return dtype.type(nodata_int)
return None
try:
nodata_f = float(nodata_val)
except (TypeError, ValueError):
return None
if not (np.isfinite(nodata_f) and nodata_f.is_integer()
and info.min <= nodata_f <= info.max):
return None
return dtype.type(int(nodata_f))

if arr.dtype.kind in ('u', 'i'):
if arr.ndim == 3 and band is None and vrt.bands:
# Per-band masking: walk ``vrt.bands`` once and stream each
# band's mask. The first band with a sentinel hit promotes
# ``arr`` to float64 in place; ``int_arr`` keeps the original
# integer view alive so subsequent bands still compare against
# the exact sentinel dtype (the post-promotion float64 view
# works too, but staying on the integer dtype avoids any
# rounding edge case on extreme sentinels). Peak boolean-mask
# memory is O(H * W), not O(bands * H * W) like the earlier
# collect-then-apply implementation.
int_arr = arr
int_dtype = int_arr.dtype
for i, vrt_band in enumerate(vrt.bands):
if i >= int_arr.shape[-1]:
break
sentinel = _sentinel_for_dtype(vrt_band.nodata, int_dtype)
if sentinel is None:
continue
mask = int_arr[..., i] == sentinel
if not mask.any():
continue
if arr.dtype != np.float64:
arr = arr.astype(np.float64)
arr[..., i][mask] = np.nan
elif nodata is not None:
sentinel = _sentinel_for_dtype(nodata, arr.dtype)
if sentinel is not None:
mask = arr == sentinel
if mask.any():
arr = arr.astype(np.float64)
arr[mask] = np.nan
# The helper handles both per-band masking (multi-band reads where
# each band has its own ``<NoDataValue>``) and single-band masking,
# promoting ``arr`` to float64 on the first sentinel hit and writing
# NaNs in place on the promoted view. Shared with the chunked path
# (issue #1825) so behaviour stays in lockstep. See issue #1611.
arr = _vrt_apply_integer_sentinel_mask(arr, vrt, band)

# Surface the source GeoTransform in the same rasterio ordering used
# by open_geotiff: (pixel_width, 0, origin_x, 0, pixel_height, origin_y).
Expand Down Expand Up @@ -4063,7 +3995,7 @@ def _sentinel_for_dtype(nodata_val, dtype):

def _vrt_chunk_read(source, r0, c0, r1, c1, *,
band, max_pixels, missing_sources,
declared_dtype, gpu):
declared_dtype, gpu, parsed_vrt):
"""Decode a single chunk window from a VRT.

Called by ``dask.delayed`` from :func:`_read_vrt_chunked`. The
Expand All @@ -4073,57 +4005,35 @@ def _vrt_chunk_read(source, r0, c0, r1, c1, *,
dask graph declared up front, and optionally moves the block to
the GPU.

``parsed_vrt`` is the parent dispatcher's already-parsed
:class:`VRTDataset`; the internal reader skips the XML parse and
source-path containment check when this is supplied, removing the
N+1 parse cost an earlier implementation had (issue #1825).

Returning a ``numpy.ndarray`` (or ``cupy.ndarray`` when ``gpu`` is
set) whose shape and dtype match the ``shape=`` / ``dtype=`` kwargs
of the surrounding ``dask.array.from_delayed`` is the contract; a
mismatch would silently produce a wrong-shape dask array.
"""
from ._vrt import read_vrt as _read_vrt_internal
from ._vrt import (
read_vrt as _read_vrt_internal,
_apply_integer_sentinel_mask,
)

# TODO(#1825): this re-parses the VRT XML and re-validates source
# paths once per chunk task. Plumb the parent's parsed VRT through
# the task graph to remove the N+1 parse cost.
arr, vrt = _read_vrt_internal(
source, window=(r0, c0, r1, c1), band=band,
max_pixels=max_pixels, missing_sources=missing_sources,
parsed=parsed_vrt,
)

# Mirror the eager post-decode integer-sentinel masking in
# ``read_vrt``. The internal reader NaN-masks float source arrays
# Mirror the eager post-decode integer-sentinel masking via the
# shared helper. The internal reader NaN-masks float source arrays
# inline but leaves integer sentinels untouched, so the eager path
# promotes to float64 when sentinels hit. Apply the same logic per
# chunk; the surrounding dask graph already declared float64 when
# any band has a representable integer sentinel, so any chunk that
# actually fires the mask returns a buffer whose dtype matches the
# declared one.
if arr.dtype.kind in ('u', 'i'):
if arr.ndim == 3 and band is None and vrt.bands:
int_arr = arr
int_dtype = int_arr.dtype
for i, vrt_band in enumerate(vrt.bands):
if i >= int_arr.shape[-1]:
break
sentinel = _vrt_sentinel_for_dtype(vrt_band.nodata, int_dtype)
if sentinel is None:
continue
mask = int_arr[..., i] == sentinel
if not mask.any():
continue
if arr.dtype != np.float64:
arr = arr.astype(np.float64)
arr[..., i][mask] = np.nan
else:
band_idx = band if band is not None else 0
nodata = None
if vrt.bands and 0 <= band_idx < len(vrt.bands):
nodata = vrt.bands[band_idx].nodata
if nodata is not None:
sentinel = _vrt_sentinel_for_dtype(nodata, arr.dtype)
if sentinel is not None:
mask = arr == sentinel
if mask.any():
arr = arr.astype(np.float64)
arr[mask] = np.nan
# promotes to float64 when sentinels hit. The surrounding dask graph
# already declared float64 when any band has a representable integer
# sentinel, so any chunk that actually fires the mask returns a
# buffer whose dtype matches the declared one.
arr = _apply_integer_sentinel_mask(arr, vrt, band)

if declared_dtype is not None and arr.dtype != declared_dtype:
arr = arr.astype(declared_dtype)
Expand All @@ -4135,33 +4045,6 @@ def _vrt_chunk_read(source, r0, c0, r1, c1, *,
return arr


def _vrt_sentinel_for_dtype(nodata_val, dtype):
"""Return ``dtype``-cast sentinel for ``nodata_val`` or None.

Module-level twin of the closure ``_sentinel_for_dtype`` defined
inside :func:`read_vrt`. Lifted to module scope so the per-chunk
helper :func:`_vrt_chunk_read` can call it without paying the cost
of re-binding the closure on every block.
"""
if nodata_val is None or dtype.kind not in ('u', 'i'):
return None
info = np.iinfo(dtype)
if isinstance(nodata_val, (int, np.integer)) and not isinstance(
nodata_val, bool):
nodata_int = int(nodata_val)
if info.min <= nodata_int <= info.max:
return dtype.type(nodata_int)
return None
try:
nodata_f = float(nodata_val)
except (TypeError, ValueError):
return None
if not (np.isfinite(nodata_f) and nodata_f.is_integer()
and info.min <= nodata_f <= info.max):
return None
return dtype.type(int(nodata_f))


def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype,
max_pixels, missing_sources):
"""Lazy ``read_vrt`` dispatch when ``chunks=`` is set (issue #1814).
Expand All @@ -4188,13 +4071,21 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype,
import dask.array as da

from ._reader import MAX_PIXELS_DEFAULT
from ._vrt import _read_vrt_xml, parse_vrt
from ._vrt import (
parse_vrt,
_read_vrt_xml,
_effective_dtype_for_bands,
_sentinel_for_dtype,
)

# Parse the VRT XML up-front (cheap; no pixel decode). Route through
# ``_read_vrt_xml`` so the 64 MiB ``XRSPATIAL_VRT_MAX_XML_BYTES`` cap
# added in #1818 applies to the chunked dispatcher too; a raw
# ``open().read()`` here would let a multi-GB attacker-supplied VRT
# exhaust memory before any parser-side guard fires (issue #1831).
# The parsed VRTDataset is plumbed into every per-chunk task so each
# task can skip the redundant XML parse and source-path allowlist
# validation the internal reader otherwise performs (issue #1825).
xml_str = _read_vrt_xml(source)
vrt_dir = _os.path.dirname(_os.path.abspath(source))
vrt = parse_vrt(xml_str, vrt_dir)
Expand Down Expand Up @@ -4277,17 +4168,18 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype,
"VRT has no <VRTRasterBand> elements; cannot determine "
"output dtype")

# Compute the declared dtype. Match the internal reader's
# ``np.result_type`` over per-band effective dtypes, then widen to
# float64 only when at least one selected band declares an integer
# nodata sentinel that round-trips through the band's dtype.
# Compute the declared dtype. Share the per-band effective-dtype
# rule (ComplexSource scale/offset promotes to float64) with the
# eager path via ``_effective_dtype_for_bands`` so both paths agree
# on the result_type (issue #1825). Then widen to float64 if any
# selected band declares an integer nodata sentinel that round-trips
# through the band's dtype.
#
# The eager path (``read_vrt`` at lines ~4033-4064) defers the
# promotion to runtime: it scans every band's pixels and promotes
# only if at least one sentinel hits. The chunked path cannot
# afford that scan up front (it would require decoding the mosaic
# the dask graph was constructed to defer), so this is a
# parse-time approximation. The trade-off:
# The eager path defers the promotion to runtime: it scans every
# band's pixels and promotes only if at least one sentinel hits.
# The chunked path cannot afford that scan up front (it would
# require decoding the mosaic the dask graph was constructed to
# defer), so this is a parse-time approximation. The trade-off:
# * if a band declares nodata and no chunk contains the
# sentinel, the chunked output is float64 where the eager
# output would have stayed integer (acceptable: the user
Expand All @@ -4296,25 +4188,13 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype,
# source integer dtype (handled by the ``promotes is False``
# fall-through below).
# See also Copilot review on PR #1822.
# TODO(#1825): share dtype + scale/offset/sentinel logic with
# the eager path instead of re-implementing it here.
effective_dtypes = []
for vrt_band in selected_bands:
eff = vrt_band.dtype
for src in vrt_band.sources:
scaled = src.scale is not None and src.scale != 1.0
offset = src.offset is not None and src.offset != 0.0
if scaled or offset:
eff = np.dtype(np.float64)
break
effective_dtypes.append(eff)
declared_dtype = np.result_type(*effective_dtypes)
declared_dtype = _effective_dtype_for_bands(selected_bands)

if declared_dtype.kind in ('u', 'i'):
promotes = False
for vrt_band in selected_bands:
if _vrt_sentinel_for_dtype(vrt_band.nodata,
declared_dtype) is not None:
if _sentinel_for_dtype(vrt_band.nodata,
declared_dtype) is not None:
promotes = True
break
if promotes:
Expand Down Expand Up @@ -4360,6 +4240,7 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype,
missing_sources=missing_sources,
declared_dtype=declared_dtype,
gpu=gpu,
parsed_vrt=vrt,
)
block = da.from_delayed(d, shape=block_shape,
dtype=declared_dtype, meta=meta)
Expand Down Expand Up @@ -4427,8 +4308,6 @@ def _read_vrt_chunked(source, *, window, band, name, chunks, gpu, dtype,
# Empty list is omitted so the attr only appears when a hole is
# actually present. Each entry mirrors the eager schema:
# ``{'source', 'band', 'dst_rect', 'error'}``.
# TODO(#1825): the per-task path independently re-parses and
# re-resolves source paths; refactor to share the parent's scan.
chunked_holes: list[dict] = []
for vrt_band in vrt.bands:
for src in vrt_band.sources:
Expand Down
Loading
Loading