Skip to content

DFT-D3 wrapper#219

Open
GardevoirX wants to merge 26 commits into
metatensor:mainfrom
GardevoirX:DFT-D3-pytorch
Open

DFT-D3 wrapper#219
GardevoirX wants to merge 26 commits into
metatensor:mainfrom
GardevoirX:DFT-D3-pytorch

Conversation

@GardevoirX
Copy link
Copy Markdown
Contributor

@GardevoirX GardevoirX commented May 4, 2026

This PR aims to introduce a DFT-D3 wrapper, which can be wrapped outside of MLIPs in AtomisticModel. The wrapper currently only supports D3(BJ).

This wrapper calculates the energy correction of the input structure, and add it to the energy calculated by the wrapped MLIP. In this way, and through the auto-differentiation, the D3-corrected forces and stresses (if with full-PBC) can be automatically calculated.

I have tested this wrapper on MAD dataset. The energies, forces, and stresses are in good agreement with another wrapper implemented with the DFT-D3 provided in the nvalchemiops package.

The overhead is about 15%-35% for s models on systems ranging from 600 atoms to 4k atoms. Mainly from the neighbor list calculation, the actual overheads of forward and backward are only ~3% and ~8%.

Remaining issues:

  • The D3 parameters shared across different functionals is now downloaded (through this function) and passed to the wrapper by the user. Maybe we need to incorporate this into the wrapper
  • The D3 parameters are in atomic units, so maybe a unit conversion helper is needed
  • Maybe also the correction to the non-conservative forces and stresses
  • Sometimes we might want to put correction to only a part of the atoms in the system (requested by @cesaremalosso, Ref: https://pubs.acs.org/doi/full/10.1021/acs.jpclett.3c00856 )
  • Should we update the interaction range in ModelCapability? Currently it's updated
  • Should we force users to use the same cutoff as the DFT-D3 calculation for calculating the coordination numbers of atoms?
  • Correctly cite the packages used during the development (tad-dftd3, nvalchemiops) (Maybe)

Contributor (creator of pull-request) checklist

  • Tests updated (for new features and bugfixes)?
  • Documentation updated (for new features)?
  • Issue referenced (for PRs that solve an issue)?

Reviewer checklist

  • CHANGELOG updated with public API or any other important changes?

@GardevoirX GardevoirX requested a review from Luthaf May 4, 2026 11:12
options.engine_cutoff("angstrom"),
cell=system.cell,
pbc=system.pbc,
max_neighbors=16384,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this required? and why this value?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When testing on mad dataset, there were some errors related to the number of neighbors. Chatty lifted this value twice to get everything passed. It chose this value because it satisfies 2^n, I think. If we are interested in which structure led to the error, I can do more tests

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please! I think we could set this to a similar value as the one we use in vesin, which is based on the cutoff + number of atoms.

Comment thread python/metatomic_torch/tests/dftd3.py Outdated
Comment thread python/metatomic_torch/tests/dftd3.py Outdated
Comment thread python/metatomic_torch/tests/dftd3.py Outdated
Comment thread python/metatomic_torch/metatomic/torch/dftd3.py Outdated
Comment thread python/metatomic_torch/metatomic/torch/dftd3.py Outdated
Comment thread python/metatomic_torch/metatomic/torch/dftd3.py Outdated
Comment thread python/metatomic_torch/metatomic/torch/dftd3.py Outdated
Comment thread python/metatomic_torch/metatomic/torch/dftd3.py Outdated
Comment thread python/metatomic_torch/metatomic/torch/dftd3.py Outdated
@Luthaf
Copy link
Copy Markdown
Member

Luthaf commented May 6, 2026

The D3 parameters shared across different functionals is now downloaded (through this function) and passed to the wrapper by the user. Maybe we need to incorporate this into the wrapper

Yes, ideally running this code should not require internet access

The D3 parameters are in atomic units, so maybe a unit conversion helper is needed

Isn't this implemented?

Maybe also the correction to the non-conservative forces and stresses

Yes!

Sometimes we might want to put correction to only a part of the atoms in the system (requested by @cesaremalosso )

This should be doable, but how would it look when working with heterogeneous systems?

Should we update the interaction range in ModelCapability? Currently it's updated

Yes!

Should we force users to use the same cutoff as the DFT-D3 calculation for calculating the coordination numbers of atoms?

Why would someone want to do or not do this? I don't quite understand the implications.

@GardevoirX GardevoirX requested a review from Luthaf May 8, 2026 12:53
Copy link
Copy Markdown
Member

@Luthaf Luthaf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will need more review, just sending the notes I took during our discussion

Comment thread python/metatomic_torch/tests/dftd3.py Outdated
Comment thread python/metatomic_torch/tests/dftd3.py Outdated
Comment thread python/metatomic_torch/metatomic/torch/dftd3.py Outdated
Comment thread python/metatomic_torch/metatomic/torch/dftd3.py
Comment thread python/metatomic_torch/metatomic/torch/dftd3.py Outdated
Comment thread python/metatomic_torch/tests/dftd3.py
@GardevoirX
Copy link
Copy Markdown
Contributor Author

Here I upload the parity plot of the current version of d3 wrapper against simple-dftd3
image
and the script

#!/usr/bin/env python3
from __future__ import annotations

import argparse
import os
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path

import ase.io
import numpy as np
import torch

from metatomic.torch import load_atomistic_model
from metatomic.torch.dftd3 import DFTD3
from metatomic_ase import MetatomicCalculator


ROOT = Path('/home/qxu/repos/metatomic')
DAMPING = {'s6': 1.0, 's8': 0.7875, 'a1': 0.4289, 'a2': 4.4407}
CUTOFF = 15.0
CN_CUTOFF = 15.0


def _stress_allowed(atoms) -> bool:
    return bool(np.all(atoms.pbc)) and abs(atoms.cell.volume) > 0.0


def _batched(frames, batch_size: int, max_atoms: int):
    batch = []
    atoms_in_batch = 0
    stress_allowed = False

    for frame in frames:
        frame_stress_allowed = _stress_allowed(frame[1])
        n_atoms = len(frame[1])

        if batch and frame_stress_allowed != stress_allowed:
            yield batch
            batch = []
            atoms_in_batch = 0
        if batch and max_atoms > 0 and atoms_in_batch + n_atoms > max_atoms:
            yield batch
            batch = []
            atoms_in_batch = 0

        batch.append(frame)
        atoms_in_batch += n_atoms
        stress_allowed = frame_stress_allowed

        if len(batch) == batch_size:
            yield batch
            batch = []
            atoms_in_batch = 0

    if batch:
        yield batch


def _simple_dftd3(atoms):
    from dftd3.ase import DFTD3 as SimpleDFTD3

    atoms = atoms.copy()
    atoms.calc = SimpleDFTD3(
        damping='d3bj',
        params_tweaks={**DAMPING, 's9': 0.0},
        realspace_cutoff={'disp2': CUTOFF, 'disp3': CUTOFF, 'cn': CN_CUTOFF},
    )

    energy = atoms.get_potential_energy()
    forces = atoms.get_forces()
    stress = atoms.get_stress(voigt=False) if _stress_allowed(atoms) else None
    return energy, forces, stress


def _metrics(reference: np.ndarray, value: np.ndarray) -> str:
    delta = np.asarray(value, dtype=float).reshape(-1) - np.asarray(
        reference, dtype=float
    ).reshape(-1)
    if delta.size == 0:
        return 'not available'
    return (
        f'RMSE={np.sqrt(np.mean(delta * delta)):.3e}  '
        f'MAE={np.mean(np.abs(delta)):.3e}  '
        f'max={np.max(np.abs(delta)):.3e}'
    )


def _plot_panel(ax, reference, value, title: str, max_points: int):
    reference = np.asarray(reference, dtype=float).reshape(-1)
    value = np.asarray(value, dtype=float).reshape(-1)
    finite = np.isfinite(reference) & np.isfinite(value)
    reference = reference[finite]
    value = value[finite]

    if reference.size == 0:
        ax.set_title(title)
        ax.text(0.5, 0.5, 'not available', ha='center', va='center')
        ax.set_axis_off()
        return

    if reference.size > max_points:
        rng = np.random.default_rng(0)
        keep = rng.choice(reference.size, size=max_points, replace=False)
        reference_plot = reference[keep]
        value_plot = value[keep]
    else:
        reference_plot = reference
        value_plot = value

    lo = min(reference.min(), value.min())
    hi = max(reference.max(), value.max())
    pad = 0.03 * (hi - lo if hi > lo else max(abs(hi), 1.0))
    lo -= pad
    hi += pad

    ax.scatter(reference_plot, value_plot, s=6, alpha=0.35, linewidths=0)
    ax.plot([lo, hi], [lo, hi], color='black', linewidth=1)
    ax.set_xlim(lo, hi)
    ax.set_ylim(lo, hi)
    ax.set_title(title)
    ax.set_xlabel('simple-dftd3')
    ax.set_ylabel('metatomic wrapper correction')
    ax.text(
        0.04,
        0.96,
        f'N={reference.size}\n{_metrics(reference, value)}',
        transform=ax.transAxes,
        va='top',
        ha='left',
        fontsize=9,
        bbox={'facecolor': 'white', 'alpha': 0.75, 'edgecolor': 'none'},
    )


def _plot(path: Path, ref_energy, value_energy, ref_forces, value_forces, ref_stress, value_stress, max_points: int):
    import matplotlib

    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 3, figsize=(14, 4.2), constrained_layout=True)
    _plot_panel(axes[0], ref_energy, value_energy, 'Energy (eV)', max_points)
    _plot_panel(axes[1], ref_forces, value_forces, 'Force components (eV/A)', max_points)
    _plot_panel(axes[2], ref_stress, value_stress, 'Stress components (eV/A^3)', max_points)
    fig.savefig(path, dpi=220)
    plt.close(fig)


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument('--xyz', default=str(ROOT / 'mad-train.xyz'))
    parser.add_argument('--model', default=str(ROOT / 'pet-mad-s-v1.0.2.pt'))
    parser.add_argument('--output', default='dftd3-simple-parity.png')
    parser.add_argument('--npz', default='dftd3-simple-parity.npz')
    parser.add_argument('--device', choices=['cpu', 'cuda'], default='cuda')
    parser.add_argument('--batch-size', type=int, default=200)
    parser.add_argument('--max-atoms-per-batch', type=int, default=1800)
    parser.add_argument('--reference-workers', type=int, default=8)
    parser.add_argument('--max-points', type=int, default=200_000)
    args = parser.parse_args()

    os.environ['OMP_NUM_THREADS'] = '1'
    os.environ['OPENBLAS_NUM_THREADS'] = '1'
    os.environ['MKL_NUM_THREADS'] = '1'

    model = load_atomistic_model(args.model)
    wrapped = DFTD3.wrap(
        model,
        damping_params={'energy': DAMPING},
        cutoff=CUTOFF,
        cn_cutoff=CN_CUTOFF,
    )
    base_calc = MetatomicCalculator(
        model,
        device=args.device,
        check_consistency=False,
        non_conservative=False,
        uncertainty_threshold=None,
    )
    wrapped_calc = MetatomicCalculator(
        wrapped,
        device=args.device,
        check_consistency=False,
        non_conservative=False,
        uncertainty_threshold=None,
    )

    frame_ids = []
    stress_frame_ids = []
    reference_energy = []
    reference_forces = []
    reference_stress = []
    wrapper_energy = []
    wrapper_forces = []
    wrapper_stress = []
    reference_time = 0.0
    wrapper_time = 0.0

    frames = enumerate(ase.io.iread(args.xyz, ':'))
    batches = _batched(frames, args.batch_size, args.max_atoms_per_batch)

    for batch in batches:
        indices = [index for index, _ in batch]
        atoms_batch = [atoms for _, atoms in batch]
        frame_ids.extend(indices)
        batch_has_stress = all(_stress_allowed(atoms) for atoms in atoms_batch)

        if args.device == 'cuda':
            torch.cuda.synchronize()
        start = time.perf_counter()
        base = base_calc.compute_energy(atoms_batch, compute_forces_and_stresses=True)
        wrapped_out = wrapped_calc.compute_energy(
            atoms_batch, compute_forces_and_stresses=True
        )
        if args.device == 'cuda':
            torch.cuda.synchronize()
        wrapper_time += time.perf_counter() - start

        wrapper_energy.append(
            np.asarray(wrapped_out['energy'], dtype=float).reshape(-1)
            - np.asarray(base['energy'], dtype=float).reshape(-1)
        )
        wrapper_forces.append(
            np.concatenate(wrapped_out['forces'], axis=0)
            - np.concatenate(base['forces'], axis=0)
        )
        if batch_has_stress:
            wrapper_stress.append(
                np.asarray(wrapped_out['stress'], dtype=float)
                - np.asarray(base['stress'], dtype=float)
            )

        start = time.perf_counter()
        with ThreadPoolExecutor(max_workers=args.reference_workers) as pool:
            references = list(pool.map(_simple_dftd3, atoms_batch))
        reference_time += time.perf_counter() - start

        for index, (energy, forces, stress) in zip(indices, references, strict=True):
            reference_energy.append(energy)
            reference_forces.append(forces)
            if stress is not None:
                stress_frame_ids.append(index)
                reference_stress.append(stress.reshape(1, 3, 3))

    reference_energy = np.asarray(reference_energy)
    reference_forces = np.concatenate(reference_forces, axis=0)
    reference_stress = (
        np.concatenate(reference_stress, axis=0)
        if reference_stress
        else np.empty((0, 3, 3))
    )
    wrapper_energy = np.concatenate(wrapper_energy, axis=0)
    wrapper_forces = np.concatenate(wrapper_forces, axis=0)
    wrapper_stress = (
        np.concatenate(wrapper_stress, axis=0)
        if wrapper_stress
        else np.empty((0, 3, 3))
    )

    print(f'frames={len(frame_ids)}')
    print(f'stress_frames={len(stress_frame_ids)}')
    print(f'simple_dftd3_time_s={reference_time:.3f}')
    print(f'wrapper_minus_base_time_s={wrapper_time:.3f}')
    print('energy:', _metrics(reference_energy, wrapper_energy))
    print('forces:', _metrics(reference_forces, wrapper_forces))
    print('stress:', _metrics(reference_stress, wrapper_stress))

    _plot(
        Path(args.output),
        reference_energy,
        wrapper_energy,
        reference_forces,
        wrapper_forces,
        reference_stress,
        wrapper_stress,
        args.max_points,
    )
    np.savez_compressed(
        args.npz,
        frame_ids=np.asarray(frame_ids, dtype=np.int64),
        stress_frame_ids=np.asarray(stress_frame_ids, dtype=np.int64),
        reference_energy=reference_energy,
        wrapper_energy=wrapper_energy,
        reference_forces=reference_forces,
        wrapper_forces=wrapper_forces,
        reference_stress=reference_stress,
        wrapper_stress=wrapper_stress,
        simple_dftd3_time_s=np.asarray(reference_time),
        wrapper_minus_base_time_s=np.asarray(wrapper_time),
    )
    print(f'plot: {args.output}')
    print(f'data: {args.npz}')
    return 0


