Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions xrspatial/geotiff/_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,15 +866,59 @@ def packbits_compress(data: bytes) -> bytes:
pass


def _splice_jpeg_tables(tile_data: bytes,
jpeg_tables: bytes | None) -> bytes:
"""Splice a JPEGTables stream into a tile's JPEG fragment.

GDAL-style tiled JPEG TIFFs store DQT/DHT tables once in tag 347
(an abbreviated JPEG: SOI + tables + EOI) and each tile is a JPEG
fragment whose own DQT/DHT segments were stripped. To make a tile
self-contained, drop the tables stream's leading SOI and trailing
EOI and insert what remains after the tile's SOI marker.

Both buffers must start with SOI (FF D8). If either does not, the
tile data is returned unchanged so libjpeg sees its original input
and raises a meaningful error.
"""
if not jpeg_tables:
return tile_data
if len(tile_data) < 2 or tile_data[0] != 0xFF or tile_data[1] != 0xD8:
return tile_data
if len(jpeg_tables) < 4:
return tile_data
if jpeg_tables[0] != 0xFF or jpeg_tables[1] != 0xD8:
return tile_data
# Strip SOI from the tables stream, and EOI if present at the end.
tables_body = jpeg_tables[2:]
if len(tables_body) >= 2 and tables_body[-2] == 0xFF and tables_body[-1] == 0xD9:
tables_body = tables_body[:-2]
return tile_data[:2] + tables_body + tile_data[2:]


def jpeg_decompress(data: bytes, width: int = 0, height: int = 0,
samples: int = 1) -> bytes:
"""Decompress JPEG tile/strip data. Requires Pillow."""
samples: int = 1, jpeg_tables: bytes | None = None) -> bytes:
"""Decompress JPEG tile/strip data. Requires Pillow.

Parameters
----------
data : bytes
Raw JPEG bytes from one TIFF strip or tile. May be a fragment
when ``jpeg_tables`` is supplied (GDAL tiled JPEG).
jpeg_tables : bytes, optional
Contents of TIFF tag 347 (JPEGTables). If supplied, the shared
DQT/DHT segments are spliced into ``data`` before decoding so
the resulting stream is a complete JPEG.
"""
if not JPEG_AVAILABLE:
raise ImportError(
"Pillow is required to read JPEG-compressed TIFFs. "
"Install it with: pip install Pillow")
import io
if jpeg_tables:
data = _splice_jpeg_tables(data, jpeg_tables)
img = Image.open(io.BytesIO(data))
# libjpeg already converts YCbCr->RGB during decode, so rely on the
# mode Pillow returns. Calling .convert() unnecessarily would copy.
return np.asarray(img).tobytes()


Expand Down Expand Up @@ -1089,7 +1133,8 @@ def lz4_compress(data: bytes, level: int = 0) -> bytes:


def decompress(data, compression: int, expected_size: int = 0,
width: int = 0, height: int = 0, samples: int = 1) -> np.ndarray:
width: int = 0, height: int = 0, samples: int = 1,
jpeg_tables: bytes | None = None) -> np.ndarray:
"""Decompress tile/strip data based on TIFF compression tag.

Parameters
Expand All @@ -1116,8 +1161,10 @@ def decompress(data, compression: int, expected_size: int = 0,
elif compression == COMPRESSION_PACKBITS:
return np.frombuffer(packbits_decompress(data), dtype=np.uint8)
elif compression == COMPRESSION_JPEG:
return np.frombuffer(jpeg_decompress(data, width, height, samples),
dtype=np.uint8)
return np.frombuffer(
jpeg_decompress(data, width, height, samples,
jpeg_tables=jpeg_tables),
dtype=np.uint8)
elif compression == COMPRESSION_ZSTD:
return np.frombuffer(zstd_decompress(data), dtype=np.uint8)
elif compression == COMPRESSION_JPEG2000:
Expand Down
29 changes: 29 additions & 0 deletions xrspatial/geotiff/_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TAG_COLORMAP = 320
TAG_EXTRA_SAMPLES = 338
TAG_SAMPLE_FORMAT = 339
TAG_JPEG_TABLES = 347
TAG_GDAL_METADATA = 42112
TAG_GDAL_NODATA = 42113

Expand Down Expand Up @@ -166,6 +167,34 @@ def photometric(self) -> int:
def planar_config(self) -> int:
return self.get_value(TAG_PLANAR_CONFIG, 1)

@property
def jpeg_tables(self) -> bytes | None:
"""JPEGTables tag (347): shared DQT/DHT segments for tiled JPEG.

