Skip to content

Comments

[tx] General implementation of trainable Hyper Connections#1008

Open
tanmaysachan wants to merge 40 commits intoNovaSky-AI:mainfrom
tanmaysachan:tanmay/mhc
Open

[tx] General implementation of trainable Hyper Connections#1008
tanmaysachan wants to merge 40 commits intoNovaSky-AI:mainfrom
tanmaysachan:tanmay/mhc

Conversation

@tanmaysachan
Copy link
Contributor

@tanmaysachan tanmaysachan commented Feb 2, 2026

Addresses #952

This PR is a general implementation of Hyper connections.

This is supposed to be an extension like Lora, where the default case mimics a standard residual connection with identity mappings.

Default case - Trainable is false. Expansion rate is 1.

[edit] we now bypass this case entirely for a regular residual network.

  1. H_res is a single value matrix [1]
  2. H_pre and H_post are vectors of [1, 1, 1, ...] that result in no-op matmuls

For expansion rate > 1

  1. H_res is initialized as identity of size nxn (n is the expansion rate)
  2. H_pre is [1/n, 1/n, ...]
  3. H_post is [1, 1, 1, ...]

These matrices preserve identity mapping. So expansion rate > 1 but untrainable still results in the the same outputs.

Todos

  • simplify rms integration - added elementwise_affine as a flag
  • Benchmark/ensure no regression for expansion_rate = 1 - minimal difference in step time when expansion rate is 1 and untrainable.

Future work

  • Fine tune on custom data with mHC + LoRA to see perf gains

Open with Devin

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a general implementation of Hyper Connections as an extension to the transformer layers. The changes are mainly in tx/layers/connectors.py where the Connector module is defined, and in tx/models/deepseekv3.py to integrate it into the decoder layers.

My review found a couple of issues:

  • An unused trainable parameter in the Connector class which should be removed for clarity.
  • A bug in DeepseekV3Model when handling intermediate hidden states for expansion_rate > 1, where squeeze() is used incorrectly.

Overall, the implementation of the Hyper Connections logic seems to follow the intended pattern of pre/post processing around existing attention and MLP blocks. The changes are well-contained. Addressing the mentioned points will improve the robustness and clarity of the implementation.

for layer_idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
all_hidden_states.append(hidden_states.squeeze())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

hidden_states.squeeze() is used here to process intermediate hidden states. This will only work correctly if expansion_rate is 1. For expansion_rate > 1, squeeze() will have no effect because the expansion dimension has size n > 1. This will result in appending a tensor with an incorrect shape (..., n, C) to all_hidden_states, which is inconsistent with other states and likely to cause issues downstream.

A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.

Suggested change
all_hidden_states.append(hidden_states.squeeze())
all_hidden_states.append(hidden_states.mean(axis=-2))

@pcmoritz pcmoritz added the tx label Feb 2, 2026
self.eps = eps
self.weight = Param(
size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs
size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Temporary, testing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to adapter indexing, ended up re-implementing norm in the connector layer itself - this change can be removed. But considering torch as the baseline, ones_init fits better still

@pcmoritz
Copy link
Collaborator

pcmoritz commented Feb 5, 2026

This looks very elegant, thanks a lot for putting it together! Have you tried to do any end-to-end runs yet / studied the performance, both in terms of learning dynamics / accuracy, as well as how much slowdown it incurs :)

@tanmaysachan
Copy link
Contributor Author

Just waiting for the weekend to give it a spin 😅

I'll give Qwen0.6B a shot on an A/H100

@pcmoritz
Copy link
Collaborator

pcmoritz commented Feb 5, 2026

Sounds great! I'm putting together the 0.3.0 release at the moment, so it will probably need to wait then, but 0.3.1 should come relatively soon thereafter, so it is not a problem. I'll put a callout in the release blog anyways, if somebody wants to try it out, they can just apply the diff themselves given how simple this is :)

devin-ai-integration[bot]

This comment was marked as resolved.

@tanmaysachan
Copy link
Contributor Author

tanmaysachan commented Feb 11, 2026

Did some analysis on the step times for each on Qwen 0.6B (on a 5060Ti)

Expansion rate as 1 does cause a hit to the average step time (about 0.3s slower, baseline has a step time of 2.1s vs 2.4s). An easy fix would be to just short circuit the entire thing for expansion rate = 1.

For expansion rate = 4, the step time was around 3.17s, so about 46% slower.

