From c4ed1f906d2c758aa77f5e2e518b613dcd8d429c Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 30 Mar 2026 17:45:05 +0800 Subject: [PATCH 1/9] feat(pt/dpmodel): add sequential_update for dpa3 --- deepmd/dpmodel/descriptor/dpa3.py | 18 + deepmd/dpmodel/descriptor/repflows.py | 451 ++++++++++++++++++ deepmd/pt/model/descriptor/dpa3.py | 1 + deepmd/pt/model/descriptor/repflow_layer.py | 411 ++++++++++++++++ deepmd/pt/model/descriptor/repflows.py | 3 + deepmd/utils/argcheck.py | 14 + .../tests/consistent/descriptor/test_dpa3.py | 10 + source/tests/pt/model/test_dpa3.py | 9 + 8 files changed, 917 insertions(+) diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 5f5aea50e5..88f56213ee 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -170,6 +170,12 @@ class RepFlowArgs: In the dynamic selection case, neighbor-scale normalization will use `e_sel / sel_reduce_factor` or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values, accommodating larger selection numbers. + sequential_update : bool, optional + Whether to use sequential update mode within each repflow layer. + When True, updates are applied sequentially: edge self → angle self (using updated edge) + → edge angle (using updated angle) → node (using final edge), + instead of the default parallel mode where all updates use original embeddings. + Currently only supports ``update_style='res_residual'``. """ def __init__( @@ -201,6 +207,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, ) -> None: self.n_dim = n_dim self.e_dim = e_dim @@ -231,6 +238,15 @@ def __init__( self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update + if self.sequential_update: + if self.update_style != "res_residual": + raise ValueError( + "sequential_update only supports update_style='res_residual', " + f"got '{self.update_style}'!" + ) + if not self.update_angle: + raise ValueError("sequential_update requires update_angle=True!") def __getitem__(self, key: str) -> Any: if hasattr(self, key): @@ -266,6 +282,7 @@ def serialize(self) -> dict: "use_exp_switch": self.use_exp_switch, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, } @classmethod @@ -404,6 +421,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, + sequential_update=self.repflow_args.sequential_update, use_loc_mapping=use_loc_mapping, exclude_types=exclude_types, env_protection=env_protection, diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 30637dc75a..5788840279 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -230,6 +230,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, use_loc_mapping: bool = True, seed: int | list[int] | None = None, trainable: bool = True, @@ -268,6 +269,7 @@ def __init__( self.use_dynamic_sel = use_dynamic_sel self.use_loc_mapping = use_loc_mapping self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update if self.use_dynamic_sel and not self.smooth_edge_update: raise NotImplementedError( "smooth_edge_update must be True when use_dynamic_sel is True!" @@ -339,6 +341,7 @@ def __init__( optim_update=self.optim_update, use_dynamic_sel=self.use_dynamic_sel, sel_reduce_factor=self.sel_reduce_factor, + sequential_update=self.sequential_update, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), trainable=trainable, @@ -757,6 +760,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "use_loc_mapping": self.use_loc_mapping, # variables "edge_embd": self.edge_embd.serialize(), @@ -905,6 +909,7 @@ def __init__( optim_update: bool = True, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, smooth_edge_update: bool = False, activation_function: str = "silu", update_style: str = "res_residual", @@ -954,8 +959,15 @@ def __init__( self.smooth_edge_update = smooth_edge_update self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update self.dynamic_e_sel = self.nnei / self.sel_reduce_factor self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor + if self.sequential_update and self.update_style != "res_residual": + raise NotImplementedError( + "sequential_update only supports update_style='res_residual'!" + ) + if self.sequential_update and not self.update_angle: + raise NotImplementedError("sequential_update requires update_angle=True!") assert update_residual_init in [ "norm", @@ -1342,6 +1354,418 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update + def _call_sequential( + self, + xp: object, + node_ebd: Array, + node_ebd_ext: Array, + edge_ebd: Array, + h2: Array, + angle_ebd: Array, + nlist: Array, + nlist_mask: Array, + sw: Array, + a_nlist_mask: Array, + a_sw: Array, + nei_node_ebd: Array, + n2e_index: Array, + n_ext2e_index: Array, + n2a_index: Array, + eij2a_index: Array, + eik2a_index: Array, + nb: int, + nloc: int, + nnei: int, + nall: int, + n_edge: int, + ) -> tuple[Array, Array, Array]: + """Sequential update path: edge_self → angle_self → edge_angle → node. + + Only supports update_style='res_residual'. + """ + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + + # ==================================================================== + # Phase 1: Edge self update (uses original node_ebd, edge_ebd) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = xp.concat( + [ + xp.tile( + xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), + (1, 1, self.nnei, 1), + ), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) + else: + edge_info = xp.concat( + [ + xp.take( + xp.reshape(node_ebd, (-1, self.n_dim)), + n2e_index, + axis=0, + ), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) + edge_self_update = self.act(self.edge_self_linear(edge_info)) + else: + edge_self_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) + ) + + # Apply edge self residual + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # ==================================================================== + # Phase 2: Angle self update (uses original node_ebd, updated edge_ebd_s1) + # ==================================================================== + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_s1) + else: + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd_s1[..., : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd_s1 + + if not self.use_dynamic_sel: + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + edge_ebd_for_angle = xp.where( + xp.expand_dims(a_nlist_mask, axis=-1), + edge_ebd_for_angle, + xp.zeros_like(edge_ebd_for_angle), + ) + + if not self.optim_update: + node_for_angle_info = ( + xp.tile( + xp.reshape( + node_ebd_for_angle, (nb, nloc, 1, 1, self.n_a_compress_dim) + ), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else xp.take( + xp.reshape(node_ebd_for_angle, (-1, self.n_a_compress_dim)), + n2a_index, + axis=0, + ) + ) + edge_for_angle_k = ( + xp.tile( + xp.reshape( + edge_ebd_for_angle, + (nb, nloc, 1, self.a_sel, self.e_a_compress_dim), + ), + (1, 1, self.a_sel, 1, 1), + ) + if not self.use_dynamic_sel + else xp.take( + edge_ebd_for_angle, + eik2a_index, + axis=0, + ) + ) + edge_for_angle_j = ( + xp.tile( + xp.reshape( + edge_ebd_for_angle, + (nb, nloc, self.a_sel, 1, self.e_a_compress_dim), + ), + (1, 1, 1, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else xp.take( + edge_ebd_for_angle, + eij2a_index, + axis=0, + ) + ) + edge_for_angle_info = xp.concat( + [edge_for_angle_k, edge_for_angle_j], axis=-1 + ) + angle_info = xp.concat( + [angle_ebd, node_for_angle_info, edge_for_angle_info], axis=-1 + ) + angle_self_update = self.act(self.angle_self_linear(angle_info)) + else: + angle_self_update = self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) + ) + + # Apply angle self residual + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # ==================================================================== + # Phase 3: Edge angle update (uses updated angle a_updated, updated edge_ebd_s1) + # ==================================================================== + if not self.optim_update: + angle_info_s2 = xp.concat( + [a_updated, node_for_angle_info, edge_for_angle_info], axis=-1 + ) + edge_angle_update = self.act(self.edge_angle_linear1(angle_info_s2)) + else: + edge_angle_update = self.act( + self.optim_angle_update( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) + ) + + # Reduce edge angle update over angle dimension + if not self.use_dynamic_sel: + weighted_edge_angle_update = ( + a_sw[:, :, :, xp.newaxis, xp.newaxis] + * a_sw[:, :, xp.newaxis, :, xp.newaxis] + * edge_angle_update + ) + reduced_edge_angle_update = xp.sum(weighted_edge_angle_update, axis=-2) / ( + self.a_sel**0.5 + ) + padding_edge_angle_update = xp.concat( + [ + reduced_edge_angle_update, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel, self.e_dim), + dtype=edge_ebd.dtype, + device=array_api_compat.device(edge_ebd), + ), + ], + axis=2, + ) + else: + weighted_edge_angle_update = edge_angle_update * xp.expand_dims( + a_sw, axis=-1 + ) + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, + eij2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + + if not self.smooth_edge_update: + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) + full_mask = xp.concat( + [ + a_nlist_mask, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel), + dtype=a_nlist_mask.dtype, + device=array_api_compat.device(a_nlist_mask), + ), + ], + axis=-1, + ) + padding_edge_angle_update = xp.where( + xp.expand_dims(full_mask, axis=-1), + padding_edge_angle_update, + edge_ebd, + ) + + edge_angle_processed = self.act( + self.edge_angle_linear2(padding_edge_angle_update) + ) + + # Apply edge angle residual on top of edge_ebd_s1 + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # ==================================================================== + # Phase 4: Node edge message (uses e_updated) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info_updated = xp.concat( + [ + xp.tile( + xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), + (1, 1, self.nnei, 1), + ), + nei_node_ebd, + e_updated, + ], + axis=-1, + ) + else: + edge_info_updated = xp.concat( + [ + xp.take( + xp.reshape(node_ebd, (-1, self.n_dim)), + n2e_index, + axis=0, + ), + nei_node_ebd, + e_updated, + ], + axis=-1, + ) + node_edge_update = self.act( + self.node_edge_linear(edge_info_updated) + ) * xp.expand_dims(sw, axis=-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + e_updated, + nlist, + "node", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + e_updated, + n2e_index, + n_ext2e_index, + "node", + ) + ) * xp.expand_dims(sw, axis=-1) + + node_edge_update = ( + (xp.sum(node_edge_update, axis=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + xp.reshape( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ), + (nb, nloc, node_edge_update.shape[-1]), + ) + / self.dynamic_e_sel + ) + ) + + # ==================================================================== + # Phase 5: Node updates (node_self, node_sym with e_updated, node_edge) + # ==================================================================== + n_update_list: list[Array] = [node_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym using e_updated + node_sym_list: list[Array] = [] + node_sym_list.append( + symmetrization_op( + e_updated, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else symmetrization_op_dynamic( + e_updated, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym_list.append( + symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym = self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) + n_update_list.append(node_sym) + + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = xp.reshape( + node_edge_update, (nb, nloc, self.n_multi_edge_message, self.n_dim) + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + else: + n_update_list.append(node_edge_update) + + n_updated = self.list_update(n_update_list, "node") + + return n_updated, e_updated, a_updated + def call( self, node_ebd_ext: Array, # nf x nall x n_dim @@ -1446,6 +1870,32 @@ def call( ) ) + if self.sequential_update and self.update_angle: + return self._call_sequential( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist_mask, + a_sw, + nei_node_ebd, + n2e_index, + n_ext2e_index, + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + nnei, + nall, + n_edge, + ) + n_update_list: list[Array] = [node_ebd] e_update_list: list[Array] = [edge_ebd] a_update_list: list[Array] = [angle_ebd] @@ -1907,6 +2357,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "node_self_mlp": self.node_self_mlp.serialize(), "node_sym_linear": self.node_sym_linear.serialize(), "node_edge_linear": self.node_edge_linear.serialize(), diff --git a/deepmd/pt/model/descriptor/dpa3.py b/deepmd/pt/model/descriptor/dpa3.py index 0c6982afe5..a5f79280fa 100644 --- a/deepmd/pt/model/descriptor/dpa3.py +++ b/deepmd/pt/model/descriptor/dpa3.py @@ -167,6 +167,7 @@ def init_subclass_params(sub_data: Any, sub_class: Any) -> Any: use_exp_switch=self.repflow_args.use_exp_switch, use_dynamic_sel=self.repflow_args.use_dynamic_sel, sel_reduce_factor=self.repflow_args.sel_reduce_factor, + sequential_update=self.repflow_args.sequential_update, use_loc_mapping=use_loc_mapping, exclude_types=exclude_types, env_protection=env_protection, diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 338f48b060..57c9368839 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -53,6 +53,7 @@ def __init__( optim_update: bool = True, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, smooth_edge_update: bool = False, activation_function: str = "silu", update_style: str = "res_residual", @@ -102,8 +103,15 @@ def __init__( self.smooth_edge_update = smooth_edge_update self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update self.dynamic_e_sel = self.nnei / self.sel_reduce_factor self.dynamic_a_sel = self.a_sel / self.sel_reduce_factor + if self.sequential_update and self.update_style != "res_residual": + raise NotImplementedError( + "sequential_update only supports update_style='res_residual'!" + ) + if self.sequential_update and not self.update_angle: + raise NotImplementedError("sequential_update requires update_angle=True!") assert update_residual_init in [ "norm", @@ -694,6 +702,383 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update + def _forward_sequential( + self, + node_ebd: torch.Tensor, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + h2: torch.Tensor, + angle_ebd: torch.Tensor, + nlist: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + a_nlist_mask: torch.Tensor, + a_sw: torch.Tensor, + nei_node_ebd: torch.Tensor, + n2e_index: torch.Tensor, + n_ext2e_index: torch.Tensor, + n2a_index: torch.Tensor, + eij2a_index: torch.Tensor, + eik2a_index: torch.Tensor, + nb: int, + nloc: int, + nnei: int, + nall: int, + n_edge: int | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sequential update path: edge_self → angle_self → edge_angle → node. + + Only supports update_style='res_residual'. + """ + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + assert self.angle_self_linear is not None + + # ==================================================================== + # Phase 1: Edge self update (uses original node_ebd, edge_ebd) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + else: + edge_info = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + edge_self_update = self.act(self.edge_self_linear(edge_info)) + else: + edge_self_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) + ) + + # Apply edge self residual: edge_ebd_s1 = edge_ebd + e_residual[0] * edge_self_update + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # ==================================================================== + # Phase 2: Angle self update (uses original node_ebd, updated edge_ebd_s1) + # ==================================================================== + # Prepare edge for angle from edge_ebd_s1 (updated edge) + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_s1) + else: + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd_s1[..., : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd_s1 + + if not self.use_dynamic_sel: + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + edge_ebd_for_angle = torch.where( + a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 + ) + + # Initialize for JIT: these are only used in non-optim_update path + node_for_angle_info = angle_ebd # placeholder, overwritten below + edge_for_angle_info = angle_ebd # placeholder, overwritten below + + if not self.optim_update: + node_for_angle_info = ( + torch.tile( + node_ebd_for_angle.unsqueeze(2).unsqueeze(2), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else torch.index_select( + node_ebd_for_angle.reshape(-1, self.n_a_compress_dim), + 0, + n2a_index, + ) + ) + edge_for_angle_k = ( + torch.tile(edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) + ) + edge_for_angle_j = ( + torch.tile(edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) + ) + edge_for_angle_info = torch.cat( + [edge_for_angle_k, edge_for_angle_j], dim=-1 + ) + angle_info = torch.cat( + [angle_ebd, node_for_angle_info, edge_for_angle_info], dim=-1 + ) + angle_self_update = self.act(self.angle_self_linear(angle_info)) + else: + angle_self_update = self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) + ) + + # Apply angle self residual: angle_ebd_s2 = angle_ebd + a_residual[0] * angle_self_update + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # ==================================================================== + # Phase 3: Edge angle update (uses updated angle_ebd_s2, updated edge_ebd_s1) + # ==================================================================== + if not self.optim_update: + # Rebuild angle_info with updated angle (a_updated) + angle_info_s2 = torch.cat( + [a_updated, node_for_angle_info, edge_for_angle_info], dim=-1 + ) + edge_angle_update = self.act(self.edge_angle_linear1(angle_info_s2)) + else: + edge_angle_update = self.act( + self.optim_angle_update( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) + ) + + # Reduce edge angle update over angle dimension + if not self.use_dynamic_sel: + weighted_edge_angle_update = ( + a_sw.unsqueeze(-1).unsqueeze(-1) + * a_sw.unsqueeze(-2).unsqueeze(-1) + * edge_angle_update + ) + reduced_edge_angle_update = torch.sum( + weighted_edge_angle_update, dim=-2 + ) / (self.a_sel**0.5) + padding_edge_angle_update = torch.concat( + [ + reduced_edge_angle_update, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel, self.e_dim], + dtype=edge_ebd.dtype, + device=edge_ebd.device, + ), + ], + dim=2, + ) + else: + assert n_edge is not None + weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, + eij2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + + if not self.smooth_edge_update: + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) + full_mask = torch.concat( + [ + a_nlist_mask, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel], + dtype=a_nlist_mask.dtype, + device=a_nlist_mask.device, + ), + ], + dim=-1, + ) + padding_edge_angle_update = torch.where( + full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd + ) + + edge_angle_processed = self.act( + self.edge_angle_linear2(padding_edge_angle_update) + ) + + # Apply edge angle residual on top of edge_ebd_s1 (no recomputation) + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # ==================================================================== + # Phase 4: Node edge message (uses e_updated) + # ==================================================================== + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info_updated = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + e_updated, + ], + dim=-1, + ) + else: + edge_info_updated = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + e_updated, + ], + dim=-1, + ) + node_edge_update = self.act( + self.node_edge_linear(edge_info_updated) + ) * sw.unsqueeze(-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + e_updated, + nlist, + "node", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + e_updated, + n2e_index, + n_ext2e_index, + "node", + ) + ) * sw.unsqueeze(-1) + + node_edge_update = ( + (torch.sum(node_edge_update, dim=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ).reshape(nb, nloc, node_edge_update.shape[-1]) + / self.dynamic_e_sel + ) + ) + + # ==================================================================== + # Phase 5: Node updates (node_self, node_sym with e_updated, node_edge) + # ==================================================================== + n_update_list: list[torch.Tensor] = [node_ebd] + + # node self mlp (uses original node_ebd) + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym using e_updated + node_sym_list: list[torch.Tensor] = [] + node_sym_list.append( + self.symmetrization_op( + e_updated, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + e_updated, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym_list.append( + self.symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) + n_update_list.append(node_sym) + + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = node_edge_update.view( + nb, nloc, self.n_multi_edge_message, self.n_dim + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) + else: + n_update_list.append(node_edge_update) + + n_updated = self.list_update(n_update_list, "node") + + return n_updated, e_updated, a_updated + def forward( self, node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parallel_mode @@ -783,6 +1168,31 @@ def forward( ) ) + if self.sequential_update and self.update_angle: + return self._forward_sequential( + node_ebd, + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist_mask, + a_sw, + nei_node_ebd, + n2e_index, + n_ext2e_index, + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + nnei, + nall, + n_edge, + ) + n_update_list: list[torch.Tensor] = [node_ebd] e_update_list: list[torch.Tensor] = [edge_ebd] a_update_list: list[torch.Tensor] = [angle_ebd] @@ -1220,6 +1630,7 @@ def serialize(self) -> dict: "smooth_edge_update": self.smooth_edge_update, "use_dynamic_sel": self.use_dynamic_sel, "sel_reduce_factor": self.sel_reduce_factor, + "sequential_update": self.sequential_update, "node_self_mlp": self.node_self_mlp.serialize(), "node_sym_linear": self.node_sym_linear.serialize(), "node_edge_linear": self.node_edge_linear.serialize(), diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 433897860f..7c16ab3c7a 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -219,6 +219,7 @@ def __init__( use_exp_switch: bool = False, use_dynamic_sel: bool = False, sel_reduce_factor: float = 10.0, + sequential_update: bool = False, use_loc_mapping: bool = True, optim_update: bool = True, seed: int | list[int] | None = None, @@ -258,6 +259,7 @@ def __init__( self.use_exp_switch = use_exp_switch self.use_dynamic_sel = use_dynamic_sel self.sel_reduce_factor = sel_reduce_factor + self.sequential_update = sequential_update if self.use_dynamic_sel and not self.smooth_edge_update: raise NotImplementedError( "smooth_edge_update must be True when use_dynamic_sel is True!" @@ -329,6 +331,7 @@ def __init__( optim_update=self.optim_update, use_dynamic_sel=self.use_dynamic_sel, sel_reduce_factor=self.sel_reduce_factor, + sequential_update=self.sequential_update, smooth_edge_update=self.smooth_edge_update, seed=child_seed(child_seed(seed, 1), ii), trainable=trainable, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index b12bc7ef6f..70a7985702 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1550,6 +1550,13 @@ def dpa3_repflow_args() -> list[Argument]: "or `a_sel / sel_reduce_factor` instead of the raw `e_sel` or `a_sel` values, " "accommodating larger selection numbers." ) + doc_sequential_update = ( + "Whether to use sequential update mode within each repflow layer. " + "When True, updates are applied sequentially: edge self → angle self (using updated edge) " + "→ edge angle (using updated angle) → node (using final edge), " + "instead of the default parallel mode where all updates use original embeddings. " + "Currently only supports update_style='res_residual'." + ) return [ # repflow args @@ -1680,6 +1687,13 @@ def dpa3_repflow_args() -> list[Argument]: default=10.0, doc=doc_sel_reduce_factor, ), + Argument( + "sequential_update", + bool, + optional=True, + default=False, + doc=doc_sequential_update, + ), ] diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index bca0759f5c..b980c584a1 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -79,6 +79,7 @@ (1,), # n_multi_edge_message ("float64",), # precision (False, True), # add_chg_spin_ebd + (False, True), # sequential_update ) class TestDPA3(CommonTest, DescriptorTest, unittest.TestCase): @property @@ -99,6 +100,7 @@ def data(self) -> dict: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return { "ntypes": self.ntypes, @@ -130,6 +132,7 @@ def data(self) -> dict: "update_style": "res_residual", "update_residual": 0.1, "update_residual_init": update_residual_init, + "sequential_update": sequential_update, } ), # kwargs for descriptor @@ -160,6 +163,7 @@ def skip_pt(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return CommonTest.skip_pt @@ -181,6 +185,7 @@ def skip_pd(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return True if add_chg_spin_ebd else CommonTest.skip_pd @@ -202,6 +207,7 @@ def skip_dp(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return CommonTest.skip_dp @@ -223,6 +229,7 @@ def skip_tf(self) -> bool: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return True @@ -288,6 +295,7 @@ def setUp(self) -> None: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param # fparam for charge=5, spin=1 when add_chg_spin_ebd is True self.fparam = ( @@ -394,6 +402,7 @@ def rtol(self) -> float: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param if precision == "float64": return 1e-10 @@ -421,6 +430,7 @@ def atol(self) -> float: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param if precision == "float64": return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786 diff --git a/source/tests/pt/model/test_dpa3.py b/source/tests/pt/model/test_dpa3.py index 12b0be4532..d66eab9dea 100644 --- a/source/tests/pt/model/test_dpa3.py +++ b/source/tests/pt/model/test_dpa3.py @@ -56,6 +56,7 @@ def test_consistency( prec, ect, add_chg_spin, + seq_upd, ) in itertools.product( [True, False], # update_angle ["res_residual"], # update_style @@ -67,7 +68,11 @@ def test_consistency( ["float64"], # precision [False], # use_econf_tebd [False, True], # add_chg_spin_ebd + [False, True], # sequential_update ): + # sequential_update only works with update_angle=True + if seq_upd and not ua: + continue dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) if prec == "float64": @@ -93,6 +98,7 @@ def test_consistency( update_style=rus, update_residual_init=ruri, smooth_edge_update=True, + sequential_update=seq_upd, ) # dpa3 new impl @@ -177,6 +183,7 @@ def test_jit( nme, prec, ect, + seq_upd, ) in itertools.product( [True], # update_angle ["res_residual"], # update_style @@ -187,6 +194,7 @@ def test_jit( [1, 2], # n_multi_edge_message ["float64"], # precision [False], # use_econf_tebd + [False, True], # sequential_update ): dtype = PRECISION_DICT[prec] rtol, atol = get_tols(prec) @@ -211,6 +219,7 @@ def test_jit( update_style=rus, update_residual_init=ruri, smooth_edge_update=True, + sequential_update=seq_upd, ) # dpa3 new impl From 8523d88e4b9cc6c202dbefeeea490df6bf64421f Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 8 May 2026 20:59:35 +0800 Subject: [PATCH 2/9] fix comments --- deepmd/dpmodel/descriptor/dpa3.py | 3 +- deepmd/dpmodel/descriptor/repflows.py | 772 +++++++----------- deepmd/pt/model/descriptor/repflow_layer.py | 710 ++++++---------- deepmd/utils/argcheck.py | 2 +- .../tests/consistent/descriptor/test_dpa3.py | 8 +- 5 files changed, 563 insertions(+), 932 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa3.py b/deepmd/dpmodel/descriptor/dpa3.py index 88f56213ee..52b2ca9c2c 100644 --- a/deepmd/dpmodel/descriptor/dpa3.py +++ b/deepmd/dpmodel/descriptor/dpa3.py @@ -175,7 +175,8 @@ class RepFlowArgs: When True, updates are applied sequentially: edge self → angle self (using updated edge) → edge angle (using updated angle) → node (using final edge), instead of the default parallel mode where all updates use original embeddings. - Currently only supports ``update_style='res_residual'``. + Currently only supports ``update_style='res_residual'`` and requires ``update_angle=True``; + otherwise, a ``ValueError`` will be raised during initialization. """ def __init__( diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 5788840279..1af1d15bd9 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -1354,42 +1354,20 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update - def _call_sequential( + def _compute_edge_self_update( self, xp: object, node_ebd: Array, node_ebd_ext: Array, edge_ebd: Array, - h2: Array, - angle_ebd: Array, - nlist: Array, - nlist_mask: Array, - sw: Array, - a_nlist_mask: Array, - a_sw: Array, nei_node_ebd: Array, + nlist: Array, n2e_index: Array, n_ext2e_index: Array, - n2a_index: Array, - eij2a_index: Array, - eik2a_index: Array, nb: int, nloc: int, - nnei: int, - nall: int, - n_edge: int, - ) -> tuple[Array, Array, Array]: - """Sequential update path: edge_self → angle_self → edge_angle → node. - - Only supports update_style='res_residual'. - """ - assert self.angle_self_linear is not None - assert self.edge_angle_linear1 is not None - assert self.edge_angle_linear2 is not None - - # ==================================================================== - # Phase 1: Edge self update (uses original node_ebd, edge_ebd) - # ==================================================================== + ) -> Array: + """Compute edge self update.""" if not self.optim_update: if not self.use_dynamic_sel: edge_info = xp.concat( @@ -1416,9 +1394,9 @@ def _call_sequential( ], axis=-1, ) - edge_self_update = self.act(self.edge_self_linear(edge_info)) + return self.act(self.edge_self_linear(edge_info)) else: - edge_self_update = self.act( + return self.act( self.optim_edge_update( node_ebd, node_ebd_ext, @@ -1437,24 +1415,26 @@ def _call_sequential( ) ) - # Apply edge self residual - edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update - - # ==================================================================== - # Phase 2: Angle self update (uses original node_ebd, updated edge_ebd_s1) - # ==================================================================== + def _prepare_angle_embeddings( + self, + xp: object, + node_ebd: Array, + edge_ebd: Array, + a_nlist_mask: Array, + ) -> tuple[Array, Array]: + """Prepare compressed node/edge embeddings for angle computation.""" if self.a_compress_rate != 0: if not self.a_compress_use_split: assert self.a_compress_n_linear is not None assert self.a_compress_e_linear is not None node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_s1) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) else: node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd_s1[..., : self.e_a_compress_dim] + edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] else: node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd_s1 + edge_ebd_for_angle = edge_ebd if not self.use_dynamic_sel: edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] @@ -1463,12 +1443,34 @@ def _call_sequential( edge_ebd_for_angle, xp.zeros_like(edge_ebd_for_angle), ) + return node_ebd_for_angle, edge_ebd_for_angle + + def _compute_angle_update( + self, + xp: object, + angle_ebd: Array, + node_ebd_for_angle: Array, + edge_ebd_for_angle: Array, + feat: str, + n2a_index: Array, + eij2a_index: Array, + eik2a_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute angle-based update (for edge_angle or angle_self). + Parameters + ---------- + feat : str + "edge" for edge_angle_linear1, "angle" for angle_self_linear. + """ if not self.optim_update: node_for_angle_info = ( xp.tile( xp.reshape( - node_ebd_for_angle, (nb, nloc, 1, 1, self.n_a_compress_dim) + node_ebd_for_angle, + (nb, nloc, 1, 1, self.n_a_compress_dim), ), (1, 1, self.a_sel, self.a_sel, 1), ) @@ -1515,59 +1517,72 @@ def _call_sequential( angle_info = xp.concat( [angle_ebd, node_for_angle_info, edge_for_angle_info], axis=-1 ) - angle_self_update = self.act(self.angle_self_linear(angle_info)) + if feat == "edge": + assert self.edge_angle_linear1 is not None + return self.act(self.edge_angle_linear1(angle_info)) + else: + assert self.angle_self_linear is not None + return self.act(self.angle_self_linear(angle_info)) else: - angle_self_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", + if feat == "edge": + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", + else: + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) ) - ) - # Apply angle self residual - a_updated = angle_ebd + self.a_residual[0] * angle_self_update - - # ==================================================================== - # Phase 3: Edge angle update (uses updated angle a_updated, updated edge_ebd_s1) - # ==================================================================== - if not self.optim_update: - angle_info_s2 = xp.concat( - [a_updated, node_for_angle_info, edge_for_angle_info], axis=-1 - ) - edge_angle_update = self.act(self.edge_angle_linear1(angle_info_s2)) - else: - edge_angle_update = self.act( - self.optim_angle_update( - a_updated, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - a_updated, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) + def _compute_edge_angle_reduction( + self, + xp: object, + edge_angle_update: Array, + edge_ebd_fallback: Array, + a_sw: Array, + a_nlist_mask: Array, + nb: int, + nloc: int, + n_edge: int, + eij2a_index: Array, + ) -> Array: + """Reduce edge angle update over angle dimension, pad, and apply linear2. - # Reduce edge angle update over angle dimension + Parameters + ---------- + edge_ebd_fallback : Array + Edge embedding used for non-smooth padding fallback. + """ + assert self.edge_angle_linear2 is not None if not self.use_dynamic_sel: weighted_edge_angle_update = ( a_sw[:, :, :, xp.newaxis, xp.newaxis] @@ -1582,8 +1597,8 @@ def _call_sequential( reduced_edge_angle_update, xp.zeros( (nb, nloc, self.nnei - self.a_sel, self.e_dim), - dtype=edge_ebd.dtype, - device=array_api_compat.device(edge_ebd), + dtype=edge_ebd_fallback.dtype, + device=array_api_compat.device(edge_ebd_fallback), ), ], axis=2, @@ -1618,34 +1633,41 @@ def _call_sequential( padding_edge_angle_update = xp.where( xp.expand_dims(full_mask, axis=-1), padding_edge_angle_update, - edge_ebd, + edge_ebd_fallback, ) - edge_angle_processed = self.act( - self.edge_angle_linear2(padding_edge_angle_update) - ) + return self.act(self.edge_angle_linear2(padding_edge_angle_update)) - # Apply edge angle residual on top of edge_ebd_s1 - e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed - - # ==================================================================== - # Phase 4: Node edge message (uses e_updated) - # ==================================================================== + def _compute_node_edge_message( + self, + xp: object, + node_ebd: Array, + node_ebd_ext: Array, + edge_ebd: Array, + nei_node_ebd: Array, + sw: Array, + nlist: Array, + n2e_index: Array, + n_ext2e_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute node edge message and reduce over neighbor dimension.""" if not self.optim_update: if not self.use_dynamic_sel: - edge_info_updated = xp.concat( + edge_info = xp.concat( [ xp.tile( xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), (1, 1, self.nnei, 1), ), nei_node_ebd, - e_updated, + edge_ebd, ], axis=-1, ) else: - edge_info_updated = xp.concat( + edge_info = xp.concat( [ xp.take( xp.reshape(node_ebd, (-1, self.n_dim)), @@ -1653,19 +1675,19 @@ def _call_sequential( axis=0, ), nei_node_ebd, - e_updated, + edge_ebd, ], axis=-1, ) node_edge_update = self.act( - self.node_edge_linear(edge_info_updated) + self.node_edge_linear(edge_info) ) * xp.expand_dims(sw, axis=-1) else: node_edge_update = self.act( self.optim_edge_update( node_ebd, node_ebd_ext, - e_updated, + edge_ebd, nlist, "node", ) @@ -1673,7 +1695,7 @@ def _call_sequential( else self.optim_edge_update_dynamic( node_ebd, node_ebd_ext, - e_updated, + edge_ebd, n2e_index, n_ext2e_index, "node", @@ -1696,21 +1718,25 @@ def _call_sequential( / self.dynamic_e_sel ) ) + return node_edge_update - # ==================================================================== - # Phase 5: Node updates (node_self, node_sym with e_updated, node_edge) - # ==================================================================== - n_update_list: list[Array] = [node_ebd] - - # node self mlp - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - n_update_list.append(node_self_mlp) - - # node sym using e_updated + def _compute_node_sym( + self, + xp: object, + edge_ebd: Array, + nei_node_ebd: Array, + h2: Array, + nlist_mask: Array, + sw: Array, + n2e_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute node symmetrization update (grrg + drrd).""" node_sym_list: list[Array] = [] node_sym_list.append( symmetrization_op( - e_updated, + edge_ebd, h2, nlist_mask, sw, @@ -1718,7 +1744,7 @@ def _call_sequential( ) if not self.use_dynamic_sel else symmetrization_op_dynamic( - e_updated, + edge_ebd, h2, sw, owner=n2e_index, @@ -1750,21 +1776,7 @@ def _call_sequential( axis_neuron=self.axis_neuron, ) ) - node_sym = self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) - n_update_list.append(node_sym) - - if self.n_multi_edge_message > 1: - node_edge_update_mul_head = xp.reshape( - node_edge_update, (nb, nloc, self.n_multi_edge_message, self.n_dim) - ) - for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) - else: - n_update_list.append(node_edge_update) - - n_updated = self.list_update(n_update_list, "node") - - return n_updated, e_updated, a_updated + return self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) def call( self, @@ -1870,160 +1882,146 @@ def call( ) ) + # Edge self update (always from original embeddings) + edge_self_update = self._compute_edge_self_update( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + if self.sequential_update and self.update_angle: - return self._call_sequential( + # === Sequential update path === + # Phase 1: Apply edge self residual + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # Phase 2: Angle self (uses updated edge_ebd_s1) + node_for_a, edge_for_a = self._prepare_angle_embeddings( + xp, node_ebd, edge_ebd_s1, a_nlist_mask + ) + angle_self_update = self._compute_angle_update( xp, - node_ebd, - node_ebd_ext, - edge_ebd, - h2, angle_ebd, - nlist, - nlist_mask, - sw, - a_nlist_mask, - a_sw, - nei_node_ebd, - n2e_index, - n_ext2e_index, + node_for_a, + edge_for_a, + "angle", n2a_index, eij2a_index, eik2a_index, nb, nloc, - nnei, - nall, - n_edge, ) + a_updated = angle_ebd + self.a_residual[0] * angle_self_update - n_update_list: list[Array] = [node_ebd] - e_update_list: list[Array] = [edge_ebd] - a_update_list: list[Array] = [angle_ebd] - - # node self mlp - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - n_update_list.append(node_self_mlp) - - # node sym (grrg + drrd) - node_sym_list: list[Array] = [] - node_sym_list.append( - symmetrization_op( - edge_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, + # Phase 3: Edge angle (uses updated angle a_updated + edge_ebd_s1) + edge_angle_update = self._compute_angle_update( + xp, + a_updated, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, ) - if not self.use_dynamic_sel - else symmetrization_op_dynamic( - edge_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, + edge_angle_processed = self._compute_edge_angle_reduction( + xp, + edge_angle_update, + edge_ebd_s1, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, ) - ) - node_sym_list.append( - symmetrization_op( + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # Phase 4+5: Node updates (uses e_updated) + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + node_sym = self._compute_node_sym( + xp, + e_updated, nei_node_ebd, h2, nlist_mask, sw, - self.axis_neuron, + n2e_index, + nb, + nloc, ) - if not self.use_dynamic_sel - else symmetrization_op_dynamic( + node_edge_update = self._compute_node_edge_message( + xp, + node_ebd, + node_ebd_ext, + e_updated, nei_node_ebd, - h2, sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, ) - ) - node_sym = self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) - n_update_list.append(node_sym) - if not self.optim_update: - if not self.use_dynamic_sel: - # nb x nloc x nnei x (n_dim * 2 + e_dim) - edge_info = xp.concat( - [ - xp.tile( - xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), - (1, 1, self.nnei, 1), - ), - nei_node_ebd, - edge_ebd, - ], - axis=-1, + n_update_list: list[Array] = [node_ebd, node_self_mlp, node_sym] + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = xp.reshape( + node_edge_update, + (nb, nloc, self.n_multi_edge_message, self.n_dim), ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) else: - # n_edge x (n_dim * 2 + e_dim) - edge_info = xp.concat( - [ - xp.take( - xp.reshape(node_ebd, (-1, self.n_dim)), - n2e_index, - axis=0, - ), - nei_node_ebd, - edge_ebd, - ], - axis=-1, - ) - else: - edge_info = None + n_update_list.append(node_edge_update) + n_updated = self.list_update(n_update_list, "node") - # node edge message - # nb x nloc x nnei x (h * n_dim) - if not self.optim_update: - assert edge_info is not None - node_edge_update = self.act( - self.node_edge_linear(edge_info) - ) * xp.expand_dims(sw, axis=-1) - else: - node_edge_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "node", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "node", - ) - ) * xp.expand_dims(sw, axis=-1) + return n_updated, e_updated, a_updated - node_edge_update = ( - (xp.sum(node_edge_update, axis=-2) / self.nnei) - if not self.use_dynamic_sel - else ( - xp.reshape( - aggregate( - node_edge_update, - n2e_index, - average=False, - num_owner=nb * nloc, - ), - (nb, nloc, node_edge_update.shape[-1]), - ) - / self.dynamic_e_sel - ) + # === Parallel update path === + n_update_list: list[Array] = [node_ebd] + e_update_list: list[Array] = [edge_ebd] + a_update_list: list[Array] = [angle_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym (grrg + drrd) + node_sym = self._compute_node_sym( + xp, + edge_ebd, + nei_node_ebd, + h2, + nlist_mask, + sw, + n2e_index, + nb, + nloc, + ) + n_update_list.append(node_sym) + + # node edge message + node_edge_update = self._compute_node_edge_message( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, ) if self.n_multi_edge_message > 1: # nb x nloc x h x n_dim @@ -2038,234 +2036,58 @@ def call( n_updated = self.list_update(n_update_list, "node") # edge self message - if not self.optim_update: - assert edge_info is not None - edge_self_update = self.act(self.edge_self_linear(edge_info)) - else: - edge_self_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "edge", - ) - ) e_update_list.append(edge_self_update) if self.update_angle: assert self.angle_self_linear is not None assert self.edge_angle_linear1 is not None assert self.edge_angle_linear2 is not None - # get angle info - if self.a_compress_rate != 0: - if not self.a_compress_use_split: - assert self.a_compress_n_linear is not None - assert self.a_compress_e_linear is not None - node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) - else: - # use the first a_compress_dim dim for node and edge - node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] - else: - node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd - - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = xp.where( - xp.expand_dims(a_nlist_mask, axis=-1), - edge_ebd_for_angle, - xp.zeros_like(edge_ebd_for_angle), - ) - if not self.optim_update: - # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim - node_for_angle_info = ( - xp.tile( - xp.reshape( - node_ebd_for_angle, (nb, nloc, 1, 1, self.n_a_compress_dim) - ), - (1, 1, self.a_sel, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else xp.take( - xp.reshape(node_ebd_for_angle, (-1, self.n_a_compress_dim)), - n2a_index, - axis=0, - ) - ) - # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim - edge_for_angle_k = ( - xp.tile( - xp.reshape( - edge_ebd_for_angle, - (nb, nloc, 1, self.a_sel, self.e_a_compress_dim), - ), - (1, 1, self.a_sel, 1, 1), - ) - if not self.use_dynamic_sel - else xp.take( - edge_ebd_for_angle, - eik2a_index, - axis=0, - ) - ) - # nb x nloc x a_nnei x (a_nnei) x e_dim - edge_for_angle_j = ( - xp.tile( - xp.reshape( - edge_ebd_for_angle, - (nb, nloc, self.a_sel, 1, self.e_a_compress_dim), - ), - (1, 1, 1, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else xp.take( - edge_ebd_for_angle, - eij2a_index, - axis=0, - ) - ) - # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) - edge_for_angle_info = xp.concat( - [edge_for_angle_k, edge_for_angle_j], axis=-1 - ) - angle_info_list = [angle_ebd] - angle_info_list.append(node_for_angle_info) - angle_info_list.append(edge_for_angle_info) - # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) - # [OR] - # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) - angle_info = xp.concat(angle_info_list, axis=-1) - else: - angle_info = None + node_for_a, edge_for_a = self._prepare_angle_embeddings( + xp, node_ebd, edge_ebd, a_nlist_mask + ) # edge angle message - # nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim - if not self.optim_update: - assert angle_info is not None - edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) - else: - edge_angle_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x a_nnei x e_dim - weighted_edge_angle_update = ( - a_sw[:, :, :, xp.newaxis, xp.newaxis] - * a_sw[:, :, xp.newaxis, :, xp.newaxis] - * edge_angle_update - ) - # nb x nloc x a_nnei x e_dim - reduced_edge_angle_update = xp.sum( - weighted_edge_angle_update, axis=-2 - ) / (self.a_sel**0.5) - # nb x nloc x nnei x e_dim - padding_edge_angle_update = xp.concat( - [ - reduced_edge_angle_update, - xp.zeros( - (nb, nloc, self.nnei - self.a_sel, self.e_dim), - dtype=edge_ebd.dtype, - device=array_api_compat.device(edge_ebd), - ), - ], - axis=2, - ) - else: - # n_angle x e_dim - weighted_edge_angle_update = edge_angle_update * xp.expand_dims( - a_sw, axis=-1 - ) - # n_edge x e_dim - padding_edge_angle_update = aggregate( - weighted_edge_angle_update, - eij2a_index, - average=False, - num_owner=n_edge, - ) / (self.dynamic_a_sel**0.5) - - if not self.smooth_edge_update: - # will be deprecated in the future - # not support dynamic index, will pass anyway - if self.use_dynamic_sel: - raise NotImplementedError( - "smooth_edge_update must be True when use_dynamic_sel is True!" - ) - full_mask = xp.concat( - [ - a_nlist_mask, - xp.zeros( - (nb, nloc, self.nnei - self.a_sel), - dtype=a_nlist_mask.dtype, - device=array_api_compat.device(a_nlist_mask), - ), - ], - axis=-1, - ) - padding_edge_angle_update = xp.where( - xp.expand_dims(full_mask, axis=-1), - padding_edge_angle_update, - edge_ebd, - ) - e_update_list.append( - self.act(self.edge_angle_linear2(padding_edge_angle_update)) + edge_angle_update = self._compute_angle_update( + xp, + angle_ebd, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, ) + edge_angle_processed = self._compute_edge_angle_reduction( + xp, + edge_angle_update, + edge_ebd, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) + e_update_list.append(edge_angle_processed) # update edge_ebd e_updated = self.list_update(e_update_list, "edge") # angle self message - # nb x nloc x a_nnei x a_nnei x dim_a - if not self.optim_update: - assert angle_info is not None - angle_self_update = self.act(self.angle_self_linear(angle_info)) - else: - angle_self_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", - ) - ) + angle_self_update = self._compute_angle_update( + xp, + angle_ebd, + node_for_a, + edge_for_a, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + ) a_update_list.append(angle_self_update) else: # update edge_ebd diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 57c9368839..a094105487 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -702,41 +702,17 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update - def _forward_sequential( + def _compute_edge_self_update( self, node_ebd: torch.Tensor, node_ebd_ext: torch.Tensor, edge_ebd: torch.Tensor, - h2: torch.Tensor, - angle_ebd: torch.Tensor, - nlist: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - a_nlist_mask: torch.Tensor, - a_sw: torch.Tensor, nei_node_ebd: torch.Tensor, + nlist: torch.Tensor, n2e_index: torch.Tensor, n_ext2e_index: torch.Tensor, - n2a_index: torch.Tensor, - eij2a_index: torch.Tensor, - eik2a_index: torch.Tensor, - nb: int, - nloc: int, - nnei: int, - nall: int, - n_edge: int | None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Sequential update path: edge_self → angle_self → edge_angle → node. - - Only supports update_style='res_residual'. - """ - assert self.edge_angle_linear1 is not None - assert self.edge_angle_linear2 is not None - assert self.angle_self_linear is not None - - # ==================================================================== - # Phase 1: Edge self update (uses original node_ebd, edge_ebd) - # ==================================================================== + ) -> torch.Tensor: + """Compute edge self update.""" if not self.optim_update: if not self.use_dynamic_sel: edge_info = torch.cat( @@ -758,9 +734,9 @@ def _forward_sequential( ], dim=-1, ) - edge_self_update = self.act(self.edge_self_linear(edge_info)) + return self.act(self.edge_self_linear(edge_info)) else: - edge_self_update = self.act( + return self.act( self.optim_edge_update( node_ebd, node_ebd_ext, @@ -779,36 +755,50 @@ def _forward_sequential( ) ) - # Apply edge self residual: edge_ebd_s1 = edge_ebd + e_residual[0] * edge_self_update - edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update - - # ==================================================================== - # Phase 2: Angle self update (uses original node_ebd, updated edge_ebd_s1) - # ==================================================================== - # Prepare edge for angle from edge_ebd_s1 (updated edge) + def _prepare_angle_embeddings( + self, + node_ebd: torch.Tensor, + edge_ebd: torch.Tensor, + a_nlist_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare compressed node/edge embeddings for angle computation.""" if self.a_compress_rate != 0: if not self.a_compress_use_split: assert self.a_compress_n_linear is not None assert self.a_compress_e_linear is not None node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd_s1) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) else: node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd_s1[..., : self.e_a_compress_dim] + edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] else: node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd_s1 + edge_ebd_for_angle = edge_ebd if not self.use_dynamic_sel: edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] edge_ebd_for_angle = torch.where( a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 ) + return node_ebd_for_angle, edge_ebd_for_angle - # Initialize for JIT: these are only used in non-optim_update path - node_for_angle_info = angle_ebd # placeholder, overwritten below - edge_for_angle_info = angle_ebd # placeholder, overwritten below + def _compute_angle_update( + self, + angle_ebd: torch.Tensor, + node_ebd_for_angle: torch.Tensor, + edge_ebd_for_angle: torch.Tensor, + feat: str, + n2a_index: torch.Tensor, + eij2a_index: torch.Tensor, + eik2a_index: torch.Tensor, + ) -> torch.Tensor: + """Compute angle-based update (for edge_angle or angle_self). + Parameters + ---------- + feat : str + "edge" for edge_angle_linear1, "angle" for angle_self_linear. + """ if not self.optim_update: node_for_angle_info = ( torch.tile( @@ -838,60 +828,71 @@ def _forward_sequential( angle_info = torch.cat( [angle_ebd, node_for_angle_info, edge_for_angle_info], dim=-1 ) - angle_self_update = self.act(self.angle_self_linear(angle_info)) + if feat == "edge": + assert self.edge_angle_linear1 is not None + return self.act(self.edge_angle_linear1(angle_info)) + else: + assert self.angle_self_linear is not None + return self.act(self.angle_self_linear(angle_info)) else: - angle_self_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", + if feat == "edge": + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", + else: + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) ) - ) - # Apply angle self residual: angle_ebd_s2 = angle_ebd + a_residual[0] * angle_self_update - a_updated = angle_ebd + self.a_residual[0] * angle_self_update - - # ==================================================================== - # Phase 3: Edge angle update (uses updated angle_ebd_s2, updated edge_ebd_s1) - # ==================================================================== - if not self.optim_update: - # Rebuild angle_info with updated angle (a_updated) - angle_info_s2 = torch.cat( - [a_updated, node_for_angle_info, edge_for_angle_info], dim=-1 - ) - edge_angle_update = self.act(self.edge_angle_linear1(angle_info_s2)) - else: - edge_angle_update = self.act( - self.optim_angle_update( - a_updated, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - a_updated, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) + def _compute_edge_angle_reduction( + self, + edge_angle_update: torch.Tensor, + edge_ebd_fallback: torch.Tensor, + a_sw: torch.Tensor, + a_nlist_mask: torch.Tensor, + nb: int, + nloc: int, + n_edge: int | None, + eij2a_index: torch.Tensor, + ) -> torch.Tensor: + """Reduce edge angle update over angle dimension, pad, and apply linear2. - # Reduce edge angle update over angle dimension + Parameters + ---------- + edge_ebd_fallback : torch.Tensor + Edge embedding used for non-smooth padding fallback. + """ + assert self.edge_angle_linear2 is not None if not self.use_dynamic_sel: weighted_edge_angle_update = ( a_sw.unsqueeze(-1).unsqueeze(-1) @@ -906,8 +907,8 @@ def _forward_sequential( reduced_edge_angle_update, torch.zeros( [nb, nloc, self.nnei - self.a_sel, self.e_dim], - dtype=edge_ebd.dtype, - device=edge_ebd.device, + dtype=edge_ebd_fallback.dtype, + device=edge_ebd_fallback.device, ), ], dim=2, @@ -939,49 +940,57 @@ def _forward_sequential( dim=-1, ) padding_edge_angle_update = torch.where( - full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd + full_mask.unsqueeze(-1), + padding_edge_angle_update, + edge_ebd_fallback, ) - edge_angle_processed = self.act( - self.edge_angle_linear2(padding_edge_angle_update) - ) - - # Apply edge angle residual on top of edge_ebd_s1 (no recomputation) - e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + return self.act(self.edge_angle_linear2(padding_edge_angle_update)) - # ==================================================================== - # Phase 4: Node edge message (uses e_updated) - # ==================================================================== + def _compute_node_edge_message( + self, + node_ebd: torch.Tensor, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + nei_node_ebd: torch.Tensor, + sw: torch.Tensor, + nlist: torch.Tensor, + n2e_index: torch.Tensor, + n_ext2e_index: torch.Tensor, + nb: int, + nloc: int, + ) -> torch.Tensor: + """Compute node edge message and reduce over neighbor dimension.""" if not self.optim_update: if not self.use_dynamic_sel: - edge_info_updated = torch.cat( + edge_info = torch.cat( [ torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), nei_node_ebd, - e_updated, + edge_ebd, ], dim=-1, ) else: - edge_info_updated = torch.cat( + edge_info = torch.cat( [ torch.index_select( node_ebd.reshape(-1, self.n_dim), 0, n2e_index ), nei_node_ebd, - e_updated, + edge_ebd, ], dim=-1, ) node_edge_update = self.act( - self.node_edge_linear(edge_info_updated) + self.node_edge_linear(edge_info) ) * sw.unsqueeze(-1) else: node_edge_update = self.act( self.optim_edge_update( node_ebd, node_ebd_ext, - e_updated, + edge_ebd, nlist, "node", ) @@ -989,7 +998,7 @@ def _forward_sequential( else self.optim_edge_update_dynamic( node_ebd, node_ebd_ext, - e_updated, + edge_ebd, n2e_index, n_ext2e_index, "node", @@ -1009,21 +1018,24 @@ def _forward_sequential( / self.dynamic_e_sel ) ) + return node_edge_update - # ==================================================================== - # Phase 5: Node updates (node_self, node_sym with e_updated, node_edge) - # ==================================================================== - n_update_list: list[torch.Tensor] = [node_ebd] - - # node self mlp (uses original node_ebd) - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - n_update_list.append(node_self_mlp) - - # node sym using e_updated + def _compute_node_sym( + self, + edge_ebd: torch.Tensor, + nei_node_ebd: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + n2e_index: torch.Tensor, + nb: int, + nloc: int, + ) -> torch.Tensor: + """Compute node symmetrization update (grrg + drrd).""" node_sym_list: list[torch.Tensor] = [] node_sym_list.append( self.symmetrization_op( - e_updated, + edge_ebd, h2, nlist_mask, sw, @@ -1031,7 +1043,7 @@ def _forward_sequential( ) if not self.use_dynamic_sel else self.symmetrization_op_dynamic( - e_updated, + edge_ebd, h2, sw, owner=n2e_index, @@ -1063,21 +1075,7 @@ def _forward_sequential( axis_neuron=self.axis_neuron, ) ) - node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) - n_update_list.append(node_sym) - - if self.n_multi_edge_message > 1: - node_edge_update_mul_head = node_edge_update.view( - nb, nloc, self.n_multi_edge_message, self.n_dim - ) - for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[..., head_index, :]) - else: - n_update_list.append(node_edge_update) - - n_updated = self.list_update(n_update_list, "node") - - return n_updated, e_updated, a_updated + return self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) def forward( self, @@ -1168,31 +1166,95 @@ def forward( ) ) + # Edge self update (always from original embeddings) + edge_self_update = self._compute_edge_self_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + ) + if self.sequential_update and self.update_angle: - return self._forward_sequential( - node_ebd, - node_ebd_ext, - edge_ebd, - h2, + # === Sequential update path === + # Phase 1: Apply edge self residual + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # Phase 2: Angle self (uses updated edge_ebd_s1) + node_for_a, edge_for_a = self._prepare_angle_embeddings( + node_ebd, edge_ebd_s1, a_nlist_mask + ) + angle_self_update = self._compute_angle_update( angle_ebd, - nlist, - nlist_mask, - sw, - a_nlist_mask, - a_sw, - nei_node_ebd, - n2e_index, - n_ext2e_index, + node_for_a, + edge_for_a, + "angle", n2a_index, eij2a_index, eik2a_index, + ) + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # Phase 3: Edge angle (uses updated angle a_updated + edge_ebd_s1) + edge_angle_update = self._compute_angle_update( + a_updated, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + ) + edge_angle_processed = self._compute_edge_angle_reduction( + edge_angle_update, + edge_ebd_s1, + a_sw, + a_nlist_mask, nb, nloc, - nnei, - nall, n_edge, + eij2a_index, + ) + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed + + # Phase 4+5: Node updates (uses e_updated) + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + node_sym = self._compute_node_sym( + e_updated, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc + ) + node_edge_update = self._compute_node_edge_message( + node_ebd, + node_ebd_ext, + e_updated, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, ) + n_update_list: list[torch.Tensor] = [ + node_ebd, + node_self_mlp, + node_sym, + ] + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = node_edge_update.view( + nb, nloc, self.n_multi_edge_message, self.n_dim + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) + else: + n_update_list.append(node_edge_update) + n_updated = self.list_update(n_update_list, "node") + + return n_updated, e_updated, a_updated + + # === Parallel update path === n_update_list: list[torch.Tensor] = [node_ebd] e_update_list: list[torch.Tensor] = [edge_ebd] a_update_list: list[torch.Tensor] = [angle_ebd] @@ -1202,118 +1264,24 @@ def forward( n_update_list.append(node_self_mlp) # node sym (grrg + drrd) - node_sym_list: list[torch.Tensor] = [] - node_sym_list.append( - self.symmetrization_op( - edge_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( - edge_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) + node_sym = self._compute_node_sym( + edge_ebd, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc ) - node_sym_list.append( - self.symmetrization_op( - nei_node_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( - nei_node_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) - ) - node_sym = self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) n_update_list.append(node_sym) - if not self.optim_update: - if not self.use_dynamic_sel: - # nb x nloc x nnei x (n_dim * 2 + e_dim) - edge_info = torch.cat( - [ - torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), - nei_node_ebd, - edge_ebd, - ], - dim=-1, - ) - else: - # n_edge x (n_dim * 2 + e_dim) - edge_info = torch.cat( - [ - torch.index_select( - node_ebd.reshape(-1, self.n_dim), 0, n2e_index - ), - nei_node_ebd, - edge_ebd, - ], - dim=-1, - ) - else: - edge_info = None - # node edge message - # nb x nloc x nnei x (h * n_dim) - if not self.optim_update: - assert edge_info is not None - node_edge_update = self.act( - self.node_edge_linear(edge_info) - ) * sw.unsqueeze(-1) - else: - node_edge_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "node", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "node", - ) - ) * sw.unsqueeze(-1) - node_edge_update = ( - (torch.sum(node_edge_update, dim=-2) / self.nnei) - if not self.use_dynamic_sel - else ( - aggregate( - node_edge_update, - n2e_index, - average=False, - num_owner=nb * nloc, - ).reshape(nb, nloc, node_edge_update.shape[-1]) - / self.dynamic_e_sel - ) + node_edge_update = self._compute_node_edge_message( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, ) - if self.n_multi_edge_message > 1: # nb x nloc x h x n_dim node_edge_update_mul_head = node_edge_update.view( @@ -1327,211 +1295,51 @@ def forward( n_updated = self.list_update(n_update_list, "node") # edge self message - if not self.optim_update: - assert edge_info is not None - edge_self_update = self.act(self.edge_self_linear(edge_info)) - else: - edge_self_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "edge", - ) - ) e_update_list.append(edge_self_update) if self.update_angle: assert self.angle_self_linear is not None assert self.edge_angle_linear1 is not None assert self.edge_angle_linear2 is not None - # get angle info - if self.a_compress_rate != 0: - if not self.a_compress_use_split: - assert self.a_compress_n_linear is not None - assert self.a_compress_e_linear is not None - node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) - else: - # use the first a_compress_dim dim for node and edge - node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] - else: - node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] - # nb x nloc x a_nnei x e_dim - edge_ebd_for_angle = torch.where( - a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 - ) - if not self.optim_update: - # nb x nloc x a_nnei x a_nnei x n_dim [OR] n_angle x n_dim - node_for_angle_info = ( - torch.tile( - node_ebd_for_angle.unsqueeze(2).unsqueeze(2), - (1, 1, self.a_sel, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else torch.index_select( - node_ebd_for_angle.reshape(-1, self.n_a_compress_dim), - 0, - n2a_index, - ) - ) - - # nb x nloc x (a_nnei) x a_nnei x e_dim [OR] n_angle x e_dim - edge_for_angle_k = ( - torch.tile( - edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1) - ) - if not self.use_dynamic_sel - else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) - ) - # nb x nloc x a_nnei x (a_nnei) x e_dim [OR] n_angle x e_dim - edge_for_angle_j = ( - torch.tile( - edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1) - ) - if not self.use_dynamic_sel - else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) - ) - # nb x nloc x a_nnei x a_nnei x (e_dim + e_dim) [OR] n_angle x (e_dim + e_dim) - edge_for_angle_info = torch.cat( - [edge_for_angle_k, edge_for_angle_j], dim=-1 - ) - angle_info_list = [angle_ebd] - angle_info_list.append(node_for_angle_info) - angle_info_list.append(edge_for_angle_info) - # nb x nloc x a_nnei x a_nnei x (a + n_dim + e_dim*2) or (a + a/c + a/c) - # [OR] - # n_angle x (a + n_dim + e_dim*2) or (a + a/c + a/c) - angle_info = torch.cat(angle_info_list, dim=-1) - else: - angle_info = None + node_for_a, edge_for_a = self._prepare_angle_embeddings( + node_ebd, edge_ebd, a_nlist_mask + ) # edge angle message - # nb x nloc x a_nnei x a_nnei x e_dim [OR] n_angle x e_dim - if not self.optim_update: - assert angle_info is not None - edge_angle_update = self.act(self.edge_angle_linear1(angle_info)) - else: - edge_angle_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) - - if not self.use_dynamic_sel: - # nb x nloc x a_nnei x a_nnei x e_dim - weighted_edge_angle_update = ( - a_sw.unsqueeze(-1).unsqueeze(-1) - * a_sw.unsqueeze(-2).unsqueeze(-1) - * edge_angle_update - ) - # nb x nloc x a_nnei x e_dim - reduced_edge_angle_update = torch.sum( - weighted_edge_angle_update, dim=-2 - ) / (self.a_sel**0.5) - # nb x nloc x nnei x e_dim - padding_edge_angle_update = torch.concat( - [ - reduced_edge_angle_update, - torch.zeros( - [nb, nloc, self.nnei - self.a_sel, self.e_dim], - dtype=edge_ebd.dtype, - device=edge_ebd.device, - ), - ], - dim=2, - ) - else: - # n_angle x e_dim - weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) - # n_edge x e_dim - padding_edge_angle_update = aggregate( - weighted_edge_angle_update, - eij2a_index, - average=False, - num_owner=n_edge, - ) / (self.dynamic_a_sel**0.5) - - if not self.smooth_edge_update: - # will be deprecated in the future - # not support dynamic index, will pass anyway - if self.use_dynamic_sel: - raise NotImplementedError( - "smooth_edge_update must be True when use_dynamic_sel is True!" - ) - full_mask = torch.concat( - [ - a_nlist_mask, - torch.zeros( - [nb, nloc, self.nnei - self.a_sel], - dtype=a_nlist_mask.dtype, - device=a_nlist_mask.device, - ), - ], - dim=-1, - ) - padding_edge_angle_update = torch.where( - full_mask.unsqueeze(-1), padding_edge_angle_update, edge_ebd - ) - e_update_list.append( - self.act(self.edge_angle_linear2(padding_edge_angle_update)) + edge_angle_update = self._compute_angle_update( + angle_ebd, + node_for_a, + edge_for_a, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + ) + edge_angle_processed = self._compute_edge_angle_reduction( + edge_angle_update, + edge_ebd, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, ) + e_update_list.append(edge_angle_processed) # update edge_ebd e_updated = self.list_update(e_update_list, "edge") # angle self message - # nb x nloc x a_nnei x a_nnei x dim_a - if not self.optim_update: - assert angle_info is not None - angle_self_update = self.act(self.angle_self_linear(angle_info)) - else: - angle_self_update = self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", - ) - ) + angle_self_update = self._compute_angle_update( + angle_ebd, + node_for_a, + edge_for_a, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + ) a_update_list.append(angle_self_update) else: # update edge_ebd diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 70a7985702..74e8aad13e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1555,7 +1555,7 @@ def dpa3_repflow_args() -> list[Argument]: "When True, updates are applied sequentially: edge self → angle self (using updated edge) " "→ edge angle (using updated angle) → node (using final edge), " "instead of the default parallel mode where all updates use original embeddings. " - "Currently only supports update_style='res_residual'." + "Currently only supports update_style='res_residual' and requires update_angle=True." ) return [ diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index b980c584a1..c158d81a93 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -67,8 +67,8 @@ ("const",), # update_residual_init ([], [[0, 1]]), # exclude_types (True,), # update_angle - (0, 1), # a_compress_rate - (1, 2), # a_compress_e_rate + (1,), # a_compress_rate + (2,), # a_compress_e_rate (True,), # a_compress_use_split (True, False), # optim_update (True, False), # edge_init_use_dist @@ -444,8 +444,8 @@ def atol(self) -> float: ("const",), # update_residual_init ([], [[0, 1]]), # exclude_types (True,), # update_angle - (0, 1), # a_compress_rate - (1, 2), # a_compress_e_rate + (1,), # a_compress_rate + (2,), # a_compress_e_rate (True,), # a_compress_use_split (True, False), # optim_update (True, False), # edge_init_use_dist From 5fe472e7c7a83e8066c4e6099d562386fb78c3a2 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 8 May 2026 21:11:00 +0800 Subject: [PATCH 3/9] Update test_dpa3.py --- source/tests/consistent/descriptor/test_dpa3.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index 6b93b849c0..c3254f68fe 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -81,6 +81,7 @@ "n_multi_edge_message", "precision", "add_chg_spin_ebd", + "sequential_update", ) @@ -100,6 +101,7 @@ "n_multi_edge_message": 1, "precision": "float64", "add_chg_spin_ebd": False, + "sequential_update": False, } @@ -131,6 +133,7 @@ def dpa3_case(**overrides: Any) -> tuple: dpa3_case(edge_init_use_dist=False), dpa3_case(use_exp_switch=False), dpa3_case(use_dynamic_sel=False), + dpa3_case(sequential_update=True), # One mixed high-risk path to keep interactions covered. dpa3_case( exclude_types=[[0, 1]], @@ -142,6 +145,7 @@ def dpa3_case(**overrides: Any) -> tuple: use_dynamic_sel=False, use_loc_mapping=False, add_chg_spin_ebd=True, + sequential_update=True, ), ) @@ -169,6 +173,7 @@ def dpa3_descriptor_api_case(**overrides: Any) -> tuple: dpa3_descriptor_api_case(edge_init_use_dist=False), dpa3_descriptor_api_case(use_exp_switch=False), dpa3_descriptor_api_case(use_dynamic_sel=False), + dpa3_descriptor_api_case(sequential_update=False), # One mixed high-risk path to keep interactions covered. dpa3_descriptor_api_case( exclude_types=[[0, 1]], @@ -181,6 +186,7 @@ def dpa3_descriptor_api_case(**overrides: Any) -> tuple: use_loc_mapping=False, fix_stat_std=0.0, add_chg_spin_ebd=True, + sequential_update=True, ), ) @@ -268,6 +274,7 @@ def skip_pt(self) -> bool: _n_multi_edge_message, _precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param return CommonTest.skip_pt @@ -290,6 +297,7 @@ def skip_pd(self) -> bool: _precision, add_chg_spin_ebd, sequential_update, + _sequential_update, ) = self.param return True if add_chg_spin_ebd else CommonTest.skip_pd @@ -311,6 +319,7 @@ def skip_dp(self) -> bool: _n_multi_edge_message, _precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param return CommonTest.skip_dp @@ -332,6 +341,7 @@ def skip_tf(self) -> bool: _n_multi_edge_message, _precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param return True @@ -398,6 +408,7 @@ def setUp(self) -> None: _precision, add_chg_spin_ebd, sequential_update, + _sequential_update, ) = self.param # fparam for charge=5, spin=1 when add_chg_spin_ebd is True self.fparam = ( @@ -504,6 +515,7 @@ def rtol(self) -> float: _n_multi_edge_message, precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param if precision == "float64": return 1e-10 @@ -531,6 +543,7 @@ def atol(self) -> float: _n_multi_edge_message, precision, _add_chg_spin_ebd, + _sequential_update, ) = self.param if precision == "float64": return 1e-6 # need to fix in the future, see issue https://github.com/deepmodeling/deepmd-kit/issues/3786 @@ -567,6 +580,7 @@ def data(self) -> dict: n_multi_edge_message, precision, add_chg_spin_ebd, + sequential_update, ) = self.param return { "ntypes": self.ntypes, @@ -598,6 +612,7 @@ def data(self) -> dict: "update_style": "res_residual", "update_residual": 0.1, "update_residual_init": update_residual_init, + "sequential_update": sequential_update, } ), # kwargs for descriptor From b4d43ee0e4e70edaf31213d5787066e13778eb66 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 8 May 2026 21:11:29 +0800 Subject: [PATCH 4/9] Update test_dpa3.py --- source/tests/consistent/descriptor/test_dpa3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index c3254f68fe..2a4e2e2b20 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -173,7 +173,7 @@ def dpa3_descriptor_api_case(**overrides: Any) -> tuple: dpa3_descriptor_api_case(edge_init_use_dist=False), dpa3_descriptor_api_case(use_exp_switch=False), dpa3_descriptor_api_case(use_dynamic_sel=False), - dpa3_descriptor_api_case(sequential_update=False), + dpa3_descriptor_api_case(sequential_update=True), # One mixed high-risk path to keep interactions covered. dpa3_descriptor_api_case( exclude_types=[[0, 1]], From daf68ba739817b8729437200c65afdbed22e2eb9 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sat, 9 May 2026 16:03:30 +0800 Subject: [PATCH 5/9] Update test_dpa3.py --- source/tests/consistent/descriptor/test_dpa3.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index 2a4e2e2b20..2aa0fd931b 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -296,7 +296,6 @@ def skip_pd(self) -> bool: _n_multi_edge_message, _precision, add_chg_spin_ebd, - sequential_update, _sequential_update, ) = self.param return True if add_chg_spin_ebd else CommonTest.skip_pd @@ -407,7 +406,6 @@ def setUp(self) -> None: _n_multi_edge_message, _precision, add_chg_spin_ebd, - sequential_update, _sequential_update, ) = self.param # fparam for charge=5, spin=1 when add_chg_spin_ebd is True From 1369a46d9fcfaa20352d148f36dac7a2dade0b34 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sun, 10 May 2026 16:45:49 +0800 Subject: [PATCH 6/9] Update repflow_layer.py --- deepmd/pt/model/descriptor/repflow_layer.py | 270 ++++++++++++-------- 1 file changed, 166 insertions(+), 104 deletions(-) diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index a094105487..f9c8e0c270 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -1150,6 +1150,24 @@ def forward( n_edge = h2.shape[0] del a_nlist # may be used in the future + if self.sequential_update and self.update_angle: + return self._forward_sequential( + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist_mask, + a_sw, + edge_index, + angle_index, + nb, + nloc, + n_edge, + ) + n2e_index, n_ext2e_index = edge_index[0], edge_index[1] n2a_index, eij2a_index, eik2a_index = ( angle_index[0], @@ -1166,95 +1184,6 @@ def forward( ) ) - # Edge self update (always from original embeddings) - edge_self_update = self._compute_edge_self_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nei_node_ebd, - nlist, - n2e_index, - n_ext2e_index, - ) - - if self.sequential_update and self.update_angle: - # === Sequential update path === - # Phase 1: Apply edge self residual - edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update - - # Phase 2: Angle self (uses updated edge_ebd_s1) - node_for_a, edge_for_a = self._prepare_angle_embeddings( - node_ebd, edge_ebd_s1, a_nlist_mask - ) - angle_self_update = self._compute_angle_update( - angle_ebd, - node_for_a, - edge_for_a, - "angle", - n2a_index, - eij2a_index, - eik2a_index, - ) - a_updated = angle_ebd + self.a_residual[0] * angle_self_update - - # Phase 3: Edge angle (uses updated angle a_updated + edge_ebd_s1) - edge_angle_update = self._compute_angle_update( - a_updated, - node_for_a, - edge_for_a, - "edge", - n2a_index, - eij2a_index, - eik2a_index, - ) - edge_angle_processed = self._compute_edge_angle_reduction( - edge_angle_update, - edge_ebd_s1, - a_sw, - a_nlist_mask, - nb, - nloc, - n_edge, - eij2a_index, - ) - e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed - - # Phase 4+5: Node updates (uses e_updated) - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - node_sym = self._compute_node_sym( - e_updated, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc - ) - node_edge_update = self._compute_node_edge_message( - node_ebd, - node_ebd_ext, - e_updated, - nei_node_ebd, - sw, - nlist, - n2e_index, - n_ext2e_index, - nb, - nloc, - ) - - n_update_list: list[torch.Tensor] = [ - node_ebd, - node_self_mlp, - node_sym, - ] - if self.n_multi_edge_message > 1: - node_edge_update_mul_head = node_edge_update.view( - nb, nloc, self.n_multi_edge_message, self.n_dim - ) - for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[..., head_index, :]) - else: - n_update_list.append(node_edge_update) - n_updated = self.list_update(n_update_list, "node") - - return n_updated, e_updated, a_updated - - # === Parallel update path === n_update_list: list[torch.Tensor] = [node_ebd] e_update_list: list[torch.Tensor] = [edge_ebd] a_update_list: list[torch.Tensor] = [angle_ebd] @@ -1295,6 +1224,15 @@ def forward( n_updated = self.list_update(n_update_list, "node") # edge self message + edge_self_update = self._compute_edge_self_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + ) e_update_list.append(edge_self_update) if self.update_angle: @@ -1302,39 +1240,40 @@ def forward( assert self.edge_angle_linear1 is not None assert self.edge_angle_linear2 is not None - node_for_a, edge_for_a = self._prepare_angle_embeddings( + node_ebd_for_angle, edge_ebd_for_angle = self._prepare_angle_embeddings( node_ebd, edge_ebd, a_nlist_mask ) # edge angle message edge_angle_update = self._compute_angle_update( angle_ebd, - node_for_a, - edge_for_a, + node_ebd_for_angle, + edge_ebd_for_angle, "edge", n2a_index, eij2a_index, eik2a_index, ) - edge_angle_processed = self._compute_edge_angle_reduction( - edge_angle_update, - edge_ebd, - a_sw, - a_nlist_mask, - nb, - nloc, - n_edge, - eij2a_index, + e_update_list.append( + self._compute_edge_angle_reduction( + edge_angle_update, + edge_ebd, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) ) - e_update_list.append(edge_angle_processed) # update edge_ebd e_updated = self.list_update(e_update_list, "edge") # angle self message angle_self_update = self._compute_angle_update( angle_ebd, - node_for_a, - edge_for_a, + node_ebd_for_angle, + edge_ebd_for_angle, "angle", n2a_index, eij2a_index, @@ -1349,6 +1288,129 @@ def forward( a_updated = self.list_update(a_update_list, "angle") return n_updated, e_updated, a_updated + def _forward_sequential( + self, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + h2: torch.Tensor, + angle_ebd: torch.Tensor, + nlist: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + a_nlist_mask: torch.Tensor, + a_sw: torch.Tensor, + edge_index: torch.Tensor, + angle_index: torch.Tensor, + nb: int, + nloc: int, + n_edge: int | None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sequential update path: edge_self -> angle_self -> edge_angle -> node. + + Each phase consumes the most recent update of its inputs (instead of + the parallel path which uses the same ``edge_ebd``/``angle_ebd`` for + all branches). Residuals are applied immediately after each phase. + """ + assert self.update_style == "res_residual" + assert self.update_angle + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + node_ebd = node_ebd_ext[:, :nloc, :] + n2e_index, n_ext2e_index = edge_index[0], edge_index[1] + n2a_index, eij2a_index, eik2a_index = ( + angle_index[0], + angle_index[1], + angle_index[2], + ) + nei_node_ebd = ( + _make_nei_g1(node_ebd_ext, nlist) + if not self.use_dynamic_sel + else torch.index_select( + node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index + ) + ) + + # Phase 1: edge self update -> apply residual immediately. + edge_self_update = self._compute_edge_self_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + ) + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # Phase 2: angle self message uses the updated edge embedding. + node_ebd_for_angle, edge_ebd_for_angle = self._prepare_angle_embeddings( + node_ebd, edge_ebd_s1, a_nlist_mask + ) + angle_self_update = self._compute_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + ) + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # Phase 3: edge angle message uses the updated angle embedding. + edge_angle_update = self._compute_angle_update( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + ) + edge_angle_reduced = self._compute_edge_angle_reduction( + edge_angle_update, + edge_ebd_s1, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_reduced + + # Phase 4: node updates use the fully updated edge embedding. + n_update_list: list[torch.Tensor] = [node_ebd] + n_update_list.append(self.act(self.node_self_mlp(node_ebd))) + n_update_list.append( + self._compute_node_sym( + e_updated, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc + ) + ) + node_edge_update = self._compute_node_edge_message( + node_ebd, + node_ebd_ext, + e_updated, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = node_edge_update.view( + nb, nloc, self.n_multi_edge_message, self.n_dim + ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) + else: + n_update_list.append(node_edge_update) + n_updated = self.list_update(n_update_list, "node") + return n_updated, e_updated, a_updated + @torch.jit.export def list_update_res_avg( self, From 9cb9ce142d376dffb7b6c400c75e324c6a9bc604 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sun, 10 May 2026 16:58:57 +0800 Subject: [PATCH 7/9] Update repflow_layer.py --- deepmd/pt/model/descriptor/repflow_layer.py | 1218 +++++++++---------- 1 file changed, 609 insertions(+), 609 deletions(-) diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index f9c8e0c270..15cfa84536 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -702,265 +702,351 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update - def _compute_edge_self_update( + def forward( self, - node_ebd: torch.Tensor, - node_ebd_ext: torch.Tensor, - edge_ebd: torch.Tensor, - nei_node_ebd: torch.Tensor, - nlist: torch.Tensor, - n2e_index: torch.Tensor, - n_ext2e_index: torch.Tensor, - ) -> torch.Tensor: - """Compute edge self update.""" - if not self.optim_update: - if not self.use_dynamic_sel: - edge_info = torch.cat( - [ - torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), - nei_node_ebd, - edge_ebd, - ], - dim=-1, - ) - else: - edge_info = torch.cat( - [ - torch.index_select( - node_ebd.reshape(-1, self.n_dim), 0, n2e_index - ), - nei_node_ebd, - edge_ebd, - ], - dim=-1, - ) - return self.act(self.edge_self_linear(edge_info)) + node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parallel_mode + edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim + h2: torch.Tensor, # nf x nloc x nnei x 3 + angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim + nlist: torch.Tensor, # nf x nloc x nnei + nlist_mask: torch.Tensor, # nf x nloc x nnei + sw: torch.Tensor, # switch func, nf x nloc x nnei + a_nlist: torch.Tensor, # nf x nloc x a_nnei + a_nlist_mask: torch.Tensor, # nf x nloc x a_nnei + a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei + edge_index: torch.Tensor, # 2 x n_edge + angle_index: torch.Tensor, # 3 x n_angle + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters + ---------- + node_ebd_ext : nf x nall x n_dim + Extended node embedding. + edge_ebd : nf x nloc x nnei x e_dim + Edge embedding. + h2 : nf x nloc x nnei x 3 + Pair-atom channel, equivariant. + angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim + Angle embedding. + nlist : nf x nloc x nnei + Neighbor list. (padded neis are set to 0) + nlist_mask : nf x nloc x nnei + Masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei + Switch function. + a_nlist : nf x nloc x a_nnei + Neighbor list for angle. (padded neis are set to 0) + a_nlist_mask : nf x nloc x a_nnei + Masks of the neighbor list for angle. real nei 1 otherwise 0 + a_sw : nf x nloc x a_nnei + Switch function for angle. + edge_index : Optional for dynamic sel, 2 x n_edge + n2e_index : n_edge + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). + n_ext2e_index : n_edge + Broadcast indices from extended node(j) to edge(ij). + angle_index : Optional for dynamic sel, 3 x n_angle + n2a_index : n_angle + Broadcast indices from extended node(j) to angle(ijk). + eij2a_index : n_angle + Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). + eik2a_index : n_angle + Broadcast indices from extended edge(ik) to angle(ijk). + + Returns + ------- + n_updated: nf x nloc x n_dim + Updated node embedding. + e_updated: nf x nloc x nnei x e_dim + Updated edge embedding. + a_updated : nf x nloc x a_nnei x a_nnei x a_dim + Updated angle embedding. + """ + nb, nloc, nnei = nlist.shape + nall = node_ebd_ext.shape[1] + node_ebd = node_ebd_ext[:, :nloc, :] + assert (nb, nloc) == node_ebd.shape[:2] + if not self.use_dynamic_sel: + assert (nb, nloc, nnei, 3) == h2.shape + n_edge = None else: - return self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "edge", - ) + # n_edge = int(nlist_mask.sum().item()) + # assert (n_edge, 3) == h2.shape + n_edge = h2.shape[0] + del a_nlist # may be used in the future + + if self.sequential_update and self.update_angle: + return self._forward_sequential( + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist_mask, + a_sw, + edge_index, + angle_index, + nb, + nloc, + n_edge, ) - def _prepare_angle_embeddings( - self, - node_ebd: torch.Tensor, - edge_ebd: torch.Tensor, - a_nlist_mask: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Prepare compressed node/edge embeddings for angle computation.""" - if self.a_compress_rate != 0: - if not self.a_compress_use_split: - assert self.a_compress_n_linear is not None - assert self.a_compress_e_linear is not None - node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) - else: - node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] - else: - node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd + n2e_index, n_ext2e_index = edge_index[0], edge_index[1] + n2a_index, eij2a_index, eik2a_index = ( + angle_index[0], + angle_index[1], + angle_index[2], + ) - if not self.use_dynamic_sel: - edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] - edge_ebd_for_angle = torch.where( - a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 + # nb x nloc x nnei x n_dim [OR] n_edge x n_dim + nei_node_ebd = ( + _make_nei_g1(node_ebd_ext, nlist) + if not self.use_dynamic_sel + else torch.index_select( + node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index ) - return node_ebd_for_angle, edge_ebd_for_angle + ) - def _compute_angle_update( - self, - angle_ebd: torch.Tensor, - node_ebd_for_angle: torch.Tensor, - edge_ebd_for_angle: torch.Tensor, - feat: str, - n2a_index: torch.Tensor, - eij2a_index: torch.Tensor, - eik2a_index: torch.Tensor, - ) -> torch.Tensor: - """Compute angle-based update (for edge_angle or angle_self). + n_update_list: list[torch.Tensor] = [node_ebd] + e_update_list: list[torch.Tensor] = [edge_ebd] + a_update_list: list[torch.Tensor] = [angle_ebd] - Parameters - ---------- - feat : str - "edge" for edge_angle_linear1, "angle" for angle_self_linear. - """ - if not self.optim_update: - node_for_angle_info = ( - torch.tile( - node_ebd_for_angle.unsqueeze(2).unsqueeze(2), - (1, 1, self.a_sel, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else torch.index_select( - node_ebd_for_angle.reshape(-1, self.n_a_compress_dim), - 0, - n2a_index, - ) + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym (grrg + drrd) + node_sym = self._compute_node_sym( + edge_ebd, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc + ) + n_update_list.append(node_sym) + + # node edge message + node_edge_update = self._compute_node_edge_message( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + if self.n_multi_edge_message > 1: + # nb x nloc x h x n_dim + node_edge_update_mul_head = node_edge_update.view( + nb, nloc, self.n_multi_edge_message, self.n_dim ) - edge_for_angle_k = ( - torch.tile(edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)) - if not self.use_dynamic_sel - else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) + else: + n_update_list.append(node_edge_update) + # update node_ebd + n_updated = self.list_update(n_update_list, "node") + + # edge self message + edge_self_update = self._compute_edge_self_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + ) + e_update_list.append(edge_self_update) + + if self.update_angle: + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + + node_ebd_for_angle, edge_ebd_for_angle = self._prepare_angle_embeddings( + node_ebd, edge_ebd, a_nlist_mask ) - edge_for_angle_j = ( - torch.tile(edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)) - if not self.use_dynamic_sel - else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) + + # edge angle message + edge_angle_update = self._compute_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + n2a_index, + eij2a_index, + eik2a_index, ) - edge_for_angle_info = torch.cat( - [edge_for_angle_k, edge_for_angle_j], dim=-1 + e_update_list.append( + self._compute_edge_angle_reduction( + edge_angle_update, + edge_ebd, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) ) - angle_info = torch.cat( - [angle_ebd, node_for_angle_info, edge_for_angle_info], dim=-1 + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") + + # angle self message + angle_self_update = self._compute_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + n2a_index, + eij2a_index, + eik2a_index, ) - if feat == "edge": - assert self.edge_angle_linear1 is not None - return self.act(self.edge_angle_linear1(angle_info)) - else: - assert self.angle_self_linear is not None - return self.act(self.angle_self_linear(angle_info)) + a_update_list.append(angle_self_update) else: - if feat == "edge": - return self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) - else: - return self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", - ) - ) + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") - def _compute_edge_angle_reduction( + # update angle_ebd + a_updated = self.list_update(a_update_list, "angle") + return n_updated, e_updated, a_updated + + def _forward_sequential( self, - edge_angle_update: torch.Tensor, - edge_ebd_fallback: torch.Tensor, - a_sw: torch.Tensor, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + h2: torch.Tensor, + angle_ebd: torch.Tensor, + nlist: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, a_nlist_mask: torch.Tensor, + a_sw: torch.Tensor, + edge_index: torch.Tensor, + angle_index: torch.Tensor, nb: int, nloc: int, n_edge: int | None, - eij2a_index: torch.Tensor, - ) -> torch.Tensor: - """Reduce edge angle update over angle dimension, pad, and apply linear2. + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Sequential update path: edge_self -> angle_self -> edge_angle -> node. - Parameters - ---------- - edge_ebd_fallback : torch.Tensor - Edge embedding used for non-smooth padding fallback. + Each phase consumes the most recent update of its inputs (instead of + the parallel path which uses the same ``edge_ebd``/``angle_ebd`` for + all branches). Residuals are applied immediately after each phase. """ + assert self.update_style == "res_residual" + assert self.update_angle + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None assert self.edge_angle_linear2 is not None - if not self.use_dynamic_sel: - weighted_edge_angle_update = ( - a_sw.unsqueeze(-1).unsqueeze(-1) - * a_sw.unsqueeze(-2).unsqueeze(-1) - * edge_angle_update - ) - reduced_edge_angle_update = torch.sum( - weighted_edge_angle_update, dim=-2 - ) / (self.a_sel**0.5) - padding_edge_angle_update = torch.concat( - [ - reduced_edge_angle_update, - torch.zeros( - [nb, nloc, self.nnei - self.a_sel, self.e_dim], - dtype=edge_ebd_fallback.dtype, - device=edge_ebd_fallback.device, - ), - ], - dim=2, + node_ebd = node_ebd_ext[:, :nloc, :] + n2e_index, n_ext2e_index = edge_index[0], edge_index[1] + n2a_index, eij2a_index, eik2a_index = ( + angle_index[0], + angle_index[1], + angle_index[2], + ) + nei_node_ebd = ( + _make_nei_g1(node_ebd_ext, nlist) + if not self.use_dynamic_sel + else torch.index_select( + node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index ) - else: - assert n_edge is not None - weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) - padding_edge_angle_update = aggregate( - weighted_edge_angle_update, - eij2a_index, - average=False, - num_owner=n_edge, - ) / (self.dynamic_a_sel**0.5) + ) + + # Phase 1: edge self update -> apply residual immediately. + edge_self_update = self._compute_edge_self_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + ) + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # Phase 2: angle self message uses the updated edge embedding. + node_ebd_for_angle, edge_ebd_for_angle = self._prepare_angle_embeddings( + node_ebd, edge_ebd_s1, a_nlist_mask + ) + angle_self_update = self._compute_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + ) + a_updated = angle_ebd + self.a_residual[0] * angle_self_update - if not self.smooth_edge_update: - if self.use_dynamic_sel: - raise NotImplementedError( - "smooth_edge_update must be True when use_dynamic_sel is True!" - ) - full_mask = torch.concat( - [ - a_nlist_mask, - torch.zeros( - [nb, nloc, self.nnei - self.a_sel], - dtype=a_nlist_mask.dtype, - device=a_nlist_mask.device, - ), - ], - dim=-1, + # Phase 3: edge angle message uses the updated angle embedding. + edge_angle_update = self._compute_angle_update( + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + ) + edge_angle_reduced = self._compute_edge_angle_reduction( + edge_angle_update, + edge_ebd_s1, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_reduced + + # Phase 4: node updates use the fully updated edge embedding. + n_update_list: list[torch.Tensor] = [node_ebd] + n_update_list.append(self.act(self.node_self_mlp(node_ebd))) + n_update_list.append( + self._compute_node_sym( + e_updated, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc ) - padding_edge_angle_update = torch.where( - full_mask.unsqueeze(-1), - padding_edge_angle_update, - edge_ebd_fallback, + ) + node_edge_update = self._compute_node_edge_message( + node_ebd, + node_ebd_ext, + e_updated, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = node_edge_update.view( + nb, nloc, self.n_multi_edge_message, self.n_dim ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[..., head_index, :]) + else: + n_update_list.append(node_edge_update) + n_updated = self.list_update(n_update_list, "node") + return n_updated, e_updated, a_updated - return self.act(self.edge_angle_linear2(padding_edge_angle_update)) - - def _compute_node_edge_message( + def _compute_edge_self_update( self, node_ebd: torch.Tensor, node_ebd_ext: torch.Tensor, edge_ebd: torch.Tensor, nei_node_ebd: torch.Tensor, - sw: torch.Tensor, nlist: torch.Tensor, n2e_index: torch.Tensor, n_ext2e_index: torch.Tensor, - nb: int, - nloc: int, ) -> torch.Tensor: - """Compute node edge message and reduce over neighbor dimension.""" + """Compute edge self update.""" if not self.optim_update: if not self.use_dynamic_sel: edge_info = torch.cat( @@ -982,434 +1068,348 @@ def _compute_node_edge_message( ], dim=-1, ) - node_edge_update = self.act( - self.node_edge_linear(edge_info) - ) * sw.unsqueeze(-1) + return self.act(self.edge_self_linear(edge_info)) else: - node_edge_update = self.act( + return self.act( self.optim_edge_update( node_ebd, node_ebd_ext, edge_ebd, nlist, - "node", + "edge", ) if not self.use_dynamic_sel else self.optim_edge_update_dynamic( node_ebd, node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "node", - ) - ) * sw.unsqueeze(-1) - - node_edge_update = ( - (torch.sum(node_edge_update, dim=-2) / self.nnei) - if not self.use_dynamic_sel - else ( - aggregate( - node_edge_update, - n2e_index, - average=False, - num_owner=nb * nloc, - ).reshape(nb, nloc, node_edge_update.shape[-1]) - / self.dynamic_e_sel - ) - ) - return node_edge_update - - def _compute_node_sym( - self, - edge_ebd: torch.Tensor, - nei_node_ebd: torch.Tensor, - h2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - n2e_index: torch.Tensor, - nb: int, - nloc: int, - ) -> torch.Tensor: - """Compute node symmetrization update (grrg + drrd).""" - node_sym_list: list[torch.Tensor] = [] - node_sym_list.append( - self.symmetrization_op( - edge_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( - edge_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) - ) - node_sym_list.append( - self.symmetrization_op( - nei_node_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( - nei_node_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) - ) - return self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) - - def forward( - self, - node_ebd_ext: torch.Tensor, # nf x nall x n_dim [OR] nf x nloc x n_dim when not parallel_mode - edge_ebd: torch.Tensor, # nf x nloc x nnei x e_dim - h2: torch.Tensor, # nf x nloc x nnei x 3 - angle_ebd: torch.Tensor, # nf x nloc x a_nnei x a_nnei x a_dim - nlist: torch.Tensor, # nf x nloc x nnei - nlist_mask: torch.Tensor, # nf x nloc x nnei - sw: torch.Tensor, # switch func, nf x nloc x nnei - a_nlist: torch.Tensor, # nf x nloc x a_nnei - a_nlist_mask: torch.Tensor, # nf x nloc x a_nnei - a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei - edge_index: torch.Tensor, # 2 x n_edge - angle_index: torch.Tensor, # 3 x n_angle - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Parameters - ---------- - node_ebd_ext : nf x nall x n_dim - Extended node embedding. - edge_ebd : nf x nloc x nnei x e_dim - Edge embedding. - h2 : nf x nloc x nnei x 3 - Pair-atom channel, equivariant. - angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim - Angle embedding. - nlist : nf x nloc x nnei - Neighbor list. (padded neis are set to 0) - nlist_mask : nf x nloc x nnei - Masks of the neighbor list. real nei 1 otherwise 0 - sw : nf x nloc x nnei - Switch function. - a_nlist : nf x nloc x a_nnei - Neighbor list for angle. (padded neis are set to 0) - a_nlist_mask : nf x nloc x a_nnei - Masks of the neighbor list for angle. real nei 1 otherwise 0 - a_sw : nf x nloc x a_nnei - Switch function for angle. - edge_index : Optional for dynamic sel, 2 x n_edge - n2e_index : n_edge - Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). - n_ext2e_index : n_edge - Broadcast indices from extended node(j) to edge(ij). - angle_index : Optional for dynamic sel, 3 x n_angle - n2a_index : n_angle - Broadcast indices from extended node(j) to angle(ijk). - eij2a_index : n_angle - Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). - eik2a_index : n_angle - Broadcast indices from extended edge(ik) to angle(ijk). - - Returns - ------- - n_updated: nf x nloc x n_dim - Updated node embedding. - e_updated: nf x nloc x nnei x e_dim - Updated edge embedding. - a_updated : nf x nloc x a_nnei x a_nnei x a_dim - Updated angle embedding. - """ - nb, nloc, nnei = nlist.shape - nall = node_ebd_ext.shape[1] - node_ebd = node_ebd_ext[:, :nloc, :] - assert (nb, nloc) == node_ebd.shape[:2] - if not self.use_dynamic_sel: - assert (nb, nloc, nnei, 3) == h2.shape - n_edge = None - else: - # n_edge = int(nlist_mask.sum().item()) - # assert (n_edge, 3) == h2.shape - n_edge = h2.shape[0] - del a_nlist # may be used in the future - - if self.sequential_update and self.update_angle: - return self._forward_sequential( - node_ebd_ext, - edge_ebd, - h2, - angle_ebd, - nlist, - nlist_mask, - sw, - a_nlist_mask, - a_sw, - edge_index, - angle_index, - nb, - nloc, - n_edge, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) ) - n2e_index, n_ext2e_index = edge_index[0], edge_index[1] - n2a_index, eij2a_index, eik2a_index = ( - angle_index[0], - angle_index[1], - angle_index[2], - ) + def _prepare_angle_embeddings( + self, + node_ebd: torch.Tensor, + edge_ebd: torch.Tensor, + a_nlist_mask: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare compressed node/edge embeddings for angle computation.""" + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) + else: + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd - # nb x nloc x nnei x n_dim [OR] n_edge x n_dim - nei_node_ebd = ( - _make_nei_g1(node_ebd_ext, nlist) - if not self.use_dynamic_sel - else torch.index_select( - node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index + if not self.use_dynamic_sel: + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + edge_ebd_for_angle = torch.where( + a_nlist_mask.unsqueeze(-1), edge_ebd_for_angle, 0.0 ) - ) - - n_update_list: list[torch.Tensor] = [node_ebd] - e_update_list: list[torch.Tensor] = [edge_ebd] - a_update_list: list[torch.Tensor] = [angle_ebd] - - # node self mlp - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - n_update_list.append(node_self_mlp) + return node_ebd_for_angle, edge_ebd_for_angle - # node sym (grrg + drrd) - node_sym = self._compute_node_sym( - edge_ebd, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc - ) - n_update_list.append(node_sym) + def _compute_angle_update( + self, + angle_ebd: torch.Tensor, + node_ebd_for_angle: torch.Tensor, + edge_ebd_for_angle: torch.Tensor, + feat: str, + n2a_index: torch.Tensor, + eij2a_index: torch.Tensor, + eik2a_index: torch.Tensor, + ) -> torch.Tensor: + """Compute angle-based update (for edge_angle or angle_self). - # node edge message - node_edge_update = self._compute_node_edge_message( - node_ebd, - node_ebd_ext, - edge_ebd, - nei_node_ebd, - sw, - nlist, - n2e_index, - n_ext2e_index, - nb, - nloc, - ) - if self.n_multi_edge_message > 1: - # nb x nloc x h x n_dim - node_edge_update_mul_head = node_edge_update.view( - nb, nloc, self.n_multi_edge_message, self.n_dim + Parameters + ---------- + feat : str + "edge" for edge_angle_linear1, "angle" for angle_self_linear. + """ + if not self.optim_update: + node_for_angle_info = ( + torch.tile( + node_ebd_for_angle.unsqueeze(2).unsqueeze(2), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else torch.index_select( + node_ebd_for_angle.reshape(-1, self.n_a_compress_dim), + 0, + n2a_index, + ) ) - for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[..., head_index, :]) + edge_for_angle_k = ( + torch.tile(edge_ebd_for_angle.unsqueeze(2), (1, 1, self.a_sel, 1, 1)) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eik2a_index) + ) + edge_for_angle_j = ( + torch.tile(edge_ebd_for_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1)) + if not self.use_dynamic_sel + else torch.index_select(edge_ebd_for_angle, 0, eij2a_index) + ) + edge_for_angle_info = torch.cat( + [edge_for_angle_k, edge_for_angle_j], dim=-1 + ) + angle_info = torch.cat( + [angle_ebd, node_for_angle_info, edge_for_angle_info], dim=-1 + ) + if feat == "edge": + assert self.edge_angle_linear1 is not None + return self.act(self.edge_angle_linear1(angle_info)) + else: + assert self.angle_self_linear is not None + return self.act(self.angle_self_linear(angle_info)) else: - n_update_list.append(node_edge_update) - # update node_ebd - n_updated = self.list_update(n_update_list, "node") - - # edge self message - edge_self_update = self._compute_edge_self_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nei_node_ebd, - nlist, - n2e_index, - n_ext2e_index, - ) - e_update_list.append(edge_self_update) - - if self.update_angle: - assert self.angle_self_linear is not None - assert self.edge_angle_linear1 is not None - assert self.edge_angle_linear2 is not None + if feat == "edge": + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) + ) + else: + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) + ) - node_ebd_for_angle, edge_ebd_for_angle = self._prepare_angle_embeddings( - node_ebd, edge_ebd, a_nlist_mask - ) + def _compute_edge_angle_reduction( + self, + edge_angle_update: torch.Tensor, + edge_ebd_fallback: torch.Tensor, + a_sw: torch.Tensor, + a_nlist_mask: torch.Tensor, + nb: int, + nloc: int, + n_edge: int | None, + eij2a_index: torch.Tensor, + ) -> torch.Tensor: + """Reduce edge angle update over angle dimension, pad, and apply linear2. - # edge angle message - edge_angle_update = self._compute_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - n2a_index, - eij2a_index, - eik2a_index, + Parameters + ---------- + edge_ebd_fallback : torch.Tensor + Edge embedding used for non-smooth padding fallback. + """ + assert self.edge_angle_linear2 is not None + if not self.use_dynamic_sel: + weighted_edge_angle_update = ( + a_sw.unsqueeze(-1).unsqueeze(-1) + * a_sw.unsqueeze(-2).unsqueeze(-1) + * edge_angle_update ) - e_update_list.append( - self._compute_edge_angle_reduction( - edge_angle_update, - edge_ebd, - a_sw, - a_nlist_mask, - nb, - nloc, - n_edge, - eij2a_index, - ) + reduced_edge_angle_update = torch.sum( + weighted_edge_angle_update, dim=-2 + ) / (self.a_sel**0.5) + padding_edge_angle_update = torch.concat( + [ + reduced_edge_angle_update, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel, self.e_dim], + dtype=edge_ebd_fallback.dtype, + device=edge_ebd_fallback.device, + ), + ], + dim=2, ) - # update edge_ebd - e_updated = self.list_update(e_update_list, "edge") - - # angle self message - angle_self_update = self._compute_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", - n2a_index, + else: + assert n_edge is not None + weighted_edge_angle_update = edge_angle_update * a_sw.unsqueeze(-1) + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, eij2a_index, - eik2a_index, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + + if not self.smooth_edge_update: + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) + full_mask = torch.concat( + [ + a_nlist_mask, + torch.zeros( + [nb, nloc, self.nnei - self.a_sel], + dtype=a_nlist_mask.dtype, + device=a_nlist_mask.device, + ), + ], + dim=-1, + ) + padding_edge_angle_update = torch.where( + full_mask.unsqueeze(-1), + padding_edge_angle_update, + edge_ebd_fallback, ) - a_update_list.append(angle_self_update) - else: - # update edge_ebd - e_updated = self.list_update(e_update_list, "edge") - # update angle_ebd - a_updated = self.list_update(a_update_list, "angle") - return n_updated, e_updated, a_updated + return self.act(self.edge_angle_linear2(padding_edge_angle_update)) - def _forward_sequential( + def _compute_node_edge_message( self, + node_ebd: torch.Tensor, node_ebd_ext: torch.Tensor, edge_ebd: torch.Tensor, - h2: torch.Tensor, - angle_ebd: torch.Tensor, - nlist: torch.Tensor, - nlist_mask: torch.Tensor, + nei_node_ebd: torch.Tensor, sw: torch.Tensor, - a_nlist_mask: torch.Tensor, - a_sw: torch.Tensor, - edge_index: torch.Tensor, - angle_index: torch.Tensor, + nlist: torch.Tensor, + n2e_index: torch.Tensor, + n_ext2e_index: torch.Tensor, nb: int, nloc: int, - n_edge: int | None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Sequential update path: edge_self -> angle_self -> edge_angle -> node. + ) -> torch.Tensor: + """Compute node edge message and reduce over neighbor dimension.""" + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + else: + edge_info = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + node_edge_update = self.act( + self.node_edge_linear(edge_info) + ) * sw.unsqueeze(-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "node", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "node", + ) + ) * sw.unsqueeze(-1) - Each phase consumes the most recent update of its inputs (instead of - the parallel path which uses the same ``edge_ebd``/``angle_ebd`` for - all branches). Residuals are applied immediately after each phase. - """ - assert self.update_style == "res_residual" - assert self.update_angle - assert self.angle_self_linear is not None - assert self.edge_angle_linear1 is not None - assert self.edge_angle_linear2 is not None - node_ebd = node_ebd_ext[:, :nloc, :] - n2e_index, n_ext2e_index = edge_index[0], edge_index[1] - n2a_index, eij2a_index, eik2a_index = ( - angle_index[0], - angle_index[1], - angle_index[2], - ) - nei_node_ebd = ( - _make_nei_g1(node_ebd_ext, nlist) + node_edge_update = ( + (torch.sum(node_edge_update, dim=-2) / self.nnei) if not self.use_dynamic_sel - else torch.index_select( - node_ebd_ext.reshape(-1, self.n_dim), 0, n_ext2e_index + else ( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ).reshape(nb, nloc, node_edge_update.shape[-1]) + / self.dynamic_e_sel ) ) + return node_edge_update - # Phase 1: edge self update -> apply residual immediately. - edge_self_update = self._compute_edge_self_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nei_node_ebd, - nlist, - n2e_index, - n_ext2e_index, - ) - edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update - - # Phase 2: angle self message uses the updated edge embedding. - node_ebd_for_angle, edge_ebd_for_angle = self._prepare_angle_embeddings( - node_ebd, edge_ebd_s1, a_nlist_mask - ) - angle_self_update = self._compute_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", - n2a_index, - eij2a_index, - eik2a_index, - ) - a_updated = angle_ebd + self.a_residual[0] * angle_self_update - - # Phase 3: edge angle message uses the updated angle embedding. - edge_angle_update = self._compute_angle_update( - a_updated, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - n2a_index, - eij2a_index, - eik2a_index, - ) - edge_angle_reduced = self._compute_edge_angle_reduction( - edge_angle_update, - edge_ebd_s1, - a_sw, - a_nlist_mask, - nb, - nloc, - n_edge, - eij2a_index, - ) - e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_reduced - - # Phase 4: node updates use the fully updated edge embedding. - n_update_list: list[torch.Tensor] = [node_ebd] - n_update_list.append(self.act(self.node_self_mlp(node_ebd))) - n_update_list.append( - self._compute_node_sym( - e_updated, nei_node_ebd, h2, nlist_mask, sw, n2e_index, nb, nloc + def _compute_node_sym( + self, + edge_ebd: torch.Tensor, + nei_node_ebd: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + n2e_index: torch.Tensor, + nb: int, + nloc: int, + ) -> torch.Tensor: + """Compute node symmetrization update (grrg + drrd).""" + node_sym_list: list[torch.Tensor] = [] + node_sym_list.append( + self.symmetrization_op( + edge_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + edge_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, ) ) - node_edge_update = self._compute_node_edge_message( - node_ebd, - node_ebd_ext, - e_updated, - nei_node_ebd, - sw, - nlist, - n2e_index, - n_ext2e_index, - nb, - nloc, - ) - if self.n_multi_edge_message > 1: - node_edge_update_mul_head = node_edge_update.view( - nb, nloc, self.n_multi_edge_message, self.n_dim + node_sym_list.append( + self.symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, ) - for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[..., head_index, :]) - else: - n_update_list.append(node_edge_update) - n_updated = self.list_update(n_update_list, "node") - return n_updated, e_updated, a_updated + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + return self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) @torch.jit.export def list_update_res_avg( From a54dce522bf5f893c245022aff3afb6f1cde8c95 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sun, 10 May 2026 17:11:35 +0800 Subject: [PATCH 8/9] Update repflow_layer.py --- deepmd/pt/model/descriptor/repflow_layer.py | 260 ++++++++++---------- 1 file changed, 130 insertions(+), 130 deletions(-) diff --git a/deepmd/pt/model/descriptor/repflow_layer.py b/deepmd/pt/model/descriptor/repflow_layer.py index 15cfa84536..0c94219152 100644 --- a/deepmd/pt/model/descriptor/repflow_layer.py +++ b/deepmd/pt/model/descriptor/repflow_layer.py @@ -1036,6 +1036,136 @@ def _forward_sequential( n_updated = self.list_update(n_update_list, "node") return n_updated, e_updated, a_updated + def _compute_node_sym( + self, + edge_ebd: torch.Tensor, + nei_node_ebd: torch.Tensor, + h2: torch.Tensor, + nlist_mask: torch.Tensor, + sw: torch.Tensor, + n2e_index: torch.Tensor, + nb: int, + nloc: int, + ) -> torch.Tensor: + """Compute node symmetrization update (grrg + drrd).""" + node_sym_list: list[torch.Tensor] = [] + node_sym_list.append( + self.symmetrization_op( + edge_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + edge_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + node_sym_list.append( + self.symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else self.symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + return self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) + + def _compute_node_edge_message( + self, + node_ebd: torch.Tensor, + node_ebd_ext: torch.Tensor, + edge_ebd: torch.Tensor, + nei_node_ebd: torch.Tensor, + sw: torch.Tensor, + nlist: torch.Tensor, + n2e_index: torch.Tensor, + n_ext2e_index: torch.Tensor, + nb: int, + nloc: int, + ) -> torch.Tensor: + """Compute node edge message and reduce over neighbor dimension.""" + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = torch.cat( + [ + torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + else: + edge_info = torch.cat( + [ + torch.index_select( + node_ebd.reshape(-1, self.n_dim), 0, n2e_index + ), + nei_node_ebd, + edge_ebd, + ], + dim=-1, + ) + node_edge_update = self.act( + self.node_edge_linear(edge_info) + ) * sw.unsqueeze(-1) + else: + node_edge_update = self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "node", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "node", + ) + ) * sw.unsqueeze(-1) + + node_edge_update = ( + (torch.sum(node_edge_update, dim=-2) / self.nnei) + if not self.use_dynamic_sel + else ( + aggregate( + node_edge_update, + n2e_index, + average=False, + num_owner=nb * nloc, + ).reshape(nb, nloc, node_edge_update.shape[-1]) + / self.dynamic_e_sel + ) + ) + return node_edge_update + def _compute_edge_self_update( self, node_ebd: torch.Tensor, @@ -1281,136 +1411,6 @@ def _compute_edge_angle_reduction( return self.act(self.edge_angle_linear2(padding_edge_angle_update)) - def _compute_node_edge_message( - self, - node_ebd: torch.Tensor, - node_ebd_ext: torch.Tensor, - edge_ebd: torch.Tensor, - nei_node_ebd: torch.Tensor, - sw: torch.Tensor, - nlist: torch.Tensor, - n2e_index: torch.Tensor, - n_ext2e_index: torch.Tensor, - nb: int, - nloc: int, - ) -> torch.Tensor: - """Compute node edge message and reduce over neighbor dimension.""" - if not self.optim_update: - if not self.use_dynamic_sel: - edge_info = torch.cat( - [ - torch.tile(node_ebd.unsqueeze(-2), [1, 1, self.nnei, 1]), - nei_node_ebd, - edge_ebd, - ], - dim=-1, - ) - else: - edge_info = torch.cat( - [ - torch.index_select( - node_ebd.reshape(-1, self.n_dim), 0, n2e_index - ), - nei_node_ebd, - edge_ebd, - ], - dim=-1, - ) - node_edge_update = self.act( - self.node_edge_linear(edge_info) - ) * sw.unsqueeze(-1) - else: - node_edge_update = self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "node", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "node", - ) - ) * sw.unsqueeze(-1) - - node_edge_update = ( - (torch.sum(node_edge_update, dim=-2) / self.nnei) - if not self.use_dynamic_sel - else ( - aggregate( - node_edge_update, - n2e_index, - average=False, - num_owner=nb * nloc, - ).reshape(nb, nloc, node_edge_update.shape[-1]) - / self.dynamic_e_sel - ) - ) - return node_edge_update - - def _compute_node_sym( - self, - edge_ebd: torch.Tensor, - nei_node_ebd: torch.Tensor, - h2: torch.Tensor, - nlist_mask: torch.Tensor, - sw: torch.Tensor, - n2e_index: torch.Tensor, - nb: int, - nloc: int, - ) -> torch.Tensor: - """Compute node symmetrization update (grrg + drrd).""" - node_sym_list: list[torch.Tensor] = [] - node_sym_list.append( - self.symmetrization_op( - edge_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( - edge_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) - ) - node_sym_list.append( - self.symmetrization_op( - nei_node_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else self.symmetrization_op_dynamic( - nei_node_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) - ) - return self.act(self.node_sym_linear(torch.cat(node_sym_list, dim=-1))) - @torch.jit.export def list_update_res_avg( self, From 54236c6768d3999c6d385c09fd7f3326b06029da Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Sun, 10 May 2026 17:38:38 +0800 Subject: [PATCH 9/9] Update repflows.py --- deepmd/dpmodel/descriptor/repflows.py | 1277 +++++++++++++------------ 1 file changed, 672 insertions(+), 605 deletions(-) diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 1af1d15bd9..f328ee75fa 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -1354,289 +1354,449 @@ def optim_edge_update_dynamic( result_update = bias + sub_node_update + sub_edge_update + sub_node_ext_update return result_update - def _compute_edge_self_update( + def call( self, - xp: object, - node_ebd: Array, - node_ebd_ext: Array, - edge_ebd: Array, - nei_node_ebd: Array, - nlist: Array, - n2e_index: Array, - n_ext2e_index: Array, - nb: int, - nloc: int, - ) -> Array: - """Compute edge self update.""" - if not self.optim_update: - if not self.use_dynamic_sel: - edge_info = xp.concat( - [ - xp.tile( - xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), - (1, 1, self.nnei, 1), - ), - nei_node_ebd, - edge_ebd, - ], - axis=-1, - ) - else: - edge_info = xp.concat( - [ - xp.take( - xp.reshape(node_ebd, (-1, self.n_dim)), - n2e_index, - axis=0, - ), - nei_node_ebd, - edge_ebd, - ], - axis=-1, - ) - return self.act(self.edge_self_linear(edge_info)) + node_ebd_ext: Array, # nf x nall x n_dim + edge_ebd: Array, # nf x nloc x nnei x e_dim + h2: Array, # nf x nloc x nnei x 3 + angle_ebd: Array, # nf x nloc x a_nnei x a_nnei x a_dim + nlist: Array, # nf x nloc x nnei + nlist_mask: Array, # nf x nloc x nnei + sw: Array, # switch func, nf x nloc x nnei + a_nlist: Array, # nf x nloc x a_nnei + a_nlist_mask: Array, # nf x nloc x a_nnei + a_sw: Array, # switch func, nf x nloc x a_nnei + edge_index: Array, # 2 x n_edge + angle_index: Array, # 3 x n_angle + ) -> tuple[Array, Array, Array]: + """ + Parameters + ---------- + node_ebd_ext : nf x nall x n_dim + Extended node embedding. + edge_ebd : nf x nloc x nnei x e_dim + Edge embedding. + h2 : nf x nloc x nnei x 3 + Pair-atom channel, equivariant. + angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim + Angle embedding. + nlist : nf x nloc x nnei + Neighbor list. (padded neis are set to 0) + nlist_mask : nf x nloc x nnei + Masks of the neighbor list. real nei 1 otherwise 0 + sw : nf x nloc x nnei + Switch function. + a_nlist : nf x nloc x a_nnei + Neighbor list for angle. (padded neis are set to 0) + a_nlist_mask : nf x nloc x a_nnei + Masks of the neighbor list for angle. real nei 1 otherwise 0 + a_sw : nf x nloc x a_nnei + Switch function for angle. + edge_index : Optional for dynamic sel, 2 x n_edge + n2e_index : n_edge + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). + n_ext2e_index : n_edge + Broadcast indices from extended node(j) to edge(ij). + angle_index : Optional for dynamic sel, 3 x n_angle + n2a_index : n_angle + Broadcast indices from extended node(j) to angle(ijk). + eij2a_index : n_angle + Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). + eik2a_index : n_angle + Broadcast indices from extended edge(ik) to angle(ijk). + + Returns + ------- + n_updated: nf x nloc x n_dim + Updated node embedding. + e_updated: nf x nloc x nnei x e_dim + Updated edge embedding. + a_updated : nf x nloc x a_nnei x a_nnei x a_dim + Updated angle embedding. + """ + xp = array_api_compat.array_namespace( + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nlist, + nlist_mask, + sw, + a_nlist, + a_nlist_mask, + a_sw, + edge_index, + angle_index, + ) + nb, nloc, nnei = nlist.shape + nall = node_ebd_ext.shape[1] + # int cannot jit; do not run it when self.use_dynamic_sel == False + n_edge = ( + int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0 + ) + node_ebd = xp_take_first_n(node_ebd_ext, 1, nloc) + assert (nb, nloc) == node_ebd.shape[:2] + if not self.use_dynamic_sel: + assert (nb, nloc, nnei) == h2.shape[:3] else: - return self.act( - self.optim_edge_update( - node_ebd, - node_ebd_ext, - edge_ebd, - nlist, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_edge_update_dynamic( - node_ebd, - node_ebd_ext, - edge_ebd, - n2e_index, - n_ext2e_index, - "edge", - ) + assert (n_edge, 3) == h2.shape + del a_nlist # may be used in the future + + n2e_index, n_ext2e_index = edge_index[0, :], edge_index[1, :] + n2a_index, eij2a_index, eik2a_index = ( + angle_index[0, :], + angle_index[1, :], + angle_index[2, :], + ) + + # nb x nloc x nnei x n_dim [OR] n_edge x n_dim + nei_node_ebd = ( + _make_nei_g1(node_ebd_ext, nlist) + if not self.use_dynamic_sel + else xp.take( + xp.reshape(node_ebd_ext, (-1, self.n_dim)), n_ext2e_index, axis=0 ) + ) - def _prepare_angle_embeddings( - self, - xp: object, - node_ebd: Array, - edge_ebd: Array, - a_nlist_mask: Array, - ) -> tuple[Array, Array]: - """Prepare compressed node/edge embeddings for angle computation.""" - if self.a_compress_rate != 0: - if not self.a_compress_use_split: - assert self.a_compress_n_linear is not None - assert self.a_compress_e_linear is not None - node_ebd_for_angle = self.a_compress_n_linear(node_ebd) - edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) - else: - node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] - edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] - else: - node_ebd_for_angle = node_ebd - edge_ebd_for_angle = edge_ebd + if self.sequential_update and self.update_angle: + return self._call_sequential( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + h2, + angle_ebd, + nei_node_ebd, + nlist, + nlist_mask, + sw, + a_nlist_mask, + a_sw, + n2e_index, + n_ext2e_index, + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + n_edge, + ) - if not self.use_dynamic_sel: - edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] - edge_ebd_for_angle = xp.where( - xp.expand_dims(a_nlist_mask, axis=-1), - edge_ebd_for_angle, - xp.zeros_like(edge_ebd_for_angle), + n_update_list: list[Array] = [node_ebd] + e_update_list: list[Array] = [edge_ebd] + a_update_list: list[Array] = [angle_ebd] + + # node self mlp + node_self_mlp = self.act(self.node_self_mlp(node_ebd)) + n_update_list.append(node_self_mlp) + + # node sym (grrg + drrd) + node_sym = self._compute_node_sym( + xp, + edge_ebd, + nei_node_ebd, + h2, + nlist_mask, + sw, + n2e_index, + nb, + nloc, + ) + n_update_list.append(node_sym) + + # node edge message + node_edge_update = self._compute_node_edge_message( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + if self.n_multi_edge_message > 1: + # nb x nloc x h x n_dim + node_edge_update_mul_head = xp.reshape( + node_edge_update, (nb, nloc, self.n_multi_edge_message, self.n_dim) ) - return node_ebd_for_angle, edge_ebd_for_angle + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) + else: + n_update_list.append(node_edge_update) + # update node_ebd + n_updated = self.list_update(n_update_list, "node") - def _compute_angle_update( - self, - xp: object, - angle_ebd: Array, - node_ebd_for_angle: Array, - edge_ebd_for_angle: Array, - feat: str, - n2a_index: Array, - eij2a_index: Array, - eik2a_index: Array, - nb: int, - nloc: int, - ) -> Array: - """Compute angle-based update (for edge_angle or angle_self). + # edge self message + edge_self_update = self._compute_edge_self_update( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + e_update_list.append(edge_self_update) - Parameters - ---------- - feat : str - "edge" for edge_angle_linear1, "angle" for angle_self_linear. - """ - if not self.optim_update: - node_for_angle_info = ( - xp.tile( - xp.reshape( - node_ebd_for_angle, - (nb, nloc, 1, 1, self.n_a_compress_dim), - ), - (1, 1, self.a_sel, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else xp.take( - xp.reshape(node_ebd_for_angle, (-1, self.n_a_compress_dim)), - n2a_index, - axis=0, - ) + if self.update_angle: + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None + assert self.edge_angle_linear2 is not None + + node_ebd_for_angle, edge_ebd_for_angle = self._prepare_angle_embeddings( + xp, node_ebd, edge_ebd, a_nlist_mask ) - edge_for_angle_k = ( - xp.tile( - xp.reshape( - edge_ebd_for_angle, - (nb, nloc, 1, self.a_sel, self.e_a_compress_dim), - ), - (1, 1, self.a_sel, 1, 1), - ) - if not self.use_dynamic_sel - else xp.take( - edge_ebd_for_angle, - eik2a_index, - axis=0, - ) + + # edge angle message + edge_angle_update = self._compute_angle_update( + xp, + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, ) - edge_for_angle_j = ( - xp.tile( - xp.reshape( - edge_ebd_for_angle, - (nb, nloc, self.a_sel, 1, self.e_a_compress_dim), - ), - (1, 1, 1, self.a_sel, 1), - ) - if not self.use_dynamic_sel - else xp.take( - edge_ebd_for_angle, + e_update_list.append( + self._compute_edge_angle_reduction( + xp, + edge_angle_update, + edge_ebd, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, eij2a_index, - axis=0, ) ) - edge_for_angle_info = xp.concat( - [edge_for_angle_k, edge_for_angle_j], axis=-1 - ) - angle_info = xp.concat( - [angle_ebd, node_for_angle_info, edge_for_angle_info], axis=-1 + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") + + # angle self message + angle_self_update = self._compute_angle_update( + xp, + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, ) - if feat == "edge": - assert self.edge_angle_linear1 is not None - return self.act(self.edge_angle_linear1(angle_info)) - else: - assert self.angle_self_linear is not None - return self.act(self.angle_self_linear(angle_info)) + a_update_list.append(angle_self_update) else: - if feat == "edge": - return self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "edge", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "edge", - ) - ) - else: - return self.act( - self.optim_angle_update( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - "angle", - ) - if not self.use_dynamic_sel - else self.optim_angle_update_dynamic( - angle_ebd, - node_ebd_for_angle, - edge_ebd_for_angle, - n2a_index, - eij2a_index, - eik2a_index, - "angle", - ) - ) + # update edge_ebd + e_updated = self.list_update(e_update_list, "edge") + + # update angle_ebd + a_updated = self.list_update(a_update_list, "angle") + return n_updated, e_updated, a_updated - def _compute_edge_angle_reduction( + def _call_sequential( self, xp: object, - edge_angle_update: Array, - edge_ebd_fallback: Array, - a_sw: Array, + node_ebd: Array, + node_ebd_ext: Array, + edge_ebd: Array, + h2: Array, + angle_ebd: Array, + nei_node_ebd: Array, + nlist: Array, + nlist_mask: Array, + sw: Array, a_nlist_mask: Array, + a_sw: Array, + n2e_index: Array, + n_ext2e_index: Array, + n2a_index: Array, + eij2a_index: Array, + eik2a_index: Array, nb: int, nloc: int, n_edge: int, - eij2a_index: Array, - ) -> Array: - """Reduce edge angle update over angle dimension, pad, and apply linear2. + ) -> tuple[Array, Array, Array]: + """Sequential update path: edge_self -> angle_self -> edge_angle -> node. - Parameters - ---------- - edge_ebd_fallback : Array - Edge embedding used for non-smooth padding fallback. + Each phase consumes the most recent update of its inputs (instead of + the parallel path which uses the same ``edge_ebd``/``angle_ebd`` for + all branches). Residuals are applied immediately after each phase. """ + assert self.update_style == "res_residual" + assert self.update_angle + assert self.angle_self_linear is not None + assert self.edge_angle_linear1 is not None assert self.edge_angle_linear2 is not None - if not self.use_dynamic_sel: - weighted_edge_angle_update = ( - a_sw[:, :, :, xp.newaxis, xp.newaxis] - * a_sw[:, :, xp.newaxis, :, xp.newaxis] - * edge_angle_update - ) - reduced_edge_angle_update = xp.sum(weighted_edge_angle_update, axis=-2) / ( - self.a_sel**0.5 + + # Phase 1: edge self update -> apply residual immediately. + edge_self_update = self._compute_edge_self_update( + xp, + node_ebd, + node_ebd_ext, + edge_ebd, + nei_node_ebd, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update + + # Phase 2: angle self message uses the updated edge embedding. + node_ebd_for_angle, edge_ebd_for_angle = self._prepare_angle_embeddings( + xp, node_ebd, edge_ebd_s1, a_nlist_mask + ) + angle_self_update = self._compute_angle_update( + xp, + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + ) + a_updated = angle_ebd + self.a_residual[0] * angle_self_update + + # Phase 3: edge angle message uses the updated angle embedding. + edge_angle_update = self._compute_angle_update( + xp, + a_updated, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + n2a_index, + eij2a_index, + eik2a_index, + nb, + nloc, + ) + edge_angle_reduced = self._compute_edge_angle_reduction( + xp, + edge_angle_update, + edge_ebd_s1, + a_sw, + a_nlist_mask, + nb, + nloc, + n_edge, + eij2a_index, + ) + e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_reduced + + # Phase 4: node updates use the fully updated edge embedding. + n_update_list: list[Array] = [node_ebd] + n_update_list.append(self.act(self.node_self_mlp(node_ebd))) + n_update_list.append( + self._compute_node_sym( + xp, + e_updated, + nei_node_ebd, + h2, + nlist_mask, + sw, + n2e_index, + nb, + nloc, ) - padding_edge_angle_update = xp.concat( - [ - reduced_edge_angle_update, - xp.zeros( - (nb, nloc, self.nnei - self.a_sel, self.e_dim), - dtype=edge_ebd_fallback.dtype, - device=array_api_compat.device(edge_ebd_fallback), - ), - ], - axis=2, + ) + node_edge_update = self._compute_node_edge_message( + xp, + node_ebd, + node_ebd_ext, + e_updated, + nei_node_ebd, + sw, + nlist, + n2e_index, + n_ext2e_index, + nb, + nloc, + ) + if self.n_multi_edge_message > 1: + node_edge_update_mul_head = xp.reshape( + node_edge_update, + (nb, nloc, self.n_multi_edge_message, self.n_dim), ) + for head_index in range(self.n_multi_edge_message): + n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) else: - weighted_edge_angle_update = edge_angle_update * xp.expand_dims( - a_sw, axis=-1 - ) - padding_edge_angle_update = aggregate( - weighted_edge_angle_update, - eij2a_index, - average=False, - num_owner=n_edge, - ) / (self.dynamic_a_sel**0.5) + n_update_list.append(node_edge_update) + n_updated = self.list_update(n_update_list, "node") + return n_updated, e_updated, a_updated - if not self.smooth_edge_update: - if self.use_dynamic_sel: - raise NotImplementedError( - "smooth_edge_update must be True when use_dynamic_sel is True!" - ) - full_mask = xp.concat( - [ - a_nlist_mask, - xp.zeros( - (nb, nloc, self.nnei - self.a_sel), - dtype=a_nlist_mask.dtype, - device=array_api_compat.device(a_nlist_mask), - ), - ], - axis=-1, + def _compute_node_sym( + self, + xp: object, + edge_ebd: Array, + nei_node_ebd: Array, + h2: Array, + nlist_mask: Array, + sw: Array, + n2e_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute node symmetrization update (grrg + drrd).""" + node_sym_list: list[Array] = [] + node_sym_list.append( + symmetrization_op( + edge_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, ) - padding_edge_angle_update = xp.where( - xp.expand_dims(full_mask, axis=-1), - padding_edge_angle_update, - edge_ebd_fallback, + if not self.use_dynamic_sel + else symmetrization_op_dynamic( + edge_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, ) - - return self.act(self.edge_angle_linear2(padding_edge_angle_update)) + ) + node_sym_list.append( + symmetrization_op( + nei_node_ebd, + h2, + nlist_mask, + sw, + self.axis_neuron, + ) + if not self.use_dynamic_sel + else symmetrization_op_dynamic( + nei_node_ebd, + h2, + sw, + owner=n2e_index, + num_owner=nb * nloc, + nb=nb, + nloc=nloc, + scale_factor=self.dynamic_e_sel ** (-0.5), + axis_neuron=self.axis_neuron, + ) + ) + return self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) def _compute_node_edge_message( self, @@ -1715,387 +1875,294 @@ def _compute_node_edge_message( ), (nb, nloc, node_edge_update.shape[-1]), ) - / self.dynamic_e_sel + / self.dynamic_e_sel + ) + ) + return node_edge_update + + def _compute_edge_self_update( + self, + xp: object, + node_ebd: Array, + node_ebd_ext: Array, + edge_ebd: Array, + nei_node_ebd: Array, + nlist: Array, + n2e_index: Array, + n_ext2e_index: Array, + nb: int, + nloc: int, + ) -> Array: + """Compute edge self update.""" + if not self.optim_update: + if not self.use_dynamic_sel: + edge_info = xp.concat( + [ + xp.tile( + xp.reshape(node_ebd, (nb, nloc, 1, self.n_dim)), + (1, 1, self.nnei, 1), + ), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) + else: + edge_info = xp.concat( + [ + xp.take( + xp.reshape(node_ebd, (-1, self.n_dim)), + n2e_index, + axis=0, + ), + nei_node_ebd, + edge_ebd, + ], + axis=-1, + ) + return self.act(self.edge_self_linear(edge_info)) + else: + return self.act( + self.optim_edge_update( + node_ebd, + node_ebd_ext, + edge_ebd, + nlist, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_edge_update_dynamic( + node_ebd, + node_ebd_ext, + edge_ebd, + n2e_index, + n_ext2e_index, + "edge", + ) ) - ) - return node_edge_update - def _compute_node_sym( + def _prepare_angle_embeddings( self, xp: object, + node_ebd: Array, edge_ebd: Array, - nei_node_ebd: Array, - h2: Array, - nlist_mask: Array, - sw: Array, - n2e_index: Array, + a_nlist_mask: Array, + ) -> tuple[Array, Array]: + """Prepare compressed node/edge embeddings for angle computation.""" + if self.a_compress_rate != 0: + if not self.a_compress_use_split: + assert self.a_compress_n_linear is not None + assert self.a_compress_e_linear is not None + node_ebd_for_angle = self.a_compress_n_linear(node_ebd) + edge_ebd_for_angle = self.a_compress_e_linear(edge_ebd) + else: + node_ebd_for_angle = node_ebd[..., : self.n_a_compress_dim] + edge_ebd_for_angle = edge_ebd[..., : self.e_a_compress_dim] + else: + node_ebd_for_angle = node_ebd + edge_ebd_for_angle = edge_ebd + + if not self.use_dynamic_sel: + edge_ebd_for_angle = edge_ebd_for_angle[..., : self.a_sel, :] + edge_ebd_for_angle = xp.where( + xp.expand_dims(a_nlist_mask, axis=-1), + edge_ebd_for_angle, + xp.zeros_like(edge_ebd_for_angle), + ) + return node_ebd_for_angle, edge_ebd_for_angle + + def _compute_angle_update( + self, + xp: object, + angle_ebd: Array, + node_ebd_for_angle: Array, + edge_ebd_for_angle: Array, + feat: str, + n2a_index: Array, + eij2a_index: Array, + eik2a_index: Array, nb: int, nloc: int, ) -> Array: - """Compute node symmetrization update (grrg + drrd).""" - node_sym_list: list[Array] = [] - node_sym_list.append( - symmetrization_op( - edge_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else symmetrization_op_dynamic( - edge_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) - ) - node_sym_list.append( - symmetrization_op( - nei_node_ebd, - h2, - nlist_mask, - sw, - self.axis_neuron, - ) - if not self.use_dynamic_sel - else symmetrization_op_dynamic( - nei_node_ebd, - h2, - sw, - owner=n2e_index, - num_owner=nb * nloc, - nb=nb, - nloc=nloc, - scale_factor=self.dynamic_e_sel ** (-0.5), - axis_neuron=self.axis_neuron, - ) - ) - return self.act(self.node_sym_linear(xp.concat(node_sym_list, axis=-1))) + """Compute angle-based update (for edge_angle or angle_self). - def call( - self, - node_ebd_ext: Array, # nf x nall x n_dim - edge_ebd: Array, # nf x nloc x nnei x e_dim - h2: Array, # nf x nloc x nnei x 3 - angle_ebd: Array, # nf x nloc x a_nnei x a_nnei x a_dim - nlist: Array, # nf x nloc x nnei - nlist_mask: Array, # nf x nloc x nnei - sw: Array, # switch func, nf x nloc x nnei - a_nlist: Array, # nf x nloc x a_nnei - a_nlist_mask: Array, # nf x nloc x a_nnei - a_sw: Array, # switch func, nf x nloc x a_nnei - edge_index: Array, # 2 x n_edge - angle_index: Array, # 3 x n_angle - ) -> tuple[Array, Array, Array]: - """ Parameters ---------- - node_ebd_ext : nf x nall x n_dim - Extended node embedding. - edge_ebd : nf x nloc x nnei x e_dim - Edge embedding. - h2 : nf x nloc x nnei x 3 - Pair-atom channel, equivariant. - angle_ebd : nf x nloc x a_nnei x a_nnei x a_dim - Angle embedding. - nlist : nf x nloc x nnei - Neighbor list. (padded neis are set to 0) - nlist_mask : nf x nloc x nnei - Masks of the neighbor list. real nei 1 otherwise 0 - sw : nf x nloc x nnei - Switch function. - a_nlist : nf x nloc x a_nnei - Neighbor list for angle. (padded neis are set to 0) - a_nlist_mask : nf x nloc x a_nnei - Masks of the neighbor list for angle. real nei 1 otherwise 0 - a_sw : nf x nloc x a_nnei - Switch function for angle. - edge_index : Optional for dynamic sel, 2 x n_edge - n2e_index : n_edge - Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). - n_ext2e_index : n_edge - Broadcast indices from extended node(j) to edge(ij). - angle_index : Optional for dynamic sel, 3 x n_angle - n2a_index : n_angle - Broadcast indices from extended node(j) to angle(ijk). - eij2a_index : n_angle - Broadcast indices from extended edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). - eik2a_index : n_angle - Broadcast indices from extended edge(ik) to angle(ijk). - - Returns - ------- - n_updated: nf x nloc x n_dim - Updated node embedding. - e_updated: nf x nloc x nnei x e_dim - Updated edge embedding. - a_updated : nf x nloc x a_nnei x a_nnei x a_dim - Updated angle embedding. + feat : str + "edge" for edge_angle_linear1, "angle" for angle_self_linear. """ - xp = array_api_compat.array_namespace( - node_ebd_ext, - edge_ebd, - h2, - angle_ebd, - nlist, - nlist_mask, - sw, - a_nlist, - a_nlist_mask, - a_sw, - edge_index, - angle_index, - ) - nb, nloc, nnei = nlist.shape - nall = node_ebd_ext.shape[1] - # int cannot jit; do not run it when self.use_dynamic_sel == False - n_edge = ( - int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0 - ) - node_ebd = xp_take_first_n(node_ebd_ext, 1, nloc) - assert (nb, nloc) == node_ebd.shape[:2] - if not self.use_dynamic_sel: - assert (nb, nloc, nnei) == h2.shape[:3] - else: - assert (n_edge, 3) == h2.shape - del a_nlist # may be used in the future - - n2e_index, n_ext2e_index = edge_index[0, :], edge_index[1, :] - n2a_index, eij2a_index, eik2a_index = ( - angle_index[0, :], - angle_index[1, :], - angle_index[2, :], - ) - - # nb x nloc x nnei x n_dim [OR] n_edge x n_dim - nei_node_ebd = ( - _make_nei_g1(node_ebd_ext, nlist) - if not self.use_dynamic_sel - else xp.take( - xp.reshape(node_ebd_ext, (-1, self.n_dim)), n_ext2e_index, axis=0 - ) - ) - - # Edge self update (always from original embeddings) - edge_self_update = self._compute_edge_self_update( - xp, - node_ebd, - node_ebd_ext, - edge_ebd, - nei_node_ebd, - nlist, - n2e_index, - n_ext2e_index, - nb, - nloc, - ) - - if self.sequential_update and self.update_angle: - # === Sequential update path === - # Phase 1: Apply edge self residual - edge_ebd_s1 = edge_ebd + self.e_residual[0] * edge_self_update - - # Phase 2: Angle self (uses updated edge_ebd_s1) - node_for_a, edge_for_a = self._prepare_angle_embeddings( - xp, node_ebd, edge_ebd_s1, a_nlist_mask - ) - angle_self_update = self._compute_angle_update( - xp, - angle_ebd, - node_for_a, - edge_for_a, - "angle", - n2a_index, - eij2a_index, - eik2a_index, - nb, - nloc, + if not self.optim_update: + node_for_angle_info = ( + xp.tile( + xp.reshape( + node_ebd_for_angle, + (nb, nloc, 1, 1, self.n_a_compress_dim), + ), + (1, 1, self.a_sel, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else xp.take( + xp.reshape(node_ebd_for_angle, (-1, self.n_a_compress_dim)), + n2a_index, + axis=0, + ) ) - a_updated = angle_ebd + self.a_residual[0] * angle_self_update - - # Phase 3: Edge angle (uses updated angle a_updated + edge_ebd_s1) - edge_angle_update = self._compute_angle_update( - xp, - a_updated, - node_for_a, - edge_for_a, - "edge", - n2a_index, - eij2a_index, - eik2a_index, - nb, - nloc, + edge_for_angle_k = ( + xp.tile( + xp.reshape( + edge_ebd_for_angle, + (nb, nloc, 1, self.a_sel, self.e_a_compress_dim), + ), + (1, 1, self.a_sel, 1, 1), + ) + if not self.use_dynamic_sel + else xp.take( + edge_ebd_for_angle, + eik2a_index, + axis=0, + ) ) - edge_angle_processed = self._compute_edge_angle_reduction( - xp, - edge_angle_update, - edge_ebd_s1, - a_sw, - a_nlist_mask, - nb, - nloc, - n_edge, - eij2a_index, + edge_for_angle_j = ( + xp.tile( + xp.reshape( + edge_ebd_for_angle, + (nb, nloc, self.a_sel, 1, self.e_a_compress_dim), + ), + (1, 1, 1, self.a_sel, 1), + ) + if not self.use_dynamic_sel + else xp.take( + edge_ebd_for_angle, + eij2a_index, + axis=0, + ) ) - e_updated = edge_ebd_s1 + self.e_residual[1] * edge_angle_processed - - # Phase 4+5: Node updates (uses e_updated) - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - node_sym = self._compute_node_sym( - xp, - e_updated, - nei_node_ebd, - h2, - nlist_mask, - sw, - n2e_index, - nb, - nloc, + edge_for_angle_info = xp.concat( + [edge_for_angle_k, edge_for_angle_j], axis=-1 ) - node_edge_update = self._compute_node_edge_message( - xp, - node_ebd, - node_ebd_ext, - e_updated, - nei_node_ebd, - sw, - nlist, - n2e_index, - n_ext2e_index, - nb, - nloc, + angle_info = xp.concat( + [angle_ebd, node_for_angle_info, edge_for_angle_info], axis=-1 ) - - n_update_list: list[Array] = [node_ebd, node_self_mlp, node_sym] - if self.n_multi_edge_message > 1: - node_edge_update_mul_head = xp.reshape( - node_edge_update, - (nb, nloc, self.n_multi_edge_message, self.n_dim), + if feat == "edge": + assert self.edge_angle_linear1 is not None + return self.act(self.edge_angle_linear1(angle_info)) + else: + assert self.angle_self_linear is not None + return self.act(self.angle_self_linear(angle_info)) + else: + if feat == "edge": + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "edge", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "edge", + ) ) - for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) else: - n_update_list.append(node_edge_update) - n_updated = self.list_update(n_update_list, "node") - - return n_updated, e_updated, a_updated - - # === Parallel update path === - n_update_list: list[Array] = [node_ebd] - e_update_list: list[Array] = [edge_ebd] - a_update_list: list[Array] = [angle_ebd] - - # node self mlp - node_self_mlp = self.act(self.node_self_mlp(node_ebd)) - n_update_list.append(node_self_mlp) + return self.act( + self.optim_angle_update( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + "angle", + ) + if not self.use_dynamic_sel + else self.optim_angle_update_dynamic( + angle_ebd, + node_ebd_for_angle, + edge_ebd_for_angle, + n2a_index, + eij2a_index, + eik2a_index, + "angle", + ) + ) - # node sym (grrg + drrd) - node_sym = self._compute_node_sym( - xp, - edge_ebd, - nei_node_ebd, - h2, - nlist_mask, - sw, - n2e_index, - nb, - nloc, - ) - n_update_list.append(node_sym) + def _compute_edge_angle_reduction( + self, + xp: object, + edge_angle_update: Array, + edge_ebd_fallback: Array, + a_sw: Array, + a_nlist_mask: Array, + nb: int, + nloc: int, + n_edge: int, + eij2a_index: Array, + ) -> Array: + """Reduce edge angle update over angle dimension, pad, and apply linear2. - # node edge message - node_edge_update = self._compute_node_edge_message( - xp, - node_ebd, - node_ebd_ext, - edge_ebd, - nei_node_ebd, - sw, - nlist, - n2e_index, - n_ext2e_index, - nb, - nloc, - ) - if self.n_multi_edge_message > 1: - # nb x nloc x h x n_dim - node_edge_update_mul_head = xp.reshape( - node_edge_update, (nb, nloc, self.n_multi_edge_message, self.n_dim) + Parameters + ---------- + edge_ebd_fallback : Array + Edge embedding used for non-smooth padding fallback. + """ + assert self.edge_angle_linear2 is not None + if not self.use_dynamic_sel: + weighted_edge_angle_update = ( + a_sw[:, :, :, xp.newaxis, xp.newaxis] + * a_sw[:, :, xp.newaxis, :, xp.newaxis] + * edge_angle_update ) - for head_index in range(self.n_multi_edge_message): - n_update_list.append(node_edge_update_mul_head[:, :, head_index, :]) - else: - n_update_list.append(node_edge_update) - # update node_ebd - n_updated = self.list_update(n_update_list, "node") - - # edge self message - e_update_list.append(edge_self_update) - - if self.update_angle: - assert self.angle_self_linear is not None - assert self.edge_angle_linear1 is not None - assert self.edge_angle_linear2 is not None - - node_for_a, edge_for_a = self._prepare_angle_embeddings( - xp, node_ebd, edge_ebd, a_nlist_mask + reduced_edge_angle_update = xp.sum(weighted_edge_angle_update, axis=-2) / ( + self.a_sel**0.5 ) - - # edge angle message - edge_angle_update = self._compute_angle_update( - xp, - angle_ebd, - node_for_a, - edge_for_a, - "edge", - n2a_index, - eij2a_index, - eik2a_index, - nb, - nloc, + padding_edge_angle_update = xp.concat( + [ + reduced_edge_angle_update, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel, self.e_dim), + dtype=edge_ebd_fallback.dtype, + device=array_api_compat.device(edge_ebd_fallback), + ), + ], + axis=2, ) - edge_angle_processed = self._compute_edge_angle_reduction( - xp, - edge_angle_update, - edge_ebd, - a_sw, - a_nlist_mask, - nb, - nloc, - n_edge, - eij2a_index, + else: + weighted_edge_angle_update = edge_angle_update * xp.expand_dims( + a_sw, axis=-1 ) - e_update_list.append(edge_angle_processed) - # update edge_ebd - e_updated = self.list_update(e_update_list, "edge") - - # angle self message - angle_self_update = self._compute_angle_update( - xp, - angle_ebd, - node_for_a, - edge_for_a, - "angle", - n2a_index, + padding_edge_angle_update = aggregate( + weighted_edge_angle_update, eij2a_index, - eik2a_index, - nb, - nloc, + average=False, + num_owner=n_edge, + ) / (self.dynamic_a_sel**0.5) + + if not self.smooth_edge_update: + if self.use_dynamic_sel: + raise NotImplementedError( + "smooth_edge_update must be True when use_dynamic_sel is True!" + ) + full_mask = xp.concat( + [ + a_nlist_mask, + xp.zeros( + (nb, nloc, self.nnei - self.a_sel), + dtype=a_nlist_mask.dtype, + device=array_api_compat.device(a_nlist_mask), + ), + ], + axis=-1, + ) + padding_edge_angle_update = xp.where( + xp.expand_dims(full_mask, axis=-1), + padding_edge_angle_update, + edge_ebd_fallback, ) - a_update_list.append(angle_self_update) - else: - # update edge_ebd - e_updated = self.list_update(e_update_list, "edge") - # update angle_ebd - a_updated = self.list_update(a_update_list, "angle") - return n_updated, e_updated, a_updated + return self.act(self.edge_angle_linear2(padding_edge_angle_update)) def list_update_res_avg( self,