DFT-D3 wrapper#219
Conversation
| options.engine_cutoff("angstrom"), | ||
| cell=system.cell, | ||
| pbc=system.pbc, | ||
| max_neighbors=16384, |
There was a problem hiding this comment.
Why is this required? and why this value?
There was a problem hiding this comment.
When testing on mad dataset, there were some errors related to the number of neighbors. Chatty lifted this value twice to get everything passed. It chose this value because it satisfies 2^n, I think. If we are interested in which structure led to the error, I can do more tests
There was a problem hiding this comment.
Yes please! I think we could set this to a similar value as the one we use in vesin, which is based on the cutoff + number of atoms.
Yes, ideally running this code should not require internet access
Isn't this implemented?
Yes!
This should be doable, but how would it look when working with heterogeneous systems?
Yes!
Why would someone want to do or not do this? I don't quite understand the implications. |
Co-authored-by: Guillaume Fraux <luthaf@luthaf.fr>
Co-authored-by: Guillaume Fraux <luthaf@luthaf.fr>
…d create a script for downloading and parsing the parameters
…ded more comments
Luthaf
left a comment
There was a problem hiding this comment.
This will need more review, just sending the notes I took during our discussion
|
Here I upload the parity plot of the current version of d3 wrapper against #!/usr/bin/env python3
from __future__ import annotations
import argparse
import os
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import ase.io
import numpy as np
import torch
from metatomic.torch import load_atomistic_model
from metatomic.torch.dftd3 import DFTD3
from metatomic_ase import MetatomicCalculator
ROOT = Path('/home/qxu/repos/metatomic')
DAMPING = {'s6': 1.0, 's8': 0.7875, 'a1': 0.4289, 'a2': 4.4407}
CUTOFF = 15.0
CN_CUTOFF = 15.0
def _stress_allowed(atoms) -> bool:
return bool(np.all(atoms.pbc)) and abs(atoms.cell.volume) > 0.0
def _batched(frames, batch_size: int, max_atoms: int):
batch = []
atoms_in_batch = 0
stress_allowed = False
for frame in frames:
frame_stress_allowed = _stress_allowed(frame[1])
n_atoms = len(frame[1])
if batch and frame_stress_allowed != stress_allowed:
yield batch
batch = []
atoms_in_batch = 0
if batch and max_atoms > 0 and atoms_in_batch + n_atoms > max_atoms:
yield batch
batch = []
atoms_in_batch = 0
batch.append(frame)
atoms_in_batch += n_atoms
stress_allowed = frame_stress_allowed
if len(batch) == batch_size:
yield batch
batch = []
atoms_in_batch = 0
if batch:
yield batch
def _simple_dftd3(atoms):
from dftd3.ase import DFTD3 as SimpleDFTD3
atoms = atoms.copy()
atoms.calc = SimpleDFTD3(
damping='d3bj',
params_tweaks={**DAMPING, 's9': 0.0},
realspace_cutoff={'disp2': CUTOFF, 'disp3': CUTOFF, 'cn': CN_CUTOFF},
)
energy = atoms.get_potential_energy()
forces = atoms.get_forces()
stress = atoms.get_stress(voigt=False) if _stress_allowed(atoms) else None
return energy, forces, stress
def _metrics(reference: np.ndarray, value: np.ndarray) -> str:
delta = np.asarray(value, dtype=float).reshape(-1) - np.asarray(
reference, dtype=float
).reshape(-1)
if delta.size == 0:
return 'not available'
return (
f'RMSE={np.sqrt(np.mean(delta * delta)):.3e} '
f'MAE={np.mean(np.abs(delta)):.3e} '
f'max={np.max(np.abs(delta)):.3e}'
)
def _plot_panel(ax, reference, value, title: str, max_points: int):
reference = np.asarray(reference, dtype=float).reshape(-1)
value = np.asarray(value, dtype=float).reshape(-1)
finite = np.isfinite(reference) & np.isfinite(value)
reference = reference[finite]
value = value[finite]
if reference.size == 0:
ax.set_title(title)
ax.text(0.5, 0.5, 'not available', ha='center', va='center')
ax.set_axis_off()
return
if reference.size > max_points:
rng = np.random.default_rng(0)
keep = rng.choice(reference.size, size=max_points, replace=False)
reference_plot = reference[keep]
value_plot = value[keep]
else:
reference_plot = reference
value_plot = value
lo = min(reference.min(), value.min())
hi = max(reference.max(), value.max())
pad = 0.03 * (hi - lo if hi > lo else max(abs(hi), 1.0))
lo -= pad
hi += pad
ax.scatter(reference_plot, value_plot, s=6, alpha=0.35, linewidths=0)
ax.plot([lo, hi], [lo, hi], color='black', linewidth=1)
ax.set_xlim(lo, hi)
ax.set_ylim(lo, hi)
ax.set_title(title)
ax.set_xlabel('simple-dftd3')
ax.set_ylabel('metatomic wrapper correction')
ax.text(
0.04,
0.96,
f'N={reference.size}\n{_metrics(reference, value)}',
transform=ax.transAxes,
va='top',
ha='left',
fontsize=9,
bbox={'facecolor': 'white', 'alpha': 0.75, 'edgecolor': 'none'},
)
def _plot(path: Path, ref_energy, value_energy, ref_forces, value_forces, ref_stress, value_stress, max_points: int):
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3, figsize=(14, 4.2), constrained_layout=True)
_plot_panel(axes[0], ref_energy, value_energy, 'Energy (eV)', max_points)
_plot_panel(axes[1], ref_forces, value_forces, 'Force components (eV/A)', max_points)
_plot_panel(axes[2], ref_stress, value_stress, 'Stress components (eV/A^3)', max_points)
fig.savefig(path, dpi=220)
plt.close(fig)
def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument('--xyz', default=str(ROOT / 'mad-train.xyz'))
parser.add_argument('--model', default=str(ROOT / 'pet-mad-s-v1.0.2.pt'))
parser.add_argument('--output', default='dftd3-simple-parity.png')
parser.add_argument('--npz', default='dftd3-simple-parity.npz')
parser.add_argument('--device', choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--batch-size', type=int, default=200)
parser.add_argument('--max-atoms-per-batch', type=int, default=1800)
parser.add_argument('--reference-workers', type=int, default=8)
parser.add_argument('--max-points', type=int, default=200_000)
args = parser.parse_args()
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'
model = load_atomistic_model(args.model)
wrapped = DFTD3.wrap(
model,
damping_params={'energy': DAMPING},
cutoff=CUTOFF,
cn_cutoff=CN_CUTOFF,
)
base_calc = MetatomicCalculator(
model,
device=args.device,
check_consistency=False,
non_conservative=False,
uncertainty_threshold=None,
)
wrapped_calc = MetatomicCalculator(
wrapped,
device=args.device,
check_consistency=False,
non_conservative=False,
uncertainty_threshold=None,
)
frame_ids = []
stress_frame_ids = []
reference_energy = []
reference_forces = []
reference_stress = []
wrapper_energy = []
wrapper_forces = []
wrapper_stress = []
reference_time = 0.0
wrapper_time = 0.0
frames = enumerate(ase.io.iread(args.xyz, ':'))
batches = _batched(frames, args.batch_size, args.max_atoms_per_batch)
for batch in batches:
indices = [index for index, _ in batch]
atoms_batch = [atoms for _, atoms in batch]
frame_ids.extend(indices)
batch_has_stress = all(_stress_allowed(atoms) for atoms in atoms_batch)
if args.device == 'cuda':
torch.cuda.synchronize()
start = time.perf_counter()
base = base_calc.compute_energy(atoms_batch, compute_forces_and_stresses=True)
wrapped_out = wrapped_calc.compute_energy(
atoms_batch, compute_forces_and_stresses=True
)
if args.device == 'cuda':
torch.cuda.synchronize()
wrapper_time += time.perf_counter() - start
wrapper_energy.append(
np.asarray(wrapped_out['energy'], dtype=float).reshape(-1)
- np.asarray(base['energy'], dtype=float).reshape(-1)
)
wrapper_forces.append(
np.concatenate(wrapped_out['forces'], axis=0)
- np.concatenate(base['forces'], axis=0)
)
if batch_has_stress:
wrapper_stress.append(
np.asarray(wrapped_out['stress'], dtype=float)
- np.asarray(base['stress'], dtype=float)
)
start = time.perf_counter()
with ThreadPoolExecutor(max_workers=args.reference_workers) as pool:
references = list(pool.map(_simple_dftd3, atoms_batch))
reference_time += time.perf_counter() - start
for index, (energy, forces, stress) in zip(indices, references, strict=True):
reference_energy.append(energy)
reference_forces.append(forces)
if stress is not None:
stress_frame_ids.append(index)
reference_stress.append(stress.reshape(1, 3, 3))
reference_energy = np.asarray(reference_energy)
reference_forces = np.concatenate(reference_forces, axis=0)
reference_stress = (
np.concatenate(reference_stress, axis=0)
if reference_stress
else np.empty((0, 3, 3))
)
wrapper_energy = np.concatenate(wrapper_energy, axis=0)
wrapper_forces = np.concatenate(wrapper_forces, axis=0)
wrapper_stress = (
np.concatenate(wrapper_stress, axis=0)
if wrapper_stress
else np.empty((0, 3, 3))
)
print(f'frames={len(frame_ids)}')
print(f'stress_frames={len(stress_frame_ids)}')
print(f'simple_dftd3_time_s={reference_time:.3f}')
print(f'wrapper_minus_base_time_s={wrapper_time:.3f}')
print('energy:', _metrics(reference_energy, wrapper_energy))
print('forces:', _metrics(reference_forces, wrapper_forces))
print('stress:', _metrics(reference_stress, wrapper_stress))
_plot(
Path(args.output),
reference_energy,
wrapper_energy,
reference_forces,
wrapper_forces,
reference_stress,
wrapper_stress,
args.max_points,
)
np.savez_compressed(
args.npz,
frame_ids=np.asarray(frame_ids, dtype=np.int64),
stress_frame_ids=np.asarray(stress_frame_ids, dtype=np.int64),
reference_energy=reference_energy,
wrapper_energy=wrapper_energy,
reference_forces=reference_forces,
wrapper_forces=wrapper_forces,
reference_stress=reference_stress,
wrapper_stress=wrapper_stress,
simple_dftd3_time_s=np.asarray(reference_time),
wrapper_minus_base_time_s=np.asarray(wrapper_time),
)
print(f'plot: {args.output}')
print(f'data: {args.npz}')
return 0
if __name__ == '__main__':
raise SystemExit(main())To run this script, one should compile simple-dftd3 from source, because one would have to modify this line |
| ``selected_atoms`` is supported with the usual domain-decomposition | ||
| convention: the D3 environment is computed with all atoms in each | ||
| :class:`System`, while pair energies are split equally between the two pair | ||
| endpoints and only the shares belonging to selected atoms are added. |
There was a problem hiding this comment.
Doesn't this mean that this code could support sample_kind = atom?
There was a problem hiding this comment.
Because the split of the pair energy is arbitrary, I prefer to disable the calculation of the per-atom energies. I just want to enable the usage in lammps by implementing this
| self._neighbor_list = NeighborListOptions( | ||
| cutoff=self._neighbor_cutoff, | ||
| full_list=False, | ||
| strict=True, | ||
| requestor="DFTD3", | ||
| ) |
There was a problem hiding this comment.
Shouldn't the code also request a NL for CN?
There was a problem hiding this comment.
No it just requests a NL for the largest cutoff, and screens the results to get the NL of the smaller cutoff
Co-authored-by: Guillaume Fraux <luthaf@luthaf.fr>

This PR aims to introduce a DFT-D3 wrapper, which can be wrapped outside of MLIPs in
AtomisticModel. The wrapper currently only supports D3(BJ).This wrapper calculates the energy correction of the input structure, and add it to the energy calculated by the wrapped MLIP. In this way, and through the auto-differentiation, the D3-corrected forces and stresses (if with full-PBC) can be automatically calculated.
I have tested this wrapper on MAD dataset. The energies, forces, and stresses are in good agreement with another wrapper implemented with the DFT-D3 provided in the nvalchemiops package.
The overhead is about 15%-35% for s models on systems ranging from 600 atoms to 4k atoms. Mainly from the neighbor list calculation, the actual overheads of forward and backward are only ~3% and ~8%.
Remaining issues:
ModelCapability? Currently it's updatedContributor (creator of pull-request) checklist
Reviewer checklist