@tanmaysachan
Copy link
Contributor Author

tanmaysachan commented Feb 11, 2026

qwen_expansion_4_loss

Loss plot for Qwen0.6B with an expansion rate = 4 max_lora_adapters=2, max_lora_rank=1.

(some more analysis todo)

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 1 new potential issue.

View 14 additional findings in Devin Review.

Open in Devin Review

Comment on lines 531 to 536
"""Compute full gradients, apply optimizer update, and reset accumulated grads."""
optimizer.update(lora_params, accumulated_grads.get_mean(adapter_index))
return accumulated_grads.reset_adapter(adapter_index)
if global_optimizer is not None and self.has_global_trainables:
global_optimizer.update(global_params, global_accumulated_grads.get_mean())
global_accumulated_grads = global_accumulated_grads.reset()
return accumulated_grads.reset_adapter(adapter_index), global_accumulated_grads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Global optimizer updated with zero gradients on second adapter's optim_step

When multiple LoRA adapters are active, the shared global optimizer receives spurious zero-gradient updates, corrupting its Adam state.

Root Cause

In compute_grads_and_update (jax.py:531-536), the global optimizer is updated and the global accumulated gradients are reset unconditionally on every call:

if global_optimizer is not None and self.has_global_trainables:
    global_optimizer.update(global_params, global_accumulated_grads.get_mean())
    global_accumulated_grads = global_accumulated_grads.reset()

Since optim_step is called once per adapter (jax.py:773-809), with two adapters the sequence is:

  1. optim_step(adapter_1) → updates global optimizer with real mean gradients, resets global_accumulated_grads to zero
  2. optim_step(adapter_2) → updates global optimizer again with get_mean() of the now-zeroed gradients (all zeros), resets again

The second zero-gradient update corrupts Adam's internal state:

  • First moments decay: m_t = β₁ · m_{t-1} + (1-β₁) · 0 — momentum decays toward zero
  • Second moments decay: v_t = β₂ · v_{t-1} + (1-β₂) · 0 — variance estimate shrinks
  • Step counter increments, affecting bias correction

Impact: Global trainable parameters (connectors) receive incorrect optimizer updates that degrade training quality, with severity proportional to the number of adapters.

Prompt for agents
The global optimizer should only be updated once per training iteration, not once per adapter. Currently in compute_grads_and_update (jax.py:531-536), the global optimizer is updated and global accumulated gradients are reset on every call, but optim_step is called once per adapter. Fix this by either: (1) tracking whether global grads have already been applied in this iteration and skipping if already done (e.g., check global_accumulated_grads.count > 0 before updating), or (2) decoupling the global optimizer step from the per-adapter optim_step so it runs exactly once per training iteration. Option (1) is simpler: guard the global optimizer update with a check like `if global_accumulated_grads.count > 0` before calling global_optimizer.update.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 new potential issues.

View 17 additional findings in Devin Review.

Open in Devin Review