GDAL-tiled ``compress=JPEG`` TIFFs store the quantization and
Huffman tables once in this tag; each tile's payload is a JPEG
fragment that needs the tables spliced in before libjpeg can
decode it. Returns the raw bytes of the abbreviated JPEG stream
(SOI ... DQT/DHT ... EOI), or None if absent.
"""
v = self.get_value(TAG_JPEG_TABLES)
if v is None:
return None
if isinstance(v, (bytes, bytearray)):
return bytes(v)
# BYTE arrays may surface as a tuple/list of ints
if isinstance(v, (tuple, list)):
return bytes(v)
# A single-byte tag value comes back as an int; wrap it in a
# one-element bytes object. Plain ``bytes(v)`` would (incorrectly)
# allocate v zero bytes -- a malformed file with a huge int here
# could otherwise blow up memory.
if isinstance(v, int):
return bytes([v & 0xFF])
raise TypeError(
f"unexpected JPEGTables tag value type: {type(v).__name__}"
)

@property
def x_resolution(self) -> float | None:
"""XResolution tag (282), or None if absent."""
Expand Down
26 changes: 20 additions & 6 deletions xrspatial/geotiff/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ def _packed_byte_count(pixel_count: int, bps: int) -> int:

def _decode_strip_or_tile(data_slice, compression, width, height, samples,
bps, bytes_per_sample, is_sub_byte, dtype, pred,
byte_order='<'):
byte_order='<', jpeg_tables=None):
"""Decompress, apply predictor, unpack sub-byte, and reshape a strip/tile.

