@@ -60,6 +60,22 @@ def init_lora(
6060 rngs = rngs ,
6161 )
6262
63+ def _apply_lora_weight (
64+ self ,
65+ lora_weight : jax .Array ,
66+ x_sorted : jax .Array ,
67+ adapter_indices_sorted : jax .Array ,
68+ group_sizes : jax .Array ,
69+ ) -> jax .Array :
70+ """Apply a LoRA weight matrix to input. Default is linear case: x @ weight.
71+
72+ Subclasses (e.g., LoRAEmbed) override this for different computation patterns.
73+ """
74+ assert lora_weight .ndim == 3
75+ assert x_sorted .ndim == 2 # (tokens, in_features)
76+ assert x_sorted .shape [1 ] == lora_weight .shape [1 ]
77+ return jax .lax .ragged_dot (x_sorted , lora_weight , group_sizes )
78+
6379 def apply_lora (
6480 self ,
6581 x : jax .Array ,
@@ -73,8 +89,6 @@ def apply_lora(
7389 raise RuntimeError ("LoRA parameters are not initialized. `init_lora` must be called." )
7490
7591 (batch_size , seq_len , * dims ) = x .shape
76- assert len (self .lora_A .shape ) == 3
77- assert len (dims ) == 0 if isinstance (self , nnx .Embed ) else tuple (dims ) == self .lora_A [...].shape [1 :- 1 ]
7892 assert adapter_indices .shape [0 ] == batch_size
7993
8094 x_flat = x .reshape (- 1 , * dims )
@@ -85,13 +99,8 @@ def apply_lora(
8599 x_flat , adapter_indices_expanded , self .max_lora_adapters , adapter_indices = adapter_indices_expanded
86100 )
87101
88- # Apply LoRA using ragged_dot: x @ A @ B
89- if isinstance (self , nnx .Embed ):
90- # Embedding path: A[x]
91- intermediate = self .lora_A [...][adapter_indices_sorted , x_sorted , :]
92- else :
93- # Linear path: x @ A
94- intermediate = jax .lax .ragged_dot (x_sorted , self .lora_A [...], group_sizes )
102+ # Apply LoRA: x @ A @ B (or A[x] @ B for embeddings)
103+ intermediate = self ._apply_lora_weight (self .lora_A [...], x_sorted , adapter_indices_sorted , group_sizes )
95104 lora_output_sorted = jax .lax .ragged_dot (intermediate , self .lora_B [...], group_sizes )
96105
97106 # Unsort, reshape, scale
@@ -141,6 +150,18 @@ def __init__(
141150 rngs = rngs ,
142151 )
143152
153+ def _apply_lora_weight (
154+ self ,
155+ lora_weight : jax .Array ,
156+ x_sorted : jax .Array ,
157+ adapter_indices_sorted : jax .Array ,
158+ group_sizes : jax .Array ,
159+ ) -> jax .Array :
160+ """For embeddings, lookup in weight instead of matmul: weight[adapter, token_id, :]."""
161+ assert lora_weight .ndim == 3
162+ assert x_sorted .ndim == 1 # (tokens,) integer indices
163+ return lora_weight [adapter_indices_sorted , x_sorted , :]
164+
144165 def __call__ (self , x : jax .Array , adapter_indices : jax .Array | None = None ) -> jax .Array :
145166 base_out = super ().__call__ (x )
146167 return self .apply_lora (x , base_out , adapter_indices )
0 commit comments