Comment on lines +64 to +67
def _get_adapter_indices(self, batch_size: int, adapter_indices: jax.Array | None) -> jax.Array:
if adapter_indices is None:
return jnp.zeros((batch_size,), dtype=jnp.int32)
return adapter_indices.astype(jnp.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 LoRAConnector broken when max_lora_adapters=0 — indexing into 0-sized parameter arrays returns wrong values

When a model is created with max_lora_adapters=0 (e.g., tx/run/train.py:80), the LoRAConnector creates all parameter arrays with a first dimension of 0. When pre() or post() is called, _get_adapter_indices returns jnp.zeros((B,), dtype=jnp.int32), and _get_params indexes into these 0-sized arrays, producing zero-filled results instead of the identity-preserving values.

Detailed Explanation

Unlike LoRAMixin.apply_lora which short-circuits when max_lora_adapters == 0 (lora.py:85), LoRAConnector has no such guard. When max_lora_adapters=0:

  • self.b_pre has shape (0, n), self.b_res has shape (0, n, n), etc.
  • _get_adapter_indices(B, None) returns jnp.zeros((B,)) at connectors.py:66
  • _get_params indexes into 0-sized arrays at connectors.py:71-80 — JAX clips out-of-bounds indices and returns zeros
  • In pre(): b_pre=0H_pre = sigmoid(0) = 0.5 instead of 1/n
  • In post(): b_res=0M = sinkhorn(zeros) produces a uniform 1/n matrix instead of identity

For the default expansion_rate=1, the impact on pre is masked by RMSNorm (the 0.5 scale cancels during normalization), and post still produces the correct residual + output. So the default case is approximately correct. However, for expansion_rate > 1 with max_lora_adapters=0, the connector would produce completely wrong outputs (uniform mixing instead of identity passthrough).

This path is exercised in production via tx/run/train.py:80 which uses max_lora_adapters=0.

Prompt for agents
Add a guard in LoRAConnector to handle the max_lora_adapters=0 case. The simplest approach is to add a check at the start of pre() and post() methods that bypasses the connector logic when max_lora_adapters is 0, falling back to identity behavior: pre() should return x.sum(axis=-2) / n (or equivalently the mean), and post() should return residual + output[..., None, :] (broadcasting output into the expansion dimension). Alternatively, ensure the constructor always creates at least 1 adapter slot (with identity initialization) even when max_lora_adapters=0, similar to how the default adapter_index=0 is used when adapter_indices is None.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

devin-ai-integration[bot]

This comment was marked as resolved.

@tanmaysachan
Copy link
Contributor Author

tanmaysachan commented Feb 12, 2026

qwen_loss_comparison

1.7B Qwen, without expansion rate and with rate = 4 (roughly identical loss plots)
mHC times with training are about 93% higher than regular per step.

@tanmaysachan
Copy link
Contributor Author

The loss differences are in a similar scale as to what is observed in the mHC paper.
image

Ground truth mHC analysis -
image

@tanmaysachan
Copy link
Contributor Author

Btw, at the moment, the only way to activate mHC is by editing the code, right? We should probably expose it in --backend-config for now. It is slightly unfortunate that it needs to be configured on the backend, but for now that's ok I think. In the docstring of the field in JaxBackendConfig we should note that this is currently experimental (e.g. if we can configure it per request going forward that would be better, but for now what you currently have is good, certainly simpler, also the behavior will likely still change a bunch).

Addressed

devin-ai-integration[bot]

This comment was marked as resolved.

@pcmoritz
Copy link
Collaborator

I just merged in main, so we have #1142 and can have a look at the gradient norm. I can also make a patch to add the gradient norm of the mHC parameters vs. everything else so we can monitor that as well. Maybe it is best to keep the patch out of this PR for now and first refactor main a little more to make this easier :)

@pcmoritz
Copy link
Collaborator

This should be the diff to include the mHC gradient norm, however the norm is currently zero for me, so either this diff is wrong or the gradient norm is actually zero for some reason :)

diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py
index f836026e..c15299ad 100644
--- a/skyrl-tx/tx/tinker/backends/jax.py
+++ b/skyrl-tx/tx/tinker/backends/jax.py
@@ -37,6 +37,7 @@ from pydantic import BaseModel, Field, TypeAdapter
 from transformers import AutoTokenizer, PretrainedConfig
 
 from tx.models.configs import Qwen3Config
+from tx.layers.connectors import is_connector_path
 from tx.layers.lora import clear_lora_adapter, init_lora_adapter
 from tx.tinker import types
 from tx.tinker.backends.backend import AbstractBackend
@@ -498,12 +499,17 @@ class JaxBackendImpl(AbstractBackend):
             lora_params: nnx.State,
             optimizer: nnx.Optimizer,
             adapter_index: jax.Array,
-        ) -> tuple[AccumulatedGradients, jax.Array]:
+        ) -> tuple[AccumulatedGradients, jax.Array, jax.Array]:
             """Compute full gradients, apply optimizer update, and reset accumulated grads."""
             mean_grads = accumulated_grads.get_mean(adapter_index)
             grad_norm = optax.global_norm(mean_grads)
+            mhc_grads = jax.tree.map_with_path(
+                lambda path, g: g if is_connector_path(path) else jnp.zeros_like(g),
+                mean_grads,
+            )
+            mhc_grad_norm = optax.global_norm(mhc_grads)
             optimizer.update(lora_params, mean_grads)
-            return accumulated_grads.reset_adapter(adapter_index), grad_norm
+            return accumulated_grads.reset_adapter(adapter_index), grad_norm, mhc_grad_norm
 
         if self.config.enforce_eager:
             self._compute_grads_and_update = compute_grads_and_update
@@ -766,16 +772,27 @@ class JaxBackendImpl(AbstractBackend):
 
         # JIT-compiled: compute full gradients, apply optimizer update, and reset accumulated grads
         with jax.set_mesh(self.mesh):
