feat: support aparam derivative in ener loss#5285
feat: support aparam derivative in ener loss#5285anyangml wants to merge 12 commits intodeepmodeling:masterfrom
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Pull request overview
This PR adds support for aparam (atomic parameter) derivative loss training in the energy loss module of DeePMD-kit's PyTorch backend. It enables models to be trained against the derivative of the total energy with respect to atomic parameters (d(∑E)/d(aparam)), using labeled grad_aparam data.
Changes:
- Adds
start_pref_ap/limit_pref_apprefactor arguments to the energy loss argcheck configuration - Implements the
aparamgradient loss computation inEnergyStdLoss.__init__andforward, including serialization of the new prefactors - Injects
numb_aparamfrom the model into the loss parameters during training setup inget_loss()
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
deepmd/utils/argcheck.py |
Adds start_pref_ap and limit_pref_ap arguments to the loss_ener configuration schema |
deepmd/pt/loss/ener.py |
Implements aparam gradient loss in EnergyStdLoss: new init parameters, forward computation via torch.autograd.grad, label requirement, and serialization |
deepmd/pt/train/training.py |
Injects numb_aparam into loss parameters when aparam gradient loss is configured |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
deepmd/pt/loss/ener.py
Outdated
| "enable_atom_ener_coeff": self.enable_atom_ener_coeff, | ||
| "start_pref_gf": self.start_pref_gf, | ||
| "limit_pref_gf": self.limit_pref_gf, | ||
| "start_pref_ap": self.start_pref_ap, | ||
| "limit_pref_ap": self.limit_pref_ap, |
There was a problem hiding this comment.
The numb_aparam field is stored as an instance attribute (self.numb_aparam) and is required to reconstruct the loss object when has_ap is true, but it is not included in the serialize() return dict. When deserialize() calls cls(**data), the numb_aparam argument will be absent, causing the reconstructed object to default to numb_aparam=0. If start_pref_ap or limit_pref_ap is non-zero in the serialized data, the __init__ will then raise RuntimeError("numb_aparam must be > 0 when aparam gradient loss is enabled"), making deserialization impossible for models that use this feature.
Additionally, since a new field is added to the serialized representation, the @version should be bumped (e.g. to 3) and check_version_compatibility in deserialize() updated to check_version_compatibility(data.pop("@version"), 3, 1) to reflect the schema change, as is the established convention in the codebase.
| if whether_hessian(loss_params): | ||
| loss_params["starter_learning_rate"] = start_lr | ||
| return EnergyHessianStdLoss(**loss_params) | ||
| elif loss_type == "ener": | ||
| loss_params["starter_learning_rate"] = start_lr | ||
| if ( | ||
| loss_params.get("start_pref_ap", 0.0) != 0.0 | ||
| or loss_params.get("limit_pref_ap", 0.0) != 0.0 | ||
| ): | ||
| loss_params["numb_aparam"] = _model.get_dim_aparam() |
There was a problem hiding this comment.
When whether_hessian() returns True, the code falls into the first branch (line 1688-1690) and creates EnergyHessianStdLoss. The numb_aparam injection (lines 1693-1697) is in the elif loss_type == "ener" branch, which is skipped. As a result, if a user configures both start_pref_h > 0 and start_pref_ap != 0.0, the EnergyHessianStdLoss constructor will receive numb_aparam=0 (the default), causing the RuntimeError "numb_aparam must be > 0 when aparam gradient loss is enabled". The numb_aparam injection logic should be moved to cover both branches, or factored out into a shared helper.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5285 +/- ##
==========================================
+ Coverage 82.16% 82.20% +0.04%
==========================================
Files 753 755 +2
Lines 75865 76095 +230
Branches 3648 3660 +12
==========================================
+ Hits 62335 62555 +220
- Misses 12362 12368 +6
- Partials 1168 1172 +4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Anyang Peng <137014849+anyangml@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| "limit_pref_gf": self.limit_pref_gf, | ||
| "start_pref_ap": self.start_pref_ap, | ||
| "limit_pref_ap": self.limit_pref_ap, | ||
| "numb_generalized_coord": self.numb_generalized_coord, |
There was a problem hiding this comment.
numb_aparam is not serialized in the serialize() method. Without it, deserialize() will pass numb_aparam=0 to __init__, causing a RuntimeError when has_ap is True. Add "numb_aparam": self.numb_aparam to the serialized dict.
| "numb_generalized_coord": self.numb_generalized_coord, | |
| "numb_generalized_coord": self.numb_generalized_coord, | |
| "numb_aparam": self.numb_aparam, |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
No description provided.