Parameters
Expand All @@ -529,6 +529,12 @@ def _decode_strip_or_tile(data_slice, compression, width, height, samples,
'<' for little-endian, '>' for big-endian. When the file byte
order differs from the system's native order, pixel data is
byte-swapped after decompression.
jpeg_tables : bytes or None
Raw bytes of the file's JPEGTables tag (347), or None if the file
doesn't have one. GDAL-style tiled JPEG TIFFs store DQT/DHT tables
once in this tag and each tile is a JPEG fragment that depends on
them; the JPEG decoder splices the tables in before handing the
tile to libjpeg. Ignored for non-JPEG compressions.

Returns an array shaped (height, width) or (height, width, samples).
"""
Expand All @@ -539,7 +545,8 @@ def _decode_strip_or_tile(data_slice, compression, width, height, samples,
expected = pixel_count * bytes_per_sample

chunk = decompress(data_slice, compression, expected,
width=width, height=height, samples=samples)
width=width, height=height, samples=samples,
jpeg_tables=jpeg_tables)

# Validate the decompressed byte count. A truncated deflate stream or a
# buggy compressor can produce fewer or more bytes than expected. Without
Expand Down Expand Up @@ -654,6 +661,7 @@ def _read_strips(data: bytes, ifd: IFD, header: TIFFHeader,
bps = bps[0]
bytes_per_sample = bps // 8
is_sub_byte = bps in SUB_BYTE_BPS
jpeg_tables = ifd.jpeg_tables

if offsets is None or byte_counts is None:
raise ValueError("Missing strip offsets or byte counts")
Expand Down Expand Up @@ -713,7 +721,8 @@ def _read_strips(data: bytes, ifd: IFD, header: TIFFHeader,
strip_pixels = _decode_strip_or_tile(
strip_data, compression, width, strip_rows, 1,
bps, bytes_per_sample, is_sub_byte, dtype, pred,
byte_order=header.byte_order)
byte_order=header.byte_order,
jpeg_tables=jpeg_tables)

src_r0 = max(r0 - strip_row, 0)
src_r1 = min(r1 - strip_row, strip_rows)
Expand All @@ -738,7 +747,8 @@ def _read_strips(data: bytes, ifd: IFD, header: TIFFHeader,
strip_pixels = _decode_strip_or_tile(
strip_data, compression, width, strip_rows, samples,
bps, bytes_per_sample, is_sub_byte, dtype, pred,
byte_order=header.byte_order)
byte_order=header.byte_order,
jpeg_tables=jpeg_tables)

src_r0 = max(r0 - strip_row, 0)
src_r1 = min(r1 - strip_row, strip_rows)
Expand Down Expand Up @@ -790,6 +800,7 @@ def _read_tiles(data: bytes, ifd: IFD, header: TIFFHeader,
bps = bps[0]
bytes_per_sample = bps // 8
is_sub_byte = bps in SUB_BYTE_BPS
jpeg_tables = ifd.jpeg_tables

offsets = ifd.tile_offsets
byte_counts = ifd.tile_byte_counts
Expand Down Expand Up @@ -885,7 +896,8 @@ def _decode_one(job):
return _decode_strip_or_tile(
tile_data, compression, tw, th, tile_samples,
bps, bytes_per_sample, is_sub_byte, dtype, pred,
byte_order=header.byte_order)
byte_order=header.byte_order,
jpeg_tables=jpeg_tables)

if use_parallel:
from concurrent.futures import ThreadPoolExecutor
Expand Down Expand Up @@ -1001,6 +1013,7 @@ def _read_cog_http(url: str, overview_level: int | None = None,
pred = ifd.predictor
bytes_per_sample = bps // 8
is_sub_byte = bps in SUB_BYTE_BPS
jpeg_tables = ifd.jpeg_tables

offsets = ifd.tile_offsets
byte_counts = ifd.tile_byte_counts
Expand Down Expand Up @@ -1067,7 +1080,8 @@ def _read_cog_http(url: str, overview_level: int | None = None,
tile_pixels = _decode_strip_or_tile(
tile_data, compression, tw, th, samples,
bps, bytes_per_sample, is_sub_byte, dtype, pred,
byte_order=header.byte_order)
byte_order=header.byte_order,
jpeg_tables=jpeg_tables)

y0 = tr * th
x0 = tc * tw
Expand Down
138 changes: 138 additions & 0 deletions xrspatial/geotiff/tests/test_jpeg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Tests for JPEG compression support (issue #1050)."""
from __future__ import annotations

import importlib.util

import numpy as np
import pytest
import xarray as xr

from xrspatial.geotiff._compression import (
COMPRESSION_JPEG,
_splice_jpeg_tables,
jpeg_compress,
jpeg_decompress,
)
Expand Down Expand Up @@ -154,3 +157,138 @@ def test_to_geotiff_jpeg_rejected(self, tmp_path):
path = str(tmp_path / 'api_1050.tif')
with pytest.raises(ValueError, match="JPEGTables"):
to_geotiff(da, path, compression='jpeg', tile_size=16)


class TestJpegTablesSplice:
"""Verify the JPEGTables splice helper used for tiled JPEG TIFFs."""

def test_splice_reconstructs_complete_jpeg(self):
# Build a complete JPEG, then split it into a tables stream + a
# tile fragment. Splicing should recover a decodable stream.
from PIL import Image
import io

rng = np.random.RandomState(1502)
arr = rng.randint(50, 200, (16, 16, 3), dtype=np.uint8)
img = Image.fromarray(arr, mode='RGB')
buf = io.BytesIO()
img.save(buf, format='JPEG', quality=85)
full = buf.getvalue()

# Find the SOS marker (FF DA): everything before is tables.
sos = full.index(b'\xff\xda')
tables = b'\xff\xd8' + full[2:sos] + b'\xff\xd9'
tile_fragment = b'\xff\xd8' + full[sos:]

spliced = _splice_jpeg_tables(tile_fragment, tables)
decoded = Image.open(io.BytesIO(spliced))
decoded.load()
assert decoded.size == (16, 16)

def test_splice_passthrough_on_empty_tables(self):
payload = b'\xff\xd8\xff\xd9'
assert _splice_jpeg_tables(payload, b'') == payload
assert _splice_jpeg_tables(payload, None) == payload

def test_splice_passthrough_on_invalid_input(self):
# No SOI -> return unchanged so libjpeg's own error surfaces.
assert _splice_jpeg_tables(b'no soi', b'\xff\xd8\xff\xd9') == b'no soi'

def test_jpeg_decompress_accepts_jpeg_tables_kwarg(self):
from PIL import Image
import io

rng = np.random.RandomState(1502)
arr = rng.randint(50, 200, (16, 16, 3), dtype=np.uint8)
img = Image.fromarray(arr, mode='RGB')
buf = io.BytesIO()
img.save(buf, format='JPEG', quality=85)
full = buf.getvalue()
sos = full.index(b'\xff\xda')
tables = b'\xff\xd8' + full[2:sos] + b'\xff\xd9'
fragment = b'\xff\xd8' + full[sos:]

out = jpeg_decompress(fragment, 16, 16, samples=3, jpeg_tables=tables)
assert len(out) == 16 * 16 * 3


# rasterio-driven tests for issue #1502: GDAL writes tiled JPEG TIFFs
# whose per-tile fragments share DQT/DHT tables in tag 347. Skip the
# class -- not the whole module -- when rasterio is missing so the
# codec/splice unit tests above still run.


@pytest.mark.skipif(
importlib.util.find_spec('rasterio') is None,
reason='rasterio is required to write GDAL-style tiled JPEG TIFFs',
)
class TestGdalTiledJpegRead:
"""Read GDAL-style tiled JPEG TIFFs that use the JPEGTables tag."""

def _gradient_rgb(self, size=128):
# Smooth content keeps JPEG error low and detection of bugs easy.
y = np.linspace(20, 240, size, dtype=np.uint8)
x = np.linspace(20, 240, size, dtype=np.uint8)
r = np.broadcast_to(y[:, None], (size, size)).astype(np.uint8)
g = np.broadcast_to(x[None, :], (size, size)).astype(np.uint8)
b = np.full((size, size), 128, dtype=np.uint8)
return np.stack([r, g, b], axis=0) # rasterio wants (bands, H, W)

def test_tiled_ycbcr_jpeg(self, tmp_path):
import rasterio as rio
from xrspatial.geotiff._header import (
parse_header, parse_all_ifds, TAG_JPEG_TABLES,
)

size = 128
data = self._gradient_rgb(size)
path = str(tmp_path / 'tiled_jpeg_ycbcr_1502.tif')
with rio.open(
path, 'w', driver='GTiff', height=size, width=size, count=3,
dtype='uint8', tiled=True, blockxsize=64, blockysize=64,
compress='JPEG', photometric='YCBCR',
) as dst:
dst.write(data)

# Sanity: the file actually carries JPEGTables (tag 347).
with open(path, 'rb') as f:
blob = f.read()
hdr = parse_header(blob)
ifds = parse_all_ifds(blob, hdr)
assert TAG_JPEG_TABLES in ifds[0].entries
assert ifds[0].jpeg_tables is not None
assert ifds[0].jpeg_tables[:2] == b'\xff\xd8'

arr, _ = read_to_array(path)
assert arr.shape == (size, size, 3)
assert arr.dtype == np.uint8

# Compare to rasterio's own decode. JPEG at quality 75 + 4:2:0
# chroma subsampling shows ~1-3 absolute mean error on smooth
# gradients; allow a generous 5.
with rio.open(path) as src:
ref = src.read() # (bands, H, W)
ref = np.transpose(ref, (1, 2, 0))
assert np.abs(arr.astype(int) - ref.astype(int)).mean() < 5

def test_tiled_grayscale_jpeg(self, tmp_path):
import rasterio as rio

size = 96
y = np.linspace(20, 240, size, dtype=np.uint8)
gray = np.broadcast_to(y[:, None], (size, size)).astype(np.uint8)

path = str(tmp_path / 'tiled_jpeg_gray_1502.tif')
with rio.open(
path, 'w', driver='GTiff', height=size, width=size, count=1,
dtype='uint8', tiled=True, blockxsize=32, blockysize=32,
compress='JPEG',
) as dst:
dst.write(gray, 1)

arr, _ = read_to_array(path)
assert arr.shape == (size, size)

with rio.open(path) as src:
ref = src.read(1)
assert np.abs(arr.astype(int) - ref.astype(int)).mean() < 5
Loading