-            self.accumulated_grads, grad_norm = self._compute_grads_and_update(
+            self.accumulated_grads, grad_norm, mhc_grad_norm = self._compute_grads_and_update(
                 self.accumulated_grads,
                 self.lora_params,
                 optimizer,
                 jnp.int32(adapter_index),
             )
 
-        grad_norm = float(jax.device_get(grad_norm))
-        logger.info(f"Applied optimizer step for model {model_id} (adapter {adapter_index}), grad_norm={grad_norm}")
-        return types.OptimStepOutput(metrics={"skyrl.ai/grad_norm": grad_norm, "skyrl.ai/learning_rate": learning_rate})
+        grad_norm, mhc_grad_norm = jax.device_get((grad_norm, mhc_grad_norm))
+        grad_norm = float(grad_norm)
+        mhc_grad_norm = float(mhc_grad_norm)
+        logger.info(
+            f"Applied optimizer step for model {model_id} "
+            f"(adapter {adapter_index}), grad_norm={grad_norm}, mhc_grad_norm={mhc_grad_norm}"
+        )
+        return types.OptimStepOutput(
+            metrics={
+                "skyrl.ai/grad_norm": grad_norm,
+                "skyrl.ai/mhc_gradient_norm": mhc_grad_norm,
+                "skyrl.ai/learning_rate": learning_rate,
+            }
+        )
 
     def sample(
         self,

@pcmoritz
Copy link
Collaborator

pcmoritz commented Feb 18, 2026

Ah, the problem was actually on my local setup, I hadn't incorporated your JaxBackendConfig change into my workflow yet so my code change got overwritten -- the mHC gradients are flowing for me now :)

@pcmoritz
Copy link
Collaborator

pcmoritz commented Feb 19, 2026

These are the best settings I have found so far (one thing it does is break the symmetry between the different streams), but the gains are still relatively small, we can likely find something better. Ideally there would be an initialization that is a little more similar to what we do with LoRA (one adapter zero, the other random, so the model is the same as the base initially, but also there is gradient signal). Choosing b_res a large multiple of the identity matrix feels a little more forced. If you have any thoughts let me know :)

[the FP32 changes are not needed I think, I just wanted to make sure there is no problem with the numerics]

diff --git a/skyrl-tx/tx/layers/connectors.py b/skyrl-tx/tx/layers/connectors.py
index 013cfe53..892052b2 100644
--- a/skyrl-tx/tx/layers/connectors.py
+++ b/skyrl-tx/tx/layers/connectors.py
@@ -12,11 +12,32 @@ def is_connector_path(path: tuple[Any, ...]) -> bool:
     return any(name in normalized_path for name in ("attn_connector", "mlp_connector"))
 
 
+def _logit(x: jax.Array) -> jax.Array:
+    """Inverse sigmoid: logit(x) = log(x / (1-x))."""
+    x = jnp.clip(x, 1e-6, 1.0 - 1e-6)
+    return jnp.log(x) - jnp.log(1.0 - x)
+
+
+def default_b_pre(n: int, dtype: jnp.dtype = jnp.float32) -> jax.Array:
+    """H_pre = sigmoid(b_pre) = 1/n: uniform aggregation across streams."""
+    return _logit(jnp.array(1.0 / n, dtype=dtype))
+
+
+def default_b_post(n: int, dtype: jnp.dtype = jnp.float32) -> jax.Array:
+    """H_post spectrum from 0 to 2: stream 0 is pure memory, stream n-1 is full update.
+
+    Mean H_post = 1.0, preserving standard residual behavior on average.
+    """
+    return _logit(jnp.linspace(0.0, 1.0, n, dtype=dtype))
+
+
 class LoRAConnector(nnx.Module):
     """
     Implementation of Manifold constrained HyperConnections (https://arxiv.org/pdf/2512.24880)
 
-    Weights initialized with identity mapping; Default behaviour equates to residual networks.
+    Streams are initialized with a spectrum of roles: stream 0 acts as pure memory
+    (preserving the residual), stream n-1 carries the full sublayer update, and
+    intermediate streams carry various mixtures. Mean behavior equals standard residual.
     """
 
     def __init__(
@@ -40,41 +61,32 @@ class LoRAConnector(nnx.Module):
         n = expansion_rate
         C = hidden_dim
 
-        self.input_norm_weight = nnx.Param(jnp.ones((max_lora_adapters, n * C), dtype=dtype))
+        # All connector parameters are stored in FP32 for numerical stability.
+        # Phi matrices are zero-initialized so that alpha * x @ 0 + bias = bias at init.
+        self.input_norm_weight = nnx.Param(jnp.ones((max_lora_adapters, n * C), dtype=jnp.float32))
         self.phi_pre = Param(
-            max_lora_adapters, n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs
+            max_lora_adapters, n * C, n, dtype=jnp.float32, kernel_init=nnx.initializers.zeros_init(), rngs=rngs
         )
         self.phi_post = Param(
-            max_lora_adapters, n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs
+            max_lora_adapters, n * C, n, dtype=jnp.float32, kernel_init=nnx.initializers.zeros_init(), rngs=rngs
         )
         self.phi_res = Param(
-            max_lora_adapters,
-            n * C,
-            n * n,
-            dtype=dtype,
-            kernel_init=nnx.initializers.normal(stddev=0.02),
-            rngs=rngs,
+            max_lora_adapters, n * C, n * n, dtype=jnp.float32, kernel_init=nnx.initializers.zeros_init(), rngs=rngs,
         )
 
-        # Initialize biases for identity-like behavior:
-        # H_pre = 1/n (uniform aggregation), H_post = 1 (full distribution), M = I (identity mixing)
-
-        # H_pre = sigmoid(b_pre) = 1/n  =>  b_pre = logit(1/n)
-        target_h_pre = jnp.array(1.0 / n, dtype=dtype)
-        clamped = jnp.clip(target_h_pre, 1e-6, 1.0 - 1e-6)
-        inv_sigmoid = jnp.log(clamped) - jnp.log(1.0 - clamped)
-        self.b_pre = nnx.Param(jnp.full((max_lora_adapters, n), inv_sigmoid, dtype=dtype))
+        # H_pre = sigmoid(b_pre) = 1/n: uniform aggregation across streams
+        self.b_pre = nnx.Param(jnp.full((max_lora_adapters, n), default_b_pre(n), dtype=jnp.float32))
 
-        # H_post = 2 * sigmoid(b_post) = 1  =>  b_post = 0
-        self.b_post = nnx.Param(jnp.zeros((max_lora_adapters, n), dtype=dtype))
+        # H_post = 2 * sigmoid(b_post): spectrum from 0 to 2, creating stream diversity.
+        # Stream 0 = pure memory, stream n-1 = full update, mean = 1 (standard residual).
+        self.b_post = nnx.Param(jnp.broadcast_to(default_b_post(n), (max_lora_adapters, n)))
 
-        # Large identity matrix -> heavily biases Sinkhorn() to product an identity matrix
-        self.b_res = nnx.Param(jnp.broadcast_to(10.0 * jnp.eye(n, dtype=dtype), (max_lora_adapters, n, n)))
+        # M ≈ I: identity mixing via Sinkhorn
+        self.b_res = nnx.Param(jnp.broadcast_to(3.0 * jnp.eye(n, dtype=jnp.float32), (max_lora_adapters, n, n)))
 
-        # Alpha = 0 so phi matrices don't contribute initially
-        self.alpha_pre = nnx.Param(jnp.zeros((max_lora_adapters,), dtype=dtype))
-        self.alpha_post = nnx.Param(jnp.zeros((max_lora_adapters,), dtype=dtype))
-        self.alpha_res = nnx.Param(jnp.zeros((max_lora_adapters,), dtype=dtype))
+        self.alpha_pre = nnx.Param(jnp.full((max_lora_adapters,), 0.01, dtype=jnp.float32))
+        self.alpha_post = nnx.Param(jnp.full((max_lora_adapters,), 0.01, dtype=jnp.float32))
+        self.alpha_res = nnx.Param(jnp.full((max_lora_adapters,), 0.01, dtype=jnp.float32))
 
     def _get_adapter_indices(self, batch_size: int, adapter_indices: jax.Array | None) -> jax.Array:
         if adapter_indices is None:
@@ -92,10 +104,11 @@ class LoRAConnector(nnx.Module):
         return jax.lax.fori_loop(0, iters, step, M)
 
     def _norm(self, x_flat: jax.Array, adapter_indices: jax.Array) -> jax.Array:
-        """Separate norm from layernorm.RMSNorm due to adapter indexing and trainability"""
+        """Separate norm from layernorm.RMSNorm due to adapter indexing and trainability."""
         input_norm_weight = self.input_norm_weight[adapter_indices]
-        rms = jnp.sqrt(jnp.mean(x_flat**2, axis=-1, keepdims=True) + self.eps)
-        return (input_norm_weight[:, None, :] * x_flat) / rms
+        x_float = x_flat.astype(jnp.float32)
+        rms = jnp.sqrt(jnp.mean(x_float**2, axis=-1, keepdims=True) + self.eps)
+        return (input_norm_weight[:, None, :] * x_float) / rms
 
     def pre(self, x: jax.Array, adapter_indices: jax.Array | None = None) -> tuple[jax.Array, jax.Array]:
         B, T, n, C = x.shape
@@ -105,7 +118,7 @@ class LoRAConnector(nnx.Module):
 
         adapter_indices = self._get_adapter_indices(B, adapter_indices)
         x_flat = x.reshape(B, T, n * C)
-        x_norm = self._norm(x_flat, adapter_indices)
+        x_norm = self._norm(x_flat, adapter_indices)  # float32
 
         alpha_pre = self.alpha_pre[adapter_indices]
         phi_pre = self.phi_pre[adapter_indices]
@@ -114,10 +127,10 @@ class LoRAConnector(nnx.Module):
         tilde_H_pre = alpha_pre[:, None, None] * pre_logits + b_pre[:, None, :]
 
         H_pre = jax.nn.sigmoid(tilde_H_pre)
-        x_agg = (H_pre[..., None] * x).sum(axis=-2)
+        x_agg = (H_pre[..., None] * x.astype(jnp.float32)).sum(axis=-2)
 
         # Return residual norm for future use by post()
-        return x_agg, x_norm
+        return x_agg.astype(x.dtype), x_norm
 
     def post(
         self,
@@ -148,6 +161,6 @@ class LoRAConnector(nnx.Module):
         H_post = 2.0 * jax.nn.sigmoid(tilde_H_post)
         M = self._sinkhorn_knopp(tilde_H_res, self.sinkhorn_iters)
 
-        y_dist = H_post[..., None] * output[..., None, :]
-        x_mixed = M @ residual
-        return x_mixed + y_dist
+        y_dist = H_post[..., None] * output.astype(jnp.float32)[..., None, :]
+        x_mixed = M @ residual.astype(jnp.float32)
+        return (x_mixed + y_dist).astype(residual.dtype)
diff --git a/skyrl-tx/tx/layers/lora.py b/skyrl-tx/tx/layers/lora.py
index d9660bfd..83170086 100644
--- a/skyrl-tx/tx/layers/lora.py
+++ b/skyrl-tx/tx/layers/lora.py
@@ -3,7 +3,7 @@ import jax
 from jax import numpy as jnp
 
 from tx.utils.models import filter_lora, get_adapter_idx
-from tx.layers.connectors import is_connector_path
+from tx.layers.connectors import default_b_post, default_b_pre, is_connector_path
 from tx.layers.util import Param, prepare_routing, ragged_dot
 from tx.models.types import ModelForCausalLM
 from tx.tinker.types import LoraConfig
@@ -354,23 +354,20 @@ def init_lora_adapter(model: ModelForCausalLM, adapter_index: int, lora_config:
         if is_connector_path(path):
             connector_slot = value[idx]
             if key_name in {"alpha_pre", "alpha_post", "alpha_res"}:
-                return value.at[idx].set(jnp.zeros_like(connector_slot))
+                return value.at[idx].set(jnp.full_like(connector_slot, 0.01))
             if key_name == "input_norm_weight":
                 return value.at[idx].set(jnp.ones_like(connector_slot))
             if key_name in {"phi_pre", "phi_post", "phi_res"}:
-                new_phi = nnx.initializers.normal(stddev=0.02)(rngs.params(), connector_slot.shape, value.dtype)
-                return value.at[idx].set(new_phi)
+                return value.at[idx].set(jnp.zeros_like(connector_slot))
             if key_name == "b_pre":
                 n = connector_slot.shape[-1]
-                target_h_pre = jnp.array(1.0 / n, dtype=value.dtype)
-                clamped = jnp.clip(target_h_pre, 1e-6, 1.0 - 1e-6)
-                inv_sigmoid = jnp.log(clamped) - jnp.log(1.0 - clamped)
-                return value.at[idx].set(jnp.full(connector_slot.shape, inv_sigmoid, dtype=value.dtype))
+                return value.at[idx].set(jnp.full(connector_slot.shape, default_b_pre(n, value.dtype), dtype=value.dtype))
             if key_name == "b_post":
-                return value.at[idx].set(jnp.zeros_like(connector_slot))
+                n = connector_slot.shape[-1]
+                return value.at[idx].set(jnp.broadcast_to(default_b_post(n, value.dtype), connector_slot.shape))
             if key_name == "b_res":
                 n = connector_slot.shape[-1]
-                eye = 10.0 * jnp.eye(n, dtype=value.dtype)
+                eye = 3.0 * jnp.eye(n, dtype=value.dtype)
                 return value.at[idx].set(jnp.broadcast_to(eye, connector_slot.shape))
             return value
 
@@ -409,22 +406,21 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int):
         # remains behaviorally neutral for mHC before being reinitialized.
         if is_connector_path(path):
             connector_slot = value[idx]
-            if key in {"alpha_pre", "alpha_post", "alpha_res", "b_post"}:
+            if key in {"alpha_pre", "alpha_post", "alpha_res"}:
                 return value.at[idx].set(jnp.zeros_like(connector_slot))
             if key in {"phi_pre", "phi_post", "phi_res"}:
-                # Keep clear deterministic and neutral: alpha=0 makes phi inactive.
                 return value.at[idx].set(jnp.zeros_like(connector_slot))
             if key == "input_norm_weight":
                 return value.at[idx].set(jnp.ones_like(connector_slot))
             if key == "b_pre":
                 n = connector_slot.shape[-1]
-                target_h_pre = jnp.array(1.0 / n, dtype=value.dtype)
-                clamped = jnp.clip(target_h_pre, 1e-6, 1.0 - 1e-6)
-                inv_sigmoid = jnp.log(clamped) - jnp.log(1.0 - clamped)
-                return value.at[idx].set(jnp.full(connector_slot.shape, inv_sigmoid, dtype=value.dtype))
+                return value.at[idx].set(jnp.full(connector_slot.shape, default_b_pre(n, value.dtype), dtype=value.dtype))
+            if key == "b_post":
+                n = connector_slot.shape[-1]
+                return value.at[idx].set(jnp.broadcast_to(default_b_post(n, value.dtype), connector_slot.shape))
             if key == "b_res":
                 n = connector_slot.shape[-1]
-                eye = 10.0 * jnp.eye(n, dtype=value.dtype)
+                eye = 3.0 * jnp.eye(n, dtype=value.dtype)
                 return value.at[idx].set(jnp.broadcast_to(eye, connector_slot.shape))
             return value.at[idx].set(0.0)
         if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"):
diff --git a/skyrl-tx/tx/tinker/backends/jax.py b/skyrl-tx/tx/tinker/backends/jax.py
index f836026e..c15299ad 100644
--- a/skyrl-tx/tx/tinker/backends/jax.py
+++ b/skyrl-tx/tx/tinker/backends/jax.py
@@ -37,6 +37,7 @@ from pydantic import BaseModel, Field, TypeAdapter
 from transformers import AutoTokenizer, PretrainedConfig
 
 from tx.models.configs import Qwen3Config
+from tx.layers.connectors import is_connector_path
 from tx.layers.lora import clear_lora_adapter, init_lora_adapter
 from tx.tinker import types
 from tx.tinker.backends.backend import AbstractBackend
@@ -498,12 +499,17 @@ class JaxBackendImpl(AbstractBackend):
             lora_params: nnx.State,
             optimizer: nnx.Optimizer,
             adapter_index: jax.Array,
-        ) -> tuple[AccumulatedGradients, jax.Array]:
+        ) -> tuple[AccumulatedGradients, jax.Array, jax.Array]:
             """Compute full gradients, apply optimizer update, and reset accumulated grads."""
             mean_grads = accumulated_grads.get_mean(adapter_index)
             grad_norm = optax.global_norm(mean_grads)
+            mhc_grads = jax.tree.map_with_path(
+                lambda path, g: g if is_connector_path(path) else jnp.zeros_like(g),
+                mean_grads,
+            )
+            mhc_grad_norm = optax.global_norm(mhc_grads)
             optimizer.update(lora_params, mean_grads)
-            return accumulated_grads.reset_adapter(adapter_index), grad_norm
+            return accumulated_grads.reset_adapter(adapter_index), grad_norm, mhc_grad_norm
 
         if self.config.enforce_eager:
             self._compute_grads_and_update = compute_grads_and_update
@@ -766,16 +772,27 @@ class JaxBackendImpl(AbstractBackend):
 
         # JIT-compiled: compute full gradients, apply optimizer update, and reset accumulated grads
         with jax.set_mesh(self.mesh):
-            self.accumulated_grads, grad_norm = self._compute_grads_and_update(
+            self.accumulated_grads, grad_norm, mhc_grad_norm = self._compute_grads_and_update(
                 self.accumulated_grads,
                 self.lora_params,
                 optimizer,
                 jnp.int32(adapter_index),
             )
 
-        grad_norm = float(jax.device_get(grad_norm))
-        logger.info(f"Applied optimizer step for model {model_id} (adapter {adapter_index}), grad_norm={grad_norm}")
-        return types.OptimStepOutput(metrics={"skyrl.ai/grad_norm": grad_norm, "skyrl.ai/learning_rate": learning_rate})
+        grad_norm, mhc_grad_norm = jax.device_get((grad_norm, mhc_grad_norm))
+        grad_norm = float(grad_norm)
+        mhc_grad_norm = float(mhc_grad_norm)
+        logger.info(
+            f"Applied optimizer step for model {model_id} "
+            f"(adapter {adapter_index}), grad_norm={grad_norm}, mhc_grad_norm={mhc_grad_norm}"
+        )
+        return types.OptimStepOutput(
+            metrics={
+                "skyrl.ai/grad_norm": grad_norm,
+                "skyrl.ai/mhc_gradient_norm": mhc_grad_norm,
+                "skyrl.ai/learning_rate": learning_rate,
+            }
+        )
 
     def sample(
         self,

@tanmaysachan
Copy link
Contributor Author

tanmaysachan commented Feb 19, 2026

@pcmoritz Thanks for running experiments with it!
I'm pretty compute restricted so can't play around with it much to run experiments 😓
I'll try to do it on some small variants

I was thinking something along the lines of havine b_res initialized as product of 2 randomly initialized square matrices A and A^-1 followed by sinkhorn. Since b_res is quite small, this would be negligible parameter count wise, but this would be a departure from core mHC

devin-ai-integration[bot]

This comment was marked as resolved.

@tanmaysachan
Copy link
Contributor Author

For clarity - should clear_adapter and init_adapter methods be implemented as static by the adapter classes themselves? I can make the change

@pcmoritz
Copy link
Collaborator

@tanmaysachan Yes, that would be much better, there are a few more small changes I'd like to try (like shifting is_lora_param into the ModelForCausalLM base class so it won't get repeated. Why don't you go ahead and make the change about shifting clear_adapter and init_adapter as static into LoRAConnector, and then afterwards I'll have another pass to make the other cleanups :)

pcmoritz added a commit that referenced this pull request Feb 20, 2026
This is in preparation for merging
#1008 and to make it easier to
introduce metrics.

<!-- devin-review-badge-begin -->

---

<a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1191"
target="_blank">
  <picture>
<source media="(prefers-color-scheme: dark)"
srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1">
<img
src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1"
alt="Open with Devin">
  </picture>
</a>
<!-- devin-review-badge-end -->
@pcmoritz
Copy link
Collaborator

Thanks a lot for all the updates, I'll do the rest (already merged a PR that cleans things up a little #1191) :)

C = hidden_dim

# Phi matrices are zero-initialized so that alpha * x @ 0 + bias = bias at init.
self.input_norm_weight = nnx.Param(jnp.ones((max_lora_adapters, n * C), dtype=dtype))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious, why did you make the RMSNorm per adapter and trainable? That seems wrong, we should probably just use the RMSNorm from the base model :) [I don't think any of the LoRA codes out there make the RMSNorm trainable]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think I misunderstood the code and you are doing the right thing :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for going back on forth on this, but I think the actually correct implementation would be to pass the input norm parameters from the model to the constructor of LoRAConnector and use it for the normalization below and keep it non-trainable. It will be slightly redundant to apply the norm twice, but I think for code clarity that's fine for now (there is more optimizations to do anyways). Let me know about your thoughts, I'll give that a shot :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change!
My layernorm change was around making the entire block including the norms trainable - but yeah if its something like lora that shouldn't be the case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants