Skip to content

Commit a496b6f

Browse files
Fix: pyright error: Type "floating[Any]" is not assignable to return type "ndarray[Unknown, Unknown]" (#765)
1 parent 5423efe commit a496b6f

5 files changed

Lines changed: 13 additions & 11 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ on:
66

77
jobs:
88
build:
9-
runs-on: ubuntu-latest
9+
runs-on: ubuntu-22.04
1010
strategy:
1111
matrix:
1212
python-version: ["3.7", "3.8", "3.12"]

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ repos:
2121
# Python
2222
- repo: https://github.com/astral-sh/ruff-pre-commit
2323
# Ruff version.
24-
rev: v0.8.1
24+
rev: v0.8.2
2525
hooks:
2626
- id: ruff
2727
args: ["--fix"]

dpdata/stat.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
from abc import ABCMeta, abstractmethod
44
from functools import lru_cache
5+
from typing import Any
56

67
import numpy as np
78

89
from dpdata.system import LabeledSystem, MultiSystems
910

1011

11-
def mae(errors: np.ndarray) -> np.float64:
12+
def mae(errors: np.ndarray) -> np.floating[Any]:
1213
"""Compute the mean absolute error (MAE).
1314
1415
Parameters
@@ -18,13 +19,13 @@ def mae(errors: np.ndarray) -> np.float64:
1819
1920
Returns
2021
-------
21-
np.float64
22+
floating[Any]
2223
mean absolute error (MAE)
2324
"""
2425
return np.mean(np.abs(errors))
2526

2627

27-
def rmse(errors: np.ndarray) -> np.float64:
28+
def rmse(errors: np.ndarray) -> np.floating[Any]:
2829
"""Compute the root mean squared error (RMSE).
2930
3031
Parameters
@@ -34,7 +35,7 @@ def rmse(errors: np.ndarray) -> np.float64:
3435
3536
Returns
3637
-------
37-
np.float64
38+
floating[Any]
3839
root mean squared error (RMSE)
3940
"""
4041
return np.sqrt(np.mean(np.square(errors)))
@@ -74,22 +75,22 @@ def f_errors(self) -> np.ndarray:
7475
"""Force errors."""
7576

7677
@property
77-
def e_mae(self) -> np.float64:
78+
def e_mae(self) -> np.floating[Any]:
7879
"""Energy MAE."""
7980
return mae(self.e_errors)
8081

8182
@property
82-
def e_rmse(self) -> np.float64:
83+
def e_rmse(self) -> np.floating[Any]:
8384
"""Energy RMSE."""
8485
return rmse(self.e_errors)
8586

8687
@property
87-
def f_mae(self) -> np.float64:
88+
def f_mae(self) -> np.floating[Any]:
8889
"""Force MAE."""
8990
return mae(self.f_errors)
9091

9192
@property
92-
def f_rmse(self) -> np.float64:
93+
def f_rmse(self) -> np.floating[Any]:
9394
"""Force RMSE."""
9495
return rmse(self.f_errors)
9596

dpdata/system.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,6 +1049,7 @@ def remove_atom_names(self, atom_names: str | list[str]):
10491049
atom_idx = self.data["atom_types"] == idx
10501050
removed_atom_idx.append(atom_idx)
10511051
picked_atom_idx = ~np.any(removed_atom_idx, axis=0)
1052+
assert not isinstance(picked_atom_idx, np.bool_)
10521053
new_sys = self.pick_atom_idx(picked_atom_idx)
10531054
# let's remove atom_names
10541055
# firstly, rearrange atom_names and put these atom_names in the end

tests/test_abacus_pw_scf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_noforcestress_job(self):
158158
# check below will not throw error
159159
system_ch4 = dpdata.LabeledSystem("abacus.scf", fmt="abacus/scf")
160160
# check the returned force is empty
161-
self.assertFalse(system_ch4.data["forces"])
161+
self.assertFalse(system_ch4.data["forces"].size)
162162
self.assertTrue("virials" not in system_ch4.data)
163163
# test append self
164164
system_ch4.append(system_ch4)

0 commit comments

Comments
 (0)