Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions skyrl-tx/tx/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@ def init_lora(
rngs=rngs,
)

def _apply_lora_weight(
self,
lora_weight: jax.Array,
x_sorted: jax.Array,
adapter_indices_sorted: jax.Array,
Comment thread
pcmoritz marked this conversation as resolved.
group_sizes: jax.Array,
) -> jax.Array:
"""Apply a LoRA weight matrix to input. Default is linear case: x @ weight.

Subclasses (e.g., LoRAEmbed) override this for different computation patterns.
"""
assert lora_weight.ndim == 3
assert x_sorted.ndim == 2 # (tokens, in_features)
assert x_sorted.shape[1] == lora_weight.shape[1]
return jax.lax.ragged_dot(x_sorted, lora_weight, group_sizes)

def apply_lora(
self,
x: jax.Array,
Expand All @@ -73,8 +89,6 @@ def apply_lora(
raise RuntimeError("LoRA parameters are not initialized. `init_lora` must be called.")

(batch_size, seq_len, *dims) = x.shape
assert len(self.lora_A.shape) == 3
assert len(dims) == 0 if isinstance(self, nnx.Embed) else tuple(dims) == self.lora_A[...].shape[1:-1]
assert adapter_indices.shape[0] == batch_size

x_flat = x.reshape(-1, *dims)
Expand All @@ -85,13 +99,8 @@ def apply_lora(
x_flat, adapter_indices_expanded, self.max_lora_adapters, adapter_indices=adapter_indices_expanded
)

# Apply LoRA using ragged_dot: x @ A @ B
if isinstance(self, nnx.Embed):
# Embedding path: A[x]
intermediate = self.lora_A[...][adapter_indices_sorted, x_sorted, :]
else:
# Linear path: x @ A
intermediate = jax.lax.ragged_dot(x_sorted, self.lora_A[...], group_sizes)
# Apply LoRA: x @ A @ B (or A[x] @ B for embeddings)
intermediate = self._apply_lora_weight(self.lora_A[...], x_sorted, adapter_indices_sorted, group_sizes)
lora_output_sorted = jax.lax.ragged_dot(intermediate, self.lora_B[...], group_sizes)

# Unsort, reshape, scale
Expand Down Expand Up @@ -141,6 +150,18 @@ def __init__(
rngs=rngs,
)

def _apply_lora_weight(
self,
lora_weight: jax.Array,
x_sorted: jax.Array,
adapter_indices_sorted: jax.Array,
group_sizes: jax.Array,
Comment thread
pcmoritz marked this conversation as resolved.
) -> jax.Array:
"""For embeddings, lookup in weight instead of matmul: weight[adapter, token_id, :]."""
assert lora_weight.ndim == 3
assert x_sorted.ndim == 1 # (tokens,) integer indices
return lora_weight[adapter_indices_sorted, x_sorted, :]

def __call__(self, x: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array:
base_out = super().__call__(x)
return self.apply_lora(x, base_out, adapter_indices)
Expand Down
Loading