Fix BFloat16 conversion error in eager checkpoint loading#3347
Fix BFloat16 conversion error in eager checkpoint loading#3347phu0ngng wants to merge 2 commits intoAI-Hypercomputer:mainfrom
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
There was a problem hiding this comment.
Thanks for the fix!
I believe this error is dependent on transformer version.
- Using older transformer version,
from_pretrainedby default loads bf16 ckpt in fp32. so it can pass withv.numpy() - using new transformer version, it loads bf16 ckpt as is, and hence error with
v.numpy(). Yes, your fixv.float().numpy()will help.
Lastly, we are working more on loading & type conversion for to_maxtext in #3184
| # Convert all to numpy immediately in eager mode | ||
| # Convert all to numpy immediately in eager mode. | ||
| # torch.Tensor.numpy() does not support bfloat16, so cast to float32 first. | ||
| import torch # pylint: disable=g-import-not-at-top |
There was a problem hiding this comment.
move this import to the top?
There was a problem hiding this comment.
We can, but then it will require torch in the lazy loading path as well, even though torch is not needed there. I think we should go with the current implementation to avoid torch requirements for lazy load.
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
|
Thanks for your patience! We re-vamped the loading/conversion/save pipeline for The current hf -> maxtext behavior would be:
Could you verify if the latest code solve your original issue? Thanks! |
Description
to_maxtext.pycrashes withTypeError: Got unsupported ScalarType BFloat16when converting bf16 HuggingFace models (e.g. Qwen3, Llama 3) in eager mode. PyTorch's.numpy()doesn't support bfloat16 tensors. This fix casts bf16 tensors to float32 before the numpy conversion.The lazy loading path is unaffected (safetensors handles this internally).
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.