Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,30 @@ def _geotiff_strict_mode() -> bool:
'XRSPATIAL_GEOTIFF_STRICT', '').lower() in ('1', 'true', 'yes')


def _gpu_fallback_warning_message(auto_detected: bool, exc: BaseException) -> str:
"""Build the ``to_geotiff`` GPU-to-CPU fallback warning text.

``to_geotiff`` reaches the GPU writer two ways: an explicit
``gpu=True`` argument, or the auto-detect branch when ``gpu is
None`` and the data lives on a CuPy device. The wording differs
because blaming the fallback on a flag the caller never set sends
them to fix the wrong thing. Both routes share the exception
payload format so callers can grep ``type(e).__name__: e`` either
way.
"""
suffix = f"({type(exc).__name__}: {exc})."
if auto_detected:
return (
"Data is on the GPU and was routed to the GPU writer, but "
"the writer is unavailable; falling back to CPU and copying "
"the array to host. " + suffix
)
return (
"to_geotiff(gpu=True) was requested but the GPU writer is "
"unavailable; falling back to CPU. " + suffix
)


def _wkt_to_epsg(wkt_or_proj: str) -> int | None:
"""Try to extract an EPSG code from a WKT or PROJ string.

Expand Down Expand Up @@ -1104,7 +1128,6 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
# non-default size alongside strip mode (it would otherwise be silently
# ignored).
if not tiled and tile_size != 256:
import warnings
warnings.warn(
f"tile_size={tile_size} is ignored when tiled=False "
"(strip layout). Pass tiled=True to use tile_size, or drop "
Expand Down Expand Up @@ -1133,7 +1156,11 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
max_z_error=max_z_error)
return

# Auto-detect GPU data and dispatch to write_geotiff_gpu
# Auto-detect GPU data and dispatch to write_geotiff_gpu. ``gpu is
# None`` is the implicit "use whatever fits the data" path; preserve
# that distinction in the fallback warning below so users who never
# set ``gpu=True`` are not told their explicit request was dropped.
auto_detected_gpu = gpu is None
use_gpu = gpu if gpu is not None else _is_gpu_data(data)
if use_gpu and _path_is_file_like:
# write_geotiff_gpu's nvCOMP path materialises tile parts and then
Expand Down Expand Up @@ -1171,8 +1198,45 @@ def to_geotiff(data: xr.DataArray | np.ndarray,
bigtiff=bigtiff,
streaming_buffer_bytes=streaming_buffer_bytes)
return
except (ImportError, Exception):
pass # fall through to CPU path
except ImportError as e:
# ``write_geotiff_gpu`` raises ImportError when cupy itself
# can't be imported. nvCOMP absence doesn't surface here:
# ``_try_nvcomp_from_device_bufs`` returns None when the
# library can't load, and the writer drops to CPU
# compression internally instead of re-raising. Fall back
# to the CPU writer with a typed warning so callers see
# that gpu=True (or auto-detected CuPy data) didn't go
# through. Strict mode re-raises so CI can fail loudly on
# missing GPU stacks.
if _geotiff_strict_mode():
raise
warnings.warn(
_gpu_fallback_warning_message(auto_detected_gpu, e),
GeoTIFFFallbackWarning,
stacklevel=2,
)
Comment on lines +1211 to +1217
except RuntimeError as e:
# Only fall back when the message names a GPU-availability
# problem; any other RuntimeError is a real bug in the GPU
# writer and the broad ``except (ImportError, Exception)``
# used to hide it from the user. Keep the keyword list
# tight: nvCOMP / CUDA / no device / no GPU / cuInit cover
# the realistic "no GPU present" failure modes without
# masking, e.g., a CRS or compression error that happens to
# raise RuntimeError. Strict mode re-raises in either case.
_gpu_unavail_tokens = (
'nvcomp', 'cuda', 'no device', 'no gpu', 'cuinit',
)
msg = str(e).lower()
if not any(tok in msg for tok in _gpu_unavail_tokens):
raise
if _geotiff_strict_mode():
raise
warnings.warn(
_gpu_fallback_warning_message(auto_detected_gpu, e),
GeoTIFFFallbackWarning,
stacklevel=2,
)

geo_transform = None
epsg = None
Expand Down
Loading
Loading