if __name__ == '__main__':
    raise SystemExit(main())

To run this script, one should compile simple-dftd3 from source, because one would have to modify this line
https://github.com/dftd3/simple-dftd3/blob/6f0b06fbfa8653a23ca55c453772ce3af4420706/src/dftd3/cutoff.f90#L127
from spread(periodic, 2, 3) to spread(periodic, 1, 3) which, imho, is the correct implementation

Comment thread python/metatomic_torch/setup.py Outdated
Comment thread python/metatomic_torch/tests/dftd3.py
Comment thread python/metatomic_torch/tests/dftd3.py
Comment thread python/metatomic_torch/tests/dftd3.py
Comment thread python/metatomic_torch/tests/dftd3.py
Comment on lines +78 to +81
``selected_atoms`` is supported with the usual domain-decomposition
convention: the D3 environment is computed with all atoms in each
:class:`System`, while pair energies are split equally between the two pair
endpoints and only the shares belonging to selected atoms are added.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this mean that this code could support sample_kind = atom?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the split of the pair energy is arbitrary, I prefer to disable the calculation of the per-atom energies. I just want to enable the usage in lammps by implementing this

Comment thread python/metatomic_torch/metatomic/torch/dftd3.py Outdated
Comment on lines +311 to +316
self._neighbor_list = NeighborListOptions(
cutoff=self._neighbor_cutoff,
full_list=False,
strict=True,
requestor="DFTD3",
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the code also request a NL for CN?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it just requests a NL for the largest cutoff, and screens the results to get the NL of the smaller cutoff

Comment thread python/metatomic_torch/metatomic/torch/dftd3.py
Comment thread python/metatomic_torch/metatomic/torch/dftd3.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants