Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions deepmd/backend/pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from collections.abc import (
Callable,
)
from typing import (
TYPE_CHECKING,
ClassVar,
)

from deepmd.backend.backend import (
Backend,
)

if TYPE_CHECKING:
from argparse import (
Namespace,
)

from deepmd.infer.deep_eval import (
DeepEvalBackend,
)
from deepmd.utils.neighbor_stat import (
NeighborStat,
)


@Backend.register("pretrained")
class PretrainedBackend(Backend):
Comment thread
njzjz marked this conversation as resolved.
"""Internal virtual backend for ``*.pretrained`` alias dispatch.
Comment thread
njzjz marked this conversation as resolved.
Outdated

This backend is not intended to be selected explicitly by users as a real
compute backend (such as TensorFlow/PyTorch/Paddle/JAX). It only bridges
``*.pretrained`` aliases into the regular deep-eval loading path.
"""

name = "Pretrained"
features: ClassVar[Backend.Feature] = Backend.Feature.DEEP_EVAL
suffixes: ClassVar[list[str]] = [".pretrained"]

def is_available(self) -> bool:
return True

@property
def entry_point_hook(self) -> Callable[["Namespace"], None]:
raise NotImplementedError("Unsupported backend: pretrained")

@property
def deep_eval(self) -> type["DeepEvalBackend"]:
from deepmd.pretrained.deep_eval import (
PretrainedDeepEvalBackend,
)

return PretrainedDeepEvalBackend

@property
def neighbor_stat(self) -> type["NeighborStat"]:
raise NotImplementedError("Unsupported backend: pretrained")

@property
def serialize_hook(self) -> Callable[[str], dict]:
raise NotImplementedError("Unsupported backend: pretrained")

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
raise NotImplementedError("Unsupported backend: pretrained")
5 changes: 5 additions & 0 deletions deepmd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
from deepmd.loggers.loggers import (
set_log_handles,
)
from deepmd.pretrained.entrypoints import (
pretrained_entrypoint,
)


def main(args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -97,5 +100,7 @@ def main(args: argparse.Namespace) -> None:
convert_backend(**dict_args)
elif args.command == "show":
show(**dict_args)
elif args.command == "pretrained":
pretrained_entrypoint(args)
else:
raise ValueError(f"Unknown command: {args.command}")
33 changes: 33 additions & 0 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from deepmd.backend.backend import (
Backend,
)
from deepmd.pretrained.registry import (
available_model_names,
)

try:
from deepmd._version import version as __version__
Expand Down Expand Up @@ -942,6 +945,35 @@ def main_parser() -> argparse.ArgumentParser:
],
nargs="+",
)

# pretrained
parser_pretrained = subparsers.add_parser(
"pretrained",
parents=[parser_log],
help="Manage builtin pretrained models",
formatter_class=RawTextArgumentDefaultsHelpFormatter,
)
pretrained_subparsers = parser_pretrained.add_subparsers(
dest="pretrained_command",
required=True,
)
parser_pretrained_download = pretrained_subparsers.add_parser(
"download",
help="Download one pretrained model",
)

parser_pretrained_download.add_argument(
"MODEL",
choices=available_model_names(),
help="Pretrained model name",
)
parser_pretrained_download.add_argument(
"--cache-dir",
default=None,
type=str,
help="Optional cache directory for pretrained model files",
)

return parser


Expand Down Expand Up @@ -997,6 +1029,7 @@ def main(args: list[str] | None = None) -> None:
"gui",
"convert-backend",
"show",
"pretrained",
):
# common entrypoints
from deepmd.entrypoints.main import main as deepmd_main
Expand Down
2 changes: 2 additions & 0 deletions deepmd/pretrained/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Pretrained model helpers for DeePMD-kit."""
178 changes: 178 additions & 0 deletions deepmd/pretrained/deep_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""DeepEval adapter for `*.pretrained` model aliases."""

from __future__ import (
annotations,
)

from pathlib import (
Path,
)
from typing import (
TYPE_CHECKING,
Any,
)

from deepmd.infer.deep_eval import (
DeepEval,
DeepEvalBackend,
)
from deepmd.pretrained.download import (
resolve_model_path,
)

if TYPE_CHECKING:
import numpy as np


class InvalidPretrainedAliasError(ValueError):
"""Raised when a pretrained alias string is malformed."""

def __init__(self, model_file: str) -> None:
super().__init__(f"Invalid pretrained alias: {model_file}")


def parse_pretrained_alias(model_file: str) -> str:
"""Extract model name from ``*.pretrained`` alias string."""
alias = Path(model_file).name
suffix = ".pretrained"
if not alias.lower().endswith(suffix):
raise InvalidPretrainedAliasError(model_file)

model_name = alias[: -len(suffix)]
if not model_name:
raise InvalidPretrainedAliasError(model_file)

return model_name


class PretrainedDeepEvalBackend(DeepEvalBackend):
"""Resolve alias and delegate to backend selected by resolved model path."""

def __init__(
self,
model_file: str,
output_def: object,
*args: object,
auto_batch_size: object = True,
neighbor_list: object | None = None,
**kwargs: object,
) -> None:
model_name = parse_pretrained_alias(model_file)
resolved = str(resolve_model_path(model_name))

# DeepEvalBackend.__new__ dispatches by resolved suffix (.pt/.pb/.dp...)
self._backend = DeepEvalBackend(
resolved,
output_def,
*args,
auto_batch_size=auto_batch_size,
neighbor_list=neighbor_list,
**kwargs,
)

def eval(
self,
coords: np.ndarray,
cells: np.ndarray | None,
atom_types: np.ndarray,
atomic: bool = False,
fparam: np.ndarray | None = None,
aparam: np.ndarray | None = None,
**kwargs: Any,
) -> dict[str, np.ndarray]:
return self._backend.eval(
coords,
cells,
atom_types,
atomic,
fparam=fparam,
aparam=aparam,
**kwargs,
)

def eval_descriptor(
self,
coords: np.ndarray,
cells: np.ndarray | None,
atom_types: np.ndarray,
fparam: np.ndarray | None = None,
aparam: np.ndarray | None = None,
efield: np.ndarray | None = None,
mixed_type: bool = False,
**kwargs: Any,
) -> np.ndarray:
return self._backend.eval_descriptor(
coords,
cells,
atom_types,
fparam=fparam,
aparam=aparam,
efield=efield,
mixed_type=mixed_type,
**kwargs,
)

def eval_fitting_last_layer(
self,
coords: np.ndarray,
cells: np.ndarray | None,
atom_types: np.ndarray,
fparam: np.ndarray | None = None,
aparam: np.ndarray | None = None,
**kwargs: Any,
) -> np.ndarray:
return self._backend.eval_fitting_last_layer(
coords,
cells,
atom_types,
fparam=fparam,
aparam=aparam,
**kwargs,
)

def get_rcut(self) -> float:
return self._backend.get_rcut()

def get_ntypes(self) -> int:
return self._backend.get_ntypes()

def get_type_map(self) -> list[str]:
return self._backend.get_type_map()

def get_dim_fparam(self) -> int:
return self._backend.get_dim_fparam()

def has_default_fparam(self) -> bool:
return self._backend.has_default_fparam()

def get_dim_aparam(self) -> int:
return self._backend.get_dim_aparam()

@property
def model_type(self) -> type[DeepEval]:
return self._backend.model_type

def get_sel_type(self) -> list[int]:
return self._backend.get_sel_type()

def get_numb_dos(self) -> int:
return self._backend.get_numb_dos()

def get_has_efield(self) -> bool:
return self._backend.get_has_efield()

def get_has_spin(self) -> bool:
return self._backend.get_has_spin()

def get_has_hessian(self) -> bool:
return self._backend.get_has_hessian()

def get_var_name(self) -> str:
return self._backend.get_var_name()

def get_ntypes_spin(self) -> int:
return self._backend.get_ntypes_spin()

def get_model(self) -> Any:
return self._backend.get_model()
Loading