[tx] General implementation of trainable Hyper Connections#1008
[tx] General implementation of trainable Hyper Connections#1008tanmaysachan wants to merge 40 commits intoNovaSky-AI:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a general implementation of Hyper Connections as an extension to the transformer layers. The changes are mainly in tx/layers/connectors.py where the Connector module is defined, and in tx/models/deepseekv3.py to integrate it into the decoder layers.
My review found a couple of issues:
- An unused
trainableparameter in theConnectorclass which should be removed for clarity. - A bug in
DeepseekV3Modelwhen handling intermediate hidden states forexpansion_rate > 1, wheresqueeze()is used incorrectly.
Overall, the implementation of the Hyper Connections logic seems to follow the intended pattern of pre/post processing around existing attention and MLP blocks. The changes are well-contained. Addressing the mentioned points will improve the robustness and clarity of the implementation.
skyrl-tx/tx/models/deepseekv3.py
Outdated
| for layer_idx, layer in enumerate(self.layers): | ||
| if output_hidden_states: | ||
| all_hidden_states.append(hidden_states) | ||
| all_hidden_states.append(hidden_states.squeeze()) |
There was a problem hiding this comment.
hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.
A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.
| all_hidden_states.append(hidden_states.squeeze()) | |
| all_hidden_states.append(hidden_states.mean(axis=-2)) |
| self.eps = eps | ||
| self.weight = Param( | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs |
There was a problem hiding this comment.
Temporary, testing
There was a problem hiding this comment.
https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html
Torch also initalizes to one by default
There was a problem hiding this comment.
Due to adapter indexing, ended up re-implementing norm in the connector layer itself - this change can be removed. But considering torch as the baseline, ones_init fits better still
|
This looks very elegant, thanks a lot for putting it together! Have you tried to do any end-to-end runs yet / studied the performance, both in terms of learning dynamics / accuracy, as well as how much slowdown it incurs :) |
|
Just waiting for the weekend to give it a spin 😅 I'll give Qwen0.6B a shot on an A/H100 |
|
Sounds great! I'm putting together the 0.3.0 release at the moment, so it will probably need to wait then, but 0.3.1 should come relatively soon thereafter, so it is not a problem. I'll put a callout in the release blog anyways, if somebody wants to try it out, they can just apply the diff themselves given how simple this is :) |
|
Did some analysis on the step times for each on Qwen 0.6B (on a 5060Ti) Expansion rate as 1 does cause a hit to the average step time (about 0.3s slower, baseline has a step time of 2.1s vs 2.4s). An easy fix would be to just short circuit the entire thing for expansion rate = 1. For expansion rate = 4, the step time was around 3.17s, so about 46% slower. |
skyrl-tx/tx/tinker/backends/jax.py
Outdated
| """Compute full gradients, apply optimizer update, and reset accumulated grads.""" | ||
| optimizer.update(lora_params, accumulated_grads.get_mean(adapter_index)) | ||
| return accumulated_grads.reset_adapter(adapter_index) | ||
| if global_optimizer is not None and self.has_global_trainables: | ||
| global_optimizer.update(global_params, global_accumulated_grads.get_mean()) | ||
| global_accumulated_grads = global_accumulated_grads.reset() | ||
| return accumulated_grads.reset_adapter(adapter_index), global_accumulated_grads |
There was a problem hiding this comment.
🔴 Global optimizer updated with zero gradients on second adapter's optim_step
When multiple LoRA adapters are active, the shared global optimizer receives spurious zero-gradient updates, corrupting its Adam state.
Root Cause
In compute_grads_and_update (jax.py:531-536), the global optimizer is updated and the global accumulated gradients are reset unconditionally on every call:
if global_optimizer is not None and self.has_global_trainables:
global_optimizer.update(global_params, global_accumulated_grads.get_mean())
global_accumulated_grads = global_accumulated_grads.reset()Since optim_step is called once per adapter (jax.py:773-809), with two adapters the sequence is:
optim_step(adapter_1)→ updates global optimizer with real mean gradients, resetsglobal_accumulated_gradsto zerooptim_step(adapter_2)→ updates global optimizer again withget_mean()of the now-zeroed gradients (all zeros), resets again
The second zero-gradient update corrupts Adam's internal state:
- First moments decay:
m_t = β₁ · m_{t-1} + (1-β₁) · 0— momentum decays toward zero - Second moments decay:
v_t = β₂ · v_{t-1} + (1-β₂) · 0— variance estimate shrinks - Step counter increments, affecting bias correction
Impact: Global trainable parameters (connectors) receive incorrect optimizer updates that degrade training quality, with severity proportional to the number of adapters.
Prompt for agents
The global optimizer should only be updated once per training iteration, not once per adapter. Currently in compute_grads_and_update (jax.py:531-536), the global optimizer is updated and global accumulated gradients are reset on every call, but optim_step is called once per adapter. Fix this by either: (1) tracking whether global grads have already been applied in this iteration and skipping if already done (e.g., check global_accumulated_grads.count > 0 before updating), or (2) decoupling the global optimizer step from the per-adapter optim_step so it runs exactly once per training iteration. Option (1) is simpler: guard the global optimizer update with a check like `if global_accumulated_grads.count > 0` before calling global_optimizer.update.
Was this helpful? React with 👍 or 👎 to provide feedback.
| def _get_adapter_indices(self, batch_size: int, adapter_indices: jax.Array | None) -> jax.Array: | ||
| if adapter_indices is None: | ||
| return jnp.zeros((batch_size,), dtype=jnp.int32) | ||
| return adapter_indices.astype(jnp.int32) |
There was a problem hiding this comment.
🟡 LoRAConnector broken when max_lora_adapters=0 — indexing into 0-sized parameter arrays returns wrong values
When a model is created with max_lora_adapters=0 (e.g., tx/run/train.py:80), the LoRAConnector creates all parameter arrays with a first dimension of 0. When pre() or post() is called, _get_adapter_indices returns jnp.zeros((B,), dtype=jnp.int32), and _get_params indexes into these 0-sized arrays, producing zero-filled results instead of the identity-preserving values.
Detailed Explanation
Unlike LoRAMixin.apply_lora which short-circuits when max_lora_adapters == 0 (lora.py:85), LoRAConnector has no such guard. When max_lora_adapters=0:
self.b_prehas shape(0, n),self.b_reshas shape(0, n, n), etc._get_adapter_indices(B, None)returnsjnp.zeros((B,))atconnectors.py:66_get_paramsindexes into 0-sized arrays atconnectors.py:71-80— JAX clips out-of-bounds indices and returns zeros- In
pre():b_pre=0→H_pre = sigmoid(0) = 0.5instead of1/n - In
post():b_res=0→M = sinkhorn(zeros)produces a uniform1/nmatrix instead of identity
For the default expansion_rate=1, the impact on pre is masked by RMSNorm (the 0.5 scale cancels during normalization), and post still produces the correct residual + output. So the default case is approximately correct. However, for expansion_rate > 1 with max_lora_adapters=0, the connector would produce completely wrong outputs (uniform mixing instead of identity passthrough).
This path is exercised in production via tx/run/train.py:80 which uses max_lora_adapters=0.
Prompt for agents
Add a guard in LoRAConnector to handle the max_lora_adapters=0 case. The simplest approach is to add a check at the start of pre() and post() methods that bypasses the connector logic when max_lora_adapters is 0, falling back to identity behavior: pre() should return x.sum(axis=-2) / n (or equivalently the mean), and post() should return residual + output[..., None, :] (broadcasting output into the expansion dimension). Alternatively, ensure the constructor always creates at least 1 adapter slot (with identity initialization) even when max_lora_adapters=0, similar to how the default adapter_index=0 is used when adapter_indices is None.
Was this helpful? React with 👍 or 👎 to provide feedback.
Addressed |
|
I just merged in main, so we have #1142 and can have a look at the gradient norm. I can also make a patch to add the gradient norm of the mHC parameters vs. everything else so we can monitor that as well. Maybe it is best to keep the patch out of this PR for now and first refactor main a little more to make this easier :) |
|
This should be the diff to include the mHC gradient norm, however the norm is currently zero for me, so either this diff is wrong or the gradient norm is actually zero for some reason :) diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py
index f836026e..c15299ad 100644
--- a/skyrl-tx/tx/tinker/backends/jax.py
+++ b/skyrl-tx/tx/tinker/backends/jax.py
@@ -37,6 +37,7 @@ from pydantic import BaseModel, Field, TypeAdapter
from transformers import AutoTokenizer, PretrainedConfig
from tx.models.configs import Qwen3Config
+from tx.layers.connectors import is_connector_path
from tx.layers.lora import clear_lora_adapter, init_lora_adapter
from tx.tinker import types
from tx.tinker.backends.backend import AbstractBackend
@@ -498,12 +499,17 @@ class JaxBackendImpl(AbstractBackend):
lora_params: nnx.State,
optimizer: nnx.Optimizer,
adapter_index: jax.Array,
- ) -> tuple[AccumulatedGradients, jax.Array]:
+ ) -> tuple[AccumulatedGradients, jax.Array, jax.Array]:
"""Compute full gradients, apply optimizer update, and reset accumulated grads."""
mean_grads = accumulated_grads.get_mean(adapter_index)
grad_norm = optax.global_norm(mean_grads)
+ mhc_grads = jax.tree.map_with_path(
+ lambda path, g: g if is_connector_path(path) else jnp.zeros_like(g),
+ mean_grads,
+ )
+ mhc_grad_norm = optax.global_norm(mhc_grads)
optimizer.update(lora_params, mean_grads)
- return accumulated_grads.reset_adapter(adapter_index), grad_norm
+ return accumulated_grads.reset_adapter(adapter_index), grad_norm, mhc_grad_norm
if self.config.enforce_eager:
self._compute_grads_and_update = compute_grads_and_update
@@ -766,16 +772,27 @@ class JaxBackendImpl(AbstractBackend):
# JIT-compiled: compute full gradients, apply optimizer update, and reset accumulated grads
with jax.set_mesh(self.mesh):
- self.accumulated_grads, grad_norm = self._compute_grads_and_update(
+ self.accumulated_grads, grad_norm, mhc_grad_norm = self._compute_grads_and_update(
self.accumulated_grads,
self.lora_params,
optimizer,
jnp.int32(adapter_index),
)
- grad_norm = float(jax.device_get(grad_norm))
- logger.info(f"Applied optimizer step for model {model_id} (adapter {adapter_index}), grad_norm={grad_norm}")
- return types.OptimStepOutput(metrics={"skyrl.ai/grad_norm": grad_norm, "skyrl.ai/learning_rate": learning_rate})
+ grad_norm, mhc_grad_norm = jax.device_get((grad_norm, mhc_grad_norm))
+ grad_norm = float(grad_norm)
+ mhc_grad_norm = float(mhc_grad_norm)
+ logger.info(
+ f"Applied optimizer step for model {model_id} "
+ f"(adapter {adapter_index}), grad_norm={grad_norm}, mhc_grad_norm={mhc_grad_norm}"
+ )
+ return types.OptimStepOutput(
+ metrics={
+ "skyrl.ai/grad_norm": grad_norm,
+ "skyrl.ai/mhc_gradient_norm": mhc_grad_norm,
+ "skyrl.ai/learning_rate": learning_rate,
+ }
+ )
def sample(
self, |
|
Ah, the problem was actually on my local setup, I hadn't incorporated your JaxBackendConfig change into my workflow yet so my code change got overwritten -- the mHC gradients are flowing for me now :) |
|
These are the best settings I have found so far (one thing it does is break the symmetry between the different streams), but the gains are still relatively small, we can likely find something better. Ideally there would be an initialization that is a little more similar to what we do with LoRA (one adapter zero, the other random, so the model is the same as the base initially, but also there is gradient signal). Choosing b_res a large multiple of the identity matrix feels a little more forced. If you have any thoughts let me know :) [the FP32 changes are not needed I think, I just wanted to make sure there is no problem with the numerics] diff --git a/skyrl-tx/tx/layers/connectors.py b/skyrl-tx/tx/layers/connectors.py
index 013cfe53..892052b2 100644
--- a/skyrl-tx/tx/layers/connectors.py
+++ b/skyrl-tx/tx/layers/connectors.py
@@ -12,11 +12,32 @@ def is_connector_path(path: tuple[Any, ...]) -> bool:
return any(name in normalized_path for name in ("attn_connector", "mlp_connector"))
+def _logit(x: jax.Array) -> jax.Array:
+ """Inverse sigmoid: logit(x) = log(x / (1-x))."""
+ x = jnp.clip(x, 1e-6, 1.0 - 1e-6)
+ return jnp.log(x) - jnp.log(1.0 - x)
+
+
+def default_b_pre(n: int, dtype: jnp.dtype = jnp.float32) -> jax.Array:
+ """H_pre = sigmoid(b_pre) = 1/n: uniform aggregation across streams."""
+ return _logit(jnp.array(1.0 / n, dtype=dtype))
+
+
+def default_b_post(n: int, dtype: jnp.dtype = jnp.float32) -> jax.Array:
+ """H_post spectrum from 0 to 2: stream 0 is pure memory, stream n-1 is full update.
+
+ Mean H_post = 1.0, preserving standard residual behavior on average.
+ """
+ return _logit(jnp.linspace(0.0, 1.0, n, dtype=dtype))
+
+
class LoRAConnector(nnx.Module):
"""
Implementation of Manifold constrained HyperConnections (https://arxiv.org/pdf/2512.24880)
- Weights initialized with identity mapping; Default behaviour equates to residual networks.
+ Streams are initialized with a spectrum of roles: stream 0 acts as pure memory
+ (preserving the residual), stream n-1 carries the full sublayer update, and
+ intermediate streams carry various mixtures. Mean behavior equals standard residual.
"""
def __init__(
@@ -40,41 +61,32 @@ class LoRAConnector(nnx.Module):
n = expansion_rate
C = hidden_dim
- self.input_norm_weight = nnx.Param(jnp.ones((max_lora_adapters, n * C), dtype=dtype))
+ # All connector parameters are stored in FP32 for numerical stability.
+ # Phi matrices are zero-initialized so that alpha * x @ 0 + bias = bias at init.
+ self.input_norm_weight = nnx.Param(jnp.ones((max_lora_adapters, n * C), dtype=jnp.float32))
self.phi_pre = Param(
- max_lora_adapters, n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs
+ max_lora_adapters, n * C, n, dtype=jnp.float32, kernel_init=nnx.initializers.zeros_init(), rngs=rngs
)
self.phi_post = Param(
- max_lora_adapters, n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs
+ max_lora_adapters, n * C, n, dtype=jnp.float32, kernel_init=nnx.initializers.zeros_init(), rngs=rngs
)
self.phi_res = Param(
- max_lora_adapters,
- n * C,
- n * n,
- dtype=dtype,
- kernel_init=nnx.initializers.normal(stddev=0.02),
- rngs=rngs,
+ max_lora_adapters, n * C, n * n, dtype=jnp.float32, kernel_init=nnx.initializers.zeros_init(), rngs=rngs,
)
- # Initialize biases for identity-like behavior:
- # H_pre = 1/n (uniform aggregation), H_post = 1 (full distribution), M = I (identity mixing)
-
- # H_pre = sigmoid(b_pre) = 1/n => b_pre = logit(1/n)
- target_h_pre = jnp.array(1.0 / n, dtype=dtype)
- clamped = jnp.clip(target_h_pre, 1e-6, 1.0 - 1e-6)
- inv_sigmoid = jnp.log(clamped) - jnp.log(1.0 - clamped)
- self.b_pre = nnx.Param(jnp.full((max_lora_adapters, n), inv_sigmoid, dtype=dtype))
+ # H_pre = sigmoid(b_pre) = 1/n: uniform aggregation across streams
+ self.b_pre = nnx.Param(jnp.full((max_lora_adapters, n), default_b_pre(n), dtype=jnp.float32))
- # H_post = 2 * sigmoid(b_post) = 1 => b_post = 0
- self.b_post = nnx.Param(jnp.zeros((max_lora_adapters, n), dtype=dtype))
+ # H_post = 2 * sigmoid(b_post): spectrum from 0 to 2, creating stream diversity.
+ # Stream 0 = pure memory, stream n-1 = full update, mean = 1 (standard residual).
+ self.b_post = nnx.Param(jnp.broadcast_to(default_b_post(n), (max_lora_adapters, n)))
- # Large identity matrix -> heavily biases Sinkhorn() to product an identity matrix
- self.b_res = nnx.Param(jnp.broadcast_to(10.0 * jnp.eye(n, dtype=dtype), (max_lora_adapters, n, n)))
+ # M ≈ I: identity mixing via Sinkhorn
+ self.b_res = nnx.Param(jnp.broadcast_to(3.0 * jnp.eye(n, dtype=jnp.float32), (max_lora_adapters, n, n)))
- # Alpha = 0 so phi matrices don't contribute initially
- self.alpha_pre = nnx.Param(jnp.zeros((max_lora_adapters,), dtype=dtype))
- self.alpha_post = nnx.Param(jnp.zeros((max_lora_adapters,), dtype=dtype))
- self.alpha_res = nnx.Param(jnp.zeros((max_lora_adapters,), dtype=dtype))
+ self.alpha_pre = nnx.Param(jnp.full((max_lora_adapters,), 0.01, dtype=jnp.float32))
+ self.alpha_post = nnx.Param(jnp.full((max_lora_adapters,), 0.01, dtype=jnp.float32))
+ self.alpha_res = nnx.Param(jnp.full((max_lora_adapters,), 0.01, dtype=jnp.float32))
def _get_adapter_indices(self, batch_size: int, adapter_indices: jax.Array | None) -> jax.Array:
if adapter_indices is None:
@@ -92,10 +104,11 @@ class LoRAConnector(nnx.Module):
return jax.lax.fori_loop(0, iters, step, M)
def _norm(self, x_flat: jax.Array, adapter_indices: jax.Array) -> jax.Array:
- """Separate norm from layernorm.RMSNorm due to adapter indexing and trainability"""
+ """Separate norm from layernorm.RMSNorm due to adapter indexing and trainability."""
input_norm_weight = self.input_norm_weight[adapter_indices]
- rms = jnp.sqrt(jnp.mean(x_flat**2, axis=-1, keepdims=True) + self.eps)
- return (input_norm_weight[:, None, :] * x_flat) / rms
+ x_float = x_flat.astype(jnp.float32)
+ rms = jnp.sqrt(jnp.mean(x_float**2, axis=-1, keepdims=True) + self.eps)
+ return (input_norm_weight[:, None, :] * x_float) / rms
def pre(self, x: jax.Array, adapter_indices: jax.Array | None = None) -> tuple[jax.Array, jax.Array]:
B, T, n, C = x.shape
@@ -105,7 +118,7 @@ class LoRAConnector(nnx.Module):
adapter_indices = self._get_adapter_indices(B, adapter_indices)
x_flat = x.reshape(B, T, n * C)
- x_norm = self._norm(x_flat, adapter_indices)
+ x_norm = self._norm(x_flat, adapter_indices) # float32
alpha_pre = self.alpha_pre[adapter_indices]
phi_pre = self.phi_pre[adapter_indices]
@@ -114,10 +127,10 @@ class LoRAConnector(nnx.Module):
tilde_H_pre = alpha_pre[:, None, None] * pre_logits + b_pre[:, None, :]
H_pre = jax.nn.sigmoid(tilde_H_pre)
- x_agg = (H_pre[..., None] * x).sum(axis=-2)
+ x_agg = (H_pre[..., None] * x.astype(jnp.float32)).sum(axis=-2)
# Return residual norm for future use by post()
- return x_agg, x_norm
+ return x_agg.astype(x.dtype), x_norm
def post(
self,
@@ -148,6 +161,6 @@ class LoRAConnector(nnx.Module):
H_post = 2.0 * jax.nn.sigmoid(tilde_H_post)
M = self._sinkhorn_knopp(tilde_H_res, self.sinkhorn_iters)
- y_dist = H_post[..., None] * output[..., None, :]
- x_mixed = M @ residual
- return x_mixed + y_dist
+ y_dist = H_post[..., None] * output.astype(jnp.float32)[..., None, :]
+ x_mixed = M @ residual.astype(jnp.float32)
+ return (x_mixed + y_dist).astype(residual.dtype)
diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py
index d9660bfd..83170086 100644
--- a/skyrl-tx/tx/layers/lora.py
+++ b/skyrl-tx/tx/layers/lora.py
@@ -3,7 +3,7 @@ import jax
from jax import numpy as jnp
from tx.utils.models import filter_lora, get_adapter_idx
-from tx.layers.connectors import is_connector_path
+from tx.layers.connectors import default_b_post, default_b_pre, is_connector_path
from tx.layers.util import Param, prepare_routing, ragged_dot
from tx.models.types import ModelForCausalLM
from tx.tinker.types import LoraConfig
@@ -354,23 +354,20 @@ def init_lora_adapter(model: ModelForCausalLM, adapter_index: int, lora_config:
if is_connector_path(path):
connector_slot = value[idx]
if key_name in {"alpha_pre", "alpha_post", "alpha_res"}:
- return value.at[idx].set(jnp.zeros_like(connector_slot))
+ return value.at[idx].set(jnp.full_like(connector_slot, 0.01))
if key_name == "input_norm_weight":
return value.at[idx].set(jnp.ones_like(connector_slot))
if key_name in {"phi_pre", "phi_post", "phi_res"}:
- new_phi = nnx.initializers.normal(stddev=0.02)(rngs.params(), connector_slot.shape, value.dtype)
- return value.at[idx].set(new_phi)
+ return value.at[idx].set(jnp.zeros_like(connector_slot))
if key_name == "b_pre":
n = connector_slot.shape[-1]
- target_h_pre = jnp.array(1.0 / n, dtype=value.dtype)
- clamped = jnp.clip(target_h_pre, 1e-6, 1.0 - 1e-6)
- inv_sigmoid = jnp.log(clamped) - jnp.log(1.0 - clamped)
- return value.at[idx].set(jnp.full(connector_slot.shape, inv_sigmoid, dtype=value.dtype))
+ return value.at[idx].set(jnp.full(connector_slot.shape, default_b_pre(n, value.dtype), dtype=value.dtype))
if key_name == "b_post":
- return value.at[idx].set(jnp.zeros_like(connector_slot))
+ n = connector_slot.shape[-1]
+ return value.at[idx].set(jnp.broadcast_to(default_b_post(n, value.dtype), connector_slot.shape))
if key_name == "b_res":
n = connector_slot.shape[-1]
- eye = 10.0 * jnp.eye(n, dtype=value.dtype)
+ eye = 3.0 * jnp.eye(n, dtype=value.dtype)
return value.at[idx].set(jnp.broadcast_to(eye, connector_slot.shape))
return value
@@ -409,22 +406,21 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int):
# remains behaviorally neutral for mHC before being reinitialized.
if is_connector_path(path):
connector_slot = value[idx]
- if key in {"alpha_pre", "alpha_post", "alpha_res", "b_post"}:
+ if key in {"alpha_pre", "alpha_post", "alpha_res"}:
return value.at[idx].set(jnp.zeros_like(connector_slot))
if key in {"phi_pre", "phi_post", "phi_res"}:
- # Keep clear deterministic and neutral: alpha=0 makes phi inactive.
return value.at[idx].set(jnp.zeros_like(connector_slot))
if key == "input_norm_weight":
return value.at[idx].set(jnp.ones_like(connector_slot))
if key == "b_pre":
n = connector_slot.shape[-1]
- target_h_pre = jnp.array(1.0 / n, dtype=value.dtype)
- clamped = jnp.clip(target_h_pre, 1e-6, 1.0 - 1e-6)
- inv_sigmoid = jnp.log(clamped) - jnp.log(1.0 - clamped)
- return value.at[idx].set(jnp.full(connector_slot.shape, inv_sigmoid, dtype=value.dtype))
+ return value.at[idx].set(jnp.full(connector_slot.shape, default_b_pre(n, value.dtype), dtype=value.dtype))
+ if key == "b_post":
+ n = connector_slot.shape[-1]
+ return value.at[idx].set(jnp.broadcast_to(default_b_post(n, value.dtype), connector_slot.shape))
if key == "b_res":
n = connector_slot.shape[-1]
- eye = 10.0 * jnp.eye(n, dtype=value.dtype)
+ eye = 3.0 * jnp.eye(n, dtype=value.dtype)
return value.at[idx].set(jnp.broadcast_to(eye, connector_slot.shape))
return value.at[idx].set(0.0)
if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"):
diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py
index f836026e..c15299ad 100644
--- a/skyrl-tx/tx/tinker/backends/jax.py
+++ b/skyrl-tx/tx/tinker/backends/jax.py
@@ -37,6 +37,7 @@ from pydantic import BaseModel, Field, TypeAdapter
from transformers import AutoTokenizer, PretrainedConfig
from tx.models.configs import Qwen3Config
+from tx.layers.connectors import is_connector_path
from tx.layers.lora import clear_lora_adapter, init_lora_adapter
from tx.tinker import types
from tx.tinker.backends.backend import AbstractBackend
@@ -498,12 +499,17 @@ class JaxBackendImpl(AbstractBackend):
lora_params: nnx.State,
optimizer: nnx.Optimizer,
adapter_index: jax.Array,
- ) -> tuple[AccumulatedGradients, jax.Array]:
+ ) -> tuple[AccumulatedGradients, jax.Array, jax.Array]:
"""Compute full gradients, apply optimizer update, and reset accumulated grads."""
mean_grads = accumulated_grads.get_mean(adapter_index)
grad_norm = optax.global_norm(mean_grads)
+ mhc_grads = jax.tree.map_with_path(
+ lambda path, g: g if is_connector_path(path) else jnp.zeros_like(g),
+ mean_grads,
+ )
+ mhc_grad_norm = optax.global_norm(mhc_grads)
optimizer.update(lora_params, mean_grads)
- return accumulated_grads.reset_adapter(adapter_index), grad_norm
+ return accumulated_grads.reset_adapter(adapter_index), grad_norm, mhc_grad_norm
if self.config.enforce_eager:
self._compute_grads_and_update = compute_grads_and_update
@@ -766,16 +772,27 @@ class JaxBackendImpl(AbstractBackend):
# JIT-compiled: compute full gradients, apply optimizer update, and reset accumulated grads
with jax.set_mesh(self.mesh):
- self.accumulated_grads, grad_norm = self._compute_grads_and_update(
+ self.accumulated_grads, grad_norm, mhc_grad_norm = self._compute_grads_and_update(
self.accumulated_grads,
self.lora_params,
optimizer,
jnp.int32(adapter_index),
)
- grad_norm = float(jax.device_get(grad_norm))
- logger.info(f"Applied optimizer step for model {model_id} (adapter {adapter_index}), grad_norm={grad_norm}")
- return types.OptimStepOutput(metrics={"skyrl.ai/grad_norm": grad_norm, "skyrl.ai/learning_rate": learning_rate})
+ grad_norm, mhc_grad_norm = jax.device_get((grad_norm, mhc_grad_norm))
+ grad_norm = float(grad_norm)
+ mhc_grad_norm = float(mhc_grad_norm)
+ logger.info(
+ f"Applied optimizer step for model {model_id} "
+ f"(adapter {adapter_index}), grad_norm={grad_norm}, mhc_grad_norm={mhc_grad_norm}"
+ )
+ return types.OptimStepOutput(
+ metrics={
+ "skyrl.ai/grad_norm": grad_norm,
+ "skyrl.ai/mhc_gradient_norm": mhc_grad_norm,
+ "skyrl.ai/learning_rate": learning_rate,
+ }
+ )
def sample(
self, |
|
@pcmoritz Thanks for running experiments with it! I was thinking something along the lines of havine b_res initialized as product of 2 randomly initialized square matrices A and A^-1 followed by sinkhorn. Since b_res is quite small, this would be negligible parameter count wise, but this would be a departure from core mHC |
|
For clarity - should clear_adapter and init_adapter methods be implemented as static by the adapter classes themselves? I can make the change |
|
@tanmaysachan Yes, that would be much better, there are a few more small changes I'd like to try (like shifting |
This is in preparation for merging #1008 and to make it easier to introduce metrics. <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1191" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
|
Thanks a lot for all the updates, I'll do the rest (already merged a PR that cleans things up a little #1191) :) |
skyrl-tx/tx/layers/connectors.py
Outdated
| C = hidden_dim | ||
|
|
||
| # Phi matrices are zero-initialized so that alpha * x @ 0 + bias = bias at init. | ||
| self.input_norm_weight = nnx.Param(jnp.ones((max_lora_adapters, n * C), dtype=dtype)) |
There was a problem hiding this comment.
I'm curious, why did you make the RMSNorm per adapter and trainable? That seems wrong, we should probably just use the RMSNorm from the base model :) [I don't think any of the LoRA codes out there make the RMSNorm trainable]
There was a problem hiding this comment.
Actually I think I misunderstood the code and you are doing the right thing :)
There was a problem hiding this comment.
Sorry for going back on forth on this, but I think the actually correct implementation would be to pass the input norm parameters from the model to the constructor of LoRAConnector and use it for the normalization below and keep it non-trainable. It will be slightly redundant to apply the norm twice, but I think for code clarity that's fine for now (there is more optimizations to do anyways). Let me know about your thoughts, I'll give that a shot :)
There was a problem hiding this comment.
Thanks for the change!
My layernorm change was around making the entire block including the norms trainable - but yeah if its something like lora that shouldn't be the case




Addresses #952
This PR is a general implementation of Hyper connections.
This is supposed to be an extension like Lora, where the default case mimics a standard residual connection with identity mappings.
Default case - Trainable is false. Expansion rate is 1.
[edit] we now bypass this case entirely for a regular residual network.
For expansion rate > 1
These matrices preserve identity mapping. So expansion rate > 1 but untrainable still results in the the same outputs.
Todos
Future work