[tx] [WIP] Support LoRA in the unembedding layer#969
[tx] [WIP] Support LoRA in the unembedding layer#969pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
Conversation
|
@pcmoritz is attempting to deploy a commit to the Tyler's projects Team on Vercel. A member of the Team first needs to authorize it. |
There was a problem hiding this comment.
Code Review
This pull request adds support for LoRA in the unembedding layer, which is a great addition for models with tied embeddings. The implementation in LoRAEmbed.T correctly applies the transposed LoRA weights. The corresponding check that prevented this is also correctly removed. My main feedback is about code duplication between the new _unembed function and the existing apply_lora method. Refactoring the common logic for handling multi-adapter inputs would improve the code's maintainability. Since this is a work in progress, it's a good time to consider such a refactoring.
| def _unembed(hidden_states: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array: | ||
| base_out = hidden_states @ self.embedding[...].T | ||
|
|
||
| if self.max_lora_adapters == 0 or adapter_indices is None: | ||
| return base_out | ||
|
|
||
| if self.lora_A is None or self.lora_B is None or self.lora_scaling is None: | ||
| raise RuntimeError("LoRA parameters are not initialized. `init_lora` must be called.") | ||
|
|
||
| batch_size, seq_len, hidden_dim = hidden_states.shape | ||
| assert hidden_dim == self.embedding[...].shape[1] | ||
| assert adapter_indices.shape[0] == batch_size | ||
|
|
||
| hidden_flat = hidden_states.reshape(-1, hidden_dim) | ||
| adapter_indices_expanded = jnp.repeat(adapter_indices, seq_len) | ||
|
|
||
| hidden_sorted, group_sizes, unsort_indices, _ = prepare_routing( | ||
| hidden_flat, adapter_indices_expanded, self.max_lora_adapters | ||
| ) | ||
|
|
||
| # Apply LoRA using ragged_dot: x @ B^T @ A^T | ||
| lora_B_T = jnp.swapaxes(self.lora_B[...], -1, -2) | ||
| intermediate = jax.lax.ragged_dot(hidden_sorted, lora_B_T, group_sizes) | ||
| lora_A_T = jnp.swapaxes(self.lora_A[...], -1, -2) | ||
| lora_output_sorted = jax.lax.ragged_dot(intermediate, lora_A_T, group_sizes) | ||
|
|
||
| lora_output = lora_output_sorted[unsort_indices].reshape(batch_size, seq_len, -1) | ||
| lora_output = lora_output * self.lora_scaling[...][adapter_indices, None, None] | ||
| return base_out + lora_output | ||
|
|
||
| return _unembed |
There was a problem hiding this comment.
This implementation of LoRA for the unembedding layer has significant code duplication with the apply_lora method, particularly the logic for handling batched adapters (flattening, routing, sorting, and scaling). To improve maintainability, consider refactoring the common logic into a shared helper function. This function could handle the adapter routing and apply a generic sequence of matrix multiplications, which would be x @ A @ B for the forward pass and x @ B.T @ A.T for the unembedding pass.
We simplify the LoRAMixin by disentangling the linear and embedding case and handling them in LoRAMixin and LoRAEmbed respectively. This is in preparation for supporting LoRA for the flattened layout (tokens, *dims) in addition to (batch, seq_len, *dims) and also #969
|
Replaced by #984 |
Not yet ready to merge but that's the general direction.