diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index 180e5458fb..4761703a19 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -277,15 +277,14 @@ def compute_input_stats( f"numb_fparam > 0 but no fparam data is provided " f"for system {ii}." ) - cat_data = np.concatenate( - [frame["fparam"] for frame in sampled], axis=0 - ) - cat_data = np.reshape(cat_data, [-1, self.numb_fparam]) + xp_fp = array_api_compat.array_namespace(sampled[0]["fparam"]) + cat_data = xp_fp.concat([frame["fparam"] for frame in sampled], axis=0) + cat_data = xp_fp.reshape(cat_data, (-1, self.numb_fparam)) fparam_stats = [ StatItem( number=cat_data.shape[0], - sum=np.sum(cat_data[:, ii]), - squared_sum=np.sum(cat_data[:, ii] ** 2), + sum=float(xp_fp.sum(cat_data[:, ii])), + squared_sum=float(xp_fp.sum(cat_data[:, ii] ** 2)), ) for ii in range(self.numb_fparam) ] @@ -335,22 +334,23 @@ def compute_input_stats( f"numb_aparam > 0 but no aparam data is provided " f"for system {ii}." ) + xp_ap = array_api_compat.array_namespace(sampled[0]["aparam"]) sys_sumv = [] sys_sumv2 = [] sys_sumn = [] for ss_ in [frame["aparam"] for frame in sampled]: - ss = np.reshape(ss_, [-1, self.numb_aparam]) - sys_sumv.append(np.sum(ss, axis=0)) - sys_sumv2.append(np.sum(ss * ss, axis=0)) + ss = xp_ap.reshape(ss_, (-1, self.numb_aparam)) + sys_sumv.append(xp_ap.sum(ss, axis=0)) + sys_sumv2.append(xp_ap.sum(ss * ss, axis=0)) sys_sumn.append(ss.shape[0]) - sumv = np.sum(np.stack(sys_sumv), axis=0) - sumv2 = np.sum(np.stack(sys_sumv2), axis=0) + sumv = xp_ap.sum(xp_ap.stack(sys_sumv), axis=0) + sumv2 = xp_ap.sum(xp_ap.stack(sys_sumv2), axis=0) sumn = sum(sys_sumn) aparam_stats = [ StatItem( number=sumn, - sum=sumv[ii], - squared_sum=sumv2[ii], + sum=float(sumv[ii]), + squared_sum=float(sumv2[ii]), ) for ii in range(self.numb_aparam) ] diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 9dc83ac9d6..2e05754a76 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -68,6 +68,7 @@ def model_call_from_call_lower( fparam: Array | None = None, aparam: Array | None = None, do_atomic_virial: bool = False, + coord_corr_for_virial: Array | None = None, ) -> dict[str, Array]: """Return model prediction from lower interface. @@ -119,14 +120,33 @@ def model_call_from_call_lower( distinguish_types=False, ) extended_coord = extended_coord.reshape(nframes, -1, 3) + if coord_corr_for_virial is not None: + xp = array_api_compat.array_namespace(coord_corr_for_virial) + # mapping: nf x nall -> nf x nall x 1, then tile to nf x nall x 3 + mapping_idx = xp.tile( + xp.reshape(mapping, (nframes, -1, 1)), + (1, 1, 3), + ) + extended_coord_corr = xp.take_along_axis( + coord_corr_for_virial, + mapping_idx, + axis=1, + ) + else: + extended_coord_corr = None + call_lower_kwargs: dict[str, Any] = { + "fparam": fp, + "aparam": ap, + "do_atomic_virial": do_atomic_virial, + } + if extended_coord_corr is not None: + call_lower_kwargs["extended_coord_corr"] = extended_coord_corr model_predict_lower = call_lower( extended_coord, extended_atype, nlist, mapping, - fparam=fp, - aparam=ap, - do_atomic_virial=do_atomic_virial, + **call_lower_kwargs, ) model_predict = communicate_extended_output( model_predict_lower, @@ -237,6 +257,7 @@ def call_common( fparam: Array | None = None, aparam: Array | None = None, do_atomic_virial: bool = False, + coord_corr_for_virial: Array | None = None, ) -> dict[str, Array]: """Return model prediction. @@ -255,6 +276,9 @@ def call_common( atomic parameter. nf x nloc x nda do_atomic_virial If calculate the atomic virial. + coord_corr_for_virial + The coordinates correction for virial. + shape: nf x (nloc x 3) Returns ------- @@ -279,6 +303,7 @@ def call_common( fparam=fp, aparam=ap, do_atomic_virial=do_atomic_virial, + coord_corr_for_virial=coord_corr_for_virial, ) model_predict = self._output_type_cast(model_predict, input_prec) return model_predict @@ -292,6 +317,7 @@ def call_common_lower( fparam: Array | None = None, aparam: Array | None = None, do_atomic_virial: bool = False, + extended_coord_corr: Array | None = None, ) -> dict[str, Array]: """Return model prediction. Lower interface that takes extended atomic coordinates and types, nlist, and mapping @@ -314,6 +340,9 @@ def call_common_lower( atomic parameter. nf x nloc x nda do_atomic_virial whether calculate atomic virial + extended_coord_corr + coordinates correction for virial in extended region. + nf x (nall x 3) Returns ------- @@ -341,6 +370,7 @@ def call_common_lower( fparam=fp, aparam=ap, do_atomic_virial=do_atomic_virial, + extended_coord_corr=extended_coord_corr, ) model_predict = self._output_type_cast(model_predict, input_prec) return model_predict @@ -354,6 +384,7 @@ def forward_common_atomic( fparam: Array | None = None, aparam: Array | None = None, do_atomic_virial: bool = False, + extended_coord_corr: Array | None = None, ) -> dict[str, Array]: atomic_ret = self.atomic_model.forward_common_atomic( extended_coord, diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index b560366457..8f96e965b0 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -164,6 +164,7 @@ def get_spin_model(data: dict) -> SpinModel: data : dict The data to construct the model. """ + data = copy.deepcopy(data) # include virtual spin and placeholder types data["type_map"] += [item + "_spin" for item in data["type_map"]] spin = Spin( diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py index eda8318abc..a943434713 100644 --- a/deepmd/dpmodel/model/spin_model.py +++ b/deepmd/dpmodel/model/spin_model.py @@ -10,7 +10,7 @@ Any, ) -import numpy as np +import array_api_compat from deepmd.dpmodel.array_api import ( Array, @@ -67,17 +67,44 @@ def __init__( self.virtual_scale_mask = self.spin.get_virtual_scale_mask() self.spin_mask = self.spin.get_spin_mask() + def _to_xp(self, arr: Any, xp: Any, ref_arr: Any) -> Any: + """Convert a numpy array to the same namespace as ref_arr.""" + return xp.asarray(arr, device=array_api_compat.device(ref_arr)) + def process_spin_input( self, coord: Array, atype: Array, spin: Array - ) -> tuple[Array, Array]: - """Generate virtual coordinates and types, concat into the input.""" + ) -> tuple[Array, Array, Array]: + """Generate virtual coordinates and types, concat into the input. + + Returns + ------- + coord_spin : Array + Concatenated coordinates with shape (nframes, 2*nloc, 3). + atype_spin : Array + Concatenated atom types with shape (nframes, 2*nloc). + coord_corr : Array + Coordinate correction for virial with shape (nframes, 2*nloc, 3). + """ + xp = array_api_compat.array_namespace(coord) nframes, nloc = coord.shape[:-1] - atype_spin = np.concatenate([atype, atype + self.ntypes_real], axis=-1) - virtual_coord = coord + spin * self.virtual_scale_mask[atype].reshape( - [nframes, nloc, 1] + atype_spin = xp.concat([atype, atype + self.ntypes_real], axis=-1) + vsm = self._to_xp(self.virtual_scale_mask, xp, coord) + spin_dist = spin * xp.reshape(vsm[atype], (nframes, nloc, 1)) + virtual_coord = coord + spin_dist + coord_spin = xp.concat([coord, virtual_coord], axis=-2) + # for spin virial correction + coord_corr = xp.concat( + [ + xp.zeros( + coord.shape, + dtype=coord.dtype, + device=array_api_compat.device(coord), + ), + -spin_dist, + ], + axis=-2, ) - coord_spin = np.concatenate([coord, virtual_coord], axis=-2) - return coord_spin, atype_spin + return coord_spin, atype_spin, coord_corr def process_spin_input_lower( self, @@ -86,26 +113,53 @@ def process_spin_input_lower( extended_spin: Array, nlist: Array, mapping: Array | None = None, - ) -> tuple[Array, Array]: + ) -> tuple[Array, Array, Array, Array | None, Array]: """ Add `extended_spin` into `extended_coord` to generate virtual atoms, and extend `nlist` and `mapping`. - Note that the final `extended_coord_updated` with shape [nframes, nall + nall, 3] has the following order: + + Returns + ------- + extended_coord_updated : Array + Updated coordinates with virtual atoms, shape (nframes, 2*nall, 3). + extended_atype_updated : Array + Updated atom types with virtual atoms, shape (nframes, 2*nall). + nlist_updated : Array + Updated neighbor list including virtual atoms. + mapping_updated : Array or None + Updated mapping indices, or None if input mapping is None. + extended_coord_corr : Array + Coordinate correction for virial with shape (nframes, 2*nall, 3). + + Notes + ----- + The final `extended_coord_updated` with shape [nframes, nall + nall, 3] has the following order: - [:, :nloc]: original nloc real atoms. - [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms. - [:, nloc + nloc: nloc + nall]: ghost real atoms. - [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms. """ + xp = array_api_compat.array_namespace(extended_coord) nframes, nall = extended_coord.shape[:2] nloc = nlist.shape[1] - virtual_extended_coord = ( - extended_coord - + extended_spin - * self.virtual_scale_mask[extended_atype].reshape([nframes, nall, 1]) + vsm = self._to_xp(self.virtual_scale_mask, xp, extended_coord) + extended_spin_dist = extended_spin * xp.reshape( + vsm[extended_atype], (nframes, nall, 1) ) + virtual_extended_coord = extended_coord + extended_spin_dist virtual_extended_atype = extended_atype + self.ntypes_real extended_coord_updated = self.concat_switch_virtual( extended_coord, virtual_extended_coord, nloc ) + # for spin virial correction + extended_coord_corr = self.concat_switch_virtual( + xp.zeros( + extended_coord.shape, + dtype=extended_coord.dtype, + device=array_api_compat.device(extended_coord), + ), + -extended_spin_dist, + nloc, + ) extended_atype_updated = self.concat_switch_virtual( extended_atype, virtual_extended_atype, nloc ) @@ -121,6 +175,7 @@ def process_spin_input_lower( extended_atype_updated, nlist_updated, mapping_updated, + extended_coord_corr, ) def process_spin_output( @@ -131,18 +186,20 @@ def process_spin_output( virtual_scale: bool = True, ) -> tuple[Array, Array]: """Split the output both real and virtual atoms, and scale the latter.""" + xp = array_api_compat.array_namespace(out_tensor) nframes, nloc_double = out_tensor.shape[:2] nloc = nloc_double // 2 if virtual_scale: - virtual_scale_mask = self.virtual_scale_mask + mask = self._to_xp(self.virtual_scale_mask, xp, out_tensor) else: - virtual_scale_mask = self.spin_mask - atomic_mask = virtual_scale_mask[atype].reshape([nframes, nloc, 1]) - out_real, out_mag = np.split(out_tensor, [nloc], axis=1) + mask = self._to_xp(self.spin_mask, xp, out_tensor) + atomic_mask = xp.reshape(mask[atype], (nframes, nloc, 1)) + out_real, out_mag = out_tensor[:, :nloc], out_tensor[:, nloc:] if add_mag: out_real = out_real + out_mag - out_mag = (out_mag.reshape([nframes, nloc, -1]) * atomic_mask).reshape( - out_mag.shape + out_mag = xp.reshape( + xp.reshape(out_mag, (nframes, nloc, -1)) * atomic_mask, + out_mag.shape, ) return out_real, out_mag, atomic_mask > 0.0 @@ -155,21 +212,22 @@ def process_spin_output_lower( virtual_scale: bool = True, ) -> tuple[Array, Array]: """Split the extended output of both real and virtual atoms with switch, and scale the latter.""" + xp = array_api_compat.array_namespace(extended_out_tensor) nframes, nall_double = extended_out_tensor.shape[:2] nall = nall_double // 2 if virtual_scale: - virtual_scale_mask = self.virtual_scale_mask + mask = self._to_xp(self.virtual_scale_mask, xp, extended_out_tensor) else: - virtual_scale_mask = self.spin_mask - atomic_mask = virtual_scale_mask[extended_atype].reshape([nframes, nall, 1]) - extended_out_real = np.concatenate( + mask = self._to_xp(self.spin_mask, xp, extended_out_tensor) + atomic_mask = xp.reshape(mask[extended_atype], (nframes, nall, 1)) + extended_out_real = xp.concat( [ extended_out_tensor[:, :nloc], extended_out_tensor[:, nloc + nloc : nloc + nall], ], axis=1, ) - extended_out_mag = np.concatenate( + extended_out_mag = xp.concat( [ extended_out_tensor[:, nloc : nloc + nloc], extended_out_tensor[:, nloc + nall :], @@ -178,75 +236,81 @@ def process_spin_output_lower( ) if add_mag: extended_out_real = extended_out_real + extended_out_mag - extended_out_mag = ( - extended_out_mag.reshape([nframes, nall, -1]) * atomic_mask - ).reshape(extended_out_mag.shape) + extended_out_mag = xp.reshape( + xp.reshape(extended_out_mag, (nframes, nall, -1)) * atomic_mask, + extended_out_mag.shape, + ) return extended_out_real, extended_out_mag, atomic_mask > 0.0 @staticmethod def extend_nlist(extended_atype: Array, nlist: Array) -> Array: - nframes, nloc, nnei = nlist.shape + xp = array_api_compat.array_namespace(nlist) + nframes, nloc, _nnei = nlist.shape nall = extended_atype.shape[1] nlist_mask = nlist != -1 - nlist[nlist == -1] = 0 - nlist_shift = nlist + nall - nlist[~nlist_mask] = -1 - nlist_shift[~nlist_mask] = -1 - self_real = ( - np.arange(0, nloc, dtype=nlist.dtype) - .reshape(1, -1, 1) - .repeat(nframes, axis=0) + # Use xp.where instead of in-place boolean indexing + nlist_safe = xp.where(nlist_mask, nlist, xp.zeros_like(nlist)) + nlist_shift = xp.where(nlist_mask, nlist_safe + nall, -1 * xp.ones_like(nlist)) + # Restore nlist with -1 for masked entries (non-mutating) + nlist = xp.where(nlist_mask, nlist, -1 * xp.ones_like(nlist)) + self_real = xp.reshape( + xp.arange( + 0, nloc, dtype=nlist.dtype, device=array_api_compat.device(nlist) + ), + (1, nloc, 1), + ) * xp.ones( + (nframes, 1, 1), dtype=nlist.dtype, device=array_api_compat.device(nlist) ) self_spin = self_real + nall # real atom's neighbors: self spin + real neighbor + virtual neighbor # nf x nloc x (1 + nnei + nnei) - real_nlist = np.concatenate([self_spin, nlist, nlist_shift], axis=-1) + real_nlist = xp.concat([self_spin, nlist, nlist_shift], axis=-1) # spin atom's neighbors: real + real neighbor + virtual neighbor # nf x nloc x (1 + nnei + nnei) - spin_nlist = np.concatenate([self_real, nlist, nlist_shift], axis=-1) + spin_nlist = xp.concat([self_real, nlist, nlist_shift], axis=-1) # nf x (nloc + nloc) x (1 + nnei + nnei) - extended_nlist = np.concatenate([real_nlist, spin_nlist], axis=-2) - # update the index for switch - first_part_index = (nloc <= extended_nlist) & (extended_nlist < nall) - second_part_index = (nall <= extended_nlist) & (extended_nlist < (nall + nloc)) - extended_nlist[first_part_index] += nloc - extended_nlist[second_part_index] -= nall - nloc + extended_nlist = xp.concat([real_nlist, spin_nlist], axis=-2) + # update the index for switch using xp.where instead of in-place ops + first_part_mask = (nloc <= extended_nlist) & (extended_nlist < nall) + second_part_mask = (nall <= extended_nlist) & (extended_nlist < (nall + nloc)) + extended_nlist = xp.where( + first_part_mask, extended_nlist + nloc, extended_nlist + ) + extended_nlist = xp.where( + second_part_mask, extended_nlist - (nall - nloc), extended_nlist + ) return extended_nlist @staticmethod def concat_switch_virtual( extended_tensor: Array, extended_tensor_virtual: Array, nloc: int ) -> Array: - nframes, nall = extended_tensor.shape[:2] - out_shape = list(extended_tensor.shape) - out_shape[1] *= 2 - extended_tensor_updated = np.zeros( - out_shape, - dtype=extended_tensor.dtype, + xp = array_api_compat.array_namespace(extended_tensor) + return xp.concat( + [ + extended_tensor[:, :nloc], + extended_tensor_virtual[:, :nloc], + extended_tensor[:, nloc:], + extended_tensor_virtual[:, nloc:], + ], + axis=1, ) - extended_tensor_updated[:, :nloc] = extended_tensor[:, :nloc] - extended_tensor_updated[:, nloc : nloc + nloc] = extended_tensor_virtual[ - :, :nloc - ] - extended_tensor_updated[:, nloc + nloc : nloc + nall] = extended_tensor[ - :, nloc: - ] - extended_tensor_updated[:, nloc + nall :] = extended_tensor_virtual[:, nloc:] - return extended_tensor_updated.reshape(out_shape) @staticmethod def expand_aparam(aparam: Array, nloc: int) -> Array: """Expand the atom parameters for virtual atoms if necessary.""" + xp = array_api_compat.array_namespace(aparam) nframes, natom, numb_aparam = aparam.shape if natom == nloc: # good pass elif natom < nloc: # for spin with virtual atoms - aparam = np.concatenate( + aparam = xp.concat( [ aparam, - np.zeros( + xp.zeros( [nframes, nloc - natom, numb_aparam], dtype=aparam.dtype, + device=array_api_compat.device(aparam), ), ], axis=1, @@ -258,6 +322,64 @@ def expand_aparam(aparam: Array, nloc: int) -> Array: ) return aparam + def compute_or_load_stat( + self, + sampled_func: Callable[[], list[dict[str, Any]]], + stat_file_path: Any | None = None, + preset_observed_type: list[str] | None = None, + ) -> None: + """Compute or load the statistics parameters of the model. + + Parameters + ---------- + sampled_func + The lazy sampled function to get data frames from different data systems. + stat_file_path + The dictionary of paths to the statistics files. + preset_observed_type + The preset observed types. + """ + + @functools.lru_cache + def spin_sampled_func() -> list[dict[str, Any]]: + sampled = sampled_func() + spin_sampled = [] + for sys in sampled: + coord_updated, atype_updated, _ = self.process_spin_input( + sys["coord"], sys["atype"], sys["spin"] + ) + tmp_dict = { + "coord": coord_updated, + "atype": atype_updated, + } + if "aparam" in sys: + tmp_dict["aparam"] = self.expand_aparam( + sys["aparam"], atype_updated.shape[1] + ) + if "natoms" in sys: + natoms = sys["natoms"] + xp = array_api_compat.array_namespace(natoms) + tmp_dict["natoms"] = xp.concat( + [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], axis=-1 + ) + for item_key in sys: + if item_key not in [ + "coord", + "atype", + "spin", + "natoms", + "aparam", + ]: + tmp_dict[item_key] = sys[item_key] + spin_sampled.append(tmp_dict) + return spin_sampled + + self.backbone_model.compute_or_load_stat( + spin_sampled_func, + stat_file_path, + preset_observed_type=preset_observed_type, + ) + def get_type_map(self) -> list[str]: """Get the type map.""" tmap = self.backbone_model.get_type_map() @@ -479,10 +601,13 @@ def call_common( The keys are defined by the `ModelOutputDef`. """ + xp = array_api_compat.array_namespace(coord) nframes, nloc = atype.shape[:2] - coord = coord.reshape(nframes, nloc, 3) - spin = spin.reshape(nframes, nloc, 3) - coord_updated, atype_updated = self.process_spin_input(coord, atype, spin) + coord = xp.reshape(coord, (nframes, nloc, 3)) + spin = xp.reshape(spin, (nframes, nloc, 3)) + coord_updated, atype_updated, coord_corr_for_virial = self.process_spin_input( + coord, atype, spin + ) if aparam is not None: aparam = self.expand_aparam(aparam, nloc * 2) model_ret = self.backbone_model.call_common( @@ -492,12 +617,13 @@ def call_common( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + coord_corr_for_virial=coord_corr_for_virial, ) model_output_type = self.backbone_model.model_output_type() if "mask" in model_output_type: model_output_type.pop(model_output_type.index("mask")) var_name = model_output_type[0] - model_ret[f"{var_name}"] = np.split(model_ret[f"{var_name}"], [nloc], axis=1)[0] + model_ret[f"{var_name}"] = model_ret[f"{var_name}"][:, :nloc] if ( self.backbone_model.do_grad_r(var_name) and model_ret.get(f"{var_name}_derv_r") is not None @@ -524,8 +650,10 @@ def call_common( ) # Always compute mask_mag from atom types (even when forces are unavailable) if "mask_mag" not in model_ret: + xp = array_api_compat.array_namespace(atype) nframes_m, nloc_m = atype.shape[:2] - atomic_mask = self.virtual_scale_mask[atype].reshape([nframes_m, nloc_m, 1]) + vsm = self._to_xp(self.virtual_scale_mask, xp, atype) + atomic_mask = xp.reshape(vsm[atype], (nframes_m, nloc_m, 1)) model_ret["mask_mag"] = atomic_mask > 0.0 return model_ret @@ -591,7 +719,15 @@ def call( ): model_predict["force"] = model_ret[f"{var_name}_derv_r"].squeeze(-2) model_predict["force_mag"] = model_ret[f"{var_name}_derv_r_mag"].squeeze(-2) - # not support virial by far + if ( + self.backbone_model.do_grad_c(var_name) + and model_ret.get(f"{var_name}_derv_c_redu") is not None + ): + model_predict["virial"] = model_ret[f"{var_name}_derv_c_redu"].squeeze(-2) + if do_atomic_virial and model_ret.get(f"{var_name}_derv_c") is not None: + model_predict["atom_virial"] = model_ret[f"{var_name}_derv_c"].squeeze( + -2 + ) return model_predict def call_common_lower( @@ -641,6 +777,7 @@ def call_common_lower( extended_atype_updated, nlist_updated, mapping_updated, + extended_coord_corr, ) = self.process_spin_input_lower( extended_coord, extended_atype, extended_spin, nlist, mapping=mapping ) @@ -654,12 +791,13 @@ def call_common_lower( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + extended_coord_corr=extended_coord_corr, ) model_output_type = self.backbone_model.model_output_type() if "mask" in model_output_type: model_output_type.pop(model_output_type.index("mask")) var_name = model_output_type[0] - model_ret[f"{var_name}"] = np.split(model_ret[f"{var_name}"], [nloc], axis=1)[0] + model_ret[f"{var_name}"] = model_ret[f"{var_name}"][:, :nloc] if ( self.backbone_model.do_grad_r(var_name) and model_ret.get(f"{var_name}_derv_r") is not None @@ -689,10 +827,10 @@ def call_common_lower( ) # Always compute mask_mag from atom types (even when forces are unavailable) if "mask_mag" not in model_ret: + xp = array_api_compat.array_namespace(extended_atype) nall = extended_atype.shape[1] - atomic_mask = self.virtual_scale_mask[extended_atype].reshape( - [nframes, nall, 1] - ) + vsm = self._to_xp(self.virtual_scale_mask, xp, extended_atype) + atomic_mask = xp.reshape(vsm[extended_atype], (nframes, nall, 1)) model_ret["mask_mag"] = atomic_mask > 0.0 return model_ret @@ -764,7 +902,15 @@ def call_lower( model_predict["extended_force_mag"] = model_ret[ f"{var_name}_derv_r_mag" ].squeeze(-2) - # not support virial by far + if ( + self.backbone_model.do_grad_c(var_name) + and model_ret.get(f"{var_name}_derv_c_redu") is not None + ): + model_predict["virial"] = model_ret[f"{var_name}_derv_c_redu"].squeeze(-2) + if do_atomic_virial and model_ret.get(f"{var_name}_derv_c") is not None: + model_predict["extended_virial"] = model_ret[ + f"{var_name}_derv_c" + ].squeeze(-2) return model_predict def translated_output_def(self) -> dict[str, Any]: @@ -789,4 +935,9 @@ def translated_output_def(self) -> dict[str, Any]: output_def["force"].squeeze(-2) output_def["force_mag"] = deepcopy(out_def_data[f"{var_name}_derv_r_mag"]) output_def["force_mag"].squeeze(-2) + if self.backbone_model.do_grad_c(var_name): + output_def["virial"] = deepcopy(out_def_data[f"{var_name}_derv_c_redu"]) + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = deepcopy(out_def_data[f"{var_name}_derv_c"]) + output_def["atom_virial"].squeeze(-2) return output_def diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py index 533181e250..4522e25586 100644 --- a/deepmd/jax/model/base_model.py +++ b/deepmd/jax/model/base_model.py @@ -25,6 +25,7 @@ def forward_common_atomic( fparam: jnp.ndarray | None = None, aparam: jnp.ndarray | None = None, do_atomic_virial: bool = False, + extended_coord_corr: jnp.ndarray | None = None, ) -> dict[str, jnp.ndarray]: atomic_ret = self.atomic_model.forward_common_atomic( extended_coord, @@ -110,6 +111,10 @@ def eval_output( assert vdef.r_differentiable # avr: [nf, *def, nall, 3, 3] avr = jnp.einsum("f...ai,faj->f...aij", ff, extended_coord) + if extended_coord_corr is not None: + avr = avr + jnp.einsum( + "f...ai,faj->f...aij", ff, extended_coord_corr + ) # the correction sums to zero, which does not contribute to global virial if do_atomic_virial: diff --git a/deepmd/jax/model/dp_model.py b/deepmd/jax/model/dp_model.py index 5545b5505b..3e96eb6689 100644 --- a/deepmd/jax/model/dp_model.py +++ b/deepmd/jax/model/dp_model.py @@ -55,6 +55,7 @@ def forward_common_atomic( fparam: jnp.ndarray | None = None, aparam: jnp.ndarray | None = None, do_atomic_virial: bool = False, + extended_coord_corr: jnp.ndarray | None = None, ) -> dict[str, jnp.ndarray]: return forward_common_atomic( self, @@ -65,6 +66,7 @@ def forward_common_atomic( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + extended_coord_corr=extended_coord_corr, ) def format_nlist( diff --git a/deepmd/jax/model/dp_zbl_model.py b/deepmd/jax/model/dp_zbl_model.py index d9be671a46..7751d22a1f 100644 --- a/deepmd/jax/model/dp_zbl_model.py +++ b/deepmd/jax/model/dp_zbl_model.py @@ -37,6 +37,7 @@ def forward_common_atomic( fparam: jnp.ndarray | None = None, aparam: jnp.ndarray | None = None, do_atomic_virial: bool = False, + extended_coord_corr: jnp.ndarray | None = None, ) -> dict[str, jnp.ndarray]: return forward_common_atomic( self, @@ -47,6 +48,7 @@ def forward_common_atomic( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + extended_coord_corr=extended_coord_corr, ) def format_nlist( diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index cc6404675e..a7b83c0fe0 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -491,13 +491,23 @@ def spin_sampled_func() -> list[dict[str, Any]]: "coord": coord_updated, "atype": atype_updated, } + if "aparam" in sys: + tmp_dict["aparam"] = self.expand_aparam( + sys["aparam"], atype_updated.shape[1] + ) if "natoms" in sys: natoms = sys["natoms"] tmp_dict["natoms"] = torch.cat( [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1 ) for item_key in sys.keys(): - if item_key not in ["coord", "atype", "spin", "natoms"]: + if item_key not in [ + "coord", + "atype", + "spin", + "natoms", + "aparam", + ]: tmp_dict[item_key] = sys[item_key] spin_sampled.append(tmp_dict) return spin_sampled @@ -671,7 +681,7 @@ def translated_output_def(self) -> dict[str, Any]: output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) output_def["virial"].squeeze(-2) output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) return output_def def forward( @@ -703,7 +713,7 @@ def forward( if self.backbone_model.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) return model_predict @torch.jit.export @@ -744,6 +754,6 @@ def forward_lower( model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( - -3 + -2 ) return model_predict diff --git a/deepmd/pt_expt/model/__init__.py b/deepmd/pt_expt/model/__init__.py index 7197e39634..c6888aed0b 100644 --- a/deepmd/pt_expt/model/__init__.py +++ b/deepmd/pt_expt/model/__init__.py @@ -27,6 +27,9 @@ from .property_model import ( PropertyModel, ) +from .spin_ener_model import ( + SpinEnergyModel, +) __all__ = [ "BaseModel", @@ -36,6 +39,7 @@ "EnergyModel", "PolarModel", "PropertyModel", + "SpinEnergyModel", "get_model", "make_hessian_model", ] diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index 02e67d0f27..bf4f911c0c 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -212,6 +212,7 @@ def forward_common_atomic( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + extended_coord_corr: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: atomic_ret = self.atomic_model.forward_common_atomic( extended_coord, @@ -228,6 +229,7 @@ def forward_common_atomic( do_atomic_virial=do_atomic_virial, create_graph=self.training, mask=atomic_ret.get("mask"), + extended_coord_corr=extended_coord_corr, ) # Hessian computation (mirrors JAX's forward_common_atomic). # Produces hessian on extended coords [nf, *def, nall, 3, nall, 3], diff --git a/deepmd/pt_expt/model/spin_ener_model.py b/deepmd/pt_expt/model/spin_ener_model.py new file mode 100644 index 0000000000..e96d0fbaf1 --- /dev/null +++ b/deepmd/pt_expt/model/spin_ener_model.py @@ -0,0 +1,206 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) +from typing import ( + Any, +) + +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from .spin_model import ( + SpinModel, +) + + +class SpinEnergyModel(SpinModel): + """A spin model for energy.""" + + model_type = "ener" + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + "mask_mag": out_def_data["mask_mag"], + } + if self.do_grad_r("energy"): + output_def["force"] = deepcopy(out_def_data["energy_derv_r"]) + output_def["force"].squeeze(-2) + output_def["force_mag"] = deepcopy(out_def_data["energy_derv_r_mag"]) + output_def["force_mag"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"]) + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"]) + output_def["atom_virial"].squeeze(-2) + return output_def + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + spin: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + model_ret = self.call_common( + coord, + atype, + spin, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + model_predict["mask_mag"] = model_ret["mask_mag"] + if self.backbone_model.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2) + if self.backbone_model.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) + return model_predict + + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_spin: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + model_ret = self.call_common_lower( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + model_predict["extended_mask_mag"] = model_ret["mask_mag"] + if self.backbone_model.do_grad_r("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["extended_force_mag"] = model_ret[ + "energy_derv_r_mag" + ].squeeze(-2) + if self.backbone_model.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -2 + ) + return model_predict + + def forward_lower_exportable( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_spin: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + **make_fx_kwargs: Any, + ) -> torch.nn.Module: + """Trace ``forward_lower`` into an exportable module. + + Delegates to ``forward_common_lower_exportable`` for tracing, + then translates the internal keys to the ``forward_lower`` + convention. + + Parameters + ---------- + extended_coord, extended_atype, extended_spin, nlist, mapping, fparam, aparam, do_atomic_virial + Sample inputs with representative shapes (used for tracing). + **make_fx_kwargs + Extra keyword arguments forwarded to ``make_fx`` + (e.g. ``tracing_mode="symbolic"``). + + Returns + ------- + torch.nn.Module + A traced module whose ``forward`` accepts + ``(extended_coord, extended_atype, extended_spin, nlist, mapping, fparam, aparam)`` + and returns a dict with the same keys as ``forward_lower``. + """ + traced = self.forward_common_lower_exportable( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + **make_fx_kwargs, + ) + + # Translate internal keys to forward_lower convention. + # Capture model config at trace time via closures. + do_grad_r = self.backbone_model.do_grad_r("energy") + do_grad_c = self.backbone_model.do_grad_c("energy") + + def fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_spin: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + model_ret = traced( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam, + aparam, + ) + model_predict: dict[str, torch.Tensor] = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + model_predict["extended_mask_mag"] = model_ret["mask_mag"] + if do_grad_r: + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + model_predict["extended_force_mag"] = model_ret[ + "energy_derv_r_mag" + ].squeeze(-2) + if do_grad_c: + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret[ + "energy_derv_c" + ].squeeze(-2) + return model_predict + + return make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam, + aparam, + ) diff --git a/deepmd/pt_expt/model/spin_model.py b/deepmd/pt_expt/model/spin_model.py new file mode 100644 index 0000000000..259ff10698 --- /dev/null +++ b/deepmd/pt_expt/model/spin_model.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.model.spin_model import SpinModel as SpinModelDP +from deepmd.pt_expt.common import ( + torch_module, +) +from deepmd.utils.spin import ( + Spin, +) + +from .make_model import ( + make_model, +) +from .model import ( + BaseModel, +) + + +@torch_module +class SpinModel(SpinModelDP): + def __getattr__(self, name: str) -> Any: + """Get attribute from the wrapped model. + + In torch.nn.Module, submodules are stored in _modules, not __dict__. + Override the dpmodel version to use torch.nn.Module's __getattr__ + first (which checks _parameters, _buffers, _modules), then fall + back to backbone_model delegation for arbitrary attributes. + """ + try: + return torch.nn.Module.__getattr__(self, name) + except AttributeError: + pass + # backbone_model is in _modules, access via _modules directly + # to avoid re-entering __getattr__ + modules = self.__dict__.get("_modules", {}) + backbone = modules.get("backbone_model") + if backbone is not None: + return getattr(backbone, name) + raise AttributeError(name) + + def forward_common_lower_exportable( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_spin: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + **make_fx_kwargs: Any, + ) -> torch.nn.Module: + """Trace ``call_common_lower`` into an exportable module. + + Uses ``make_fx`` to trace through ``torch.autograd.grad``, + decomposing the backward pass into primitive ops. The returned + module can be passed directly to ``torch.export.export``. + + The output uses internal key names (e.g. ``energy``, + ``energy_redu``, ``energy_derv_r``) so that subclasses can + apply their own key translation on top. + + Parameters + ---------- + extended_coord, extended_atype, extended_spin, nlist, mapping, fparam, aparam, do_atomic_virial + Sample inputs with representative shapes (used for tracing). + **make_fx_kwargs + Extra keyword arguments forwarded to ``make_fx`` + (e.g. ``tracing_mode="symbolic"``). + + Returns + ------- + torch.nn.Module + A traced module whose ``forward`` accepts + ``(extended_coord, extended_atype, extended_spin, nlist, + mapping, fparam, aparam)`` and returns a dict with the same + keys as ``call_common_lower``. + """ + model = self + + def fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_spin: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + extended_coord = extended_coord.detach().requires_grad_(True) + return model.forward_common_lower( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + return make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + extended_spin, + nlist, + mapping, + fparam, + aparam, + ) + + def forward_common_lower( + self, *args: Any, **kwargs: Any + ) -> dict[str, torch.Tensor]: + """Forward common lower delegates to call_common_lower().""" + return self.call_common_lower(*args, **kwargs) + + @classmethod + def deserialize(cls, data: dict) -> "SpinModel": + from deepmd.dpmodel.atomic_model import ( + DPEnergyAtomicModel, + ) + + backbone_model_obj = make_model( + DPEnergyAtomicModel, T_Bases=(BaseModel,) + ).deserialize(data["backbone_model"]) + spin = Spin.deserialize(data["spin"]) + return cls( + backbone_model=backbone_model_obj, + spin=spin, + ) diff --git a/deepmd/pt_expt/model/transform_output.py b/deepmd/pt_expt/model/transform_output.py index 265a5a268d..381eb86831 100644 --- a/deepmd/pt_expt/model/transform_output.py +++ b/deepmd/pt_expt/model/transform_output.py @@ -140,6 +140,7 @@ def fit_output_to_model_output( do_atomic_virial: bool = False, create_graph: bool = True, mask: torch.Tensor | None = None, + extended_coord_corr: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: """Transform the output of the fitting network to the model output. @@ -176,6 +177,12 @@ def fit_output_to_model_output( model_ret[kk_derv_r] = dr if vdef.c_differentiable: assert dc is not None + if extended_coord_corr is not None: + dc_corr = ( + dr.squeeze(-2).unsqueeze(-1) + @ extended_coord_corr.unsqueeze(-2).to(dr.dtype) + ).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005 + dc = dc + dc_corr model_ret[kk_derv_c] = dc model_ret[kk_derv_c + "_redu"] = torch.sum( model_ret[kk_derv_c].to(redu_prec), dim=1 diff --git a/source/tests/consistent/model/test_spin_ener.py b/source/tests/consistent/model/test_spin_ener.py index f6019732cf..90adb34b8f 100644 --- a/source/tests/consistent/model/test_spin_ener.py +++ b/source/tests/consistent/model/test_spin_ener.py @@ -7,6 +7,9 @@ import numpy as np +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.model.model import get_model as get_model_dp from deepmd.dpmodel.model.spin_model import SpinModel as SpinModelDP from deepmd.dpmodel.utils.nlist import ( @@ -21,11 +24,14 @@ ) from ..common import ( + INSTALLED_ARRAY_API_STRICT, INSTALLED_PT, + INSTALLED_PT_EXPT, CommonTest, ) from .common import ( ModelTest, + compare_variables_recursive, ) if INSTALLED_PT: @@ -35,6 +41,15 @@ from deepmd.pt.utils.utils import to_torch_tensor as numpy_to_torch else: SpinEnergyModelPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.common import to_torch_array as pt_expt_numpy_to_torch + from deepmd.pt_expt.model.spin_ener_model import ( + SpinEnergyModel as SpinEnergyModelPTExpt, + ) +else: + SpinEnergyModelPTExpt = None +if INSTALLED_ARRAY_API_STRICT: + import array_api_strict from deepmd.utils.argcheck import ( model_args, @@ -65,6 +80,8 @@ "resnet_dt": True, "precision": "float64", "seed": 1, + "numb_fparam": 2, + "numb_aparam": 3, }, "spin": { "use_spin": [True, False, False], @@ -82,15 +99,18 @@ def data(self) -> dict: dp_class = SpinModelDP pt_class = SpinEnergyModelPT pd_class = None - pt_expt_class = None + pt_expt_class = SpinEnergyModelPTExpt jax_class = None args = model_args() skip_tf = True skip_jax = True - skip_pt_expt = True skip_pd = True + @property + def skip_pt_expt(self): + return not INSTALLED_PT_EXPT + def get_reference_backend(self): """Get the reference backend. @@ -98,6 +118,8 @@ def get_reference_backend(self): """ if not self.skip_pt: return self.RefBackend.PT + if not self.skip_pt_expt: + return self.RefBackend.PT_EXPT if not self.skip_dp: return self.RefBackend.DP raise ValueError("No available reference") @@ -109,6 +131,9 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is SpinEnergyModelPT: return get_model_pt(data) + elif cls is SpinEnergyModelPTExpt: + dp_model = get_model_dp(data) + return SpinEnergyModelPTExpt.deserialize(dp_model.serialize()) return cls(**data, **self.additional_data) def setUp(self) -> None: @@ -167,6 +192,17 @@ def setUp(self) -> None: ], dtype=GLOBAL_NP_FLOAT_PRECISION, ).reshape(1, -1, 3) + nframes = 1 + nloc = 6 + numb_fparam = 2 + numb_aparam = 3 + rng = np.random.default_rng(42) + self.fparam = rng.normal(size=(nframes, numb_fparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + self.aparam = rng.normal(size=(nframes, nloc, numb_aparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ) def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: raise NotImplementedError("no TF in this test") @@ -177,6 +213,8 @@ def eval_dp(self, dp_obj: Any) -> Any: self.atype, self.spin, box=self.box, + fparam=self.fparam, + aparam=self.aparam, ) def eval_pt(self, pt_obj: Any) -> Any: @@ -187,6 +225,23 @@ def eval_pt(self, pt_obj: Any) -> Any: numpy_to_torch(self.atype), numpy_to_torch(self.spin), box=numpy_to_torch(self.box), + fparam=numpy_to_torch(self.fparam), + aparam=numpy_to_torch(self.aparam), + ).items() + } + + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + coord_tensor = pt_expt_numpy_to_torch(self.coords) + coord_tensor.requires_grad_(True) + return { + kk: vv.detach().cpu().numpy() + for kk, vv in pt_expt_obj( + coord_tensor, + pt_expt_numpy_to_torch(self.atype), + pt_expt_numpy_to_torch(self.spin), + box=pt_expt_numpy_to_torch(self.box), + fparam=pt_expt_numpy_to_torch(self.fparam), + aparam=pt_expt_numpy_to_torch(self.aparam), ).items() } @@ -203,6 +258,7 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["mask_mag"].ravel(), SKIP_FLAG, SKIP_FLAG, + SKIP_FLAG, ) elif backend is self.RefBackend.PT: return ( @@ -211,6 +267,16 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["mask_mag"].ravel(), ret["force"].ravel(), ret["force_mag"].ravel(), + ret["virial"].ravel(), + ) + elif backend is self.RefBackend.PT_EXPT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["mask_mag"].ravel(), + ret["force"].ravel(), + ret["force_mag"].ravel(), + ret["virial"].ravel(), ) raise ValueError(f"Unknown backend: {backend}") @@ -224,14 +290,21 @@ def data(self) -> dict: dp_class = SpinModelDP pt_class = SpinEnergyModelPT pd_class = None - pt_expt_class = None + pt_expt_class = SpinEnergyModelPTExpt jax_class = None + array_api_strict_class = SpinModelDP args = model_args() skip_tf = True skip_jax = True - skip_pt_expt = True skip_pd = True + # The backbone model (make_model) is not yet array_api_strict compatible + # at the full model call_lower level (indexing issues in descriptor/fitting). + skip_array_api_strict = True + + @property + def skip_pt_expt(self): + return not INSTALLED_PT_EXPT def get_reference_backend(self): """Get the reference backend. @@ -240,6 +313,8 @@ def get_reference_backend(self): """ if not self.skip_pt: return self.RefBackend.PT + if not self.skip_pt_expt: + return self.RefBackend.PT_EXPT if not self.skip_dp: return self.RefBackend.DP raise ValueError("No available reference") @@ -251,6 +326,9 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is SpinEnergyModelPT: return get_model_pt(data) + elif cls is SpinEnergyModelPTExpt: + dp_model = get_model_dp(data) + return SpinEnergyModelPTExpt.deserialize(dp_model.serialize()) return cls(**data, **self.additional_data) def setUp(self) -> None: @@ -339,6 +417,18 @@ def setUp(self) -> None: axis=1, ) + nframes = 1 + nloc = 6 + numb_fparam = 2 + numb_aparam = 3 + rng = np.random.default_rng(42) + self.fparam = rng.normal(size=(nframes, numb_fparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + self.aparam = rng.normal(size=(nframes, nloc, numb_aparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: raise NotImplementedError("no TF in this test") @@ -349,6 +439,8 @@ def eval_dp(self, dp_obj: Any) -> Any: self.extended_spin, self.nlist, self.mapping, + fparam=self.fparam, + aparam=self.aparam, ) def eval_pt(self, pt_obj: Any) -> Any: @@ -360,6 +452,38 @@ def eval_pt(self, pt_obj: Any) -> Any: numpy_to_torch(self.extended_spin), numpy_to_torch(self.nlist), numpy_to_torch(self.mapping), + fparam=numpy_to_torch(self.fparam), + aparam=numpy_to_torch(self.aparam), + ).items() + } + + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + extended_coord_tensor = pt_expt_numpy_to_torch(self.extended_coord) + extended_coord_tensor.requires_grad_(True) + return { + kk: vv.detach().cpu().numpy() + for kk, vv in pt_expt_obj.forward_lower( + extended_coord_tensor, + pt_expt_numpy_to_torch(self.extended_atype), + pt_expt_numpy_to_torch(self.extended_spin), + pt_expt_numpy_to_torch(self.nlist), + pt_expt_numpy_to_torch(self.mapping), + fparam=pt_expt_numpy_to_torch(self.fparam), + aparam=pt_expt_numpy_to_torch(self.aparam), + ).items() + } + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + return { + kk: to_numpy_array(vv) if hasattr(vv, "__array_namespace__") else vv + for kk, vv in array_api_strict_obj.call_lower( + array_api_strict.asarray(self.extended_coord), + array_api_strict.asarray(self.extended_atype), + array_api_strict.asarray(self.extended_spin), + array_api_strict.asarray(self.nlist), + array_api_strict.asarray(self.mapping), + fparam=array_api_strict.asarray(self.fparam), + aparam=array_api_strict.asarray(self.aparam), ).items() } @@ -376,6 +500,7 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["extended_mask_mag"].ravel(), SKIP_FLAG, SKIP_FLAG, + SKIP_FLAG, ) elif backend is self.RefBackend.PT: return ( @@ -384,5 +509,462 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["extended_mask_mag"].ravel(), ret["extended_force"].ravel(), ret["extended_force_mag"].ravel(), + ret["virial"].ravel(), + ) + elif backend is self.RefBackend.PT_EXPT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["extended_mask_mag"].ravel(), + ret["extended_force"].ravel(), + ret["extended_force_mag"].ravel(), + ret["virial"].ravel(), + ) + elif backend is self.RefBackend.ARRAY_API_STRICT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["extended_mask_mag"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + SKIP_FLAG, ) raise ValueError(f"Unknown backend: {backend}") + + +@unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") +class TestSpinEnerComputeOrLoadStat(unittest.TestCase): + """Test that compute_or_load_stat produces identical statistics on dp, pt, and pt_expt + for spin models. + """ + + def setUp(self) -> None: + data = model_args().normalize_value( + copy.deepcopy(SPIN_DATA), + trim_pattern="_*", + ) + + self._model_data = data + + # Build dp model, then deserialize into pt and pt_expt to share weights + self.dp_model = get_model_dp(data) + serialized = self.dp_model.serialize() + self.pt_model = SpinEnergyModelPT.deserialize(serialized) + self.pt_expt_model = SpinEnergyModelPTExpt.deserialize(serialized) + + # Test coords / atype / box for forward evaluation + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 0.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 0, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.spin = np.array( + [ + 0.50, + 0.30, + 0.20, + 0.40, + 0.25, + 0.15, + 0.10, + 0.05, + 0.08, + 0.12, + 0.07, + 0.09, + 0.45, + 0.35, + 0.28, + 0.11, + 0.06, + 0.03, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + + # Mock training data for compute_or_load_stat + natoms = 6 + nframes = 3 + numb_fparam = 2 + numb_aparam = 3 + rng = np.random.default_rng(42) + coords_stat = rng.normal(size=(nframes, natoms, 3)).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + atype_stat = np.array([[0, 0, 1, 0, 1, 1]] * nframes, dtype=np.int32) + box_stat = np.tile( + np.eye(3, dtype=GLOBAL_NP_FLOAT_PRECISION).reshape(1, 3, 3) * 13.0, + (nframes, 1, 1), + ) + # natoms: [total, real, type0_count, type1_count, type2_count] + natoms_stat = np.array([[natoms, natoms, 3, 3, 0]] * nframes, dtype=np.int32) + energy_stat = rng.normal(size=(nframes, 1)).astype(GLOBAL_NP_FLOAT_PRECISION) + spin_stat = rng.normal(size=(nframes, natoms, 3)).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + fparam_stat = rng.normal(size=(nframes, numb_fparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + aparam_stat = rng.normal(size=(nframes, natoms, numb_aparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + + # dp sample (numpy) + np_sample = { + "coord": coords_stat, + "atype": atype_stat, + "atype_ext": atype_stat, + "box": box_stat, + "natoms": natoms_stat, + "energy": energy_stat, + "find_energy": np.float32(1.0), + "spin": spin_stat, + "fparam": fparam_stat, + "find_fparam": np.float32(1.0), + "aparam": aparam_stat, + "find_aparam": np.float32(1.0), + } + # pt / pt_expt sample (torch tensors) + pt_sample = { + "coord": numpy_to_torch(coords_stat), + "atype": numpy_to_torch(atype_stat), + "atype_ext": numpy_to_torch(atype_stat), + "box": numpy_to_torch(box_stat), + "natoms": numpy_to_torch(natoms_stat), + "energy": numpy_to_torch(energy_stat), + "find_energy": np.float32(1.0), + "spin": numpy_to_torch(spin_stat), + "fparam": numpy_to_torch(fparam_stat), + "find_fparam": np.float32(1.0), + "aparam": numpy_to_torch(aparam_stat), + "find_aparam": np.float32(1.0), + } + + self.np_sampled = [np_sample] + self.pt_sampled = [pt_sample] + + # fparam/aparam for eval + nframes_eval = 1 + nloc_eval = 6 + self.fparam = rng.normal(size=(nframes_eval, numb_fparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + self.aparam = rng.normal(size=(nframes_eval, nloc_eval, numb_aparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ) + + def _eval_dp(self) -> dict: + return self.dp_model( + self.coords, + self.atype, + self.spin, + box=self.box, + fparam=self.fparam, + aparam=self.aparam, + ) + + def _eval_pt(self) -> dict: + return { + kk: torch_to_numpy(vv) + for kk, vv in self.pt_model( + numpy_to_torch(self.coords), + numpy_to_torch(self.atype), + numpy_to_torch(self.spin), + box=numpy_to_torch(self.box), + fparam=numpy_to_torch(self.fparam), + aparam=numpy_to_torch(self.aparam), + ).items() + } + + def _eval_pt_expt(self) -> dict: + coord_t = pt_expt_numpy_to_torch(self.coords) + coord_t.requires_grad_(True) + return { + k: v.detach().cpu().numpy() + for k, v in self.pt_expt_model( + coord_t, + pt_expt_numpy_to_torch(self.atype), + pt_expt_numpy_to_torch(self.spin), + box=pt_expt_numpy_to_torch(self.box), + fparam=pt_expt_numpy_to_torch(self.fparam), + aparam=pt_expt_numpy_to_torch(self.aparam), + ).items() + } + + def test_compute_stat(self) -> None: + # 1. Pre-stat forward consistency + dp_ret0 = self._eval_dp() + pt_ret0 = self._eval_pt() + pe_ret0 = self._eval_pt_expt() + for key in ("energy", "atom_energy"): + np.testing.assert_allclose( + dp_ret0[key], + pt_ret0[key], + rtol=1e-10, + atol=1e-10, + err_msg=f"Pre-stat dp vs pt mismatch in {key}", + ) + np.testing.assert_allclose( + dp_ret0[key], + pe_ret0[key], + rtol=1e-10, + atol=1e-10, + err_msg=f"Pre-stat dp vs pt_expt mismatch in {key}", + ) + + # 2. Run compute_or_load_stat on all three backends + from copy import ( + deepcopy, + ) + + self.dp_model.compute_or_load_stat(lambda: deepcopy(self.np_sampled)) + self.pt_model.compute_or_load_stat(lambda: deepcopy(self.pt_sampled)) + self.pt_expt_model.compute_or_load_stat(lambda: deepcopy(self.pt_sampled)) + + # 3. Serialize all three and compare @variables + dp_ser = self.dp_model.serialize() + pt_ser = self.pt_model.serialize() + pe_ser = self.pt_expt_model.serialize() + compare_variables_recursive(dp_ser, pt_ser) + compare_variables_recursive(dp_ser, pe_ser) + + # 4. Post-stat forward consistency + dp_ret1 = self._eval_dp() + pt_ret1 = self._eval_pt() + pe_ret1 = self._eval_pt_expt() + for key in ("energy", "atom_energy"): + np.testing.assert_allclose( + dp_ret1[key], + pt_ret1[key], + rtol=1e-10, + atol=1e-10, + err_msg=f"Post-stat dp vs pt mismatch in {key}", + ) + np.testing.assert_allclose( + dp_ret1[key], + pe_ret1[key], + rtol=1e-10, + atol=1e-10, + err_msg=f"Post-stat dp vs pt_expt mismatch in {key}", + ) + + def test_load_stat_from_file(self) -> None: + import tempfile + from pathlib import ( + Path, + ) + + import h5py + + from deepmd.utils.path import ( + DPPath, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + # Create separate stat files for each backend + dp_h5 = str((Path(tmpdir) / "dp_stat.h5").resolve()) + pt_h5 = str((Path(tmpdir) / "pt_stat.h5").resolve()) + pe_h5 = str((Path(tmpdir) / "pe_stat.h5").resolve()) + for p in (dp_h5, pt_h5, pe_h5): + with h5py.File(p, "w"): + pass + + # 1. Compute stats and save to file + self.dp_model.compute_or_load_stat( + lambda: self.np_sampled, stat_file_path=DPPath(dp_h5, "a") + ) + self.pt_model.compute_or_load_stat( + lambda: self.pt_sampled, stat_file_path=DPPath(pt_h5, "a") + ) + self.pt_expt_model.compute_or_load_stat( + lambda: self.pt_sampled, stat_file_path=DPPath(pe_h5, "a") + ) + + # Save the computed serializations as reference + dp_ser_computed = self.dp_model.serialize() + pt_ser_computed = self.pt_model.serialize() + pe_ser_computed = self.pt_expt_model.serialize() + + # 2. Build fresh models from the same initial weights + dp_model2 = get_model_dp(self._model_data) + pt_model2 = SpinEnergyModelPT.deserialize(dp_model2.serialize()) + pe_model2 = SpinEnergyModelPTExpt.deserialize(dp_model2.serialize()) + + # 3. Load stats from file (should NOT call the sampled func) + def raise_error(): + raise RuntimeError("Should load from file, not recompute") + + dp_model2.compute_or_load_stat( + raise_error, stat_file_path=DPPath(dp_h5, "a") + ) + pt_model2.compute_or_load_stat( + raise_error, stat_file_path=DPPath(pt_h5, "a") + ) + pe_model2.compute_or_load_stat( + raise_error, stat_file_path=DPPath(pe_h5, "a") + ) + + # 4. Loaded models should match the computed ones + dp_ser_loaded = dp_model2.serialize() + pt_ser_loaded = pt_model2.serialize() + pe_ser_loaded = pe_model2.serialize() + compare_variables_recursive(dp_ser_computed, dp_ser_loaded) + compare_variables_recursive(pt_ser_computed, pt_ser_loaded) + compare_variables_recursive(pe_ser_computed, pe_ser_loaded) + + # 5. Cross-backend consistency after loading + compare_variables_recursive(dp_ser_loaded, pt_ser_loaded) + compare_variables_recursive(dp_ser_loaded, pe_ser_loaded) + + +@unittest.skipUnless(INSTALLED_PT and INSTALLED_PT_EXPT, "PT and PT_EXPT are required") +class TestSpinEnerModelAPIs(unittest.TestCase): + """Test consistency of model-level APIs between dp, pt, and pt_expt for spin models. + + Both models are constructed from the same serialized weights + (dpmodel -> serialize -> pt/pt_expt deserialize) so that numerical outputs + can be compared directly. + """ + + def setUp(self) -> None: + data = model_args().normalize_value( + copy.deepcopy(SPIN_DATA), + trim_pattern="_*", + ) + self.dp_model = get_model_dp(data) + serialized = self.dp_model.serialize() + self.pt_model = SpinEnergyModelPT.deserialize(serialized) + self.pt_expt_model = SpinEnergyModelPTExpt.deserialize(serialized) + + def test_translated_output_def(self) -> None: + """translated_output_def should return the same keys on dp, pt, and pt_expt.""" + dp_def = self.dp_model.translated_output_def() + pt_def = self.pt_model.translated_output_def() + pt_expt_def = self.pt_expt_model.translated_output_def() + self.assertEqual(set(dp_def.keys()), set(pt_def.keys())) + self.assertEqual(set(dp_def.keys()), set(pt_expt_def.keys())) + for key in dp_def: + self.assertEqual(dp_def[key].shape, pt_def[key].shape) + self.assertEqual(dp_def[key].shape, pt_expt_def[key].shape) + + def test_model_output_def(self) -> None: + """model_output_def should return the same keys and shapes on dp and pt.""" + dp_def = self.dp_model.model_output_def().get_data() + pt_def = self.pt_model.model_output_def().get_data() + self.assertEqual(set(dp_def.keys()), set(pt_def.keys())) + for key in dp_def: + self.assertEqual(dp_def[key].shape, pt_def[key].shape) + + def test_model_output_type(self) -> None: + """model_output_type should return the same list on dp and pt.""" + self.assertEqual( + self.dp_model.model_output_type(), + self.pt_model.model_output_type(), + ) + + def test_do_grad_r(self) -> None: + """do_grad_r should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.do_grad_r("energy"), + self.pt_model.do_grad_r("energy"), + ) + self.assertTrue(self.dp_model.do_grad_r("energy")) + + def test_do_grad_c(self) -> None: + """do_grad_c should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.do_grad_c("energy"), + self.pt_model.do_grad_c("energy"), + ) + self.assertTrue(self.dp_model.do_grad_c("energy")) + + def test_get_rcut(self) -> None: + """get_rcut should return the same value on dp, pt, and pt_expt.""" + self.assertEqual(self.dp_model.get_rcut(), self.pt_model.get_rcut()) + self.assertEqual(self.dp_model.get_rcut(), self.pt_expt_model.get_rcut()) + self.assertAlmostEqual(self.dp_model.get_rcut(), 4.0) + + def test_get_type_map(self) -> None: + """get_type_map should return the same list on dp, pt, and pt_expt.""" + self.assertEqual(self.dp_model.get_type_map(), self.pt_model.get_type_map()) + self.assertEqual( + self.dp_model.get_type_map(), self.pt_expt_model.get_type_map() + ) + self.assertEqual(self.dp_model.get_type_map(), ["O", "H", "B"]) + + def test_get_ntypes(self) -> None: + """get_ntypes should return the same value on dp, pt, and pt_expt.""" + self.assertEqual(self.dp_model.get_ntypes(), self.pt_model.get_ntypes()) + self.assertEqual(self.dp_model.get_ntypes(), self.pt_expt_model.get_ntypes()) + self.assertEqual(self.dp_model.get_ntypes(), 3) + + def test_get_sel_type(self) -> None: + """get_sel_type should return the same list on dp and pt.""" + self.assertEqual(self.dp_model.get_sel_type(), self.pt_model.get_sel_type()) + + def test_get_dim_fparam(self) -> None: + """get_dim_fparam should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_dim_fparam(), self.pt_model.get_dim_fparam()) + self.assertEqual(self.dp_model.get_dim_fparam(), 2) + + def test_get_dim_aparam(self) -> None: + """get_dim_aparam should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_dim_aparam(), self.pt_model.get_dim_aparam()) + self.assertEqual(self.dp_model.get_dim_aparam(), 3) + + def test_get_nnei(self) -> None: + """get_nnei should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_nnei(), self.pt_model.get_nnei()) + + def test_get_nsel(self) -> None: + """get_nsel should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_nsel(), self.pt_model.get_nsel()) + + def test_is_aparam_nall(self) -> None: + """is_aparam_nall should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.is_aparam_nall(), self.pt_model.is_aparam_nall()) + + def test_has_spin(self) -> None: + """has_spin should return True on all backends.""" + self.assertTrue(self.dp_model.has_spin()) + self.assertTrue(self.pt_model.has_spin()) + self.assertTrue(self.pt_expt_model.has_spin()) + + def test_get_model_def_script(self) -> None: + """get_model_def_script should return the same value on dp, pt, and pt_expt.""" + dp_val = self.dp_model.get_model_def_script() + pt_val = self.pt_model.get_model_def_script() + pe_val = self.pt_expt_model.get_model_def_script() + self.assertEqual(dp_val, pt_val) + self.assertEqual(dp_val, pe_val) + + def test_get_min_nbor_dist(self) -> None: + """get_min_nbor_dist should return the same value on dp, pt, and pt_expt.""" + dp_val = self.dp_model.get_min_nbor_dist() + pt_val = self.pt_model.get_min_nbor_dist() + pe_val = self.pt_expt_model.get_min_nbor_dist() + self.assertEqual(dp_val, pt_val) + self.assertEqual(dp_val, pe_val) diff --git a/source/tests/pt_expt/model/test_spin_ener_model.py b/source/tests/pt_expt/model/test_spin_ener_model.py new file mode 100644 index 0000000000..8789f944e9 --- /dev/null +++ b/source/tests/pt_expt/model/test_spin_ener_model.py @@ -0,0 +1,481 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.pt_expt.model.spin_ener_model import ( + SpinEnergyModel, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) + +dtype = torch.float64 + +SPIN_DATA = { + "type_map": ["O", "H", "B"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20, 20], + "rcut_smth": 0.50, + "rcut": 4.00, + "neuron": [3, 6], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + "spin": { + "use_spin": [True, False, False], + "virtual_scale": [0.3140], + }, +} + + +def finite_difference(f, x, delta=1e-6): + in_shape = x.shape + y0 = f(x) + out_shape = y0.shape + res = np.empty(out_shape + in_shape) + for idx in np.ndindex(*in_shape): + diff = np.zeros(in_shape) + diff[idx] += delta + y1p = f(x + diff) + y1n = f(x - diff) + res[(Ellipsis, *idx)] = (y1p - y1n) / (2 * delta) + return res + + +def stretch_box(old_coord, old_box, new_box): + ocoord = old_coord.reshape(-1, 3) + obox = old_box.reshape(3, 3) + nbox = new_box.reshape(3, 3) + ncoord = ocoord @ np.linalg.inv(obox) @ nbox + return ncoord.reshape(old_coord.shape) + + +def _make_model(): + dp_model = get_model_dp(SPIN_DATA) + model = SpinEnergyModel.deserialize(dp_model.serialize()).to(env.DEVICE) + model.eval() + return model + + +def eval_model(model, coord, cell, atype, spin): + """Evaluate the pt_expt SpinEnergyModel.""" + nframes = coord.shape[0] + if len(atype.shape) == 1: + atype = atype.unsqueeze(0).expand(nframes, -1) + coord_input = coord.to(dtype=dtype, device=env.DEVICE) + cell_input = cell.reshape(nframes, 9).to(dtype=dtype, device=env.DEVICE) + atype_input = atype.to(dtype=torch.long, device=env.DEVICE) + spin_input = spin.to(dtype=dtype, device=env.DEVICE) + coord_input.requires_grad_(True) + result = model(coord_input, atype_input, spin_input, cell_input) + return result + + +class TestSpinEnerModelOutputKeys(unittest.TestCase): + def test_output_keys(self) -> None: + """Test that SpinEnergyModel produces expected output keys.""" + model = _make_model() + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([6, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.tensor([0, 0, 1, 0, 1, 1], dtype=torch.int64) + spin = torch.rand([6, 3], dtype=dtype, device="cpu", generator=generator) * 0.5 + + result = eval_model( + model, + coord.unsqueeze(0), + cell.unsqueeze(0), + atype, + spin.unsqueeze(0), + ) + self.assertIn("energy", result) + self.assertIn("atom_energy", result) + self.assertIn("force", result) + self.assertIn("force_mag", result) + self.assertIn("mask_mag", result) + self.assertIn("virial", result) + + def test_output_shapes(self) -> None: + """Test that output shapes are correct.""" + model = _make_model() + natoms = 6 + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.tensor([0, 0, 1, 0, 1, 1], dtype=torch.int64) + spin = ( + torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + * 0.5 + ) + + result = eval_model( + model, + coord.unsqueeze(0), + cell.unsqueeze(0), + atype, + spin.unsqueeze(0), + ) + self.assertEqual(result["energy"].shape, (1, 1)) + self.assertEqual(result["atom_energy"].shape, (1, natoms, 1)) + self.assertEqual(result["force"].shape, (1, natoms, 3)) + self.assertEqual(result["force_mag"].shape, (1, natoms, 3)) + self.assertEqual(result["mask_mag"].shape, (1, natoms, 1)) + self.assertEqual(result["virial"].shape, (1, 9)) + + +class TestSpinEnerModelSerialize(unittest.TestCase): + def test_serialize_deserialize(self) -> None: + """Test serialize/deserialize round-trip.""" + model = _make_model() + serialized = model.serialize() + model2 = SpinEnergyModel.deserialize(serialized).to(env.DEVICE) + model2.eval() + + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([6, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.tensor([0, 0, 1, 0, 1, 1], dtype=torch.int64) + spin = torch.rand([6, 3], dtype=dtype, device="cpu", generator=generator) * 0.5 + + ret1 = eval_model( + model, coord.unsqueeze(0), cell.unsqueeze(0), atype, spin.unsqueeze(0) + ) + ret2 = eval_model( + model2, coord.unsqueeze(0), cell.unsqueeze(0), atype, spin.unsqueeze(0) + ) + + for key in ["energy", "atom_energy", "force", "force_mag", "mask_mag"]: + np.testing.assert_allclose( + ret1[key].detach().cpu().numpy(), + ret2[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"Mismatch in {key} after round-trip", + ) + + +class TestSpinEnerModelDPConsistency(unittest.TestCase): + def test_dp_consistency(self) -> None: + """Test numerical consistency with dpmodel.""" + dp_model = get_model_dp(SPIN_DATA) + pt_model = SpinEnergyModel.deserialize(dp_model.serialize()).to(env.DEVICE) + pt_model.eval() + + coords_np = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 0.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=np.float64, + ).reshape(1, -1, 3) + atype_np = np.array([[0, 0, 1, 0, 1, 1]], dtype=np.int32) + box_np = np.array( + [13.0, 0, 0, 0, 13.0, 0, 0, 0, 13.0], dtype=np.float64 + ).reshape(1, 9) + spin_np = np.array( + [ + 0.50, + 0.30, + 0.20, + 0.40, + 0.25, + 0.15, + 0.10, + 0.05, + 0.08, + 0.12, + 0.07, + 0.09, + 0.45, + 0.35, + 0.28, + 0.11, + 0.06, + 0.03, + ], + dtype=np.float64, + ).reshape(1, -1, 3) + + dp_ret = dp_model(coords_np, atype_np, spin_np, box=box_np) + + pt_ret = eval_model( + pt_model, + torch.tensor(coords_np), + torch.tensor(box_np), + torch.tensor(atype_np.squeeze(), dtype=torch.int64), + torch.tensor(spin_np), + ) + + np.testing.assert_allclose( + dp_ret["energy"], + pt_ret["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + dp_ret["atom_energy"], + pt_ret["atom_energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + dp_ret["mask_mag"], + pt_ret["mask_mag"].detach().cpu().numpy(), + ) + + +class ForceTest: + def test(self) -> None: + places = 5 + delta = 1e-5 + natoms = 6 + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.tensor([0, 0, 1, 0, 1, 1], dtype=torch.int64) + spin = ( + torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + * 0.5 + ) + coord = coord.numpy() + + def np_infer_coord(coord): + result = eval_model( + self.model, + torch.tensor(coord, device=env.DEVICE).unsqueeze(0), + cell.unsqueeze(0), + atype, + spin.unsqueeze(0), + ) + ret = { + key: result[key].squeeze(0).detach().cpu().numpy() + for key in ["energy", "force", "force_mag", "virial"] + } + return ret + + def ff_coord(_coord): + return np_infer_coord(_coord)["energy"] + + fdf = -finite_difference(ff_coord, coord, delta=delta).squeeze() + rff = np_infer_coord(coord)["force"] + np.testing.assert_almost_equal(fdf, rff, decimal=places) + + +class VirialTest: + def test(self) -> None: + places = 5 + delta = 1e-4 + natoms = 6 + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.tensor([0, 0, 1, 0, 1, 1], dtype=torch.int64) + spin = ( + torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + * 0.5 + ) + coord = coord.numpy() + cell = cell.numpy() + + def np_infer(new_cell): + result = eval_model( + self.model, + torch.tensor( + stretch_box(coord, cell, new_cell), device="cpu" + ).unsqueeze(0), + torch.tensor(new_cell, device="cpu").unsqueeze(0), + atype, + spin.unsqueeze(0), + ) + ret = { + key: result[key].squeeze(0).detach().cpu().numpy() + for key in ["energy", "force", "virial"] + } + return ret + + def ff(bb): + return np_infer(bb)["energy"] + + fdv = ( + -(finite_difference(ff, cell, delta=delta).transpose(0, 2, 1) @ cell) + .squeeze() + .reshape(9) + ) + rfv = np_infer(cell)["virial"] + np.testing.assert_almost_equal(fdv, rfv, decimal=places) + + +class TestSpinEnerModelForce(unittest.TestCase, ForceTest): + def setUp(self) -> None: + self.model = _make_model() + + +class TestSpinEnerModelVirial(unittest.TestCase, VirialTest): + def setUp(self) -> None: + self.model = _make_model() + + +class TestSpinEnerModelExportable(unittest.TestCase): + def test_forward_lower_exportable(self) -> None: + """Test that SpinEnergyModel.forward_lower_exportable works with make_fx and torch.export.""" + from deepmd.dpmodel.utils import ( + build_neighbor_list, + extend_coord_with_ghosts, + normalize_coord, + ) + + model = _make_model() + natoms = 6 + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.tensor([0, 0, 1, 0, 1, 1], dtype=torch.int64) + spin = ( + torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + * 0.5 + ) + + # Build extended inputs (use original sel, not backbone's doubled sel) + rcut = model.get_rcut() + sel = SPIN_DATA["descriptor"]["sel"] + coord_np = coord.unsqueeze(0).numpy() + atype_np = atype.unsqueeze(0).numpy() + box_np = cell.reshape(1, 9).numpy() + coord_normalized = normalize_coord( + coord_np.reshape(1, natoms, 3), + box_np.reshape(1, 3, 3), + ) + ext_coord, ext_atype, mapping = extend_coord_with_ghosts( + coord_normalized, + atype_np, + box_np, + rcut, + ) + nlist = build_neighbor_list( + ext_coord, ext_atype, natoms, rcut, sel, distinguish_types=True + ) + ext_coord = ext_coord.reshape(1, -1, 3) + # Extend spin to ghost atoms using mapping + spin_np = spin.unsqueeze(0).numpy() + ext_spin = np.take_along_axis( + spin_np, + np.repeat(mapping[:, :, np.newaxis], 3, axis=2), + axis=1, + ) + + ext_coord_t = torch.tensor(ext_coord, dtype=dtype, device=env.DEVICE) + ext_atype_t = torch.tensor(ext_atype, dtype=torch.int64, device=env.DEVICE) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=env.DEVICE) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=env.DEVICE) + ext_spin_t = torch.tensor(ext_spin, dtype=dtype, device=env.DEVICE) + + output_keys = ( + "energy", + "extended_force", + "extended_force_mag", + "virial", + ) + + # --- eager reference --- + ret_eager = model.forward_lower( + ext_coord_t.requires_grad_(True), + ext_atype_t, + ext_spin_t, + nlist_t, + mapping_t, + ) + for key in output_keys: + self.assertIn(key, ret_eager, f"Missing key {key} in eager result") + + # --- trace with make_fx --- + traced = model.forward_lower_exportable( + ext_coord_t, + ext_atype_t, + ext_spin_t, + nlist_t, + mapping_t, + ) + self.assertIsInstance(traced, torch.nn.Module) + + # --- export with torch.export --- + exported = torch.export.export( + traced, + (ext_coord_t, ext_atype_t, ext_spin_t, nlist_t, mapping_t, None, None), + strict=False, + ) + self.assertIsNotNone(exported) + + # --- verify traced matches eager --- + ret_traced = traced( + ext_coord_t, ext_atype_t, ext_spin_t, nlist_t, mapping_t, None, None + ) + for key in output_keys: + np.testing.assert_allclose( + ret_eager[key].detach().cpu().numpy(), + ret_traced[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"traced vs eager: {key}", + ) + + # --- verify exported matches eager --- + ret_exported = exported.module()( + ext_coord_t, ext_atype_t, ext_spin_t, nlist_t, mapping_t, None, None + ) + for key in output_keys: + np.testing.assert_allclose( + ret_eager[key].detach().cpu().numpy(), + ret_exported[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"exported vs eager: {key}", + ) + + +if __name__ == "__main__": + unittest.main()