Skip to content

Commit 9eba59f

Browse files
committed
refactor losses
1 parent 99e371d commit 9eba59f

2 files changed

Lines changed: 81 additions & 90 deletions

File tree

emulator/src/core/losses.py

Lines changed: 74 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
1+
import abc
12
import torch
23
import torch.nn as nn
34
import logging
45
import gpytorch
56

6-
import xarray as xr
7-
87
from pytorch_lightning.utilities import rank_zero_only
98

10-
import numpy as np
11-
129
# import problems from utils
1310
def get_logger(name=__name__, level=logging.INFO) -> logging.Logger:
1411
"""Initializes multi-GPU-friendly python logger."""
@@ -93,24 +90,64 @@ def forward(self, pred, y):
9390

9491
return error
9592

96-
class NRMSELoss_s_ClimateBench(nn.Module):
93+
# Parent class with latitude weights and some other functions
94+
class ClimateSetLoss(nn.Module):
95+
def __init__(self):
96+
super().__init__()
97+
# weighting to account for decreasing grid-cell area towards poles
98+
def get_latitude_weights(self, lat_size: int) -> torch.Tensor:
99+
""" Returns latitude weights for a given number of data points along the latitude axis (y.shape[-2]).
100+
The weights are 1 at the equator and decrease towards the poles.
101+
102+
Parameters:
103+
lat_size (int): How many latitude data points should be considered
104+
Returns:
105+
torch.Tensor: Latitude weights
106+
"""
107+
# we are not using -90, 90, but a small offset to make sure none of the datapoints get weights of 0
108+
lats = torch.linspace(-89.75, 89.75, lat_size)
109+
weights = torch.cos((torch.pi * lats) / 180)
110+
weights = weights.unsqueeze(-1)
111+
return weights
112+
113+
def weighted_global_mean(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
114+
""" Get a weighted global mean of a tensor.
115+
Parameters:
116+
x (torch.Tensor): Values that are to be averaged
117+
weights (torch.Tensor): Latitude weights for x
118+
Returns:
119+
torch.Tensor: Single number representing the global mean
120+
"""
121+
# sum_lat(sum_lon(x * weights)) / N_lat * N_lon
122+
# i.e.: sum(sum(x * weights)) / (96 * 144)
123+
return torch.mean(x * weights, dim=(-2, -1)) # dims order does not matter
124+
125+
def check_lat_lon(self, pred: torch.Tensor, y: torch.Tensor):
126+
""" Functions that checks if latitude and longitude is behaving as expected.
127+
Parameters:
128+
pred (torch.Tensor): Predictions
129+
y (torch.Tensor): Targets
130+
"""
131+
# Expected shape: [4, 12, 96, 144] -> [batch, time, latitude, longitude]
132+
if (pred.shape[-1] == 1) or (y[-1].shape == 1):
133+
raise ValueError("Loss function: Last dimension (values/channels) must be squeezed away")
134+
135+
if (pred.shape[-1] < pred.shape[-2]) or (y.shape[-1] < y.shape[-2]):
136+
raise ValueError("There are more latitude than longitude grid cells. Check if you swapped longitude and latitude.")
137+
138+
139+
class NRMSELoss_s_ClimateBench(ClimateSetLoss):
97140
"""
98141
Spatial normalized weighted RMSE taken from Climate Bench.
99142
Weighting to account for decreasing grid size towards the poles.
100143
"""
101-
102144
def __init__(self):
103145
super().__init__()
104146
self.mse = nn.MSELoss(reduction="none")
105147

106-
def forward(self, pred, y):
107-
# weighting to account for decreasing grid-cell area towards poles
108-
# latitude weights
109-
lat_size = y.shape[-2]
110-
lats = torch.linspace(-89.75, 89.75, lat_size)
111-
weights = torch.cos((torch.pi * lats) / 180)
112-
113-
weights = weights.unsqueeze(-1)
148+
def forward(self, pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
149+
self.check_lat_lon(pred, y)
150+
weights = self.get_latitude_weights(y.shape[-2])
114151
weights = weights.to(device)
115152

116153
# nrmses = sqrt((weights * (pred_mean_t - y_mean_n_t)**2)_mean_s) / ((weights*y)_mean_s)_mean_t_n
@@ -127,28 +164,18 @@ def forward(self, pred, y):
127164

128165
return nrmse_s
129166

130-
def weighted_global_mean(self, x, weights):
131-
# sum_lat(sum_lon(x * weights)) / N_lat * N_lon
132-
# i.e.: sum(sum(x * weights)) / (96 * 144)
133-
return torch.mean(x * weights, dim=(-2, -1)) # dims order does not matter
134-
135-
class NRMSELoss_g_ClimateBench(nn.Module):
167+
class NRMSELoss_g_ClimateBench(ClimateSetLoss):
136168
"""
137169
Spatial normalized weighted RMSE taken from Climate Bench.
138170
Weighting to account for decreasing grid size towards the pole.
139171
"""
140-
141172
def __init__(self):
142173
super().__init__()
143174
self.mse = nn.MSELoss(reduction="none")
144175

145-
def forward(self, pred, y):
146-
#latitude weighting to account for decreasing grid-cell area towards pole
147-
lat_size = y.shape[-2]
148-
lats = torch.linspace(-89.75, 89.75, lat_size)
149-
# same like np.cos(np.deg2rad(lats))
150-
weights = torch.cos((torch.pi * lats) / 180)
151-
weights = weights.unsqueeze(-1)
176+
def forward(self, pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
177+
self.check_lat_lon(pred, y)
178+
weights = self.get_latitude_weights(y.shape[-2])
152179
weights = weights.to(device)
153180

154181
# nrmseg = sqrt(
@@ -164,83 +191,67 @@ def forward(self, pred, y):
164191
)
165192
/ self.weighted_global_mean(y, weights).mean(dim=(0, 1))
166193
)
167-
# TODO understand: the values are in the same range like nrmse_s - why do we need to adapt them?
168194
return nrmse_g
169195

170-
def weighted_global_mean(self, x, weights):
171-
# sum_lat(sum_lon(x * weights)) / N_lat * N_lon
172-
# i.e.: sum(sum(x * weights)) / (144 * 96)
173-
return torch.mean(x * weights, dim=(-2, -1))
174-
175-
class NRMSELoss_ClimateBench(nn.Module):
196+
class NRMSELoss_ClimateBench(ClimateSetLoss):
176197
"""
177198
Combination of global weighted and spatially weighted nrmse.
178-
179199
"""
180-
181200
def __init__(self, alpha: int = 5):
182201
super().__init__()
183202

184203
self.nrmse_g = NRMSELoss_g_ClimateBench()
185204
self.nrmse_s = NRMSELoss_s_ClimateBench()
186205
self.alpha = alpha
187206

188-
def forward(self, pred, y):
207+
def forward(self, pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
208+
self.check_lat_lon(pred, y)
189209
nrmseg = self.nrmse_g(pred, y)
190210
nrmses = self.nrmse_s(pred, y)
191211
nrmse = nrmses + self.alpha * nrmseg
192212
return nrmse
193213

194-
class LLWeighted_RMSELoss_WeatherBench(nn.Module):
195-
214+
class LLWeighted_RMSELoss_WeatherBench(ClimateSetLoss):
196215
"""
197216
Weigthed RMSE taken from Weather Bench.
198217
Weighting to account for decreasing grid sizes towards the pole.
199218
200219
rmse = mean over forecasts and time of torch.sqrt( mean over lon lat L(lat_j)*)MSE(pred, y)
201220
weights = cos(latitude)/cos(latitude).mean()
202221
"""
203-
204222
def __init__(self):
205223
super().__init__()
206224

207225
self.mse = nn.MSELoss(reduction="none")
208226

209-
def forward(self, pred, y):
210-
211-
lat_size = y.shape[-2]
212-
lats = torch.linspace(-89.75, 89.75, lat_size)
213-
weights = torch.cos((torch.pi * lats) / 180)
214-
weights = weights.unsqueeze(-1)
227+
def forward(self, pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
228+
self.check_lat_lon(pred, y)
229+
weights = self.get_latitude_weights(y.shape[-2])
230+
weights = weights.to(device)
215231
weights = weights.to(device)
216232

217233
#rmse_before = torch.sqrt(torch.mean(weights * self.mse(pred, y), dim=(-2, -1))).mean()
218234
rmse = torch.mean(torch.sqrt(torch.mean(weights * ((pred - y)**2), dim=([-2, -1]))))
219235

220236
return rmse
221237

222-
class LLweighted_MSELoss_Climax(nn.Module):
238+
class LLweighted_MSELoss_Climax(ClimateSetLoss):
223239
"""
224240
Latitude weighted mean squared error taken from ClimaX.
225241
Allows to weight the loss by the cosine of the latitude to account for gridding differences at equator vs. poles.
226242
Applied per variable.
227243
If given a mask, normalized by sum of that.
228-
229244
"""
230-
231245
def __init__(self, mask=None):
232246
super().__init__()
233247

234248
self.mse = nn.MSELoss(reduction="none")
235249
self.mask = mask
236250

237-
def forward(self, pred, y):
251+
def forward(self, pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
252+
self.check_lat_lon(pred, y)
238253
mse = self.mse(pred, y)
239-
240-
lat_size = y.shape[-2]
241-
lats = torch.linspace(-89.75, 89.75, lat_size)
242-
weights = torch.cos((torch.pi * lats) / 180)
243-
weights = weights.unsqueeze(-1)
254+
weights = self.get_latitude_weights(y.shape[-2])
244255
weights = weights.to(device)
245256

246257
# how they create the weights (does not work for us, results make no sense)
@@ -261,46 +272,32 @@ def forward(self, pred, y):
261272
return error
262273

263274

264-
class LLweighted_RMSELoss_Climax(nn.Module):
275+
class LLweighted_RMSELoss_Climax(ClimateSetLoss):
265276
"""
266277
Latitude weighted root mean squared error taken from ClimaX.
267278
Allows to weight the loss by the cosine of the latitude to account for gridding differences at equator vs. poles.
268279
Applied per variable.
269280
If given a mask, normalized by sum of that.
270281
"""
271-
272282
def __init__(self, mask=None):
273283
super().__init__()
274284

275285
self.mse = nn.MSELoss(reduction="none")
276286
self.mask = mask
277287

278-
def forward(self, pred, y):
288+
def forward(self, pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
279289
""" Latitude is expected to be on position -2
280290
"""
281-
lat_num_grid_cells = y.shape[-2]
282-
283-
# Expected shape: [4, 12, 96, 144] -> [batch, time, latitude, longitude]
284-
if (pred.shape[-1] == 1) or (y[-1].shape == 1):
285-
raise ValueError("Loss function: Last dimension (values/channels) must be squeezed away")
286-
287-
if (pred.shape[-1] < pred.shape[-2]):
288-
raise ValueError("There are more latitude than longitude grid cells. Check if you swapped longitude and latitude.")
289-
291+
self.check_lat_lon(pred, y)
290292
mse = self.mse(pred, y) # [batch, time, lat, lon]
291293

292-
latitudes = torch.linspace(-89.75, 89.75, lat_num_grid_cells)
293-
# torch.abs: -90 and + 90 get -0.000X as weight -> make all weights positive
294-
weights = torch.abs(torch.cos(torch.deg2rad(latitudes)))
295-
weights = weights.unsqueeze(-1)
294+
weights = self.get_latitude_weights(y.shape[-2])
295+
weights = weights.to(device)
296296

297297
# ClimaX creates weird weights by dividing them by the mean
298298
# this leads to the climax rmse and mse to be the exact same like the unweighted mse / rmse
299299
#weights = weights / weights.mean() # ignored in this code
300300

301-
# move weights to device
302-
weights = weights.to(device)
303-
304301
if self.mask is not None:
305302
raise NotImplementedError("Masking is not supported in the loss functions anymore.")
306303

@@ -335,8 +332,7 @@ def forward(self, pred, y):
335332

336333
llmse_cx = LLweighted_MSELoss_Climax()
337334
llrmse_cx = LLweighted_RMSELoss_Climax()
338-
339-
# MSE: CHECKED
335+
340336
loss = mse(dummy, targets)
341337
print("MSE loss", loss, loss.size())
342338
# np_dummy = dummy.cpu().detach().numpy()
@@ -363,17 +359,9 @@ def forward(self, pred, y):
363359
loss = llrmse_cx(dummy, targets)
364360
print("CX rmse loss", loss, loss.size())
365361

366-
# TESTS for losses:
367-
# - with specific tensor of 1s + offset
368-
# - with specific random tensors
369-
# - make sure output size is only one number (except if several channels?)
370-
371-
# - compare losses for ones: rmse == nrmse_g
372-
# - with channels for different variables (make sure it's not breaking) / doing whatever is needed
373-
# - make sure WB and CX rmse losses are the same
374362

375363
# REFACTOR
376-
# - weight function should be one function (utils)
364+
# - weight function should be one function
377365

378366
# Same tests needed for metrics
379367

tests/test_scoring/test_losses.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
PRECISION_VALUE = 0.0005
2424

25+
# TODO add channel issue
26+
# TODO test weighting function
27+
# TODO test if numpy weights and torch weights are the same!
28+
2529
@pytest.fixture
2630
def rand_targets():
2731
return torch.rand(size=(batch_size, out_time, lat, lon))
@@ -87,7 +91,7 @@ def test_nrmse(rand_predics, rand_targets, ones_predics, ones_targets):
8791
expected_loss(loss_rand, 0.1142, 0.005)
8892

8993
def test_wb_rmse(rand_predics, rand_targets, ones_predics, ones_targets):
90-
error = LLWeighted_RMSELoss_WheatherBench()
94+
error = LLWeighted_RMSELoss_WeatherBench()
9195
loss_ones = error(ones_predics, ones_targets)
9296
loss_rand = error(rand_predics, rand_targets)
9397
expected_loss(loss_ones, 0.0795, PRECISION_VALUE)
@@ -112,7 +116,7 @@ def test_equality_ones(ones_predics, ones_targets):
112116
nrmse_s = NRMSELoss_s_ClimateBench()
113117
nrmse_g = NRMSELoss_g_ClimateBench()
114118
nrmse = NRMSELoss_ClimateBench()
115-
wb_rmse = LLWeighted_RMSELoss_WheatherBench()
119+
wb_rmse = LLWeighted_RMSELoss_WeatherBench()
116120
cx_rmse = LLweighted_RMSELoss_Climax()
117121

118122
rmse_loss_ones = rmse(ones_predics, ones_targets)
@@ -127,7 +131,7 @@ def test_equality_ones(ones_predics, ones_targets):
127131
assert (nrmse_s_loss_ones.item() + 5 * nrmse_g_loss_ones.item()) == pytest.approx(nrmse_loss_ones.item(), abs=PRECISION_VALUE)
128132

129133
def test_equality_rand(rand_predics, rand_targets):
130-
wb_rmse = LLWeighted_RMSELoss_WheatherBench()
134+
wb_rmse = LLWeighted_RMSELoss_WeatherBench()
131135
cx_rmse = LLweighted_RMSELoss_Climax()
132136
nrmse_s = NRMSELoss_s_ClimateBench()
133137
nrmse_g = NRMSELoss_g_ClimateBench()
@@ -143,4 +147,3 @@ def test_equality_rand(rand_predics, rand_targets):
143147
assert (nrmse_s_loss_rand.item() + 5 * nrmse_g_loss_rand.item()) == pytest.approx(nrmse_loss_rand.item(), abs=PRECISION_VALUE)
144148

145149

146-
# TODO add channel issue

0 commit comments

Comments
 (0)