Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ repos:
# Python
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.8.1
rev: v0.8.2
hooks:
- id: ruff
args: ["--fix"]
Expand Down
17 changes: 9 additions & 8 deletions dpdata/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from abc import ABCMeta, abstractmethod
from functools import lru_cache
from typing import Any

import numpy as np

from dpdata.system import LabeledSystem, MultiSystems


def mae(errors: np.ndarray) -> np.float64:
def mae(errors: np.ndarray) -> np.floating[Any]:
"""Compute the mean absolute error (MAE).

Parameters
Expand All @@ -18,13 +19,13 @@ def mae(errors: np.ndarray) -> np.float64:

Returns
-------
np.float64
floating[Any]
mean absolute error (MAE)
"""
return np.mean(np.abs(errors))


def rmse(errors: np.ndarray) -> np.float64:
def rmse(errors: np.ndarray) -> np.floating[Any]:
"""Compute the root mean squared error (RMSE).

Parameters
Expand All @@ -34,7 +35,7 @@ def rmse(errors: np.ndarray) -> np.float64:

Returns
-------
np.float64
floating[Any]
root mean squared error (RMSE)
"""
return np.sqrt(np.mean(np.square(errors)))
Expand Down Expand Up @@ -74,22 +75,22 @@ def f_errors(self) -> np.ndarray:
"""Force errors."""

@property
def e_mae(self) -> np.float64:
def e_mae(self) -> np.floating[Any]:
"""Energy MAE."""
return mae(self.e_errors)

@property
def e_rmse(self) -> np.float64:
def e_rmse(self) -> np.floating[Any]:
"""Energy RMSE."""
return rmse(self.e_errors)

@property
def f_mae(self) -> np.float64:
def f_mae(self) -> np.floating[Any]:
"""Force MAE."""
return mae(self.f_errors)

@property
def f_rmse(self) -> np.float64:
def f_rmse(self) -> np.floating[Any]:
"""Force RMSE."""
return rmse(self.f_errors)

Expand Down