Skip to content

Commit 31fbcfc

Browse files
committed
MAINT & STY: update based on PR #34 comments and style fixes
1 parent 5fe46a2 commit 31fbcfc

8 files changed

Lines changed: 158 additions & 251 deletions

File tree

nff/analysis/loss_plot.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,38 +4,35 @@
44
from . import mpl_settings
55

66

7-
def plot_loss(energy_history, forces_history, figname, train_key="train", val_key="val"):
7+
def plot_loss(
8+
energy_history: dict,
9+
forces_history: dict,
10+
figname: str,
11+
train_key: str = "train",
12+
val_key: str = "val",
13+
) -> None:
814
"""Plot the loss history of the model.
9-
Args:
10-
energy_history (dict): energy loss history of the model for training and validation
11-
forces_history (dict): forces loss history of the model for training and validation
12-
figname (str): name of the figure
13-
1415
15-
Returns:
16-
None
16+
Args:
17+
energy_history: energy loss history of the model for training and validation
18+
forces_history: forces loss history of the model for training and validation
19+
figname: name of the figure
20+
train_key: key for training data in the history dictionary
21+
val_key: key for validation data in the history dictionary
1722
"""
1823
epochs = np.arange(1, len(energy_history[train_key]) + 1)
1924
fig, ax_fig = plt.subplots(1, 2, figsize=(5, 2.5), dpi=mpl_settings.DPI)
20-
ax_fig[0].semilogy(
21-
epochs, energy_history[train_key], label="train", color=mpl_settings.colors[1]
22-
)
25+
ax_fig[0].semilogy(epochs, energy_history[train_key], label="train", color=mpl_settings.colors[1])
2326
ax_fig[0].semilogy(epochs, energy_history[val_key], label="val", color=mpl_settings.colors[2])
2427
ax_fig[0].legend()
2528
ax_fig[0].set_xlabel("Epoch")
2629
ax_fig[0].set_ylabel("Loss")
27-
ax_fig[0].set_xlabel("Epoch")
28-
ax_fig[0].set_ylabel("Loss")
2930

30-
ax_fig[1].semilogy(
31-
epochs, forces_history[train_key], label="train", color=mpl_settings.colors[1]
32-
)
31+
ax_fig[1].semilogy(epochs, forces_history[train_key], label="train", color=mpl_settings.colors[1])
3332
ax_fig[1].semilogy(epochs, forces_history[val_key], label="val", color=mpl_settings.colors[2])
3433
ax_fig[1].legend()
3534
ax_fig[1].set_xlabel("Epoch")
3635
ax_fig[1].set_ylabel("Loss")
37-
ax_fig[1].set_xlabel("Epoch")
38-
ax_fig[1].set_ylabel("Loss")
3936

4037
plt.tight_layout()
4138
plt.savefig(f"{figname}.png")

nff/analysis/mpl_settings.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from __future__ import annotations
2+
13
import json
24
from pathlib import Path
3-
from typing import List, Optional
5+
from typing import List
46

57
import matplotlib as mpl
68
import matplotlib.pyplot as plt
@@ -61,11 +63,11 @@
6163
plt.rcParams.update(custom_settings)
6264

6365

