-
Notifications
You must be signed in to change notification settings - Fork 315
[tx] General implementation of trainable Hyper Connections #1008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
57d1881
98c0994
91d5e74
24b82d7
874ab08
975faa1
f685543
e493ae5
b4ad7ad
066af09
bcd4e41
744ce19
9f88cc5
495bb38
c204a38
7a2e921
c9b3b93
587a3bf
6798eae
0b8ed0a
7c02962
cea717b
8321e5d
c14203b
41add0d
43c53d4
4f42fe2
01e3e3d
a7b6925
dcc241e
cbbff6f
ac6c81b
3fcb413
5260f38
9353ecb
8086212
cb7a559
7752d22
5da4994
07e5cce
f7a6756
a0a3ce0
96089f3
4e60734
ac9b815
644b714
209dccc
72d4369
df3501f
c546a8a
74895d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| from flax import nnx | ||
| import jax | ||
| from jax import numpy as jnp | ||
|
|
||
| from tx.layers.util import Param | ||
| from tx.layers.layernorm import RMSNorm | ||
|
|
||
|
|
||
| class Connector(nnx.Module): | ||
| """ | ||
| Implementation of Manifold constrained HyperConnections (https://arxiv.org/pdf/2512.24880) | ||
|
|
||
| Weights initialized with identity mapping; Default behaviour equates to residual networks. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| hidden_dim: int, | ||
| expansion_rate: int, | ||
| *, | ||
| trainable: bool = False, | ||
| sinkhorn_iters: int = 20, | ||
| eps: float = 1e-5, | ||
| dtype: jnp.dtype, | ||
| rngs: nnx.Rngs, | ||
| ) -> None: | ||
| self.hidden_dim = hidden_dim | ||
| self.expansion_rate = expansion_rate | ||
| self.trainable = trainable | ||
| self.sinkhorn_iters = sinkhorn_iters | ||
| self.eps = eps | ||
| n = expansion_rate | ||
| C = hidden_dim | ||
|
|
||
| self.norm = RMSNorm(hidden_dim, eps=eps, dtype=dtype, rngs=rngs) | ||
|
|
||
| self.phi_pre = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) | ||
| self.phi_post = Param(n * C, n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) | ||
| self.phi_res = Param(n * C, n * n, dtype=dtype, kernel_init=nnx.initializers.normal(stddev=0.02), rngs=rngs) | ||
|
|
||
| # Initialize biases for identity-like behavior: | ||
| # H_pre = 1/n (uniform aggregation), H_post = 1 (full distribution), M = I (identity mixing) | ||
|
|
||
| # H_pre = sigmoid(b_pre) = 1/n => b_pre = logit(1/n) | ||
| target_h_pre = jnp.array(1.0 / n, dtype=dtype) | ||
| clamped = jnp.clip(target_h_pre, 1e-6, 1.0 - 1e-6) | ||
| logit_1_over_n = jnp.log(clamped) - jnp.log(1.0 - clamped) | ||
| self.b_pre = nnx.Param(jnp.full((n,), logit_1_over_n, dtype=dtype)) | ||
|
|
||
| # H_post = 2 * sigmoid(b_post) = 1 => b_post = 0 | ||
| self.b_post = nnx.Param(jnp.zeros((n,), dtype=dtype)) | ||
|
|
||
| # M = sinkhorn(exp(b_res)) = I => b_res = large diagonal matrix | ||
| self.b_res = nnx.Param(10.0 * jnp.eye(n, dtype=dtype)) | ||
|
|
||
| self.alpha_pre = nnx.Param(jnp.array(0.0, dtype=dtype)) | ||
| self.alpha_post = nnx.Param(jnp.array(0.0, dtype=dtype)) | ||
| self.alpha_res = nnx.Param(jnp.array(0.0, dtype=dtype)) | ||
|
|
||
| def _sinkhorn_knopp(self, M: jax.Array) -> jax.Array: | ||
| M = jnp.exp(M) | ||
| for _ in range(self.sinkhorn_iters): | ||
| M = M / (M.sum(axis=-1, keepdims=True) + self.eps) | ||
| M = M / (M.sum(axis=-2, keepdims=True) + self.eps) | ||
| return M | ||
|
|
||
| def _get_params(self): | ||
| """Get all connector params, with stop_gradient applied if not trainable.""" | ||
| sg = (lambda x: x) if self.trainable else jax.lax.stop_gradient | ||
| return ( | ||
| sg(self.alpha_pre[...]), sg(self.alpha_post[...]), sg(self.alpha_res[...]), | ||
| sg(self.phi_pre[...]), sg(self.phi_post[...]), sg(self.phi_res[...]), | ||
| sg(self.b_pre[...]), sg(self.b_post[...]), sg(self.b_res[...]), | ||
| sg(self.norm.weight[...]), | ||
| ) | ||
|
|
||
| def pre(self, x: jax.Array) -> jax.Array: | ||
| *batch_dims, n, C = x.shape | ||
|
|
||
| x_flat = x.reshape(*batch_dims, n * C) | ||
| rms = jnp.sqrt(jnp.mean(x_flat * x_flat, axis=-1, keepdims=True) + self.eps) | ||
| x_norm = x_flat / rms | ||
|
|
||
| (alpha_pre, alpha_post, alpha_res, phi_pre, phi_post, phi_res, | ||
| b_pre, b_post, b_res, norm_weight) = self._get_params() | ||
|
|
||
| tilde_H_pre = alpha_pre * (x_norm @ phi_pre) + b_pre | ||
| tilde_H_post = alpha_post * (x_norm @ phi_post) + b_post | ||
| tilde_H_res = alpha_res * (x_norm @ phi_res).reshape(*batch_dims, n, n) + b_res | ||
|
|
||
| H_pre = jax.nn.sigmoid(tilde_H_pre) | ||
| self._H_post = 2.0 * jax.nn.sigmoid(tilde_H_post) | ||
| self._M = self._sinkhorn_knopp(tilde_H_res) | ||
|
|
||
| x_agg = jnp.einsum("...i,...ic->...c", H_pre, x) | ||
| rms_norm = jnp.sqrt(jnp.mean(x_agg**2, axis=-1, keepdims=True) + self.norm.eps) | ||
| x_normed = norm_weight * x_agg / rms_norm | ||
|
|
||
| return x_normed | ||
|
|
||
| def post(self, residual: jax.Array, output: jax.Array) -> jax.Array: | ||
| y_dist = self._H_post[..., None] * output[..., None, :] | ||
| x_mixed = jnp.einsum("...ij,...jc->...ic", self._M, residual) | ||
| return x_mixed + y_dist | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ class RMSNorm(nnx.Module): | |
| def __init__(self, size: int, *, eps: float = 1e-6, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: | ||
| self.eps = eps | ||
| self.weight = Param( | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.normal(), jax.P(None)), rngs=rngs | ||
| size, dtype=dtype, kernel_init=nnx.with_partitioning(nnx.initializers.ones_init(), jax.P(None)), rngs=rngs | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Temporary, testing
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://docs.pytorch.org/docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html Torch also initalizes to one by default
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Due to adapter indexing, ended up re-implementing norm in the connector layer itself - this change can be removed. But considering torch as the baseline, ones_init fits better still |
||
| ) | ||
|
|
||
| def __call__(self, x: jax.Array) -> jax.Array: | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -7,6 +7,7 @@ | |||||
| from tx.layers.rotary_embedding import get_rope | ||||||
| from tx.layers.util import Param, prepare_routing, shard_map_ep | ||||||
| from tx.layers.layernorm import RMSNorm | ||||||
| from tx.layers.connectors import Connector | ||||||
| from tx.models.configs import DeepseekV3Config | ||||||
| from tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput | ||||||
| from tx.utils.generator import GeneratorMixin, KVCache | ||||||
|
|
@@ -417,17 +418,28 @@ def __call__( | |||||
|
|
||||||
| class DeepseekV3DecoderLayer(nnx.Module): | ||||||
|
|
||||||
| def __init__(self, config: DeepseekV3Config, layer_idx: int, *, dtype: jnp.dtype, rngs: nnx.Rngs) -> None: | ||||||
| self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) | ||||||
| self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs) | ||||||
| def __init__( | ||||||
| self, | ||||||
| config: DeepseekV3Config, | ||||||
| layer_idx: int, | ||||||
| *, | ||||||
| dtype: jnp.dtype, | ||||||
| rngs: nnx.Rngs, | ||||||
| expansion_rate: int = 1, | ||||||
| ) -> None: | ||||||
| self.self_attn = DeepseekV3Attention(config, dtype=dtype, rngs=rngs) | ||||||
| self.layer_idx = layer_idx | ||||||
| self.num_layers = config.num_hidden_layers | ||||||
| self.expansion_rate = expansion_rate | ||||||
|
|
||||||
| # Use dense MLP for initial layers, MoE for the rest | ||||||
| if layer_idx >= config.first_k_dense_replace: | ||||||
| self.mlp = DeepseekV3MoE(config, dtype=dtype, rngs=rngs) | ||||||
| else: | ||||||
| self.mlp = DeepseekV3MLP(config, dtype=dtype, rngs=rngs) | ||||||
|
|
||||||
| self.attn_connector = Connector(config.hidden_size, expansion_rate, dtype=dtype, rngs=rngs) | ||||||
| self.mlp_connector = Connector(config.hidden_size, expansion_rate, dtype=dtype, rngs=rngs) | ||||||
|
|
||||||
| def __call__( | ||||||
| self, | ||||||
| hidden_states: jax.Array, | ||||||
|
|
@@ -437,21 +449,28 @@ def __call__( | |||||
| adapter_indices: jax.Array | None = None, | ||||||
| kv_cache: tuple[jax.Array, jax.Array] | None = None, | ||||||
| ) -> tuple[jax.Array, tuple[jax.Array, jax.Array]]: | ||||||
| n = self.expansion_rate | ||||||
| if self.layer_idx == 0: | ||||||
| hidden_states = jnp.repeat(hidden_states[..., None, :], n, axis=-2) | ||||||
|
|
||||||
| residual = hidden_states | ||||||
| hidden_states = self.input_layernorm(hidden_states) | ||||||
| hidden_states = self.attn_connector.pre(hidden_states) | ||||||
| hidden_states, updated_cache = self.self_attn( | ||||||
| hidden_states, | ||||||
| attention_mask=attention_mask, | ||||||
| positions=positions, | ||||||
| adapter_indices=adapter_indices, | ||||||
| kv_cache=kv_cache, | ||||||
| ) | ||||||
| hidden_states = residual + hidden_states | ||||||
| hidden_states = self.attn_connector.post(residual, hidden_states) | ||||||
|
|
||||||
| residual = hidden_states | ||||||
| hidden_states = self.post_attention_layernorm(hidden_states) | ||||||
| hidden_states = self.mlp_connector.pre(hidden_states) | ||||||
| mlp_output = self.mlp(hidden_states, adapter_indices=adapter_indices) | ||||||
| hidden_states = residual + mlp_output | ||||||
| hidden_states = self.mlp_connector.post(residual, mlp_output) | ||||||
|
|
||||||
| if self.layer_idx == self.num_layers - 1: | ||||||
| hidden_states = hidden_states.sum(axis=-2) | ||||||
|
|
||||||
| return hidden_states, updated_cache | ||||||
|
|
||||||
|
|
@@ -500,7 +519,7 @@ def __call__( | |||||
|
|
||||||
| for layer_idx, layer in enumerate(self.layers): | ||||||
| if output_hidden_states: | ||||||
| all_hidden_states.append(hidden_states) | ||||||
| all_hidden_states.append(hidden_states.squeeze()) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
A more robust approach is to aggregate across the expansion dimension, for example by taking the mean.
Suggested change
devin-ai-integration[bot] marked this conversation as resolved.
Outdated
|
||||||
|
|
||||||
| hidden_states, (k, v) = layer( | ||||||
| hidden_states, | ||||||
|
|
||||||
Uh oh!
There was an error while loading. Please reload this page.