Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions skyrl-tx/tx/layers/lora.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from flax import nnx
import jax
from jax import numpy as jnp
from jax.sharding import PartitionSpec as P

from tx.utils.models import filter_lora
from tx.layers.util import Param, prepare_routing
Expand Down Expand Up @@ -74,7 +75,7 @@ def apply_lora(
assert adapter_indices.shape[0] == batch_size

x_flat = x.reshape(-1, *dims)
adapter_indices_expanded = jnp.repeat(adapter_indices, seq_len)
adapter_indices_expanded = jnp.repeat(adapter_indices, seq_len, out_sharding=P(None))

# Sort tokens to prepare for ragged_dot
x_sorted, group_sizes, unsort_indices, adapter_indices_sorted = prepare_routing(
Expand All @@ -83,8 +84,10 @@ def apply_lora(

# Apply LoRA using ragged_dot: x @ A @ B
if isinstance(self, nnx.Embed):
# Embedding path: A[x]
intermediate = self.lora_A.value[adapter_indices_sorted, x_sorted, :]
# Embedding path: A[x] — explicit out_sharding for JAX 0.9+ gather
intermediate = self.lora_A.value.at[adapter_indices_sorted, x_sorted, :].get(
out_sharding=P(None, None)
)
else:
# Linear path: x @ A
intermediate = jax.lax.ragged_dot(x_sorted, self.lora_A.value, group_sizes)
Expand Down Expand Up @@ -126,6 +129,10 @@ def __init__(
), "LoRAEmbed layer needs sharding, you can specify it by using nnx.with_partitioning on the embedding_init"
sharding = self.embedding.value.sharding.spec

# Store sharding spec eagerly. During JIT, self.embedding.value is a
# traced ShapedArray without a .sharding attribute.
self._embed_sharding_spec = sharding

self.init_lora(
max_lora_adapters=max_lora_adapters,
max_lora_rank=max_lora_rank,
Expand All @@ -138,7 +145,13 @@ def __init__(
)

def __call__(self, x: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array:
base_out = super().__call__(x)
# nnx.Embed.__call__ uses jnp.take() which lacks out_sharding support
# for tensor-parallel sharded embedding tables on JAX 0.9+.
# Use explicit .at[].get(out_sharding=...) instead.
embedding = self.embedding.value
spec = self._embed_sharding_spec
out_spec = P(*([None] * x.ndim), spec[1])
base_out = embedding.at[x].get(out_sharding=out_spec).astype(self.dtype)
Comment on lines +151 to +154
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation caches the sharding spec in __init__ and reuses it in __call__. While this is necessary for JIT-compiled contexts where .sharding is unavailable on traced values, it could lead to using a stale sharding spec if the embedding's sharding is modified after initialization (e.g., via jax.device_put).

A more robust approach is to dynamically access the live sharding spec when possible and only fall back to the cached value inside a JIT trace. This ensures the correct sharding is always used, preventing potential errors or silent correctness issues.

Suggested change
embedding = self.embedding.value
spec = self._embed_sharding_spec
out_spec = P(*([None] * x.ndim), spec[1])
base_out = embedding.at[x].get(out_sharding=out_spec).astype(self.dtype)
embedding = self.embedding.value
try:
# Access live sharding spec when not in a JIT context
spec = embedding.sharding.spec
except AttributeError:
# Fallback to cached spec inside a JIT trace
spec = self._embed_sharding_spec
out_spec = P(*([None] * x.ndim), spec[1])
base_out = embedding.at[x].get(out_sharding=out_spec).astype(self.dtype)

return self.apply_lora(x, base_out, adapter_indices)


Expand Down Expand Up @@ -242,7 +255,10 @@ def __call__(
raise RuntimeError("LoRA parameters are not initialized. `init_lora` must be called.")

# Reconstruct expert indices from group_sizes
expert_indices = jnp.repeat(jnp.arange(self.num_experts), group_sizes, total_repeat_length=x.shape[0])
expert_indices = jnp.repeat(
jnp.arange(self.num_experts), group_sizes,
total_repeat_length=x.shape[0], out_sharding=P(None),
)

# Flatten (adapter, expert) into a single routing dimension.
flattened_indices = adapter_indices_sorted * self.num_experts + expert_indices
Expand Down