diff --git a/src/shapefile.py b/src/shapefile.py index 9fccf32..89c52a4 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -22,13 +22,13 @@ import time import warnings import zipfile -from collections.abc import Container, Iterable, Iterator, Mapping, Reversible, Sequence +from collections.abc import Container, Iterable, Iterator, Reversible, Sequence from contextlib import AbstractContextManager, ExitStack from datetime import date, datetime from os import PathLike from pathlib import Path from struct import Struct, calcsize, error, pack, unpack -from types import MappingProxyType, TracebackType +from types import TracebackType from typing import ( IO, Any, @@ -285,32 +285,57 @@ def _truncate_utf8_str( @functools.cache def _BOM_and_dbf_decoded_pad_bytes( + pad_byte: Literal[b" ", b"\x00"], encoding: str = "utf8", -) -> tuple[bytes, Mapping[str, bytes]]: +) -> tuple[bytes, dict[str, bytes], dict[str, bytes], dict[str, bytes]]: try: BOM = "".encode(encoding) except UnicodeEncodeError: BOM = b"" - tuples: list[tuple[str, bytes]] = [] - for pad_byte_str, N in {b" ": 5, b"\x00": 5, b" \x00": 2}.items(): - # Max code unit size under UTF-8, UTF-16, and UTF-32 is 4 bytes. - for n in range(1, N): - pad_bytes = pad_byte_str * n + N: int # code-unit size in bytes (possible length of + # byte strings, that a single code point could encode to) + if encoding.lower().startswith("utf32"): + N = 4 + elif encoding.lower().startswith("utf16"): + # Null bytes and ascii spaces don't encode to Surrogate-pairs + N = 2 + else: + # Both Ascii and UTF-8 handled here (UTF-8 is backward compatible with ascii) + N = 1 + + def decoded_code_points_and_bytes( + pad_byte_strs: Iterable[bytes], + ) -> dict[str, bytes]: + retval = {} + for pad_bytes in pad_byte_strs: try: s: str = (BOM + pad_bytes).decode(encoding) except UnicodeDecodeError: continue - tuples.append((s, pad_bytes)) - break - return BOM, MappingProxyType(dict(tuples)) + retval[s] = pad_bytes + return retval + + # Max code unit size under UTF-8, UTF-16, and UTF-32 is 4 bytes. + if pad_byte == b"\x00": + # Just checking the field, in which asii spaces are technically valid + # even though PyShp historically has converted them to underscores + return BOM, {}, decoded_code_points_and_bytes([b"\x00" * N]), {} + else: + pad_byte_strs = [b" " * i + b"\x00" * (N - i) for i in range(N + 1)] + + all_ascii_spaces = decoded_code_points_and_bytes([b" "]) + mixed = decoded_code_points_and_bytes(pad_byte_strs) + all_null_bytes = decoded_code_points_and_bytes([b"\x00"]) + + return BOM, all_ascii_spaces, mixed, all_null_bytes def _encode_dbf_string( s: str, size: int, - decode: Decoder | None, - pad_byte: bytes, + pad_byte: Literal[b" ", b"\x00"], + decode: Decoder | None = None, encoding: str = "utf8", encodingErrors: str = "strict", strict: bool = True, @@ -337,7 +362,7 @@ def _encode_dbf_string( if len(encoded) <= size: if i <= N - 1: msg = ( - f"Dropped {N - i} code points (e.g. characters)! " + f"Dropped {N - i} out of {N} code points (e.g. characters)! " f"{s} was truncated to {trimmed} (discarding: {s[i:]}), " f"in order to encode it under {size} bytes for the field or field name. " f"Used: {encoding=} and {encodingErrors=}. " @@ -358,21 +383,43 @@ def _encode_dbf_string( f"to a short enough byte string, using {encoding=}, {encodingErrors=} ({BOM=!r})" ) - _BOM, decoded_pad_bytes = _BOM_and_dbf_decoded_pad_bytes(encoding) + _BOM, all_first, mixed, all_last = _BOM_and_dbf_decoded_pad_bytes( + pad_byte, encoding + ) + already_warned = False - for suffix, pad_bytes in decoded_pad_bytes.items(): - if s.endswith(suffix): - msg = ( - f"Under the given encoding: {encoding}, " - f" the text (field name or 'C' or 'M' field): {s!r} " - f" ends with {suffix!r}, which " - f"encodes to the pad bytes: {pad_bytes!r}. " - "The real end of the actual data may be earlier. " - ) - if strict: - raise DbfStringDataLoss(msg) - warnings.warn(msg, category=PossibleDataLoss) - break + def check_and_trim(decoded_pad_bytes: dict[str, bytes]) -> None: + + nonlocal trimmed, already_warned + + for suffix, pad_bytes in decoded_pad_bytes.items(): + if not suffix: + continue + if len(suffix) >= 2: + raise ValueError( + f"Multiple code points: {suffix} encoded to: {pad_bytes!r} under {encoding=}" + ) + if trimmed.endswith(suffix): + msg = ( + f"Under the given encoding: {encoding}, after truncation to {size} bytes," + f" the remaining text (field name or 'C' or 'M' field): {trimmed!r} " + f" ends with {suffix!r}, which " + f"encodes to the pad bytes: {pad_bytes!r}. " + "The real end of the actual data may be earlier. " + ) + if strict: + raise DbfStringDataLoss(msg) + if not already_warned: + warnings.warn(msg, category=PossibleDataLoss) + already_warned = True + if len(set(pad_bytes)) == 1: # all same byte => strip all code points + trimmed = trimmed.rstrip(suffix) + else: + trimmed = trimmed.removesuffix(suffix) + + check_and_trim(all_last) + check_and_trim(mixed) + check_and_trim(all_first) if len(encoded) < size: padded = encoded.ljust(size, pad_byte) @@ -384,7 +431,8 @@ def _encode_dbf_string( with warnings.catch_warnings(): warnings.simplefilter("ignore") - # TODO: Fuzz test this to see what it actually catches. + # TODO: Fuzz test this to see what it actually catches, + # as it makes encoding much slower. decoded = decode( b=padded, encoding=encoding, @@ -407,7 +455,7 @@ def _encode_dbf_string( def _try_to_decode_dbf_name_or_text_field( b: bytes, - pad_bytes: bytes, # Pad bytes will be trimmed (from the R of b) in their order in the byte-string + pad_bytes: bytes, # Pad bytes will be trimmed from the RHS (end) of b. encoding: str = "utf8", encodingErrors: str = "strict", ) -> str: @@ -537,12 +585,12 @@ def from_unchecked( if "\x00" in name: msg = ( - "Field names should not contain null characters " + "Field names ought not contain null characters, " "as null bytes are used for padding in the header. " f"Got: {name=} " ) if strict: - raise dbfFileException(msg) + raise DbfStringDataLoss(msg) warnings.warn(msg, category=PossibleDataLoss) try: @@ -592,6 +640,7 @@ def from_unchecked( return inst @classmethod + @functools.cache def trim_name_until_encodable( cls, name: str, @@ -4216,7 +4265,12 @@ def __init__( self.recNum = 0 self._is_utf8 = encoding.replace("-", "").replace("_", "").lower() == "utf8" - self._BOM, self._decoded_pad_bytes = _BOM_and_dbf_decoded_pad_bytes(encoding) + ( + self._BOM, + self._decoded_ascii_spaces, + self._decoded_mixed_bytes, + self._decoded_null_bytes, + ) = _BOM_and_dbf_decoded_pad_bytes(b" ", encoding) def field( # Types of args should match *Field @@ -4434,14 +4488,21 @@ def _record(self, record: list[RecordValue]) -> None: ) if self.strict: raise DbfStringDataLoss(msg) - warnings.warn(msg) + warnings.warn(msg, category=PossibleDataLoss) + + depadded = trimmed + for byte in [b"\x00", b" "]: + try: + decoded_pad_byte = byte.decode( + self.encoding, self.encodingErrors + ) + except UnicodeDecodeError: + continue + depadded = depadded.rstrip(decoded_pad_byte) - # TODO: Handle decoded_pad_bytes longer than 1 - pad_bytes = "".join(self._decoded_pad_bytes) - depadded = trimmed.rstrip(pad_bytes) if len(depadded) < len(trimmed): msg = ( - f"Trimmed: {trimmed}, stringified: {str_val} of data: {value} " + f"Trimmed: {trimmed!r}, stringified: {str_val!r} of data: {value!r} " f"ends in decoded pad bytes or decoded null bytes. " "Data encoded as null bytes and pad bytes will probably not " "be recovered by applications reading the Shapefile or dbf file " diff --git a/tests/hypothesis_tests.py b/tests/hypothesis_tests.py index 8fd077f..ea95e8b 100644 --- a/tests/hypothesis_tests.py +++ b/tests/hypothesis_tests.py @@ -556,8 +556,9 @@ def _dbf_fields_strategy(draw, encoding: str) -> dict[str, str | int]: text( alphabet=characters( codec=encoding, - exclude_categories=["C"], # Z - Whitespace, C - Control chars++ - exclude_characters=[" "], + # https://en.wikipedia.org/wiki/Unicode_character_property#General_Category + exclude_categories=["Cs", "Co", "Cn"], # Cs - surrogates + # exclude_characters=[" "], ), min_size=1, max_size=10, @@ -582,38 +583,42 @@ def encodings_and_dbf_fields(draw): def _get_fields_context(fields, codec, strict=False): for field in fields: - if len(field["name"].encode(codec)) > 10: - return pytest.warns(shp.PossibleDataLoss) - if not strict and " " in field: - return pytest.warns(shp.PossibleDataLoss) - return contextlib.nullcontext() + if (len(field["name"].encode(codec)) > 10 or + "\x00" in field["name"] or + (" " in field["name"] and not strict) + ): + if strict: + return pytest.raises(shp.DbfStringDataLoss), True + return pytest.warns(shp.PossibleDataLoss), False + return contextlib.nullcontext(), False @pytest.mark.hypothesis @settings(suppress_health_check=[HealthCheck.too_slow, HealthCheck.data_too_large]) @given(encoding_and_dbf_field=encodings_and_dbf_fields()) -def test_dbf_Field_roundtrips( - encoding_and_dbf_field: dict, -) -> None: +def test_dbf_Field_roundtrips(encoding_and_dbf_field: dict) -> None: + encoding, field_kwargs = encoding_and_dbf_field - L = len(field_kwargs["name"].encode(encoding)) - context = _get_fields_context([field_kwargs], encoding, strict=False) + w_context, error_expected = _get_fields_context([field_kwargs], encoding, strict=True) - with context: + with w_context: expected = shp.Field.from_unchecked( encoding=encoding, - strict=False, + strict=True, **field_kwargs, ) + encoded = expected.encode_field_descriptor(strict=True) + if error_expected: + return stream = io.BytesIO() - - encoded = expected.encode_field_descriptor(strict=True) stream.write(encoded) stream.seek(0) + actual = shp.Field.from_byte_stream( stream, encoding=encoding, ) + assert isinstance(actual, shp.Field) assert actual.name == expected.name assert actual[1:] == expected[1:] @@ -681,18 +686,24 @@ def dbf_encoding_fields_and_records( return encoding, fields, records -def _assert_reader_matches_expected_fields(r, fields, written_records, writer_strict): - for f_r, f_w in itertools.zip_longest(r.data_fields, fields): - actual_field_dict = f_r._asdict() - actual_name = actual_field_dict["name"] +def _assert_reader_matches_expected_fields(r, expected_fields, writer_strict): + assert len(expected_fields) == len(r.data_fields), f"{expected_fields=}, {r.data_fields=}" + + for f_r, f_w in zip(r.data_fields, expected_fields): + expected_name = f_w["name"] if not writer_strict: - actual_name = actual_name.replace(" ", "_") - assert f_w["name"].startswith(actual_name) + expected_name = expected_name.replace(" ", "_") + expected_name = expected_name.rstrip("\x00") + assert expected_name.startswith(f_r.name), f"{expected_name=}, {f_r.name=}" + actual_field_dict = f_r._asdict() for k in ("field_type", "size", "decimal"): assert actual_field_dict[k] == f_w[k], f"{k=}, {actual_field_dict[k]=}, {f_w[k]=}" def _assert_reader_matches_expected_records(r, fields, written_records): - for exp_rec, actual_rec in itertools.zip_longest(written_records, r.records()): + actual_records = r.records() + expected_records = [rec for rec in written_records if rec is not None] + assert len(expected_records) == len(actual_records), f"{expected_records=}, {actual_records=}" + for exp_rec, actual_rec in zip(expected_records, actual_records): for expected, actual, field in itertools.zip_longest(exp_rec, actual_rec, fields): field_type = field["field_type"] decimal = field["decimal"] @@ -706,33 +717,62 @@ def _assert_reader_matches_expected_records(r, fields, written_records): assert actual == expected, f"{actual=}, {expected=}, {field_type=}, {type(actual)=}, {type(expected)=}" +def _write_fields_and_records_to_strict(w, fields, records): + + field_indices, written_records = set(), [] + + + for i, field in enumerate(fields): + try: + w.field(**field) + except shp.DbfStringDataLoss: + pass + else: + field_indices.add(i) + + if not field_indices: + return None, None + + + for record in records: + rec_list = [ + val + for i, val in enumerate(record) + if i in field_indices + ] + try: + w.record(*rec_list) + except shp.DbfStringDataLoss: + written_records.append(None) + else: + written_records.append(rec_list) + + + written_fields = [field for i, field in enumerate(fields) if i in field_indices] + + return written_fields, written_records + @pytest.mark.hypothesis @given(codec_fields_and_records=dbf_encoding_fields_and_records()) def test_dbf_reader_writer_roundtrip(codec_fields_and_records)-> None: codec, fields, records = codec_fields_and_records stream = io.BytesIO() - fields_context = _get_fields_context(fields, codec, strict=False) - written_records = [] - with shp.DbfWriter(dbf=stream, encoding=codec, strict=False) as dbf_w: - - # Only use strict = False to write fields, so that we still - # test the corresponding record values for any fields - # whose name was truncated. - with fields_context: - for field in fields: - dbf_w.field(**field) - dbf_w.strict = True - for record in records: - try: - dbf_w.record(*record) - except shp.DbfStringDataLoss: - pass - else: - written_records.append(record) + + # pytest.raises and pytest.warns can obscure other + # exceptions inside them + w = shp.DbfWriter(dbf=stream, encoding=codec, strict=True) + + written_fields, written_records = _write_fields_and_records_to_strict(w, fields, records) + + if not written_fields or written_records is None: + return + + w.close() + with shp.DbfReader(dbf=stream, encoding=codec) as r: - _assert_reader_matches_expected_fields(r, fields, written_records, False) - _assert_reader_matches_expected_records(r, fields, written_records) + _assert_reader_matches_expected_fields(r, written_fields, True) + _assert_reader_matches_expected_records(r, written_fields, written_records) @composite @@ -742,7 +782,7 @@ def codes_codecs_fields_shapes_and_records(draw): N = len(shapes) records = [draw(records_strategy) for _ in range(N)] - return code, encoding, fields, list(zip(shapes, records)) + return code, encoding, fields, shapes, records @pytest.mark.hypothesis @@ -750,32 +790,30 @@ def codes_codecs_fields_shapes_and_records(draw): @given(codes_codecs_fields_shapes_and_records=codes_codecs_fields_shapes_and_records()) def test_shapefile_reader_writer_roundtrip(codes_codecs_fields_shapes_and_records)-> None: - code_ex, encoding, fields_ex, shapes_and_records = codes_codecs_fields_shapes_and_records + code_ex, encoding, fields, shapes, records = codes_codecs_fields_shapes_and_records streams = {"shp" : io.BytesIO(), "shx" : io.BytesIO(), "dbf" : io.BytesIO(),} + w = shp.Writer(shapeType = code_ex, encoding=encoding, strict=True, **streams) expected_shapes = [] - expected_records = [] - fields_context = _get_fields_context(fields=fields_ex, codec=encoding, strict=False) - - with shp.Writer(shapeType = code_ex, encoding=encoding, strict=False, **streams) as w: - with fields_context: - for field in fields_ex: - w.field(**field) - - # Only use strict = False to write fields, so that we still - # test the corresponding record values for any fields - # whose name was truncated. - w.strict=True - for shape, record in shapes_and_records: - try: - w.record(*record) - except shp.DbfStringDataLoss: - continue - w.shape(shape) - expected_shapes.append(shape) - expected_records.append(record) + + written_fields, written_records = _write_fields_and_records_to_strict(w, fields, records) + + if not written_fields: + try: + w.close() + except shp.dbfFileException: + pass + return + + for shape, written_record in zip(shapes, written_records): + if written_record is None: + continue + w.shape(shape) + expected_shapes.append(shape) + + w.close() with shp.Reader(encoding=encoding, **streams) as r: - _assert_reader_matches_expected_fields(r, fields_ex, expected_records, False) - _assert_reader_matches_expected_records(r, fields_ex, expected_records) + _assert_reader_matches_expected_fields(r, written_fields, True) + _assert_reader_matches_expected_records(r, written_fields, written_records) _assert_reader_matches_expected_shapes(r, code_ex, expected_shapes) \ No newline at end of file