Skip to content

Comments

Fix LoRA sharding compatibility with JAX 0.9+#1040

Open
jaredquincy wants to merge 1 commit intoNovaSky-AI:mainfrom
jaredquincy:fix/jax09-lora-sharding
Open

Fix LoRA sharding compatibility with JAX 0.9+#1040
jaredquincy wants to merge 1 commit intoNovaSky-AI:mainfrom
jaredquincy:fix/jax09-lora-sharding

Conversation

@jaredquincy
Copy link
Contributor

Summary

JAX 0.9 requires explicit out_sharding for gather/repeat operations on tensor-parallel sharded arrays. Without this fix, LoRA training and inference crash with ShardingTypeError on any model using tensor parallelism.

Three root causes in tx/layers/lora.py:

  1. LoRAEmbed.__call__: nnx.Embed.__call__ uses jnp.take() which does not accept out_sharding. Replaced with explicit .at[x].get(out_sharding=...). The embedding PartitionSpec is now stored during __init__ since traced values inside JIT lack the .sharding attribute.

  2. LoRAMixin.apply_lora: jnp.repeat(adapter_indices, seq_len) needs out_sharding=P(None) on JAX 0.9+.

  3. LoRA A embedding gather and LoRAExpert expert index repeat: Same out_sharding requirement.

Error reproduced

jax._src.interpreters.pxla.ShardingTypeError:
  Use `.at[...].get(out_sharding=)` to provide output PartitionSpec for the gather indexing
  operand=ShapedArray(bfloat16[151936@tp,1024]), indices=ShapedArray(int32[1,32,1])

Test plan

  • Tested embedding gather fix against Qwen/Qwen3-0.6B on A100 80GB with JAX 0.9.0.1
  • Verified healthz, create_model, forward_backward (wire format), optim_step, save_weights, load_weights, and asample endpoints all accept requests
  • Full forward-backward execution with LoRA training (blocked by cascading sharding issues in prepare_routing — separate fix may be needed)

JAX 0.9 requires explicit out_sharding for gather/repeat operations on
tensor-parallel sharded arrays. Without this, LoRA training and
inference crash with ShardingTypeError on any model using TP > 1.

Three root causes fixed:

1. LoRAEmbed.__call__: nnx.Embed.__call__ uses jnp.take() which does
   not accept out_sharding. Replace with explicit .at[x].get() and
   store the embedding sharding spec during __init__ (traced values
   inside JIT lack the .sharding attribute).

2. LoRAMixin.apply_lora: jnp.repeat(adapter_indices, seq_len) needs
   out_sharding=P(None) on JAX 0.9+.

3. LoRA A embedding gather and LoRAExpert expert index repeat also
   need explicit out_sharding for the same reason.

Tested against Qwen/Qwen3-0.6B on A100 80GB with JAX 0.9.0.1.
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses compatibility issues with JAX 0.9+ by explicitly providing out_sharding for gather and repeat operations, which is now required for sharded arrays. The changes are correct and necessary to prevent crashes with tensor parallelism. I have one suggestion to improve the robustness of sharding specification handling in LoRAEmbed.

Comment on lines +151 to +154
embedding = self.embedding.value
spec = self._embed_sharding_spec
out_spec = P(*([None] * x.ndim), spec[1])
base_out = embedding.at[x].get(out_sharding=out_spec).astype(self.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation caches the sharding spec in __init__ and reuses it in __call__. While this is necessary for JIT-compiled contexts where .sharding is unavailable on traced values, it could lead to using a stale sharding spec if the embedding's sharding is modified after initialization (e.g., via jax.device_put).

A more robust approach is to dynamically access the live sharding spec when possible and only fall back to the cached value inside a JIT trace. This ensures the correct sharding is always used, preventing potential errors or silent correctness issues.

Suggested change
embedding = self.embedding.value
spec = self._embed_sharding_spec
out_spec = P(*([None] * x.ndim), spec[1])
base_out = embedding.at[x].get(out_sharding=out_spec).astype(self.dtype)
embedding = self.embedding.value
try:
# Access live sharding spec when not in a JIT context
spec = embedding.sharding.spec
except AttributeError:
# Fallback to cached spec inside a JIT trace
spec = self._embed_sharding_spec
out_spec = P(*([None] * x.ndim), spec[1])
base_out = embedding.at[x].get(out_sharding=out_spec).astype(self.dtype)

@pcmoritz
Copy link
Collaborator

pcmoritz commented Feb 7, 2026

@jaredquincy Thanks a lot for submitting this PR, unfortunately I'm not able to reproduce the issue you are describing. I'm using the current main branch and also Jax 0.9.0.1

There were a few fixes recently (#965, #966) that should fix Jax 0.9 compatibility, and we also run Jax 0.9 in the CI. Let me know if the commit you tried already has them included (and which commit you used).

If you give me a concrete commit, we can also exchange uv.lock files so we can exactly reproduce each others environment and see what is going on :)

@pcmoritz pcmoritz added the tx label Feb 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants