Conversation
RissyRan
left a comment
There was a problem hiding this comment.
@gemini-cli /review
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This pull request introduces the DeepSeek Manifold-Constrained Hyper Connections (mHC) feature. The implementation includes new common types, configuration updates, the core mhc layer, and unit tests. The unit tests provide good coverage for the expand/reduce functions and the sinkhorn algorithm.
🔍 General Feedback
- The overall structure of the mHC implementation is clear, and the separation into common types, configuration, and a dedicated layer is well-organized.
- The unit tests for the utility functions (
expand,reduce,sinkhorn) are thorough and cover important properties. - A critical concern has been identified regarding the dependency of weight matrix initialization on batch and sequence dimensions, which should be addressed for robustness and scalability.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
shuningjin
left a comment
There was a problem hiding this comment.
Thanks for implementing the feature! There seems to be some ambiguity; I left a few questions. I may need another look.
| in_axis = 0 | ||
| out_axis = 1 | ||
| weight_sharding_axis_name = ("activation_embed", None) | ||
| self.res_alpha = nnx.Param( |
There was a problem hiding this comment.
general question: for linear weight, wondering how to choose between linears.DenseGeneral and nnx.Param?
There was a problem hiding this comment.
Those are different.
- nnx.Param is inside the init method of your custom nnx.Module classes to define the trainable tensors/weights
- linears.DenseGeneral is a pre-built NNX Module (a layer), and will use nnx.Param to define its weight matrix and bias vector
| def res_mapping(self, x: jnp.ndarray): | ||
| """Helper function for residule mapping.""" | ||
| # Apply projection: (b, s, k*d) @ (k*d, k*k) -> (k*k) | ||
| h_res = jnp.einsum("bsm,mn -> n", x, self.res_alpha[...]) |
There was a problem hiding this comment.
Q3-1. Could it be bsm, mn -> bsn?
- the paper omit batch and seq. in eq 7,
H_reshas(k*d) @ (k*d, k*k) -> (k*k). then reshape to (k, k) matrix for sinkhorn - with batch and seq, will we get a batch of b*s (k, k) matrix in sinkhorn?
There was a problem hiding this comment.
Good question. Let me take a deep look. If so, we will need to align all pre/post mapping with res mapping
There was a problem hiding this comment.
I think this is a little bit confusing, as for weights, we shoudn't put batch & seq, i.e. for alpha and beta, and the paper annotates as either 1xn, or nxn i.e. for beta.
For H, Genemi said it should also be batch, seq independent: https://screenshot.googleplex.com/AkWCwapW7MjKQqp
| def mapping(self, x: jnp.ndarray, alpha_scale: jnp.ndarray, alpha: jnp.ndarray, beta: jnp.ndarray, scale: int): | ||
| """Helper function for both pre and post mappings.""" | ||
| # Apply projection: (b, s, k*d) @ (k*d, k) -> (k) | ||
| h = jnp.einsum("bsm,mk -> k", x, alpha) |
There was a problem hiding this comment.
Q3-2. Could it be bsm,mk -> bsk?
- the paper omit batch and seq. in eq 7,
H_prehas(k*d) @ (k*d, k) -> (k)
Description
Initial implementation of DeepSeek MHC feature:
Next: model/decoder layer integration using this deepseek-custom config, and test end-to-end.
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.