Fix LoRA sharding compatibility with JAX 0.9+#1040
Fix LoRA sharding compatibility with JAX 0.9+#1040jaredquincy wants to merge 1 commit intoNovaSky-AI:mainfrom
Conversation
JAX 0.9 requires explicit out_sharding for gather/repeat operations on tensor-parallel sharded arrays. Without this, LoRA training and inference crash with ShardingTypeError on any model using TP > 1. Three root causes fixed: 1. LoRAEmbed.__call__: nnx.Embed.__call__ uses jnp.take() which does not accept out_sharding. Replace with explicit .at[x].get() and store the embedding sharding spec during __init__ (traced values inside JIT lack the .sharding attribute). 2. LoRAMixin.apply_lora: jnp.repeat(adapter_indices, seq_len) needs out_sharding=P(None) on JAX 0.9+. 3. LoRA A embedding gather and LoRAExpert expert index repeat also need explicit out_sharding for the same reason. Tested against Qwen/Qwen3-0.6B on A100 80GB with JAX 0.9.0.1.
There was a problem hiding this comment.
Code Review
This pull request addresses compatibility issues with JAX 0.9+ by explicitly providing out_sharding for gather and repeat operations, which is now required for sharded arrays. The changes are correct and necessary to prevent crashes with tensor parallelism. I have one suggestion to improve the robustness of sharding specification handling in LoRAEmbed.
| embedding = self.embedding.value | ||
| spec = self._embed_sharding_spec | ||
| out_spec = P(*([None] * x.ndim), spec[1]) | ||
| base_out = embedding.at[x].get(out_sharding=out_spec).astype(self.dtype) |
There was a problem hiding this comment.
The current implementation caches the sharding spec in __init__ and reuses it in __call__. While this is necessary for JIT-compiled contexts where .sharding is unavailable on traced values, it could lead to using a stale sharding spec if the embedding's sharding is modified after initialization (e.g., via jax.device_put).
A more robust approach is to dynamically access the live sharding spec when possible and only fall back to the cached value inside a JIT trace. This ensures the correct sharding is always used, preventing potential errors or silent correctness issues.
| embedding = self.embedding.value | |
| spec = self._embed_sharding_spec | |
| out_spec = P(*([None] * x.ndim), spec[1]) | |
| base_out = embedding.at[x].get(out_sharding=out_spec).astype(self.dtype) | |
| embedding = self.embedding.value | |
| try: | |
| # Access live sharding spec when not in a JIT context | |
| spec = embedding.sharding.spec | |
| except AttributeError: | |
| # Fallback to cached spec inside a JIT trace | |
| spec = self._embed_sharding_spec | |
| out_spec = P(*([None] * x.ndim), spec[1]) | |
| base_out = embedding.at[x].get(out_sharding=out_spec).astype(self.dtype) |
|
@jaredquincy Thanks a lot for submitting this PR, unfortunately I'm not able to reproduce the issue you are describing. I'm using the current There were a few fixes recently (#965, #966) that should fix Jax 0.9 compatibility, and we also run Jax 0.9 in the CI. Let me know if the commit you tried already has them included (and which commit you used). If you give me a concrete commit, we can also exchange |
Summary
JAX 0.9 requires explicit
out_shardingfor gather/repeat operations on tensor-parallel sharded arrays. Without this fix, LoRA training and inference crash withShardingTypeErroron any model using tensor parallelism.Three root causes in
tx/layers/lora.py:LoRAEmbed.__call__:nnx.Embed.__call__usesjnp.take()which does not acceptout_sharding. Replaced with explicit.at[x].get(out_sharding=...). The embeddingPartitionSpecis now stored during__init__since traced values inside JIT lack the.shardingattribute.LoRAMixin.apply_lora:jnp.repeat(adapter_indices, seq_len)needsout_sharding=P(None)on JAX 0.9+.LoRA A embedding gather and
LoRAExpertexpert index repeat: Sameout_shardingrequirement.Error reproduced
Test plan
prepare_routing— separate fix may be needed)