Skip to content

[tx] [WIP] Support LoRA in the unembedding layer#969

Closed
pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
pcmoritz:tx-lora-unembed
Closed

[tx] [WIP] Support LoRA in the unembedding layer#969
pcmoritz wants to merge 1 commit intoNovaSky-AI:mainfrom
pcmoritz:tx-lora-unembed

Conversation

@pcmoritz
Copy link
Collaborator

Not yet ready to merge but that's the general direction.

@pcmoritz pcmoritz added the tx label Jan 27, 2026
@vercel
Copy link

vercel bot commented Jan 27, 2026

@pcmoritz is attempting to deploy a commit to the Tyler's projects Team on Vercel.

A member of the Team first needs to authorize it.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for LoRA in the unembedding layer, which is a great addition for models with tied embeddings. The implementation in LoRAEmbed.T correctly applies the transposed LoRA weights. The corresponding check that prevented this is also correctly removed. My main feedback is about code duplication between the new _unembed function and the existing apply_lora method. Refactoring the common logic for handling multi-adapter inputs would improve the code's maintainability. Since this is a work in progress, it's a good time to consider such a refactoring.

Comment on lines +151 to +181
def _unembed(hidden_states: jax.Array, adapter_indices: jax.Array | None = None) -> jax.Array:
base_out = hidden_states @ self.embedding[...].T

if self.max_lora_adapters == 0 or adapter_indices is None:
return base_out

if self.lora_A is None or self.lora_B is None or self.lora_scaling is None:
raise RuntimeError("LoRA parameters are not initialized. `init_lora` must be called.")

batch_size, seq_len, hidden_dim = hidden_states.shape
assert hidden_dim == self.embedding[...].shape[1]
assert adapter_indices.shape[0] == batch_size

hidden_flat = hidden_states.reshape(-1, hidden_dim)
adapter_indices_expanded = jnp.repeat(adapter_indices, seq_len)

hidden_sorted, group_sizes, unsort_indices, _ = prepare_routing(
hidden_flat, adapter_indices_expanded, self.max_lora_adapters
)

# Apply LoRA using ragged_dot: x @ B^T @ A^T
lora_B_T = jnp.swapaxes(self.lora_B[...], -1, -2)
intermediate = jax.lax.ragged_dot(hidden_sorted, lora_B_T, group_sizes)
lora_A_T = jnp.swapaxes(self.lora_A[...], -1, -2)
lora_output_sorted = jax.lax.ragged_dot(intermediate, lora_A_T, group_sizes)

lora_output = lora_output_sorted[unsort_indices].reshape(batch_size, seq_len, -1)
lora_output = lora_output * self.lora_scaling[...][adapter_indices, None, None]
return base_out + lora_output

return _unembed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation of LoRA for the unembedding layer has significant code duplication with the apply_lora method, particularly the logic for handling batched adapters (flattening, routing, sorting, and scaling). To improve maintainability, consider refactoring the common logic into a shared helper function. This function could handle the adapter routing and apply a generic sequence of matrix multiplications, which would be x @ A @ B for the forward pass and x @ B.T @ A.T for the unembedding pass.

pcmoritz added a commit that referenced this pull request Jan 28, 2026
We simplify the LoRAMixin by disentangling the linear and embedding case
and handling them in LoRAMixin and LoRAEmbed respectively. This is in
preparation for supporting LoRA for the flattened layout (tokens, *dims)
in addition to (batch, seq_len, *dims) and also
#969
@pcmoritz
Copy link
Collaborator Author

Replaced by #984

@pcmoritz pcmoritz closed this Jan 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant

Comments