Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 99 additions & 38 deletions src/shapefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
import time
import warnings
import zipfile
from collections.abc import Container, Iterable, Iterator, Mapping, Reversible, Sequence
from collections.abc import Container, Iterable, Iterator, Reversible, Sequence
from contextlib import AbstractContextManager, ExitStack
from datetime import date, datetime
from os import PathLike
from pathlib import Path
from struct import Struct, calcsize, error, pack, unpack
from types import MappingProxyType, TracebackType
from types import TracebackType
from typing import (
IO,
Any,
Expand Down Expand Up @@ -285,32 +285,57 @@ def _truncate_utf8_str(

@functools.cache
def _BOM_and_dbf_decoded_pad_bytes(
pad_byte: Literal[b" ", b"\x00"],
encoding: str = "utf8",
) -> tuple[bytes, Mapping[str, bytes]]:
) -> tuple[bytes, dict[str, bytes], dict[str, bytes], dict[str, bytes]]:
try:
BOM = "".encode(encoding)
except UnicodeEncodeError:
BOM = b""

tuples: list[tuple[str, bytes]] = []
for pad_byte_str, N in {b" ": 5, b"\x00": 5, b" \x00": 2}.items():
# Max code unit size under UTF-8, UTF-16, and UTF-32 is 4 bytes.
for n in range(1, N):
pad_bytes = pad_byte_str * n
N: int # code-unit size in bytes (possible length of
# byte strings, that a single code point could encode to)
if encoding.lower().startswith("utf32"):
N = 4
elif encoding.lower().startswith("utf16"):
# Null bytes and ascii spaces don't encode to Surrogate-pairs
N = 2
else:
# Both Ascii and UTF-8 handled here (UTF-8 is backward compatible with ascii)
N = 1

def decoded_code_points_and_bytes(
pad_byte_strs: Iterable[bytes],
) -> dict[str, bytes]:
retval = {}
for pad_bytes in pad_byte_strs:
try:
s: str = (BOM + pad_bytes).decode(encoding)
except UnicodeDecodeError:
continue
tuples.append((s, pad_bytes))
break
return BOM, MappingProxyType(dict(tuples))
retval[s] = pad_bytes
return retval

# Max code unit size under UTF-8, UTF-16, and UTF-32 is 4 bytes.
if pad_byte == b"\x00":
# Just checking the field, in which asii spaces are technically valid
# even though PyShp historically has converted them to underscores
return BOM, {}, decoded_code_points_and_bytes([b"\x00" * N]), {}
else:
pad_byte_strs = [b" " * i + b"\x00" * (N - i) for i in range(N + 1)]

all_ascii_spaces = decoded_code_points_and_bytes([b" "])
mixed = decoded_code_points_and_bytes(pad_byte_strs)
all_null_bytes = decoded_code_points_and_bytes([b"\x00"])

return BOM, all_ascii_spaces, mixed, all_null_bytes


def _encode_dbf_string(
s: str,
size: int,
decode: Decoder | None,
pad_byte: bytes,
pad_byte: Literal[b" ", b"\x00"],
decode: Decoder | None = None,
encoding: str = "utf8",
encodingErrors: str = "strict",
strict: bool = True,
Expand All @@ -337,7 +362,7 @@ def _encode_dbf_string(
if len(encoded) <= size:
if i <= N - 1:
msg = (
f"Dropped {N - i} code points (e.g. characters)! "
f"Dropped {N - i} out of {N} code points (e.g. characters)! "
f"{s} was truncated to {trimmed} (discarding: {s[i:]}), "
f"in order to encode it under {size} bytes for the field or field name. "
f"Used: {encoding=} and {encodingErrors=}. "
Expand All @@ -358,21 +383,43 @@ def _encode_dbf_string(
f"to a short enough byte string, using {encoding=}, {encodingErrors=} ({BOM=!r})"
)

_BOM, decoded_pad_bytes = _BOM_and_dbf_decoded_pad_bytes(encoding)
_BOM, all_first, mixed, all_last = _BOM_and_dbf_decoded_pad_bytes(
pad_byte, encoding
)
already_warned = False

for suffix, pad_bytes in decoded_pad_bytes.items():
if s.endswith(suffix):
msg = (
f"Under the given encoding: {encoding}, "
f" the text (field name or 'C' or 'M' field): {s!r} "
f" ends with {suffix!r}, which "
f"encodes to the pad bytes: {pad_bytes!r}. "
"The real end of the actual data may be earlier. "
)
if strict:
raise DbfStringDataLoss(msg)
warnings.warn(msg, category=PossibleDataLoss)
break
def check_and_trim(decoded_pad_bytes: dict[str, bytes]) -> None:

nonlocal trimmed, already_warned

for suffix, pad_bytes in decoded_pad_bytes.items():
if not suffix:
continue
if len(suffix) >= 2:
raise ValueError(
f"Multiple code points: {suffix} encoded to: {pad_bytes!r} under {encoding=}"
)
if trimmed.endswith(suffix):
msg = (
f"Under the given encoding: {encoding}, after truncation to {size} bytes,"
f" the remaining text (field name or 'C' or 'M' field): {trimmed!r} "
f" ends with {suffix!r}, which "
f"encodes to the pad bytes: {pad_bytes!r}. "
"The real end of the actual data may be earlier. "
)
if strict:
raise DbfStringDataLoss(msg)
if not already_warned:
warnings.warn(msg, category=PossibleDataLoss)
already_warned = True
if len(set(pad_bytes)) == 1: # all same byte => strip all code points
trimmed = trimmed.rstrip(suffix)
else:
trimmed = trimmed.removesuffix(suffix)

check_and_trim(all_last)
check_and_trim(mixed)
check_and_trim(all_first)

if len(encoded) < size:
padded = encoded.ljust(size, pad_byte)
Expand All @@ -384,7 +431,8 @@ def _encode_dbf_string(

with warnings.catch_warnings():
warnings.simplefilter("ignore")
# TODO: Fuzz test this to see what it actually catches.
# TODO: Fuzz test this to see what it actually catches,
# as it makes encoding much slower.
decoded = decode(
b=padded,
encoding=encoding,
Expand All @@ -407,7 +455,7 @@ def _encode_dbf_string(

def _try_to_decode_dbf_name_or_text_field(
b: bytes,
pad_bytes: bytes, # Pad bytes will be trimmed (from the R of b) in their order in the byte-string
pad_bytes: bytes, # Pad bytes will be trimmed from the RHS (end) of b.
encoding: str = "utf8",
encodingErrors: str = "strict",
) -> str:
Expand Down Expand Up @@ -537,12 +585,12 @@ def from_unchecked(

if "\x00" in name:
msg = (
"Field names should not contain null characters "
"Field names ought not contain null characters, "
"as null bytes are used for padding in the header. "
f"Got: {name=} "
)
if strict:
raise dbfFileException(msg)
raise DbfStringDataLoss(msg)
warnings.warn(msg, category=PossibleDataLoss)

try:
Expand Down Expand Up @@ -592,6 +640,7 @@ def from_unchecked(
return inst

@classmethod
@functools.cache
def trim_name_until_encodable(
cls,
name: str,
Expand Down Expand Up @@ -4216,7 +4265,12 @@ def __init__(
self.recNum = 0
self._is_utf8 = encoding.replace("-", "").replace("_", "").lower() == "utf8"

self._BOM, self._decoded_pad_bytes = _BOM_and_dbf_decoded_pad_bytes(encoding)
(
self._BOM,
self._decoded_ascii_spaces,
self._decoded_mixed_bytes,
self._decoded_null_bytes,
) = _BOM_and_dbf_decoded_pad_bytes(b" ", encoding)

def field(
# Types of args should match *Field
Expand Down Expand Up @@ -4434,14 +4488,21 @@ def _record(self, record: list[RecordValue]) -> None:
)
if self.strict:
raise DbfStringDataLoss(msg)
warnings.warn(msg)
warnings.warn(msg, category=PossibleDataLoss)

depadded = trimmed
for byte in [b"\x00", b" "]:
try:
decoded_pad_byte = byte.decode(
self.encoding, self.encodingErrors
)
except UnicodeDecodeError:
continue
depadded = depadded.rstrip(decoded_pad_byte)

# TODO: Handle decoded_pad_bytes longer than 1
pad_bytes = "".join(self._decoded_pad_bytes)
depadded = trimmed.rstrip(pad_bytes)
if len(depadded) < len(trimmed):
msg = (
f"Trimmed: {trimmed}, stringified: {str_val} of data: {value} "
f"Trimmed: {trimmed!r}, stringified: {str_val!r} of data: {value!r} "
f"ends in decoded pad bytes or decoded null bytes. "
"Data encoded as null bytes and pad bytes will probably not "
"be recovered by applications reading the Shapefile or dbf file "
Expand Down
Loading
Loading