Skip to content

Checkpoint conversion fails on v5litepod TPU - Orbax OCDBT atomic rename ENOENT #3061

@jaisong123

Description

@jaisong123

Environment

  • TPU: v5litepod-32 (32 chips, 8 hosts)
  • JAX: 0.9.0
  • orbax-checkpoint: 0.11.32
  • MaxText: Latest from main branch
  • Python: 3.12

Problem

The HuggingFace to MaxText checkpoint conversion (llama_or_mistral_ckpt.py) fails on v5litepod-32 TPU. The layer conversion completes successfully, but the Orbax checkpoint save fails during atomic file rename.

Related orbax issue: google/orbax#2837

Error

ValueError: NOT_FOUND: Error writing "params.params.decoder.decoder_norm.scale/c/0" in OCDBT database
Failed to rename fd: 14 ".../__lock" to: "..."
[OS error 2: ENOENT No such file or directory]

Steps to Reproduce

Following the official TPU recipe from tpu-recipes:

# 1. Create TPU
gcloud compute tpus tpu-vm create my-tpu \
    --zone=us-central1-a \
    --accelerator-type=v5litepod-32 \
    --version=tpu-ubuntu2204-base

# 2. Install MaxText
pip install -e .

# 3. Download model
huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Llama-8B

# 4. Convert (FAILS HERE)
JAX_PLATFORMS=cpu python3 llama_or_mistral_ckpt.py \
    --base-model-path=/path/to/model \
    --maxtext-model-path=gs://bucket/checkpoint \
    --model-size=llama3-8b \
    --huggingface-checkpoint=True

Observations

  1. Layer conversion completes: layers: 100%|██████████| 32/32
  2. Orbax save fails during atomic file operations
  3. Checkpoint is left incomplete (~11GB but missing tree metadata)
  4. generate_param_only_checkpoint.py (Step 6) fails with "No structure could be identified"

Questions

  1. Is there a known workaround for this issue on multi-host TPUs?
  2. Does the official TPU recipe work on v5litepod-32, or only on specific TPU types?
  3. Are there pre-converted MaxText checkpoints available for Llama 3 8B models?

Additional Context

  • The official TPU recipe at tpu-recipes/inference/trillium/JetStream-Maxtext/DeepSeek-R1-Distill-Llama-70B references bash setup.sh which doesn't exist in MaxText
  • The recipe recommends JAX 0.5.0 but MaxText's pyproject.toml requires JAX ≥0.8.1
  • This may indicate the documentation is out of date

Impact

Complete blocker for using MaxText + JetStream serving stack with custom models on v5litepod TPUs.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions