Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/private/pypi/repack_whl.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,21 @@ def main(sys_argv):
logging.debug(f"Found dist-info dir: {distinfo_dir}")
record_path = distinfo_dir / "RECORD"
record_contents = record_path.read_text() if record_path.exists() else ""
quote_files = all(line.startswith('"') for line in record_contents.splitlines())
distribution_prefix = distinfo_dir.with_suffix("").name

with _WhlFile(
args.output, mode="w", distribution_prefix=distribution_prefix
args.output,
mode="w",
distribution_prefix=distribution_prefix,
quote_all_filenames=quote_files,
) as out:
for p in _files_to_pack(patched_wheel_dir, record_contents):
rel_path = p.relative_to(patched_wheel_dir)
out.add_file(str(rel_path), p)

logging.debug(f"Writing RECORD file")
got_record = out.add_recordfile().decode("utf-8", "surrogateescape")
got_record = out.add_recordfile()

if got_record == record_contents:
logging.info(f"Created a whl file: {args.output}")
Expand Down
57 changes: 29 additions & 28 deletions tools/wheelmaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,17 @@ def __init__(
distribution_prefix: str,
strip_path_prefixes=None,
compression=zipfile.ZIP_DEFLATED,
quote_all_filenames: bool = False,
**kwargs,
):
self._distribution_prefix = distribution_prefix

self._strip_path_prefixes = strip_path_prefixes or []
# Entries for the RECORD file as (filename, hash, size) tuples.
self._record = []
# Entries for the RECORD file as (filename, digest, size) tuples.
self._record: list[tuple[str, str, str]] = []
# Whether to quote filenames in the RECORD file (for compatibility with
# some wheels like torch that have quoted filenames in their RECORD).
self.quote_all_filenames = quote_all_filenames

super().__init__(filename, mode=mode, compression=compression, **kwargs)

Expand Down Expand Up @@ -192,16 +196,15 @@ def add_string(self, filename, contents):
hash.update(contents)
self._add_to_record(filename, self._serialize_digest(hash), len(contents))

def _serialize_digest(self, hash):
def _serialize_digest(self, hash) -> str:
# https://www.python.org/dev/peps/pep-0376/#record
# "base64.urlsafe_b64encode(digest) with trailing = removed"
digest = base64.urlsafe_b64encode(hash.digest())
digest = b"sha256=" + digest.rstrip(b"=")
return digest
return digest.decode("utf-8", "surrogateescape")

def _add_to_record(self, filename, hash, size):
size = str(size).encode("ascii")
self._record.append((filename, hash, size))
def _add_to_record(self, filename: str, hash: str, size: int) -> None:
self._record.append((filename, hash, str(size)))

def _zipinfo(self, filename):
"""Construct deterministic ZipInfo entry for a file named filename"""
Expand All @@ -223,29 +226,27 @@ def _zipinfo(self, filename):
zinfo.compress_type = self.compression
return zinfo

def add_recordfile(self):
def _quote_filename(self, filename: str) -> str:
"""Return a possibly quoted filename for RECORD file."""
filename = filename.lstrip("/")
# Some RECORDs like torch have *all* filenames quoted and we must minimize diff.
# Otherwise, we quote only when necessary (e.g. for filenames with commas).
quoting = csv.QUOTE_ALL if self.quote_all_filenames else csv.QUOTE_MINIMAL
with io.StringIO() as buf:
csv.writer(buf, quoting=quoting).writerow([filename])
return buf.getvalue().strip()

def add_recordfile(self) -> str:
"""Write RECORD file to the distribution."""
record_path = self.distinfo_path("RECORD")
entries = self._record + [(record_path, b"", b"")]
with io.StringIO() as contents_io:
writer = csv.writer(contents_io, lineterminator="\n")
for filename, digest, size in entries:
if isinstance(filename, str):
filename = filename.lstrip("/")
writer.writerow(
(
(
c
if isinstance(c, str)
else c.decode("utf-8", "surrogateescape")
)
for c in (filename, digest, size)
)
)

contents = contents_io.getvalue()
self.add_string(record_path, contents)
return contents.encode("utf-8", "surrogateescape")
entries = self._record + [(record_path, "", "")]
entries = [
(self._quote_filename(fname), digest, size)
for fname, digest, size in entries
]
contents = "\n".join(",".join(entry) for entry in entries) + "\n"
self.add_string(record_path, contents)
return contents


class WheelMaker(object):
Expand Down