Skip to content

Commit 8048778

Browse files
authored
[tx] Simplify LoRAMixin for linear and embedding case (#983)
We simplify the LoRAMixin by disentangling the linear and embedding case and handling them in LoRAMixin and LoRAEmbed respectively. This is in preparation for supporting LoRA for the flattened layout (tokens, *dims) in addition to (batch, seq_len, *dims) and also #969
1 parent 911f8ca commit 8048778

1 file changed

Lines changed: 30 additions & 9 deletions

File tree

skyrl-tx/tx/layers/lora.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,22 @@ def init_lora(
6060
rngs=rngs,
6161
)
6262

63+
def _apply_lora_weight(
64+
self,
65+
lora_weight: jax.Array,
66+
x_sorted: jax.Array,
67+
adapter_indices_sorted: jax.Array,
68+
group_sizes: jax.Array,
69+
) -> jax.Array:
70+
"""Apply a LoRA weight matrix to input. Default is linear case: x @ weight.
71+
72+
Subclasses (e.g., LoRAEmbed) override this for different computation patterns.
73+
"""
74+
assert lora_weight.ndim == 3
75+
assert x_sorted.ndim == 2 # (tokens, in_features)
76+
assert x_sorted.shape[1] == lora_weight.shape[1]
77+
return jax.lax.ragged_dot(x_sorted, lora_weight, group_sizes)
78+
6379
def apply_lora(
6480
self,
6581
x: jax.Array,
@@ -73,8 +89,6 @@ def apply_lora(
7389
raise RuntimeError("LoRA parameters are not initialized. `init_lora` must be called.")
7490

7591
(batch_size, seq_len, *dims) = x.shape
76-
assert len(self.lora_A.shape) == 3
77-
assert len(dims) == 0 if isinstance(self, nnx.Embed) else tuple(dims) == self.lora_A[...].shape[1:-1]
7892
assert adapter_indices.shape[0] == batch_size
7993

8094
x_flat = x.reshape(-1, *dims)
@@ -85,13 +99,8 @@ def apply_lora(
8599
x_flat, adapter_indices_expanded, self.max_lora_adapters, adapter_indices=adapter_indices_expanded
86100
)
87101

88-
# Apply LoRA using ragged_dot: x @ A @ B
89-
if isinstance(self, nnx.Embed):
90-
# Embedding path: A[x]
91-
intermediate = self.lora_A[...][adapter_indices_sorted, x_sorted, :]
92-
else:
93-
# Linear path: x @ A
94-
intermediate = jax.lax.ragged_dot(x_sorted, self.lora_A[...], group_sizes)
102+
# Apply LoRA: x @ A @ B (or A[x] @ B for embeddings)
103+
intermediate = self._apply_lora_weight(self.lora_A[...], x_sorted, adapter_indices_sorted, group_sizes)
95104
lora_output_sorted = jax.lax.ragged_dot(intermediate, self.lora_B[...], group_sizes)
96105

97106
# Unsort, reshape, scale
@@ -141,6 +150,18 @@ def __init__(
141150
rngs=rngs,
142151
)
143152

153+
def _apply_lora_weight(
154+
self,
155+
lora_weight: jax.Array,
156+
x_sorted: jax.Array,
157+
adapter_indices_sorted: jax.Array,
158+
group_sizes: jax.Array,
159+
) -> jax.Array:
160+
"""For embeddings, lookup in weight instead of matmul: weight[adapter, token_id, :]."""
161+
assert lora_weight.ndim == 3
162+
assert x_sorted.ndim == 1 # (tokens,) integer indices
163+
return lora_weight[adapter_indices_sorted, x_sorted, :]
164+
144165
def __call__(self, x: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array:
145166
base_out = super().__call__(x)
146167
return self.apply_lora(x, base_out, adapter_indices)

0 commit comments

Comments
 (0)