1- from typing import Dict , Literal , Union
1+ from __future__ import annotations
2+
3+ from typing import TYPE_CHECKING , Dict , Literal
24
35import matplotlib .pyplot as plt
46import numpy as np
57import pandas as pd
6- import torch
78from matplotlib .lines import Line2D
89from scipy import stats
910from scipy .stats import gaussian_kde
1011
12+ if TYPE_CHECKING :
13+ from torch import Tensor
1114from nff .data import to_tensor
1215from nff .utils import cuda
1316
1821
1922
2023def plot_parity (
21- results : Dict [str , Union [ list , torch . Tensor ] ],
22- targets : Dict [str , Union [ list , torch . Tensor ] ],
24+ results : Dict [str , list | Tensor ],
25+ targets : Dict [str , list | Tensor ],
2326 figname : str ,
2427 plot_type : Literal ["hexbin" , "scatter" ] = "hexbin" ,
2528 energy_key : str = "energy" ,
@@ -29,13 +32,13 @@ def plot_parity(
2932 """Perform a parity plot between the results and the targets.
3033
3134 Args:
32- results (dict) : dictionary containing the results
33- targets (dict) : dictionary containing the targets
34- figname (str) : name of the figure
35- plot_type (str) : type of plot to use, either "hexbin" or "scatter"
36- energy_key (str) : key for the energy
37- force_key (str) : key for the forces
38- units (dict) : dictionary containing the units of the keys
35+ results: dictionary containing the results
36+ targets: dictionary containing the targets
37+ figname: name of the figure
38+ plot_type: type of plot to use, either "hexbin" or "scatter"
39+ energy_key: key for the energy
40+ force_key: key for the forces
41+ units: dictionary containing the units of the keys
3942
4043 Returns:
4144 float: MAE of the energy
@@ -113,8 +116,8 @@ def plot_parity(
113116
114117
115118def plot_err_var (
116- err : Union [ torch . Tensor , np .ndarray ] ,
117- var : Union [ torch . Tensor , np .ndarray ] ,
119+ err : Tensor | np .ndarray ,
120+ var : Tensor | np .ndarray ,
118121 figname : str ,
119122 units : str = "eV/Å" ,
120123 x_min : float = 0.0 ,
@@ -128,20 +131,17 @@ def plot_err_var(
128131 """Plot the error vs variance of the forces.
129132
130133 Args:
131- err (torch.Tensor): error of the forces
132- var (torch.Tensor): variance of the forces
133- figname (str): name of the figure
134- units (str): units of the error and variance
135- x_min (float): minimum value of the x-axis
136- x_max (float): maximum value of the x-axis
137- y_min (float): minimum value of the y-axis
138- y_max (float): maximum value of the y-axis
139- sample_frac (float): fraction of the data to sample for the plot
140- num_bins (int): number of bins to use for binning
141- cb_format (str): format of the colorbar
142-
143- Returns:
144- None
134+ err: error of the forces
135+ var: variance of the forces
136+ figname: name of the figure
137+ units: units of the error and variance
138+ x_min: minimum value of the x-axis
139+ x_max: maximum value of the x-axis
140+ y_min: minimum value of the y-axis
141+ y_max: maximum value of the y-axis
142+ sample_frac: fraction of the data to sample for the plot
143+ num_bins: number of bins to use for binning
144+ cb_format: format of the colorbar
145145 """
146146 fig , ax = plt .subplots (1 , 1 , figsize = (6 , 6 ), dpi = mpl_settings .DPI )
147147
@@ -166,12 +166,7 @@ def plot_err_var(
166166 x = pd .Series (var )
167167 y = pd .Series (err )
168168
169- kernel = gaussian_kde (
170- np .vstack ([x .sample (n = len (x ), random_state = 2 ), y .sample (n = len (y ), random_state = 2 )])
171- )
172- kernel = gaussian_kde (
173- np .vstack ([x .sample (n = len (x ), random_state = 2 ), y .sample (n = len (y ), random_state = 2 )])
174- )
169+ kernel = gaussian_kde (np .vstack ([x .sample (n = len (x ), random_state = 2 ), y .sample (n = len (y ), random_state = 2 )]))
175170 c = kernel (np .vstack ([x , y ]))
176171 hb = ax .scatter (
177172 var ,
0 commit comments