64-
def update_custom_settings(custom_settings: dict = custom_settings) -> None:
66+
def update_custom_settings(custom_settings: dict | None = custom_settings) -> None:
6567
"""Update the custom settings for Matplotlib.
6668
6769
Args:
68-
custom_settings (dict, optional): Custom settings for Matplotlib. Defaults to
70+
custom_settings: Custom settings for Matplotlib. Defaults to
6971
custom_settings.
7072
"""
7173
current_settings = plt.rcParams.copy()
@@ -77,10 +79,7 @@ def hex_to_rgb(value: str) -> list[float]:
7779
"""Converts hex to rgb colors.
7880
7981
Args:
80-
value (str): string of 6 characters representing a hex colour.
81-
82-
Returns:
83-
list: length 3 of RGB values
82+
value: string of 6 characters representing a hex color.
8483
"""
8584
value = value.strip("#") # removes hash symbol if present
8685
lv = len(value)
@@ -91,7 +90,7 @@ def rgb_to_dec(value: list[float]) -> list[float]:
9190
"""Converts rgb to decimal colors (i.e. divides each value by 256).
9291
9392
Args:
94-
value (list[float]): string of 6 characters representing a hex colour.
93+
value: string of 6 characters representing a hex color.
9594
9695
Returns:
9796
list: length 3 of RGB values
@@ -107,12 +106,9 @@ def get_continuous_cmap(
107106
each color in hex_list is mapped to the respective location in float_list.
108107
109108
Args:
110-
hex_list (list[str]): list of hex code strings
111-
float_list (list[float]): list of floats between 0 and 1, same length as hex_list. Must
112-
start with 0 and end with 1.
113-
114-
Returns:
115-
matplotlib.colors.LinearSegmentedColormap: continuous
109+
hex_list: list of hex code strings
110+
float_list: list of floats between 0 and 1, same length as hex_list.
111+
Must start with 0 and end with 1.
116112
"""
117113
rgb_list = [rgb_to_dec(hex_to_rgb(i)) for i in hex_list]
118114
if float_list:
@@ -122,9 +118,7 @@ def get_continuous_cmap(
122118

123119
cdict = dict()
124120
for num, col in enumerate(["red", "green", "blue"]):
125-
col_list = [
126-
[float_list[i], rgb_list[i][num], rgb_list[i][num]] for i in range(len(float_list))
127-
]
121+
col_list = [[float_list[i], rgb_list[i][num], rgb_list[i][num]] for i in range(len(float_list))]
128122
cdict[col] = col_list
129123
return mpl.colors.LinearSegmentedColormap("j_cmap", segmentdata=cdict, N=256)
130124

nff/analysis/parity_plot.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
from typing import Dict, Literal, Union
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Dict, Literal
24

35
import matplotlib.pyplot as plt
46
import numpy as np
57
import pandas as pd
6-
import torch
78
from matplotlib.lines import Line2D
89
from scipy import stats
910
from scipy.stats import gaussian_kde
1011

12+
if TYPE_CHECKING:
13+
from torch import Tensor
1114
from nff.data import to_tensor
1215
from nff.utils import cuda
1316

@@ -18,8 +21,8 @@
1821

1922

2023
def plot_parity(
21-
results: Dict[str, Union[list, torch.Tensor]],
22-
targets: Dict[str, Union[list, torch.Tensor]],
24+
results: Dict[str, list | Tensor],
25+
targets: Dict[str, list | Tensor],
2326
figname: str,
2427
plot_type: Literal["hexbin", "scatter"] = "hexbin",
2528
energy_key: str = "energy",
@@ -29,13 +32,13 @@ def plot_parity(
2932
"""Perform a parity plot between the results and the targets.
3033
3134
Args:
32-
results (dict): dictionary containing the results
33-
targets (dict): dictionary containing the targets
34-
figname (str): name of the figure
35-
plot_type (str): type of plot to use, either "hexbin" or "scatter"
36-
energy_key (str): key for the energy
37-
force_key (str): key for the forces
38-
units (dict): dictionary containing the units of the keys
35+
results: dictionary containing the results
36+
targets: dictionary containing the targets
37+
figname: name of the figure
38+
plot_type: type of plot to use, either "hexbin" or "scatter"
39+
energy_key: key for the energy
40+
force_key: key for the forces
41+
units: dictionary containing the units of the keys
3942
4043
Returns:
4144
float: MAE of the energy
@@ -113,8 +116,8 @@ def plot_parity(
113116

114117

115118
def plot_err_var(
116-
err: Union[torch.Tensor, np.ndarray],
117-
var: Union[torch.Tensor, np.ndarray],
119+
err: Tensor | np.ndarray,
120+
var: Tensor | np.ndarray,
118121
figname: str,
119122
units: str = "eV/Å",
120123
x_min: float = 0.0,
@@ -128,20 +131,17 @@ def plot_err_var(
128131
"""Plot the error vs variance of the forces.
129132
130133
Args:
131-
err (torch.Tensor): error of the forces
132-
var (torch.Tensor): variance of the forces
133-
figname (str): name of the figure
134-
units (str): units of the error and variance
135-
x_min (float): minimum value of the x-axis
136-
x_max (float): maximum value of the x-axis
137-
y_min (float): minimum value of the y-axis
138-
y_max (float): maximum value of the y-axis
139-
sample_frac (float): fraction of the data to sample for the plot
140-
num_bins (int): number of bins to use for binning
141-
cb_format (str): format of the colorbar
142-
143-
Returns:
144-
None
134+
err: error of the forces
135+
var: variance of the forces
136+
figname: name of the figure
137+
units: units of the error and variance
138+
x_min: minimum value of the x-axis
139+
x_max: maximum value of the x-axis
140+
y_min: minimum value of the y-axis
141+
y_max: maximum value of the y-axis
142+
sample_frac: fraction of the data to sample for the plot
143+
num_bins: number of bins to use for binning
144+
cb_format: format of the colorbar
145145
"""
146146
fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=mpl_settings.DPI)
147147

@@ -166,12 +166,7 @@ def plot_err_var(
166166
x = pd.Series(var)
167167
y = pd.Series(err)
168168

169-
kernel = gaussian_kde(
170-
np.vstack([x.sample(n=len(x), random_state=2), y.sample(n=len(y), random_state=2)])
171-
)
172-
kernel = gaussian_kde(
173-
np.vstack([x.sample(n=len(x), random_state=2), y.sample(n=len(y), random_state=2)])
174-
)
169+
kernel = gaussian_kde(np.vstack([x.sample(n=len(x), random_state=2), y.sample(n=len(y), random_state=2)]))
175170
c = kernel(np.vstack([x, y]))
176171
hb = ax.scatter(
177172
var,

0 commit comments

Comments
 (0)