-
Notifications
You must be signed in to change notification settings - Fork 496
Checkpoint conversion fails on v5litepod TPU - Orbax OCDBT atomic rename ENOENT #3061
Copy link
Copy link
Open
Description
Environment
- TPU: v5litepod-32 (32 chips, 8 hosts)
- JAX: 0.9.0
- orbax-checkpoint: 0.11.32
- MaxText: Latest from main branch
- Python: 3.12
Problem
The HuggingFace to MaxText checkpoint conversion (llama_or_mistral_ckpt.py) fails on v5litepod-32 TPU. The layer conversion completes successfully, but the Orbax checkpoint save fails during atomic file rename.
Related orbax issue: google/orbax#2837
Error
ValueError: NOT_FOUND: Error writing "params.params.decoder.decoder_norm.scale/c/0" in OCDBT database
Failed to rename fd: 14 ".../__lock" to: "..."
[OS error 2: ENOENT No such file or directory]
Steps to Reproduce
Following the official TPU recipe from tpu-recipes:
# 1. Create TPU
gcloud compute tpus tpu-vm create my-tpu \
--zone=us-central1-a \
--accelerator-type=v5litepod-32 \
--version=tpu-ubuntu2204-base
# 2. Install MaxText
pip install -e .
# 3. Download model
huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# 4. Convert (FAILS HERE)
JAX_PLATFORMS=cpu python3 llama_or_mistral_ckpt.py \
--base-model-path=/path/to/model \
--maxtext-model-path=gs://bucket/checkpoint \
--model-size=llama3-8b \
--huggingface-checkpoint=TrueObservations
- Layer conversion completes:
layers: 100%|██████████| 32/32 - Orbax save fails during atomic file operations
- Checkpoint is left incomplete (~11GB but missing tree metadata)
generate_param_only_checkpoint.py(Step 6) fails with "No structure could be identified"
Questions
- Is there a known workaround for this issue on multi-host TPUs?
- Does the official TPU recipe work on v5litepod-32, or only on specific TPU types?
- Are there pre-converted MaxText checkpoints available for Llama 3 8B models?
Additional Context
- The official TPU recipe at
tpu-recipes/inference/trillium/JetStream-Maxtext/DeepSeek-R1-Distill-Llama-70Breferencesbash setup.shwhich doesn't exist in MaxText - The recipe recommends JAX 0.5.0 but MaxText's pyproject.toml requires JAX ≥0.8.1
- This may indicate the documentation is out of date
Impact
Complete blocker for using MaxText + JetStream serving stack with custom models on v5litepod TPUs.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels