Skip to content

HuggingFace safetensors load benchmark#3396

Open
mridul-sahu wants to merge 2 commits into
google:mainfrom
mridul-sahu:feature/hf-safetensors-benchmark
Open

HuggingFace safetensors load benchmark#3396
mridul-sahu wants to merge 2 commits into
google:mainfrom
mridul-sahu:feature/hf-safetensors-benchmark

Conversation

@mridul-sahu

Copy link
Copy Markdown
Collaborator

Description

Adds a load-only benchmark for HuggingFace-format safetensors checkpoints under
_src/testing/benchmarks/safetensor/, for A/B-comparing the safetensors load
path across revisions:

  • SafetensorLoadBenchmark generator + tiered model configs (GPT-2 through
    DeepSeek-V3 / Llama-3-405B, in leading-dim FSDP and inner-dim TP variants),
    staged by prepare.py — it downloads a model from the HF Hub, optionally
    mirrors it to GCS, and emits the per-tensor sharding spec read from the
    safetensors headers. Per-host SHA-256 digests carry load correctness.
  • The current (transient-array) safetensors loader now self-reports per-host
    read accounting via jax.monitoring
    (/jax/orbax/read/safetensors/bytes_read, num_reads). It has no TensorStore
    counters, so this is what lets the benchmark's per-host bytes/reads card
    populate — and lets a future loader change be A/B'd against this baseline on
    the same channel. Telemetry only; load behavior is unchanged.

Builds on the per-name benchmark card work merged in #3395.

Type of change

  • Bug fix (non-breaking change that fixes an issue)
  • New feature (non-breaking change that adds functionality)
  • Breaking change (alters existing behavior or a public API)
  • Docs / tests / internal tooling (no user-facing behavior change)

Checklist

  • I have read the contribution guidelines.
  • Tests cover this change and pass locally.
  • Public APIs and non-trivial functions are documented.
  • N/A — internal benchmark tooling + loader telemetry, not user-facing.

A/B-comparison load benchmark for HuggingFace-format safetensors checkpoints:
the SafetensorLoadBenchmark generator, per-tensor sharding staged by prepare.py
(leading-dim FSDP / inner-dim TP), tiered model configs (GPT-2 through
DeepSeek-V3 / Llama-3-405B), per-host correctness digests, and a run_suite.sh
tier runner. The loader reports its per-host read accounting via jax.monitoring,
so per-host bytes/reads land in the card without TensorStore counters.
The current (transient-array) loader now emits /jax/orbax/read/safetensors/
bytes_read and num_reads per host, the same channel the sharding-driven loader
uses, so the load benchmark's per-host bytes/reads card populates for the
baseline too and an A/B run compares like-for-like. storage_reads is omitted:
each per-host bundle is one contiguous read, so it would equal num_reads.
@github-actions github-actions Bot added the pull ready Ready to be pulled from GitHub into Google label Jun 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull ready Ready to be pulled from GitHub into Google

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant