Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ def __init__(
if hasattr(self.fitting_net, "reinit_exclude"):
self.fitting_net.reinit_exclude(self.atom_exclude_types)
self.type_map = type_map
self.add_chg_spin_ebd: bool = getattr(
self.descriptor, "add_chg_spin_ebd", False
)
super().init_out_stat()

def fitting_output_def(self) -> FittingOutputDef:
Expand Down Expand Up @@ -180,11 +183,38 @@ def forward_atomic(
"""
nframes, nloc, nnei = nlist.shape
atype = xp_take_first_n(extended_atype, 1, nloc)

# Handle default fparam if fitting net supports it
if (
Comment thread
iProzd marked this conversation as resolved.
hasattr(self.fitting_net, "get_dim_fparam")
and self.fitting_net.get_dim_fparam() > 0
and fparam is None
):
# use default fparam
from deepmd.dpmodel.array_api import (
array_api_compat,
)

default_fparam = self.fitting_net.get_default_fparam()
assert default_fparam is not None
xp = array_api_compat.array_namespace(extended_coord)
default_fparam_array = xp.asarray(
default_fparam,
dtype=extended_coord.dtype,
device=array_api_compat.device(extended_coord),
)
fparam_input_for_des = xp.tile(
xp.reshape(default_fparam_array, (1, -1)), (nframes, 1)
)
else:
fparam_input_for_des = fparam
Comment thread
iProzd marked this conversation as resolved.

descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
)
Comment thread
iProzd marked this conversation as resolved.
ret = self.fitting_net(
descriptor,
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def call(
atype_ext: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
) -> Array:
"""Compute the descriptor.

Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,7 @@ def call(
atype_ext: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
) -> tuple[Array, Array, Array, Array, Array]:
"""Compute the descriptor.

Expand Down
70 changes: 70 additions & 0 deletions deepmd/dpmodel/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from deepmd.dpmodel.utils.network import (
NativeLayer,
get_activation_fn,
)
from deepmd.dpmodel.utils.seed import (
child_seed,
Expand Down Expand Up @@ -356,6 +357,7 @@ def __init__(
use_tebd_bias: bool = False,
use_loc_mapping: bool = True,
type_map: list[str] | None = None,
add_chg_spin_ebd: bool = False,
) -> None:
super().__init__()

Expand Down Expand Up @@ -410,6 +412,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
)

self.use_econf_tebd = use_econf_tebd
self.add_chg_spin_ebd = add_chg_spin_ebd
self.use_tebd_bias = use_tebd_bias
self.use_loc_mapping = use_loc_mapping
self.type_map = type_map
Expand All @@ -428,6 +431,38 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any:
)
self.concat_output_tebd = concat_output_tebd
self.precision = precision

if self.add_chg_spin_ebd:
self.cs_activation_fn = get_activation_fn(activation_function)
# -100 ~ 100 is a conservative bound
self.chg_embedding = TypeEmbedNet(
ntypes=200,
neuron=[self.tebd_dim],
padding=True,
activation_function="Linear",
precision=precision,
seed=child_seed(seed, 3),
)
# 100 is a conservative upper bound
self.spin_embedding = TypeEmbedNet(
ntypes=100,
neuron=[self.tebd_dim],
padding=True,
activation_function="Linear",
precision=precision,
seed=child_seed(seed, 4),
)
self.mix_cs_mlp = NativeLayer(
2 * self.tebd_dim,
self.tebd_dim,
precision=precision,
seed=child_seed(seed, 5),
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
else:
self.chg_embedding = None
self.spin_embedding = None
self.mix_cs_mlp = None

self.exclude_types = exclude_types
self.env_protection = env_protection
self.trainable = trainable
Expand Down Expand Up @@ -579,6 +614,7 @@ def call(
atype_ext: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
) -> tuple[Array, Array, Array, Array, Array]:
"""Compute the descriptor.

Expand Down Expand Up @@ -625,6 +661,27 @@ def call(
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
(nframes, nall, self.tebd_dim),
)

if self.add_chg_spin_ebd:
assert fparam is not None
assert self.chg_embedding is not None
assert self.spin_embedding is not None
chg_tebd = self.chg_embedding.call()
spin_tebd = self.spin_embedding.call()
charge = xp.astype(fparam[:, 0], xp.int64) + 100
spin = xp.astype(fparam[:, 1], xp.int64)
chg_ebd = xp.reshape(
xp.take(chg_tebd, xp.reshape(charge, (-1,)), axis=0),
(nframes, self.tebd_dim),
)
spin_ebd = xp.reshape(
xp.take(spin_tebd, xp.reshape(spin, (-1,)), axis=0),
(nframes, self.tebd_dim),
)
cs_cat = xp.concat([chg_ebd, spin_ebd], axis=-1)
sys_cs_embd = self.cs_activation_fn(self.mix_cs_mlp.call(cs_cat))
node_ebd_ext = node_ebd_ext + xp.expand_dims(sys_cs_embd, axis=1)

node_ebd_inp = node_ebd_ext[:, :nloc, :]
# repflows
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
Expand Down Expand Up @@ -655,9 +712,14 @@ def serialize(self) -> dict:
"use_econf_tebd": self.use_econf_tebd,
"use_tebd_bias": self.use_tebd_bias,
"use_loc_mapping": self.use_loc_mapping,
"add_chg_spin_ebd": self.add_chg_spin_ebd,
"type_map": self.type_map,
"type_embedding": self.type_embedding.serialize(),
}
if self.add_chg_spin_ebd:
data["chg_embedding"] = self.chg_embedding.serialize()
data["spin_embedding"] = self.spin_embedding.serialize()
data["mix_cs_mlp"] = self.mix_cs_mlp.serialize()
repflow_variable = {
"edge_embd": repflows.edge_embd.serialize(),
"angle_embd": repflows.angle_embd.serialize(),
Expand All @@ -684,10 +746,18 @@ def deserialize(cls, data: dict) -> "DescrptDPA3":
data.pop("type")
repflow_variable = data.pop("repflow_variable").copy()
type_embedding = data.pop("type_embedding")
chg_embedding = data.pop("chg_embedding", None)
spin_embedding = data.pop("spin_embedding", None)
mix_cs_mlp = data.pop("mix_cs_mlp", None)
data["repflow"] = RepFlowArgs(**data.pop("repflow_args"))
obj = cls(**data)
obj.type_embedding = TypeEmbedNet.deserialize(type_embedding)

if obj.add_chg_spin_ebd and chg_embedding is not None:
obj.chg_embedding = TypeEmbedNet.deserialize(chg_embedding)
obj.spin_embedding = TypeEmbedNet.deserialize(spin_embedding)
obj.mix_cs_mlp = NativeLayer.deserialize(mix_cs_mlp)

# deserialize repflow
statistic_repflows = repflow_variable.pop("@variables")
env_mat = repflow_variable.pop("env_mat")
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def call(
atype_ext: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
) -> tuple[
Array,
Array | None,
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def fwd(
extended_atype: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
) -> Array:
"""Calculate descriptor."""
pass
Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def call(
atype_ext: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
) -> Array:
"""Compute the descriptor.

Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def call(
atype_ext: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
) -> Array:
"""Compute the descriptor.

Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def call(
atype_ext: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
) -> tuple[Array, Array]:
"""Compute the descriptor.

Expand Down
1 change: 1 addition & 0 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def call(
atype_ext: Array,
nlist: Array,
mapping: Array | None = None,
fparam: Array | None = None,
) -> tuple[Array, Array]:
"""Compute the descriptor.

Expand Down
11 changes: 9 additions & 2 deletions deepmd/jax/descriptor/dpa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
flax_version,
nnx,
)
from deepmd.jax.utils.network import (
NativeLayer,
)
from deepmd.jax.utils.type_embed import (
TypeEmbedNet,
)
Expand All @@ -40,8 +43,12 @@ def __setattr__(self, name: str, value: Any) -> None:
value = nnx.data(value)
elif name in {"repflows"}:
value = DescrptBlockRepflows.deserialize(value.serialize())
elif name in {"type_embedding"}:
value = TypeEmbedNet.deserialize(value.serialize())
elif name in {"type_embedding", "chg_embedding", "spin_embedding"}:
if value is not None:
value = TypeEmbedNet.deserialize(value.serialize())
elif name in {"mix_cs_mlp"}:
if value is not None:
value = NativeLayer.deserialize(value.serialize())
else:
pass
return super().__setattr__(name, value)
20 changes: 20 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def __init__(
if hasattr(self.fitting_net, "reinit_exclude"):
self.fitting_net.reinit_exclude(self.atom_exclude_types)
super().init_out_stat()
self.add_chg_spin_ebd: bool = getattr(
self.descriptor, "add_chg_spin_ebd", False
)
self.enable_eval_descriptor_hook = False
self.enable_eval_fitting_last_layer_hook = False
self.eval_descriptor_list = []
Expand Down Expand Up @@ -269,12 +272,29 @@ def forward_atomic(
atype = extended_atype[:, :nloc]
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)

# Handle default fparam if fitting net supports it
if (
hasattr(self.fitting_net, "get_dim_fparam")
and self.fitting_net.get_dim_fparam() > 0
and fparam is None
):
# use default fparam
default_fparam_tensor = self.fitting_net.get_default_fparam()
assert default_fparam_tensor is not None
fparam_input_for_des = torch.tile(
default_fparam_tensor.unsqueeze(0), [nframes, 1]
)
else:
fparam_input_for_des = fparam
Comment thread
iProzd marked this conversation as resolved.

descriptor, rot_mat, g2, h2, sw = self.descriptor(
extended_coord,
extended_atype,
nlist,
mapping=mapping,
comm_dict=comm_dict,
fparam=fparam_input_for_des if self.add_chg_spin_ebd else None,
)
Comment thread
iProzd marked this conversation as resolved.
assert descriptor is not None
if self.enable_eval_descriptor_hook:
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def forward(
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
comm_dict: dict[str, torch.Tensor] | None = None,
fparam: torch.Tensor | None = None,
) -> tuple[
torch.Tensor,
torch.Tensor | None,
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ def forward(
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
comm_dict: dict[str, torch.Tensor] | None = None,
fparam: torch.Tensor | None = None,
Comment thread
iProzd marked this conversation as resolved.
) -> tuple[
torch.Tensor,
torch.Tensor | None,
Expand Down
Loading
Loading