diff --git a/README.md b/README.md index ac7eaefd..d6834656 100644 --- a/README.md +++ b/README.md @@ -130,131 +130,82 @@ Rasters are regularly gridded datasets like GeoTIFFs, JPGs, and PNGs. In the GIS world, rasters are used for representing continuous phenomena (e.g. elevation, rainfall, distance), either directly as numerical values, or as RGB images created for humans to view. Rasters typically have two spatial dimensions, but may have any number of other dimensions (time, type of measurement, etc.) #### Supported Spatial Functions with Supported Inputs - ✅ = native backend    🔄 = accepted (CPU fallback) ------- +### **GeoTIFF / COG I/O** -### **Classification** - -| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | -|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [Box Plot](xrspatial/classify.py) | Classifies values into bins based on box plot quartile boundaries | PySAL mapclassify | ✅️ |✅ | ✅ | 🔄 | -| [Equal Interval](xrspatial/classify.py) | Divides the value range into equal-width bins | PySAL mapclassify | ✅️ |✅ | ✅ |✅ | -| [Head/Tail Breaks](xrspatial/classify.py) | Classifies heavy-tailed distributions using recursive mean splitting | PySAL mapclassify | ✅️ |✅ | 🔄 | 🔄 | -| [Maximum Breaks](xrspatial/classify.py) | Finds natural groupings by maximizing differences between sorted values | PySAL mapclassify | ✅️ |✅ | 🔄 | 🔄 | -| [Natural Breaks](xrspatial/classify.py) | Optimizes class boundaries to minimize within-class variance (Jenks) | Jenks 1967, PySAL | ✅️ |✅ | 🔄 | 🔄 | -| [Percentiles](xrspatial/classify.py) | Assigns classes based on user-defined percentile breakpoints | PySAL mapclassify | ✅️ |✅ | ✅ | 🔄 | -| [Quantile](xrspatial/classify.py) | Distributes values into classes with equal observation counts | PySAL mapclassify | ✅️ |✅ | ✅ | 🔄 | -| [Reclassify](xrspatial/classify.py) | Remaps pixel values to new classes using a user-defined lookup | PySAL mapclassify | ✅️ |✅ | ✅ |✅ | -| [Std Mean](xrspatial/classify.py) | Classifies values by standard deviation intervals from the mean | PySAL mapclassify | ✅️ |✅ | ✅ |✅ | - -------- - -### **Diffusion** - -| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | -|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [Diffuse](xrspatial/diffusion.py) | Runs explicit forward-Euler diffusion on a 2D scalar field | Standard (heat equation) | ✅️ | ✅️ | ✅️ | ✅️ | - -------- - -### **Focal** - -| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | -|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [Apply](xrspatial/focal.py) | Applies a custom function over a sliding neighborhood window | Standard | ✅️ | ✅️ | ✅️ | ✅️ | -| [Hotspots](xrspatial/focal.py) | Identifies statistically significant spatial clusters using Getis-Ord Gi* | Getis & Ord 1992 | ✅️ | ✅️ | ✅️ | ✅️ | -| [Emerging Hotspots](xrspatial/emerging_hotspots.py) | Classifies time-series hot/cold spot trends using Gi* and Mann-Kendall | Getis & Ord 1992, Mann 1945 | ✅️ | ✅️ | ✅️ | ✅️ | -| [Mean](xrspatial/focal.py) | Computes the mean value within a sliding neighborhood window | Standard | ✅️ | ✅️ | ✅️ | ✅️ | -| [Focal Statistics](xrspatial/focal.py) | Computes summary statistics over a sliding neighborhood window | Standard | ✅️ | ✅️ | ✅️ | ✅️ | -| [Bilateral](xrspatial/bilateral.py) | Feature-preserving smoothing via bilateral filtering | Tomasi & Manduchi 1998 | ✅️ | ✅️ | ✅️ | ✅️ | -| [GLCM Texture](xrspatial/glcm.py) | Computes Haralick GLCM texture features over a sliding window | Haralick et al. 1973 | ✅️ | ✅️ | 🔄 | 🔄 | - -------- - -### **Morphological** - -| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | -|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [Erode](xrspatial/morphology.py) | Morphological erosion (local minimum over structuring element) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | -| [Dilate](xrspatial/morphology.py) | Morphological dilation (local maximum over structuring element) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | -| [Opening](xrspatial/morphology.py) | Erosion then dilation (removes small bright features) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | -| [Closing](xrspatial/morphology.py) | Dilation then erosion (fills small dark gaps) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | -| [Gradient](xrspatial/morphology.py) | Dilation minus erosion (edge detection) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | -| [White Top-hat](xrspatial/morphology.py) | Original minus opening (isolate bright features) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | -| [Black Top-hat](xrspatial/morphology.py) | Closing minus original (isolate dark features) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | - -------- - -### **Fire** +Native GeoTIFF and Cloud Optimized GeoTIFF reader/writer. No GDAL required. -| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | -|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [dNBR](xrspatial/fire.py) | Differenced Normalized Burn Ratio (pre minus post NBR) | USGS | ✅️ | ✅️ | ✅️ | ✅️ | -| [RdNBR](xrspatial/fire.py) | Relative dNBR normalized by pre-fire vegetation density | USGS | ✅️ | ✅️ | ✅️ | ✅️ | -| [Burn Severity Class](xrspatial/fire.py) | USGS 7-class burn severity from dNBR thresholds | USGS | ✅️ | ✅️ | ✅️ | ✅️ | -| [Fireline Intensity](xrspatial/fire.py) | Byram's fireline intensity from fuel load and spread rate (kW/m) | Byram 1959 | ✅️ | ✅️ | ✅️ | ✅️ | -| [Flame Length](xrspatial/fire.py) | Flame length derived from fireline intensity (m) | Byram 1959 | ✅️ | ✅️ | ✅️ | ✅️ | -| [Rate of Spread](xrspatial/fire.py) | Simplified Rothermel spread rate with Anderson 13 fuel models (m/min) | Rothermel 1972, Anderson 1982 | ✅️ | ✅️ | ✅️ | ✅️ | -| [KBDI](xrspatial/fire.py) | Keetch-Byram Drought Index single time-step update (0-800 mm) | Keetch & Byram 1968 | ✅️ | ✅️ | ✅️ | ✅️ | - -------- - -### **Multispectral** +| Name | Description | NumPy | Dask | CuPy GPU | Dask+CuPy GPU | Cloud | +|:-----|:------------|:-----:|:----:|:--------:|:-------------:|:-----:| +| [read_geotiff](xrspatial/geotiff/__init__.py) | Read GeoTIFF / COG / VRT | ✅️ | ✅️ | ✅️ | ✅️ | ✅️ | +| [write_geotiff](xrspatial/geotiff/__init__.py) | Write DataArray as GeoTIFF / COG | ✅️ | ✅️ | ✅️ | ✅️ | ✅️ | +| [write_vrt](xrspatial/geotiff/__init__.py) | Generate VRT mosaic from GeoTIFFs | ✅️ | | | | | -| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | -|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [Atmospherically Resistant Vegetation Index (ARVI)](xrspatial/multispectral.py) | Vegetation index resistant to atmospheric effects using blue band correction | Kaufman & Tanre 1992 | ✅️ |✅️ | ✅️ |✅️ | -| [Enhanced Built-Up and Bareness Index (EBBI)](xrspatial/multispectral.py) | Highlights built-up areas and barren land from thermal and SWIR bands | As-syakur et al. 2012 | ✅️ |✅️ | ✅️ |✅️ | -| [Enhanced Vegetation Index (EVI)](xrspatial/multispectral.py) | Enhanced vegetation index reducing soil and atmospheric noise | Huete et al. 2002 | ✅️ |✅️ | ✅️ |✅️ | -| [Green Chlorophyll Index (GCI)](xrspatial/multispectral.py) | Estimates leaf chlorophyll content from green and NIR reflectance | Gitelson et al. 2003 | ✅️ |✅️ | ✅️ |✅️ | -| [Normalized Burn Ratio (NBR)](xrspatial/multispectral.py) | Measures burn severity using NIR and SWIR band difference | USGS Landsat | ✅️ |✅️ | ✅️ |✅️ | -| [Normalized Burn Ratio 2 (NBR2)](xrspatial/multispectral.py) | Refines burn severity mapping using two SWIR bands | USGS Landsat | ✅️ |✅️ | ✅️ |✅️ | -| [Normalized Difference Moisture Index (NDMI)](xrspatial/multispectral.py) | Detects vegetation moisture stress from NIR and SWIR reflectance | USGS Landsat | ✅️ |✅️ | ✅️ |✅️ | -| [Normalized Difference Water Index (NDWI)](xrspatial/multispectral.py) | Maps open water bodies using green and NIR band difference | McFeeters 1996 | ✅️ |✅️ | ✅️ |✅️ | -| [Modified Normalized Difference Water Index (MNDWI)](xrspatial/multispectral.py) | Detects water in urban areas using green and SWIR bands | Xu 2006 | ✅️ |✅️ | ✅️ |✅️ | -| [Normalized Difference Vegetation Index (NDVI)](xrspatial/multispectral.py) | Quantifies vegetation density from red and NIR band difference | Rouse et al. 1973 | ✅️ |✅️ | ✅️ |✅️ | -| [Soil Adjusted Vegetation Index (SAVI)](xrspatial/multispectral.py) | Vegetation index with soil brightness correction factor | Huete 1988 | ✅️ |✅️ | ✅️ |✅️ | -| [Structure Insensitive Pigment Index (SIPI)](xrspatial/multispectral.py) | Estimates carotenoid-to-chlorophyll ratio for plant stress detection | Penuelas et al. 1995 | ✅️ |✅️ | ✅️ |✅️ | -| [True Color](xrspatial/multispectral.py) | Composites red, green, and blue bands into a natural color image | Standard | ✅️ | ✅️ | ✅️ | ✅️ | +`read_geotiff` and `write_geotiff` auto-dispatch to the correct backend: -------- +```python +read_geotiff('dem.tif') # NumPy +read_geotiff('dem.tif', chunks=512) # Dask +read_geotiff('dem.tif', gpu=True) # CuPy (nvCOMP + GDS) +read_geotiff('dem.tif', gpu=True, chunks=512) # Dask + CuPy +read_geotiff('https://example.com/cog.tif') # HTTP COG +read_geotiff('s3://bucket/dem.tif') # Cloud (S3/GCS/Azure) +read_geotiff('mosaic.vrt') # VRT mosaic (auto-detected) + +write_geotiff(cupy_array, 'out.tif') # auto-detects GPU +write_geotiff(data, 'out.tif', gpu=True) # force GPU compress +write_vrt('mosaic.vrt', ['tile1.tif', 'tile2.tif']) # generate VRT +``` +**Compression codecs:** Deflate, LZW (Numba JIT), ZSTD, PackBits, JPEG (Pillow), uncompressed -### **Multivariate** +**GPU codecs:** Deflate and ZSTD via nvCOMP batch API; LZW via Numba CUDA kernels -| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | -|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [Mahalanobis Distance](xrspatial/mahalanobis.py) | Measures statistical distance from a multi-band reference distribution, accounting for band correlations | Mahalanobis 1936 | ✅️ |✅️ | ✅️ |✅️ | +**Features:** +- Tiled, stripped, BigTIFF, multi-band (RGB/RGBA), sub-byte (1/2/4/12-bit) +- Predictors: horizontal differencing (pred=2), floating-point (pred=3) +- GeoKeys: EPSG, WKT/PROJ (via pyproj), citations, units, ellipsoid, vertical CRS +- Metadata: nodata masking, palette colormaps, DPI/resolution, GDALMetadata XML, arbitrary tag preservation +- Cloud storage: S3 (`s3://`), GCS (`gs://`), Azure (`az://`) via fsspec +- GPUDirect Storage: SSD→GPU direct DMA via KvikIO (optional) +- Thread-safe mmap reads, atomic writes, HTTP connection reuse (urllib3) +- Overview generation: mean, nearest, min, max, median, mode, cubic +- Planar config, big-endian byte swap, PixelIsArea/PixelIsPoint -------- +**Read performance** (real-world files, A6000 GPU): -### **Pathfinding** +| File | Format | xrspatial CPU | rioxarray | GPU (nvCOMP) | +|:-----|:-------|:------------:|:---------:|:------------:| +| render_demo 187x253 | uncompressed | **0.2ms** | 2.4ms | 0.7ms | +| Landsat B4 1310x1093 | uncompressed | **1.0ms** | 6.0ms | 1.7ms | +| Copernicus 3600x3600 | deflate+fp3 | 241ms | 195ms | 872ms | +| USGS 1as 3612x3612 | LZW+fp3 | 275ms | 215ms | 747ms | +| USGS 1m 10012x10012 | LZW | **1.25s** | 1.80s | **990ms** | -| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | -|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [A* Pathfinding](xrspatial/pathfinding.py) | Finds the least-cost path between two cells on a cost surface | Hart et al. 1968 | ✅️ | ✅ | 🔄 | 🔄 | -| [Multi-Stop Search](xrspatial/pathfinding.py) | Routes through N waypoints in sequence, with optional TSP reordering | Custom | ✅️ | ✅ | 🔄 | 🔄 | +**Read performance** (synthetic tiled, GPU shines at scale): ----------- +| Size | Codec | xrspatial CPU | rioxarray | GPU (nvCOMP) | +|:-----|:------|:------------:|:---------:|:------------:| +| 4096x4096 | deflate | 265ms | 211ms | **158ms** | +| 4096x4096 | zstd | **73ms** | 159ms | **58ms** | +| 8192x8192 | deflate | 1.06s | 859ms | **565ms** | +| 8192x8192 | zstd | **288ms** | 668ms | **171ms** | -### **Proximity** +**Write performance** (synthetic tiled): -| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | -|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [Allocation](xrspatial/proximity.py) | Assigns each cell to the identity of the nearest source feature | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | -| [Balanced Allocation](xrspatial/balanced_allocation.py) | Partitions a cost surface into territories of roughly equal cost-weighted area | Custom | ✅️ | ✅ | ✅️ | ✅️ | -| [Cost Distance](xrspatial/cost_distance.py) | Computes minimum accumulated cost to the nearest source through a friction surface | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | -| [Least-Cost Corridor](xrspatial/corridor.py) | Identifies zones of low cumulative cost between two source locations | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | -| [Direction](xrspatial/proximity.py) | Computes the direction from each cell to the nearest source feature | Standard | ✅️ | ✅ | ✅️ | ✅️ | -| [Proximity](xrspatial/proximity.py) | Computes the distance from each cell to the nearest source feature | Standard | ✅️ | ✅ | ✅️ | ✅️ | -| [Surface Distance](xrspatial/surface_distance.py) | Computes distance along the 3D terrain surface to the nearest source | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | -| [Surface Allocation](xrspatial/surface_distance.py) | Assigns each cell to the nearest source by terrain surface distance | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | -| [Surface Direction](xrspatial/surface_distance.py) | Computes compass direction to the nearest source by terrain surface distance | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | +| Size | Codec | xrspatial CPU | rioxarray | GPU (nvCOMP) | +|:-----|:------|:------------:|:---------:|:------------:| +| 2048x2048 | deflate | 424ms | 110ms | **135ms** | +| 2048x2048 | zstd | 49ms | 83ms | 81ms | +| 4096x4096 | deflate | 1.68s | 447ms | **302ms** | +| 8192x8192 | deflate | 6.84s | 2.03s | **1.11s** | +| 8192x8192 | zstd | 847ms | 822ms | 1.03s | --------- +**Consistency:** 100% pixel-exact match vs rioxarray on all tested files (Landsat 8, Copernicus DEM, USGS 1-arc-second, USGS 1-meter). +----------- ### **Reproject / Merge** | Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | @@ -264,15 +215,15 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e ------- -### **Raster / Vector Conversion** +### **Utilities** | Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | -|:-----|:------------|:------:|:------------------:|:-----------------:|:---------------------:|:---------------------:| -| [Polygonize](xrspatial/polygonize.py) | Converts contiguous regions of equal value into vector polygons | Standard (CCL) | ✅️ | ✅️ | ✅️ | 🔄 | -| [Contours](xrspatial/contour.py) | Extracts elevation contour lines (isolines) from a raster surface | Standard (marching squares) | ✅️ | ✅️ | 🔄 | 🔄 | -| [Rasterize](xrspatial/rasterize.py) | Rasterizes vector geometries (polygons, lines, points) from a GeoDataFrame | Standard (scanline, Bresenham) | ✅️ | | ✅️ | | +|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| +| [Preview](xrspatial/preview.py) | Downsamples a raster to target pixel dimensions for visualization | Custom | ✅️ | ✅️ | ✅️ | 🔄 | +| [Rescale](xrspatial/normalize.py) | Min-max normalization to a target range (default [0, 1]) | Standard | ✅️ | ✅️ | ✅️ | ✅️ | +| [Standardize](xrspatial/normalize.py) | Z-score normalization (subtract mean, divide by std) | Standard | ✅️ | ✅️ | ✅️ | ✅️ | --------- +----------- ### **Surface** @@ -339,24 +290,72 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e ----------- -### **Interpolation** +### **Multispectral** | Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | |:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [IDW](xrspatial/interpolate/_idw.py) | Inverse Distance Weighting from scattered points to a raster grid | Standard (IDW) | ✅️ | ✅️ | ✅️ | ✅️ | -| [Kriging](xrspatial/interpolate/_kriging.py) | Ordinary Kriging with automatic variogram fitting (spherical, exponential, gaussian) | Standard (ordinary kriging) | ✅️ | ✅️ | ✅️ | ✅️ | -| [Spline](xrspatial/interpolate/_spline.py) | Thin Plate Spline interpolation with optional smoothing | Standard (TPS) | ✅️ | ✅️ | ✅️ | ✅️ | +| [Atmospherically Resistant Vegetation Index (ARVI)](xrspatial/multispectral.py) | Vegetation index resistant to atmospheric effects using blue band correction | Kaufman & Tanre 1992 | ✅️ |✅️ | ✅️ |✅️ | +| [Enhanced Built-Up and Bareness Index (EBBI)](xrspatial/multispectral.py) | Highlights built-up areas and barren land from thermal and SWIR bands | As-syakur et al. 2012 | ✅️ |✅️ | ✅️ |✅️ | +| [Enhanced Vegetation Index (EVI)](xrspatial/multispectral.py) | Enhanced vegetation index reducing soil and atmospheric noise | Huete et al. 2002 | ✅️ |✅️ | ✅️ |✅️ | +| [Green Chlorophyll Index (GCI)](xrspatial/multispectral.py) | Estimates leaf chlorophyll content from green and NIR reflectance | Gitelson et al. 2003 | ✅️ |✅️ | ✅️ |✅️ | +| [Normalized Burn Ratio (NBR)](xrspatial/multispectral.py) | Measures burn severity using NIR and SWIR band difference | USGS Landsat | ✅️ |✅️ | ✅️ |✅️ | +| [Normalized Burn Ratio 2 (NBR2)](xrspatial/multispectral.py) | Refines burn severity mapping using two SWIR bands | USGS Landsat | ✅️ |✅️ | ✅️ |✅️ | +| [Normalized Difference Moisture Index (NDMI)](xrspatial/multispectral.py) | Detects vegetation moisture stress from NIR and SWIR reflectance | USGS Landsat | ✅️ |✅️ | ✅️ |✅️ | +| [Normalized Difference Water Index (NDWI)](xrspatial/multispectral.py) | Maps open water bodies using green and NIR band difference | McFeeters 1996 | ✅️ |✅️ | ✅️ |✅️ | +| [Modified Normalized Difference Water Index (MNDWI)](xrspatial/multispectral.py) | Detects water in urban areas using green and SWIR bands | Xu 2006 | ✅️ |✅️ | ✅️ |✅️ | +| [Normalized Difference Vegetation Index (NDVI)](xrspatial/multispectral.py) | Quantifies vegetation density from red and NIR band difference | Rouse et al. 1973 | ✅️ |✅️ | ✅️ |✅️ | +| [Soil Adjusted Vegetation Index (SAVI)](xrspatial/multispectral.py) | Vegetation index with soil brightness correction factor | Huete 1988 | ✅️ |✅️ | ✅️ |✅️ | +| [Structure Insensitive Pigment Index (SIPI)](xrspatial/multispectral.py) | Estimates carotenoid-to-chlorophyll ratio for plant stress detection | Penuelas et al. 1995 | ✅️ |✅️ | ✅️ |✅️ | +| [True Color](xrspatial/multispectral.py) | Composites red, green, and blue bands into a natural color image | Standard | ✅️ | ✅️ | ✅️ | ✅️ | ------------ +------- -### **Dasymetric** + +### **Classification** | Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | |:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [Disaggregate](xrspatial/dasymetric.py) | Redistributes zonal totals to pixels using an ancillary weight surface | Mennis 2003 | ✅️ | ✅️ | ✅️ | ✅️ | -| [Pycnophylactic](xrspatial/dasymetric.py) | Tobler's pycnophylactic interpolation preserving zone totals via Laplacian smoothing | Tobler 1979 | ✅️ | | | | +| [Box Plot](xrspatial/classify.py) | Classifies values into bins based on box plot quartile boundaries | PySAL mapclassify | ✅️ |✅ | ✅ | 🔄 | +| [Equal Interval](xrspatial/classify.py) | Divides the value range into equal-width bins | PySAL mapclassify | ✅️ |✅ | ✅ |✅ | +| [Head/Tail Breaks](xrspatial/classify.py) | Classifies heavy-tailed distributions using recursive mean splitting | PySAL mapclassify | ✅️ |✅ | 🔄 | 🔄 | +| [Maximum Breaks](xrspatial/classify.py) | Finds natural groupings by maximizing differences between sorted values | PySAL mapclassify | ✅️ |✅ | 🔄 | 🔄 | +| [Natural Breaks](xrspatial/classify.py) | Optimizes class boundaries to minimize within-class variance (Jenks) | Jenks 1967, PySAL | ✅️ |✅ | 🔄 | 🔄 | +| [Percentiles](xrspatial/classify.py) | Assigns classes based on user-defined percentile breakpoints | PySAL mapclassify | ✅️ |✅ | ✅ | 🔄 | +| [Quantile](xrspatial/classify.py) | Distributes values into classes with equal observation counts | PySAL mapclassify | ✅️ |✅ | ✅ | 🔄 | +| [Reclassify](xrspatial/classify.py) | Remaps pixel values to new classes using a user-defined lookup | PySAL mapclassify | ✅️ |✅ | ✅ |✅ | +| [Std Mean](xrspatial/classify.py) | Classifies values by standard deviation intervals from the mean | PySAL mapclassify | ✅️ |✅ | ✅ |✅ | ------------ +------- + +### **Focal** + +| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | +|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| +| [Apply](xrspatial/focal.py) | Applies a custom function over a sliding neighborhood window | Standard | ✅️ | ✅️ | ✅️ | ✅️ | +| [Hotspots](xrspatial/focal.py) | Identifies statistically significant spatial clusters using Getis-Ord Gi* | Getis & Ord 1992 | ✅️ | ✅️ | ✅️ | ✅️ | +| [Emerging Hotspots](xrspatial/emerging_hotspots.py) | Classifies time-series hot/cold spot trends using Gi* and Mann-Kendall | Getis & Ord 1992, Mann 1945 | ✅️ | ✅️ | ✅️ | ✅️ | +| [Mean](xrspatial/focal.py) | Computes the mean value within a sliding neighborhood window | Standard | ✅️ | ✅️ | ✅️ | ✅️ | +| [Focal Statistics](xrspatial/focal.py) | Computes summary statistics over a sliding neighborhood window | Standard | ✅️ | ✅️ | ✅️ | ✅️ | +| [Bilateral](xrspatial/bilateral.py) | Feature-preserving smoothing via bilateral filtering | Tomasi & Manduchi 1998 | ✅️ | ✅️ | ✅️ | ✅️ | +| [GLCM Texture](xrspatial/glcm.py) | Computes Haralick GLCM texture features over a sliding window | Haralick et al. 1973 | ✅️ | ✅️ | 🔄 | 🔄 | + +------- + +### **Proximity** + +| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | +|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| +| [Allocation](xrspatial/proximity.py) | Assigns each cell to the identity of the nearest source feature | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | +| [Balanced Allocation](xrspatial/balanced_allocation.py) | Partitions a cost surface into territories of roughly equal cost-weighted area | Custom | ✅️ | ✅ | ✅️ | ✅️ | +| [Cost Distance](xrspatial/cost_distance.py) | Computes minimum accumulated cost to the nearest source through a friction surface | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | +| [Least-Cost Corridor](xrspatial/corridor.py) | Identifies zones of low cumulative cost between two source locations | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | +| [Direction](xrspatial/proximity.py) | Computes the direction from each cell to the nearest source feature | Standard | ✅️ | ✅ | ✅️ | ✅️ | +| [Proximity](xrspatial/proximity.py) | Computes the distance from each cell to the nearest source feature | Standard | ✅️ | ✅ | ✅️ | ✅️ | +| [Surface Distance](xrspatial/surface_distance.py) | Computes distance along the 3D terrain surface to the nearest source | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | +| [Surface Allocation](xrspatial/surface_distance.py) | Assigns each cell to the nearest source by terrain surface distance | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | +| [Surface Direction](xrspatial/surface_distance.py) | Computes compass direction to the nearest source by terrain surface distance | Standard (Dijkstra) | ✅️ | ✅ | ✅️ | ✅️ | + +-------- ### **Zonal** @@ -371,13 +370,88 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e ----------- -### **Utilities** +### **Interpolation** | Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | |:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| -| [Preview](xrspatial/preview.py) | Downsamples a raster to target pixel dimensions for visualization | Custom | ✅️ | ✅️ | ✅️ | 🔄 | -| [Rescale](xrspatial/normalize.py) | Min-max normalization to a target range (default [0, 1]) | Standard | ✅️ | ✅️ | ✅️ | ✅️ | -| [Standardize](xrspatial/normalize.py) | Z-score normalization (subtract mean, divide by std) | Standard | ✅️ | ✅️ | ✅️ | ✅️ | +| [IDW](xrspatial/interpolate/_idw.py) | Inverse Distance Weighting from scattered points to a raster grid | Standard (IDW) | ✅️ | ✅️ | ✅️ | ✅️ | +| [Kriging](xrspatial/interpolate/_kriging.py) | Ordinary Kriging with automatic variogram fitting (spherical, exponential, gaussian) | Standard (ordinary kriging) | ✅️ | ✅️ | ✅️ | ✅️ | +| [Spline](xrspatial/interpolate/_spline.py) | Thin Plate Spline interpolation with optional smoothing | Standard (TPS) | ✅️ | ✅️ | ✅️ | ✅️ | + +----------- + +### **Morphological** + +| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | +|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| +| [Erode](xrspatial/morphology.py) | Morphological erosion (local minimum over structuring element) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | +| [Dilate](xrspatial/morphology.py) | Morphological dilation (local maximum over structuring element) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | +| [Opening](xrspatial/morphology.py) | Erosion then dilation (removes small bright features) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | +| [Closing](xrspatial/morphology.py) | Dilation then erosion (fills small dark gaps) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | +| [Gradient](xrspatial/morphology.py) | Dilation minus erosion (edge detection) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | +| [White Top-hat](xrspatial/morphology.py) | Original minus opening (isolate bright features) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | +| [Black Top-hat](xrspatial/morphology.py) | Closing minus original (isolate dark features) | Standard (morphology) | ✅️ | ✅️ | ✅️ | ✅️ | + +------- + +### **Fire** + +| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | +|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| +| [dNBR](xrspatial/fire.py) | Differenced Normalized Burn Ratio (pre minus post NBR) | USGS | ✅️ | ✅️ | ✅️ | ✅️ | +| [RdNBR](xrspatial/fire.py) | Relative dNBR normalized by pre-fire vegetation density | USGS | ✅️ | ✅️ | ✅️ | ✅️ | +| [Burn Severity Class](xrspatial/fire.py) | USGS 7-class burn severity from dNBR thresholds | USGS | ✅️ | ✅️ | ✅️ | ✅️ | +| [Fireline Intensity](xrspatial/fire.py) | Byram's fireline intensity from fuel load and spread rate (kW/m) | Byram 1959 | ✅️ | ✅️ | ✅️ | ✅️ | +| [Flame Length](xrspatial/fire.py) | Flame length derived from fireline intensity (m) | Byram 1959 | ✅️ | ✅️ | ✅️ | ✅️ | +| [Rate of Spread](xrspatial/fire.py) | Simplified Rothermel spread rate with Anderson 13 fuel models (m/min) | Rothermel 1972, Anderson 1982 | ✅️ | ✅️ | ✅️ | ✅️ | +| [KBDI](xrspatial/fire.py) | Keetch-Byram Drought Index single time-step update (0-800 mm) | Keetch & Byram 1968 | ✅️ | ✅️ | ✅️ | ✅️ | + +------- + +### **Raster / Vector Conversion** + +| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | +|:-----|:------------|:------:|:------------------:|:-----------------:|:---------------------:|:---------------------:| +| [Polygonize](xrspatial/polygonize.py) | Converts contiguous regions of equal value into vector polygons | Standard (CCL) | ✅️ | ✅️ | ✅️ | 🔄 | +| [Contours](xrspatial/contour.py) | Extracts elevation contour lines (isolines) from a raster surface | Standard (marching squares) | ✅️ | ✅️ | 🔄 | 🔄 | +| [Rasterize](xrspatial/rasterize.py) | Rasterizes vector geometries (polygons, lines, points) from a GeoDataFrame | Standard (scanline, Bresenham) | ✅️ | | ✅️ | | + +-------- + +### **Multivariate** + +| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | +|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| +| [Mahalanobis Distance](xrspatial/mahalanobis.py) | Measures statistical distance from a multi-band reference distribution, accounting for band correlations | Mahalanobis 1936 | ✅️ |✅️ | ✅️ |✅️ | + +------- + +### **Pathfinding** + +| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | +|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| +| [A* Pathfinding](xrspatial/pathfinding.py) | Finds the least-cost path between two cells on a cost surface | Hart et al. 1968 | ✅️ | ✅ | 🔄 | 🔄 | +| [Multi-Stop Search](xrspatial/pathfinding.py) | Routes through N waypoints in sequence, with optional TSP reordering | Custom | ✅️ | ✅ | 🔄 | 🔄 | + +---------- + +### **Diffusion** + +| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | +|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| +| [Diffuse](xrspatial/diffusion.py) | Runs explicit forward-Euler diffusion on a 2D scalar field | Standard (heat equation) | ✅️ | ✅️ | ✅️ | ✅️ | + +------- + +### **Dasymetric** + +| Name | Description | Source | NumPy xr.DataArray | Dask xr.DataArray | CuPy GPU xr.DataArray | Dask GPU xr.DataArray | +|:----------:|:------------|:------:|:----------------------:|:--------------------:|:-------------------:|:------:| +| [Disaggregate](xrspatial/dasymetric.py) | Redistributes zonal totals to pixels using an ancillary weight surface | Mennis 2003 | ✅️ | ✅️ | ✅️ | ✅️ | +| [Pycnophylactic](xrspatial/dasymetric.py) | Tobler's pycnophylactic interpolation preserving zone totals via Laplacian smoothing | Tobler 1979 | ✅️ | | | | + +----------- + #### Usage @@ -386,12 +460,11 @@ In the GIS world, rasters are used for representing continuous phenomena (e.g. e Importing `xrspatial` registers an `.xrs` accessor on DataArrays and Datasets, giving you tab-completable access to every spatial operation: ```python -import numpy as np -import xarray as xr import xrspatial +from xrspatial.geotiff import read_geotiff -# Create or load a raster -elevation = xr.DataArray(np.random.rand(100, 100) * 1000, dims=['y', 'x']) +# Read a GeoTIFF (no GDAL required) +elevation = read_geotiff('dem.tif') # Surface analysis — call operations directly on the DataArray slope = elevation.xrs.slope() @@ -449,20 +522,27 @@ Check out the user guide [here](/examples/user_guide/). #### Dependencies -`xarray-spatial` currently depends on Datashader, but will soon be updated to depend only on `xarray` and `numba`, while still being able to make use of Datashader output when available. +**Core:** numpy, numba, scipy, xarray, matplotlib, zstandard + +**Optional:** +- `pyproj` — WKT/PROJ CRS resolution +- `cupy` — GPU acceleration +- `dask` — out-of-core processing +- `libnvcomp` — GPU batch decompression (deflate, ZSTD) +- `kvikio` — GPUDirect Storage (SSD → GPU) +- `fsspec` + `s3fs`/`gcsfs`/`adlfs` — cloud storage ![title](img/dependencies.svg) #### Notes on GDAL -Within the Python ecosystem, many geospatial libraries interface with the GDAL C++ library for raster and vector input, output, and analysis (e.g. rasterio, rasterstats, geopandas). GDAL is robust, performant, and has decades of great work behind it. For years, off-loading expensive computations to the C/C++ level in this way has been a key performance strategy for Python libraries (obviously...Python itself is implemented in C!). +`xarray-spatial` does not depend on GDAL. The built-in GeoTIFF/COG reader and writer (`xrspatial.geotiff`) handles raster I/O natively using only numpy, numba, and the standard library. This means: -However, wrapping GDAL has a few drawbacks for Python developers and data scientists: -- GDAL can be a pain to build / install. -- GDAL is hard for Python developers/analysts to extend, because it requires understanding multiple languages. -- GDAL's data structures are defined at the C/C++ level, which constrains how they can be accessed from Python. +- **Zero GDAL installation hassle.** `pip install xarray-spatial` gets you everything needed to read and write GeoTIFFs, COGs, and VRT files. +- **Pure Python, fully extensible.** All codec, header parsing, and metadata code is readable Python/Numba, not wrapped C/C++. +- **GPU-accelerated reads.** With optional nvCOMP, compressed tiles decompress directly on the GPU via CUDA -- something GDAL cannot do. -With the introduction of projects like Numba, Python gained new ways to provide high-performance code directly in Python, without depending on or being constrained by separate C/C++ extensions. `xarray-spatial` implements algorithms using Numba and Dask, making all of its source code available as pure Python without any "black box" barriers that obscure what is going on and prevent full optimization. Projects can make use of the functionality provided by `xarray-spatial` where available, while still using GDAL where required for other tasks. +The native reader is pixel-exact against rasterio/GDAL across Landsat 8, Copernicus DEM, USGS 1-arc-second, and USGS 1-meter DEMs. For uncompressed files it reads 5-7x faster than rioxarray; for compressed COGs it is comparable or faster with GPU acceleration. #### Citation Cite this code: diff --git a/docs/source/user_guide/multispectral.ipynb b/docs/source/user_guide/multispectral.ipynb index f736de73..60ff5f4e 100644 --- a/docs/source/user_guide/multispectral.ipynb +++ b/docs/source/user_guide/multispectral.ipynb @@ -41,18 +41,7 @@ }, "outputs": [], "source": [ - "import datashader as ds\n", - "from datashader.colors import Elevation\n", - "import datashader.transfer_functions as tf\n", - "from datashader.transfer_functions import shade\n", - "from datashader.transfer_functions import stack\n", - "from datashader.transfer_functions import dynspread\n", - "from datashader.transfer_functions import set_background\n", - "from datashader.transfer_functions import Images, Image\n", - "from datashader.utils import orient_array\n", - "import numpy as np\n", - "import xarray as xr\n", - "import rioxarray" + "import datashader as ds\nfrom datashader.colors import Elevation\nimport datashader.transfer_functions as tf\nfrom datashader.transfer_functions import shade\nfrom datashader.transfer_functions import stack\nfrom datashader.transfer_functions import dynspread\nfrom datashader.transfer_functions import set_background\nfrom datashader.transfer_functions import Images, Image\nfrom datashader.utils import orient_array\nimport numpy as np\nimport xarray as xr\nfrom xrspatial.geotiff import read_geotiff" ] }, { @@ -143,23 +132,7 @@ } ], "source": [ - "SCENE_ID = \"LC80030172015001LGN00\"\n", - "EXTS = {\n", - " \"blue\": \"B2\",\n", - " \"green\": \"B3\",\n", - " \"red\": \"B4\",\n", - " \"nir\": \"B5\",\n", - "}\n", - "\n", - "cvs = ds.Canvas(plot_width=1024, plot_height=1024)\n", - "layers = {}\n", - "for name, ext in EXTS.items():\n", - " layer = rioxarray.open_rasterio(f\"../../../xrspatial-examples/data/{SCENE_ID}_{ext}.tiff\").load()[0]\n", - " layer.name = name\n", - " layer = cvs.raster(layer, agg=\"mean\")\n", - " layer.data = orient_array(layer)\n", - " layers[name] = layer\n", - "layers" + "SCENE_ID = \"LC80030172015001LGN00\"\nEXTS = {\n \"blue\": \"B2\",\n \"green\": \"B3\",\n \"red\": \"B4\",\n \"nir\": \"B5\",\n}\n\ncvs = ds.Canvas(plot_width=1024, plot_height=1024)\nlayers = {}\nfor name, ext in EXTS.items():\n layer = read_geotiff(f\"../../../xrspatial-examples/data/{SCENE_ID}_{ext}.tiff\", band=0)\n layer.name = name\n layer = cvs.raster(layer, agg=\"mean\")\n layer.data = orient_array(layer)\n layers[name] = layer\nlayers" ] }, { @@ -362,7 +335,7 @@ "}\n", "\n", ".xr-group-name::before {\n", - " content: \"📁\";\n", + " content: \"\ud83d\udcc1\";\n", " padding-right: 0.3em;\n", "}\n", "\n", @@ -425,7 +398,7 @@ "\n", ".xr-section-summary-in + label:before {\n", " display: inline-block;\n", - " content: \"►\";\n", + " content: \"\u25ba\";\n", " font-size: 11px;\n", " width: 15px;\n", " text-align: center;\n", @@ -436,7 +409,7 @@ "}\n", "\n", ".xr-section-summary-in:checked + label:before {\n", - " content: \"▼\";\n", + " content: \"\u25bc\";\n", "}\n", "\n", ".xr-section-summary-in:checked + label > span {\n", diff --git a/examples/user_guide/25_GLCM_Texture.ipynb b/examples/user_guide/25_GLCM_Texture.ipynb index c1623471..9ff23695 100644 --- a/examples/user_guide/25_GLCM_Texture.ipynb +++ b/examples/user_guide/25_GLCM_Texture.ipynb @@ -264,7 +264,7 @@ "id": "ec79xdunce9", "metadata": {}, "source": [ - "### Step 1 — Download a Sentinel-2 NIR band\n", + "### Step 1 \u2014 Download a Sentinel-2 NIR band\n", "\n", "We read a 500 x 500 pixel window (5 km x 5 km at 10 m resolution) straight from a\n", "Cloud-Optimized GeoTIFF hosted on AWS. The scene is\n", @@ -282,39 +282,7 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "import rioxarray\n", - "\n", - "os.environ['AWS_NO_SIGN_REQUEST'] = 'YES'\n", - "os.environ['GDAL_DISABLE_READDIR_ON_OPEN'] = 'EMPTY_DIR'\n", - "\n", - "COG_URL = (\n", - " 'https://sentinel-cogs.s3.us-west-2.amazonaws.com/'\n", - " 'sentinel-s2-l2a-cogs/10/S/EG/2023/9/'\n", - " 'S2B_10SEG_20230921_0_L2A/B08.tif'\n", - ")\n", - "\n", - "try:\n", - " nir_da = rioxarray.open_rasterio(COG_URL).isel(band=0, y=slice(2100, 2600), x=slice(5300, 5800))\n", - " nir = nir_da.load().values.astype(np.float64)\n", - " print(f'Downloaded NIR band: {nir.shape}, range {nir.min():.0f} to {nir.max():.0f}')\n", - "except Exception as exc:\n", - " print(f'Remote read failed ({exc}), using synthetic fallback')\n", - " rng_sat = np.random.default_rng(99)\n", - " nir = np.zeros((500, 500), dtype=np.float64)\n", - " nir[:, 250:] = rng_sat.normal(80, 10, (500, 250)).clip(20, 200)\n", - " nir[:, :250] = rng_sat.normal(1800, 400, (500, 250)).clip(300, 4000)\n", - "\n", - "satellite = xr.DataArray(nir, dims=['y', 'x'],\n", - " coords={'y': np.arange(nir.shape[0], dtype=float),\n", - " 'x': np.arange(nir.shape[1], dtype=float)})\n", - "\n", - "fig, ax = plt.subplots(figsize=(7, 7))\n", - "satellite.plot.imshow(ax=ax, cmap='gray', vmax=float(np.percentile(nir, 98)),\n", - " add_colorbar=False)\n", - "ax.set_title('Sentinel-2 NIR band')\n", - "ax.set_axis_off()\n", - "plt.tight_layout()" + "import os\nfrom xrspatial.geotiff import read_geotiff\n\n\nCOG_URL = (\n 'https://sentinel-cogs.s3.us-west-2.amazonaws.com/'\n 'sentinel-s2-l2a-cogs/10/S/EG/2023/9/'\n 'S2B_10SEG_20230921_0_L2A/B08.tif'\n)\n\ntry:\n nir_da = read_geotiff(COG_URL, band=0, window=(2100, 5300, 2600, 5800))\n nir = nir_da.values.astype(np.float64)\n print(f'Downloaded NIR band: {nir.shape}, range {nir.min():.0f} to {nir.max():.0f}')\nexcept Exception as exc:\n print(f'Remote read failed ({exc}), using synthetic fallback')\n rng_sat = np.random.default_rng(99)\n nir = np.zeros((500, 500), dtype=np.float64)\n nir[:, 250:] = rng_sat.normal(80, 10, (500, 250)).clip(20, 200)\n nir[:, :250] = rng_sat.normal(1800, 400, (500, 250)).clip(300, 4000)\n\nsatellite = xr.DataArray(nir, dims=['y', 'x'],\n coords={'y': np.arange(nir.shape[0], dtype=float),\n 'x': np.arange(nir.shape[1], dtype=float)})\n\nfig, ax = plt.subplots(figsize=(7, 7))\nsatellite.plot.imshow(ax=ax, cmap='gray', vmax=float(np.percentile(nir, 98)),\n add_colorbar=False)\nax.set_title('Sentinel-2 NIR band')\nax.set_axis_off()\nplt.tight_layout()" ] }, { @@ -322,7 +290,7 @@ "id": "joxz7n8olpc", "metadata": {}, "source": [ - "### Step 2 — Compute GLCM texture features\n", + "### Step 2 \u2014 Compute GLCM texture features\n", "\n", "We pick four metrics that tend to separate water (uniform, high energy, high homogeneity) from land (rough, high contrast):\n", "\n", @@ -485,4 +453,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/examples/viewshed_gpu.ipynb b/examples/viewshed_gpu.ipynb index 845995d3..61f1ffa8 100644 --- a/examples/viewshed_gpu.ipynb +++ b/examples/viewshed_gpu.ipynb @@ -34,7 +34,9 @@ } }, "outputs": [], - "source": "import pandas\nimport matplotlib.pyplot as plt\nimport geopandas as gpd\n\nimport xarray as xr\nimport numpy as np\nimport cupy\nimport rioxarray\n\nimport xrspatial" + "source": [ + "import pandas\nimport matplotlib.pyplot as plt\nimport geopandas as gpd\n\nimport xarray as xr\nimport numpy as np\nimport cupy\nfrom xrspatial.geotiff import read_geotiff\n\nimport xrspatial" + ] }, { "cell_type": "markdown", @@ -64,15 +66,7 @@ }, "outputs": [], "source": [ - "file_name = '../xrspatial-examples/data/colorado_merge_3arc_resamp.tif'\n", - "\n", - "raster = rioxarray.open_rasterio(file_name).sel(band=1).drop_vars('band')\n", - "raster.name = 'Colorado Elevation Raster'\n", - "\n", - "xmin, xmax = raster.x.data.min(), raster.x.data.max()\n", - "ymin, ymax = raster.y.data.min(), raster.y.data.max()\n", - "\n", - "xmin, xmax, ymin, ymax" + "file_name = '../xrspatial-examples/data/colorado_merge_3arc_resamp.tif'\n\nraster = read_geotiff(file_name, band=0)\nraster.name = 'Colorado Elevation Raster'\n\nxmin, xmax = raster.x.data.min(), raster.x.data.max()\nymin, ymax = raster.y.data.min(), raster.y.data.max()\n\nxmin, xmax, ymin, ymax" ] }, { diff --git a/examples/xarray-spatial_classification-methods.ipynb b/examples/xarray-spatial_classification-methods.ipynb index 8d4416f0..ab56f074 100644 --- a/examples/xarray-spatial_classification-methods.ipynb +++ b/examples/xarray-spatial_classification-methods.ipynb @@ -46,7 +46,9 @@ } }, "outputs": [], - "source": "import xarray as xr\nimport rioxarray\nimport xrspatial\n\nfile_name = '../xrspatial-examples/data/colorado_merge_3arc_resamp.tif'\nraster = rioxarray.open_rasterio(file_name).sel(band=1).drop_vars('band')\nraster.name = 'Colorado Elevation Raster'\n\nxmin, xmax = raster.x.data.min(), raster.x.data.max()\nymin, ymax = raster.y.data.min(), raster.y.data.max()\n\nxmin, xmax, ymin, ymax" + "source": [ + "import xarray as xr\nfrom xrspatial.geotiff import read_geotiff\nimport xrspatial\n\nfile_name = '../xrspatial-examples/data/colorado_merge_3arc_resamp.tif'\nraster = read_geotiff(file_name, band=0)\nraster.name = 'Colorado Elevation Raster'\n\nxmin, xmax = raster.x.data.min(), raster.x.data.max()\nymin, ymax = raster.y.data.min(), raster.y.data.max()\n\nxmin, xmax, ymin, ymax" + ] }, { "cell_type": "code", diff --git a/setup.cfg b/setup.cfg index 85c1a741..9f7648ad 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,8 @@ install_requires = scipy xarray numpy + matplotlib + zstandard packages = find: python_requires = >=3.12 setup_requires = setuptools_scm diff --git a/xrspatial/accessor.py b/xrspatial/accessor.py index 51eb1007..c17b2949 100644 --- a/xrspatial/accessor.py +++ b/xrspatial/accessor.py @@ -21,6 +21,33 @@ class XrsSpatialDataArrayAccessor: def __init__(self, obj): self._obj = obj + # ---- Plot ---- + + def plot(self, **kwargs): + """Plot the DataArray, using an embedded TIFF colormap if present. + + For palette/indexed-color GeoTIFFs (read via ``read_geotiff``), + the TIFF's color table is applied automatically with correct + normalization. For all other DataArrays, falls through to the + standard ``da.plot()``. + + Usage:: + + da = read_geotiff('landcover.tif') + da.xrs.plot() # palette colors used automatically + """ + import numpy as np + cmap = self._obj.attrs.get('cmap') + if cmap is not None and 'cmap' not in kwargs: + from matplotlib.colors import BoundaryNorm + n_colors = len(cmap.colors) + boundaries = np.arange(n_colors + 1) - 0.5 + norm = BoundaryNorm(boundaries, n_colors) + kwargs.setdefault('cmap', cmap) + kwargs.setdefault('norm', norm) + kwargs.setdefault('add_colorbar', True) + return self._obj.plot(**kwargs) + # ---- Surface ---- def slope(self, **kwargs): diff --git a/xrspatial/geotiff/__init__.py b/xrspatial/geotiff/__init__.py new file mode 100644 index 00000000..2940ba4c --- /dev/null +++ b/xrspatial/geotiff/__init__.py @@ -0,0 +1,969 @@ +"""Lightweight GeoTIFF/COG reader and writer. + +No GDAL dependency -- uses only numpy, numba, xarray, and the standard library. + +Public API +---------- +read_geotiff(source, ...) + Read a GeoTIFF file to an xarray.DataArray. +write_geotiff(data, path, ...) + Write an xarray.DataArray as a GeoTIFF or COG. +open_cog(url, ...) + Read a Cloud Optimized GeoTIFF from an HTTP URL. +""" +from __future__ import annotations + +import numpy as np +import xarray as xr + +from ._geotags import GeoTransform, RASTER_PIXEL_IS_AREA, RASTER_PIXEL_IS_POINT +from ._reader import read_to_array +from ._writer import write + +__all__ = ['read_geotiff', 'write_geotiff', 'write_vrt'] + + +def _wkt_to_epsg(wkt_or_proj: str) -> int | None: + """Try to extract an EPSG code from a WKT or PROJ string. + + Returns None if pyproj is not installed or the string can't be parsed. + """ + try: + from pyproj import CRS + crs = CRS.from_user_input(wkt_or_proj) + epsg = crs.to_epsg() + return epsg + except Exception: + return None + + +def _geo_to_coords(geo_info, height: int, width: int) -> dict: + """Build y/x coordinate arrays from GeoInfo. + + For PixelIsArea (default): origin is the edge of pixel (0,0), so pixel + centers are at origin + 0.5*pixel_size. + For PixelIsPoint: origin (tiepoint) is already the center of pixel (0,0), + so no half-pixel offset is needed. + """ + t = geo_info.transform + if geo_info.raster_type == RASTER_PIXEL_IS_POINT: + # Tiepoint is pixel center -- no offset needed + x = np.arange(width, dtype=np.float64) * t.pixel_width + t.origin_x + y = np.arange(height, dtype=np.float64) * t.pixel_height + t.origin_y + else: + # Tiepoint is pixel edge -- shift to center + x = np.arange(width, dtype=np.float64) * t.pixel_width + t.origin_x + t.pixel_width * 0.5 + y = np.arange(height, dtype=np.float64) * t.pixel_height + t.origin_y + t.pixel_height * 0.5 + return {'y': y, 'x': x} + + +def _coords_to_transform(da: xr.DataArray) -> GeoTransform | None: + """Infer GeoTransform from DataArray coordinates. + + Coordinates are always pixel-center values. The transform origin depends + on raster_type: + - PixelIsArea (default): origin = center - half_pixel (edge of pixel 0) + - PixelIsPoint: origin = center (center of pixel 0) + """ + ydim = da.dims[-2] + xdim = da.dims[-1] + + if xdim not in da.coords or ydim not in da.coords: + return None + + x = da.coords[xdim].values + y = da.coords[ydim].values + + if len(x) < 2 or len(y) < 2: + return None + + pixel_width = float(x[1] - x[0]) + pixel_height = float(y[1] - y[0]) + + is_point = da.attrs.get('raster_type') == 'point' + if is_point: + # PixelIsPoint: tiepoint is at the pixel center + origin_x = float(x[0]) + origin_y = float(y[0]) + else: + # PixelIsArea: tiepoint is at the edge (center - half pixel) + origin_x = float(x[0]) - pixel_width * 0.5 + origin_y = float(y[0]) - pixel_height * 0.5 + + return GeoTransform( + origin_x=origin_x, + origin_y=origin_y, + pixel_width=pixel_width, + pixel_height=pixel_height, + ) + + +def read_geotiff(source: str, *, window=None, + overview_level: int | None = None, + band: int | None = None, + name: str | None = None, + chunks: int | tuple | None = None, + gpu: bool = False) -> xr.DataArray: + """Read a GeoTIFF, COG, or VRT file into an xarray.DataArray. + + Automatically dispatches to the best backend: + - ``gpu=True``: GPU-accelerated read via nvCOMP (returns CuPy) + - ``chunks=N``: Dask lazy read via windowed chunks + - ``gpu=True, chunks=N``: Dask+CuPy for out-of-core GPU pipelines + - Default: NumPy eager read + + VRT files are auto-detected by extension. + + Parameters + ---------- + source : str + File path, HTTP URL, or cloud URI (s3://, gs://, az://). + window : tuple or None + (row_start, col_start, row_stop, col_stop) for windowed reading. + overview_level : int or None + Overview level (0 = full resolution). + band : int or None + Band index (0-based). None returns all bands. + name : str or None + Name for the DataArray. + chunks : int, tuple, or None + Chunk size for Dask lazy reading. + gpu : bool + Use GPU-accelerated decompression (requires cupy + nvCOMP). + + Returns + ------- + xr.DataArray + NumPy, Dask, CuPy, or Dask+CuPy backed depending on options. + """ + # VRT files + if source.lower().endswith('.vrt'): + return read_vrt(source, window=window, band=band, name=name, + chunks=chunks, gpu=gpu) + + # GPU path + if gpu: + return read_geotiff_gpu(source, overview_level=overview_level, + name=name, chunks=chunks) + + # Dask path (CPU) + if chunks is not None: + return read_geotiff_dask(source, chunks=chunks, + overview_level=overview_level, name=name) + + arr, geo_info = read_to_array( + source, window=window, + overview_level=overview_level, band=band, + ) + + height, width = arr.shape[:2] + coords = _geo_to_coords(geo_info, height, width) + + if window is not None: + # Adjust coordinates for windowed read + r0, c0, r1, c1 = window + t = geo_info.transform + full_x = np.arange(c0, c1, dtype=np.float64) * t.pixel_width + t.origin_x + t.pixel_width * 0.5 + full_y = np.arange(r0, r1, dtype=np.float64) * t.pixel_height + t.origin_y + t.pixel_height * 0.5 + coords = {'y': full_y, 'x': full_x} + + if name is None: + # Derive from source path + import os + name = os.path.splitext(os.path.basename(source))[0] + + attrs = {} + if geo_info.crs_epsg is not None: + attrs['crs'] = geo_info.crs_epsg + if geo_info.crs_wkt is not None: + attrs['crs_wkt'] = geo_info.crs_wkt + if geo_info.raster_type == RASTER_PIXEL_IS_POINT: + attrs['raster_type'] = 'point' + + # CRS description fields + if geo_info.crs_name is not None: + attrs['crs_name'] = geo_info.crs_name + if geo_info.geog_citation is not None: + attrs['geog_citation'] = geo_info.geog_citation + if geo_info.datum_code is not None: + attrs['datum_code'] = geo_info.datum_code + if geo_info.angular_units is not None: + attrs['angular_units'] = geo_info.angular_units + if geo_info.linear_units is not None: + attrs['linear_units'] = geo_info.linear_units + if geo_info.semi_major_axis is not None: + attrs['semi_major_axis'] = geo_info.semi_major_axis + if geo_info.inv_flattening is not None: + attrs['inv_flattening'] = geo_info.inv_flattening + if geo_info.projection_code is not None: + attrs['projection_code'] = geo_info.projection_code + # Vertical CRS + if geo_info.vertical_epsg is not None: + attrs['vertical_crs'] = geo_info.vertical_epsg + if geo_info.vertical_citation is not None: + attrs['vertical_citation'] = geo_info.vertical_citation + if geo_info.vertical_units is not None: + attrs['vertical_units'] = geo_info.vertical_units + + # GDAL metadata (tag 42112) + if geo_info.gdal_metadata is not None: + attrs['gdal_metadata'] = geo_info.gdal_metadata + if geo_info.gdal_metadata_xml is not None: + attrs['gdal_metadata_xml'] = geo_info.gdal_metadata_xml + + # Extra (non-managed) TIFF tags for pass-through + if geo_info.extra_tags is not None: + attrs['extra_tags'] = geo_info.extra_tags + + # Resolution / DPI metadata + if geo_info.x_resolution is not None: + attrs['x_resolution'] = geo_info.x_resolution + if geo_info.y_resolution is not None: + attrs['y_resolution'] = geo_info.y_resolution + if geo_info.resolution_unit is not None: + _unit_names = {1: 'none', 2: 'inch', 3: 'centimeter'} + attrs['resolution_unit'] = _unit_names.get( + geo_info.resolution_unit, str(geo_info.resolution_unit)) + + # Attach palette colormap for indexed-color TIFFs + if geo_info.colormap is not None: + try: + from matplotlib.colors import ListedColormap + cmap = ListedColormap(geo_info.colormap, name='tiff_palette') + attrs['cmap'] = cmap + attrs['colormap_rgba'] = geo_info.colormap + except ImportError: + # matplotlib not available -- store raw RGBA tuples only + attrs['colormap_rgba'] = geo_info.colormap + + # Apply nodata mask: replace nodata sentinel values with NaN + nodata = geo_info.nodata + if nodata is not None: + attrs['nodata'] = nodata + if arr.dtype.kind == 'f': + if not np.isnan(nodata): + arr = arr.copy() + arr[arr == arr.dtype.type(nodata)] = np.nan + elif arr.dtype.kind in ('u', 'i'): + # Integer arrays: convert to float to represent NaN + nodata_int = int(nodata) + mask = arr == arr.dtype.type(nodata_int) + if mask.any(): + arr = arr.astype(np.float64) + arr[mask] = np.nan + + if arr.ndim == 3: + dims = ['y', 'x', 'band'] + coords['band'] = np.arange(arr.shape[2]) + else: + dims = ['y', 'x'] + + da = xr.DataArray( + arr, + dims=dims, + coords=coords, + name=name, + attrs=attrs, + ) + return da + + +def _is_gpu_data(data) -> bool: + """Check if data is CuPy-backed (raw array or DataArray).""" + try: + import cupy + _cupy_type = cupy.ndarray + except ImportError: + return False + + if isinstance(data, xr.DataArray): + raw = data.data + if hasattr(raw, 'compute'): + meta = getattr(raw, '_meta', None) + return isinstance(meta, _cupy_type) + return isinstance(raw, _cupy_type) + return isinstance(data, _cupy_type) + + +def write_geotiff(data: xr.DataArray | np.ndarray, path: str, *, + crs: int | str | None = None, + nodata=None, + compression: str = 'deflate', + tiled: bool = True, + tile_size: int = 256, + predictor: bool = False, + cog: bool = False, + overview_levels: list[int] | None = None, + overview_resampling: str = 'mean', + bigtiff: bool | None = None, + gpu: bool | None = None) -> None: + """Write data as a GeoTIFF or Cloud Optimized GeoTIFF. + + Automatically dispatches to GPU compression when: + - ``gpu=True`` is passed, or + - The input data is CuPy-backed (auto-detected) + + GPU write uses nvCOMP batch compression (deflate/ZSTD) and keeps + the array on device. Falls back to CPU if nvCOMP is not available. + + Parameters + ---------- + data : xr.DataArray or np.ndarray + 2D raster data. + path : str + Output file path. + crs : int, str, or None + EPSG code (int), WKT string, or PROJ string. If None and data + is a DataArray, tries to read from attrs ('crs' for EPSG, + 'crs_wkt' for WKT). + nodata : float, int, or None + NoData value. + compression : str + 'none', 'deflate', or 'lzw'. + tiled : bool + Use tiled layout (default True). + tile_size : int + Tile size in pixels (default 256). + predictor : bool + Use horizontal differencing predictor. + cog : bool + Write as Cloud Optimized GeoTIFF. + overview_levels : list[int] or None + Overview decimation factors. Only used when cog=True. + overview_resampling : str + Resampling method for overviews: 'mean' (default), 'nearest', + 'min', 'max', 'median', 'mode', or 'cubic'. + gpu : bool or None + Force GPU compression. None (default) auto-detects CuPy data. + """ + # Auto-detect GPU data and dispatch to write_geotiff_gpu + use_gpu = gpu if gpu is not None else _is_gpu_data(data) + if use_gpu: + try: + write_geotiff_gpu(data, path, crs=crs, nodata=nodata, + compression=compression, tile_size=tile_size, + predictor=predictor) + return + except (ImportError, Exception): + pass # fall through to CPU path + + geo_transform = None + epsg = None + raster_type = RASTER_PIXEL_IS_AREA + x_res = None + y_res = None + res_unit = None + gdal_meta_xml = None + extra_tags_list = None + + # Resolve crs argument: can be int (EPSG) or str (WKT/PROJ) + if isinstance(crs, int): + epsg = crs + elif isinstance(crs, str): + epsg = _wkt_to_epsg(crs) # try to extract EPSG from WKT/PROJ + + if isinstance(data, xr.DataArray): + # Handle CuPy-backed DataArrays: convert to numpy for CPU write + raw = data.data + if hasattr(raw, 'get'): + arr = raw.get() # CuPy -> numpy + elif hasattr(raw, 'compute'): + arr = raw.compute() # Dask -> numpy + if hasattr(arr, 'get'): + arr = arr.get() # Dask+CuPy -> numpy + else: + arr = np.asarray(raw) + # Handle band-first dimension order (band, y, x) -> (y, x, band) + if arr.ndim == 3 and data.dims[0] in ('band', 'bands', 'channel'): + arr = np.moveaxis(arr, 0, -1) + if geo_transform is None: + geo_transform = _coords_to_transform(data) + if epsg is None and crs is None: + epsg = data.attrs.get('crs') + if epsg is None: + # Try resolving EPSG from a WKT string in attrs + wkt = data.attrs.get('crs_wkt') + if isinstance(wkt, str): + epsg = _wkt_to_epsg(wkt) + if nodata is None: + nodata = data.attrs.get('nodata') + if data.attrs.get('raster_type') == 'point': + raster_type = RASTER_PIXEL_IS_POINT + # GDAL metadata from attrs (prefer raw XML, fall back to dict) + gdal_meta_xml = data.attrs.get('gdal_metadata_xml') + if gdal_meta_xml is None: + gdal_meta_dict = data.attrs.get('gdal_metadata') + if isinstance(gdal_meta_dict, dict): + from ._geotags import _build_gdal_metadata_xml + gdal_meta_xml = _build_gdal_metadata_xml(gdal_meta_dict) + # Extra tags for pass-through + extra_tags_list = data.attrs.get('extra_tags') + # Resolution / DPI from attrs + x_res = data.attrs.get('x_resolution') + y_res = data.attrs.get('y_resolution') + unit_str = data.attrs.get('resolution_unit') + if unit_str is not None: + _unit_ids = {'none': 1, 'inch': 2, 'centimeter': 3} + res_unit = _unit_ids.get(str(unit_str), None) + else: + if hasattr(data, 'get'): + arr = data.get() # CuPy -> numpy + else: + arr = np.asarray(data) + + if arr.ndim not in (2, 3): + raise ValueError(f"Expected 2D or 3D array, got {arr.ndim}D") + + # Auto-promote unsupported dtypes + if arr.dtype == np.float16: + arr = arr.astype(np.float32) + elif arr.dtype == np.bool_: + arr = arr.astype(np.uint8) + + write( + arr, path, + geo_transform=geo_transform, + crs_epsg=epsg, + nodata=nodata, + compression=compression, + tiled=tiled, + tile_size=tile_size, + predictor=predictor, + cog=cog, + overview_levels=overview_levels, + overview_resampling=overview_resampling, + raster_type=raster_type, + x_resolution=x_res, + y_resolution=y_res, + resolution_unit=res_unit, + gdal_metadata_xml=gdal_meta_xml, + extra_tags=extra_tags_list, + bigtiff=bigtiff, + ) + + +def open_cog(url: str, **kwargs) -> xr.DataArray: + """Deprecated: use ``read_geotiff(url, ...)`` instead. + + read_geotiff handles HTTP URLs, cloud URIs, and local files. + """ + return read_geotiff(url, **kwargs) + + +def read_geotiff_dask(source: str, *, chunks: int | tuple = 512, + overview_level: int | None = None, + name: str | None = None) -> xr.DataArray: + """Read a GeoTIFF as a dask-backed DataArray for out-of-core processing. + + Each chunk is loaded lazily via windowed reads. + + Parameters + ---------- + source : str + File path. + chunks : int or (row_chunk, col_chunk) tuple + Chunk size in pixels. Default 512. + overview_level : int or None + Overview level (0 = full resolution). + name : str or None + Name for the DataArray. + + Returns + ------- + xr.DataArray + Dask-backed DataArray with y/x coordinates. + """ + import dask.array as da + + # VRT files: delegate to read_vrt which handles chunks + if source.lower().endswith('.vrt'): + return read_vrt(source, name=name, chunks=chunks) + + # First, do a metadata-only read to get shape, dtype, coords, attrs + arr, geo_info = read_to_array(source, overview_level=overview_level) + full_h, full_w = arr.shape[:2] + n_bands = arr.shape[2] if arr.ndim == 3 else 0 + dtype = arr.dtype + + coords = _geo_to_coords(geo_info, full_h, full_w) + + if name is None: + import os + name = os.path.splitext(os.path.basename(source))[0] + + attrs = {} + if geo_info.crs_epsg is not None: + attrs['crs'] = geo_info.crs_epsg + if geo_info.raster_type == RASTER_PIXEL_IS_POINT: + attrs['raster_type'] = 'point' + if geo_info.nodata is not None: + attrs['nodata'] = geo_info.nodata + + if isinstance(chunks, int): + ch_h = ch_w = chunks + else: + ch_h, ch_w = chunks + + # Build dask array from delayed windowed reads + rows = list(range(0, full_h, ch_h)) + cols = list(range(0, full_w, ch_w)) + + # For multi-band, each window read returns (h, w, bands); for single-band (h, w) + # read_to_array with band=0 extracts a single band, band=None returns all + band_arg = None # return all bands (or 2D if single-band) + + dask_rows = [] + for r0 in rows: + r1 = min(r0 + ch_h, full_h) + dask_cols = [] + for c0 in cols: + c1 = min(c0 + ch_w, full_w) + if n_bands > 0: + block_shape = (r1 - r0, c1 - c0, n_bands) + else: + block_shape = (r1 - r0, c1 - c0) + block = da.from_delayed( + _delayed_read_window(source, r0, c0, r1, c1, + overview_level, geo_info.nodata, + dtype, band_arg), + shape=block_shape, + dtype=dtype, + ) + dask_cols.append(block) + dask_rows.append(da.concatenate(dask_cols, axis=1)) + + dask_arr = da.concatenate(dask_rows, axis=0) + + if n_bands > 0: + dims = ['y', 'x', 'band'] + coords['band'] = np.arange(n_bands) + else: + dims = ['y', 'x'] + + return xr.DataArray( + dask_arr, dims=dims, coords=coords, name=name, attrs=attrs, + ) + + +def _delayed_read_window(source, r0, c0, r1, c1, overview_level, nodata, + dtype, band): + """Dask-delayed function to read a single window.""" + import dask + @dask.delayed + def _read(): + arr, _ = read_to_array(source, window=(r0, c0, r1, c1), + overview_level=overview_level, band=band) + if nodata is not None: + if arr.dtype.kind == 'f' and not np.isnan(nodata): + arr = arr.copy() + arr[arr == arr.dtype.type(nodata)] = np.nan + elif arr.dtype.kind in ('u', 'i'): + mask = arr == arr.dtype.type(int(nodata)) + if mask.any(): + arr = arr.astype(np.float64) + arr[mask] = np.nan + return arr + return _read() + + +def read_geotiff_gpu(source: str, *, + overview_level: int | None = None, + name: str | None = None, + chunks: int | tuple | None = None) -> xr.DataArray: + """Read a GeoTIFF with GPU-accelerated decompression via Numba CUDA. + + Decompresses all tiles in parallel on the GPU and returns a + CuPy-backed DataArray that stays on device memory. No CPU->GPU + transfer needed for downstream xrspatial GPU operations. + + With ``chunks=``, returns a Dask+CuPy DataArray for out-of-core + GPU pipelines. + + Requires: cupy, numba with CUDA support. + + Parameters + ---------- + source : str + File path. + overview_level : int or None + Overview level (0 = full resolution). + chunks : int, tuple, or None + If set, return a Dask-chunked CuPy DataArray. int for square + chunks, (row, col) tuple for rectangular. + name : str or None + Name for the DataArray. + + Returns + ------- + xr.DataArray + CuPy-backed DataArray on GPU device. + """ + try: + import cupy + except ImportError: + raise ImportError( + "cupy is required for GPU reads. " + "Install it with: pip install cupy-cuda12x") + + from ._reader import _FileSource + from ._header import parse_header, parse_all_ifds + from ._dtypes import tiff_dtype_to_numpy + from ._geotags import extract_geo_info + from ._gpu_decode import gpu_decode_tiles + + # Parse metadata on CPU (fast, <1ms) + src = _FileSource(source) + data = src.read_all() + + try: + header = parse_header(data) + ifds = parse_all_ifds(data, header) + + if len(ifds) == 0: + raise ValueError("No IFDs found in TIFF file") + + ifd_idx = 0 + if overview_level is not None: + ifd_idx = min(overview_level, len(ifds) - 1) + ifd = ifds[ifd_idx] + + bps = ifd.bits_per_sample + if isinstance(bps, tuple): + bps = bps[0] + dtype = tiff_dtype_to_numpy(bps, ifd.sample_format) + geo_info = extract_geo_info(ifd, data, header.byte_order) + + if not ifd.is_tiled: + # Fall back to CPU for stripped files + src.close() + arr_cpu, _ = read_to_array(source, overview_level=overview_level) + arr_gpu = cupy.asarray(arr_cpu) + coords = _geo_to_coords(geo_info, arr_gpu.shape[0], arr_gpu.shape[1]) + if name is None: + import os + name = os.path.splitext(os.path.basename(source))[0] + attrs = {} + if geo_info.crs_epsg is not None: + attrs['crs'] = geo_info.crs_epsg + return xr.DataArray(arr_gpu, dims=['y', 'x'], + coords=coords, name=name, attrs=attrs) + + offsets = ifd.tile_offsets + byte_counts = ifd.tile_byte_counts + compression = ifd.compression + predictor = ifd.predictor + samples = ifd.samples_per_pixel + tw = ifd.tile_width + th = ifd.tile_height + width = ifd.width + height = ifd.height + + finally: + src.close() + + # GPU decode: try GDS (SSD→GPU direct) first, then CPU mmap path + from ._gpu_decode import gpu_decode_tiles_from_file + arr_gpu = None + + try: + arr_gpu = gpu_decode_tiles_from_file( + source, offsets, byte_counts, + tw, th, width, height, + compression, predictor, dtype, samples, + ) + except Exception: + pass + + if arr_gpu is None: + # Fallback: extract tiles via CPU mmap, then GPU decode + src2 = _FileSource(source) + data2 = src2.read_all() + try: + compressed_tiles = [ + bytes(data2[offsets[i]:offsets[i] + byte_counts[i]]) + for i in range(len(offsets)) + ] + finally: + src2.close() + + if arr_gpu is None: + try: + arr_gpu = gpu_decode_tiles( + compressed_tiles, + tw, th, width, height, + compression, predictor, dtype, samples, + ) + except (ValueError, Exception): + # Unsupported compression -- fall back to CPU then transfer + arr_cpu, _ = read_to_array(source, overview_level=overview_level) + arr_gpu = cupy.asarray(arr_cpu) + + # Build DataArray + if name is None: + import os + name = os.path.splitext(os.path.basename(source))[0] + + coords = _geo_to_coords(geo_info, height, width) + + attrs = {} + if geo_info.crs_epsg is not None: + attrs['crs'] = geo_info.crs_epsg + if geo_info.crs_wkt is not None: + attrs['crs_wkt'] = geo_info.crs_wkt + + if arr_gpu.ndim == 3: + dims = ['y', 'x', 'band'] + coords['band'] = np.arange(arr_gpu.shape[2]) + else: + dims = ['y', 'x'] + + result = xr.DataArray(arr_gpu, dims=dims, coords=coords, + name=name, attrs=attrs) + + if chunks is not None: + if isinstance(chunks, int): + chunk_dict = {'y': chunks, 'x': chunks} + else: + chunk_dict = {'y': chunks[0], 'x': chunks[1]} + result = result.chunk(chunk_dict) + + return result + + +def write_geotiff_gpu(data, path: str, *, + crs: int | str | None = None, + nodata=None, + compression: str = 'zstd', + tile_size: int = 256, + predictor: bool = False) -> None: + """Write a CuPy-backed DataArray as a GeoTIFF with GPU compression. + + Tiles are extracted and compressed on the GPU via nvCOMP, then + assembled into a TIFF file on CPU. The CuPy array stays on device + throughout compression -- only the compressed bytes transfer to CPU + for file writing. + + Falls back to CPU compression if nvCOMP is not available. + + Parameters + ---------- + data : xr.DataArray (CuPy-backed) or cupy.ndarray + 2D raster on GPU. + path : str + Output file path. + crs : int, str, or None + EPSG code or WKT string. + nodata : float, int, or None + NoData value. + compression : str + 'zstd' (default, fastest on GPU), 'deflate', or 'none'. + tile_size : int + Tile size in pixels (default 256). + predictor : bool + Apply horizontal differencing predictor. + """ + try: + import cupy + except ImportError: + raise ImportError("cupy is required for GPU writes") + + from ._gpu_decode import gpu_compress_tiles + from ._writer import ( + _compression_tag, _assemble_tiff, _write_bytes, + GeoTransform as _GT, + ) + from ._dtypes import numpy_to_tiff_dtype + + # Extract array and metadata + geo_transform = None + epsg = None + raster_type = 1 + + if isinstance(crs, int): + epsg = crs + elif isinstance(crs, str): + epsg = _wkt_to_epsg(crs) + + if isinstance(data, xr.DataArray): + arr = data.data + # Handle Dask arrays: compute to materialize + if hasattr(arr, 'compute'): + arr = arr.compute() + # Now arr should be CuPy or numpy + if hasattr(arr, 'get'): + pass # CuPy array, already on GPU + else: + arr = cupy.asarray(np.asarray(arr)) # numpy -> GPU + + geo_transform = _coords_to_transform(data) + if epsg is None: + epsg = data.attrs.get('crs') + if nodata is None: + nodata = data.attrs.get('nodata') + if data.attrs.get('raster_type') == 'point': + raster_type = RASTER_PIXEL_IS_POINT + else: + if hasattr(data, 'compute'): + data = data.compute() # Dask -> CuPy or numpy + if hasattr(data, 'device'): + arr = data # already CuPy + elif hasattr(data, 'get'): + arr = data # CuPy + else: + arr = cupy.asarray(np.asarray(data)) # numpy/list -> GPU + + if arr.ndim not in (2, 3): + raise ValueError(f"Expected 2D or 3D array, got {arr.ndim}D") + + height, width = arr.shape[:2] + samples = arr.shape[2] if arr.ndim == 3 else 1 + np_dtype = np.dtype(str(arr.dtype)) # cupy dtype -> numpy dtype + + comp_tag = _compression_tag(compression) + pred_val = 2 if predictor else 1 + + # GPU compress + compressed_tiles = gpu_compress_tiles( + arr, tile_size, tile_size, width, height, + comp_tag, pred_val, np_dtype, samples) + + # Build offset/bytecount lists + rel_offsets = [] + byte_counts = [] + offset = 0 + for tile in compressed_tiles: + rel_offsets.append(offset) + byte_counts.append(len(tile)) + offset += len(tile) + + # Assemble TIFF on CPU (only metadata + compressed bytes) + # _assemble_tiff needs an array in parts[0] to detect samples_per_pixel + shape_stub = np.empty((1, 1, samples) if samples > 1 else (1, 1), dtype=np_dtype) + parts = [(shape_stub, width, height, rel_offsets, byte_counts, compressed_tiles)] + + file_bytes = _assemble_tiff( + width, height, np_dtype, comp_tag, predictor, True, tile_size, + parts, geo_transform, epsg, nodata, is_cog=False, + raster_type=raster_type) + + _write_bytes(file_bytes, path) + + +def read_vrt(source: str, *, window=None, + band: int | None = None, + name: str | None = None, + chunks: int | tuple | None = None, + gpu: bool = False) -> xr.DataArray: + """Read a GDAL Virtual Raster Table (.vrt) into an xarray.DataArray. + + The VRT's source GeoTIFFs are read via windowed reads and assembled + into a single array. + + Parameters + ---------- + source : str + Path to the .vrt file. + window : tuple or None + (row_start, col_start, row_stop, col_stop) for windowed reading. + band : int or None + Band index (0-based). None returns all bands. + name : str or None + Name for the DataArray. + chunks : int, tuple, or None + If set, return a Dask-chunked DataArray. int for square chunks, + (row, col) tuple for rectangular. + gpu : bool + If True, return a CuPy-backed DataArray on GPU. + + Returns + ------- + xr.DataArray + NumPy, Dask, CuPy, or Dask+CuPy backed depending on options. + """ + from ._vrt import read_vrt as _read_vrt_internal + + arr, vrt = _read_vrt_internal(source, window=window, band=band) + + if name is None: + import os + name = os.path.splitext(os.path.basename(source))[0] + + # Build coordinates from GeoTransform + gt = vrt.geo_transform + if gt is not None: + origin_x, res_x, _, origin_y, _, res_y = gt + if window is not None: + r0, c0, r1, c1 = window + r0 = max(0, r0) + c0 = max(0, c0) + else: + r0, c0 = 0, 0 + height, width = arr.shape[:2] + x = np.arange(width, dtype=np.float64) * res_x + origin_x + (c0 + 0.5) * res_x + y = np.arange(height, dtype=np.float64) * res_y + origin_y + (r0 + 0.5) * res_y + coords = {'y': y, 'x': x} + else: + coords = {} + + attrs = {} + if vrt.crs_wkt: + epsg = _wkt_to_epsg(vrt.crs_wkt) + if epsg is not None: + attrs['crs'] = epsg + attrs['crs_wkt'] = vrt.crs_wkt + if vrt.bands: + nodata = vrt.bands[0].nodata + if nodata is not None: + attrs['nodata'] = nodata + + # Transfer to GPU if requested + if gpu: + import cupy + arr = cupy.asarray(arr) + + if arr.ndim == 3: + dims = ['y', 'x', 'band'] + coords['band'] = np.arange(arr.shape[2]) + else: + dims = ['y', 'x'] + + result = xr.DataArray(arr, dims=dims, coords=coords, name=name, attrs=attrs) + + # Chunk for Dask (or Dask+CuPy if gpu=True) + if chunks is not None: + if isinstance(chunks, int): + chunk_dict = {'y': chunks, 'x': chunks} + else: + chunk_dict = {'y': chunks[0], 'x': chunks[1]} + result = result.chunk(chunk_dict) + + return result + + +def write_vrt(vrt_path: str, source_files: list[str], **kwargs) -> str: + """Generate a VRT file that mosaics multiple GeoTIFF tiles. + + Parameters + ---------- + vrt_path : str + Output .vrt file path. + source_files : list of str + Paths to the source GeoTIFF files. + **kwargs + relative, crs_wkt, nodata -- see _vrt.write_vrt. + + Returns + ------- + str + Path to the written VRT file. + """ + from ._vrt import write_vrt as _write_vrt_internal + return _write_vrt_internal(vrt_path, source_files, **kwargs) + + +def plot_geotiff(da: xr.DataArray, **kwargs): + """Plot a DataArray using its embedded colormap if present. + + Deprecated: use ``da.xrs.plot()`` instead. + """ + return da.xrs.plot(**kwargs) diff --git a/xrspatial/geotiff/_compression.py b/xrspatial/geotiff/_compression.py new file mode 100644 index 00000000..c78c6ebc --- /dev/null +++ b/xrspatial/geotiff/_compression.py @@ -0,0 +1,807 @@ +"""Compression codecs: deflate (zlib) and LZW (Numba), plus horizontal predictor.""" +from __future__ import annotations + +import zlib + +import numpy as np + +from xrspatial.utils import ngjit + +# -- Deflate (zlib wrapper) -------------------------------------------------- + + +def deflate_decompress(data: bytes) -> bytes: + """Decompress deflate/zlib data.""" + return zlib.decompress(data) + + +def deflate_compress(data: bytes, level: int = 6) -> bytes: + """Compress data with deflate/zlib.""" + return zlib.compress(data, level) + + +# -- LZW constants ----------------------------------------------------------- + +LZW_CLEAR_CODE = 256 +LZW_EOI_CODE = 257 +LZW_FIRST_CODE = 258 +LZW_MAX_CODE = 4095 +LZW_MAX_BITS = 12 + + +# -- LZW decode (Numba) ------------------------------------------------------ + +@ngjit +def _lzw_decode_kernel(src, src_len, dst, dst_len): + """Decode TIFF-variant LZW (MSB-first) into dst buffer. + + Parameters + ---------- + src : uint8 array + Compressed bytes. + src_len : int + Number of valid bytes in src. + dst : uint8 array + Output buffer (must be pre-allocated large enough). + dst_len : int + Maximum bytes to write. + + Returns + ------- + int + Number of bytes written to dst. + """ + # Table: prefix-chain representation + table_prefix = np.full(4096, -1, dtype=np.int32) + table_suffix = np.zeros(4096, dtype=np.uint8) + table_length = np.zeros(4096, dtype=np.int32) + + # Small stack for chain reversal + stack = np.empty(4096, dtype=np.uint8) + + # Bit reader state + bit_pos = 0 + code_size = 9 + next_code = LZW_FIRST_CODE + + # Initialize table with single-byte entries + for i in range(256): + table_prefix[i] = -1 + table_suffix[i] = np.uint8(i) + table_length[i] = 1 + + out_pos = 0 + old_code = -1 + + while True: + # Read next code (MSB-first bit packing) + byte_offset = bit_pos >> 3 + if byte_offset >= src_len: + break + + # Gather up to 24 bits from available bytes + bits = np.int32(src[byte_offset]) << 16 + if byte_offset + 1 < src_len: + bits |= np.int32(src[byte_offset + 1]) << 8 + if byte_offset + 2 < src_len: + bits |= np.int32(src[byte_offset + 2]) + + bit_offset_in_byte = bit_pos & 7 + # Shift to align the code_size bits at the LSB side + bits = (bits >> (24 - bit_offset_in_byte - code_size)) & ((1 << code_size) - 1) + bit_pos += code_size + code = bits + + if code == LZW_EOI_CODE: + break + + if code == LZW_CLEAR_CODE: + code_size = 9 + next_code = LZW_FIRST_CODE + old_code = -1 + continue + + if old_code == -1: + # First code after clear + if code < 256: + if out_pos < dst_len: + dst[out_pos] = np.uint8(code) + out_pos += 1 + old_code = code + continue + + # Determine the string for this code + if code < next_code: + # Code is in table -- walk the chain, push to stack, emit reversed + c = code + stack_pos = 0 + while c >= 0 and c < 4096 and stack_pos < 4096: + stack[stack_pos] = table_suffix[c] + stack_pos += 1 + c = table_prefix[c] + + # Emit in correct order + for i in range(stack_pos - 1, -1, -1): + if out_pos < dst_len: + dst[out_pos] = stack[i] + out_pos += 1 + + # Add new entry: old_code string + first char of code string + if next_code <= LZW_MAX_CODE and stack_pos > 0: + table_prefix[next_code] = old_code + table_suffix[next_code] = stack[stack_pos - 1] # first char + table_length[next_code] = table_length[old_code] + 1 + next_code += 1 + else: + # Special case: code == next_code + # String = old_code string + first char of old_code string + c = old_code + stack_pos = 0 + while c >= 0 and c < 4096 and stack_pos < 4096: + stack[stack_pos] = table_suffix[c] + stack_pos += 1 + c = table_prefix[c] + + if stack_pos == 0: + old_code = code + continue + + first_char = stack[stack_pos - 1] + + # Emit old_code string + for i in range(stack_pos - 1, -1, -1): + if out_pos < dst_len: + dst[out_pos] = stack[i] + out_pos += 1 + # Emit first char again + if out_pos < dst_len: + dst[out_pos] = first_char + out_pos += 1 + + # Add new entry + if next_code <= LZW_MAX_CODE: + table_prefix[next_code] = old_code + table_suffix[next_code] = first_char + table_length[next_code] = table_length[old_code] + 1 + next_code += 1 + + # Bump code size (TIFF LZW uses "early change": bump one code before + # the table fills the current code_size capacity) + if next_code > (1 << code_size) - 2 and code_size < LZW_MAX_BITS: + code_size += 1 + + old_code = code + + return out_pos + + +def lzw_decompress(data: bytes, expected_size: int = 0) -> np.ndarray: + """Decompress TIFF-variant LZW data. + + Parameters + ---------- + data : bytes + LZW compressed data. + expected_size : int + Expected decompressed size. If 0, uses 10x compressed size as buffer. + + Returns + ------- + np.ndarray + Mutable uint8 array of decompressed data. + """ + src = np.frombuffer(data, dtype=np.uint8) + if expected_size <= 0: + expected_size = len(data) * 10 + dst = np.empty(expected_size, dtype=np.uint8) + n = _lzw_decode_kernel(src, len(src), dst, expected_size) + return dst[:n].copy() # owned, mutable slice + + +# -- LZW encode (Numba) ------------------------------------------------------ + +@ngjit +def _lzw_encode_kernel(src, src_len, dst, dst_len): + """Encode data as TIFF-variant LZW (MSB-first). + + Returns number of bytes written to dst. + """ + # Hash table for string matching + # Key: (prefix_code << 8) | suffix_byte -> code + # Uses generation counter to avoid clearing: an entry is valid only when + # ht_gen[slot] == current_gen. + HT_SIZE = 8209 # prime > 4096*2 + ht_keys = np.empty(HT_SIZE, dtype=np.int64) + ht_values = np.empty(HT_SIZE, dtype=np.int32) + ht_gen = np.zeros(HT_SIZE, dtype=np.int32) + current_gen = np.int32(1) + + # Bit accumulator: collect bits and flush whole bytes + bit_buf = np.int32(0) # up to 24 bits pending + bits_in_buf = np.int32(0) + out_pos = 0 + + code_size = 9 + next_code = LZW_FIRST_CODE + + def flush_code(code, code_size, bit_buf, bits_in_buf, dst, dst_len, out_pos): + """Pack a code into the bit accumulator and flush complete bytes.""" + # Merge code bits (MSB-first) into accumulator + bit_buf = (bit_buf << code_size) | code + bits_in_buf += code_size + # Flush whole bytes from the top of the accumulator + while bits_in_buf >= 8: + bits_in_buf -= 8 + if out_pos < dst_len: + dst[out_pos] = np.uint8((bit_buf >> bits_in_buf) & 0xFF) + out_pos += 1 + return bit_buf, bits_in_buf, out_pos + + # Write initial clear code + bit_buf, bits_in_buf, out_pos = flush_code( + LZW_CLEAR_CODE, code_size, bit_buf, bits_in_buf, dst, dst_len, out_pos) + + if src_len == 0: + bit_buf, bits_in_buf, out_pos = flush_code( + LZW_EOI_CODE, code_size, bit_buf, bits_in_buf, dst, dst_len, out_pos) + # Flush remaining bits + if bits_in_buf > 0 and out_pos < dst_len: + dst[out_pos] = np.uint8((bit_buf << (8 - bits_in_buf)) & 0xFF) + out_pos += 1 + return out_pos + + prefix = np.int32(src[0]) + pos = 1 + + while pos < src_len: + suffix = np.int32(src[pos]) + # Look up (prefix, suffix) in hash table + key = np.int64(prefix) * 256 + np.int64(suffix) + h = int(key % HT_SIZE) + if h < 0: + h += HT_SIZE + + found = False + for _ in range(HT_SIZE): + if ht_gen[h] == current_gen and ht_keys[h] == key: + prefix = ht_values[h] + found = True + break + elif ht_gen[h] != current_gen: + break + h = (h + 1) % HT_SIZE + + if not found: + # Output the prefix code + bit_buf, bits_in_buf, out_pos = flush_code( + prefix, code_size, bit_buf, bits_in_buf, dst, dst_len, out_pos) + + # Add new entry to table + if next_code <= LZW_MAX_CODE: + ht_gen[h] = current_gen + ht_keys[h] = key + ht_values[h] = next_code + next_code += 1 + + # Encoder bumps one entry later than decoder (decoder trails by 1) + if next_code > (1 << code_size) - 1 and code_size < LZW_MAX_BITS: + code_size += 1 + + else: + # Table full, emit clear code and reset + bit_buf, bits_in_buf, out_pos = flush_code( + LZW_CLEAR_CODE, code_size, bit_buf, bits_in_buf, dst, dst_len, out_pos) + code_size = 9 + next_code = LZW_FIRST_CODE + current_gen += 1 + + prefix = suffix + pos += 1 + + # Output last prefix + bit_buf, bits_in_buf, out_pos = flush_code( + prefix, code_size, bit_buf, bits_in_buf, dst, dst_len, out_pos) + bit_buf, bits_in_buf, out_pos = flush_code( + LZW_EOI_CODE, code_size, bit_buf, bits_in_buf, dst, dst_len, out_pos) + + # Flush remaining bits + if bits_in_buf > 0 and out_pos < dst_len: + dst[out_pos] = np.uint8((bit_buf << (8 - bits_in_buf)) & 0xFF) + out_pos += 1 + + return out_pos + + +def lzw_compress(data: bytes) -> bytes: + """Compress data using TIFF-variant LZW. + + Parameters + ---------- + data : bytes + Raw data to compress. + + Returns + ------- + bytes + """ + src = np.frombuffer(data, dtype=np.uint8) + # Worst case: output slightly larger than input + max_out = len(data) + len(data) // 2 + 256 + dst = np.empty(max_out, dtype=np.uint8) + n = _lzw_encode_kernel(src, len(src), dst, max_out) + return dst[:n].tobytes() + + +# -- Horizontal predictor (Numba) -------------------------------------------- + +@ngjit +def _predictor_decode(data, width, height, bytes_per_sample): + """Undo horizontal differencing predictor (TIFF predictor=2). + + Operates in-place on the flat byte array, performing cumulative sum + per row at the sample level. + """ + row_bytes = width * bytes_per_sample + for row in range(height): + row_start = row * row_bytes + for col in range(bytes_per_sample, row_bytes): + idx = row_start + col + data[idx] = np.uint8((np.int32(data[idx]) + np.int32(data[idx - bytes_per_sample])) & 0xFF) + + +@ngjit +def _predictor_encode(data, width, height, bytes_per_sample): + """Apply horizontal differencing predictor (TIFF predictor=2). + + Operates in-place, converting absolute values to differences. + Process right-to-left to avoid overwriting values we still need. + """ + row_bytes = width * bytes_per_sample + for row in range(height): + row_start = row * row_bytes + for col in range(row_bytes - 1, bytes_per_sample - 1, -1): + idx = row_start + col + data[idx] = np.uint8((np.int32(data[idx]) - np.int32(data[idx - bytes_per_sample])) & 0xFF) + + +def predictor_decode(data: np.ndarray, width: int, height: int, + bytes_per_sample: int) -> np.ndarray: + """Undo horizontal differencing predictor (predictor=2). + + Parameters + ---------- + data : np.ndarray + Flat uint8 array of decompressed pixel data (modified in-place). + width, height : int + Image dimensions. + bytes_per_sample : int + Bytes per sample (e.g. 1 for uint8, 4 for float32). + + Returns + ------- + np.ndarray + Same array, modified in-place. + """ + buf = np.ascontiguousarray(data) + _predictor_decode(buf, width, height, bytes_per_sample) + return buf + + +def predictor_encode(data: np.ndarray, width: int, height: int, + bytes_per_sample: int) -> np.ndarray: + """Apply horizontal differencing predictor (predictor=2). + + Parameters + ---------- + data : np.ndarray + Flat uint8 array of pixel data (modified in-place). + width, height : int + Image dimensions. + bytes_per_sample : int + Bytes per sample. + + Returns + ------- + np.ndarray + Same array, modified in-place. + """ + buf = np.ascontiguousarray(data) + _predictor_encode(buf, width, height, bytes_per_sample) + return buf + + +# -- Floating-point predictor (predictor=3) ----------------------------------- +# +# TIFF predictor=3 (floating-point horizontal differencing): +# During encoding, bytes of each sample are rearranged into byte-lane order +# (MSB lane first, LSB lane last), then horizontal differencing is applied +# across the entire transposed row. +# +# For little-endian float32 with N samples: +# Swizzled layout: [MSB_s0..MSB_sN-1, byte2_s0..byte2_sN-1, +# byte1_s0..byte1_sN-1, LSB_s0..LSB_sN-1] +# i.e. lane 0 = native byte (bps-1), lane 1 = native byte (bps-2), etc. +# +# Decode: undo differencing, then un-transpose (lane b -> native byte bps-1-b). + +@ngjit +def _fp_predictor_decode_row(row_data, width, bps): + """Undo floating-point predictor for one row (in-place). + + row_data: uint8 array of length width * bps + """ + n = width * bps + + # Step 1: undo horizontal differencing on the byte-swizzled row + for i in range(1, n): + row_data[i] = np.uint8((np.int32(row_data[i]) + np.int32(row_data[i - 1])) & 0xFF) + + # Step 2: un-transpose bytes back to native sample order + tmp = np.empty(n, dtype=np.uint8) + for sample in range(width): + for b in range(bps): + tmp[sample * bps + b] = row_data[(bps - 1 - b) * width + sample] + for i in range(n): + row_data[i] = tmp[i] + + +@ngjit +def _fp_predictor_decode_rows(data, width, height, bps): + """Dispatch per-row decode from Numba, avoiding Python loop overhead.""" + row_len = width * bps + for row in range(height): + start = row * row_len + _fp_predictor_decode_row(data[start:start + row_len], width, bps) + + +def fp_predictor_decode(data: np.ndarray, width: int, height: int, + bytes_per_sample: int) -> np.ndarray: + """Undo floating-point predictor (predictor=3). + + Parameters + ---------- + data : np.ndarray + Flat uint8 array of decompressed tile/strip data. + width, height : int + Tile/strip dimensions. + bytes_per_sample : int + Bytes per sample (e.g. 4 for float32, 8 for float64). + + Returns + ------- + np.ndarray + Corrected array. + """ + buf = np.ascontiguousarray(data) + _fp_predictor_decode_rows(buf, width, height, bytes_per_sample) + return buf + + +@ngjit +def _fp_predictor_encode_row(row_data, width, bps): + """Apply floating-point predictor for one row (in-place).""" + n = width * bps + + # Step 1: transpose to byte-swizzled layout (MSB lane first) + # Native byte b of each sample goes to lane (bps-1-b). + tmp = np.empty(n, dtype=np.uint8) + for sample in range(width): + for b in range(bps): + tmp[(bps - 1 - b) * width + sample] = row_data[sample * bps + b] + for i in range(n): + row_data[i] = tmp[i] + + # Step 2: horizontal differencing on the swizzled row (right to left) + for i in range(n - 1, 0, -1): + row_data[i] = np.uint8((np.int32(row_data[i]) - np.int32(row_data[i - 1])) & 0xFF) + + +def fp_predictor_encode(data: np.ndarray, width: int, height: int, + bytes_per_sample: int) -> np.ndarray: + """Apply floating-point predictor (predictor=3). + + Parameters + ---------- + data : np.ndarray + Flat uint8 array of pixel data. + width, height : int + Dimensions. + bytes_per_sample : int + Bytes per sample. + + Returns + ------- + np.ndarray + Encoded array. + """ + buf = np.ascontiguousarray(data) + row_len = width * bytes_per_sample + for row in range(height): + start = row * row_len + _fp_predictor_encode_row(buf[start:start + row_len], width, bytes_per_sample) + return buf + + +# -- Sub-byte bit unpacking --------------------------------------------------- + +def unpack_bits(data: np.ndarray, bps: int, pixel_count: int) -> np.ndarray: + """Unpack sub-byte pixel data into one value per array element. + + Parameters + ---------- + data : np.ndarray + Flat uint8 array of packed bytes. + bps : int + Bits per sample (1, 2, 4, or 12). + pixel_count : int + Number of pixels to unpack. + + Returns + ------- + np.ndarray + uint8 for bps <= 8, uint16 for bps=12. + """ + if bps == 1: + # MSB-first: each byte holds 8 pixels + out = np.unpackbits(data)[:pixel_count] + return out.astype(np.uint8) + elif bps == 2: + # 4 pixels per byte, MSB-first + out = np.empty(pixel_count, dtype=np.uint8) + for i in range(min(len(data), (pixel_count + 3) // 4)): + b = data[i] + base = i * 4 + if base < pixel_count: + out[base] = (b >> 6) & 0x03 + if base + 1 < pixel_count: + out[base + 1] = (b >> 4) & 0x03 + if base + 2 < pixel_count: + out[base + 2] = (b >> 2) & 0x03 + if base + 3 < pixel_count: + out[base + 3] = b & 0x03 + return out + elif bps == 4: + # 2 pixels per byte, high nibble first + out = np.empty(pixel_count, dtype=np.uint8) + for i in range(min(len(data), (pixel_count + 1) // 2)): + b = data[i] + base = i * 2 + if base < pixel_count: + out[base] = (b >> 4) & 0x0F + if base + 1 < pixel_count: + out[base + 1] = b & 0x0F + return out + elif bps == 12: + # 2 pixels per 3 bytes, MSB-first + out = np.empty(pixel_count, dtype=np.uint16) + n_pairs = pixel_count // 2 + remainder = pixel_count % 2 + for i in range(n_pairs): + off = i * 3 + if off + 2 < len(data): + b0 = int(data[off]) + b1 = int(data[off + 1]) + b2 = int(data[off + 2]) + out[i * 2] = (b0 << 4) | (b1 >> 4) + out[i * 2 + 1] = ((b1 & 0x0F) << 8) | b2 + if remainder and n_pairs * 3 + 1 < len(data): + off = n_pairs * 3 + out[pixel_count - 1] = (int(data[off]) << 4) | (int(data[off + 1]) >> 4) + return out + else: + raise ValueError(f"Unsupported sub-byte bit depth: {bps}") + + +# -- PackBits (simple RLE) ---------------------------------------------------- + +def packbits_decompress(data: bytes) -> bytes: + """Decompress PackBits (TIFF compression tag 32773). + + Simple RLE: read a header byte n. + - 0 <= n <= 127: copy the next n+1 bytes literally. + - -127 <= n <= -1: repeat the next byte 1-n times. + - n == -128: no-op. + """ + src = data if isinstance(data, (bytes, bytearray)) else bytes(data) + out = bytearray() + i = 0 + length = len(src) + while i < length: + n = src[i] + if n > 127: + n = n - 256 # interpret as signed + i += 1 + if 0 <= n <= 127: + count = n + 1 + out.extend(src[i:i + count]) + i += count + elif -127 <= n <= -1: + if i < length: + out.extend(bytes([src[i]]) * (1 - n)) + i += 1 + # n == -128: skip + return bytes(out) + + +def packbits_compress(data: bytes) -> bytes: + """Compress data using PackBits.""" + src = data if isinstance(data, (bytes, bytearray)) else bytes(data) + out = bytearray() + i = 0 + length = len(src) + while i < length: + # Check for a run of identical bytes + j = i + 1 + while j < length and j - i < 128 and src[j] == src[i]: + j += 1 + run_len = j - i + + if run_len >= 3: + # Encode as run + out.append((256 - (run_len - 1)) & 0xFF) + out.append(src[i]) + i = j + else: + # Literal run: accumulate non-repeating bytes + lit_start = i + i = j + while i < length and i - lit_start < 128: + # Check if a run starts here + if i + 2 < length and src[i] == src[i + 1] == src[i + 2]: + break + i += 1 + lit_len = i - lit_start + out.append(lit_len - 1) + out.extend(src[lit_start:lit_start + lit_len]) + return bytes(out) + + +# -- JPEG codec (via Pillow) -------------------------------------------------- + +JPEG_AVAILABLE = False +try: + from PIL import Image + JPEG_AVAILABLE = True +except ImportError: + pass + + +def jpeg_decompress(data: bytes, width: int = 0, height: int = 0, + samples: int = 1) -> bytes: + """Decompress JPEG tile/strip data. Requires Pillow.""" + if not JPEG_AVAILABLE: + raise ImportError( + "Pillow is required to read JPEG-compressed TIFFs. " + "Install it with: pip install Pillow") + import io + img = Image.open(io.BytesIO(data)) + return np.asarray(img).tobytes() + + +def jpeg_compress(data: bytes, width: int, height: int, + samples: int = 1, quality: int = 75) -> bytes: + """Compress raw pixel data as JPEG. Requires Pillow.""" + if not JPEG_AVAILABLE: + raise ImportError( + "Pillow is required to write JPEG-compressed TIFFs. " + "Install it with: pip install Pillow") + import io + if samples == 1: + arr = np.frombuffer(data, dtype=np.uint8).reshape(height, width) + img = Image.fromarray(arr, mode='L') + elif samples == 3: + arr = np.frombuffer(data, dtype=np.uint8).reshape(height, width, 3) + img = Image.fromarray(arr, mode='RGB') + else: + raise ValueError(f"JPEG compression requires 1 or 3 bands, got {samples}") + buf = io.BytesIO() + img.save(buf, format='JPEG', quality=quality) + return buf.getvalue() + + +# -- ZSTD codec (via zstandard) ----------------------------------------------- + +ZSTD_AVAILABLE = False +try: + import zstandard as _zstd + ZSTD_AVAILABLE = True +except ImportError: + _zstd = None + + +def zstd_decompress(data: bytes) -> bytes: + """Decompress Zstandard data. Requires the ``zstandard`` package.""" + if not ZSTD_AVAILABLE: + raise ImportError( + "zstandard is required to read ZSTD-compressed TIFFs. " + "Install it with: pip install zstandard") + return _zstd.ZstdDecompressor().decompress(data) + + +def zstd_compress(data: bytes, level: int = 3) -> bytes: + """Compress data with Zstandard. Requires the ``zstandard`` package.""" + if not ZSTD_AVAILABLE: + raise ImportError( + "zstandard is required to write ZSTD-compressed TIFFs. " + "Install it with: pip install zstandard") + return _zstd.ZstdCompressor(level=level).compress(data) + + +# -- Dispatch helpers --------------------------------------------------------- + +# TIFF compression tag values +COMPRESSION_NONE = 1 +COMPRESSION_LZW = 5 +COMPRESSION_JPEG = 7 +COMPRESSION_DEFLATE = 8 +COMPRESSION_ZSTD = 50000 +COMPRESSION_PACKBITS = 32773 +COMPRESSION_ADOBE_DEFLATE = 32946 + + +def decompress(data, compression: int, expected_size: int = 0, + width: int = 0, height: int = 0, samples: int = 1) -> np.ndarray: + """Decompress tile/strip data based on TIFF compression tag. + + Parameters + ---------- + data : bytes + Compressed data. + compression : int + TIFF compression tag value. + expected_size : int + Expected decompressed size (used for LZW buffer allocation). + + Returns + ------- + np.ndarray + uint8 array. Mutable for LZW/deflate; may be read-only view for + uncompressed data (caller must .copy() if mutation is needed). + """ + if compression == COMPRESSION_NONE: + return np.frombuffer(data, dtype=np.uint8) + elif compression in (COMPRESSION_DEFLATE, COMPRESSION_ADOBE_DEFLATE): + return np.frombuffer(deflate_decompress(data), dtype=np.uint8) + elif compression == COMPRESSION_LZW: + return lzw_decompress(data, expected_size) + elif compression == COMPRESSION_PACKBITS: + return np.frombuffer(packbits_decompress(data), dtype=np.uint8) + elif compression == COMPRESSION_JPEG: + return np.frombuffer(jpeg_decompress(data, width, height, samples), + dtype=np.uint8) + elif compression == COMPRESSION_ZSTD: + return np.frombuffer(zstd_decompress(data), dtype=np.uint8) + else: + raise ValueError(f"Unsupported compression type: {compression}") + + +def compress(data: bytes, compression: int, level: int = 6) -> bytes: + """Compress data based on TIFF compression tag. + + Parameters + ---------- + data : bytes + Raw data. + compression : int + TIFF compression tag value. + level : int + Compression level (for deflate). + + Returns + ------- + bytes + """ + if compression == COMPRESSION_NONE: + return data + elif compression in (COMPRESSION_DEFLATE, COMPRESSION_ADOBE_DEFLATE): + return deflate_compress(data, level) + elif compression == COMPRESSION_LZW: + return lzw_compress(data) + elif compression == COMPRESSION_PACKBITS: + return packbits_compress(data) + elif compression == COMPRESSION_ZSTD: + return zstd_compress(data, level) + elif compression == COMPRESSION_JPEG: + raise ValueError("Use jpeg_compress() directly with width/height/samples") + else: + raise ValueError(f"Unsupported compression type: {compression}") diff --git a/xrspatial/geotiff/_dtypes.py b/xrspatial/geotiff/_dtypes.py new file mode 100644 index 00000000..a510061d --- /dev/null +++ b/xrspatial/geotiff/_dtypes.py @@ -0,0 +1,136 @@ +"""TIFF type ID <-> numpy dtype mapping.""" +from __future__ import annotations + +import numpy as np + +# TIFF type IDs (baseline + BigTIFF extensions) +BYTE = 1 +ASCII = 2 +SHORT = 3 +LONG = 4 +RATIONAL = 5 +SBYTE = 6 +UNDEFINED = 7 +SSHORT = 8 +SLONG = 9 +SRATIONAL = 10 +FLOAT = 11 +DOUBLE = 12 +# BigTIFF additions +LONG8 = 16 +SLONG8 = 17 +IFD8 = 18 + +# Bytes per element for each TIFF type +TIFF_TYPE_SIZES: dict[int, int] = { + BYTE: 1, + ASCII: 1, + SHORT: 2, + LONG: 4, + RATIONAL: 8, # two LONGs + SBYTE: 1, + UNDEFINED: 1, + SSHORT: 2, + SLONG: 4, + SRATIONAL: 8, # two SLONGs + FLOAT: 4, + DOUBLE: 8, + LONG8: 8, + SLONG8: 8, + IFD8: 8, +} + +# struct format characters for single values (excludes RATIONAL/SRATIONAL) +TIFF_TYPE_STRUCT_CODES: dict[int, str] = { + BYTE: 'B', + ASCII: 's', + SHORT: 'H', + LONG: 'I', + SBYTE: 'b', + UNDEFINED: 'B', + SSHORT: 'h', + SLONG: 'i', + FLOAT: 'f', + DOUBLE: 'd', + LONG8: 'Q', + SLONG8: 'q', + IFD8: 'Q', +} + +# SampleFormat tag values +SAMPLE_FORMAT_UINT = 1 +SAMPLE_FORMAT_INT = 2 +SAMPLE_FORMAT_FLOAT = 3 +SAMPLE_FORMAT_UNDEFINED = 4 + + +def tiff_dtype_to_numpy(bits_per_sample: int, sample_format: int = 1) -> np.dtype: + """Convert TIFF BitsPerSample + SampleFormat to a numpy dtype. + + Parameters + ---------- + bits_per_sample : int + Bits per sample (8, 16, 32, 64). + sample_format : int + TIFF SampleFormat tag value (1=uint, 2=int, 3=float). + + Returns + ------- + np.dtype + """ + _map = { + (8, SAMPLE_FORMAT_UINT): np.dtype('uint8'), + (8, SAMPLE_FORMAT_INT): np.dtype('int8'), + (16, SAMPLE_FORMAT_UINT): np.dtype('uint16'), + (16, SAMPLE_FORMAT_INT): np.dtype('int16'), + (32, SAMPLE_FORMAT_UINT): np.dtype('uint32'), + (32, SAMPLE_FORMAT_INT): np.dtype('int32'), + (32, SAMPLE_FORMAT_FLOAT): np.dtype('float32'), + (64, SAMPLE_FORMAT_UINT): np.dtype('uint64'), + (64, SAMPLE_FORMAT_INT): np.dtype('int64'), + (64, SAMPLE_FORMAT_FLOAT): np.dtype('float64'), + # treat UNDEFINED same as UINT + (8, SAMPLE_FORMAT_UNDEFINED): np.dtype('uint8'), + (16, SAMPLE_FORMAT_UNDEFINED): np.dtype('uint16'), + (32, SAMPLE_FORMAT_UNDEFINED): np.dtype('uint32'), + (64, SAMPLE_FORMAT_UNDEFINED): np.dtype('uint64'), + # Sub-byte and non-standard bit depths: promoted to smallest + # numpy type that can hold the values. + (1, SAMPLE_FORMAT_UINT): np.dtype('uint8'), + (1, SAMPLE_FORMAT_UNDEFINED): np.dtype('uint8'), + (2, SAMPLE_FORMAT_UINT): np.dtype('uint8'), + (2, SAMPLE_FORMAT_UNDEFINED): np.dtype('uint8'), + (4, SAMPLE_FORMAT_UINT): np.dtype('uint8'), + (4, SAMPLE_FORMAT_UNDEFINED): np.dtype('uint8'), + (12, SAMPLE_FORMAT_UINT): np.dtype('uint16'), + (12, SAMPLE_FORMAT_UNDEFINED): np.dtype('uint16'), + } + key = (bits_per_sample, sample_format) + if key not in _map: + raise ValueError( + f"Unsupported BitsPerSample={bits_per_sample}, " + f"SampleFormat={sample_format}" + ) + return _map[key] + + +# Set of BitsPerSample values that require bit-level unpacking +SUB_BYTE_BPS = {1, 2, 4, 12} + + +def numpy_to_tiff_dtype(dt: np.dtype) -> tuple[int, int]: + """Convert a numpy dtype to (bits_per_sample, sample_format). + + Returns + ------- + (bits_per_sample, sample_format) tuple + """ + dt = np.dtype(dt) + if dt.kind == 'u': + return (dt.itemsize * 8, SAMPLE_FORMAT_UINT) + elif dt.kind == 'i': + return (dt.itemsize * 8, SAMPLE_FORMAT_INT) + elif dt.kind == 'f': + return (dt.itemsize * 8, SAMPLE_FORMAT_FLOAT) + else: + raise ValueError(f"Unsupported numpy dtype: {dt}") diff --git a/xrspatial/geotiff/_geotags.py b/xrspatial/geotiff/_geotags.py new file mode 100644 index 00000000..d3352819 --- /dev/null +++ b/xrspatial/geotiff/_geotags.py @@ -0,0 +1,598 @@ +"""GeoTIFF tag interpretation: CRS, affine transform, GeoKeys.""" +from __future__ import annotations + +import struct +from dataclasses import dataclass, field + +from ._header import ( + IFD, + TAG_IMAGE_WIDTH, TAG_IMAGE_LENGTH, TAG_BITS_PER_SAMPLE, + TAG_COMPRESSION, TAG_PHOTOMETRIC, + TAG_STRIP_OFFSETS, TAG_SAMPLES_PER_PIXEL, + TAG_ROWS_PER_STRIP, TAG_STRIP_BYTE_COUNTS, + TAG_X_RESOLUTION, TAG_Y_RESOLUTION, + TAG_PLANAR_CONFIG, TAG_RESOLUTION_UNIT, + TAG_PREDICTOR, TAG_COLORMAP, + TAG_TILE_WIDTH, TAG_TILE_LENGTH, + TAG_TILE_OFFSETS, TAG_TILE_BYTE_COUNTS, + TAG_SAMPLE_FORMAT, TAG_GDAL_METADATA, TAG_GDAL_NODATA, + TAG_MODEL_PIXEL_SCALE, TAG_MODEL_TIEPOINT, + TAG_MODEL_TRANSFORMATION, + TAG_GEO_KEY_DIRECTORY, TAG_GEO_DOUBLE_PARAMS, TAG_GEO_ASCII_PARAMS, +) + +# Tags that the writer manages -- everything else can be passed through +_MANAGED_TAGS = frozenset({ + TAG_IMAGE_WIDTH, TAG_IMAGE_LENGTH, TAG_BITS_PER_SAMPLE, + TAG_COMPRESSION, TAG_PHOTOMETRIC, + TAG_STRIP_OFFSETS, TAG_SAMPLES_PER_PIXEL, + TAG_ROWS_PER_STRIP, TAG_STRIP_BYTE_COUNTS, + TAG_X_RESOLUTION, TAG_Y_RESOLUTION, + TAG_PLANAR_CONFIG, TAG_RESOLUTION_UNIT, + TAG_PREDICTOR, TAG_COLORMAP, + TAG_TILE_WIDTH, TAG_TILE_LENGTH, + TAG_TILE_OFFSETS, TAG_TILE_BYTE_COUNTS, + TAG_SAMPLE_FORMAT, TAG_GDAL_METADATA, TAG_GDAL_NODATA, + TAG_MODEL_PIXEL_SCALE, TAG_MODEL_TIEPOINT, + TAG_MODEL_TRANSFORMATION, + TAG_GEO_KEY_DIRECTORY, TAG_GEO_DOUBLE_PARAMS, TAG_GEO_ASCII_PARAMS, +}) + +# GeoKey IDs +GEOKEY_MODEL_TYPE = 1024 +GEOKEY_RASTER_TYPE = 1025 +GEOKEY_CITATION = 1026 +GEOKEY_GEOGRAPHIC_TYPE = 2048 +GEOKEY_GEOG_CITATION = 2049 +GEOKEY_GEODETIC_DATUM = 2050 +GEOKEY_GEOG_LINEAR_UNITS = 2052 +GEOKEY_GEOG_ANGULAR_UNITS = 2054 +GEOKEY_GEOG_SEMI_MAJOR_AXIS = 2057 +GEOKEY_GEOG_INV_FLATTENING = 2059 +GEOKEY_PROJECTED_CS_TYPE = 3072 +GEOKEY_PROJ_CITATION = 3073 +GEOKEY_PROJECTION = 3074 +GEOKEY_PROJ_LINEAR_UNITS = 3076 +GEOKEY_VERTICAL_CS_TYPE = 4096 +GEOKEY_VERTICAL_CITATION = 4097 +GEOKEY_VERTICAL_DATUM = 4098 +GEOKEY_VERTICAL_UNITS = 4099 + +# Well-known EPSG unit codes +ANGULAR_UNITS = { + 9101: 'radian', + 9102: 'degree', + 9103: 'arc-minute', + 9104: 'arc-second', + 9105: 'grad', +} + +LINEAR_UNITS = { + 9001: 'metre', + 9002: 'foot', + 9003: 'us_survey_foot', + 9030: 'nautical_mile', + 9036: 'kilometre', +} + +# ModelType values +MODEL_TYPE_PROJECTED = 1 +MODEL_TYPE_GEOGRAPHIC = 2 +MODEL_TYPE_GEOCENTRIC = 3 + +# RasterType values +RASTER_PIXEL_IS_AREA = 1 +RASTER_PIXEL_IS_POINT = 2 + + +@dataclass +class GeoTransform: + """Affine transform from pixel to geographic coordinates. + + For pixel (col, row): + x = origin_x + col * pixel_width + y = origin_y + row * pixel_height + + pixel_height is typically negative (y decreases downward). + """ + origin_x: float = 0.0 + origin_y: float = 0.0 + pixel_width: float = 1.0 + pixel_height: float = -1.0 + + +@dataclass +class GeoInfo: + """Geographic metadata extracted from GeoTIFF tags.""" + transform: GeoTransform = field(default_factory=GeoTransform) + crs_epsg: int | None = None + model_type: int = 0 + raster_type: int = RASTER_PIXEL_IS_AREA + nodata: float | None = None + colormap: list | None = None # list of (R, G, B, A) float tuples, or None + x_resolution: float | None = None + y_resolution: float | None = None + resolution_unit: int | None = None # 1=none, 2=inch, 3=cm + # CRS description fields + crs_name: str | None = None # GTCitationGeoKey or ProjCitationGeoKey + geog_citation: str | None = None # e.g. "WGS 84", "NAD83" + datum_code: int | None = None # GeogGeodeticDatumGeoKey + angular_units: str | None = None # e.g. "degree" + angular_units_code: int | None = None + linear_units: str | None = None # e.g. "metre" + linear_units_code: int | None = None + semi_major_axis: float | None = None + inv_flattening: float | None = None + projection_code: int | None = None + # Vertical CRS + vertical_epsg: int | None = None + vertical_citation: str | None = None + vertical_datum: int | None = None + vertical_units: str | None = None + vertical_units_code: int | None = None + # WKT CRS string (resolved from EPSG via pyproj, or provided by caller) + crs_wkt: str | None = None + # GDAL metadata: dict of {name: value} for dataset-level items, + # and {(name, band): value} for per-band items. Raw XML also kept. + gdal_metadata: dict | None = None + gdal_metadata_xml: str | None = None + # Extra TIFF tags not managed by the writer (pass-through on round-trip) + # List of (tag_id, type_id, count, raw_value) tuples. + extra_tags: list | None = None + # Raw geokeys dict for anything else + geokeys: dict[int, int | float | str] = field(default_factory=dict) + + +def _parse_gdal_metadata(xml_str: str) -> dict: + """Parse GDALMetadata XML into a flat dict. + + Dataset-level items are stored as ``{name: value}``. + Per-band items are stored as ``{(name, band_int): value}``. + """ + import xml.etree.ElementTree as ET + result = {} + try: + root = ET.fromstring(xml_str) + for item in root.findall('Item'): + name = item.get('name', '') + sample = item.get('sample') + text = item.text or '' + if sample is not None: + result[(name, int(sample))] = text + else: + result[name] = text + except ET.ParseError: + pass + return result + + +def _build_gdal_metadata_xml(meta: dict) -> str: + """Serialize a metadata dict back to GDALMetadata XML. + + Accepts the same dict format that _parse_gdal_metadata produces: + string keys for dataset-level, (name, band) tuples for per-band. + """ + lines = [''] + for key, value in meta.items(): + if isinstance(key, tuple): + name, sample = key + lines.append( + f' {value}') + else: + lines.append(f' {value}') + lines.append('') + return '\n'.join(lines) + '\n' + + +def _epsg_to_wkt(epsg: int) -> str | None: + """Resolve an EPSG code to a WKT string using pyproj. + + Returns None if pyproj is not installed or the code is unknown. + """ + try: + from pyproj import CRS + return CRS.from_epsg(epsg).to_wkt() + except Exception: + return None + + +def _parse_geokeys(ifd: IFD, data: bytes | memoryview, + byte_order: str) -> dict[int, int | float | str]: + """Parse the GeoKeyDirectory and resolve values from param tags. + + The GeoKeyDirectoryTag (34735) contains a header: + [key_directory_version, key_revision, minor_revision, num_keys] + followed by num_keys entries of: + [key_id, tiff_tag_location, count, value_offset] + + If tiff_tag_location == 0, value_offset is the value itself. + If tiff_tag_location == 34736, look up in GeoDoubleParamsTag. + If tiff_tag_location == 34737, look up in GeoAsciiParamsTag. + """ + geokeys: dict[int, int | float | str] = {} + + dir_entry = ifd.entries.get(TAG_GEO_KEY_DIRECTORY) + if dir_entry is None: + return geokeys + + dir_values = dir_entry.value + if isinstance(dir_values, int): + return geokeys + if not isinstance(dir_values, tuple): + dir_values = (dir_values,) + + if len(dir_values) < 4: + return geokeys + + num_keys = dir_values[3] + + # Get param tags + double_params = ifd.get_value(TAG_GEO_DOUBLE_PARAMS) + if double_params is not None: + if not isinstance(double_params, tuple): + double_params = (double_params,) + else: + double_params = () + + ascii_params = ifd.get_value(TAG_GEO_ASCII_PARAMS) + if ascii_params is None: + ascii_params = '' + if isinstance(ascii_params, bytes): + ascii_params = ascii_params.decode('ascii', errors='replace') + + for i in range(num_keys): + base = 4 + i * 4 + if base + 3 >= len(dir_values): + break + + key_id = dir_values[base] + tag_loc = dir_values[base + 1] + count = dir_values[base + 2] + value_offset = dir_values[base + 3] + + if tag_loc == 0: + # Value is inline + geokeys[key_id] = value_offset + elif tag_loc == TAG_GEO_DOUBLE_PARAMS: + # Value in double params + if value_offset < len(double_params): + if count == 1: + geokeys[key_id] = double_params[value_offset] + else: + end = min(value_offset + count, len(double_params)) + geokeys[key_id] = double_params[value_offset:end] + else: + geokeys[key_id] = 0.0 + elif tag_loc == TAG_GEO_ASCII_PARAMS: + # Value in ASCII params + end = value_offset + count + val = ascii_params[value_offset:end].rstrip('|\x00') + geokeys[key_id] = val + else: + geokeys[key_id] = value_offset + + return geokeys + + +def _extract_transform(ifd: IFD) -> GeoTransform: + """Extract affine transform from ModelTransformation, or + ModelTiepoint + ModelPixelScale tags.""" + + # Try ModelTransformationTag (4x4 matrix) + transform_tag = ifd.get_value(TAG_MODEL_TRANSFORMATION) + if transform_tag is not None: + if isinstance(transform_tag, tuple) and len(transform_tag) >= 12: + # 4x4 row-major matrix + # x = M[0]*col + M[1]*row + M[3] + # y = M[4]*col + M[5]*row + M[7] + return GeoTransform( + origin_x=transform_tag[3], + origin_y=transform_tag[7], + pixel_width=transform_tag[0], + pixel_height=transform_tag[5], + ) + + # Try ModelTiepoint + ModelPixelScale + tiepoint = ifd.get_value(TAG_MODEL_TIEPOINT) + scale = ifd.get_value(TAG_MODEL_PIXEL_SCALE) + + if scale is not None: + if not isinstance(scale, tuple): + scale = (scale,) + + sx = scale[0] if len(scale) > 0 else 1.0 + sy = scale[1] if len(scale) > 1 else 1.0 + + if tiepoint is not None: + if not isinstance(tiepoint, tuple): + tiepoint = (tiepoint,) + # tiepoint: (I, J, K, X, Y, Z) + tp_i = tiepoint[0] if len(tiepoint) > 0 else 0.0 + tp_j = tiepoint[1] if len(tiepoint) > 1 else 0.0 + tp_x = tiepoint[3] if len(tiepoint) > 3 else 0.0 + tp_y = tiepoint[4] if len(tiepoint) > 4 else 0.0 + + origin_x = tp_x - tp_i * sx + origin_y = tp_y + tp_j * sy # sy is positive, but y goes down + + return GeoTransform( + origin_x=origin_x, + origin_y=origin_y, + pixel_width=sx, + pixel_height=-sy, # negative because y decreases + ) + + return GeoTransform(pixel_width=sx, pixel_height=-sy) + + return GeoTransform() + + +def extract_geo_info(ifd: IFD, data: bytes | memoryview, + byte_order: str) -> GeoInfo: + """Extract full geographic metadata from a parsed IFD. + + Parameters + ---------- + ifd : IFD + Parsed IFD. + data : bytes + Full file data (needed for resolving GeoKey param offsets). + byte_order : str + '<' or '>'. + + Returns + ------- + GeoInfo + """ + transform = _extract_transform(ifd) + geokeys = _parse_geokeys(ifd, data, byte_order) + + # Extract EPSG + epsg = None + if GEOKEY_PROJECTED_CS_TYPE in geokeys: + val = geokeys[GEOKEY_PROJECTED_CS_TYPE] + if isinstance(val, (int, float)) and val != 32767: + epsg = int(val) + if epsg is None and GEOKEY_GEOGRAPHIC_TYPE in geokeys: + val = geokeys[GEOKEY_GEOGRAPHIC_TYPE] + if isinstance(val, (int, float)) and val != 32767: + epsg = int(val) + + model_type = geokeys.get(GEOKEY_MODEL_TYPE, 0) + raster_type = geokeys.get(GEOKEY_RASTER_TYPE, RASTER_PIXEL_IS_AREA) + + # CRS name: prefer GTCitationGeoKey, fall back to ProjCitationGeoKey + crs_name = geokeys.get(GEOKEY_CITATION) + if crs_name is None: + crs_name = geokeys.get(GEOKEY_PROJ_CITATION) + if isinstance(crs_name, str): + crs_name = crs_name.strip().rstrip('|') + else: + crs_name = None + + geog_citation = geokeys.get(GEOKEY_GEOG_CITATION) + if isinstance(geog_citation, str): + geog_citation = geog_citation.strip().rstrip('|') + else: + geog_citation = None + + datum_code = geokeys.get(GEOKEY_GEODETIC_DATUM) + if isinstance(datum_code, (int, float)): + datum_code = int(datum_code) + else: + datum_code = None + + # Angular units (geographic CRS) + ang_code = geokeys.get(GEOKEY_GEOG_ANGULAR_UNITS) + ang_name = None + if isinstance(ang_code, (int, float)): + ang_code = int(ang_code) + ang_name = ANGULAR_UNITS.get(ang_code) + else: + ang_code = None + + # Linear units (projected CRS) + lin_code = geokeys.get(GEOKEY_PROJ_LINEAR_UNITS) + lin_name = None + if isinstance(lin_code, (int, float)): + lin_code = int(lin_code) + lin_name = LINEAR_UNITS.get(lin_code) + else: + lin_code = None + + # Ellipsoid parameters + semi_major = geokeys.get(GEOKEY_GEOG_SEMI_MAJOR_AXIS) + if not isinstance(semi_major, (int, float)): + semi_major = None + inv_flat = geokeys.get(GEOKEY_GEOG_INV_FLATTENING) + if not isinstance(inv_flat, (int, float)): + inv_flat = None + + proj_code = geokeys.get(GEOKEY_PROJECTION) + if isinstance(proj_code, (int, float)): + proj_code = int(proj_code) + else: + proj_code = None + + # Vertical CRS + vert_epsg = geokeys.get(GEOKEY_VERTICAL_CS_TYPE) + if isinstance(vert_epsg, (int, float)) and vert_epsg != 32767: + vert_epsg = int(vert_epsg) + else: + vert_epsg = None + + vert_citation = geokeys.get(GEOKEY_VERTICAL_CITATION) + if isinstance(vert_citation, str): + vert_citation = vert_citation.strip().rstrip('|') + else: + vert_citation = None + + vert_datum = geokeys.get(GEOKEY_VERTICAL_DATUM) + if isinstance(vert_datum, (int, float)): + vert_datum = int(vert_datum) + else: + vert_datum = None + + vert_units_code = geokeys.get(GEOKEY_VERTICAL_UNITS) + vert_units_name = None + if isinstance(vert_units_code, (int, float)): + vert_units_code = int(vert_units_code) + vert_units_name = LINEAR_UNITS.get(vert_units_code) + else: + vert_units_code = None + + # Extract nodata from GDAL_NODATA tag + nodata = None + nodata_str = ifd.nodata_str + if nodata_str is not None: + try: + nodata = float(nodata_str) + except (ValueError, TypeError): + pass + + # Parse GDALMetadata XML (tag 42112) + gdal_metadata = None + gdal_metadata_xml = ifd.gdal_metadata + if gdal_metadata_xml is not None: + gdal_metadata = _parse_gdal_metadata(gdal_metadata_xml) + + # Extract palette colormap (Photometric=3, tag 320) + colormap = None + if ifd.photometric == 3: + raw_cmap = ifd.colormap + if raw_cmap is not None: + bps_val = ifd.bits_per_sample + if isinstance(bps_val, tuple): + bps_val = bps_val[0] + n_colors = 1 << bps_val # 2^BitsPerSample + # TIFF ColorMap: 3 * n_colors uint16 values + # Layout: [R0..R_{n-1}, G0..G_{n-1}, B0..B_{n-1}] + # Values are 0-65535, scale to 0.0-1.0 for matplotlib + if len(raw_cmap) >= 3 * n_colors: + colormap = [] + for i in range(n_colors): + r = raw_cmap[i] / 65535.0 + g = raw_cmap[n_colors + i] / 65535.0 + b = raw_cmap[2 * n_colors + i] / 65535.0 + colormap.append((r, g, b, 1.0)) + + # Collect extra (non-managed) tags for pass-through + extra_tags = [] + for tag_id, entry in ifd.entries.items(): + if tag_id not in _MANAGED_TAGS: + extra_tags.append((tag_id, entry.type_id, entry.count, entry.value)) + if not extra_tags: + extra_tags = None + + # Resolve EPSG -> WKT via pyproj if available + crs_wkt = None + if epsg is not None: + crs_wkt = _epsg_to_wkt(epsg) + + return GeoInfo( + transform=transform, + crs_epsg=epsg, + model_type=int(model_type) if isinstance(model_type, (int, float)) else 0, + raster_type=int(raster_type) if isinstance(raster_type, (int, float)) else RASTER_PIXEL_IS_AREA, + nodata=nodata, + colormap=colormap, + x_resolution=ifd.x_resolution, + y_resolution=ifd.y_resolution, + resolution_unit=ifd.resolution_unit, + crs_name=crs_name, + geog_citation=geog_citation, + datum_code=datum_code, + angular_units=ang_name, + angular_units_code=ang_code, + linear_units=lin_name, + linear_units_code=lin_code, + semi_major_axis=float(semi_major) if semi_major is not None else None, + inv_flattening=float(inv_flat) if inv_flat is not None else None, + projection_code=proj_code, + vertical_epsg=vert_epsg, + vertical_citation=vert_citation, + vertical_datum=vert_datum, + vertical_units=vert_units_name, + vertical_units_code=vert_units_code, + crs_wkt=crs_wkt, + gdal_metadata=gdal_metadata, + gdal_metadata_xml=gdal_metadata_xml, + extra_tags=extra_tags, + geokeys=geokeys, + ) + + +def build_geo_tags(transform: GeoTransform, crs_epsg: int | None = None, + nodata=None, + raster_type: int = RASTER_PIXEL_IS_AREA) -> dict[int, tuple]: + """Build GeoTIFF IFD tag entries for writing. + + Parameters + ---------- + transform : GeoTransform + Pixel-to-coordinate mapping. + crs_epsg : int or None + EPSG code for the CRS. + nodata : float, int, or None + NoData value. + raster_type : int + RASTER_PIXEL_IS_AREA (1) or RASTER_PIXEL_IS_POINT (2). + + Returns + ------- + dict mapping tag ID to (type_id, count, value_bytes) tuples, + where value_bytes is already serialized for little-endian output. + """ + tags = {} + + # ModelPixelScaleTag (33550): (ScaleX, ScaleY, ScaleZ) + sx = abs(transform.pixel_width) + sy = abs(transform.pixel_height) + tags[TAG_MODEL_PIXEL_SCALE] = (sx, sy, 0.0) + + # ModelTiepointTag (33922): (I, J, K, X, Y, Z) + tags[TAG_MODEL_TIEPOINT] = ( + 0.0, 0.0, 0.0, + transform.origin_x, transform.origin_y, 0.0, + ) + + # GeoKeyDirectoryTag (34735) + geokeys = [] + # Header: version=1, revision=1, minor=0 + num_keys = 1 # at least RasterType + key_entries = [] + + # ModelType + if crs_epsg is not None: + # Guess model type from EPSG (simple heuristic) + if crs_epsg == 4326 or (crs_epsg >= 4000 and crs_epsg < 5000): + model_type = MODEL_TYPE_GEOGRAPHIC + else: + model_type = MODEL_TYPE_PROJECTED + key_entries.append((GEOKEY_MODEL_TYPE, 0, 1, model_type)) + num_keys += 1 + + # RasterType + key_entries.append((GEOKEY_RASTER_TYPE, 0, 1, raster_type)) + + # CRS + if crs_epsg is not None: + if model_type == MODEL_TYPE_GEOGRAPHIC: + key_entries.append((GEOKEY_GEOGRAPHIC_TYPE, 0, 1, crs_epsg)) + else: + key_entries.append((GEOKEY_PROJECTED_CS_TYPE, 0, 1, crs_epsg)) + num_keys += 1 + + num_keys = len(key_entries) + header = [1, 1, 0, num_keys] + flat = header.copy() + for entry in key_entries: + flat.extend(entry) + + tags[TAG_GEO_KEY_DIRECTORY] = tuple(flat) + + # GDAL_NODATA + if nodata is not None: + tags[TAG_GDAL_NODATA] = str(nodata) + + return tags diff --git a/xrspatial/geotiff/_gpu_decode.py b/xrspatial/geotiff/_gpu_decode.py new file mode 100644 index 00000000..93f2ae1a --- /dev/null +++ b/xrspatial/geotiff/_gpu_decode.py @@ -0,0 +1,1571 @@ +"""GPU-accelerated TIFF tile decompression via Numba CUDA. + +Provides CUDA kernels for LZW decode, horizontal predictor decode, +and floating-point predictor decode. Each tile is processed by one +thread (LZW is sequential per-stream), but all tiles run in parallel. +""" +from __future__ import annotations + +import math + +import numpy as np +from numba import cuda + +# LZW constants (same as _compression.py) +LZW_CLEAR_CODE = 256 +LZW_EOI_CODE = 257 +LZW_FIRST_CODE = 258 +LZW_MAX_CODE = 4095 +LZW_MAX_BITS = 12 + + +# --------------------------------------------------------------------------- +# LZW decode kernel -- one thread per tile +# --------------------------------------------------------------------------- + +@cuda.jit +def _lzw_decode_tiles_kernel( + compressed_buf, # uint8: all compressed tile data concatenated + tile_offsets, # int64: start offset of each tile in compressed_buf + tile_sizes, # int64: compressed size of each tile + decompressed_buf, # uint8: output buffer (all tiles concatenated) + tile_out_offsets, # int64: start offset of each tile in decompressed_buf + tile_out_sizes, # int64: expected decompressed size per tile + tile_actual_sizes, # int64: actual bytes written per tile (output) +): + """Decode one LZW tile per thread block. + + One thread block = one tile. Thread 0 in each block does the sequential + LZW decode. The table lives in shared memory (fast, ~20KB per block) + instead of local memory (slow DRAM spill). + """ + tile_idx = cuda.blockIdx.x + if tile_idx >= tile_offsets.shape[0]: + return + + # Only thread 0 in each block does the work + if cuda.threadIdx.x != 0: + return + + src_start = tile_offsets[tile_idx] + src_len = tile_sizes[tile_idx] + dst_start = tile_out_offsets[tile_idx] + dst_len = tile_out_sizes[tile_idx] + + if src_len == 0: + tile_actual_sizes[tile_idx] = 0 + return + + # LZW table in shared memory (fast on-chip SRAM) + table_prefix = cuda.shared.array(4096, dtype=numba_int32) + table_suffix = cuda.shared.array(4096, dtype=numba_uint8) + stack = cuda.shared.array(4096, dtype=numba_uint8) + + # Initialize single-byte entries + for i in range(256): + table_prefix[i] = -1 + table_suffix[i] = numba_uint8(i) + for i in range(256, 4096): + table_prefix[i] = -1 + table_suffix[i] = numba_uint8(0) + + bit_pos = 0 + code_size = 9 + next_code = LZW_FIRST_CODE + out_pos = 0 + old_code = -1 + + while True: + # Read next code (MSB-first) + byte_offset = bit_pos >> 3 + if byte_offset >= src_len: + break + + b0 = numba_int32(compressed_buf[src_start + byte_offset]) << 16 + if byte_offset + 1 < src_len: + b0 |= numba_int32(compressed_buf[src_start + byte_offset + 1]) << 8 + if byte_offset + 2 < src_len: + b0 |= numba_int32(compressed_buf[src_start + byte_offset + 2]) + + bit_off = bit_pos & 7 + code = (b0 >> (24 - bit_off - code_size)) & ((1 << code_size) - 1) + bit_pos += code_size + + if code == LZW_EOI_CODE: + break + + if code == LZW_CLEAR_CODE: + code_size = 9 + next_code = LZW_FIRST_CODE + old_code = -1 + continue + + if old_code == -1: + if code < 256 and out_pos < dst_len: + decompressed_buf[dst_start + out_pos] = numba_uint8(code) + out_pos += 1 + old_code = code + continue + + if code < next_code: + # Walk chain, push to stack + c = code + sp = 0 + while c >= 0 and c < 4096 and sp < 4096: + stack[sp] = table_suffix[c] + sp += 1 + c = table_prefix[c] + + # Emit reversed + for i in range(sp - 1, -1, -1): + if out_pos < dst_len: + decompressed_buf[dst_start + out_pos] = stack[i] + out_pos += 1 + + if next_code <= LZW_MAX_CODE and sp > 0: + table_prefix[next_code] = old_code + table_suffix[next_code] = stack[sp - 1] + next_code += 1 + else: + # Special case: code == next_code + c = old_code + sp = 0 + while c >= 0 and c < 4096 and sp < 4096: + stack[sp] = table_suffix[c] + sp += 1 + c = table_prefix[c] + + if sp == 0: + old_code = code + continue + + first_char = stack[sp - 1] + for i in range(sp - 1, -1, -1): + if out_pos < dst_len: + decompressed_buf[dst_start + out_pos] = stack[i] + out_pos += 1 + if out_pos < dst_len: + decompressed_buf[dst_start + out_pos] = first_char + out_pos += 1 + + if next_code <= LZW_MAX_CODE: + table_prefix[next_code] = old_code + table_suffix[next_code] = first_char + next_code += 1 + + # Early change + if next_code > (1 << code_size) - 2 and code_size < LZW_MAX_BITS: + code_size += 1 + + old_code = code + + tile_actual_sizes[tile_idx] = out_pos + + +# Type aliases for Numba CUDA local arrays +from numba import int32 as numba_int32, uint8 as numba_uint8, int64 as numba_int64 + + +# --------------------------------------------------------------------------- +# Deflate/inflate decode kernel -- one thread block per tile +# --------------------------------------------------------------------------- + +# Static tables for deflate +# Length base values and extra bits for codes 257-285 +_LEN_BASE = np.array([ + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, +], dtype=np.int32) +_LEN_EXTRA = np.array([ + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, + 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0, +], dtype=np.int32) +# Distance base values and extra bits for codes 0-29 +_DIST_BASE = np.array([ + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, + 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, + 12289, 16385, 24577, +], dtype=np.int32) +_DIST_EXTRA = np.array([ + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, + 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, +], dtype=np.int32) +# Code length code order (for dynamic Huffman) +_CL_ORDER = np.array([ + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15, +], dtype=np.int32) + + +@cuda.jit(device=True) +def _inflate_read_bits(src, src_start, src_len, bit_pos, n): + """Read n bits (LSB-first) from the source stream.""" + val = numba_int32(0) + for i in range(n): + byte_idx = (bit_pos[0] >> 3) + bit_idx = bit_pos[0] & 7 + if byte_idx < src_len: + val |= numba_int32((src[src_start + byte_idx] >> bit_idx) & 1) << i + bit_pos[0] += 1 + return val + + +@cuda.jit(device=True) +def _inflate_build_table(lengths, n_codes, table, max_bits, + overflow_codes, overflow_lens, n_overflow): + """Build a Huffman decode table from code lengths. + + Codes <= max_bits go into the fast table: table[reversed_code] = (sym << 5) | length. + Codes > max_bits go into overflow arrays for slow-path decode. + """ + bl_count = cuda.local.array(16, dtype=numba_int32) + for i in range(16): + bl_count[i] = 0 + for i in range(n_codes): + bl_count[lengths[i]] += 1 + bl_count[0] = 0 + + next_code = cuda.local.array(16, dtype=numba_int32) + code = 0 + for bits in range(1, 16): + code = (code + bl_count[bits - 1]) << 1 + next_code[bits] = code + + for i in range(1 << max_bits): + table[i] = 0 + + n_overflow[0] = 0 + + for sym in range(n_codes): + ln = lengths[sym] + if ln == 0: + continue + code = next_code[ln] + next_code[ln] += 1 + + # Reverse the code bits for LSB-first lookup + rev = numba_int32(0) + c = code + for b in range(ln): + rev = (rev << 1) | (c & 1) + c >>= 1 + + if ln <= max_bits: + # Fast table: fill all entries that share this prefix + # (entries where the extra high bits vary) + step = 1 << ln + idx = rev + while idx < (1 << max_bits): + table[idx] = numba_int32((sym << 5) | ln) + idx += step + else: + # Overflow: store reversed code + length for slow-path scan + oi = n_overflow[0] + if oi < overflow_codes.shape[0]: + overflow_codes[oi] = rev + overflow_lens[oi] = (sym << 5) | ln + n_overflow[0] = oi + 1 + + +@cuda.jit(device=True) +def _inflate_decode_symbol(src, src_start, src_len, bit_pos, table, max_bits, + overflow_codes, overflow_lens, n_overflow): + """Decode one Huffman symbol. Fast table for short codes, overflow scan for long.""" + # Peek 15 bits (max deflate code length) + peek = numba_int64(0) + for i in range(15): + byte_idx = (bit_pos[0] + i) >> 3 + bit_idx = (bit_pos[0] + i) & 7 + if byte_idx < src_len: + peek |= numba_int64((src[src_start + byte_idx] >> bit_idx) & 1) << i + + # Try fast table first + entry = table[numba_int32(peek) & ((1 << max_bits) - 1)] + length = entry & 0x1F + symbol = entry >> 5 + + if length > 0: + bit_pos[0] += length + return symbol + + # Slow path: scan overflow entries + for i in range(n_overflow[0]): + ov_rev = overflow_codes[i] + ov_entry = overflow_lens[i] + ov_len = ov_entry & 0x1F + ov_sym = ov_entry >> 5 + mask = (1 << ov_len) - 1 + if (numba_int32(peek) & mask) == ov_rev: + bit_pos[0] += ov_len + return ov_sym + + # Should not happen with valid data -- advance 1 bit to avoid freeze + bit_pos[0] += 1 + return 0 + + +@cuda.jit +def _inflate_tiles_kernel( + compressed_buf, + tile_offsets, + tile_sizes, + decompressed_buf, + tile_out_offsets, + tile_out_sizes, + tile_actual_sizes, + d_len_base, d_len_extra, d_dist_base, d_dist_extra, d_cl_order, +): + """Inflate (decompress) one zlib-wrapped deflate tile per thread block. + + Thread 0 in each block does the sequential inflate. + Huffman table in shared memory. + """ + tile_idx = cuda.blockIdx.x + if tile_idx >= tile_offsets.shape[0]: + return + if cuda.threadIdx.x != 0: + return + + src_start = tile_offsets[tile_idx] + src_len = tile_sizes[tile_idx] + dst_start = tile_out_offsets[tile_idx] + dst_len = tile_out_sizes[tile_idx] + + if src_len <= 2: + tile_actual_sizes[tile_idx] = 0 + return + + # Skip 2-byte zlib header (0x78 0x9C or similar) + bit_pos = cuda.local.array(1, dtype=numba_int64) + bit_pos[0] = numba_int64(16) # skip 2 bytes = 16 bits + + out_pos = 0 + + # Two-level Huffman tables: + # Level 1 (shared memory, fast): 10-bit lookup (1024 entries) + # Level 2 (local memory, slow): overflow for codes > 10 bits + MAX_LIT_BITS = 10 + MAX_DIST_BITS = 10 + lit_table = cuda.shared.array(1024, dtype=numba_int32) + dist_table = cuda.shared.array(1024, dtype=numba_int32) + + # Overflow arrays for long codes (rarely > 50 entries) + lit_ov_codes = cuda.local.array(64, dtype=numba_int32) + lit_ov_lens = cuda.local.array(64, dtype=numba_int32) + n_lit_ov = cuda.local.array(1, dtype=numba_int32) + dist_ov_codes = cuda.local.array(32, dtype=numba_int32) + dist_ov_lens = cuda.local.array(32, dtype=numba_int32) + n_dist_ov = cuda.local.array(1, dtype=numba_int32) + n_lit_ov[0] = 0 + n_dist_ov[0] = 0 + + code_lengths = cuda.local.array(320, dtype=numba_int32) + + while True: + # Read block header + bfinal = _inflate_read_bits(compressed_buf, src_start, src_len, bit_pos, 1) + btype = _inflate_read_bits(compressed_buf, src_start, src_len, bit_pos, 2) + + if btype == 0: + # Stored block: align to byte boundary, read len + bit_pos[0] = ((bit_pos[0] + 7) >> 3) << 3 + ln = _inflate_read_bits(compressed_buf, src_start, src_len, bit_pos, 16) + _inflate_read_bits(compressed_buf, src_start, src_len, bit_pos, 16) # nlen (complement) + for i in range(ln): + byte_idx = bit_pos[0] >> 3 + if byte_idx < src_len and out_pos < dst_len: + decompressed_buf[dst_start + out_pos] = compressed_buf[src_start + byte_idx] + out_pos += 1 + bit_pos[0] += 8 + + elif btype == 1: + # Fixed Huffman: build fixed tables + for i in range(144): + code_lengths[i] = 8 + for i in range(144, 256): + code_lengths[i] = 9 + for i in range(256, 280): + code_lengths[i] = 7 + for i in range(280, 288): + code_lengths[i] = 8 + _inflate_build_table(code_lengths, 288, lit_table, MAX_LIT_BITS, + lit_ov_codes, lit_ov_lens, n_lit_ov) + + for i in range(30): + code_lengths[i] = 5 + _inflate_build_table(code_lengths, 30, dist_table, MAX_DIST_BITS, + dist_ov_codes, dist_ov_lens, n_dist_ov) + + # Decode symbols + while True: + sym = _inflate_decode_symbol( + compressed_buf, src_start, src_len, bit_pos, + lit_table, MAX_LIT_BITS, + lit_ov_codes, lit_ov_lens, n_lit_ov) + + if sym < 256: + if out_pos < dst_len: + decompressed_buf[dst_start + out_pos] = numba_uint8(sym) + out_pos += 1 + elif sym == 256: + break + else: + # Length-distance pair + li = sym - 257 + if li < 29: + length = d_len_base[li] + if d_len_extra[li] > 0: + length += _inflate_read_bits( + compressed_buf, src_start, src_len, + bit_pos, d_len_extra[li]) + else: + length = 3 + + dsym = _inflate_decode_symbol( + compressed_buf, src_start, src_len, bit_pos, + dist_table, MAX_DIST_BITS, + dist_ov_codes, dist_ov_lens, n_dist_ov) + if dsym < 30: + dist = d_dist_base[dsym] + if d_dist_extra[dsym] > 0: + dist += _inflate_read_bits( + compressed_buf, src_start, src_len, + bit_pos, d_dist_extra[dsym]) + else: + dist = 1 + + # Copy from output window + for i in range(length): + if out_pos < dst_len and dist <= out_pos: + decompressed_buf[dst_start + out_pos] = \ + decompressed_buf[dst_start + out_pos - dist] + out_pos += 1 + + elif btype == 2: + # Dynamic Huffman: read code length codes, then build tables + hlit = _inflate_read_bits(compressed_buf, src_start, src_len, bit_pos, 5) + 257 + hdist = _inflate_read_bits(compressed_buf, src_start, src_len, bit_pos, 5) + 1 + hclen = _inflate_read_bits(compressed_buf, src_start, src_len, bit_pos, 4) + 4 + + # Read code length code lengths + cl_lengths = cuda.local.array(19, dtype=numba_int32) + for i in range(19): + cl_lengths[i] = 0 + for i in range(hclen): + cl_lengths[d_cl_order[i]] = _inflate_read_bits( + compressed_buf, src_start, src_len, bit_pos, 3) + + # Build code length Huffman table (small: 7 bits max, no overflow) + cl_table = cuda.local.array(128, dtype=numba_int32) + cl_ov_c = cuda.local.array(4, dtype=numba_int32) + cl_ov_l = cuda.local.array(4, dtype=numba_int32) + n_cl_ov = cuda.local.array(1, dtype=numba_int32) + n_cl_ov[0] = 0 + _inflate_build_table(cl_lengths, 19, cl_table, 7, + cl_ov_c, cl_ov_l, n_cl_ov) + + # Decode literal/length + distance code lengths + total_codes = hlit + hdist + idx = 0 + for i in range(320): + code_lengths[i] = 0 + + while idx < total_codes: + sym = numba_int32(0) + # Decode from cl_table (7-bit) + peek = numba_int32(0) + for b in range(7): + byte_idx = (bit_pos[0] + b) >> 3 + bit_idx = (bit_pos[0] + b) & 7 + if byte_idx < src_len: + peek |= numba_int32( + (compressed_buf[src_start + byte_idx] >> bit_idx) & 1) << b + entry = cl_table[peek & 127] + ln = entry & 0x1F + sym = entry >> 5 + if ln > 0: + bit_pos[0] += ln + else: + bit_pos[0] += 1 + + if sym < 16: + code_lengths[idx] = sym + idx += 1 + elif sym == 16: + rep = _inflate_read_bits( + compressed_buf, src_start, src_len, bit_pos, 2) + 3 + val = code_lengths[idx - 1] if idx > 0 else 0 + for _ in range(rep): + if idx < 320: + code_lengths[idx] = val + idx += 1 + elif sym == 17: + rep = _inflate_read_bits( + compressed_buf, src_start, src_len, bit_pos, 3) + 3 + for _ in range(rep): + if idx < 320: + code_lengths[idx] = 0 + idx += 1 + elif sym == 18: + rep = _inflate_read_bits( + compressed_buf, src_start, src_len, bit_pos, 7) + 11 + for _ in range(rep): + if idx < 320: + code_lengths[idx] = 0 + idx += 1 + + # Build lit/len and dist tables + n_lit_ov[0] = 0 + _inflate_build_table(code_lengths, hlit, lit_table, MAX_LIT_BITS, + lit_ov_codes, lit_ov_lens, n_lit_ov) + # Distance codes start at code_lengths[hlit] + dist_lengths = cuda.local.array(32, dtype=numba_int32) + for i in range(32): + dist_lengths[i] = 0 + for i in range(hdist): + dist_lengths[i] = code_lengths[hlit + i] + n_dist_ov[0] = 0 + _inflate_build_table(dist_lengths, hdist, dist_table, MAX_DIST_BITS, + dist_ov_codes, dist_ov_lens, n_dist_ov) + + # Decode symbols (same loop as fixed Huffman) + while True: + sym = _inflate_decode_symbol( + compressed_buf, src_start, src_len, bit_pos, + lit_table, MAX_LIT_BITS, + lit_ov_codes, lit_ov_lens, n_lit_ov) + + if sym < 256: + if out_pos < dst_len: + decompressed_buf[dst_start + out_pos] = numba_uint8(sym) + out_pos += 1 + elif sym == 256: + break + else: + li = sym - 257 + if li < 29: + length = d_len_base[li] + if d_len_extra[li] > 0: + length += _inflate_read_bits( + compressed_buf, src_start, src_len, + bit_pos, d_len_extra[li]) + else: + length = 3 + + dsym = _inflate_decode_symbol( + compressed_buf, src_start, src_len, bit_pos, + dist_table, MAX_DIST_BITS, + dist_ov_codes, dist_ov_lens, n_dist_ov) + if dsym < 30: + dist = d_dist_base[dsym] + if d_dist_extra[dsym] > 0: + dist += _inflate_read_bits( + compressed_buf, src_start, src_len, + bit_pos, d_dist_extra[dsym]) + else: + dist = 1 + + for i in range(length): + if out_pos < dst_len and dist <= out_pos: + decompressed_buf[dst_start + out_pos] = \ + decompressed_buf[dst_start + out_pos - dist] + out_pos += 1 + else: + break # invalid block type + + if bfinal: + break + + tile_actual_sizes[tile_idx] = out_pos + + +# --------------------------------------------------------------------------- +# Predictor decode kernels -- one thread per row +# --------------------------------------------------------------------------- + +@cuda.jit +def _predictor_decode_kernel(data, width, height, bytes_per_sample): + """Undo horizontal differencing (predictor=2), one thread per row.""" + row = cuda.grid(1) + if row >= height: + return + + row_bytes = width * bytes_per_sample + row_start = row * row_bytes + + for col in range(bytes_per_sample, row_bytes): + idx = row_start + col + data[idx] = numba_uint8( + (numba_int32(data[idx]) + numba_int32(data[idx - bytes_per_sample])) & 0xFF) + + +@cuda.jit +def _fp_predictor_decode_kernel(data, tmp, width, height, bps): + """Undo floating-point predictor (predictor=3), one thread per row. + + data: flat uint8 device array + tmp: scratch buffer, same size as data + """ + row = cuda.grid(1) + if row >= height: + return + + row_len = width * bps + start = row * row_len + + # Step 1: undo horizontal differencing + for i in range(1, row_len): + idx = start + i + data[idx] = numba_uint8( + (numba_int32(data[idx]) + numba_int32(data[idx - 1])) & 0xFF) + + # Step 2: un-transpose byte lanes (MSB-first) back to native order + for sample in range(width): + for b in range(bps): + tmp[start + sample * bps + b] = data[start + (bps - 1 - b) * width + sample] + + # Copy back + for i in range(row_len): + data[start + i] = tmp[start + i] + + +# --------------------------------------------------------------------------- +# Tile assembly kernel -- one thread per output pixel +# --------------------------------------------------------------------------- + +@cuda.jit +def _assemble_tiles_kernel( + decompressed_buf, # uint8: all decompressed tiles concatenated + tile_out_offsets, # int64: byte offset of each tile in decompressed_buf + tile_width, # int: tile width in pixels + tile_height, # int: tile height in pixels + bytes_per_pixel, # int: dtype.itemsize * samples_per_pixel + image_width, # int: output image width + image_height, # int: output image height + tiles_across, # int: number of tile columns + output, # uint8: output image buffer (flat, row-major) +): + """Copy decompressed tile pixels into the output image, one thread per pixel.""" + pixel_idx = cuda.grid(1) + total_pixels = image_width * image_height + if pixel_idx >= total_pixels: + return + + # Output row and column + out_row = pixel_idx // image_width + out_col = pixel_idx % image_width + + # Which tile does this pixel belong to? + tile_row = out_row // tile_height + tile_col = out_col // tile_width + tile_idx = tile_row * tiles_across + tile_col + + # Position within the tile + local_row = out_row - tile_row * tile_height + local_col = out_col - tile_col * tile_width + + # Source and destination byte offsets + tile_offset = tile_out_offsets[tile_idx] + src_byte = tile_offset + (local_row * tile_width + local_col) * bytes_per_pixel + dst_byte = (out_row * image_width + out_col) * bytes_per_pixel + + for b in range(bytes_per_pixel): + output[dst_byte + b] = decompressed_buf[src_byte + b] + + +# --------------------------------------------------------------------------- +# KvikIO GDS (GPUDirect Storage) -- read file directly to GPU +# --------------------------------------------------------------------------- + +def _try_kvikio_read_tiles(file_path, tile_offsets, tile_byte_counts, tile_bytes): + """Read compressed tile bytes directly from SSD to GPU via GDS. + + When kvikio is available and GDS is supported, file data is DMA'd + directly from the NVMe drive to GPU VRAM, bypassing CPU entirely. + Falls back to None if kvikio is not installed or GDS is not available. + + Returns list of cupy arrays (one per tile) on GPU, or None. + """ + try: + import kvikio + import cupy + except ImportError: + return None + + try: + d_tiles = [] + with kvikio.CuFile(file_path, 'r') as f: + for off, bc in zip(tile_offsets, tile_byte_counts): + buf = cupy.empty(bc, dtype=cupy.uint8) + nbytes = f.pread(buf, file_offset=off) + # Verify the read completed correctly + actual = nbytes.get() if hasattr(nbytes, 'get') else int(nbytes) + if actual != bc: + return None # partial read, fall back + d_tiles.append(buf) + cupy.cuda.Device().synchronize() + return d_tiles + except Exception: + # GDS not available, version mismatch, or CUDA error + # Reset CUDA error state if possible + try: + import cupy + cupy.cuda.Device().synchronize() + except Exception: + pass + return None + + +# --------------------------------------------------------------------------- +# nvCOMP batch decompression (optional, fast path) +# --------------------------------------------------------------------------- + +def _find_nvcomp_lib(): + """Find and load libnvcomp.so. Returns ctypes.CDLL or None.""" + import ctypes + import os + + # Try common locations + search_paths = [ + 'libnvcomp.so', # system LD_LIBRARY_PATH + ] + + # Check conda envs + conda_prefix = os.environ.get('CONDA_PREFIX', '') + if conda_prefix: + search_paths.append(os.path.join(conda_prefix, 'lib', 'libnvcomp.so')) + + # Also check sibling conda envs that might have rapids + conda_base = os.path.dirname(conda_prefix) if conda_prefix else '' + if conda_base: + for env in ['rapids', 'test-again', 'rtxpy-fire']: + p = os.path.join(conda_base, env, 'lib', 'libnvcomp.so') + if os.path.exists(p): + search_paths.append(p) + + for path in search_paths: + try: + return ctypes.CDLL(path) + except OSError: + continue + return None + + +_nvcomp_lib = None +_nvcomp_checked = False + + +def _get_nvcomp(): + """Get the nvCOMP library handle (cached). Returns CDLL or None.""" + global _nvcomp_lib, _nvcomp_checked + if not _nvcomp_checked: + _nvcomp_checked = True + _nvcomp_lib = _find_nvcomp_lib() + return _nvcomp_lib + + +def _try_nvcomp_batch_decompress(compressed_tiles, tile_bytes, compression): + """Try batch decompression via nvCOMP C API. Returns CuPy array or None. + + Uses nvcompBatchedDeflateDecompressAsync to decompress all tiles in + one GPU API call. Falls back to None if nvCOMP is not available. + """ + if compression not in (8, 32946, 50000): # Deflate and ZSTD + return None + + lib = _get_nvcomp() + if lib is None: + # Try kvikio.nvcomp as alternative + try: + import kvikio.nvcomp as nvcomp + except ImportError: + return None + + import cupy + try: + raw_tiles = [] + for tile in compressed_tiles: + raw_tiles.append(tile[2:-4] if len(tile) > 6 else tile) + manager = nvcomp.DeflateManager(chunk_size=tile_bytes) + d_compressed = [cupy.asarray(np.frombuffer(t, dtype=np.uint8)) + for t in raw_tiles] + d_decompressed = manager.decompress(d_compressed) + return cupy.concatenate([d.ravel() for d in d_decompressed]) + except Exception: + return None + + # Direct ctypes nvCOMP C API + import ctypes + import cupy + + class _NvcompDecompOpts(ctypes.Structure): + """nvCOMP batched decompression options (passed by value).""" + _fields_ = [ + ('backend', ctypes.c_int), + ('reserved', ctypes.c_char * 60), + ] + + # Deflate has a different struct with sort_before_hw_decompress field + class _NvcompDeflateDecompOpts(ctypes.Structure): + _fields_ = [ + ('backend', ctypes.c_int), + ('sort_before_hw_decompress', ctypes.c_int), + ('reserved', ctypes.c_char * 56), + ] + + try: + n_tiles = len(compressed_tiles) + + # Prepare compressed tiles for nvCOMP + if compression in (8, 32946): # Deflate + # Strip 2-byte zlib header + 4-byte adler32 checksum + raw_tiles = [t[2:-4] if len(t) > 6 else t for t in compressed_tiles] + get_temp_fn = 'nvcompBatchedDeflateDecompressGetTempSizeAsync' + decomp_fn = 'nvcompBatchedDeflateDecompressAsync' + # backend=2 (CUDA) works on all GPUs; backend=1 (HW) needs Ada/Hopper + opts = _NvcompDeflateDecompOpts(backend=2, sort_before_hw_decompress=0, + reserved=b'\x00' * 56) + elif compression == 50000: # ZSTD + raw_tiles = list(compressed_tiles) # no header stripping + get_temp_fn = 'nvcompBatchedZstdDecompressGetTempSizeAsync' + decomp_fn = 'nvcompBatchedZstdDecompressAsync' + opts = _NvcompDecompOpts(backend=0, reserved=b'\x00' * 60) + else: + return None + + # Upload compressed tiles to device + d_comp_bufs = [cupy.asarray(np.frombuffer(t, dtype=np.uint8)) for t in raw_tiles] + d_decomp_bufs = [cupy.empty(tile_bytes, dtype=cupy.uint8) for _ in range(n_tiles)] + + d_comp_ptrs = cupy.array([b.data.ptr for b in d_comp_bufs], dtype=cupy.uint64) + d_decomp_ptrs = cupy.array([b.data.ptr for b in d_decomp_bufs], dtype=cupy.uint64) + d_comp_sizes = cupy.array([len(t) for t in raw_tiles], dtype=cupy.uint64) + d_buf_sizes = cupy.full(n_tiles, tile_bytes, dtype=cupy.uint64) + d_actual = cupy.empty(n_tiles, dtype=cupy.uint64) + + # Set argtypes for proper struct passing + temp_fn = getattr(lib, get_temp_fn) + temp_fn.restype = ctypes.c_int + + temp_size = ctypes.c_size_t(0) + status = temp_fn( + ctypes.c_size_t(n_tiles), + ctypes.c_size_t(tile_bytes), + opts, + ctypes.byref(temp_size), + ctypes.c_size_t(n_tiles * tile_bytes), + ) + if status != 0: + return None + + ts = max(temp_size.value, 1) + d_temp = cupy.empty(ts, dtype=cupy.uint8) + d_statuses = cupy.zeros(n_tiles, dtype=cupy.int32) + + dec_fn = getattr(lib, decomp_fn) + dec_fn.restype = ctypes.c_int + + status = dec_fn( + ctypes.c_void_p(d_comp_ptrs.data.ptr), + ctypes.c_void_p(d_comp_sizes.data.ptr), + ctypes.c_void_p(d_buf_sizes.data.ptr), + ctypes.c_void_p(d_actual.data.ptr), + ctypes.c_size_t(n_tiles), + ctypes.c_void_p(d_temp.data.ptr), + ctypes.c_size_t(ts), + ctypes.c_void_p(d_decomp_ptrs.data.ptr), + opts, + ctypes.c_void_p(d_statuses.data.ptr), + ctypes.c_void_p(0), # default stream + ) + if status != 0: + return None + + cupy.cuda.Device().synchronize() + + if int(cupy.any(d_statuses != 0)): + return None + + return cupy.concatenate(d_decomp_bufs) + + except Exception: + return None + + +# --------------------------------------------------------------------------- +# High-level GPU decode pipeline +# --------------------------------------------------------------------------- + +def gpu_decode_tiles_from_file( + file_path: str, + tile_offsets: list | tuple, + tile_byte_counts: list | tuple, + tile_width: int, + tile_height: int, + image_width: int, + image_height: int, + compression: int, + predictor: int, + dtype: np.dtype, + samples: int = 1, +): + """Decode tiles from a file, using GDS if available. + + Tries KvikIO GDS (SSD → GPU direct) first, then falls back to + CPU mmap + gpu_decode_tiles. + """ + import cupy + + # Try GDS: read compressed tiles directly from SSD to GPU + d_tiles = _try_kvikio_read_tiles( + file_path, tile_offsets, tile_byte_counts, + tile_width * tile_height * dtype.itemsize * samples) + + if d_tiles is not None: + # Tiles are already on GPU as cupy arrays. + # Try nvCOMP batch decompress on them directly. + tile_bytes = tile_width * tile_height * dtype.itemsize * samples + + if compression in (50000,) and _get_nvcomp() is not None: + # ZSTD: nvCOMP can decompress directly from GPU buffers + result = _try_nvcomp_from_device_bufs( + d_tiles, tile_bytes, compression) + if result is not None: + decomp_offsets = np.arange(len(d_tiles), dtype=np.int64) * tile_bytes + d_decomp = result + d_decomp_offsets = cupy.asarray(decomp_offsets) + # Apply predictor + assemble (shared code below) + return _apply_predictor_and_assemble( + d_decomp, d_decomp_offsets, len(d_tiles), + tile_width, tile_height, image_width, image_height, + predictor, dtype, samples, tile_bytes) + + # GDS read succeeded but nvCOMP can't decompress on GPU, + # or it's LZW/deflate. Copy tiles to host and use normal path. + compressed_tiles = [t.get().tobytes() for t in d_tiles] + else: + # No GDS -- read tiles via CPU mmap (caller provides bytes) + # This path is used when called from gpu_decode_tiles() + return None # signal caller to use the bytes-based path + + return gpu_decode_tiles( + compressed_tiles, tile_width, tile_height, + image_width, image_height, compression, predictor, dtype, samples) + + +def _try_nvcomp_from_device_bufs(d_tiles, tile_bytes, compression): + """Run nvCOMP batch decompress on tiles already in GPU memory.""" + import ctypes + import cupy + + lib = _get_nvcomp() + if lib is None: + return None + + class _NvcompDecompOpts(ctypes.Structure): + _fields_ = [('backend', ctypes.c_int), ('reserved', ctypes.c_char * 60)] + + try: + n = len(d_tiles) + d_decomp_bufs = [cupy.empty(tile_bytes, dtype=cupy.uint8) for _ in range(n)] + + d_comp_ptrs = cupy.array([t.data.ptr for t in d_tiles], dtype=cupy.uint64) + d_decomp_ptrs = cupy.array([b.data.ptr for b in d_decomp_bufs], dtype=cupy.uint64) + d_comp_sizes = cupy.array([t.size for t in d_tiles], dtype=cupy.uint64) + d_buf_sizes = cupy.full(n, tile_bytes, dtype=cupy.uint64) + d_actual = cupy.empty(n, dtype=cupy.uint64) + + opts = _NvcompDecompOpts(backend=0, reserved=b'\x00' * 60) + + fn_name = {50000: 'nvcompBatchedZstdDecompressGetTempSizeAsync'}.get(compression) + dec_name = {50000: 'nvcompBatchedZstdDecompressAsync'}.get(compression) + if fn_name is None: + return None + + temp_fn = getattr(lib, fn_name) + temp_fn.restype = ctypes.c_int + temp_size = ctypes.c_size_t(0) + s = temp_fn(n, tile_bytes, opts, ctypes.byref(temp_size), n * tile_bytes) + if s != 0: + return None + + ts = max(temp_size.value, 1) + d_temp = cupy.empty(ts, dtype=cupy.uint8) + d_statuses = cupy.zeros(n, dtype=cupy.int32) + + dec_fn = getattr(lib, dec_name) + dec_fn.restype = ctypes.c_int + s = dec_fn( + ctypes.c_void_p(d_comp_ptrs.data.ptr), + ctypes.c_void_p(d_comp_sizes.data.ptr), + ctypes.c_void_p(d_buf_sizes.data.ptr), + ctypes.c_void_p(d_actual.data.ptr), + ctypes.c_size_t(n), + ctypes.c_void_p(d_temp.data.ptr), ctypes.c_size_t(ts), + ctypes.c_void_p(d_decomp_ptrs.data.ptr), + opts, + ctypes.c_void_p(d_statuses.data.ptr), + ctypes.c_void_p(0), + ) + if s != 0: + return None + + cupy.cuda.Device().synchronize() + if int(cupy.any(d_statuses != 0)): + return None + + return cupy.concatenate(d_decomp_bufs) + except Exception: + return None + + +def _apply_predictor_and_assemble(d_decomp, d_decomp_offsets, n_tiles, + tile_width, tile_height, + image_width, image_height, + predictor, dtype, samples, tile_bytes): + """Apply predictor decode and tile assembly on GPU.""" + import cupy + + bytes_per_pixel = dtype.itemsize * samples + + if predictor == 2: + total_rows = n_tiles * tile_height + tpb = min(256, total_rows) + bpg = math.ceil(total_rows / tpb) + _predictor_decode_kernel[bpg, tpb]( + d_decomp, tile_width * samples, total_rows, dtype.itemsize * samples) + cuda.synchronize() + elif predictor == 3: + total_rows = n_tiles * tile_height + tpb = min(256, total_rows) + bpg = math.ceil(total_rows / tpb) + d_tmp = cupy.empty_like(d_decomp) + _fp_predictor_decode_kernel[bpg, tpb]( + d_decomp, d_tmp, tile_width * samples, total_rows, dtype.itemsize) + cuda.synchronize() + + tiles_across = math.ceil(image_width / tile_width) + total_pixels = image_width * image_height + d_output = cupy.empty(total_pixels * bytes_per_pixel, dtype=cupy.uint8) + + tpb = 256 + bpg = math.ceil(total_pixels / tpb) + _assemble_tiles_kernel[bpg, tpb]( + d_decomp, d_decomp_offsets, + tile_width, tile_height, bytes_per_pixel, + image_width, image_height, tiles_across, + d_output, + ) + cuda.synchronize() + + if samples > 1: + return d_output.view(dtype=cupy.dtype(dtype)).reshape( + image_height, image_width, samples) + return d_output.view(dtype=cupy.dtype(dtype)).reshape( + image_height, image_width) + + +def gpu_decode_tiles( + compressed_tiles: list[bytes], + tile_width: int, + tile_height: int, + image_width: int, + image_height: int, + compression: int, + predictor: int, + dtype: np.dtype, + samples: int = 1, +): + """Decode and assemble TIFF tiles entirely on GPU. + + Parameters + ---------- + compressed_tiles : list of bytes + One entry per tile, in row-major tile order. + tile_width, tile_height : int + Tile dimensions. + image_width, image_height : int + Output image dimensions. + compression : int + TIFF compression tag (5=LZW, 1=none). + predictor : int + Predictor tag (1=none, 2=horizontal, 3=float). + dtype : np.dtype + Output pixel dtype. + samples : int + Samples per pixel. + + Returns + ------- + cupy.ndarray + Decoded image on GPU device. + """ + import cupy + + n_tiles = len(compressed_tiles) + bytes_per_pixel = dtype.itemsize * samples + tile_bytes = tile_width * tile_height * bytes_per_pixel + + # Try nvCOMP batch decompression first (much faster if available) + nvcomp_result = _try_nvcomp_batch_decompress( + compressed_tiles, tile_bytes, compression) + if nvcomp_result is not None: + d_decomp = nvcomp_result + decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + d_decomp_offsets = cupy.asarray(decomp_offsets) + elif compression == 5: # LZW + # Concatenate all compressed tiles into one device buffer + comp_sizes = [len(t) for t in compressed_tiles] + comp_offsets = np.zeros(n_tiles, dtype=np.int64) + for i in range(1, n_tiles): + comp_offsets[i] = comp_offsets[i - 1] + comp_sizes[i - 1] + total_comp = sum(comp_sizes) + + comp_buf_host = np.empty(total_comp, dtype=np.uint8) + for i, tile in enumerate(compressed_tiles): + comp_buf_host[comp_offsets[i]:comp_offsets[i] + comp_sizes[i]] = \ + np.frombuffer(tile, dtype=np.uint8) + + # Transfer to device + d_comp = cupy.asarray(comp_buf_host) + d_comp_offsets = cupy.asarray(comp_offsets) + d_comp_sizes = cupy.asarray(np.array(comp_sizes, dtype=np.int64)) + + # Allocate decompressed buffer on device + decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + d_decomp = cupy.zeros(n_tiles * tile_bytes, dtype=cupy.uint8) + d_decomp_offsets = cupy.asarray(decomp_offsets) + d_tile_sizes = cupy.full(n_tiles, tile_bytes, dtype=cupy.int64) + d_actual_sizes = cupy.zeros(n_tiles, dtype=cupy.int64) + + # Launch LZW decode: one thread block per tile (thread 0 decodes, + # table in shared memory). Block size 32 for warp scheduling. + _lzw_decode_tiles_kernel[n_tiles, 32]( + d_comp, d_comp_offsets, d_comp_sizes, + d_decomp, d_decomp_offsets, d_tile_sizes, d_actual_sizes, + ) + cuda.synchronize() + + elif compression in (8, 32946): # Deflate / Adobe Deflate + comp_sizes = [len(t) for t in compressed_tiles] + comp_offsets = np.zeros(n_tiles, dtype=np.int64) + for i in range(1, n_tiles): + comp_offsets[i] = comp_offsets[i - 1] + comp_sizes[i - 1] + total_comp = sum(comp_sizes) + + comp_buf_host = np.empty(total_comp, dtype=np.uint8) + for i, tile in enumerate(compressed_tiles): + comp_buf_host[comp_offsets[i]:comp_offsets[i] + comp_sizes[i]] = \ + np.frombuffer(tile, dtype=np.uint8) + + d_comp = cupy.asarray(comp_buf_host) + d_comp_offsets = cupy.asarray(comp_offsets) + d_comp_sizes = cupy.asarray(np.array(comp_sizes, dtype=np.int64)) + + decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + d_decomp = cupy.zeros(n_tiles * tile_bytes, dtype=cupy.uint8) + d_decomp_offsets = cupy.asarray(decomp_offsets) + d_tile_sizes = cupy.full(n_tiles, tile_bytes, dtype=cupy.int64) + d_actual_sizes = cupy.zeros(n_tiles, dtype=cupy.int64) + + # Static deflate tables on device + d_len_base = cupy.asarray(_LEN_BASE) + d_len_extra = cupy.asarray(_LEN_EXTRA) + d_dist_base = cupy.asarray(_DIST_BASE) + d_dist_extra = cupy.asarray(_DIST_EXTRA) + d_cl_order = cupy.asarray(_CL_ORDER) + + # One thread block per tile, thread 0 does the inflate + _inflate_tiles_kernel[n_tiles, 32]( + d_comp, d_comp_offsets, d_comp_sizes, + d_decomp, d_decomp_offsets, d_tile_sizes, d_actual_sizes, + d_len_base, d_len_extra, d_dist_base, d_dist_extra, d_cl_order, + ) + cuda.synchronize() + + elif compression == 1: # Uncompressed + raw_host = np.empty(n_tiles * tile_bytes, dtype=np.uint8) + for i, tile in enumerate(compressed_tiles): + start = i * tile_bytes + t = np.frombuffer(tile, dtype=np.uint8) + raw_host[start:start + len(t)] = t[:tile_bytes] + d_decomp = cupy.asarray(raw_host) + decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + d_decomp_offsets = cupy.asarray(decomp_offsets) + + else: + # Unsupported GPU codec: decompress on CPU, transfer to GPU + from ._compression import decompress as cpu_decompress + raw_host = np.empty(n_tiles * tile_bytes, dtype=np.uint8) + for i, tile in enumerate(compressed_tiles): + start = i * tile_bytes + chunk = cpu_decompress(tile, compression, tile_bytes) + raw_host[start:start + min(len(chunk), tile_bytes)] = \ + chunk[:tile_bytes] if len(chunk) >= tile_bytes else \ + np.pad(chunk, (0, tile_bytes - len(chunk))) + d_decomp = cupy.asarray(raw_host) + decomp_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + d_decomp_offsets = cupy.asarray(decomp_offsets) + + # Apply predictor on GPU + if predictor == 2: + # Horizontal differencing: one thread per row across all tiles + total_rows = n_tiles * tile_height + tpb = min(256, total_rows) + bpg = math.ceil(total_rows / tpb) + # Reshape so each tile's rows are contiguous (they already are) + _predictor_decode_kernel[bpg, tpb]( + d_decomp, tile_width * samples, total_rows, dtype.itemsize * samples) + cuda.synchronize() + + elif predictor == 3: + # Float predictor: one thread per row + total_rows = n_tiles * tile_height + tpb = min(256, total_rows) + bpg = math.ceil(total_rows / tpb) + d_tmp = cupy.empty_like(d_decomp) + _fp_predictor_decode_kernel[bpg, tpb]( + d_decomp, d_tmp, tile_width * samples, total_rows, dtype.itemsize) + cuda.synchronize() + + # Assemble tiles into output image on GPU + tiles_across = math.ceil(image_width / tile_width) + total_pixels = image_width * image_height + d_output = cupy.empty(total_pixels * bytes_per_pixel, dtype=cupy.uint8) + + tpb = 256 + bpg = math.ceil(total_pixels / tpb) + _assemble_tiles_kernel[bpg, tpb]( + d_decomp, d_decomp_offsets, + tile_width, tile_height, bytes_per_pixel, + image_width, image_height, tiles_across, + d_output, + ) + cuda.synchronize() + + # Reshape to image + if samples > 1: + return d_output.view(dtype=cupy.dtype(dtype)).reshape( + image_height, image_width, samples) + return d_output.view(dtype=cupy.dtype(dtype)).reshape( + image_height, image_width) + + +# --------------------------------------------------------------------------- +# GPU tile extraction kernel -- image → individual tiles +# --------------------------------------------------------------------------- + +@cuda.jit +def _extract_tiles_kernel( + image, # uint8: flat row-major image + tile_bufs, # uint8: output buffer (all tiles concatenated) + tile_offsets, # int64: byte offset of each tile in tile_bufs + tile_width, + tile_height, + bytes_per_pixel, + image_width, + image_height, + tiles_across, +): + """Extract tile pixels from image into per-tile buffers, one thread per pixel.""" + pixel_idx = cuda.grid(1) + total_pixels = image_width * image_height + if pixel_idx >= total_pixels: + return + + row = pixel_idx // image_width + col = pixel_idx % image_width + + tile_row = row // tile_height + tile_col = col // tile_width + tile_idx = tile_row * tiles_across + tile_col + + local_row = row - tile_row * tile_height + local_col = col - tile_col * tile_width + + src_byte = (row * image_width + col) * bytes_per_pixel + tile_off = tile_offsets[tile_idx] + dst_byte = tile_off + (local_row * tile_width + local_col) * bytes_per_pixel + + for b in range(bytes_per_pixel): + tile_bufs[dst_byte + b] = image[src_byte + b] + + +# --------------------------------------------------------------------------- +# GPU predictor encode kernels +# --------------------------------------------------------------------------- + +@cuda.jit +def _predictor_encode_kernel(data, width, height, bytes_per_sample): + """Apply horizontal differencing (predictor=2), one thread per row. + Process right-to-left to avoid overwriting values we still need. + """ + row = cuda.grid(1) + if row >= height: + return + + row_bytes = width * bytes_per_sample + row_start = row * row_bytes + + for col in range(row_bytes - 1, bytes_per_sample - 1, -1): + idx = row_start + col + data[idx] = numba_uint8( + (numba_int32(data[idx]) - numba_int32(data[idx - bytes_per_sample])) & 0xFF) + + +@cuda.jit +def _fp_predictor_encode_kernel(data, tmp, width, height, bps): + """Apply floating-point predictor (predictor=3), one thread per row.""" + row = cuda.grid(1) + if row >= height: + return + + row_len = width * bps + start = row * row_len + + # Step 1: transpose to byte-swizzled layout (MSB lane first) + for sample in range(width): + for b in range(bps): + tmp[start + (bps - 1 - b) * width + sample] = data[start + sample * bps + b] + + # Copy back + for i in range(row_len): + data[start + i] = tmp[start + i] + + # Step 2: horizontal differencing (right to left) + for i in range(row_len - 1, 0, -1): + idx = start + i + data[idx] = numba_uint8( + (numba_int32(data[idx]) - numba_int32(data[idx - 1])) & 0xFF) + + +# --------------------------------------------------------------------------- +# nvCOMP batch compress +# --------------------------------------------------------------------------- + +def _nvcomp_batch_compress(d_tile_bufs, tile_byte_counts, tile_bytes, + compression, n_tiles): + """Compress tiles on GPU via nvCOMP. Returns list of bytes on CPU. + + Parameters + ---------- + d_tile_bufs : list of cupy arrays + Uncompressed tile data on GPU. + tile_byte_counts : not used (all tiles same size) + tile_bytes : int + Size of each uncompressed tile in bytes. + compression : int + TIFF compression tag (8=deflate, 50000=ZSTD). + n_tiles : int + Number of tiles. + + Returns + ------- + list of bytes + Compressed tile data on CPU, ready for file assembly. + """ + import ctypes + import cupy + + lib = _get_nvcomp() + if lib is None: + return None + + class _CompOpts(ctypes.Structure): + _fields_ = [('algorithm', ctypes.c_int), ('reserved', ctypes.c_char * 60)] + + class _DeflateCompOpts(ctypes.Structure): + _fields_ = [('algorithm', ctypes.c_int), ('reserved', ctypes.c_char * 60)] + + try: + # Select codec + if compression == 50000: # ZSTD + get_max_fn = 'nvcompBatchedZstdCompressGetMaxOutputChunkSize' + get_temp_fn = 'nvcompBatchedZstdCompressGetTempSizeAsync' + compress_fn = 'nvcompBatchedZstdCompressAsync' + opts = _CompOpts(algorithm=0, reserved=b'\x00' * 60) + elif compression in (8, 32946): # Deflate + get_max_fn = 'nvcompBatchedDeflateCompressGetMaxOutputChunkSize' + get_temp_fn = 'nvcompBatchedDeflateCompressGetTempSizeAsync' + compress_fn = 'nvcompBatchedDeflateCompressAsync' + opts = _DeflateCompOpts(algorithm=1, reserved=b'\x00' * 60) + else: + return None + + # Get max compressed chunk size + max_comp_size = ctypes.c_size_t(0) + fn = getattr(lib, get_max_fn) + fn.restype = ctypes.c_int + s = fn(ctypes.c_size_t(tile_bytes), opts, ctypes.byref(max_comp_size)) + if s != 0: + return None + max_cs = max_comp_size.value + + # Allocate compressed output buffers on device + d_comp_bufs = [cupy.empty(max_cs, dtype=cupy.uint8) for _ in range(n_tiles)] + + # Build pointer and size arrays + d_uncomp_ptrs = cupy.array([b.data.ptr for b in d_tile_bufs], dtype=cupy.uint64) + d_comp_ptrs = cupy.array([b.data.ptr for b in d_comp_bufs], dtype=cupy.uint64) + d_uncomp_sizes = cupy.full(n_tiles, tile_bytes, dtype=cupy.uint64) + d_comp_sizes = cupy.empty(n_tiles, dtype=cupy.uint64) + + # Get temp size + temp_size = ctypes.c_size_t(0) + fn2 = getattr(lib, get_temp_fn) + fn2.restype = ctypes.c_int + s = fn2(ctypes.c_size_t(n_tiles), ctypes.c_size_t(tile_bytes), + opts, ctypes.byref(temp_size), ctypes.c_size_t(n_tiles * tile_bytes)) + if s != 0: + return None + + d_temp = cupy.empty(max(temp_size.value, 1), dtype=cupy.uint8) + d_statuses = cupy.zeros(n_tiles, dtype=cupy.int32) + + # Compress + fn3 = getattr(lib, compress_fn) + fn3.restype = ctypes.c_int + s = fn3( + ctypes.c_void_p(d_uncomp_ptrs.data.ptr), + ctypes.c_void_p(d_uncomp_sizes.data.ptr), + ctypes.c_size_t(tile_bytes), + ctypes.c_size_t(n_tiles), + ctypes.c_void_p(d_temp.data.ptr), + ctypes.c_size_t(max(temp_size.value, 1)), + ctypes.c_void_p(d_comp_ptrs.data.ptr), + ctypes.c_void_p(d_comp_sizes.data.ptr), + opts, + ctypes.c_void_p(d_statuses.data.ptr), + ctypes.c_void_p(0), # default stream + ) + if s != 0: + return None + + cupy.cuda.Device().synchronize() + + if int(cupy.any(d_statuses != 0)): + return None + + # For deflate, compute adler32 checksums from uncompressed tiles + # before reading compressed data (need the originals) + adler_checksums = None + if compression in (8, 32946): + import zlib + import struct + adler_checksums = [] + for i in range(n_tiles): + uncomp = d_tile_bufs[i].get().tobytes() + adler_checksums.append(zlib.adler32(uncomp)) + + # Read compressed sizes and data back to CPU + comp_sizes = d_comp_sizes.get().astype(int) + result = [] + for i in range(n_tiles): + cs = int(comp_sizes[i]) + raw = d_comp_bufs[i][:cs].get().tobytes() + + if adler_checksums is not None: + # Wrap raw deflate in zlib format: header + data + adler32 + checksum = struct.pack('>I', adler_checksums[i] & 0xFFFFFFFF) + raw = b'\x78\x9c' + raw + checksum + + result.append(raw) + + return result + + except Exception: + return None + + +# --------------------------------------------------------------------------- +# High-level GPU write pipeline +# --------------------------------------------------------------------------- + +def gpu_compress_tiles(d_image, tile_width, tile_height, + image_width, image_height, + compression, predictor, dtype, + samples=1): + """Extract and compress tiles from a CuPy image on GPU. + + Parameters + ---------- + d_image : cupy.ndarray + 2D or 3D image on GPU device. + tile_width, tile_height : int + Tile dimensions. + image_width, image_height : int + Image dimensions. + compression : int + TIFF compression tag. + predictor : int + Predictor tag (1=none, 2=horizontal, 3=float). + dtype : np.dtype + Pixel dtype. + samples : int + Samples per pixel. + + Returns + ------- + list of bytes + Compressed tile data on CPU, ready for _assemble_tiff. + """ + import cupy + + bytes_per_pixel = dtype.itemsize * samples + tile_bytes = tile_width * tile_height * bytes_per_pixel + tiles_across = math.ceil(image_width / tile_width) + tiles_down = math.ceil(image_height / tile_height) + n_tiles = tiles_across * tiles_down + + # Flatten image to uint8 + d_flat = d_image.view(cupy.uint8).ravel() + + # Allocate tile buffer + d_tile_buf = cupy.zeros(n_tiles * tile_bytes, dtype=cupy.uint8) + tile_offsets = np.arange(n_tiles, dtype=np.int64) * tile_bytes + d_tile_offsets = cupy.asarray(tile_offsets) + + # Extract tiles on GPU + total_pixels = image_width * image_height + tpb = 256 + bpg = math.ceil(total_pixels / tpb) + _extract_tiles_kernel[bpg, tpb]( + d_flat, d_tile_buf, d_tile_offsets, + tile_width, tile_height, bytes_per_pixel, + image_width, image_height, tiles_across) + cuda.synchronize() + + # Apply predictor encode on GPU + total_rows = n_tiles * tile_height + if predictor == 2: + tpb_r = min(256, total_rows) + bpg_r = math.ceil(total_rows / tpb_r) + _predictor_encode_kernel[bpg_r, tpb_r]( + d_tile_buf, tile_width * samples, total_rows, dtype.itemsize * samples) + cuda.synchronize() + elif predictor == 3: + tpb_r = min(256, total_rows) + bpg_r = math.ceil(total_rows / tpb_r) + d_tmp = cupy.empty_like(d_tile_buf) + _fp_predictor_encode_kernel[bpg_r, tpb_r]( + d_tile_buf, d_tmp, tile_width * samples, total_rows, dtype.itemsize) + cuda.synchronize() + + # Split into per-tile buffers for nvCOMP + d_tiles = [d_tile_buf[i * tile_bytes:(i + 1) * tile_bytes] for i in range(n_tiles)] + + # Try nvCOMP batch compress + result = _nvcomp_batch_compress(d_tiles, None, tile_bytes, compression, n_tiles) + + if result is not None: + return result + + # Fallback: copy to CPU, compress with CPU codecs + from ._compression import compress as cpu_compress + cpu_buf = d_tile_buf.get() + result = [] + for i in range(n_tiles): + start = i * tile_bytes + tile_data = bytes(cpu_buf[start:start + tile_bytes]) + result.append(cpu_compress(tile_data, compression)) + + return result diff --git a/xrspatial/geotiff/_header.py b/xrspatial/geotiff/_header.py new file mode 100644 index 00000000..1f0751f7 --- /dev/null +++ b/xrspatial/geotiff/_header.py @@ -0,0 +1,392 @@ +"""TIFF/BigTIFF header and IFD parsing.""" +from __future__ import annotations + +import struct +from dataclasses import dataclass, field +from typing import Any + +from ._dtypes import ( + TIFF_TYPE_SIZES, + TIFF_TYPE_STRUCT_CODES, + RATIONAL, + SRATIONAL, + ASCII, + UNDEFINED, +) + +# Well-known TIFF tag IDs +TAG_IMAGE_WIDTH = 256 +TAG_IMAGE_LENGTH = 257 +TAG_BITS_PER_SAMPLE = 258 +TAG_COMPRESSION = 259 +TAG_PHOTOMETRIC = 262 +TAG_STRIP_OFFSETS = 273 +TAG_SAMPLES_PER_PIXEL = 277 +TAG_ROWS_PER_STRIP = 278 +TAG_STRIP_BYTE_COUNTS = 279 +TAG_X_RESOLUTION = 282 +TAG_Y_RESOLUTION = 283 +TAG_PLANAR_CONFIG = 284 +TAG_RESOLUTION_UNIT = 296 +TAG_PREDICTOR = 317 +TAG_TILE_WIDTH = 322 +TAG_TILE_LENGTH = 323 +TAG_TILE_OFFSETS = 324 +TAG_TILE_BYTE_COUNTS = 325 +TAG_COLORMAP = 320 +TAG_EXTRA_SAMPLES = 338 +TAG_SAMPLE_FORMAT = 339 +TAG_GDAL_METADATA = 42112 +TAG_GDAL_NODATA = 42113 + +# GeoTIFF tags +TAG_MODEL_PIXEL_SCALE = 33550 +TAG_MODEL_TIEPOINT = 33922 +TAG_MODEL_TRANSFORMATION = 34264 +TAG_GEO_KEY_DIRECTORY = 34735 +TAG_GEO_DOUBLE_PARAMS = 34736 +TAG_GEO_ASCII_PARAMS = 34737 + + +@dataclass +class TIFFHeader: + """Parsed TIFF file header.""" + byte_order: str # '<' or '>' + is_bigtiff: bool + first_ifd_offset: int + + +@dataclass +class IFDEntry: + """A single IFD entry with its resolved value.""" + tag: int + type_id: int + count: int + value: Any # resolved: int, float, tuple, bytes, or str + + +@dataclass +class IFD: + """Parsed Image File Directory.""" + entries: dict[int, IFDEntry] = field(default_factory=dict) + next_ifd_offset: int = 0 + + def get_value(self, tag: int, default: Any = None) -> Any: + """Get the resolved value for a tag, or default if absent.""" + entry = self.entries.get(tag) + if entry is None: + return default + return entry.value + + def get_values(self, tag: int) -> tuple | None: + """Get a tag's value as a tuple (even if scalar).""" + entry = self.entries.get(tag) + if entry is None: + return None + v = entry.value + if isinstance(v, tuple): + return v + return (v,) + + # Convenience properties + @property + def width(self) -> int: + return self.get_value(TAG_IMAGE_WIDTH, 0) + + @property + def height(self) -> int: + return self.get_value(TAG_IMAGE_LENGTH, 0) + + @property + def bits_per_sample(self) -> int | tuple: + v = self.get_value(TAG_BITS_PER_SAMPLE, 8) + if isinstance(v, tuple): + return v[0] if len(v) == 1 else v + return v + + @property + def samples_per_pixel(self) -> int: + return self.get_value(TAG_SAMPLES_PER_PIXEL, 1) + + @property + def sample_format(self) -> int: + v = self.get_value(TAG_SAMPLE_FORMAT, 1) + if isinstance(v, tuple): + return v[0] + return v + + @property + def compression(self) -> int: + return self.get_value(TAG_COMPRESSION, 1) + + @property + def predictor(self) -> int: + return self.get_value(TAG_PREDICTOR, 1) + + @property + def is_tiled(self) -> bool: + return TAG_TILE_WIDTH in self.entries + + @property + def tile_width(self) -> int: + return self.get_value(TAG_TILE_WIDTH, 0) + + @property + def tile_height(self) -> int: + return self.get_value(TAG_TILE_LENGTH, 0) + + @property + def rows_per_strip(self) -> int: + # Default: entire image in one strip + return self.get_value(TAG_ROWS_PER_STRIP, self.height) + + @property + def strip_offsets(self) -> tuple | None: + return self.get_values(TAG_STRIP_OFFSETS) + + @property + def strip_byte_counts(self) -> tuple | None: + return self.get_values(TAG_STRIP_BYTE_COUNTS) + + @property + def tile_offsets(self) -> tuple | None: + return self.get_values(TAG_TILE_OFFSETS) + + @property + def tile_byte_counts(self) -> tuple | None: + return self.get_values(TAG_TILE_BYTE_COUNTS) + + @property + def photometric(self) -> int: + return self.get_value(TAG_PHOTOMETRIC, 1) + + @property + def planar_config(self) -> int: + return self.get_value(TAG_PLANAR_CONFIG, 1) + + @property + def x_resolution(self) -> float | None: + """XResolution tag (282), or None if absent.""" + v = self.get_value(TAG_X_RESOLUTION) + return float(v) if v is not None else None + + @property + def y_resolution(self) -> float | None: + """YResolution tag (283), or None if absent.""" + v = self.get_value(TAG_Y_RESOLUTION) + return float(v) if v is not None else None + + @property + def resolution_unit(self) -> int | None: + """ResolutionUnit tag (296): 1=none, 2=inch, 3=cm. None if absent.""" + return self.get_value(TAG_RESOLUTION_UNIT) + + @property + def colormap(self) -> tuple | None: + """ColorMap tag (320) values, or None if absent.""" + return self.get_values(TAG_COLORMAP) + + @property + def gdal_metadata(self) -> str | None: + """GDALMetadata XML string (tag 42112), or None if absent.""" + v = self.get_value(TAG_GDAL_METADATA) + if v is None: + return None + if isinstance(v, bytes): + return v.rstrip(b'\x00').decode('ascii', errors='replace') + return str(v).rstrip('\x00') + + @property + def nodata_str(self) -> str | None: + """GDAL_NODATA tag value as string, or None.""" + v = self.get_value(TAG_GDAL_NODATA) + if v is None: + return None + if isinstance(v, bytes): + return v.rstrip(b'\x00').decode('ascii', errors='replace') + return str(v).rstrip('\x00') + + +def parse_header(data: bytes | memoryview) -> TIFFHeader: + """Parse a TIFF/BigTIFF file header. + + Parameters + ---------- + data : bytes + At least the first 16 bytes of the file. + + Returns + ------- + TIFFHeader + """ + if len(data) < 8: + raise ValueError("Not enough data for TIFF header") + + bom = data[0:2] + if bom == b'II': + bo = '<' + elif bom == b'MM': + bo = '>' + else: + raise ValueError(f"Invalid TIFF byte order marker: {bom!r}") + + magic = struct.unpack_from(f'{bo}H', data, 2)[0] + + if magic == 42: + # Standard TIFF + offset = struct.unpack_from(f'{bo}I', data, 4)[0] + return TIFFHeader(byte_order=bo, is_bigtiff=False, first_ifd_offset=offset) + elif magic == 43: + # BigTIFF + if len(data) < 16: + raise ValueError("Not enough data for BigTIFF header") + offset_size = struct.unpack_from(f'{bo}H', data, 4)[0] + if offset_size != 8: + raise ValueError(f"Unexpected BigTIFF offset size: {offset_size}") + # skip 2 bytes padding + offset = struct.unpack_from(f'{bo}Q', data, 8)[0] + return TIFFHeader(byte_order=bo, is_bigtiff=True, first_ifd_offset=offset) + else: + raise ValueError(f"Invalid TIFF magic number: {magic}") + + +def _read_value(data: bytes | memoryview, offset: int, type_id: int, + count: int, bo: str) -> Any: + """Read a typed value array from data at the given offset.""" + type_size = TIFF_TYPE_SIZES.get(type_id, 1) + + if type_id == ASCII: + raw = bytes(data[offset:offset + count]) + # Strip trailing null + return raw.rstrip(b'\x00').decode('ascii', errors='replace') + + if type_id == UNDEFINED: + return bytes(data[offset:offset + count]) + + if type_id == RATIONAL: + values = [] + for i in range(count): + off = offset + i * 8 + num = struct.unpack_from(f'{bo}I', data, off)[0] + den = struct.unpack_from(f'{bo}I', data, off + 4)[0] + values.append(num / den if den != 0 else 0.0) + return tuple(values) if count > 1 else values[0] + + if type_id == SRATIONAL: + values = [] + for i in range(count): + off = offset + i * 8 + num = struct.unpack_from(f'{bo}i', data, off)[0] + den = struct.unpack_from(f'{bo}i', data, off + 4)[0] + values.append(num / den if den != 0 else 0.0) + return tuple(values) if count > 1 else values[0] + + fmt_char = TIFF_TYPE_STRUCT_CODES.get(type_id) + if fmt_char is None: + return bytes(data[offset:offset + count * type_size]) + + if count == 1: + return struct.unpack_from(f'{bo}{fmt_char}', data, offset)[0] + + # Batch unpack: single call for all elements + return struct.unpack_from(f'{bo}{count}{fmt_char}', data, offset) + + +def parse_ifd(data: bytes | memoryview, offset: int, + header: TIFFHeader) -> IFD: + """Parse a single IFD at the given offset. + + Parameters + ---------- + data : bytes + Full file data (or at least enough of it). + offset : int + Byte offset of this IFD. + header : TIFFHeader + Parsed file header. + + Returns + ------- + IFD + """ + bo = header.byte_order + is_big = header.is_bigtiff + + if is_big: + num_entries = struct.unpack_from(f'{bo}Q', data, offset)[0] + entry_offset = offset + 8 + entry_size = 20 + else: + num_entries = struct.unpack_from(f'{bo}H', data, offset)[0] + entry_offset = offset + 2 + entry_size = 12 + + inline_max = 8 if is_big else 4 + entries = {} + + for i in range(num_entries): + eo = entry_offset + i * entry_size + + if is_big: + tag = struct.unpack_from(f'{bo}H', data, eo)[0] + type_id = struct.unpack_from(f'{bo}H', data, eo + 2)[0] + count = struct.unpack_from(f'{bo}Q', data, eo + 4)[0] + value_area_offset = eo + 12 + else: + tag = struct.unpack_from(f'{bo}H', data, eo)[0] + type_id = struct.unpack_from(f'{bo}H', data, eo + 2)[0] + count = struct.unpack_from(f'{bo}I', data, eo + 4)[0] + value_area_offset = eo + 8 + + type_size = TIFF_TYPE_SIZES.get(type_id, 1) + total_size = count * type_size + + if total_size <= inline_max: + value = _read_value(data, value_area_offset, type_id, count, bo) + else: + if is_big: + ptr = struct.unpack_from(f'{bo}Q', data, value_area_offset)[0] + else: + ptr = struct.unpack_from(f'{bo}I', data, value_area_offset)[0] + value = _read_value(data, ptr, type_id, count, bo) + + entries[tag] = IFDEntry(tag=tag, type_id=type_id, count=count, value=value) + + # Next IFD offset + next_offset_pos = entry_offset + num_entries * entry_size + if is_big: + next_ifd = struct.unpack_from(f'{bo}Q', data, next_offset_pos)[0] + else: + next_ifd = struct.unpack_from(f'{bo}I', data, next_offset_pos)[0] + + return IFD(entries=entries, next_ifd_offset=next_ifd) + + +def parse_all_ifds(data: bytes | memoryview, + header: TIFFHeader) -> list[IFD]: + """Parse all IFDs in a TIFF file. + + Parameters + ---------- + data : bytes + Full file data. + header : TIFFHeader + Parsed file header. + + Returns + ------- + list[IFD] + """ + ifds = [] + offset = header.first_ifd_offset + seen = set() + + while offset != 0 and offset not in seen: + seen.add(offset) + if offset >= len(data): + break + ifd = parse_ifd(data, offset, header) + ifds.append(ifd) + offset = ifd.next_ifd_offset + + return ifds diff --git a/xrspatial/geotiff/_reader.py b/xrspatial/geotiff/_reader.py new file mode 100644 index 00000000..8b15c544 --- /dev/null +++ b/xrspatial/geotiff/_reader.py @@ -0,0 +1,701 @@ +"""TIFF/COG reader: tile/strip assembly, windowed reads, HTTP range requests.""" +from __future__ import annotations + +import math +import mmap +import threading +import urllib.request + +import numpy as np + +from ._compression import ( + COMPRESSION_NONE, + decompress, + fp_predictor_decode, + predictor_decode, + unpack_bits, +) +from ._dtypes import SUB_BYTE_BPS, tiff_dtype_to_numpy +from ._geotags import GeoInfo, GeoTransform, extract_geo_info +from ._header import IFD, TIFFHeader, parse_all_ifds, parse_header + + +# --------------------------------------------------------------------------- +# Data source abstraction +# --------------------------------------------------------------------------- + +class _MmapCache: + """Thread-safe, reference-counted mmap cache. + + Multiple threads reading the same file share a single read-only mmap. + The mmap is closed when the last reference is released. + mmap slicing on a read-only mapping is thread-safe (no seek involved). + """ + + def __init__(self): + self._lock = threading.Lock() + # path -> (fh, mm, refcount) + self._entries: dict[str, tuple] = {} + + def acquire(self, path: str): + """Get or create a read-only mmap for *path*. Returns (mm, size).""" + import os + real = os.path.realpath(path) + with self._lock: + if real in self._entries: + fh, mm, size, rc = self._entries[real] + self._entries[real] = (fh, mm, size, rc + 1) + return mm, size + + fh = open(real, 'rb') + fh.seek(0, 2) + size = fh.tell() + fh.seek(0) + if size > 0: + mm = mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ) + else: + mm = None + self._entries[real] = (fh, mm, size, 1) + return mm, size + + def release(self, path: str): + """Decrement the reference count; close the mmap when it hits zero.""" + import os + real = os.path.realpath(path) + with self._lock: + entry = self._entries.get(real) + if entry is None: + return + fh, mm, size, rc = entry + rc -= 1 + if rc <= 0: + del self._entries[real] + if mm is not None: + mm.close() + fh.close() + else: + self._entries[real] = (fh, mm, size, rc) + + +# Module-level cache shared across all reads +_mmap_cache = _MmapCache() + + +class _FileSource: + """Local file data source using a shared, thread-safe mmap cache.""" + + def __init__(self, path: str): + self._path = path + self._mm, self._size = _mmap_cache.acquire(path) + + def read_range(self, start: int, length: int) -> bytes: + if self._mm is not None: + return self._mm[start:start + length] + return b'' + + def read_all(self): + """Return mmap object (supports slicing, struct.unpack_from, len).""" + if self._mm is not None: + return self._mm + return b'' + + @property + def size(self) -> int: + return self._size + + def close(self): + _mmap_cache.release(self._path) + + +def _get_http_pool(): + """Return a module-level urllib3 PoolManager, or None if unavailable.""" + global _http_pool + if _http_pool is not None: + return _http_pool + try: + import urllib3 + _http_pool = urllib3.PoolManager( + num_pools=10, + maxsize=10, + retries=urllib3.Retry(total=2, backoff_factor=0.1), + ) + return _http_pool + except ImportError: + return None + + +_http_pool = None + + +class _HTTPSource: + """HTTP data source using range requests with connection reuse. + + Uses urllib3.PoolManager when available (reuses TCP connections and + TLS sessions across range requests to the same host). Falls back to + stdlib urllib.request if urllib3 is not installed. + """ + + def __init__(self, url: str): + self._url = url + self._size = None + self._pool = _get_http_pool() + + def read_range(self, start: int, length: int) -> bytes: + end = start + length - 1 + if self._pool is not None: + resp = self._pool.request( + 'GET', self._url, + headers={'Range': f'bytes={start}-{end}'}, + ) + return resp.data + # Fallback: stdlib + req = urllib.request.Request( + self._url, + headers={'Range': f'bytes={start}-{end}'}, + ) + with urllib.request.urlopen(req) as resp: + return resp.read() + + def read_all(self) -> bytes: + if self._pool is not None: + resp = self._pool.request('GET', self._url) + return resp.data + with urllib.request.urlopen(self._url) as resp: + return resp.read() + + @property + def size(self) -> int | None: + return self._size + + def close(self): + pass + + +_CLOUD_SCHEMES = ('s3://', 'gs://', 'az://', 'abfs://') + + +def _is_fsspec_uri(path: str) -> bool: + """Check if a path is a fsspec-compatible URI (not http/https/local).""" + if path.startswith(('http://', 'https://')): + return False + return '://' in path + + +class _CloudSource: + """Cloud storage data source using fsspec. + + Supports S3, GCS, Azure Blob Storage, and any other fsspec backend. + Requires the appropriate library (s3fs, gcsfs, adlfs) to be installed. + """ + + def __init__(self, url: str, **storage_options): + try: + import fsspec + except ImportError: + raise ImportError( + "fsspec is required to read from cloud storage. " + "Install it with: pip install fsspec") + self._url = url + self._fs, self._path = fsspec.core.url_to_fs(url, **storage_options) + self._size = self._fs.size(self._path) + + def read_range(self, start: int, length: int) -> bytes: + with self._fs.open(self._path, 'rb') as f: + f.seek(start) + return f.read(length) + + def read_all(self) -> bytes: + with self._fs.open(self._path, 'rb') as f: + return f.read() + + @property + def size(self) -> int: + return self._size + + def close(self): + pass + + +def _open_source(source: str): + """Open a data source (local file, URL, or cloud path).""" + if source.startswith(('http://', 'https://')): + return _HTTPSource(source) + if _is_fsspec_uri(source): + return _CloudSource(source) + return _FileSource(source) + + +def _apply_predictor(chunk: np.ndarray, pred: int, width: int, + height: int, bytes_per_sample: int) -> np.ndarray: + """Apply the appropriate predictor decode to decompressed data.""" + if pred == 2: + return predictor_decode(chunk, width, height, bytes_per_sample) + elif pred == 3: + return fp_predictor_decode(chunk, width, height, bytes_per_sample) + return chunk + + +def _packed_byte_count(pixel_count: int, bps: int) -> int: + """Compute the number of packed bytes for sub-byte bit depths.""" + return (pixel_count * bps + 7) // 8 + + +def _decode_strip_or_tile(data_slice, compression, width, height, samples, + bps, bytes_per_sample, is_sub_byte, dtype, pred, + byte_order='<'): + """Decompress, apply predictor, unpack sub-byte, and reshape a strip/tile. + + Parameters + ---------- + byte_order : str + '<' for little-endian, '>' for big-endian. When the file byte + order differs from the system's native order, pixel data is + byte-swapped after decompression. + + Returns an array shaped (height, width) or (height, width, samples). + """ + pixel_count = width * height * samples + if is_sub_byte: + expected = _packed_byte_count(pixel_count, bps) + else: + expected = pixel_count * bytes_per_sample + + chunk = decompress(data_slice, compression, expected, + width=width, height=height, samples=samples) + + if pred in (2, 3) and not is_sub_byte: + if not chunk.flags.writeable: + chunk = chunk.copy() + chunk = _apply_predictor(chunk, pred, width, height, + bytes_per_sample * samples) + + if is_sub_byte: + pixels = unpack_bits(chunk, bps, pixel_count) + else: + # Use the file's byte order for the view, then convert to native + file_dtype = dtype.newbyteorder(byte_order) + pixels = chunk.view(file_dtype) + if file_dtype.byteorder not in ('=', '|', _NATIVE_ORDER): + pixels = pixels.astype(dtype) + + if samples > 1: + return pixels.reshape(height, width, samples) + return pixels.reshape(height, width) + + +import sys as _sys +_NATIVE_ORDER = '<' if _sys.byteorder == 'little' else '>' + + +# --------------------------------------------------------------------------- +# Strip reader +# --------------------------------------------------------------------------- + +def _read_strips(data: bytes, ifd: IFD, header: TIFFHeader, + dtype: np.dtype, window=None) -> np.ndarray: + """Read a strip-organized TIFF image. + + Parameters + ---------- + data : bytes + Full file data. + ifd : IFD + Parsed IFD for this image. + header : TIFFHeader + File header. + dtype : np.dtype + Output pixel dtype. + window : tuple or None + (row_start, col_start, row_stop, col_stop) or None for full image. + + Returns + ------- + np.ndarray with shape (height, width) or windowed subset. + """ + width = ifd.width + height = ifd.height + samples = ifd.samples_per_pixel + compression = ifd.compression + rps = ifd.rows_per_strip + offsets = ifd.strip_offsets + byte_counts = ifd.strip_byte_counts + pred = ifd.predictor + bps = ifd.bits_per_sample + if isinstance(bps, tuple): + bps = bps[0] + bytes_per_sample = bps // 8 + is_sub_byte = bps in SUB_BYTE_BPS + + if offsets is None or byte_counts is None: + raise ValueError("Missing strip offsets or byte counts") + + planar = ifd.planar_config # 1=chunky (interleaved), 2=planar (separate) + + # Determine output region + if window is not None: + r0, c0, r1, c1 = window + r0 = max(0, r0) + c0 = max(0, c0) + r1 = min(height, r1) + c1 = min(width, c1) + else: + r0, c0, r1, c1 = 0, 0, height, width + + out_h = r1 - r0 + out_w = c1 - c0 + + if samples > 1: + result = np.empty((out_h, out_w, samples), dtype=dtype) + else: + result = np.empty((out_h, out_w), dtype=dtype) + + if planar == 2 and samples > 1: + strips_per_band = math.ceil(height / rps) + first_strip = r0 // rps + last_strip = min((r1 - 1) // rps, strips_per_band - 1) + + for band_idx in range(samples): + band_offset = band_idx * strips_per_band + for strip_idx in range(first_strip, last_strip + 1): + global_idx = band_offset + strip_idx + if global_idx >= len(offsets): + continue + strip_row = strip_idx * rps + strip_rows = min(rps, height - strip_row) + if strip_rows <= 0: + continue + + strip_data = data[offsets[global_idx]:offsets[global_idx] + byte_counts[global_idx]] + strip_pixels = _decode_strip_or_tile( + strip_data, compression, width, strip_rows, 1, + bps, bytes_per_sample, is_sub_byte, dtype, pred, + byte_order=header.byte_order) + + src_r0 = max(r0 - strip_row, 0) + src_r1 = min(r1 - strip_row, strip_rows) + dst_r0 = max(strip_row - r0, 0) + dst_r1 = dst_r0 + (src_r1 - src_r0) + if dst_r1 > dst_r0: + result[dst_r0:dst_r1, :, band_idx] = strip_pixels[src_r0:src_r1, c0:c1] + else: + first_strip = r0 // rps + last_strip = min((r1 - 1) // rps, len(offsets) - 1) + + for strip_idx in range(first_strip, last_strip + 1): + strip_row = strip_idx * rps + strip_rows = min(rps, height - strip_row) + if strip_rows <= 0: + continue + + strip_data = data[offsets[strip_idx]:offsets[strip_idx] + byte_counts[strip_idx]] + strip_pixels = _decode_strip_or_tile( + strip_data, compression, width, strip_rows, samples, + bps, bytes_per_sample, is_sub_byte, dtype, pred, + byte_order=header.byte_order) + + src_r0 = max(r0 - strip_row, 0) + src_r1 = min(r1 - strip_row, strip_rows) + dst_r0 = max(strip_row - r0, 0) + dst_r1 = dst_r0 + (src_r1 - src_r0) + if dst_r1 > dst_r0: + result[dst_r0:dst_r1] = strip_pixels[src_r0:src_r1, c0:c1] + + return result + + +# --------------------------------------------------------------------------- +# Tile reader +# --------------------------------------------------------------------------- + +def _read_tiles(data: bytes, ifd: IFD, header: TIFFHeader, + dtype: np.dtype, window=None) -> np.ndarray: + """Read a tile-organized TIFF image. + + Parameters + ---------- + data : bytes + Full file data. + ifd : IFD + Parsed IFD for this image. + header : TIFFHeader + File header. + dtype : np.dtype + Output pixel dtype. + window : tuple or None + (row_start, col_start, row_stop, col_stop) or None for full image. + + Returns + ------- + np.ndarray with shape (height, width) or windowed subset. + """ + width = ifd.width + height = ifd.height + tw = ifd.tile_width + th = ifd.tile_height + samples = ifd.samples_per_pixel + compression = ifd.compression + pred = ifd.predictor + bps = ifd.bits_per_sample + if isinstance(bps, tuple): + bps = bps[0] + bytes_per_sample = bps // 8 + is_sub_byte = bps in SUB_BYTE_BPS + + offsets = ifd.tile_offsets + byte_counts = ifd.tile_byte_counts + if offsets is None or byte_counts is None: + raise ValueError("Missing tile offsets or byte counts") + + planar = ifd.planar_config + tiles_across = math.ceil(width / tw) + tiles_down = math.ceil(height / th) + + if window is not None: + r0, c0, r1, c1 = window + r0 = max(0, r0) + c0 = max(0, c0) + r1 = min(height, r1) + c1 = min(width, c1) + else: + r0, c0, r1, c1 = 0, 0, height, width + + out_h = r1 - r0 + out_w = c1 - c0 + + _alloc = np.zeros if window is not None else np.empty + if samples > 1: + result = _alloc((out_h, out_w, samples), dtype=dtype) + else: + result = _alloc((out_h, out_w), dtype=dtype) + + tile_row_start = r0 // th + tile_row_end = min(math.ceil(r1 / th), tiles_down) + tile_col_start = c0 // tw + tile_col_end = min(math.ceil(c1 / tw), tiles_across) + + band_count = samples if (planar == 2 and samples > 1) else 1 + tiles_per_band = tiles_across * tiles_down + + for band_idx in range(band_count): + band_tile_offset = band_idx * tiles_per_band if band_count > 1 else 0 + tile_samples = 1 if band_count > 1 else samples + + for tr in range(tile_row_start, tile_row_end): + for tc in range(tile_col_start, tile_col_end): + tile_idx = band_tile_offset + tr * tiles_across + tc + if tile_idx >= len(offsets): + continue + + tile_data = data[offsets[tile_idx]:offsets[tile_idx] + byte_counts[tile_idx]] + tile_pixels = _decode_strip_or_tile( + tile_data, compression, tw, th, tile_samples, + bps, bytes_per_sample, is_sub_byte, dtype, pred, + byte_order=header.byte_order) + + tile_r0 = tr * th + tile_c0 = tc * tw + + src_r0 = max(r0 - tile_r0, 0) + src_c0 = max(c0 - tile_c0, 0) + src_r1 = min(r1 - tile_r0, th) + src_c1 = min(c1 - tile_c0, tw) + + dst_r0 = max(tile_r0 - r0, 0) + dst_c0 = max(tile_c0 - c0, 0) + + actual_tile_h = min(th, height - tile_r0) + actual_tile_w = min(tw, width - tile_c0) + src_r1 = min(src_r1, actual_tile_h) + src_c1 = min(src_c1, actual_tile_w) + dst_r1 = dst_r0 + (src_r1 - src_r0) + dst_c1 = dst_c0 + (src_c1 - src_c0) + + if dst_r1 > dst_r0 and dst_c1 > dst_c0: + src_slice = tile_pixels[src_r0:src_r1, src_c0:src_c1] + if band_count > 1: + result[dst_r0:dst_r1, dst_c0:dst_c1, band_idx] = src_slice + else: + result[dst_r0:dst_r1, dst_c0:dst_c1] = src_slice + + return result + + +# --------------------------------------------------------------------------- +# COG HTTP reader +# --------------------------------------------------------------------------- + +def _read_cog_http(url: str, overview_level: int | None = None, + band: int | None = None) -> tuple[np.ndarray, GeoInfo]: + """Read a COG via HTTP range requests. + + Parameters + ---------- + url : str + HTTP(S) URL to the COG file. + overview_level : int or None + Which overview to read (0 = full res, 1 = first overview, etc.). + band : int + Band index (0-based, for multi-band files). + + Returns + ------- + (array, geo_info) tuple + """ + source = _HTTPSource(url) + + # Initial fetch: get header + IFDs (COGs put metadata first) + header_bytes = source.read_range(0, 16384) + + header = parse_header(header_bytes) + ifds = parse_all_ifds(header_bytes, header) + + # If we didn't get all IFDs, try a larger fetch + if len(ifds) == 0: + header_bytes = source.read_range(0, 65536) + ifds = parse_all_ifds(header_bytes, header) + + if len(ifds) == 0: + raise ValueError("No IFDs found in COG") + + # Select IFD based on overview level + ifd_idx = 0 + if overview_level is not None: + ifd_idx = min(overview_level, len(ifds) - 1) + ifd = ifds[ifd_idx] + + bps = ifd.bits_per_sample + if isinstance(bps, tuple): + bps = bps[0] + dtype = tiff_dtype_to_numpy(bps, ifd.sample_format) + geo_info = extract_geo_info(ifd, header_bytes, header.byte_order) + + # COGs are tiled -- fetch individual tiles + if not ifd.is_tiled: + # Fallback: fetch entire file + all_data = source.read_all() + arr = _read_strips(all_data, ifd, header, dtype) + source.close() + return arr, geo_info + + width = ifd.width + height = ifd.height + tw = ifd.tile_width + th = ifd.tile_height + samples = ifd.samples_per_pixel + compression = ifd.compression + pred = ifd.predictor + bytes_per_sample = bps // 8 + is_sub_byte = bps in SUB_BYTE_BPS + + offsets = ifd.tile_offsets + byte_counts = ifd.tile_byte_counts + + tiles_across = math.ceil(width / tw) + tiles_down = math.ceil(height / th) + + if samples > 1: + result = np.empty((height, width, samples), dtype=dtype) + else: + result = np.empty((height, width), dtype=dtype) + + for tr in range(tiles_down): + for tc in range(tiles_across): + tile_idx = tr * tiles_across + tc + if tile_idx >= len(offsets): + continue + + off = offsets[tile_idx] + bc = byte_counts[tile_idx] + if bc == 0: + continue + + tile_data = source.read_range(off, bc) + tile_pixels = _decode_strip_or_tile( + tile_data, compression, tw, th, samples, + bps, bytes_per_sample, is_sub_byte, dtype, pred, + byte_order=header.byte_order) + + # Place tile + y0 = tr * th + x0 = tc * tw + y1 = min(y0 + th, height) + x1 = min(x0 + tw, width) + actual_h = y1 - y0 + actual_w = x1 - x0 + result[y0:y1, x0:x1] = tile_pixels[:actual_h, :actual_w] + + source.close() + return result, geo_info + + +# --------------------------------------------------------------------------- +# Main read function +# --------------------------------------------------------------------------- + +def read_to_array(source: str, *, window=None, overview_level: int | None = None, + band: int | None = None) -> tuple[np.ndarray, GeoInfo]: + """Read a GeoTIFF/COG to a numpy array. + + Parameters + ---------- + source : str + File path or URL. + window : tuple or None + (row_start, col_start, row_stop, col_stop). + overview_level : int or None + Overview level (0 = full res). + band : int + Band index for multi-band files. + + Returns + ------- + (np.ndarray, GeoInfo) tuple + """ + if source.startswith(('http://', 'https://')): + return _read_cog_http(source, overview_level=overview_level, band=band) + + # Local file or cloud storage: read all bytes then parse + if _is_fsspec_uri(source): + src = _CloudSource(source) + else: + src = _FileSource(source) + data = src.read_all() + + try: + header = parse_header(data) + ifds = parse_all_ifds(data, header) + + if len(ifds) == 0: + raise ValueError("No IFDs found in TIFF file") + + # Select IFD + ifd_idx = 0 + if overview_level is not None: + ifd_idx = min(overview_level, len(ifds) - 1) + ifd = ifds[ifd_idx] + + bps = ifd.bits_per_sample + if isinstance(bps, tuple): + bps = bps[0] + dtype = tiff_dtype_to_numpy(bps, ifd.sample_format) + geo_info = extract_geo_info(ifd, data, header.byte_order) + + if ifd.is_tiled: + arr = _read_tiles(data, ifd, header, dtype, window) + else: + arr = _read_strips(data, ifd, header, dtype, window) + + # For multi-band with band selection, extract single band + if arr.ndim == 3 and ifd.samples_per_pixel > 1 and band is not None: + arr = arr[:, :, band] + + # MinIsWhite (photometric=0): invert single-band grayscale values + if ifd.photometric == 0 and ifd.samples_per_pixel == 1: + if arr.dtype.kind == 'u': + arr = np.iinfo(arr.dtype).max - arr + elif arr.dtype.kind == 'f': + arr = -arr + finally: + src.close() + + return arr, geo_info diff --git a/xrspatial/geotiff/_vrt.py b/xrspatial/geotiff/_vrt.py new file mode 100644 index 00000000..8a6f2671 --- /dev/null +++ b/xrspatial/geotiff/_vrt.py @@ -0,0 +1,481 @@ +"""Virtual Raster Table (VRT) reader. + +Parses GDAL VRT XML files and assembles a virtual raster from one or +more source GeoTIFF files using windowed reads. +""" +from __future__ import annotations + +import os +import xml.etree.ElementTree as ET +from dataclasses import dataclass, field + +import numpy as np + +# Lazy imports to avoid circular dependency +_DTYPE_MAP = { + 'Byte': np.uint8, + 'UInt16': np.uint16, + 'Int16': np.int16, + 'UInt32': np.uint32, + 'Int32': np.int32, + 'Float32': np.float32, + 'Float64': np.float64, + 'Int8': np.int8, +} + + +@dataclass +class _Rect: + """Pixel rectangle: (x_off, y_off, x_size, y_size).""" + x_off: int + y_off: int + x_size: int + y_size: int + + +@dataclass +class _Source: + """A single source region within a VRT band.""" + filename: str + band: int # 1-based + src_rect: _Rect + dst_rect: _Rect + nodata: float | None = None + # ComplexSource extras + scale: float | None = None + offset: float | None = None + + +@dataclass +class _VRTBand: + """A single band in a VRT dataset.""" + band_num: int # 1-based + dtype: np.dtype + nodata: float | None = None + sources: list[_Source] = field(default_factory=list) + color_interp: str | None = None + + +@dataclass +class VRTDataset: + """Parsed Virtual Raster Table.""" + width: int + height: int + crs_wkt: str | None = None + geo_transform: tuple | None = None # (origin_x, res_x, skew_x, origin_y, skew_y, res_y) + bands: list[_VRTBand] = field(default_factory=list) + + +def _parse_rect(elem) -> _Rect: + """Parse a SrcRect or DstRect element.""" + return _Rect( + x_off=int(float(elem.get('xOff', 0))), + y_off=int(float(elem.get('yOff', 0))), + x_size=int(float(elem.get('xSize', 0))), + y_size=int(float(elem.get('ySize', 0))), + ) + + +def _text(elem, tag, default=None): + """Get text content of a child element.""" + child = elem.find(tag) + if child is not None and child.text: + return child.text.strip() + return default + + +def parse_vrt(xml_str: str, vrt_dir: str = '.') -> VRTDataset: + """Parse a VRT XML string into a VRTDataset. + + Parameters + ---------- + xml_str : str + VRT XML content. + vrt_dir : str + Directory of the VRT file, for resolving relative source paths. + + Returns + ------- + VRTDataset + """ + root = ET.fromstring(xml_str) + + width = int(root.get('rasterXSize', 0)) + height = int(root.get('rasterYSize', 0)) + + # CRS + crs_wkt = _text(root, 'SRS') + + # GeoTransform: "origin_x, res_x, skew_x, origin_y, skew_y, res_y" + gt_str = _text(root, 'GeoTransform') + geo_transform = None + if gt_str: + parts = [float(x.strip()) for x in gt_str.split(',')] + if len(parts) == 6: + geo_transform = tuple(parts) + + # Bands + bands = [] + for band_elem in root.findall('VRTRasterBand'): + band_num = int(band_elem.get('band', 1)) + dtype_name = band_elem.get('dataType', 'Float32') + dtype = np.dtype(_DTYPE_MAP.get(dtype_name, np.float32)) + nodata_str = _text(band_elem, 'NoDataValue') + nodata = float(nodata_str) if nodata_str else None + color_interp = _text(band_elem, 'ColorInterp') + + sources = [] + for src_elem in band_elem: + tag = src_elem.tag + if tag not in ('SimpleSource', 'ComplexSource'): + continue + + filename = _text(src_elem, 'SourceFilename') or '' + relative = src_elem.find('SourceFilename') + is_relative = (relative is not None and + relative.get('relativeToVRT', '0') == '1') + if is_relative and not os.path.isabs(filename): + filename = os.path.join(vrt_dir, filename) + + src_band = int(_text(src_elem, 'SourceBand') or '1') + + src_rect_elem = src_elem.find('SrcRect') + dst_rect_elem = src_elem.find('DstRect') + if src_rect_elem is None or dst_rect_elem is None: + continue + + src_rect = _parse_rect(src_rect_elem) + dst_rect = _parse_rect(dst_rect_elem) + + src_nodata_str = _text(src_elem, 'NODATA') + src_nodata = float(src_nodata_str) if src_nodata_str else None + + # ComplexSource extras + scale = None + offset = None + if tag == 'ComplexSource': + scale_str = _text(src_elem, 'ScaleOffset') + offset_str = _text(src_elem, 'ScaleRatio') + # Note: GDAL uses ScaleOffset=offset, ScaleRatio=scale + if offset_str: + scale = float(offset_str) + if scale_str: + offset = float(scale_str) + + sources.append(_Source( + filename=filename, + band=src_band, + src_rect=src_rect, + dst_rect=dst_rect, + nodata=src_nodata, + scale=scale, + offset=offset, + )) + + bands.append(_VRTBand( + band_num=band_num, + dtype=dtype, + nodata=nodata, + sources=sources, + color_interp=color_interp, + )) + + return VRTDataset( + width=width, + height=height, + crs_wkt=crs_wkt, + geo_transform=geo_transform, + bands=bands, + ) + + +def read_vrt(vrt_path: str, *, window=None, + band: int | None = None) -> tuple[np.ndarray, VRTDataset]: + """Read a VRT file by assembling pixel data from its source files. + + Parameters + ---------- + vrt_path : str + Path to the .vrt file. + window : tuple or None + (row_start, col_start, row_stop, col_stop) for windowed read. + band : int or None + Band index (0-based). None returns all bands. + + Returns + ------- + (np.ndarray, VRTDataset) tuple + """ + from ._reader import read_to_array + + with open(vrt_path, 'r') as f: + xml_str = f.read() + + vrt_dir = os.path.dirname(os.path.abspath(vrt_path)) + vrt = parse_vrt(xml_str, vrt_dir) + + if window is not None: + r0, c0, r1, c1 = window + r0 = max(0, r0) + c0 = max(0, c0) + r1 = min(vrt.height, r1) + c1 = min(vrt.width, c1) + else: + r0, c0, r1, c1 = 0, 0, vrt.height, vrt.width + + out_h = r1 - r0 + out_w = c1 - c0 + + # Select bands + if band is not None: + selected_bands = [vrt.bands[band]] + else: + selected_bands = vrt.bands + + # Allocate output + if len(selected_bands) == 1: + dtype = selected_bands[0].dtype + result = np.full((out_h, out_w), np.nan if dtype.kind == 'f' else 0, + dtype=dtype) + else: + dtype = selected_bands[0].dtype + result = np.full((out_h, out_w, len(selected_bands)), + np.nan if dtype.kind == 'f' else 0, dtype=dtype) + + for band_idx, vrt_band in enumerate(selected_bands): + nodata = vrt_band.nodata + + for src in vrt_band.sources: + # Compute overlap between source's destination rect and our window + dr = src.dst_rect + sr = src.src_rect + + # Destination rect in virtual raster coordinates + dst_r0 = dr.y_off + dst_c0 = dr.x_off + dst_r1 = dr.y_off + dr.y_size + dst_c1 = dr.x_off + dr.x_size + + # Clip to window + clip_r0 = max(dst_r0, r0) + clip_c0 = max(dst_c0, c0) + clip_r1 = min(dst_r1, r1) + clip_c1 = min(dst_c1, c1) + + if clip_r0 >= clip_r1 or clip_c0 >= clip_c1: + continue # no overlap + + # Map back to source coordinates + # Scale factor: source pixels per destination pixel + scale_y = sr.y_size / dr.y_size if dr.y_size > 0 else 1.0 + scale_x = sr.x_size / dr.x_size if dr.x_size > 0 else 1.0 + + src_r0 = sr.y_off + int((clip_r0 - dst_r0) * scale_y) + src_c0 = sr.x_off + int((clip_c0 - dst_c0) * scale_x) + src_r1 = sr.y_off + int((clip_r1 - dst_r0) * scale_y) + src_c1 = sr.x_off + int((clip_c1 - dst_c0) * scale_x) + + # Read from source file using windowed read + try: + src_arr, _ = read_to_array( + src.filename, + window=(src_r0, src_c0, src_r1, src_c1), + band=src.band - 1, # convert 1-based to 0-based + ) + except Exception: + continue # skip missing/unreadable sources + + # Handle source nodata + src_nodata = src.nodata or nodata + if src_nodata is not None and src_arr.dtype.kind == 'f': + src_arr = src_arr.copy() + src_arr[src_arr == np.float32(src_nodata)] = np.nan + + # Apply ComplexSource scaling + if src.scale is not None and src.scale != 1.0: + src_arr = src_arr.astype(np.float64) * src.scale + if src.offset is not None and src.offset != 0.0: + src_arr = src_arr.astype(np.float64) + src.offset + + # Place into output + out_r0 = clip_r0 - r0 + out_c0 = clip_c0 - c0 + out_r1 = out_r0 + src_arr.shape[0] + out_c1 = out_c0 + src_arr.shape[1] + + # Handle size mismatch from rounding + actual_h = min(src_arr.shape[0], out_r1 - out_r0) + actual_w = min(src_arr.shape[1], out_c1 - out_c0) + + if len(selected_bands) == 1: + result[out_r0:out_r0 + actual_h, + out_c0:out_c0 + actual_w] = src_arr[:actual_h, :actual_w] + else: + result[out_r0:out_r0 + actual_h, + out_c0:out_c0 + actual_w, + band_idx] = src_arr[:actual_h, :actual_w] + + return result, vrt + + +# --------------------------------------------------------------------------- +# VRT writer +# --------------------------------------------------------------------------- + +_NP_TO_VRT_DTYPE = {v: k for k, v in _DTYPE_MAP.items()} + + +def write_vrt(vrt_path: str, source_files: list[str], *, + relative: bool = True, + crs_wkt: str | None = None, + nodata: float | None = None) -> str: + """Generate a VRT file that mosaics multiple GeoTIFF tiles. + + Each source file is placed in the virtual raster based on its + geo transform. Files must share the same CRS and pixel size. + + Parameters + ---------- + vrt_path : str + Output .vrt file path. + source_files : list of str + Paths to the source GeoTIFF files. + relative : bool + Store source paths relative to the VRT file. + crs_wkt : str or None + CRS as WKT string. If None, taken from the first source. + nodata : float or None + NoData value. If None, taken from the first source. + + Returns + ------- + str + Path to the written VRT file. + """ + from ._reader import read_to_array + from ._header import parse_header, parse_all_ifds + from ._geotags import extract_geo_info + from ._reader import _FileSource + + if not source_files: + raise ValueError("source_files must not be empty") + + # Read metadata from all sources + sources_meta = [] + for src_path in source_files: + src = _FileSource(src_path) + data = src.read_all() + header = parse_header(data) + ifds = parse_all_ifds(data, header) + ifd = ifds[0] + geo = extract_geo_info(ifd, data, header.byte_order) + src.close() + + bps = ifd.bits_per_sample + if isinstance(bps, tuple): + bps = bps[0] + + sources_meta.append({ + 'path': src_path, + 'width': ifd.width, + 'height': ifd.height, + 'bands': ifd.samples_per_pixel, + 'dtype': np.dtype(_DTYPE_MAP.get( + {v: k for k, v in _DTYPE_MAP.items()}.get( + np.dtype(f'{"f" if ifd.sample_format == 3 else ("i" if ifd.sample_format == 2 else "u")}{bps // 8}').type, + 'Float32'), + np.float32)), + 'bps': bps, + 'sample_format': ifd.sample_format, + 'transform': geo.transform, + 'crs_wkt': geo.crs_wkt, + 'nodata': geo.nodata, + }) + + first = sources_meta[0] + res_x = first['transform'].pixel_width + res_y = first['transform'].pixel_height + + # Compute the bounding box of all sources + all_x0, all_y0, all_x1, all_y1 = [], [], [], [] + for m in sources_meta: + t = m['transform'] + x0 = t.origin_x + y0 = t.origin_y + x1 = x0 + m['width'] * t.pixel_width + y1 = y0 + m['height'] * t.pixel_height + all_x0.append(min(x0, x1)) + all_y0.append(min(y0, y1)) + all_x1.append(max(x0, x1)) + all_y1.append(max(y0, y1)) + + mosaic_x0 = min(all_x0) + mosaic_y_top = max(all_y1) # top edge (y increases upward in geo) + mosaic_x1 = max(all_x1) + mosaic_y_bottom = min(all_y0) + + total_w = int(round((mosaic_x1 - mosaic_x0) / abs(res_x))) + total_h = int(round((mosaic_y_top - mosaic_y_bottom) / abs(res_y))) + + # Determine VRT dtype + sf = first['sample_format'] + bps = first['bps'] + if sf == 3: + vrt_dtype_name = 'Float64' if bps == 64 else 'Float32' + elif sf == 2: + vrt_dtype_name = {8: 'Int8', 16: 'Int16', 32: 'Int32'}.get(bps, 'Int32') + else: + vrt_dtype_name = {8: 'Byte', 16: 'UInt16', 32: 'UInt32'}.get(bps, 'Byte') + + srs = crs_wkt or first.get('crs_wkt') or '' + nd = nodata if nodata is not None else first.get('nodata') + + vrt_dir = os.path.dirname(os.path.abspath(vrt_path)) + n_bands = first['bands'] + + # Build XML + lines = [f''] + if srs: + lines.append(f' {srs}') + lines.append(f' {mosaic_x0}, {res_x}, 0.0, ' + f'{mosaic_y_top}, 0.0, {res_y}') + + for band_num in range(1, n_bands + 1): + lines.append(f' ') + if nd is not None: + lines.append(f' {nd}') + + for m in sources_meta: + t = m['transform'] + # Pixel offset in the virtual raster + dst_x_off = int(round((t.origin_x - mosaic_x0) / abs(res_x))) + dst_y_off = int(round((mosaic_y_top - t.origin_y) / abs(res_y))) + + fname = m['path'] + rel_attr = '0' + if relative: + try: + fname = os.path.relpath(fname, vrt_dir) + rel_attr = '1' + except ValueError: + pass # different drives on Windows + + lines.append(' ') + lines.append(f' ' + f'{fname}') + lines.append(f' {band_num}') + lines.append(f' ') + lines.append(f' ') + lines.append(' ') + + lines.append(' ') + + lines.append('') + + xml = '\n'.join(lines) + '\n' + with open(vrt_path, 'w') as f: + f.write(xml) + + return vrt_path diff --git a/xrspatial/geotiff/_writer.py b/xrspatial/geotiff/_writer.py new file mode 100644 index 00000000..ae7658ab --- /dev/null +++ b/xrspatial/geotiff/_writer.py @@ -0,0 +1,875 @@ +"""GeoTIFF/COG writer.""" +from __future__ import annotations + +import math +import struct + +import numpy as np + +from ._compression import ( + COMPRESSION_DEFLATE, + COMPRESSION_LZW, + COMPRESSION_NONE, + COMPRESSION_PACKBITS, + COMPRESSION_ZSTD, + compress, + predictor_encode, +) +from ._dtypes import ( + DOUBLE, + RATIONAL, + SHORT, + LONG, + ASCII, + numpy_to_tiff_dtype, + TIFF_TYPE_SIZES, +) +from ._geotags import ( + GeoTransform, + build_geo_tags, + TAG_GEO_KEY_DIRECTORY, + TAG_GDAL_NODATA, + TAG_MODEL_PIXEL_SCALE, + TAG_MODEL_TIEPOINT, +) +from ._header import ( + TAG_IMAGE_WIDTH, + TAG_IMAGE_LENGTH, + TAG_BITS_PER_SAMPLE, + TAG_COMPRESSION, + TAG_PHOTOMETRIC, + TAG_SAMPLES_PER_PIXEL, + TAG_SAMPLE_FORMAT, + TAG_STRIP_OFFSETS, + TAG_ROWS_PER_STRIP, + TAG_STRIP_BYTE_COUNTS, + TAG_X_RESOLUTION, + TAG_Y_RESOLUTION, + TAG_RESOLUTION_UNIT, + TAG_TILE_WIDTH, + TAG_TILE_LENGTH, + TAG_TILE_OFFSETS, + TAG_TILE_BYTE_COUNTS, + TAG_EXTRA_SAMPLES, + TAG_PREDICTOR, + TAG_GDAL_METADATA, +) + +# Byte order: always write little-endian +BO = '<' + + +def _compression_tag(compression_name: str) -> int: + """Convert compression name to TIFF tag value.""" + _map = { + 'none': COMPRESSION_NONE, + 'deflate': COMPRESSION_DEFLATE, + 'lzw': COMPRESSION_LZW, + 'packbits': COMPRESSION_PACKBITS, + 'zstd': COMPRESSION_ZSTD, + } + name = compression_name.lower() + if name not in _map: + raise ValueError(f"Unsupported compression: {compression_name!r}. " + f"Use one of: {list(_map.keys())}") + return _map[name] + + +OVERVIEW_METHODS = ('mean', 'nearest', 'min', 'max', 'median', 'mode', 'cubic') + + +def _block_reduce_2d(arr2d, method): + """2x block-reduce a single 2D plane using *method*.""" + h, w = arr2d.shape + h2 = (h // 2) * 2 + w2 = (w // 2) * 2 + cropped = arr2d[:h2, :w2] + oh, ow = h2 // 2, w2 // 2 + + if method == 'nearest': + # Top-left pixel of each 2x2 block + return cropped[::2, ::2].copy() + + if method == 'cubic': + try: + from scipy.ndimage import zoom + except ImportError: + raise ImportError( + "scipy is required for cubic overview resampling. " + "Install it with: pip install scipy") + return zoom(arr2d, 0.5, order=3).astype(arr2d.dtype) + + if method == 'mode': + # Most-common value per 2x2 block (useful for classified rasters) + blocks = cropped.reshape(oh, 2, ow, 2).transpose(0, 2, 1, 3).reshape(oh, ow, 4) + out = np.empty((oh, ow), dtype=arr2d.dtype) + for r in range(oh): + for c in range(ow): + vals, counts = np.unique(blocks[r, c], return_counts=True) + out[r, c] = vals[counts.argmax()] + return out + + # Block reshape for mean/min/max/median + if arr2d.dtype.kind == 'f': + blocks = cropped.reshape(oh, 2, ow, 2) + else: + blocks = cropped.astype(np.float64).reshape(oh, 2, ow, 2) + + if method == 'mean': + result = np.nanmean(blocks, axis=(1, 3)) + elif method == 'min': + result = np.nanmin(blocks, axis=(1, 3)) + elif method == 'max': + result = np.nanmax(blocks, axis=(1, 3)) + elif method == 'median': + flat = blocks.transpose(0, 2, 1, 3).reshape(oh, ow, 4) + result = np.nanmedian(flat, axis=2) + else: + raise ValueError( + f"Unknown overview resampling method: {method!r}. " + f"Use one of: {OVERVIEW_METHODS}") + + if arr2d.dtype.kind != 'f': + return np.round(result).astype(arr2d.dtype) + return result.astype(arr2d.dtype) + + +def _make_overview(arr: np.ndarray, method: str = 'mean') -> np.ndarray: + """Generate a 2x decimated overview. + + Parameters + ---------- + arr : np.ndarray + 2D or 3D (height, width, bands) array. + method : str + Resampling method: 'mean' (default), 'nearest', 'min', 'max', + 'median', 'mode', or 'cubic'. + + Returns + ------- + np.ndarray + Half-resolution array. + """ + if arr.ndim == 3: + bands = [_block_reduce_2d(arr[:, :, b], method) for b in range(arr.shape[2])] + return np.stack(bands, axis=2) + return _block_reduce_2d(arr, method) + + +# --------------------------------------------------------------------------- +# Tag serialization +# --------------------------------------------------------------------------- + +def _float_to_rational(val): + """Convert a float to a TIFF RATIONAL (numerator, denominator) pair.""" + if val == int(val): + return (int(val), 1) + # Use a denominator of 10000 for reasonable precision + den = 10000 + num = int(round(val * den)) + return (num, den) + + +def _serialize_tag_value(type_id, count, values): + """Serialize tag values to bytes.""" + if type_id == ASCII: + if isinstance(values, str): + return values.encode('ascii') + b'\x00' + return values + b'\x00' + elif type_id == SHORT: + if isinstance(values, (list, tuple)): + return struct.pack(f'{BO}{count}H', *values) + return struct.pack(f'{BO}H', values) + elif type_id == LONG: + if isinstance(values, (list, tuple)): + return struct.pack(f'{BO}{count}I', *values) + return struct.pack(f'{BO}I', values) + elif type_id == RATIONAL: + # RATIONAL = two LONGs (numerator, denominator) per value + if isinstance(values, (list, tuple)) and isinstance(values[0], (list, tuple)): + parts = [] + for num, den in values: + parts.extend([int(num), int(den)]) + return struct.pack(f'{BO}{count * 2}I', *parts) + else: + num, den = _float_to_rational(float(values)) + return struct.pack(f'{BO}II', num, den) + elif type_id == DOUBLE: + if isinstance(values, (list, tuple)): + return struct.pack(f'{BO}{count}d', *values) + return struct.pack(f'{BO}d', values) + else: + if isinstance(values, bytes): + return values + return struct.pack(f'{BO}I', values) + + +def _pack_tag_value(tag_id: int, type_id: int, count: int, + values, overflow_buf: bytearray, + overflow_base: int, bigtiff: bool = False) -> bytes: + """Pack a single IFD entry. + + Standard TIFF: 12 bytes (tag:2, type:2, count:4, value:4). + BigTIFF: 20 bytes (tag:2, type:2, count:8, value:8). + """ + val_bytes = _serialize_tag_value(type_id, count, values) + + # For ASCII, count is the actual byte length + if type_id == ASCII: + count = len(val_bytes) + + inline_max = 8 if bigtiff else 4 + + if bigtiff: + entry = struct.pack(f'{BO}HHQ', tag_id, type_id, count) + else: + entry = struct.pack(f'{BO}HHI', tag_id, type_id, count) + + if len(val_bytes) <= inline_max: + value_field = val_bytes.ljust(inline_max, b'\x00') + else: + offset = overflow_base + len(overflow_buf) + if bigtiff: + value_field = struct.pack(f'{BO}Q', offset) + else: + value_field = struct.pack(f'{BO}I', offset) + overflow_buf.extend(val_bytes) + if len(overflow_buf) % 2: + overflow_buf.append(0) + + return entry + value_field + + +def _build_ifd(tags: list[tuple], overflow_base: int, + bigtiff: bool = False) -> tuple[bytes, bytes]: + """Build a complete IFD block. + + Parameters + ---------- + tags : list of (tag_id, type_id, count, values) + Tags sorted by tag_id. + overflow_base : int + Where overflow data starts in the file. + + Returns + ------- + (ifd_bytes, overflow_bytes) + """ + # Sort by tag ID (TIFF spec requires this) + tags = sorted(tags, key=lambda t: t[0]) + + num_entries = len(tags) + overflow_buf = bytearray() + + if bigtiff: + ifd_parts = [struct.pack(f'{BO}Q', num_entries)] + else: + ifd_parts = [struct.pack(f'{BO}H', num_entries)] + + for tag_id, type_id, count, values in tags: + entry = _pack_tag_value(tag_id, type_id, count, values, + overflow_buf, overflow_base, bigtiff=bigtiff) + ifd_parts.append(entry) + + # Next IFD offset (0 = no more IFDs, will be patched for COG) + if bigtiff: + ifd_parts.append(struct.pack(f'{BO}Q', 0)) + else: + ifd_parts.append(struct.pack(f'{BO}I', 0)) + + return b''.join(ifd_parts), bytes(overflow_buf) + + +# --------------------------------------------------------------------------- +# Strip writer +# --------------------------------------------------------------------------- + +def _write_stripped(data: np.ndarray, compression: int, predictor: bool, + rows_per_strip: int = 256) -> tuple[list, list, list]: + """Compress data as strips. + + Returns + ------- + (offsets_placeholder, byte_counts, compressed_chunks) + offsets are relative to the start of the compressed data block. + compressed_chunks is a list of bytes objects (one per strip). + """ + height, width = data.shape[:2] + samples = data.shape[2] if data.ndim == 3 else 1 + dtype = data.dtype + bytes_per_sample = dtype.itemsize + + strips = [] + rel_offsets = [] + byte_counts = [] + current_offset = 0 + + num_strips = math.ceil(height / rows_per_strip) + for i in range(num_strips): + r0 = i * rows_per_strip + r1 = min(r0 + rows_per_strip, height) + strip_rows = r1 - r0 + + if predictor and compression != COMPRESSION_NONE: + strip_arr = np.ascontiguousarray(data[r0:r1]) + buf = strip_arr.view(np.uint8).ravel().copy() + buf = predictor_encode(buf, width, strip_rows, bytes_per_sample * samples) + strip_data = buf.tobytes() + else: + strip_data = np.ascontiguousarray(data[r0:r1]).tobytes() + + compressed = compress(strip_data, compression) + + rel_offsets.append(current_offset) + byte_counts.append(len(compressed)) + strips.append(compressed) + current_offset += len(compressed) + + return rel_offsets, byte_counts, strips + + +# --------------------------------------------------------------------------- +# Tile writer +# --------------------------------------------------------------------------- + +def _write_tiled(data: np.ndarray, compression: int, predictor: bool, + tile_size: int = 256) -> tuple[list, list, list]: + """Compress data as tiles. + + Returns + ------- + (relative_offsets, byte_counts, compressed_chunks) + compressed_chunks is a list of bytes objects (one per tile). + """ + height, width = data.shape[:2] + samples = data.shape[2] if data.ndim == 3 else 1 + dtype = data.dtype + bytes_per_sample = dtype.itemsize + + tw = tile_size + th = tile_size + tiles_across = math.ceil(width / tw) + tiles_down = math.ceil(height / th) + + tiles = [] + rel_offsets = [] + byte_counts = [] + current_offset = 0 + + for tr in range(tiles_down): + for tc in range(tiles_across): + r0 = tr * th + c0 = tc * tw + r1 = min(r0 + th, height) + c1 = min(c0 + tw, width) + + actual_h = r1 - r0 + actual_w = c1 - c0 + + # Extract tile, pad to full tile size if needed + tile_slice = data[r0:r1, c0:c1] + + if actual_h < th or actual_w < tw: + if data.ndim == 3: + padded = np.empty((th, tw, samples), dtype=dtype) + else: + padded = np.empty((th, tw), dtype=dtype) + padded[:actual_h, :actual_w] = tile_slice + # Zero only the padding regions + if actual_h < th: + padded[actual_h:, :] = 0 + if actual_w < tw: + padded[:actual_h, actual_w:] = 0 + tile_arr = padded + else: + tile_arr = np.ascontiguousarray(tile_slice) + + if predictor and compression != COMPRESSION_NONE: + buf = tile_arr.view(np.uint8).ravel().copy() + buf = predictor_encode(buf, tw, th, bytes_per_sample * samples) + tile_data = buf.tobytes() + else: + tile_data = tile_arr.tobytes() + + compressed = compress(tile_data, compression) + + rel_offsets.append(current_offset) + byte_counts.append(len(compressed)) + tiles.append(compressed) + current_offset += len(compressed) + + return rel_offsets, byte_counts, tiles + + +# --------------------------------------------------------------------------- +# File assembly +# --------------------------------------------------------------------------- + +def _assemble_tiff(width: int, height: int, dtype: np.dtype, + compression: int, predictor: bool, + tiled: bool, tile_size: int, + pixel_data_parts: list[tuple], + geo_transform: GeoTransform | None, + crs_epsg: int | None, + nodata, + is_cog: bool = False, + raster_type: int = 1, + gdal_metadata_xml: str | None = None, + extra_tags: list | None = None, + x_resolution: float | None = None, + y_resolution: float | None = None, + resolution_unit: int | None = None, + force_bigtiff: bool | None = None) -> bytes: + """Assemble a complete TIFF file. + + Parameters + ---------- + pixel_data_parts : list of (array, width, height, relative_offsets, byte_counts, compressed_data) + One entry per resolution level (full res first, then overviews). + is_cog : bool + If True, layout IFDs contiguously at file start (COG layout). + raster_type : int + 1 = PixelIsArea, 2 = PixelIsPoint. + + Returns + ------- + bytes + Complete TIFF file. + """ + bits_per_sample, sample_format = numpy_to_tiff_dtype(dtype) + + # Determine samples per pixel from the pixel data + first_arr = pixel_data_parts[0][0] + samples_per_pixel = first_arr.shape[2] if first_arr.ndim == 3 else 1 + + # Build geo tags + geo_tags_dict = {} + if geo_transform is not None: + geo_tags_dict = build_geo_tags( + geo_transform, crs_epsg, nodata, raster_type=raster_type) + else: + # No spatial reference -- still write CRS and nodata if provided + if crs_epsg is not None or nodata is not None: + geo_tags_dict = build_geo_tags( + GeoTransform(), crs_epsg, nodata, raster_type=raster_type, + ) + # Remove the default pixel scale / tiepoint tags since we + # have no real transform -- keep only GeoKeys and NODATA. + geo_tags_dict.pop(TAG_MODEL_PIXEL_SCALE, None) + geo_tags_dict.pop(TAG_MODEL_TIEPOINT, None) + + # Compression tag for predictor + pred_val = 2 if (predictor and compression != COMPRESSION_NONE) else 1 + + # Build IFDs for each resolution level + ifd_specs = [] + for level_idx, (arr, lw, lh, rel_offsets, byte_counts, comp_data) in enumerate(pixel_data_parts): + tags = [] + + tags.append((TAG_IMAGE_WIDTH, LONG, 1, lw)) + tags.append((TAG_IMAGE_LENGTH, LONG, 1, lh)) + if samples_per_pixel > 1: + tags.append((TAG_BITS_PER_SAMPLE, SHORT, samples_per_pixel, + [bits_per_sample] * samples_per_pixel)) + else: + tags.append((TAG_BITS_PER_SAMPLE, SHORT, 1, bits_per_sample)) + tags.append((TAG_COMPRESSION, SHORT, 1, compression)) + # Photometric: RGB for 3+ bands, BlackIsZero for single-band + photometric = 2 if samples_per_pixel >= 3 else 1 + tags.append((TAG_PHOTOMETRIC, SHORT, 1, photometric)) + tags.append((TAG_SAMPLES_PER_PIXEL, SHORT, 1, samples_per_pixel)) + if samples_per_pixel > 1: + tags.append((TAG_SAMPLE_FORMAT, SHORT, samples_per_pixel, + [sample_format] * samples_per_pixel)) + else: + tags.append((TAG_SAMPLE_FORMAT, SHORT, 1, sample_format)) + + # ExtraSamples: for bands beyond what Photometric accounts for + # Photometric=2 (RGB) accounts for 3 bands; any extra are alpha/other + if photometric == 2 and samples_per_pixel > 3: + n_extra = samples_per_pixel - 3 + # 2 = unassociated alpha for the first extra, 0 = unspecified for rest + extra_vals = [2] + [0] * (n_extra - 1) + tags.append((TAG_EXTRA_SAMPLES, SHORT, n_extra, extra_vals)) + elif photometric == 1 and samples_per_pixel > 1: + n_extra = samples_per_pixel - 1 + extra_vals = [0] * n_extra # unspecified + tags.append((TAG_EXTRA_SAMPLES, SHORT, n_extra, extra_vals)) + + if pred_val != 1: + tags.append((TAG_PREDICTOR, SHORT, 1, pred_val)) + + # Resolution / DPI tags + if x_resolution is not None: + tags.append((TAG_X_RESOLUTION, RATIONAL, 1, x_resolution)) + if y_resolution is not None: + tags.append((TAG_Y_RESOLUTION, RATIONAL, 1, y_resolution)) + if resolution_unit is not None: + tags.append((TAG_RESOLUTION_UNIT, SHORT, 1, resolution_unit)) + + if tiled: + tags.append((TAG_TILE_WIDTH, SHORT, 1, tile_size)) + tags.append((TAG_TILE_LENGTH, SHORT, 1, tile_size)) + # Placeholder offsets/counts -- will be patched + tags.append((TAG_TILE_OFFSETS, LONG, len(rel_offsets), rel_offsets)) + tags.append((TAG_TILE_BYTE_COUNTS, LONG, len(byte_counts), byte_counts)) + else: + rows_per_strip = 256 + if lh <= rows_per_strip: + rows_per_strip = lh + tags.append((TAG_ROWS_PER_STRIP, SHORT, 1, rows_per_strip)) + tags.append((TAG_STRIP_OFFSETS, LONG, len(rel_offsets), rel_offsets)) + tags.append((TAG_STRIP_BYTE_COUNTS, LONG, len(byte_counts), byte_counts)) + + # Geo tags only on first IFD + if level_idx == 0: + for gtag, gval in geo_tags_dict.items(): + if gtag == TAG_MODEL_PIXEL_SCALE: + tags.append((gtag, DOUBLE, 3, list(gval))) + elif gtag == TAG_MODEL_TIEPOINT: + tags.append((gtag, DOUBLE, 6, list(gval))) + elif gtag == TAG_GEO_KEY_DIRECTORY: + tags.append((gtag, SHORT, len(gval), list(gval))) + elif gtag == TAG_GDAL_NODATA: + tags.append((gtag, ASCII, len(str(gval)) + 1, str(gval))) + + # GDALMetadata XML (tag 42112) + if gdal_metadata_xml is not None: + tags.append((TAG_GDAL_METADATA, ASCII, + len(gdal_metadata_xml) + 1, gdal_metadata_xml)) + + # Extra tags (pass-through from source file) + if extra_tags is not None: + for etag_id, etype_id, ecount, evalue in extra_tags: + # Skip any tag we already wrote to avoid duplicates + existing_ids = {t[0] for t in tags} + if etag_id not in existing_ids: + tags.append((etag_id, etype_id, ecount, evalue)) + + ifd_specs.append(tags) + + # --- Determine if BigTIFF is needed --- + # Classic TIFF uses 32-bit offsets (max ~4.29 GB). Estimate total file + # size including headers, IFDs, overflow data, and all pixel data. + # Switch to BigTIFF if any offset could exceed 2^32. + total_pixel_data = sum(sum(len(c) for c in chunks) + for _, _, _, _, _, chunks in pixel_data_parts) + # Conservative overhead estimate: header + IFDs + overflow + geo tags + num_levels = len(ifd_specs) + max_tags_per_ifd = max(len(tags) for tags in ifd_specs) if ifd_specs else 20 + ifd_overhead = num_levels * (2 + 12 * max_tags_per_ifd + 4 + 1024) # ~1KB overflow per IFD + estimated_file_size = 8 + ifd_overhead + total_pixel_data + + UINT32_MAX = 0xFFFFFFFF # 4,294,967,295 + if force_bigtiff is not None: + bigtiff = force_bigtiff + else: + bigtiff = estimated_file_size > UINT32_MAX + + header_size = 16 if bigtiff else 8 + + if is_cog and len(ifd_specs) > 1: + return _assemble_cog_layout(header_size, ifd_specs, pixel_data_parts, + bigtiff=bigtiff) + else: + return _assemble_standard_layout(header_size, ifd_specs, pixel_data_parts, + bigtiff=bigtiff) + + +def _assemble_standard_layout(header_size: int, + ifd_specs: list, + pixel_data_parts: list, + bigtiff: bool = False) -> bytes: + """Assemble standard TIFF layout (one IFD at a time).""" + output = bytearray() + entry_size = 20 if bigtiff else 12 + + # TIFF header + output.extend(b'II') # little-endian + if bigtiff: + output.extend(struct.pack(f'{BO}H', 43)) # BigTIFF magic + output.extend(struct.pack(f'{BO}H', 8)) # offset size + output.extend(struct.pack(f'{BO}H', 0)) # padding + output.extend(struct.pack(f'{BO}Q', 0)) # first IFD offset placeholder + else: + output.extend(struct.pack(f'{BO}H', 42)) # magic + output.extend(struct.pack(f'{BO}I', 0)) # first IFD offset placeholder + + for level_idx, (tags, (_arr, _lw, _lh, rel_offsets, byte_counts, comp_chunks)) in enumerate( + zip(ifd_specs, pixel_data_parts)): + + ifd_offset = len(output) + + if level_idx == 0: + if bigtiff: + struct.pack_into(f'{BO}Q', output, 8, ifd_offset) + else: + struct.pack_into(f'{BO}I', output, 4, ifd_offset) + + num_entries = len(tags) + count_size = 8 if bigtiff else 2 + next_size = 8 if bigtiff else 4 + ifd_block_size = count_size + entry_size * num_entries + next_size + overflow_base = ifd_offset + ifd_block_size + + ifd_bytes, overflow_bytes = _build_ifd(tags, overflow_base, bigtiff=bigtiff) + + pixel_data_offset = overflow_base + len(overflow_bytes) + + patched_tags = [] + for tag_id, type_id, count, values in tags: + if tag_id in (TAG_STRIP_OFFSETS, TAG_TILE_OFFSETS): + actual_offsets = [pixel_data_offset + ro for ro in rel_offsets] + patched_tags.append((tag_id, type_id, count, actual_offsets)) + else: + patched_tags.append((tag_id, type_id, count, values)) + + ifd_bytes, overflow_bytes = _build_ifd(patched_tags, overflow_base, + bigtiff=bigtiff) + + output.extend(ifd_bytes) + output.extend(overflow_bytes) + # Extend directly from chunk list (no intermediate join copy) + for chunk in comp_chunks: + output.extend(chunk) + + # Patch next IFD pointer if there are more levels + if level_idx < len(ifd_specs) - 1: + next_ifd_offset = len(output) + next_ptr_pos = ifd_offset + count_size + entry_size * num_entries + if bigtiff: + struct.pack_into(f'{BO}Q', output, next_ptr_pos, next_ifd_offset) + else: + struct.pack_into(f'{BO}I', output, next_ptr_pos, next_ifd_offset) + + return bytes(output) + + +def _assemble_cog_layout(header_size: int, + ifd_specs: list, + pixel_data_parts: list, + bigtiff: bool = False) -> bytes: + """Assemble COG layout: all IFDs first, then all pixel data.""" + entry_size = 20 if bigtiff else 12 + count_size = 8 if bigtiff else 2 + next_size = 8 if bigtiff else 4 + + # First pass: compute IFD sizes + ifd_blocks = [] + for tags in ifd_specs: + num_entries = len(tags) + ifd_block_size = count_size + entry_size * num_entries + next_size + _, overflow = _build_ifd(tags, 0, bigtiff=bigtiff) + ifd_blocks.append((ifd_block_size, len(overflow))) + + total_ifd_size = sum(bs + ov for bs, ov in ifd_blocks) + pixel_data_start = header_size + total_ifd_size + + # Second pass: pixel data offsets per level + current_pixel_offset = pixel_data_start + level_pixel_offsets = [] + for _arr, _lw, _lh, rel_offsets, byte_counts, comp_chunks in pixel_data_parts: + level_pixel_offsets.append(current_pixel_offset) + current_pixel_offset += sum(len(c) for c in comp_chunks) + + # Third pass: build IFDs with correct offsets + output = bytearray() + output.extend(b'II') + if bigtiff: + output.extend(struct.pack(f'{BO}H', 43)) + output.extend(struct.pack(f'{BO}H', 8)) + output.extend(struct.pack(f'{BO}H', 0)) + output.extend(struct.pack(f'{BO}Q', header_size)) + else: + output.extend(struct.pack(f'{BO}H', 42)) + output.extend(struct.pack(f'{BO}I', header_size)) + + current_ifd_pos = header_size + for level_idx, (tags, (_arr, _lw, _lh, rel_offsets, byte_counts, comp_chunks)) in enumerate( + zip(ifd_specs, pixel_data_parts)): + + pixel_base = level_pixel_offsets[level_idx] + + patched_tags = [] + for tag_id, type_id, count, values in tags: + if tag_id in (TAG_STRIP_OFFSETS, TAG_TILE_OFFSETS): + actual_offsets = [pixel_base + ro for ro in rel_offsets] + patched_tags.append((tag_id, type_id, count, actual_offsets)) + else: + patched_tags.append((tag_id, type_id, count, values)) + + num_entries = len(patched_tags) + ifd_block_size = count_size + entry_size * num_entries + next_size + overflow_base = current_ifd_pos + ifd_block_size + + ifd_bytes, overflow_bytes = _build_ifd(patched_tags, overflow_base, + bigtiff=bigtiff) + + # Patch next IFD offset + if level_idx < len(ifd_specs) - 1: + next_ifd_pos = current_ifd_pos + ifd_block_size + len(overflow_bytes) + ifd_ba = bytearray(ifd_bytes) + next_ptr_pos = count_size + entry_size * num_entries + if bigtiff: + struct.pack_into(f'{BO}Q', ifd_ba, next_ptr_pos, next_ifd_pos) + else: + struct.pack_into(f'{BO}I', ifd_ba, next_ptr_pos, next_ifd_pos) + ifd_bytes = bytes(ifd_ba) + + output.extend(ifd_bytes) + output.extend(overflow_bytes) + current_ifd_pos = len(output) + + # Append all pixel data + for _arr, _lw, _lh, _rel_offsets, _byte_counts, comp_chunks in pixel_data_parts: + for chunk in comp_chunks: + output.extend(chunk) + + return bytes(output) + + +# --------------------------------------------------------------------------- +# Public write function +# --------------------------------------------------------------------------- + +def write(data: np.ndarray, path: str, *, + geo_transform: GeoTransform | None = None, + crs_epsg: int | None = None, + nodata=None, + compression: str = 'deflate', + tiled: bool = True, + tile_size: int = 256, + predictor: bool = False, + cog: bool = False, + overview_levels: list[int] | None = None, + overview_resampling: str = 'mean', + raster_type: int = 1, + x_resolution: float | None = None, + y_resolution: float | None = None, + resolution_unit: int | None = None, + gdal_metadata_xml: str | None = None, + extra_tags: list | None = None, + bigtiff: bool | None = None) -> None: + """Write a numpy array as a GeoTIFF or COG. + + Parameters + ---------- + data : np.ndarray + 2D array (height x width). + path : str + Output file path. + geo_transform : GeoTransform or None + Pixel-to-coordinate mapping. + crs_epsg : int or None + EPSG code. + nodata : float, int, or None + NoData value. + compression : str + 'none', 'deflate', or 'lzw'. + tiled : bool + Use tiled layout (vs strips). + tile_size : int + Tile width and height. + predictor : bool + Use horizontal differencing predictor. + cog : bool + Write as Cloud Optimized GeoTIFF. + overview_levels : list of int or None + Overview decimation factors (e.g. [2, 4, 8]). + Only used if cog=True. If None and cog=True, auto-generate. + """ + comp_tag = _compression_tag(compression) + + # Build pixel data parts + parts = [] + + # Full resolution + if tiled: + rel_off, bc, comp_data = _write_tiled(data, comp_tag, predictor, tile_size) + else: + rel_off, bc, comp_data = _write_stripped(data, comp_tag, predictor) + + h, w = data.shape[:2] + parts.append((data, w, h, rel_off, bc, comp_data)) + + # Overviews + if cog: + if overview_levels is None: + # Auto-generate: keep halving until < tile_size + overview_levels = [] + oh, ow = h, w + while oh > tile_size and ow > tile_size: + oh //= 2 + ow //= 2 + if oh > 0 and ow > 0: + overview_levels.append(len(overview_levels) + 1) + + current = data + for _ in overview_levels: + current = _make_overview(current, method=overview_resampling) + oh, ow = current.shape[:2] + if tiled: + o_off, o_bc, o_data = _write_tiled(current, comp_tag, predictor, tile_size) + else: + o_off, o_bc, o_data = _write_stripped(current, comp_tag, predictor) + parts.append((current, ow, oh, o_off, o_bc, o_data)) + + file_bytes = _assemble_tiff( + w, h, data.dtype, comp_tag, predictor, tiled, tile_size, + parts, geo_transform, crs_epsg, nodata, is_cog=cog, + raster_type=raster_type, + gdal_metadata_xml=gdal_metadata_xml, + extra_tags=extra_tags, + x_resolution=x_resolution, y_resolution=y_resolution, + resolution_unit=resolution_unit, + force_bigtiff=bigtiff, + ) + + _write_bytes(file_bytes, path) + + # Post-write validation: verify the header is parseable + from ._header import parse_header as _ph + try: + _ph(file_bytes[:16]) + except Exception as e: + import warnings + warnings.warn(f"Written file may be corrupt: {e}", stacklevel=2) + + +def _is_fsspec_uri(path: str) -> bool: + """Check if a path is a fsspec-compatible URI.""" + if path.startswith(('http://', 'https://')): + return False + return '://' in path + + +def _write_bytes(file_bytes: bytes, path: str) -> None: + """Write bytes to a local file (atomic) or cloud storage (via fsspec).""" + import os + + if _is_fsspec_uri(path): + try: + import fsspec + except ImportError: + raise ImportError( + "fsspec is required to write to cloud storage. " + "Install it with: pip install fsspec") + fs, fspath = fsspec.core.url_to_fs(path) + with fs.open(fspath, 'wb') as f: + f.write(file_bytes) + return + + # Local file: write to temp file then atomically rename + import tempfile + dir_name = os.path.dirname(os.path.abspath(path)) + fd, tmp_path = tempfile.mkstemp(dir=dir_name, suffix='.tif.tmp') + try: + with os.fdopen(fd, 'wb') as f: + f.write(file_bytes) + os.replace(tmp_path, path) # atomic on POSIX + except BaseException: + try: + os.unlink(tmp_path) + except OSError: + pass + raise diff --git a/xrspatial/geotiff/tests/__init__.py b/xrspatial/geotiff/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/xrspatial/geotiff/tests/bench_vs_rioxarray.py b/xrspatial/geotiff/tests/bench_vs_rioxarray.py new file mode 100644 index 00000000..82abe85b --- /dev/null +++ b/xrspatial/geotiff/tests/bench_vs_rioxarray.py @@ -0,0 +1,318 @@ +"""Benchmark xrspatial.geotiff vs rioxarray for read/write performance and consistency.""" +from __future__ import annotations + +import os +import tempfile +import time + +import numpy as np +import xarray as xr + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _timer(fn, warmup=1, runs=5): + """Time a callable, returning (median_seconds, result_from_last_call).""" + for _ in range(warmup): + result = fn() + times = [] + for _ in range(runs): + t0 = time.perf_counter() + result = fn() + times.append(time.perf_counter() - t0) + times.sort() + return times[len(times) // 2], result + + +def _fmt_ms(seconds): + return f"{seconds * 1000:.1f} ms" + + +# --------------------------------------------------------------------------- +# Consistency check +# --------------------------------------------------------------------------- + +def check_consistency(path): + """Compare pixel values and geo metadata between the two readers.""" + import rioxarray # noqa: F401 + from xrspatial.geotiff import read_geotiff + + rio_da = xr.open_dataarray(path, engine='rasterio') + rio_arr = rio_da.squeeze('band').values.astype(np.float64) + + our_da = read_geotiff(path) + our_arr = our_da.values.astype(np.float64) + + # Shape + assert rio_arr.shape == our_arr.shape, ( + f"Shape mismatch: rioxarray {rio_arr.shape} vs ours {our_arr.shape}") + + # Pixel values (count NaN agreement as exact match) + rio_nan = np.isnan(rio_arr) + our_nan = np.isnan(our_arr) + both_nan = rio_nan & our_nan + valid = ~(rio_nan | our_nan) + diff = np.zeros_like(rio_arr) + diff[valid] = np.abs(rio_arr[valid] - our_arr[valid]) + max_diff = float(diff[valid].max()) if valid.any() else 0.0 + mean_diff = float(diff[valid].mean()) if valid.any() else 0.0 + # Exact = same value on valid pixels + both NaN on NaN pixels + exact_count = int(np.sum(diff[valid] == 0)) + int(both_nan.sum()) + pct_exact = exact_count / diff.size * 100 + + # CRS + rio_epsg = rio_da.rio.crs.to_epsg() if rio_da.rio.crs else None + our_epsg = our_da.attrs.get('crs') + + # Coordinate comparison + rio_y = rio_da.coords['y'].values + rio_x = rio_da.coords['x'].values + our_y = our_da.coords['y'].values + our_x = our_da.coords['x'].values + + y_max_diff = float(np.max(np.abs(rio_y - our_y))) if len(rio_y) == len(our_y) else float('inf') + x_max_diff = float(np.max(np.abs(rio_x - our_x))) if len(rio_x) == len(our_x) else float('inf') + + return { + 'shape': rio_arr.shape, + 'dtype_rio': str(rio_da.dtype), + 'dtype_ours': str(our_da.dtype), + 'max_pixel_diff': max_diff, + 'mean_pixel_diff': mean_diff, + 'pct_exact_match': pct_exact, + 'epsg_rio': rio_epsg, + 'epsg_ours': our_epsg, + 'epsg_match': rio_epsg == our_epsg, + 'y_max_diff': y_max_diff, + 'x_max_diff': x_max_diff, + } + + +# --------------------------------------------------------------------------- +# Read benchmark +# --------------------------------------------------------------------------- + +def bench_read(path, runs=10): + """Benchmark read performance.""" + import rioxarray # noqa: F401 + from xrspatial.geotiff import read_geotiff + + def rio_read(): + da = xr.open_dataarray(path, engine='rasterio') + _ = da.values # force load + da.close() + return da + + def our_read(): + return read_geotiff(path) + + rio_time, _ = _timer(rio_read, warmup=2, runs=runs) + our_time, _ = _timer(our_read, warmup=2, runs=runs) + + return rio_time, our_time + + +# --------------------------------------------------------------------------- +# Write benchmark +# --------------------------------------------------------------------------- + +def bench_write(shape=(512, 512), compression='deflate', runs=5): + """Benchmark write performance.""" + import rioxarray # noqa: F401 + from xrspatial.geotiff import write_geotiff + from xrspatial.geotiff._geotags import GeoTransform + + rng = np.random.RandomState(42) + arr = rng.rand(*shape).astype(np.float32) + + y = np.linspace(45.0, 44.0, shape[0]) + x = np.linspace(-120.0, -119.0, shape[1]) + da = xr.DataArray(arr, dims=['y', 'x'], coords={'y': y, 'x': x}) + da = da.rio.write_crs(4326) + da = da.rio.write_nodata(np.nan) + + da_ours = xr.DataArray(arr, dims=['y', 'x'], coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + + tmpdir = tempfile.mkdtemp() + + comp_map = {'deflate': 'deflate', 'lzw': 'lzw', 'none': None} + rio_comp = comp_map.get(compression, compression) + + def rio_write(): + p = os.path.join(tmpdir, 'rio_out.tif') + if rio_comp: + da.rio.to_raster(p, compress=rio_comp.upper()) + else: + da.rio.to_raster(p) + return os.path.getsize(p) + + def our_write(): + p = os.path.join(tmpdir, 'our_out.tif') + write_geotiff(da_ours, p, compression=compression, tiled=False) + return os.path.getsize(p) + + rio_time, rio_size = _timer(rio_write, warmup=1, runs=runs) + our_time, our_size = _timer(our_write, warmup=1, runs=runs) + + return rio_time, our_time, rio_size, our_size + + +# --------------------------------------------------------------------------- +# Write + read-back consistency +# --------------------------------------------------------------------------- + +def bench_round_trip(shape=(256, 256), compression='deflate'): + """Write with our module, read back with rioxarray, and vice versa.""" + import rioxarray # noqa: F401 + from xrspatial.geotiff import read_geotiff, write_geotiff + + rng = np.random.RandomState(99) + arr = rng.rand(*shape).astype(np.float32) + y = np.linspace(45.0, 44.0, shape[0]) + x = np.linspace(-120.0, -119.0, shape[1]) + + tmpdir = tempfile.mkdtemp() + + # Ours write -> rioxarray read + our_path = os.path.join(tmpdir, 'ours.tif') + da_ours = xr.DataArray(arr, dims=['y', 'x'], coords={'y': y, 'x': x}, + attrs={'crs': 4326}) + write_geotiff(da_ours, our_path, compression=compression, tiled=False) + + rio_da = xr.open_dataarray(our_path, engine='rasterio') + rio_arr = rio_da.squeeze('band').values if 'band' in rio_da.dims else rio_da.values + rio_da.close() + + diff1 = float(np.nanmax(np.abs(arr - rio_arr))) + + # rioxarray write -> ours read + rio_path = os.path.join(tmpdir, 'rio.tif') + da_rio = xr.DataArray(arr, dims=['y', 'x'], coords={'y': y, 'x': x}) + da_rio = da_rio.rio.write_crs(4326) + comp_map = {'deflate': 'DEFLATE', 'lzw': 'LZW', 'none': None} + rio_comp = comp_map.get(compression) + if rio_comp: + da_rio.rio.to_raster(rio_path, compress=rio_comp) + else: + da_rio.rio.to_raster(rio_path) + + our_da = read_geotiff(rio_path) + our_arr = our_da.values + + diff2 = float(np.nanmax(np.abs(arr - our_arr))) + + return diff1, diff2 + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + landsat_dir = 'docs/source/user_guide/data' + bands = [ + 'LC80030172015001LGN00_B2.tiff', + 'LC80030172015001LGN00_B3.tiff', + 'LC80030172015001LGN00_B4.tiff', + 'LC80030172015001LGN00_B5.tiff', + ] + + print("=" * 72) + print("xrspatial.geotiff vs rioxarray -- Benchmark & Consistency") + print("=" * 72) + + # --- Consistency on real Landsat files --- + print("\n--- Pixel & Metadata Consistency (Landsat 8 bands) ---\n") + for band_file in bands: + path = os.path.join(landsat_dir, band_file) + if not os.path.exists(path): + print(f" {band_file}: SKIPPED (not found)") + continue + c = check_consistency(path) + name = band_file.split('_')[1].replace('.tiff', '') + print(f" {name}: shape={c['shape']} dtype rio={c['dtype_rio']} ours={c['dtype_ours']}") + print(f" pixels: max_diff={c['max_pixel_diff']:.6f} " + f"mean_diff={c['mean_pixel_diff']:.6f} exact={c['pct_exact_match']:.1f}%") + print(f" EPSG: rio={c['epsg_rio']} ours={c['epsg_ours']} match={c['epsg_match']}") + print(f" coords: y_max_diff={c['y_max_diff']:.6f} x_max_diff={c['x_max_diff']:.6f}") + + # --- Read performance --- + print("\n--- Read Performance (median of 10 runs) ---\n") + print(f" {'File':<8} {'rioxarray':>12} {'xrspatial':>12} {'ratio':>8}") + print(f" {'-'*8} {'-'*12} {'-'*12} {'-'*8}") + for band_file in bands: + path = os.path.join(landsat_dir, band_file) + if not os.path.exists(path): + continue + rio_t, our_t = bench_read(path, runs=10) + name = band_file.split('_')[1].replace('.tiff', '') + ratio = our_t / rio_t if rio_t > 0 else float('inf') + print(f" {name:<8} {_fmt_ms(rio_t):>12} {_fmt_ms(our_t):>12} {ratio:>7.2f}x") + + # --- Write performance --- + print("\n--- Write Performance (512x512 float32, median of 5 runs) ---\n") + print(f" {'Compression':<12} {'rioxarray':>12} {'xrspatial':>12} {'ratio':>8} {'size rio':>10} {'size ours':>10}") + print(f" {'-'*12} {'-'*12} {'-'*12} {'-'*8} {'-'*10} {'-'*10}") + for comp in ['none', 'deflate', 'lzw']: + rio_t, our_t, rio_sz, our_sz = bench_write((512, 512), comp, runs=5) + ratio = our_t / rio_t if rio_t > 0 else float('inf') + print(f" {comp:<12} {_fmt_ms(rio_t):>12} {_fmt_ms(our_t):>12} {ratio:>7.2f}x " + f"{rio_sz:>9,} {our_sz:>9,}") + + # --- Write performance (larger) --- + print("\n--- Write Performance (2048x2048 float32, median of 3 runs) ---\n") + print(f" {'Compression':<12} {'rioxarray':>12} {'xrspatial':>12} {'ratio':>8} {'size rio':>10} {'size ours':>10}") + print(f" {'-'*12} {'-'*12} {'-'*12} {'-'*8} {'-'*10} {'-'*10}") + for comp in ['none', 'deflate']: + rio_t, our_t, rio_sz, our_sz = bench_write((2048, 2048), comp, runs=3) + ratio = our_t / rio_t if rio_t > 0 else float('inf') + print(f" {comp:<12} {_fmt_ms(rio_t):>12} {_fmt_ms(our_t):>12} {ratio:>7.2f}x " + f"{rio_sz:>9,} {our_sz:>9,}") + + # --- Cross-library round-trip --- + print("\n--- Cross-Library Round-Trip Consistency ---\n") + for comp in ['none', 'deflate']: + d1, d2 = bench_round_trip((256, 256), comp) + print(f" {comp}: ours->rioxarray max_diff={d1:.8f} rioxarray->ours max_diff={d2:.8f}") + + # --- Real-world files from rtxpy --- + rtxpy_dir = '../rtxpy/examples' + rtxpy_files = [ + ('render_demo_terrain.tif', 'uncompressed strip'), + ('Copernicus_DSM_COG_10_N40_00_W075_00_DEM.tif', 'deflate+fpred COG'), + ('Copernicus_DSM_COG_10_S23_00_W044_00_DEM.tif', 'deflate+fpred COG'), + ('USGS_1_n43w122.tif', 'LZW+fpred COG'), + ('USGS_1_n39w106.tif', 'LZW+fpred COG'), + ('USGS_one_meter_x65y454_NY_LongIsland_Z18_2014.tif', 'LZW tiled COG'), + ] + + print("\n--- Real-World Files: Consistency & Read Performance ---\n") + print(f" {'File':<52} {'Format':<20} {'Shape':>12} {'Exact%':>7} {'rio':>9} {'ours':>9} {'ratio':>7}") + print(f" {'-'*52} {'-'*20} {'-'*12} {'-'*7} {'-'*9} {'-'*9} {'-'*7}") + + for fname, desc in rtxpy_files: + path = os.path.join(rtxpy_dir, fname) + if not os.path.exists(path): + continue + + # Consistency + c = check_consistency(path) + + # Performance (fewer runs for large files) + fsize = os.path.getsize(path) + runs = 3 if fsize > 50_000_000 else 5 + rio_t, our_t = bench_read(path, runs=runs) + ratio = our_t / rio_t if rio_t > 0 else float('inf') + + shape_str = f"{c['shape'][0]}x{c['shape'][1]}" + short_name = fname[:50] + print(f" {short_name:<52} {desc:<20} {shape_str:>12} {c['pct_exact_match']:>6.1f}% " + f"{_fmt_ms(rio_t):>9} {_fmt_ms(our_t):>9} {ratio:>6.2f}x") + + print() + + +if __name__ == '__main__': + main() diff --git a/xrspatial/geotiff/tests/conftest.py b/xrspatial/geotiff/tests/conftest.py new file mode 100644 index 00000000..b90e96f3 --- /dev/null +++ b/xrspatial/geotiff/tests/conftest.py @@ -0,0 +1,269 @@ +"""Shared fixtures for geotiff tests.""" +from __future__ import annotations + +import math +import struct + +import numpy as np +import pytest + + +def make_minimal_tiff( + width: int = 4, + height: int = 4, + dtype: np.dtype = np.dtype('float32'), + pixel_data: np.ndarray | None = None, + compression: int = 1, + tiled: bool = False, + tile_size: int = 4, + big_endian: bool = False, + bigtiff: bool = False, + geo_transform: tuple | None = None, + epsg: int | None = None, +) -> bytes: + """Build a minimal valid TIFF file in memory for testing. + + Uses a three-pass approach: + 1. Collect all tags and their raw value data + 2. Compute file layout (IFD size, overflow positions, pixel data offset) + 3. Serialize everything with correct offsets + """ + bo = '>' if big_endian else '<' + bom = b'MM' if big_endian else b'II' + + if pixel_data is None: + pixel_data = np.arange(width * height, dtype=dtype).reshape(height, width) + else: + dtype = pixel_data.dtype + + bits_per_sample = dtype.itemsize * 8 + if dtype.kind == 'f': + sample_format = 3 + elif dtype.kind == 'i': + sample_format = 2 + else: + sample_format = 1 + + # --- Build pixel data (strips or tiles) --- + if tiled: + tiles_across = math.ceil(width / tile_size) + tiles_down = math.ceil(height / tile_size) + num_tiles = tiles_across * tiles_down + + tile_blobs = [] + for tr in range(tiles_down): + for tc in range(tiles_across): + tile = np.zeros((tile_size, tile_size), dtype=dtype) + r0, c0 = tr * tile_size, tc * tile_size + r1 = min(r0 + tile_size, height) + c1 = min(c0 + tile_size, width) + tile[:r1 - r0, :c1 - c0] = pixel_data[r0:r1, c0:c1] + tile_blobs.append(tile.tobytes()) + + pixel_bytes = b''.join(tile_blobs) + tile_byte_counts = [len(b) for b in tile_blobs] + else: + if big_endian and pixel_data.dtype.itemsize > 1: + pixel_bytes = pixel_data.astype(pixel_data.dtype.newbyteorder('>')).tobytes() + else: + pixel_bytes = pixel_data.tobytes() + + # --- Collect tags as (tag_id, type_id, value_bytes) --- + # value_bytes is the serialized value; if len <= 4 it's inline, else overflow. + tag_list: list[tuple[int, int, int, bytes]] = [] # (tag, type, count, raw_bytes) + + def add_short(tag, val): + tag_list.append((tag, 3, 1, struct.pack(f'{bo}H', val))) + + def add_long(tag, val): + tag_list.append((tag, 4, 1, struct.pack(f'{bo}I', val))) + + def add_shorts(tag, vals): + tag_list.append((tag, 3, len(vals), struct.pack(f'{bo}{len(vals)}H', *vals))) + + def add_longs(tag, vals): + tag_list.append((tag, 4, len(vals), struct.pack(f'{bo}{len(vals)}I', *vals))) + + def add_doubles(tag, vals): + tag_list.append((tag, 12, len(vals), struct.pack(f'{bo}{len(vals)}d', *vals))) + + add_short(256, width) # ImageWidth + add_short(257, height) # ImageLength + add_short(258, bits_per_sample) # BitsPerSample + add_short(259, compression) # Compression + add_short(262, 1) # PhotometricInterpretation + add_short(277, 1) # SamplesPerPixel + add_short(339, sample_format) # SampleFormat + + if tiled: + add_short(322, tile_size) # TileWidth + add_short(323, tile_size) # TileLength + # Placeholder offsets -- will be patched after layout is known + add_longs(324, [0] * num_tiles) # TileOffsets + add_longs(325, tile_byte_counts) # TileByteCounts + else: + add_short(278, height) # RowsPerStrip + add_long(273, 0) # StripOffsets (placeholder) + add_long(279, len(pixel_bytes)) # StripByteCounts + + if geo_transform is not None: + ox, oy, pw, ph = geo_transform + add_doubles(33550, [abs(pw), abs(ph), 0.0]) # ModelPixelScale + add_doubles(33922, [0.0, 0.0, 0.0, ox, oy, 0.0]) # ModelTiepoint + + if epsg is not None: + if epsg == 4326 or (4000 <= epsg < 5000): + model_type, key_id = 2, 2048 + else: + model_type, key_id = 1, 3072 + gkd = [1, 1, 0, 2, 1024, 0, 1, model_type, key_id, 0, 1, epsg] + add_shorts(34735, gkd) + + # Sort by tag ID (TIFF spec requirement) + tag_list.sort(key=lambda t: t[0]) + + # --- Compute layout --- + num_entries = len(tag_list) + ifd_start = 8 # right after header + ifd_size = 2 + 12 * num_entries + 4 # count + entries + next_ifd_offset + overflow_start = ifd_start + ifd_size + + # Figure out which tags need overflow (value > 4 bytes) + overflow_buf = bytearray() + for _tag, _type, _count, raw in tag_list: + if len(raw) > 4: + # This will go to overflow -- just accumulate size for now + overflow_buf.extend(raw) + # Word-align + if len(overflow_buf) % 2: + overflow_buf.append(0) + + pixel_data_start = overflow_start + len(overflow_buf) + + # --- Patch offset tags --- + # Now we know where pixel data starts, patch strip/tile offsets + patched = [] + for tag, typ, count, raw in tag_list: + if tag == 273: # StripOffsets + patched.append((tag, typ, count, struct.pack(f'{bo}I', pixel_data_start))) + elif tag == 324: # TileOffsets + offsets = [] + pos = 0 + for blob in tile_blobs: + offsets.append(pixel_data_start + pos) + pos += len(blob) + patched.append((tag, typ, count, struct.pack(f'{bo}{num_tiles}I', *offsets))) + else: + patched.append((tag, typ, count, raw)) + tag_list = patched + + # --- Rebuild overflow with final values --- + overflow_buf = bytearray() + tag_offsets = {} # tag -> offset within overflow_buf (or None if inline) + + for tag, typ, count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + # Recalculate in case overflow size changed from patching + actual_pixel_start = overflow_start + len(overflow_buf) + if actual_pixel_start != pixel_data_start: + # Need another pass to fix offsets + pixel_data_start = actual_pixel_start + patched2 = [] + for tag, typ, count, raw in tag_list: + if tag == 273: + patched2.append((tag, typ, count, struct.pack(f'{bo}I', pixel_data_start))) + elif tag == 324: + offsets = [] + pos = 0 + for blob in tile_blobs: + offsets.append(pixel_data_start + pos) + pos += len(blob) + patched2.append((tag, typ, count, struct.pack(f'{bo}{num_tiles}I', *offsets))) + else: + patched2.append((tag, typ, count, raw)) + tag_list = patched2 + + # Rebuild overflow again + overflow_buf = bytearray() + tag_offsets = {} + for tag, typ, count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + # --- Serialize --- + out = bytearray() + + # Header + out.extend(bom) + out.extend(struct.pack(f'{bo}H', 42)) + out.extend(struct.pack(f'{bo}I', ifd_start)) + + # IFD + out.extend(struct.pack(f'{bo}H', num_entries)) + + for tag, typ, count, raw in tag_list: + out.extend(struct.pack(f'{bo}HHI', tag, typ, count)) + if len(raw) <= 4: + # Inline value, padded to 4 bytes + out.extend(raw.ljust(4, b'\x00')) + else: + # Pointer to overflow + ptr = overflow_start + tag_offsets[tag] + out.extend(struct.pack(f'{bo}I', ptr)) + + # Next IFD offset + out.extend(struct.pack(f'{bo}I', 0)) + + # Overflow + out.extend(overflow_buf) + + # Pixel data + out.extend(pixel_bytes) + + return bytes(out) + + +@pytest.fixture +def simple_float32_tiff(): + """4x4 float32 stripped TIFF with sequential values.""" + return make_minimal_tiff(4, 4, np.dtype('float32')) + + +@pytest.fixture +def simple_uint16_tiff(): + """4x4 uint16 stripped TIFF.""" + return make_minimal_tiff(4, 4, np.dtype('uint16')) + + +@pytest.fixture +def geo_tiff_data(): + """4x4 float32 TIFF with geo transform and EPSG 4326.""" + return make_minimal_tiff( + 4, 4, np.dtype('float32'), + geo_transform=(-120.0, 45.0, 0.001, -0.001), + epsg=4326, + ) + + +@pytest.fixture +def tiled_tiff_data(): + """8x8 float32 tiled TIFF with 4x4 tiles.""" + data = np.arange(64, dtype=np.float32).reshape(8, 8) + return make_minimal_tiff( + 8, 8, np.dtype('float32'), + pixel_data=data, + tiled=True, + tile_size=4, + ) diff --git a/xrspatial/geotiff/tests/test_cog.py b/xrspatial/geotiff/tests/test_cog.py new file mode 100644 index 00000000..40b24808 --- /dev/null +++ b/xrspatial/geotiff/tests/test_cog.py @@ -0,0 +1,137 @@ +"""Tests for COG writing and the public API.""" +from __future__ import annotations + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import read_geotiff, write_geotiff +from xrspatial.geotiff._header import parse_header, parse_all_ifds +from xrspatial.geotiff._writer import write +from xrspatial.geotiff._geotags import GeoTransform, extract_geo_info + + +class TestCOGWriter: + def test_cog_layout_ifds_before_data(self, tmp_path): + """COG spec: all IFDs should come before pixel data.""" + arr = np.arange(256, dtype=np.float32).reshape(16, 16) + path = str(tmp_path / 'cog.tif') + write(arr, path, compression='deflate', tiled=True, tile_size=8, + cog=True, overview_levels=[1]) + + with open(path, 'rb') as f: + data = f.read() + + header = parse_header(data) + ifds = parse_all_ifds(data, header) + + assert len(ifds) >= 2 # full res + at least 1 overview + + # All IFD offsets should be < the first tile data offset + all_tile_offsets = [] + for ifd in ifds: + tile_off = ifd.tile_offsets + if tile_off: + all_tile_offsets.extend(tile_off) + + if all_tile_offsets: + first_data_offset = min(all_tile_offsets) + # The last IFD byte should be before the first tile data + # (This is the COG layout requirement) + assert header.first_ifd_offset < first_data_offset + + def test_cog_round_trip(self, tmp_path): + arr = np.arange(256, dtype=np.float32).reshape(16, 16) + gt = GeoTransform(-120.0, 45.0, 0.001, -0.001) + path = str(tmp_path / 'cog_rt.tif') + write(arr, path, geo_transform=gt, crs_epsg=4326, + compression='deflate', tiled=True, tile_size=8, + cog=True, overview_levels=[1]) + + result, geo = read_to_array_local(path) + np.testing.assert_array_equal(result, arr) + assert geo.crs_epsg == 4326 + + def test_cog_auto_overviews(self, tmp_path): + """Auto-generate overviews when none specified.""" + arr = np.arange(1024, dtype=np.float32).reshape(32, 32) + path = str(tmp_path / 'cog_auto.tif') + write(arr, path, compression='deflate', tiled=True, tile_size=8, + cog=True) + + with open(path, 'rb') as f: + data = f.read() + + header = parse_header(data) + ifds = parse_all_ifds(data, header) + # Should have at least 2 IFDs (full res + overviews) + assert len(ifds) >= 2 + + +class TestPublicAPI: + def test_read_write_round_trip(self, tmp_path): + """Write a DataArray, read it back, verify values and coords.""" + y = np.linspace(45.0, 44.0, 10) + x = np.linspace(-120.0, -119.0, 12) + data = np.random.RandomState(42).rand(10, 12).astype(np.float32) + + da = xr.DataArray( + data, dims=['y', 'x'], + coords={'y': y, 'x': x}, + attrs={'crs': 4326}, + name='test', + ) + + path = str(tmp_path / 'round_trip.tif') + write_geotiff(da, path, compression='deflate', tiled=False) + + result = read_geotiff(path) + np.testing.assert_array_almost_equal(result.values, data, decimal=5) + assert result.attrs.get('crs') == 4326 + + def test_read_geotiff_name(self, tmp_path): + """DataArray name defaults to filename stem.""" + arr = np.zeros((4, 4), dtype=np.float32) + path = str(tmp_path / 'myfile.tif') + write(arr, path, compression='none', tiled=False) + + da = read_geotiff(path) + assert da.name == 'myfile' + + def test_read_geotiff_custom_name(self, tmp_path): + arr = np.zeros((4, 4), dtype=np.float32) + path = str(tmp_path / 'test.tif') + write(arr, path, compression='none', tiled=False) + + da = read_geotiff(path, name='custom') + assert da.name == 'custom' + + def test_write_numpy_array(self, tmp_path): + """write_geotiff should accept raw numpy arrays too.""" + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + path = str(tmp_path / 'numpy.tif') + write_geotiff(arr, path, compression='none') + + result = read_geotiff(path) + np.testing.assert_array_equal(result.values, arr) + + def test_write_3d_rgb(self, tmp_path): + """3D arrays (height, width, bands) should write multi-band.""" + arr = np.zeros((4, 4, 3), dtype=np.uint8) + arr[:, :, 0] = 255 # red channel + path = str(tmp_path / 'rgb.tif') + write_geotiff(arr, path, compression='none') + + result = read_geotiff(path) + np.testing.assert_array_equal(result.values, arr) + + def test_write_rejects_4d(self, tmp_path): + arr = np.zeros((2, 3, 4, 4), dtype=np.float32) + with pytest.raises(ValueError, match="Expected 2D or 3D"): + write_geotiff(arr, str(tmp_path / 'bad.tif')) + + +def read_to_array_local(path): + """Helper to call read_to_array for local files.""" + from xrspatial.geotiff._reader import read_to_array + return read_to_array(path) diff --git a/xrspatial/geotiff/tests/test_compression.py b/xrspatial/geotiff/tests/test_compression.py new file mode 100644 index 00000000..a296ab88 --- /dev/null +++ b/xrspatial/geotiff/tests/test_compression.py @@ -0,0 +1,129 @@ +"""Tests for compression codecs.""" +from __future__ import annotations + +import zlib + +import numpy as np +import pytest + +from xrspatial.geotiff._compression import ( + COMPRESSION_DEFLATE, + COMPRESSION_LZW, + COMPRESSION_NONE, + compress, + decompress, + deflate_compress, + deflate_decompress, + lzw_compress, + lzw_decompress, + predictor_decode, + predictor_encode, +) + + +class TestDeflate: + def test_round_trip(self): + data = b'hello world! ' * 100 + compressed = deflate_compress(data) + assert compressed != data + assert deflate_decompress(compressed) == data + + def test_empty(self): + compressed = deflate_compress(b'') + assert deflate_decompress(compressed) == b'' + + def test_binary_data(self): + data = bytes(range(256)) * 10 + compressed = deflate_compress(data) + assert deflate_decompress(compressed) == data + + +class TestLZW: + def test_round_trip_simple(self): + data = b'ABCABCABCABC' + compressed = lzw_compress(data) + decompressed = lzw_decompress(compressed, len(data)) + assert decompressed.tobytes() == data + + def test_round_trip_repetitive(self): + data = b'\x00' * 1000 + compressed = lzw_compress(data) + decompressed = lzw_decompress(compressed, len(data)) + assert decompressed.tobytes() == data + + def test_round_trip_sequential(self): + data = bytes(range(256)) + compressed = lzw_compress(data) + decompressed = lzw_decompress(compressed, len(data)) + assert decompressed.tobytes() == data + + def test_round_trip_random(self): + rng = np.random.RandomState(42) + data = bytes(rng.randint(0, 256, size=500, dtype=np.uint8)) + compressed = lzw_compress(data) + decompressed = lzw_decompress(compressed, len(data)) + assert decompressed.tobytes() == data + + def test_round_trip_large(self): + rng = np.random.RandomState(123) + data = bytes(rng.randint(0, 256, size=10000, dtype=np.uint8)) + compressed = lzw_compress(data) + decompressed = lzw_decompress(compressed, len(data)) + assert decompressed.tobytes() == data + + def test_empty(self): + compressed = lzw_compress(b'') + decompressed = lzw_decompress(compressed, 0) + assert decompressed.tobytes() == b'' + + +class TestPredictor: + def test_round_trip_uint8(self): + # 4x4 image, 1 byte per sample + data = np.array([10, 20, 30, 40, 50, 60, 70, 80, + 90, 100, 110, 120, 130, 140, 150, 160], + dtype=np.uint8) + encoded = predictor_encode(data.copy(), 4, 4, 1) + decoded = predictor_decode(encoded.copy(), 4, 4, 1) + np.testing.assert_array_equal(decoded, data) + + def test_round_trip_float32(self): + # 2x3 image, 4 bytes per sample + arr = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=np.float32) + raw = np.frombuffer(arr.tobytes(), dtype=np.uint8).copy() + encoded = predictor_encode(raw.copy(), 3, 2, 4) + decoded = predictor_decode(encoded.copy(), 3, 2, 4) + np.testing.assert_array_equal(decoded, raw) + + def test_predictor_encode_differences(self): + # First pixel unchanged, rest are differences + data = np.array([10, 20, 30, 40], dtype=np.uint8) + encoded = predictor_encode(data.copy(), 4, 1, 1) + assert encoded[0] == 10 + assert encoded[1] == 10 # 20 - 10 + assert encoded[2] == 10 # 30 - 20 + assert encoded[3] == 10 # 40 - 30 + + +class TestDispatch: + def test_none(self): + data = b'hello' + assert decompress(data, COMPRESSION_NONE).tobytes() == data + assert compress(data, COMPRESSION_NONE) == data + + def test_deflate(self): + data = b'test data ' * 50 + compressed = compress(data, COMPRESSION_DEFLATE) + assert decompress(compressed, COMPRESSION_DEFLATE).tobytes() == data + + def test_lzw(self): + data = b'ABCABC' * 20 + compressed = compress(data, COMPRESSION_LZW) + decompressed = decompress(compressed, COMPRESSION_LZW, len(data)) + assert decompressed.tobytes() == data + + def test_unsupported(self): + with pytest.raises(ValueError, match="Unsupported compression"): + decompress(b'', 99) + with pytest.raises(ValueError, match="Unsupported compression"): + compress(b'', 99) diff --git a/xrspatial/geotiff/tests/test_edge_cases.py b/xrspatial/geotiff/tests/test_edge_cases.py new file mode 100644 index 00000000..1a8a8680 --- /dev/null +++ b/xrspatial/geotiff/tests/test_edge_cases.py @@ -0,0 +1,660 @@ +"""Edge case tests for invalid, corrupt, and boundary-condition inputs.""" +from __future__ import annotations + +import struct +import zlib + +import numpy as np +import pytest +import xarray as xr + +from xrspatial.geotiff import read_geotiff, write_geotiff +from xrspatial.geotiff._compression import ( + COMPRESSION_DEFLATE, + COMPRESSION_LZW, + COMPRESSION_NONE, + compress, + decompress, + deflate_decompress, + lzw_compress, + lzw_decompress, +) +from xrspatial.geotiff._dtypes import numpy_to_tiff_dtype, tiff_dtype_to_numpy +from xrspatial.geotiff._header import parse_all_ifds, parse_header +from xrspatial.geotiff._reader import read_to_array +from xrspatial.geotiff._writer import write + + +# ----------------------------------------------------------------------- +# Writer: invalid inputs +# ----------------------------------------------------------------------- + +class TestWriteInvalidInputs: + """Writer should reject or gracefully handle bad inputs.""" + + def test_4d_array(self, tmp_path): + arr = np.zeros((2, 3, 4, 4), dtype=np.float32) + with pytest.raises(ValueError, match="Expected 2D or 3D"): + write_geotiff(arr, str(tmp_path / 'bad.tif')) + + def test_1d_array(self, tmp_path): + arr = np.zeros(10, dtype=np.float32) + with pytest.raises(ValueError, match="Expected 2D"): + write_geotiff(arr, str(tmp_path / 'bad.tif')) + + def test_0d_scalar(self, tmp_path): + arr = np.float32(42.0) + with pytest.raises(ValueError, match="Expected 2D"): + write_geotiff(arr, str(tmp_path / 'bad.tif')) + + def test_unsupported_compression(self, tmp_path): + arr = np.zeros((4, 4), dtype=np.float32) + with pytest.raises(ValueError, match="Unsupported compression"): + write_geotiff(arr, str(tmp_path / 'bad.tif'), compression='jpeg2000') + + def test_complex_dtype(self, tmp_path): + arr = np.zeros((4, 4), dtype=np.complex64) + with pytest.raises(ValueError, match="Unsupported numpy dtype"): + write_geotiff(arr, str(tmp_path / 'bad.tif')) + + def test_bool_dtype_auto_promoted(self, tmp_path): + """Bool arrays are auto-promoted to uint8.""" + arr = np.array([[True, False], [False, True]]) + path = str(tmp_path / 'bool.tif') + write_geotiff(arr, path, compression='none') + + result = read_geotiff(path) + np.testing.assert_array_equal(result.values, arr.astype(np.uint8)) + + +# ----------------------------------------------------------------------- +# Writer: boundary-condition data values +# ----------------------------------------------------------------------- + +class TestWriteSpecialValues: + """Writer should handle NaN, Inf, and extreme values.""" + + def test_all_nan(self, tmp_path): + arr = np.full((4, 4), np.nan, dtype=np.float32) + path = str(tmp_path / 'all_nan.tif') + write(arr, path, compression='none', tiled=False) + + result, _ = read_to_array(path) + assert np.all(np.isnan(result)) + + def test_nan_and_inf(self, tmp_path): + arr = np.array([[np.nan, np.inf], [-np.inf, 0.0]], dtype=np.float32) + path = str(tmp_path / 'nan_inf.tif') + write(arr, path, compression='none', tiled=False) + + result, _ = read_to_array(path) + assert np.isnan(result[0, 0]) + assert np.isposinf(result[0, 1]) + assert np.isneginf(result[1, 0]) + assert result[1, 1] == 0.0 + + def test_nan_with_deflate(self, tmp_path): + arr = np.array([[np.nan, 1.0], [2.0, np.nan]], dtype=np.float32) + path = str(tmp_path / 'nan_deflate.tif') + write(arr, path, compression='deflate', tiled=False) + + result, _ = read_to_array(path) + assert np.isnan(result[0, 0]) + assert np.isnan(result[1, 1]) + assert result[0, 1] == 1.0 + assert result[1, 0] == 2.0 + + def test_nan_with_lzw(self, tmp_path): + arr = np.array([[np.nan, 1.0], [2.0, np.nan]], dtype=np.float32) + path = str(tmp_path / 'nan_lzw.tif') + write(arr, path, compression='lzw', tiled=False) + + result, _ = read_to_array(path) + assert np.isnan(result[0, 0]) + assert np.isnan(result[1, 1]) + + def test_float32_extremes(self, tmp_path): + finfo = np.finfo(np.float32) + arr = np.array([[finfo.max, finfo.min], + [finfo.tiny, -finfo.tiny]], dtype=np.float32) + path = str(tmp_path / 'extremes.tif') + write(arr, path, compression='deflate', tiled=False) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_uint16_full_range(self, tmp_path): + arr = np.array([[0, 65535], [1, 65534]], dtype=np.uint16) + path = str(tmp_path / 'uint16_range.tif') + write(arr, path, compression='none', tiled=False) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_int16_negative(self, tmp_path): + arr = np.array([[-32768, 32767], [-1, 0]], dtype=np.int16) + path = str(tmp_path / 'int16.tif') + write(arr, path, compression='none', tiled=False) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_all_zeros(self, tmp_path): + arr = np.zeros((8, 8), dtype=np.float32) + path = str(tmp_path / 'zeros.tif') + write(arr, path, compression='deflate', tiled=False) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_all_same_value(self, tmp_path): + arr = np.full((16, 16), 42.5, dtype=np.float32) + path = str(tmp_path / 'constant.tif') + write(arr, path, compression='lzw', tiled=False) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + +# ----------------------------------------------------------------------- +# Writer: boundary-condition shapes +# ----------------------------------------------------------------------- + +class TestWriteBoundaryShapes: + """Test extreme and non-aligned image dimensions.""" + + def test_single_pixel(self, tmp_path): + arr = np.array([[42.0]], dtype=np.float32) + path = str(tmp_path / '1x1.tif') + write(arr, path, compression='none', tiled=False) + + result, _ = read_to_array(path) + assert result.shape == (1, 1) + assert result[0, 0] == 42.0 + + def test_single_row(self, tmp_path): + arr = np.arange(10, dtype=np.float32).reshape(1, 10) + path = str(tmp_path / '1x10.tif') + write(arr, path, compression='deflate', tiled=False) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_single_column(self, tmp_path): + arr = np.arange(10, dtype=np.float32).reshape(10, 1) + path = str(tmp_path / '10x1.tif') + write(arr, path, compression='deflate', tiled=False) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_non_tile_aligned(self, tmp_path): + """Image dimensions not divisible by tile size.""" + arr = np.arange(35, dtype=np.float32).reshape(5, 7) + path = str(tmp_path / 'non_aligned.tif') + write(arr, path, compression='none', tiled=True, tile_size=4) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_tile_larger_than_image(self, tmp_path): + """Tile size larger than the image.""" + arr = np.arange(6, dtype=np.float32).reshape(2, 3) + path = str(tmp_path / 'big_tile.tif') + write(arr, path, compression='deflate', tiled=True, tile_size=256) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_odd_dimensions_all_compressions(self, tmp_path): + """Non-power-of-2 dimensions with every compression.""" + arr = np.random.RandomState(99).rand(13, 17).astype(np.float32) + for comp in ['none', 'deflate', 'lzw']: + path = str(tmp_path / f'odd_{comp}.tif') + write(arr, path, compression=comp, tiled=False) + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_very_wide_single_row_tiled(self, tmp_path): + """1 row, many columns, tiled layout.""" + arr = np.arange(500, dtype=np.float32).reshape(1, 500) + path = str(tmp_path / 'wide.tif') + write(arr, path, compression='none', tiled=True, tile_size=64) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_very_tall_single_column_tiled(self, tmp_path): + """Many rows, 1 column, tiled layout.""" + arr = np.arange(500, dtype=np.float32).reshape(500, 1) + path = str(tmp_path / 'tall.tif') + write(arr, path, compression='deflate', tiled=True, tile_size=64) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_predictor_single_pixel(self, tmp_path): + """Predictor on a 1x1 image (no pixels to difference against).""" + arr = np.array([[7.5]], dtype=np.float32) + path = str(tmp_path / 'pred_1x1.tif') + write(arr, path, compression='deflate', tiled=False, predictor=True) + + result, _ = read_to_array(path) + assert result[0, 0] == pytest.approx(7.5) + + +# ----------------------------------------------------------------------- +# Reader: corrupt / truncated files +# ----------------------------------------------------------------------- + +class TestReadCorruptFiles: + """Reader should raise clear errors on malformed input.""" + + def test_empty_file(self, tmp_path): + path = str(tmp_path / 'empty.tif') + with open(path, 'wb') as f: + pass # 0 bytes + with pytest.raises((ValueError, Exception)): + read_to_array(path) + + def test_too_short_for_header(self, tmp_path): + path = str(tmp_path / 'short.tif') + with open(path, 'wb') as f: + f.write(b'II\x2a\x00') # only 4 bytes, need 8 + with pytest.raises((ValueError, Exception)): + read_to_array(path) + + def test_random_bytes(self, tmp_path): + path = str(tmp_path / 'random.tif') + with open(path, 'wb') as f: + f.write(b'\xde\xad\xbe\xef' * 100) + with pytest.raises(ValueError, match="Invalid TIFF"): + read_to_array(path) + + def test_valid_header_but_no_ifd(self, tmp_path): + """TIFF header pointing to IFD beyond file end.""" + path = str(tmp_path / 'no_ifd.tif') + # Valid LE TIFF header pointing to offset 99999 which doesn't exist + with open(path, 'wb') as f: + f.write(b'II') + f.write(struct.pack(' name lookup works for known codes.""" + from xrspatial.geotiff._geotags import ANGULAR_UNITS, LINEAR_UNITS + assert ANGULAR_UNITS[9102] == 'degree' + assert ANGULAR_UNITS[9101] == 'radian' + assert LINEAR_UNITS[9001] == 'metre' + assert LINEAR_UNITS[9002] == 'foot' + assert LINEAR_UNITS[9003] == 'us_survey_foot' + + def test_crs_wkt_from_epsg(self, tmp_path): + """crs_wkt is resolved from EPSG via pyproj.""" + from xrspatial.geotiff._geotags import GeoTransform + arr = np.ones((4, 4), dtype=np.float32) + gt = GeoTransform(-120.0, 45.0, 0.001, -0.001) + path = str(tmp_path / 'wkt.tif') + write(arr, path, compression='none', tiled=False, + geo_transform=gt, crs_epsg=4326) + + da = read_geotiff(path) + assert 'crs_wkt' in da.attrs + wkt = da.attrs['crs_wkt'] + assert 'WGS 84' in wkt or '4326' in wkt + + def test_write_with_wkt_string(self, tmp_path): + """crs= accepts a WKT string and resolves to EPSG.""" + arr = np.ones((4, 4), dtype=np.float32) + wkt = ('GEOGCRS["WGS 84",DATUM["World Geodetic System 1984",' + 'ELLIPSOID["WGS 84",6378137,298.257223563]],' + 'CS[ellipsoidal,2],' + 'AXIS["geodetic latitude (Lat)",north],' + 'AXIS["geodetic longitude (Lon)",east],' + 'UNIT["degree",0.0174532925199433],' + 'ID["EPSG",4326]]') + path = str(tmp_path / 'wkt_in.tif') + write_geotiff(arr, path, crs=wkt, compression='none') + + da = read_geotiff(path) + assert da.attrs['crs'] == 4326 + + def test_write_with_proj_string(self, tmp_path): + """crs= accepts a PROJ string.""" + arr = np.ones((4, 4), dtype=np.float32) + path = str(tmp_path / 'proj_in.tif') + write_geotiff(arr, path, crs='+proj=utm +zone=18 +datum=NAD83', + compression='none') + + da = read_geotiff(path) + # pyproj should resolve this to EPSG:26918 + assert da.attrs.get('crs') is not None + + def test_crs_wkt_attr_round_trip(self, tmp_path): + """DataArray with crs_wkt attr (no int crs) round-trips.""" + wkt = ('GEOGCRS["WGS 84",DATUM["World Geodetic System 1984",' + 'ELLIPSOID["WGS 84",6378137,298.257223563]],' + 'CS[ellipsoidal,2],' + 'AXIS["geodetic latitude (Lat)",north],' + 'AXIS["geodetic longitude (Lon)",east],' + 'UNIT["degree",0.0174532925199433],' + 'ID["EPSG",4326]]') + y = np.linspace(45.0, 44.0, 4) + x = np.linspace(-120.0, -119.0, 4) + da = xr.DataArray(np.ones((4, 4), dtype=np.float32), + dims=['y', 'x'], coords={'y': y, 'x': x}, + attrs={'crs_wkt': wkt}) + path = str(tmp_path / 'wkt_rt.tif') + write_geotiff(da, path, compression='none') + + result = read_geotiff(path) + assert result.attrs['crs'] == 4326 + assert 'crs_wkt' in result.attrs + + def test_no_crs_no_wkt(self, tmp_path): + """File without CRS has no crs_wkt attr.""" + arr = np.ones((4, 4), dtype=np.float32) + path = str(tmp_path / 'no_wkt.tif') + write(arr, path, compression='none', tiled=False) + + da = read_geotiff(path) + assert 'crs_wkt' not in da.attrs + + +# ----------------------------------------------------------------------- +# Resolution / DPI tags +# ----------------------------------------------------------------------- + +# ----------------------------------------------------------------------- +# GDAL metadata (tag 42112) +# ----------------------------------------------------------------------- + +# ----------------------------------------------------------------------- +# Arbitrary tag preservation +# ----------------------------------------------------------------------- + +# ----------------------------------------------------------------------- +# Big-endian pixel data +# ----------------------------------------------------------------------- + +# ----------------------------------------------------------------------- +# Cloud storage (fsspec) support +# ----------------------------------------------------------------------- + +# ----------------------------------------------------------------------- +# VRT (Virtual Raster Table) support +# ----------------------------------------------------------------------- + +# ----------------------------------------------------------------------- +# Fixes: band-first, MinIsWhite, ExtraSamples, float16, VRT write, etc. +# ----------------------------------------------------------------------- + +class TestFixesBatch: + + def test_band_first_dataarray(self, tmp_path): + """DataArray with (band, y, x) dims is transposed before write.""" + arr = np.zeros((3, 8, 8), dtype=np.uint8) + arr[0] = 200 # red + arr[1] = 100 # green + arr[2] = 50 # blue + + da = xr.DataArray(arr, dims=['band', 'y', 'x']) + path = str(tmp_path / 'band_first.tif') + write_geotiff(da, path, compression='none') + + result = read_geotiff(path) + assert result.shape == (8, 8, 3) + assert result.values[0, 0, 0] == 200 # red channel + assert result.values[0, 0, 1] == 100 # green channel + + def test_band_last_dataarray_unchanged(self, tmp_path): + """DataArray with (y, x, band) dims is not transposed.""" + arr = np.zeros((8, 8, 3), dtype=np.uint8) + arr[:, :, 0] = 200 + da = xr.DataArray(arr, dims=['y', 'x', 'band']) + path = str(tmp_path / 'band_last.tif') + write_geotiff(da, path, compression='none') + + result = read_geotiff(path) + assert result.shape == (8, 8, 3) + assert result.values[0, 0, 0] == 200 + + def test_min_is_white_inversion(self, tmp_path): + """MinIsWhite (photometric=0) inverts grayscale values on read.""" + from .conftest import make_minimal_tiff + import struct + + # Build a minimal TIFF with photometric=0 + # The conftest doesn't support photometric param, so build manually + bo = '<' + width, height = 4, 4 + pixels = np.array([[0, 50, 100, 200]], dtype=np.uint8).repeat(4, axis=0) + + tag_list = [] + def add_short(tag, val): + tag_list.append((tag, 3, 1, struct.pack(f'{bo}H', val))) + def add_long(tag, val): + tag_list.append((tag, 4, 1, struct.pack(f'{bo}I', val))) + + add_short(256, width) + add_short(257, height) + add_short(258, 8) + add_short(259, 1) + add_short(262, 0) # MinIsWhite + add_short(277, 1) + add_short(278, height) + add_long(273, 0) + add_long(279, len(pixels.tobytes())) + add_short(339, 1) + + tag_list.sort(key=lambda t: t[0]) + num_entries = len(tag_list) + ifd_start = 8 + ifd_size = 2 + 12 * num_entries + 4 + overflow_start = ifd_start + ifd_size + pixel_start = overflow_start + # Patch strip offset + for i, (tag, typ, count, raw) in enumerate(tag_list): + if tag == 273: + tag_list[i] = (tag, typ, count, struct.pack(f'{bo}I', pixel_start)) + + out = bytearray() + out.extend(b'II') + out.extend(struct.pack(f'{bo}H', 42)) + out.extend(struct.pack(f'{bo}I', ifd_start)) + out.extend(struct.pack(f'{bo}H', num_entries)) + for tag, typ, count, raw in tag_list: + out.extend(struct.pack(f'{bo}HHI', tag, typ, count)) + out.extend(raw.ljust(4, b'\x00')) + out.extend(struct.pack(f'{bo}I', 0)) + out.extend(pixels.tobytes()) + + path = str(tmp_path / 'miniswhite.tif') + with open(path, 'wb') as f: + f.write(bytes(out)) + + from xrspatial.geotiff._reader import read_to_array + result, _ = read_to_array(path) + # MinIsWhite: 0 -> 255, 50 -> 205, 100 -> 155, 200 -> 55 + assert result[0, 0] == 255 + assert result[0, 1] == 205 + assert result[0, 2] == 155 + assert result[0, 3] == 55 + + def test_extra_samples_rgba(self, tmp_path): + """RGBA write includes ExtraSamples tag.""" + from xrspatial.geotiff._header import parse_header, parse_all_ifds, TAG_EXTRA_SAMPLES + arr = np.ones((4, 4, 4), dtype=np.uint8) * 128 + path = str(tmp_path / 'rgba.tif') + write(arr, path, compression='none', tiled=False) + + with open(path, 'rb') as f: + data = f.read() + header = parse_header(data) + ifds = parse_all_ifds(data, header) + ifd = ifds[0] + extra = ifd.entries.get(TAG_EXTRA_SAMPLES) + assert extra is not None + # Value 2 = unassociated alpha + assert extra.value == 2 or (isinstance(extra.value, tuple) and extra.value[0] == 2) + + def test_float16_auto_promotion(self, tmp_path): + """Float16 arrays are auto-promoted to float32.""" + arr = np.ones((4, 4), dtype=np.float16) * 3.14 + path = str(tmp_path / 'f16.tif') + write_geotiff(arr, path, compression='none') + + result = read_geotiff(path) + assert result.dtype == np.float32 + np.testing.assert_array_almost_equal(result.values, 3.14, decimal=2) + + def test_vrt_write_and_read_back(self, tmp_path): + """write_vrt generates a valid VRT that reads back correctly.""" + from xrspatial.geotiff import write_vrt + from xrspatial.geotiff._geotags import GeoTransform + + # Write two tiles with known geo transforms + left = np.arange(16, dtype=np.float32).reshape(4, 4) + right = np.arange(16, 32, dtype=np.float32).reshape(4, 4) + + gt_left = GeoTransform(origin_x=0.0, origin_y=4.0, + pixel_width=1.0, pixel_height=-1.0) + gt_right = GeoTransform(origin_x=4.0, origin_y=4.0, + pixel_width=1.0, pixel_height=-1.0) + + lpath = str(tmp_path / 'left.tif') + rpath = str(tmp_path / 'right.tif') + write(left, lpath, geo_transform=gt_left, compression='none', tiled=False) + write(right, rpath, geo_transform=gt_right, compression='none', tiled=False) + + vrt_path = str(tmp_path / 'mosaic.vrt') + write_vrt(vrt_path, [lpath, rpath]) + + da = read_geotiff(vrt_path) + assert da.shape == (4, 8) + np.testing.assert_array_equal(da.values[:, :4], left) + np.testing.assert_array_equal(da.values[:, 4:], right) + + def test_dask_vrt(self, tmp_path): + """read_geotiff_dask handles VRT files.""" + from xrspatial.geotiff import read_geotiff_dask + + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + tile_path = str(tmp_path / 'tile.tif') + write(arr, tile_path, compression='none', tiled=False) + + vrt_xml = ( + '\n' + ' \n' + ' \n' + f' {os.path.basename(tile_path)}\n' + ' 1\n' + ' \n' + ' \n' + ' \n' + ' \n' + '\n' + ) + vrt_path = str(tmp_path / 'dask.vrt') + with open(vrt_path, 'w') as f: + f.write(vrt_xml) + + import dask.array as da + result = read_geotiff_dask(vrt_path, chunks=2) + assert isinstance(result.data, da.Array) + computed = result.compute() + np.testing.assert_array_equal(computed.values, arr) + + +class TestVRT: + + def _write_tile(self, tmp_path, name, data): + """Write a GeoTIFF tile and return its path.""" + from xrspatial.geotiff._writer import write + path = str(tmp_path / name) + write(data, path, compression='none', tiled=False) + return path + + def _make_mosaic_vrt(self, tmp_path, tile_paths, tile_shapes, + tile_offsets, width, height, dtype='Float32'): + """Build a VRT XML that mosaics multiple tiles.""" + lines = [ + f'', + ' 0.0, 1.0, 0.0, 0.0, 0.0, -1.0', + f' ', + ] + for path, (th, tw), (yo, xo) in zip(tile_paths, tile_shapes, tile_offsets): + lines.append(' ') + lines.append(f' {os.path.basename(path)}') + lines.append(' 1') + lines.append(f' ') + lines.append(f' ') + lines.append(' ') + lines.append(' ') + lines.append('') + + vrt_path = str(tmp_path / 'mosaic.vrt') + with open(vrt_path, 'w') as f: + f.write('\n'.join(lines)) + return vrt_path + + def test_single_tile_vrt(self, tmp_path): + """VRT with one source tile reads correctly.""" + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + tile_path = self._write_tile(tmp_path, 'tile.tif', arr) + + vrt_path = self._make_mosaic_vrt( + tmp_path, + [tile_path], [(4, 4)], [(0, 0)], + width=4, height=4, + ) + + da = read_geotiff(vrt_path) + np.testing.assert_array_equal(da.values, arr) + + def test_2x1_mosaic(self, tmp_path): + """VRT that tiles two images side-by-side.""" + left = np.arange(16, dtype=np.float32).reshape(4, 4) + right = np.arange(16, 32, dtype=np.float32).reshape(4, 4) + lpath = self._write_tile(tmp_path, 'left.tif', left) + rpath = self._write_tile(tmp_path, 'right.tif', right) + + vrt_path = self._make_mosaic_vrt( + tmp_path, + [lpath, rpath], [(4, 4), (4, 4)], [(0, 0), (0, 4)], + width=8, height=4, + ) + + da = read_geotiff(vrt_path) + assert da.shape == (4, 8) + np.testing.assert_array_equal(da.values[:, :4], left) + np.testing.assert_array_equal(da.values[:, 4:], right) + + def test_2x2_mosaic(self, tmp_path): + """VRT that tiles four images in a 2x2 grid.""" + tiles = [] + paths = [] + offsets = [] + for r in range(2): + for c in range(2): + base = (r * 2 + c) * 16 + arr = np.arange(base, base + 16, dtype=np.float32).reshape(4, 4) + name = f'tile_{r}_{c}.tif' + paths.append(self._write_tile(tmp_path, name, arr)) + tiles.append(arr) + offsets.append((r * 4, c * 4)) + + vrt_path = self._make_mosaic_vrt( + tmp_path, + paths, [(4, 4)] * 4, offsets, + width=8, height=8, + ) + + da = read_geotiff(vrt_path) + assert da.shape == (8, 8) + # Check each quadrant + np.testing.assert_array_equal(da.values[0:4, 0:4], tiles[0]) + np.testing.assert_array_equal(da.values[0:4, 4:8], tiles[1]) + np.testing.assert_array_equal(da.values[4:8, 0:4], tiles[2]) + np.testing.assert_array_equal(da.values[4:8, 4:8], tiles[3]) + + def test_windowed_vrt_read(self, tmp_path): + """Windowed read of a VRT mosaic.""" + left = np.arange(16, dtype=np.float32).reshape(4, 4) + right = np.arange(16, 32, dtype=np.float32).reshape(4, 4) + lpath = self._write_tile(tmp_path, 'left.tif', left) + rpath = self._write_tile(tmp_path, 'right.tif', right) + + vrt_path = self._make_mosaic_vrt( + tmp_path, + [lpath, rpath], [(4, 4), (4, 4)], [(0, 0), (0, 4)], + width=8, height=4, + ) + + # Window spanning both tiles + da = read_geotiff(vrt_path, window=(1, 2, 3, 6)) + assert da.shape == (2, 4) + expected = np.hstack([left, right])[1:3, 2:6] + np.testing.assert_array_equal(da.values, expected) + + def test_vrt_with_crs(self, tmp_path): + """VRT with SRS tag populates CRS in attrs.""" + arr = np.ones((4, 4), dtype=np.float32) + tile_path = self._write_tile(tmp_path, 'tile.tif', arr) + + vrt_xml = ( + '\n' + ' EPSG:4326\n' + ' -120.0, 0.001, 0.0, 45.0, 0.0, -0.001\n' + ' \n' + ' \n' + f' {os.path.basename(tile_path)}\n' + ' 1\n' + ' \n' + ' \n' + ' \n' + ' \n' + '\n' + ) + vrt_path = str(tmp_path / 'crs.vrt') + with open(vrt_path, 'w') as f: + f.write(vrt_xml) + + da = read_geotiff(vrt_path) + assert da.attrs.get('crs_wkt') == 'EPSG:4326' + assert len(da.coords['x']) == 4 + assert len(da.coords['y']) == 4 + + def test_vrt_nodata(self, tmp_path): + """VRT NoDataValue is stored in attrs.""" + arr = np.array([[1, 2], [3, -9999]], dtype=np.float32) + tile_path = self._write_tile(tmp_path, 'tile.tif', arr) + + vrt_xml = ( + '\n' + ' \n' + ' -9999\n' + ' \n' + f' {os.path.basename(tile_path)}\n' + ' 1\n' + ' \n' + ' \n' + ' \n' + ' \n' + '\n' + ) + vrt_path = str(tmp_path / 'nodata.vrt') + with open(vrt_path, 'w') as f: + f.write(vrt_xml) + + da = read_geotiff(vrt_path) + assert da.attrs.get('nodata') == -9999.0 + + def test_read_vrt_function(self, tmp_path): + """read_vrt() works directly.""" + from xrspatial.geotiff import read_vrt + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + tile_path = self._write_tile(tmp_path, 'tile.tif', arr) + + vrt_path = self._make_mosaic_vrt( + tmp_path, + [tile_path], [(4, 4)], [(0, 0)], + width=4, height=4, + ) + + da = read_vrt(vrt_path) + assert da.name == 'mosaic' + np.testing.assert_array_equal(da.values, arr) + + def test_vrt_parser(self): + """VRT XML parser extracts all fields correctly.""" + from xrspatial.geotiff._vrt import parse_vrt + + xml = ( + '\n' + ' EPSG:32610\n' + ' 500000, 30, 0, 4500000, 0, -30\n' + ' \n' + ' 0\n' + ' \n' + ' /data/tile.tif\n' + ' 1\n' + ' \n' + ' \n' + ' \n' + ' \n' + '\n' + ) + vrt = parse_vrt(xml) + assert vrt.width == 100 + assert vrt.height == 200 + assert vrt.crs_wkt == 'EPSG:32610' + assert vrt.geo_transform == (500000.0, 30.0, 0.0, 4500000.0, 0.0, -30.0) + assert len(vrt.bands) == 1 + assert vrt.bands[0].dtype == np.uint16 + assert vrt.bands[0].nodata == 0.0 + assert len(vrt.bands[0].sources) == 1 + src = vrt.bands[0].sources[0] + assert src.filename == '/data/tile.tif' + assert src.src_rect.x_off == 10 + + +import os + +class TestCloudStorage: + + def test_cloud_scheme_detection(self): + """Cloud URI schemes are detected correctly.""" + from xrspatial.geotiff._reader import _is_fsspec_uri + assert _is_fsspec_uri('s3://bucket/key.tif') + assert _is_fsspec_uri('gs://bucket/key.tif') + assert _is_fsspec_uri('az://container/blob.tif') + assert _is_fsspec_uri('abfs://container/blob.tif') + assert _is_fsspec_uri('memory:///test.tif') + assert not _is_fsspec_uri('/local/path.tif') + assert not _is_fsspec_uri('http://example.com/file.tif') + assert not _is_fsspec_uri('relative/path.tif') + + def test_memory_filesystem_read_write(self, tmp_path): + """Round-trip through fsspec's in-memory filesystem.""" + import fsspec + + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + + # Write to memory filesystem via fsspec + from xrspatial.geotiff._writer import write, _write_bytes + from xrspatial.geotiff._writer import _assemble_tiff, _write_stripped + from xrspatial.geotiff._compression import COMPRESSION_NONE + + # First write locally, then copy to memory fs + local_path = str(tmp_path / 'test.tif') + write(arr, local_path, compression='none', tiled=False) + + with open(local_path, 'rb') as f: + tiff_bytes = f.read() + + # Put into fsspec memory filesystem + fs = fsspec.filesystem('memory') + fs.pipe('/test.tif', tiff_bytes) + + # Read via _CloudSource + from xrspatial.geotiff._reader import _CloudSource + src = _CloudSource('memory:///test.tif') + data = src.read_all() + assert len(data) == len(tiff_bytes) + assert data == tiff_bytes + + # Range read + chunk = src.read_range(0, 8) + assert chunk == tiff_bytes[:8] + + # Clean up + fs.rm('/test.tif') + + def test_memory_filesystem_full_roundtrip(self, tmp_path): + """write_geotiff + read_geotiff through memory:// filesystem.""" + import fsspec + + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + + # Write locally first, then copy to memory fs + local_path = str(tmp_path / 'local.tif') + write_geotiff(arr, local_path, compression='deflate') + with open(local_path, 'rb') as f: + tiff_bytes = f.read() + + fs = fsspec.filesystem('memory') + fs.pipe('/roundtrip.tif', tiff_bytes) + + # Read from memory filesystem + from xrspatial.geotiff._reader import read_to_array + result, geo = read_to_array('memory:///roundtrip.tif') + np.testing.assert_array_equal(result, arr) + + fs.rm('/roundtrip.tif') + + def test_writer_cloud_scheme_detection(self): + """Writer detects cloud schemes.""" + from xrspatial.geotiff._writer import _is_fsspec_uri + assert _is_fsspec_uri('s3://bucket/key.tif') + assert _is_fsspec_uri('gs://bucket/key.tif') + assert _is_fsspec_uri('az://container/blob.tif') + assert not _is_fsspec_uri('/local/path.tif') + + def test_write_to_memory_filesystem(self, tmp_path): + """_write_bytes can write to fsspec memory filesystem.""" + import fsspec + from xrspatial.geotiff._writer import write + + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + local_path = str(tmp_path / 'src.tif') + write(arr, local_path, compression='none', tiled=False) + with open(local_path, 'rb') as f: + tiff_bytes = f.read() + + # Write via _write_bytes to memory filesystem + from xrspatial.geotiff._writer import _write_bytes + _write_bytes(tiff_bytes, 'memory:///written.tif') + + fs = fsspec.filesystem('memory') + assert fs.exists('/written.tif') + assert fs.cat('/written.tif') == tiff_bytes + + fs.rm('/written.tif') + + +class TestBigEndian: + + def test_float32_big_endian(self, tmp_path): + """Read a big-endian float32 TIFF.""" + from .conftest import make_minimal_tiff + expected = np.arange(16, dtype=np.float32).reshape(4, 4) + tiff_data = make_minimal_tiff(4, 4, np.dtype('float32'), + pixel_data=expected, big_endian=True) + path = str(tmp_path / 'be_f32.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.dtype == np.float32 + np.testing.assert_array_equal(result, expected) + + def test_uint16_big_endian(self, tmp_path): + """Read a big-endian uint16 TIFF.""" + from .conftest import make_minimal_tiff + expected = np.arange(20, dtype=np.uint16).reshape(4, 5) * 1000 + tiff_data = make_minimal_tiff(5, 4, np.dtype('uint16'), + pixel_data=expected, big_endian=True) + path = str(tmp_path / 'be_u16.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.dtype == np.uint16 + np.testing.assert_array_equal(result, expected) + + def test_int32_big_endian(self, tmp_path): + """Read a big-endian int32 TIFF.""" + from .conftest import make_minimal_tiff + expected = np.arange(16, dtype=np.int32).reshape(4, 4) - 8 + tiff_data = make_minimal_tiff(4, 4, np.dtype('int32'), + pixel_data=expected, big_endian=True) + path = str(tmp_path / 'be_i32.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.dtype == np.int32 + np.testing.assert_array_equal(result, expected) + + def test_float64_big_endian(self, tmp_path): + """Read a big-endian float64 TIFF.""" + from .conftest import make_minimal_tiff + expected = np.linspace(-1.0, 1.0, 16, dtype=np.float64).reshape(4, 4) + tiff_data = make_minimal_tiff(4, 4, np.dtype('float64'), + pixel_data=expected, big_endian=True) + path = str(tmp_path / 'be_f64.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.dtype == np.float64 + np.testing.assert_array_almost_equal(result, expected) + + def test_uint8_big_endian_no_swap_needed(self, tmp_path): + """uint8 big-endian needs no byte swap (single byte per sample).""" + from .conftest import make_minimal_tiff + expected = np.arange(16, dtype=np.uint8).reshape(4, 4) + tiff_data = make_minimal_tiff(4, 4, np.dtype('uint8'), + pixel_data=expected, big_endian=True) + path = str(tmp_path / 'be_u8.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, expected) + + def test_big_endian_windowed(self, tmp_path): + """Windowed read of a big-endian TIFF.""" + from .conftest import make_minimal_tiff + expected = np.arange(64, dtype=np.float32).reshape(8, 8) + tiff_data = make_minimal_tiff(8, 8, np.dtype('float32'), + pixel_data=expected, big_endian=True) + path = str(tmp_path / 'be_window.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path, window=(2, 3, 6, 7)) + np.testing.assert_array_equal(result, expected[2:6, 3:7]) + + def test_big_endian_via_public_api(self, tmp_path): + """read_geotiff handles big-endian files.""" + from .conftest import make_minimal_tiff + expected = np.arange(16, dtype=np.float32).reshape(4, 4) + tiff_data = make_minimal_tiff( + 4, 4, np.dtype('float32'), pixel_data=expected, + big_endian=True, + geo_transform=(-120.0, 45.0, 0.001, -0.001), epsg=4326) + path = str(tmp_path / 'be_api.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + da = read_geotiff(path) + assert da.attrs['crs'] == 4326 + np.testing.assert_array_equal(da.values, expected) + + +class TestExtraTags: + + def _make_tiff_with_extra_tags(self, tmp_path): + """Build a TIFF with Software (305) and DateTime (306) tags.""" + import struct + bo = '<' + width, height = 4, 4 + pixels = np.arange(16, dtype=np.float32).reshape(4, 4) + pixel_bytes = pixels.tobytes() + + tag_list = [] + def add_short(tag, val): + tag_list.append((tag, 3, 1, struct.pack(f'{bo}H', val))) + def add_long(tag, val): + tag_list.append((tag, 4, 1, struct.pack(f'{bo}I', val))) + def add_ascii(tag, text): + raw = text.encode('ascii') + b'\x00' + tag_list.append((tag, 2, len(raw), raw)) + + add_short(256, width) + add_short(257, height) + add_short(258, 32) + add_short(259, 1) + add_short(262, 1) + add_short(277, 1) + add_short(278, height) + add_long(273, 0) # placeholder + add_long(279, len(pixel_bytes)) + add_short(339, 3) # float + add_ascii(305, 'TestSoftware v1.0') + add_ascii(306, '2025:01:15 12:00:00') + + tag_list.sort(key=lambda t: t[0]) + num_entries = len(tag_list) + ifd_start = 8 + ifd_size = 2 + 12 * num_entries + 4 + overflow_start = ifd_start + ifd_size + + overflow_buf = bytearray() + tag_offsets = {} + for tag, typ, count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + pixel_data_start = overflow_start + len(overflow_buf) + + patched = [] + for tag, typ, count, raw in tag_list: + if tag == 273: + patched.append((tag, typ, count, struct.pack(f'{bo}I', pixel_data_start))) + else: + patched.append((tag, typ, count, raw)) + tag_list = patched + + overflow_buf = bytearray() + tag_offsets = {} + for tag, typ, count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + out = bytearray() + out.extend(b'II') + out.extend(struct.pack(f'{bo}H', 42)) + out.extend(struct.pack(f'{bo}I', ifd_start)) + out.extend(struct.pack(f'{bo}H', num_entries)) + for tag, typ, count, raw in tag_list: + out.extend(struct.pack(f'{bo}HHI', tag, typ, count)) + if len(raw) <= 4: + out.extend(raw.ljust(4, b'\x00')) + else: + ptr = overflow_start + tag_offsets[tag] + out.extend(struct.pack(f'{bo}I', ptr)) + out.extend(struct.pack(f'{bo}I', 0)) + out.extend(overflow_buf) + out.extend(pixel_bytes) + + path = str(tmp_path / 'extra_tags.tif') + with open(path, 'wb') as f: + f.write(bytes(out)) + return path, pixels + + def test_extra_tags_read(self, tmp_path): + """Extra tags are collected in attrs['extra_tags'].""" + path, _ = self._make_tiff_with_extra_tags(tmp_path) + da = read_geotiff(path) + + extra = da.attrs.get('extra_tags') + assert extra is not None + tag_ids = {t[0] for t in extra} + assert 305 in tag_ids # Software + assert 306 in tag_ids # DateTime + + def test_extra_tags_round_trip(self, tmp_path): + """Extra tags survive read -> write -> read.""" + path, pixels = self._make_tiff_with_extra_tags(tmp_path) + da = read_geotiff(path) + + out_path = str(tmp_path / 'roundtrip.tif') + write_geotiff(da, out_path, compression='none') + + da2 = read_geotiff(out_path) + + # Pixels should match + np.testing.assert_array_equal(da2.values, pixels) + + # Extra tags should survive + extra2 = da2.attrs.get('extra_tags') + assert extra2 is not None + tag_map = {t[0]: t[3] for t in extra2} + assert 305 in tag_map + assert 'TestSoftware v1.0' in str(tag_map[305]) + assert 306 in tag_map + assert '2025:01:15' in str(tag_map[306]) + + def test_no_extra_tags(self, tmp_path): + """Files with only managed tags have no extra_tags attr.""" + arr = np.ones((4, 4), dtype=np.float32) + path = str(tmp_path / 'no_extra.tif') + write(arr, path, compression='none', tiled=False) + + da = read_geotiff(path) + assert 'extra_tags' not in da.attrs + + +class TestGDALMetadata: + + def test_parse_gdal_metadata_xml(self): + """XML parsing extracts dataset and per-band items.""" + from xrspatial.geotiff._geotags import _parse_gdal_metadata + xml = ( + '\n' + ' Generic\n' + ' 100.5\n' + ' -5.2\n' + ' green\n' + '\n' + ) + meta = _parse_gdal_metadata(xml) + assert meta['DataType'] == 'Generic' + assert meta[('STATISTICS_MAX', 0)] == '100.5' + assert meta[('STATISTICS_MIN', 0)] == '-5.2' + assert meta[('BAND_NAME', 1)] == 'green' + + def test_build_gdal_metadata_xml(self): + """Dict serializes back to valid XML.""" + from xrspatial.geotiff._geotags import ( + _build_gdal_metadata_xml, _parse_gdal_metadata) + meta = { + 'DataType': 'Generic', + ('STATS_MAX', 0): '42.0', + ('STATS_MIN', 0): '-1.0', + } + xml = _build_gdal_metadata_xml(meta) + assert '' in xml + assert 'Generic' in xml + assert 'sample="0"' in xml + # Round-trip through parser + reparsed = _parse_gdal_metadata(xml) + assert reparsed == meta + + def test_round_trip_via_file(self, tmp_path): + """GDAL metadata survives write -> read.""" + meta = { + 'DataType': 'Elevation', + ('STATISTICS_MAXIMUM', 0): '2500.0', + ('STATISTICS_MINIMUM', 0): '100.0', + ('STATISTICS_MEAN', 0): '1200.5', + } + from xrspatial.geotiff._geotags import _build_gdal_metadata_xml + xml = _build_gdal_metadata_xml(meta) + + arr = np.ones((4, 4), dtype=np.float32) + path = str(tmp_path / 'gdal_meta.tif') + write(arr, path, compression='none', tiled=False, + gdal_metadata_xml=xml) + + da = read_geotiff(path) + assert 'gdal_metadata' in da.attrs + assert 'gdal_metadata_xml' in da.attrs + result_meta = da.attrs['gdal_metadata'] + assert result_meta['DataType'] == 'Elevation' + assert result_meta[('STATISTICS_MAXIMUM', 0)] == '2500.0' + assert result_meta[('STATISTICS_MEAN', 0)] == '1200.5' + + def test_dataarray_attrs_round_trip(self, tmp_path): + """GDAL metadata from DataArray attrs is preserved.""" + meta = {'Source': 'test', ('BAND', 0): 'dem'} + da = xr.DataArray( + np.ones((4, 4), dtype=np.float32), + dims=['y', 'x'], + attrs={'gdal_metadata': meta}, + ) + path = str(tmp_path / 'da_meta.tif') + write_geotiff(da, path, compression='none') + + result = read_geotiff(path) + assert result.attrs['gdal_metadata']['Source'] == 'test' + assert result.attrs['gdal_metadata'][('BAND', 0)] == 'dem' + + def test_no_metadata_no_attrs(self, tmp_path): + """Files without GDAL metadata don't get the attrs.""" + arr = np.ones((4, 4), dtype=np.float32) + path = str(tmp_path / 'no_meta.tif') + write(arr, path, compression='none', tiled=False) + + da = read_geotiff(path) + assert 'gdal_metadata' not in da.attrs + assert 'gdal_metadata_xml' not in da.attrs + + def test_real_file_metadata(self): + """Real USGS file has GDAL metadata with statistics.""" + import os + path = '../rtxpy/examples/USGS_one_meter_x65y454_NY_LongIsland_Z18_2014.tif' + if not os.path.exists(path): + pytest.skip("Real test files not available") + + da = read_geotiff(path) + meta = da.attrs.get('gdal_metadata') + assert meta is not None + assert 'DataType' in meta + assert ('STATISTICS_MAXIMUM', 0) in meta + + def test_real_file_round_trip(self): + """GDAL metadata survives real-file round-trip.""" + import os, tempfile + path = '../rtxpy/examples/USGS_one_meter_x65y454_NY_LongIsland_Z18_2014.tif' + if not os.path.exists(path): + pytest.skip("Real test files not available") + + da = read_geotiff(path) + orig_meta = da.attrs['gdal_metadata'] + + out = os.path.join(tempfile.mkdtemp(), 'rt.tif') + write_geotiff(da, out, compression='deflate', tiled=False) + + da2 = read_geotiff(out) + for k, v in orig_meta.items(): + assert da2.attrs['gdal_metadata'].get(k) == v, f"Mismatch on {k}" + + +class TestResolution: + + def test_write_read_dpi(self, tmp_path): + """Resolution tags round-trip through write and read.""" + arr = np.ones((4, 4), dtype=np.float32) + path = str(tmp_path / 'dpi.tif') + write(arr, path, compression='none', tiled=False, + x_resolution=300.0, y_resolution=300.0, resolution_unit=2) + + da = read_geotiff(path) + assert da.attrs['x_resolution'] == pytest.approx(300.0, rel=0.01) + assert da.attrs['y_resolution'] == pytest.approx(300.0, rel=0.01) + assert da.attrs['resolution_unit'] == 'inch' + + def test_write_read_cm(self, tmp_path): + """Centimeter resolution unit.""" + arr = np.ones((4, 4), dtype=np.float32) + path = str(tmp_path / 'dpi_cm.tif') + write(arr, path, compression='none', tiled=False, + x_resolution=118.0, y_resolution=118.0, resolution_unit=3) + + da = read_geotiff(path) + assert da.attrs['x_resolution'] == pytest.approx(118.0, rel=0.01) + assert da.attrs['resolution_unit'] == 'centimeter' + + def test_no_resolution_no_attrs(self, tmp_path): + """Files without resolution tags don't get resolution attrs.""" + arr = np.ones((4, 4), dtype=np.float32) + path = str(tmp_path / 'no_dpi.tif') + write(arr, path, compression='none', tiled=False) + + da = read_geotiff(path) + assert 'x_resolution' not in da.attrs + assert 'y_resolution' not in da.attrs + assert 'resolution_unit' not in da.attrs + + def test_dataarray_attrs_round_trip(self, tmp_path): + """Resolution attrs on DataArray are preserved through write/read.""" + da = xr.DataArray( + np.ones((4, 4), dtype=np.float32), + dims=['y', 'x'], + attrs={'x_resolution': 72.0, 'y_resolution': 72.0, + 'resolution_unit': 'inch'}, + ) + path = str(tmp_path / 'da_dpi.tif') + write_geotiff(da, path, compression='none') + + result = read_geotiff(path) + assert result.attrs['x_resolution'] == pytest.approx(72.0, rel=0.01) + assert result.attrs['y_resolution'] == pytest.approx(72.0, rel=0.01) + assert result.attrs['resolution_unit'] == 'inch' + + def test_unit_none(self, tmp_path): + """ResolutionUnit=1 (no unit) round-trips as 'none'.""" + arr = np.ones((4, 4), dtype=np.float32) + path = str(tmp_path / 'no_unit.tif') + write(arr, path, compression='none', tiled=False, + x_resolution=1.0, y_resolution=1.0, resolution_unit=1) + + da = read_geotiff(path) + assert da.attrs['resolution_unit'] == 'none' + + +# ----------------------------------------------------------------------- +# Overview resampling methods +# ----------------------------------------------------------------------- + +class TestOverviewResampling: + + def test_mean_default(self, tmp_path): + """Default mean resampling produces correct 2x2 block averages.""" + from xrspatial.geotiff._writer import _make_overview + arr = np.array([[1, 3, 5, 7], + [2, 4, 6, 8], + [10, 20, 30, 40], + [10, 20, 30, 40]], dtype=np.float32) + ov = _make_overview(arr, 'mean') + assert ov.shape == (2, 2) + # (1+3+2+4)/4 = 2.5 + assert ov[0, 0] == pytest.approx(2.5) + + def test_nearest(self, tmp_path): + """Nearest resampling picks top-left pixel of each 2x2 block.""" + from xrspatial.geotiff._writer import _make_overview + arr = np.array([[10, 20, 30, 40], + [50, 60, 70, 80], + [90, 100, 110, 120], + [130, 140, 150, 160]], dtype=np.uint8) + ov = _make_overview(arr, 'nearest') + assert ov.shape == (2, 2) + assert ov[0, 0] == 10 + assert ov[0, 1] == 30 + assert ov[1, 0] == 90 + assert ov[1, 1] == 110 + + def test_min(self, tmp_path): + from xrspatial.geotiff._writer import _make_overview + arr = np.array([[10, 1, 5, 3], + [20, 2, 6, 4], + [30, 3, 7, 5], + [40, 4, 8, 6]], dtype=np.float32) + ov = _make_overview(arr, 'min') + assert ov[0, 0] == pytest.approx(1.0) + assert ov[0, 1] == pytest.approx(3.0) + + def test_max(self, tmp_path): + from xrspatial.geotiff._writer import _make_overview + arr = np.array([[10, 1, 5, 3], + [20, 2, 6, 4], + [30, 3, 7, 5], + [40, 4, 8, 6]], dtype=np.float32) + ov = _make_overview(arr, 'max') + assert ov[0, 0] == pytest.approx(20.0) + assert ov[1, 1] == pytest.approx(8.0) + + def test_median(self, tmp_path): + from xrspatial.geotiff._writer import _make_overview + arr = np.array([[1, 2, 10, 20], + [3, 100, 30, 40], + [0, 0, 0, 0], + [0, 0, 0, 0]], dtype=np.float32) + ov = _make_overview(arr, 'median') + assert ov.shape == (2, 2) + # median of [1, 2, 3, 100] = 2.5 + assert ov[0, 0] == pytest.approx(2.5) + + def test_mode(self, tmp_path): + """Mode picks the most common value in each 2x2 block.""" + from xrspatial.geotiff._writer import _make_overview + arr = np.array([[1, 1, 2, 3], + [1, 2, 2, 2], + [5, 5, 5, 6], + [5, 7, 6, 6]], dtype=np.uint8) + ov = _make_overview(arr, 'mode') + assert ov[0, 0] == 1 # 1 appears 3 times + assert ov[0, 1] == 2 # 2 appears 3 times + assert ov[1, 0] == 5 # 5 appears 3 times + assert ov[1, 1] == 6 # 6 appears 3 times + + def test_mean_with_nan(self, tmp_path): + """Mean resampling ignores NaN values.""" + from xrspatial.geotiff._writer import _make_overview + arr = np.array([[np.nan, 2, 4, 6], + [1, 3, np.nan, 8], + [10, 20, 30, 40], + [10, 20, 30, 40]], dtype=np.float32) + ov = _make_overview(arr, 'mean') + # nanmean([nan, 2, 1, 3]) = 2.0 + assert ov[0, 0] == pytest.approx(2.0) + + def test_multiband(self, tmp_path): + """Resampling works on 3D (multi-band) arrays.""" + from xrspatial.geotiff._writer import _make_overview + arr = np.zeros((4, 4, 3), dtype=np.uint8) + arr[:, :, 0] = 100 + arr[:, :, 1] = 200 + arr[:, :, 2] = 50 + ov = _make_overview(arr, 'mean') + assert ov.shape == (2, 2, 3) + assert ov[0, 0, 0] == 100 + assert ov[0, 0, 1] == 200 + assert ov[0, 0, 2] == 50 + + def test_cog_round_trip_nearest(self, tmp_path): + """COG with nearest resampling writes and reads back.""" + arr = np.arange(256, dtype=np.float32).reshape(16, 16) + path = str(tmp_path / 'cog_nearest.tif') + write(arr, path, compression='deflate', tiled=True, tile_size=8, + cog=True, overview_levels=[1], overview_resampling='nearest') + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_cog_round_trip_mode(self, tmp_path): + """COG with mode resampling for classified data.""" + arr = np.array([[0, 0, 1, 1, 2, 2, 3, 3], + [0, 0, 1, 1, 2, 2, 3, 3], + [4, 4, 5, 5, 6, 6, 7, 7], + [4, 4, 5, 5, 6, 6, 7, 7], + [0, 0, 1, 1, 2, 2, 3, 3], + [0, 0, 1, 1, 2, 2, 3, 3], + [4, 4, 5, 5, 6, 6, 7, 7], + [4, 4, 5, 5, 6, 6, 7, 7]], dtype=np.uint8) + path = str(tmp_path / 'cog_mode.tif') + write(arr, path, compression='deflate', tiled=True, tile_size=4, + cog=True, overview_levels=[1], overview_resampling='mode') + + # Full res should be exact + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + # Overview should have mode-reduced values + ov, _ = read_to_array(path, overview_level=1) + assert ov.shape == (4, 4) + assert ov[0, 0] == 0 + assert ov[0, 1] == 1 + + def test_write_geotiff_api(self, tmp_path): + """overview_resampling kwarg works through the public API.""" + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'api_nearest.tif') + write_geotiff(arr, path, compression='deflate', + cog=True, overview_resampling='nearest') + + result = read_geotiff(path) + np.testing.assert_array_equal(result.values, arr) + + def test_invalid_method(self): + from xrspatial.geotiff._writer import _make_overview + arr = np.ones((4, 4), dtype=np.float32) + with pytest.raises(ValueError, match="Unknown overview resampling"): + _make_overview(arr, 'bicubic_spline') + + +# ----------------------------------------------------------------------- +# BigTIFF write +# ----------------------------------------------------------------------- + +class TestBigTIFF: + + def test_bigtiff_header_written(self, tmp_path): + """Force BigTIFF via internal threshold by mocking; test header parsing.""" + # We can't easily create a >4GB file in tests, but we can verify + # the BigTIFF path works by writing a small file with bigtiff=True + # through the internal API. + from xrspatial.geotiff._writer import _assemble_tiff, _write_stripped + from xrspatial.geotiff._compression import COMPRESSION_NONE + from xrspatial.geotiff._geotags import GeoTransform + + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + rel_off, bc, chunks = _write_stripped(arr, COMPRESSION_NONE, False) + parts = [(arr, 4, 4, rel_off, bc, chunks)] + + file_bytes = _assemble_tiff( + 4, 4, arr.dtype, COMPRESSION_NONE, False, False, 256, + parts, None, None, None, is_cog=False, raster_type=1) + + # Standard TIFF: magic 42 + header = parse_header(file_bytes) + assert not header.is_bigtiff + + def test_bigtiff_read_write_round_trip(self, tmp_path): + """Test that BigTIFF files produced internally can be read back.""" + from xrspatial.geotiff._writer import ( + _assemble_tiff, _write_stripped, _assemble_standard_layout, + ) + from xrspatial.geotiff._compression import COMPRESSION_NONE + from xrspatial.geotiff._dtypes import numpy_to_tiff_dtype, SHORT, LONG, DOUBLE + from xrspatial.geotiff._header import ( + TAG_IMAGE_WIDTH, TAG_IMAGE_LENGTH, TAG_BITS_PER_SAMPLE, + TAG_COMPRESSION, TAG_PHOTOMETRIC, TAG_SAMPLES_PER_PIXEL, + TAG_SAMPLE_FORMAT, TAG_ROWS_PER_STRIP, + TAG_STRIP_OFFSETS, TAG_STRIP_BYTE_COUNTS, + ) + + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + rel_off, bc, chunks = _write_stripped(arr, COMPRESSION_NONE, False) + bits_per_sample, sample_format = numpy_to_tiff_dtype(arr.dtype) + + tags = [ + (TAG_IMAGE_WIDTH, LONG, 1, 8), + (TAG_IMAGE_LENGTH, LONG, 1, 8), + (TAG_BITS_PER_SAMPLE, SHORT, 1, bits_per_sample), + (TAG_COMPRESSION, SHORT, 1, 1), + (TAG_PHOTOMETRIC, SHORT, 1, 1), + (TAG_SAMPLES_PER_PIXEL, SHORT, 1, 1), + (TAG_SAMPLE_FORMAT, SHORT, 1, sample_format), + (TAG_ROWS_PER_STRIP, SHORT, 1, 8), + (TAG_STRIP_OFFSETS, LONG, len(rel_off), rel_off), + (TAG_STRIP_BYTE_COUNTS, LONG, len(bc), bc), + ] + + parts = [(arr, 8, 8, rel_off, bc, chunks)] + file_bytes = _assemble_standard_layout( + 16, [tags], parts, bigtiff=True) + + path = str(tmp_path / 'bigtiff.tif') + with open(path, 'wb') as f: + f.write(file_bytes) + + header = parse_header(file_bytes) + assert header.is_bigtiff + + result, _ = read_to_array(path) + np.testing.assert_array_equal(result, arr) + + def test_force_bigtiff_via_public_api(self, tmp_path): + """bigtiff=True on write_geotiff forces BigTIFF even for small files.""" + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + path = str(tmp_path / 'forced_bigtiff.tif') + write_geotiff(arr, path, compression='none', bigtiff=True) + + with open(path, 'rb') as f: + header = parse_header(f.read(16)) + assert header.is_bigtiff + + result = read_geotiff(path) + np.testing.assert_array_equal(result.values, arr) + + def test_small_file_stays_classic(self, tmp_path): + """Small files default to classic TIFF (bigtiff=None auto-detects).""" + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + path = str(tmp_path / 'classic.tif') + write_geotiff(arr, path, compression='none') + + with open(path, 'rb') as f: + header = parse_header(f.read(16)) + assert not header.is_bigtiff + + def test_force_bigtiff_false_stays_classic(self, tmp_path): + """bigtiff=False forces classic TIFF.""" + arr = np.arange(16, dtype=np.float32).reshape(4, 4) + path = str(tmp_path / 'forced_classic.tif') + write_geotiff(arr, path, compression='none', bigtiff=False) + + with open(path, 'rb') as f: + header = parse_header(f.read(16)) + assert not header.is_bigtiff + + +# ----------------------------------------------------------------------- +# Sub-byte bit depths (1-bit, 4-bit, 12-bit) +# ----------------------------------------------------------------------- + +def _make_sub_byte_tiff(width, height, bps, pixel_values): + """Build a minimal TIFF with sub-byte BitsPerSample. + + pixel_values: 2D array of unpacked integer values. + Data is packed MSB-first into bytes according to bps. + """ + import struct + bo = '<' + dtype_np = np.dtype('uint8') if bps <= 8 else np.dtype('uint16') + + # Pack pixel values into bytes + flat = pixel_values.ravel() + if bps == 1: + packed = np.packbits(flat.astype(np.uint8)) + elif bps == 4: + n = len(flat) + packed_len = (n + 1) // 2 + packed = np.zeros(packed_len, dtype=np.uint8) + for i in range(n): + if i % 2 == 0: + packed[i // 2] |= (flat[i] & 0x0F) << 4 + else: + packed[i // 2] |= flat[i] & 0x0F + packed = packed + elif bps == 12: + n = len(flat) + n_pairs = n // 2 + remainder = n % 2 + packed_len = n_pairs * 3 + (2 if remainder else 0) + packed = np.zeros(packed_len, dtype=np.uint8) + for i in range(n_pairs): + v0 = int(flat[i * 2]) + v1 = int(flat[i * 2 + 1]) + off = i * 3 + packed[off] = (v0 >> 4) & 0xFF + packed[off + 1] = ((v0 & 0x0F) << 4) | ((v1 >> 8) & 0x0F) + packed[off + 2] = v1 & 0xFF + if remainder: + v = int(flat[-1]) + off = n_pairs * 3 + packed[off] = (v >> 4) & 0xFF + packed[off + 1] = (v & 0x0F) << 4 + else: + raise ValueError(f"Unsupported bps: {bps}") + + pixel_bytes = packed.tobytes() + + # Build tags + tag_list = [] + def add_short(tag, val): + tag_list.append((tag, 3, 1, struct.pack(f'{bo}H', val))) + def add_long(tag, val): + tag_list.append((tag, 4, 1, struct.pack(f'{bo}I', val))) + + add_short(256, width) + add_short(257, height) + add_short(258, bps) + add_short(259, 1) # no compression + add_short(262, 1) # BlackIsZero (works for all bit depths) + add_short(277, 1) + add_short(278, height) + add_long(273, 0) # strip offset placeholder + add_long(279, len(pixel_bytes)) + if bps <= 8: + add_short(339, 1) # UINT + else: + add_short(339, 1) + + tag_list.sort(key=lambda t: t[0]) + num_entries = len(tag_list) + ifd_start = 8 + ifd_size = 2 + 12 * num_entries + 4 + overflow_buf = bytearray() + tag_offsets = {} + overflow_start = ifd_start + ifd_size + + for tag, typ, count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + pixel_data_start = overflow_start + len(overflow_buf) + + # Patch strip offset + patched = [] + for tag, typ, count, raw in tag_list: + if tag == 273: + patched.append((tag, typ, count, struct.pack(f'{bo}I', pixel_data_start))) + else: + patched.append((tag, typ, count, raw)) + tag_list = patched + + # Rebuild overflow after patching + overflow_buf = bytearray() + tag_offsets = {} + for tag, typ, count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + out = bytearray() + out.extend(b'II') + out.extend(struct.pack(f'{bo}H', 42)) + out.extend(struct.pack(f'{bo}I', ifd_start)) + out.extend(struct.pack(f'{bo}H', num_entries)) + + for tag, typ, count, raw in tag_list: + out.extend(struct.pack(f'{bo}HHI', tag, typ, count)) + if len(raw) <= 4: + out.extend(raw.ljust(4, b'\x00')) + else: + ptr = overflow_start + tag_offsets[tag] + out.extend(struct.pack(f'{bo}I', ptr)) + + out.extend(struct.pack(f'{bo}I', 0)) + out.extend(overflow_buf) + out.extend(pixel_bytes) + + return bytes(out), pixel_values + + +class TestSubByteBitDepths: + + def test_1bit_bilevel(self, tmp_path): + """Read a 1-bit bilevel TIFF.""" + pixels = np.array([[1, 0, 1, 0, 1, 0, 1, 0], + [0, 1, 0, 1, 0, 1, 0, 1], + [1, 1, 0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 0, 0, 1, 1]], dtype=np.uint8) + tiff_data, expected = _make_sub_byte_tiff(8, 4, 1, pixels) + path = str(tmp_path / '1bit.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.dtype == np.uint8 + assert result.shape == (4, 8) + np.testing.assert_array_equal(result, expected) + + def test_1bit_non_byte_aligned_width(self, tmp_path): + """1-bit image whose width is not a multiple of 8.""" + pixels = np.array([[1, 0, 1], + [0, 1, 0]], dtype=np.uint8) + tiff_data, expected = _make_sub_byte_tiff(3, 2, 1, pixels) + path = str(tmp_path / '1bit_3wide.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.shape == (2, 3) + np.testing.assert_array_equal(result, expected) + + def test_4bit_nibble(self, tmp_path): + """Read a 4-bit TIFF.""" + pixels = np.array([[0, 1, 2, 3], + [4, 5, 6, 7], + [8, 9, 10, 11], + [12, 13, 14, 15]], dtype=np.uint8) + tiff_data, expected = _make_sub_byte_tiff(4, 4, 4, pixels) + path = str(tmp_path / '4bit.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.dtype == np.uint8 + assert result.shape == (4, 4) + np.testing.assert_array_equal(result, expected) + + def test_4bit_odd_width(self, tmp_path): + """4-bit image with odd width (partial byte at row end).""" + pixels = np.array([[1, 2, 3], + [4, 5, 6]], dtype=np.uint8) + tiff_data, expected = _make_sub_byte_tiff(3, 2, 4, pixels) + path = str(tmp_path / '4bit_odd.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.shape == (2, 3) + np.testing.assert_array_equal(result, expected) + + def test_12bit(self, tmp_path): + """Read a 12-bit TIFF.""" + pixels = np.array([[0, 100, 2048, 4095], + [1000, 2000, 3000, 4000]], dtype=np.uint16) + tiff_data, expected = _make_sub_byte_tiff(4, 2, 12, pixels) + path = str(tmp_path / '12bit.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.dtype == np.uint16 + assert result.shape == (2, 4) + np.testing.assert_array_equal(result, expected) + + def test_unpack_bits_codec_directly(self): + """Test unpack_bits on known packed data.""" + from xrspatial.geotiff._compression import unpack_bits + + # 1-bit: byte 0xA5 = 10100101 -> [1,0,1,0,0,1,0,1] + data = np.array([0xA5], dtype=np.uint8) + result = unpack_bits(data, 1, 8) + np.testing.assert_array_equal(result, [1, 0, 1, 0, 0, 1, 0, 1]) + + # 4-bit: byte 0x3C = 0011_1100 -> [3, 12] + data = np.array([0x3C], dtype=np.uint8) + result = unpack_bits(data, 4, 2) + np.testing.assert_array_equal(result, [3, 12]) + + +# ----------------------------------------------------------------------- +# Planar configuration (separate planes) +# ----------------------------------------------------------------------- + +def _make_planar_tiff(width, height, bands, dtype=np.uint8, tiled=False, + tile_size=4): + """Build a minimal planar-config TIFF (PlanarConfiguration=2) by hand. + + Each band's pixel data is stored as a separate set of strips (or tiles). + Band values: band 0 gets pixel values 10+pixel_idx, band 1 gets 20+, + band 2 gets 30+, etc. + """ + import struct + bo = '<' + + dtype = np.dtype(dtype) + bps = dtype.itemsize * 8 + if dtype.kind == 'f': + sf = 3 + elif dtype.kind == 'i': + sf = 2 + else: + sf = 1 + + # Build per-band pixel arrays + band_arrays = [] + for b in range(bands): + base = (b + 1) * 10 + arr = np.arange(width * height, dtype=dtype).reshape(height, width) + dtype.type(base) + band_arrays.append(arr) + + if tiled: + import math + tw = th = tile_size + tiles_across = math.ceil(width / tw) + tiles_down = math.ceil(height / th) + tiles_per_band = tiles_across * tiles_down + + # Build tile data: all tiles for band 0, then band 1, etc. + tile_blobs = [] + for b in range(bands): + for tr in range(tiles_down): + for tc in range(tiles_across): + tile = np.zeros((th, tw), dtype=dtype) + r0, c0 = tr * th, tc * tw + r1 = min(r0 + th, height) + c1 = min(c0 + tw, width) + tile[:r1 - r0, :c1 - c0] = band_arrays[b][r0:r1, c0:c1] + tile_blobs.append(tile.tobytes()) + + pixel_bytes = b''.join(tile_blobs) + tile_byte_counts = [len(t) for t in tile_blobs] + num_offsets = len(tile_blobs) + else: + # Strips: 1 strip per band (whole image), one set per band + strip_blobs = [] + for b in range(bands): + strip_blobs.append(band_arrays[b].tobytes()) + pixel_bytes = b''.join(strip_blobs) + strip_byte_counts = [len(s) for s in strip_blobs] + num_offsets = bands + + # Build tags + tag_list = [] + def add_short(tag, val): + tag_list.append((tag, 3, 1, struct.pack(f'{bo}H', val))) + def add_shorts(tag, vals): + tag_list.append((tag, 3, len(vals), struct.pack(f'{bo}{len(vals)}H', *vals))) + def add_long(tag, val): + tag_list.append((tag, 4, 1, struct.pack(f'{bo}I', val))) + def add_longs(tag, vals): + tag_list.append((tag, 4, len(vals), struct.pack(f'{bo}{len(vals)}I', *vals))) + + add_short(256, width) + add_short(257, height) + add_shorts(258, [bps] * bands) + add_short(259, 1) # no compression + add_short(262, 2 if bands >= 3 else 1) # RGB or BlackIsZero + add_short(277, bands) + add_short(284, 2) # PlanarConfiguration = Separate + add_shorts(339, [sf] * bands) + + if tiled: + add_short(322, tile_size) + add_short(323, tile_size) + add_longs(324, [0] * num_offsets) # placeholder + add_longs(325, tile_byte_counts) + else: + add_short(278, height) # RowsPerStrip = full image + add_longs(273, [0] * num_offsets) # placeholder + add_longs(279, strip_byte_counts) + + tag_list.sort(key=lambda t: t[0]) + + # Layout + num_entries = len(tag_list) + ifd_start = 8 + ifd_size = 2 + 12 * num_entries + 4 + + # Collect overflow + overflow_buf = bytearray() + tag_offsets = {} + overflow_start = ifd_start + ifd_size + + for tag, typ, count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + pixel_data_start = overflow_start + len(overflow_buf) + + # Patch offsets + offset_tag = 324 if tiled else 273 + patched = [] + for tag, typ, count, raw in tag_list: + if tag == offset_tag: + if tiled: + offs = [] + pos = 0 + for blob in tile_blobs: + offs.append(pixel_data_start + pos) + pos += len(blob) + new_raw = struct.pack(f'{bo}{num_offsets}I', *offs) + else: + offs = [] + pos = 0 + for blob in strip_blobs: + offs.append(pixel_data_start + pos) + pos += len(blob) + new_raw = struct.pack(f'{bo}{num_offsets}I', *offs) + patched.append((tag, typ, count, new_raw)) + else: + patched.append((tag, typ, count, raw)) + tag_list = patched + + # Rebuild overflow + overflow_buf = bytearray() + tag_offsets = {} + for tag, typ, count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + # Serialize + out = bytearray() + out.extend(b'II') + out.extend(struct.pack(f'{bo}H', 42)) + out.extend(struct.pack(f'{bo}I', ifd_start)) + out.extend(struct.pack(f'{bo}H', num_entries)) + + for tag, typ, count, raw in tag_list: + out.extend(struct.pack(f'{bo}HHI', tag, typ, count)) + if len(raw) <= 4: + out.extend(raw.ljust(4, b'\x00')) + else: + ptr = overflow_start + tag_offsets[tag] + out.extend(struct.pack(f'{bo}I', ptr)) + + out.extend(struct.pack(f'{bo}I', 0)) # next IFD + out.extend(overflow_buf) + out.extend(pixel_bytes) + + # Build expected output for verification + expected = np.stack(band_arrays, axis=2) + return bytes(out), expected + + +# ----------------------------------------------------------------------- +# Palette / indexed color (ColorMap tag 320) +# ----------------------------------------------------------------------- + +def _make_palette_tiff(width, height, bps, pixel_values, palette_rgb): + """Build a palette-color TIFF (Photometric=3 + ColorMap tag). + + palette_rgb: list of (R, G, B) tuples, uint16 values (0-65535). + """ + import struct + bo = '<' + n_colors = len(palette_rgb) + assert n_colors == (1 << bps), f"Palette must have {1 << bps} entries for {bps}-bit" + + # Pack pixel data + flat = pixel_values.ravel().astype(np.uint8) + if bps == 8: + pixel_bytes = flat.tobytes() + elif bps == 4: + n = len(flat) + packed_len = (n + 1) // 2 + packed = np.zeros(packed_len, dtype=np.uint8) + for i in range(n): + if i % 2 == 0: + packed[i // 2] |= (flat[i] & 0x0F) << 4 + else: + packed[i // 2] |= flat[i] & 0x0F + pixel_bytes = packed.tobytes() + else: + pixel_bytes = flat.tobytes() + + # Build ColorMap: [R0..R_{n-1}, G0..G_{n-1}, B0..B_{n-1}] + r_vals = [c[0] for c in palette_rgb] + g_vals = [c[1] for c in palette_rgb] + b_vals = [c[2] for c in palette_rgb] + cmap_values = r_vals + g_vals + b_vals + + tag_list = [] + def add_short(tag, val): + tag_list.append((tag, 3, 1, struct.pack(f'{bo}H', val))) + def add_long(tag, val): + tag_list.append((tag, 4, 1, struct.pack(f'{bo}I', val))) + def add_shorts(tag, vals): + tag_list.append((tag, 3, len(vals), struct.pack(f'{bo}{len(vals)}H', *vals))) + + add_short(256, width) + add_short(257, height) + add_short(258, bps) + add_short(259, 1) # no compression + add_short(262, 3) # Photometric = Palette + add_short(277, 1) # SamplesPerPixel = 1 + add_short(278, height) + add_long(273, 0) # StripOffsets placeholder + add_long(279, len(pixel_bytes)) + add_shorts(320, cmap_values) # ColorMap + add_short(339, 1) # SampleFormat = UINT + + tag_list.sort(key=lambda t: t[0]) + num_entries = len(tag_list) + ifd_start = 8 + ifd_size = 2 + 12 * num_entries + 4 + overflow_start = ifd_start + ifd_size + + overflow_buf = bytearray() + tag_offsets = {} + for tag, typ, count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + pixel_data_start = overflow_start + len(overflow_buf) + + patched = [] + for tag, typ, count, raw in tag_list: + if tag == 273: + patched.append((tag, typ, count, struct.pack(f'{bo}I', pixel_data_start))) + else: + patched.append((tag, typ, count, raw)) + tag_list = patched + + overflow_buf = bytearray() + tag_offsets = {} + for tag, typ, count, raw in tag_list: + if len(raw) > 4: + tag_offsets[tag] = len(overflow_buf) + overflow_buf.extend(raw) + if len(overflow_buf) % 2: + overflow_buf.append(0) + else: + tag_offsets[tag] = None + + out = bytearray() + out.extend(b'II') + out.extend(struct.pack(f'{bo}H', 42)) + out.extend(struct.pack(f'{bo}I', ifd_start)) + out.extend(struct.pack(f'{bo}H', num_entries)) + + for tag, typ, count, raw in tag_list: + out.extend(struct.pack(f'{bo}HHI', tag, typ, count)) + if len(raw) <= 4: + out.extend(raw.ljust(4, b'\x00')) + else: + ptr = overflow_start + tag_offsets[tag] + out.extend(struct.pack(f'{bo}I', ptr)) + + out.extend(struct.pack(f'{bo}I', 0)) + out.extend(overflow_buf) + out.extend(pixel_bytes) + + return bytes(out) + + +class TestPalette: + + def test_palette_8bit_read(self, tmp_path): + """Read an 8-bit palette TIFF and verify pixel indices.""" + # 4-color palette: red, green, blue, white + palette = [ + (65535, 0, 0), # 0 = red + (0, 65535, 0), # 1 = green + (0, 0, 65535), # 2 = blue + (65535, 65535, 65535),# 3 = white + ] + [(0, 0, 0)] * 252 # pad to 256 entries for 8-bit + + pixels = np.array([[0, 1, 2, 3], + [3, 2, 1, 0]], dtype=np.uint8) + + tiff_data = _make_palette_tiff(4, 2, 8, pixels, palette) + path = str(tmp_path / 'palette8.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + da = read_geotiff(path) + # Should return raw index values + assert da.dtype == np.uint8 + np.testing.assert_array_equal(da.values, pixels) + + # Should have cmap and colormap_rgba in attrs + assert 'cmap' in da.attrs + assert 'colormap_rgba' in da.attrs + + # Verify the palette colors + rgba = da.attrs['colormap_rgba'] + assert len(rgba) == 256 + assert rgba[0] == pytest.approx((1.0, 0.0, 0.0, 1.0)) + assert rgba[1] == pytest.approx((0.0, 1.0, 0.0, 1.0)) + assert rgba[2] == pytest.approx((0.0, 0.0, 1.0, 1.0)) + + def test_palette_4bit(self, tmp_path): + """Read a 4-bit palette TIFF.""" + palette = [(i * 4369, i * 4369, i * 4369) for i in range(16)] + pixels = np.array([[0, 5, 10, 15], + [1, 6, 11, 3]], dtype=np.uint8) + + tiff_data = _make_palette_tiff(4, 2, 4, pixels, palette) + path = str(tmp_path / 'palette4.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + da = read_geotiff(path) + assert da.dtype == np.uint8 + np.testing.assert_array_equal(da.values, pixels) + assert 'cmap' in da.attrs + assert len(da.attrs['colormap_rgba']) == 16 + + def test_palette_cmap_works_with_plot(self, tmp_path): + """Verify the colormap can be used with matplotlib.""" + from matplotlib.colors import ListedColormap + + palette = [ + (65535, 0, 0), + (0, 65535, 0), + (0, 0, 65535), + (65535, 65535, 0), + ] + [(0, 0, 0)] * 252 + + pixels = np.array([[0, 1], [2, 3]], dtype=np.uint8) + tiff_data = _make_palette_tiff(2, 2, 8, pixels, palette) + path = str(tmp_path / 'palette_plot.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + da = read_geotiff(path) + cmap = da.attrs['cmap'] + assert isinstance(cmap, ListedColormap) + + # Verify color mapping at known indices + assert cmap(0)[:3] == pytest.approx((1.0, 0.0, 0.0), abs=0.01) + assert cmap(1 / 255)[:3] == pytest.approx((0.0, 1.0, 0.0), abs=0.01) + + def test_xrs_plot_with_palette(self, tmp_path): + """da.xrs.plot() uses the embedded colormap.""" + import matplotlib + matplotlib.use('Agg') + import xrspatial.accessor # register .xrs accessor + + palette = [ + (65535, 0, 0), + (0, 65535, 0), + (0, 0, 65535), + (65535, 65535, 65535), + ] + [(0, 0, 0)] * 252 + + pixels = np.array([[0, 1, 2, 3], + [3, 2, 1, 0]], dtype=np.uint8) + tiff_data = _make_palette_tiff(4, 2, 8, pixels, palette) + path = str(tmp_path / 'plot_palette.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + da = read_geotiff(path) + artist = da.xrs.plot() + assert artist is not None + import matplotlib.pyplot as plt + plt.close('all') + + def test_xrs_plot_no_palette(self, tmp_path): + """da.xrs.plot() falls through to normal plot for non-palette data.""" + import matplotlib + matplotlib.use('Agg') + import xrspatial.accessor + + arr = np.random.RandomState(42).rand(4, 4).astype(np.float32) + path = str(tmp_path / 'no_palette.tif') + write(arr, path, compression='none', tiled=False) + + da = read_geotiff(path) + artist = da.xrs.plot() + assert artist is not None + import matplotlib.pyplot as plt + plt.close('all') + + def test_plot_geotiff_deprecated(self, tmp_path): + """plot_geotiff still works as deprecated wrapper.""" + import matplotlib + matplotlib.use('Agg') + import xrspatial.accessor + from xrspatial.geotiff import plot_geotiff + + palette = [(65535, 0, 0), (0, 65535, 0)] + [(0, 0, 0)] * 254 + pixels = np.array([[0, 1], [1, 0]], dtype=np.uint8) + tiff_data = _make_palette_tiff(2, 2, 8, pixels, palette) + path = str(tmp_path / 'deprecated.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + da = read_geotiff(path) + artist = plot_geotiff(da) + assert artist is not None + import matplotlib.pyplot as plt + plt.close('all') + + def test_non_palette_no_cmap(self, tmp_path): + """Non-palette TIFFs should not have a cmap attr.""" + arr = np.ones((4, 4), dtype=np.float32) + path = str(tmp_path / 'no_palette.tif') + write(arr, path, compression='none', tiled=False) + + da = read_geotiff(path) + assert 'cmap' not in da.attrs + assert 'colormap_rgba' not in da.attrs + + +class TestPlanarConfig: + + def test_planar_strips_rgb(self, tmp_path): + """Read a 3-band planar-stripped TIFF.""" + tiff_data, expected = _make_planar_tiff(4, 6, 3, np.uint8) + path = str(tmp_path / 'planar_strip.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.shape == (6, 4, 3) + np.testing.assert_array_equal(result, expected) + + def test_planar_strips_2band(self, tmp_path): + """Read a 2-band planar-stripped TIFF.""" + tiff_data, expected = _make_planar_tiff(5, 4, 2, np.uint16) + path = str(tmp_path / 'planar_2band.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.shape == (4, 5, 2) + np.testing.assert_array_equal(result, expected) + + def test_planar_tiles_rgb(self, tmp_path): + """Read a 3-band planar-tiled TIFF.""" + tiff_data, expected = _make_planar_tiff( + 8, 8, 3, np.uint8, tiled=True, tile_size=4) + path = str(tmp_path / 'planar_tiled.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path) + assert result.shape == (8, 8, 3) + np.testing.assert_array_equal(result, expected) + + def test_planar_windowed(self, tmp_path): + """Windowed read of a planar-stripped TIFF.""" + tiff_data, expected = _make_planar_tiff(8, 8, 3, np.uint8) + path = str(tmp_path / 'planar_window.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path, window=(2, 1, 6, 5)) + np.testing.assert_array_equal(result, expected[2:6, 1:5, :]) + + def test_planar_band_selection(self, tmp_path): + """Selecting a single band from a planar TIFF.""" + tiff_data, expected = _make_planar_tiff(4, 4, 3, np.uint8) + path = str(tmp_path / 'planar_band.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + result, _ = read_to_array(path, band=1) + assert result.shape == (4, 4) + np.testing.assert_array_equal(result, expected[:, :, 1]) + + def test_planar_via_public_api(self, tmp_path): + """read_geotiff on a planar file returns correct DataArray.""" + from xrspatial.geotiff import read_geotiff + tiff_data, expected = _make_planar_tiff(4, 4, 3, np.uint8) + path = str(tmp_path / 'planar_api.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + da = read_geotiff(path) + assert 'band' in da.dims + assert da.shape == (4, 4, 3) + np.testing.assert_array_equal(da.values, expected) + + +# ----------------------------------------------------------------------- +# Dask lazy reads +# ----------------------------------------------------------------------- + +class TestDaskReads: + + def test_dask_basic(self, tmp_path): + """read_geotiff_dask returns a dask-backed DataArray.""" + import dask.array as da + from xrspatial.geotiff import read_geotiff_dask + + arr = np.arange(256, dtype=np.float32).reshape(16, 16) + path = str(tmp_path / 'dask_test.tif') + write(arr, path, compression='none', tiled=False) + + result = read_geotiff_dask(path, chunks=8) + assert isinstance(result.data, da.Array) + assert result.shape == (16, 16) + + # Compute and compare + computed = result.compute() + np.testing.assert_array_equal(computed.values, arr) + + def test_dask_coords(self, tmp_path): + """Dask read preserves coordinates and CRS.""" + from xrspatial.geotiff import read_geotiff_dask + from xrspatial.geotiff._geotags import GeoTransform + + arr = np.ones((8, 8), dtype=np.float32) + gt = GeoTransform(-120.0, 45.0, 0.001, -0.001) + path = str(tmp_path / 'dask_geo.tif') + write(arr, path, geo_transform=gt, crs_epsg=4326, + compression='none', tiled=False) + + result = read_geotiff_dask(path, chunks=4) + assert result.attrs['crs'] == 4326 + assert len(result.coords['y']) == 8 + assert len(result.coords['x']) == 8 + + def test_dask_nodata(self, tmp_path): + """Nodata masking applied per-chunk.""" + from xrspatial.geotiff import read_geotiff_dask + + arr = np.array([[1.0, -9999.0], [-9999.0, 2.0], + [3.0, 4.0], [5.0, -9999.0]], dtype=np.float32) + path = str(tmp_path / 'dask_nodata.tif') + write(arr, path, compression='none', tiled=False, nodata=-9999.0) + + result = read_geotiff_dask(path, chunks=2) + computed = result.compute() + assert np.isnan(computed.values[0, 1]) + assert np.isnan(computed.values[1, 0]) + assert computed.values[0, 0] == 1.0 + + def test_dask_chunk_tuple(self, tmp_path): + """Chunks as (row, col) tuple.""" + from xrspatial.geotiff import read_geotiff_dask + + arr = np.arange(200, dtype=np.float32).reshape(10, 20) + path = str(tmp_path / 'dask_tuple.tif') + write(arr, path, compression='deflate', tiled=False) + + result = read_geotiff_dask(path, chunks=(5, 10)) + computed = result.compute() + np.testing.assert_array_equal(computed.values, arr) diff --git a/xrspatial/geotiff/tests/test_geotags.py b/xrspatial/geotiff/tests/test_geotags.py new file mode 100644 index 00000000..4bb366c8 --- /dev/null +++ b/xrspatial/geotiff/tests/test_geotags.py @@ -0,0 +1,109 @@ +"""Tests for GeoTIFF tag interpretation.""" +from __future__ import annotations + +import numpy as np +import pytest + +from xrspatial.geotiff._geotags import ( + GeoInfo, + GeoTransform, + build_geo_tags, + extract_geo_info, + GEOKEY_GEOGRAPHIC_TYPE, + GEOKEY_MODEL_TYPE, + GEOKEY_PROJECTED_CS_TYPE, + GEOKEY_RASTER_TYPE, + MODEL_TYPE_GEOGRAPHIC, + MODEL_TYPE_PROJECTED, + RASTER_PIXEL_IS_AREA, + TAG_GEO_KEY_DIRECTORY, + TAG_GDAL_NODATA, + TAG_MODEL_PIXEL_SCALE, + TAG_MODEL_TIEPOINT, +) +from xrspatial.geotiff._header import parse_all_ifds, parse_header +from .conftest import make_minimal_tiff + + +class TestGeoTransform: + def test_defaults(self): + gt = GeoTransform() + assert gt.origin_x == 0.0 + assert gt.origin_y == 0.0 + assert gt.pixel_width == 1.0 + assert gt.pixel_height == -1.0 + + +class TestExtractGeoInfo: + def test_with_tiepoint_and_scale(self): + data = make_minimal_tiff( + 4, 4, np.dtype('float32'), + geo_transform=(-120.0, 45.0, 0.001, -0.001), + epsg=4326, + ) + header = parse_header(data) + ifds = parse_all_ifds(data, header) + assert len(ifds) == 1 + + geo = extract_geo_info(ifds[0], data, header.byte_order) + assert geo.transform.origin_x == pytest.approx(-120.0) + assert geo.transform.origin_y == pytest.approx(45.0) + assert geo.transform.pixel_width == pytest.approx(0.001) + assert geo.transform.pixel_height == pytest.approx(-0.001) + assert geo.crs_epsg == 4326 + assert geo.model_type == MODEL_TYPE_GEOGRAPHIC + + def test_projected_crs(self): + data = make_minimal_tiff( + 4, 4, np.dtype('float32'), + geo_transform=(500000.0, 4500000.0, 30.0, -30.0), + epsg=32610, + ) + header = parse_header(data) + ifds = parse_all_ifds(data, header) + geo = extract_geo_info(ifds[0], data, header.byte_order) + assert geo.crs_epsg == 32610 + assert geo.model_type == MODEL_TYPE_PROJECTED + + def test_no_geo_tags(self): + data = make_minimal_tiff(4, 4, np.dtype('float32')) + header = parse_header(data) + ifds = parse_all_ifds(data, header) + geo = extract_geo_info(ifds[0], data, header.byte_order) + assert geo.crs_epsg is None + # Default transform + assert geo.transform.pixel_width == 1.0 + + +class TestBuildGeoTags: + def test_basic(self): + gt = GeoTransform(-120.0, 45.0, 0.001, -0.001) + tags = build_geo_tags(gt, crs_epsg=4326, nodata=-9999.0) + + assert TAG_MODEL_PIXEL_SCALE in tags + scale = tags[TAG_MODEL_PIXEL_SCALE] + assert scale[0] == pytest.approx(0.001) + assert scale[1] == pytest.approx(0.001) + + assert TAG_MODEL_TIEPOINT in tags + tp = tags[TAG_MODEL_TIEPOINT] + assert tp[3] == pytest.approx(-120.0) + assert tp[4] == pytest.approx(45.0) + + assert TAG_GEO_KEY_DIRECTORY in tags + assert TAG_GDAL_NODATA in tags + assert tags[TAG_GDAL_NODATA] == '-9999.0' + + def test_no_crs(self): + gt = GeoTransform(0.0, 0.0, 1.0, -1.0) + tags = build_geo_tags(gt, crs_epsg=None, nodata=None) + assert TAG_MODEL_PIXEL_SCALE in tags + assert TAG_GEO_KEY_DIRECTORY in tags + assert TAG_GDAL_NODATA not in tags + + def test_projected_crs_geokey(self): + gt = GeoTransform(500000.0, 4500000.0, 30.0, -30.0) + tags = build_geo_tags(gt, crs_epsg=32610) + geokeys = tags[TAG_GEO_KEY_DIRECTORY] + # Flatten and check that ProjectedCSType is present + assert 3072 in geokeys # GEOKEY_PROJECTED_CS_TYPE diff --git a/xrspatial/geotiff/tests/test_header.py b/xrspatial/geotiff/tests/test_header.py new file mode 100644 index 00000000..ff16116b --- /dev/null +++ b/xrspatial/geotiff/tests/test_header.py @@ -0,0 +1,123 @@ +"""Tests for TIFF header and IFD parsing.""" +from __future__ import annotations + +import struct + +import numpy as np +import pytest + +from xrspatial.geotiff._header import ( + IFD, + TIFFHeader, + parse_all_ifds, + parse_header, + parse_ifd, + TAG_IMAGE_WIDTH, + TAG_IMAGE_LENGTH, + TAG_BITS_PER_SAMPLE, + TAG_COMPRESSION, +) +from .conftest import make_minimal_tiff + + +class TestParseHeader: + def test_little_endian(self): + data = make_minimal_tiff(4, 4) + header = parse_header(data) + assert header.byte_order == '<' + assert not header.is_bigtiff + assert header.first_ifd_offset == 8 + + def test_big_endian(self): + data = make_minimal_tiff(4, 4, big_endian=True) + header = parse_header(data) + assert header.byte_order == '>' + assert not header.is_bigtiff + + def test_invalid_bom(self): + with pytest.raises(ValueError, match="Invalid TIFF byte order"): + parse_header(b'XX\x00\x2a\x00\x00\x00\x08') + + def test_invalid_magic(self): + with pytest.raises(ValueError, match="Invalid TIFF magic"): + parse_header(b'II\x00\x99\x00\x00\x00\x08') + + def test_too_short(self): + with pytest.raises(ValueError, match="Not enough data"): + parse_header(b'II\x00') + + +class TestParseIFD: + def test_basic_tags(self): + data = make_minimal_tiff(10, 20, np.dtype('uint16')) + header = parse_header(data) + ifd = parse_ifd(data, header.first_ifd_offset, header) + + assert ifd.width == 10 + assert ifd.height == 20 + assert ifd.bits_per_sample == 16 + assert ifd.compression == 1 # uncompressed + assert ifd.samples_per_pixel == 1 + + def test_float32_tags(self): + data = make_minimal_tiff(8, 8, np.dtype('float32')) + header = parse_header(data) + ifd = parse_ifd(data, header.first_ifd_offset, header) + + assert ifd.bits_per_sample == 32 + assert ifd.sample_format == 3 # float + + def test_strip_layout(self): + data = make_minimal_tiff(4, 4) + header = parse_header(data) + ifd = parse_ifd(data, header.first_ifd_offset, header) + + assert not ifd.is_tiled + assert ifd.strip_offsets is not None + assert ifd.strip_byte_counts is not None + + def test_next_ifd_zero(self): + data = make_minimal_tiff(4, 4) + header = parse_header(data) + ifd = parse_ifd(data, header.first_ifd_offset, header) + assert ifd.next_ifd_offset == 0 + + +class TestParseAllIFDs: + def test_single_ifd(self): + data = make_minimal_tiff(4, 4) + header = parse_header(data) + ifds = parse_all_ifds(data, header) + assert len(ifds) == 1 + assert ifds[0].width == 4 + + def test_tiled_ifd(self): + data = make_minimal_tiff( + 8, 8, np.dtype('float32'), + pixel_data=np.arange(64, dtype=np.float32).reshape(8, 8), + tiled=True, tile_size=4, + ) + header = parse_header(data) + ifds = parse_all_ifds(data, header) + assert len(ifds) == 1 + assert ifds[0].is_tiled + assert ifds[0].tile_width == 4 + assert ifds[0].tile_height == 4 + + +class TestIFDProperties: + def test_nodata_str(self): + ifd = IFD() + assert ifd.nodata_str is None + + def test_defaults(self): + ifd = IFD() + assert ifd.width == 0 + assert ifd.height == 0 + assert ifd.bits_per_sample == 8 + assert ifd.compression == 1 + assert ifd.predictor == 1 + assert ifd.samples_per_pixel == 1 + assert ifd.photometric == 1 + assert ifd.planar_config == 1 + assert not ifd.is_tiled diff --git a/xrspatial/geotiff/tests/test_reader.py b/xrspatial/geotiff/tests/test_reader.py new file mode 100644 index 00000000..7be32370 --- /dev/null +++ b/xrspatial/geotiff/tests/test_reader.py @@ -0,0 +1,117 @@ +"""Tests for the TIFF reader.""" +from __future__ import annotations + +import os +import tempfile + +import numpy as np +import pytest + +from xrspatial.geotiff._reader import read_to_array, _read_strips, _read_tiles +from xrspatial.geotiff._header import parse_header, parse_all_ifds +from xrspatial.geotiff._dtypes import tiff_dtype_to_numpy +from xrspatial.geotiff._geotags import extract_geo_info +from .conftest import make_minimal_tiff + + +class TestReadStrips: + def test_float32_sequential(self): + """Read a simple float32 stripped TIFF and verify pixel values.""" + expected = np.arange(16, dtype=np.float32).reshape(4, 4) + data = make_minimal_tiff(4, 4, np.dtype('float32'), pixel_data=expected) + + header = parse_header(data) + ifds = parse_all_ifds(data, header) + ifd = ifds[0] + dtype = tiff_dtype_to_numpy(ifd.bits_per_sample, ifd.sample_format) + + arr = _read_strips(data, ifd, header, dtype) + np.testing.assert_array_equal(arr, expected) + + def test_uint16(self): + expected = np.arange(20, dtype=np.uint16).reshape(4, 5) + data = make_minimal_tiff(5, 4, np.dtype('uint16'), pixel_data=expected) + + header = parse_header(data) + ifds = parse_all_ifds(data, header) + ifd = ifds[0] + dtype = tiff_dtype_to_numpy(ifd.bits_per_sample, ifd.sample_format) + + arr = _read_strips(data, ifd, header, dtype) + np.testing.assert_array_equal(arr, expected) + + def test_windowed_read(self): + expected = np.arange(64, dtype=np.float32).reshape(8, 8) + data = make_minimal_tiff(8, 8, np.dtype('float32'), pixel_data=expected) + + header = parse_header(data) + ifds = parse_all_ifds(data, header) + ifd = ifds[0] + dtype = tiff_dtype_to_numpy(ifd.bits_per_sample, ifd.sample_format) + + window = (2, 3, 6, 7) # rows 2-5, cols 3-6 + arr = _read_strips(data, ifd, header, dtype, window=window) + np.testing.assert_array_equal(arr, expected[2:6, 3:7]) + + +class TestReadTiles: + def test_tiled_float32(self): + expected = np.arange(64, dtype=np.float32).reshape(8, 8) + data = make_minimal_tiff( + 8, 8, np.dtype('float32'), + pixel_data=expected, + tiled=True, + tile_size=4, + ) + + header = parse_header(data) + ifds = parse_all_ifds(data, header) + ifd = ifds[0] + dtype = tiff_dtype_to_numpy(ifd.bits_per_sample, ifd.sample_format) + + arr = _read_tiles(data, ifd, header, dtype) + np.testing.assert_array_equal(arr, expected) + + def test_tiled_windowed(self): + expected = np.arange(64, dtype=np.float32).reshape(8, 8) + data = make_minimal_tiff( + 8, 8, np.dtype('float32'), + pixel_data=expected, + tiled=True, + tile_size=4, + ) + + header = parse_header(data) + ifds = parse_all_ifds(data, header) + ifd = ifds[0] + dtype = tiff_dtype_to_numpy(ifd.bits_per_sample, ifd.sample_format) + + window = (1, 2, 5, 6) + arr = _read_tiles(data, ifd, header, dtype, window=window) + np.testing.assert_array_equal(arr, expected[1:5, 2:6]) + + +class TestReadToArray: + def test_local_file(self, tmp_path): + expected = np.arange(16, dtype=np.float32).reshape(4, 4) + tiff_data = make_minimal_tiff(4, 4, np.dtype('float32'), pixel_data=expected) + path = str(tmp_path / 'test.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + arr, geo_info = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + def test_geo_info(self, tmp_path): + tiff_data = make_minimal_tiff( + 4, 4, np.dtype('float32'), + geo_transform=(-120.0, 45.0, 0.001, -0.001), + epsg=4326, + ) + path = str(tmp_path / 'geo_test.tif') + with open(path, 'wb') as f: + f.write(tiff_data) + + arr, geo_info = read_to_array(path) + assert geo_info.crs_epsg == 4326 + assert geo_info.transform.origin_x == pytest.approx(-120.0) diff --git a/xrspatial/geotiff/tests/test_writer.py b/xrspatial/geotiff/tests/test_writer.py new file mode 100644 index 00000000..a016f49f --- /dev/null +++ b/xrspatial/geotiff/tests/test_writer.py @@ -0,0 +1,104 @@ +"""Tests for the GeoTIFF writer.""" +from __future__ import annotations + +import numpy as np +import pytest + +from xrspatial.geotiff._geotags import GeoTransform +from xrspatial.geotiff._writer import write, _make_overview +from xrspatial.geotiff._reader import read_to_array + + +class TestMakeOverview: + def test_2x_decimation(self): + arr = np.arange(64, dtype=np.float32).reshape(8, 8) + ov = _make_overview(arr) + assert ov.shape == (4, 4) + # Check first value: mean of top-left 2x2 block + expected = np.mean([0, 1, 8, 9]) + assert ov[0, 0] == pytest.approx(expected) + + def test_integer_rounding(self): + arr = np.array([[1, 2, 3, 4], + [5, 6, 7, 8]], dtype=np.uint8) + ov = _make_overview(arr) + assert ov.shape == (1, 2) + assert ov.dtype == np.uint8 + + +class TestWriteRoundTrip: + def test_uncompressed_stripped(self, tmp_path): + expected = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'uncompressed.tif') + write(expected, path, compression='none', tiled=False) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + def test_deflate_stripped(self, tmp_path): + expected = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'deflate.tif') + write(expected, path, compression='deflate', tiled=False) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + def test_uncompressed_tiled(self, tmp_path): + expected = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'tiled.tif') + write(expected, path, compression='none', tiled=True, tile_size=4) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + def test_deflate_tiled(self, tmp_path): + expected = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'deflate_tiled.tif') + write(expected, path, compression='deflate', tiled=True, tile_size=4) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + def test_lzw_stripped(self, tmp_path): + expected = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'lzw.tif') + write(expected, path, compression='lzw', tiled=False) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + def test_uint16(self, tmp_path): + expected = np.arange(100, dtype=np.uint16).reshape(10, 10) + path = str(tmp_path / 'uint16.tif') + write(expected, path, compression='none', tiled=False) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + def test_with_geo_info(self, tmp_path): + expected = np.ones((4, 4), dtype=np.float32) + gt = GeoTransform(-120.0, 45.0, 0.001, -0.001) + path = str(tmp_path / 'geo.tif') + write(expected, path, geo_transform=gt, crs_epsg=4326, + nodata=-9999.0, compression='none', tiled=False) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + assert geo.crs_epsg == 4326 + assert geo.transform.origin_x == pytest.approx(-120.0) + assert geo.transform.pixel_width == pytest.approx(0.001) + + def test_predictor_deflate(self, tmp_path): + expected = np.arange(64, dtype=np.float32).reshape(8, 8) + path = str(tmp_path / 'predictor.tif') + write(expected, path, compression='deflate', tiled=False, predictor=True) + + arr, geo = read_to_array(path) + np.testing.assert_array_equal(arr, expected) + + +class TestWriteInvalidInput: + def test_unsupported_compression(self, tmp_path): + arr = np.zeros((4, 4), dtype=np.float32) + with pytest.raises(ValueError, match="Unsupported compression"): + write(arr, str(tmp_path / 'bad.tif'), compression='jpeg') diff --git a/xrspatial/reproject/_crs_utils.py b/xrspatial/reproject/_crs_utils.py index a4eb5be6..fa5d699d 100644 --- a/xrspatial/reproject/_crs_utils.py +++ b/xrspatial/reproject/_crs_utils.py @@ -35,11 +35,21 @@ def _detect_source_crs(raster): """Auto-detect the CRS of a DataArray. Fallback chain: - 1. ``raster.rio.crs`` (rioxarray) - 2. ``raster.attrs['crs']`` - 3. None + 1. ``raster.attrs['crs']`` (EPSG int from xrspatial.geotiff) + 2. ``raster.attrs['crs_wkt']`` (WKT string from xrspatial.geotiff) + 3. ``raster.rio.crs`` (rioxarray, if installed) + 4. None """ - # rioxarray + # attrs (xrspatial.geotiff convention) + crs_attr = raster.attrs.get('crs') + if crs_attr is not None: + return _resolve_crs(crs_attr) + + crs_wkt = raster.attrs.get('crs_wkt') + if crs_wkt is not None: + return _resolve_crs(crs_wkt) + + # rioxarray fallback try: rio_crs = raster.rio.crs if rio_crs is not None: @@ -47,11 +57,6 @@ def _detect_source_crs(raster): except Exception: pass - # attrs - crs_attr = raster.attrs.get('crs') - if crs_attr is not None: - return _resolve_crs(crs_attr) - return None