Skip to content

DeepSeek MHC feature#3065

Open
RissyRan wants to merge 1 commit intomainfrom
mhc_feature
Open

DeepSeek MHC feature#3065
RissyRan wants to merge 1 commit intomainfrom
mhc_feature

Conversation

@RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented Feb 2, 2026

Description

Initial implementation of DeepSeek MHC feature:

  • MHC paper: https://arxiv.org/pdf/2512.24880
  • Research implementation from 2 repos; however, I noticed they are not exactly the same as paper. So I took those as reference and implement based on paper, to better fit in MaxText integration.
    • Ref1 - weights init and sinkhorn are slightly different, i.e. sinkhorn used log version
    • Ref2 - has extra configs like num_fracs, num_input_views potentially from HC
    • Move the expansion_rate to 3rd dimension to align with Engram reference (highly likely for V4)

Next: model/decoder layer integration using this deepseek-custom config, and test end-to-end.

Tests

  • Unit tests for MHC feature: link
  • I didn't find a good way to test implementation details in the unit test (only test shape for core part), but plan to test convergence/loss in the coming model integration, compared to non-MHC workload.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Collaborator Author

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gemini-cli /review

@github-actions
Copy link

github-actions bot commented Feb 2, 2026

🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces the DeepSeek Manifold-Constrained Hyper Connections (mHC) feature. The implementation includes new common types, configuration updates, the core mhc layer, and unit tests. The unit tests provide good coverage for the expand/reduce functions and the sinkhorn algorithm.

🔍 General Feedback

  • The overall structure of the mHC implementation is clear, and the separation into common types, configuration, and a dedicated layer is well-organized.
  • The unit tests for the utility functions (expand, reduce, sinkhorn) are thorough and cover important properties.
  • A critical concern has been identified regarding the dependency of weight matrix initialization on batch and sequence dimensions, which should be addressed for robustness and scalability.

@RissyRan RissyRan marked this pull request as draft February 2, 2026 19:54
@RissyRan RissyRan marked this pull request as ready for review February 2, 2026 20:26
@RissyRan RissyRan changed the title [WIP] DeepSeek MHC feature DeepSeek MHC feature Feb 2, 2026
@codecov
Copy link

codecov bot commented Feb 3, 2026

Codecov Report

❌ Patch coverage is 97.53086% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/mhc.py 97.36% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@shuningjin shuningjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing the feature! There seems to be some ambiguity; I left a few questions. I may need another look.

in_axis = 0
out_axis = 1
weight_sharding_axis_name = ("activation_embed", None)
self.res_alpha = nnx.Param(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general question: for linear weight, wondering how to choose between linears.DenseGeneral and nnx.Param?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are different.

  • nnx.Param is inside the init method of your custom nnx.Module classes to define the trainable tensors/weights
  • linears.DenseGeneral is a pre-built NNX Module (a layer), and will use nnx.Param to define its weight matrix and bias vector

def res_mapping(self, x: jnp.ndarray):
"""Helper function for residule mapping."""
# Apply projection: (b, s, k*d) @ (k*d, k*k) -> (k*k)
h_res = jnp.einsum("bsm,mn -> n", x, self.res_alpha[...])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q3-1. Could it be bsm, mn -> bsn?

  • the paper omit batch and seq. in eq 7, H_res has (k*d) @ (k*d, k*k) -> (k*k). then reshape to (k, k) matrix for sinkhorn
  • with batch and seq, will we get a batch of b*s (k, k) matrix in sinkhorn?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Let me take a deep look. If so, we will need to align all pre/post mapping with res mapping

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a little bit confusing, as for weights, we shoudn't put batch & seq, i.e. for alpha and beta, and the paper annotates as either 1xn, or nxn i.e. for beta.

For H, Genemi said it should also be batch, seq independent: https://screenshot.googleplex.com/AkWCwapW7MjKQqp

def mapping(self, x: jnp.ndarray, alpha_scale: jnp.ndarray, alpha: jnp.ndarray, beta: jnp.ndarray, scale: int):
"""Helper function for both pre and post mappings."""
# Apply projection: (b, s, k*d) @ (k*d, k) -> (k)
h = jnp.einsum("bsm,mk -> k", x, alpha)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q3-2. Could it be bsm,mk -> bsk?

  • the paper omit batch and seq. in eq 7, H_pre has (k*d) @ (k*d, k) -> (k)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants