1+ import abc
12import torch
23import torch .nn as nn
34import logging
45import gpytorch
56
6- import xarray as xr
7-
87from pytorch_lightning .utilities import rank_zero_only
98
10- import numpy as np
11-
129# import problems from utils
1310def get_logger (name = __name__ , level = logging .INFO ) -> logging .Logger :
1411 """Initializes multi-GPU-friendly python logger."""
@@ -93,24 +90,64 @@ def forward(self, pred, y):
9390
9491 return error
9592
96- class NRMSELoss_s_ClimateBench (nn .Module ):
93+ # Parent class with latitude weights and some other functions
94+ class ClimateSetLoss (nn .Module ):
95+ def __init__ (self ):
96+ super ().__init__ ()
97+ # weighting to account for decreasing grid-cell area towards poles
98+ def get_latitude_weights (self , lat_size : int ) -> torch .Tensor :
99+ """ Returns latitude weights for a given number of data points along the latitude axis (y.shape[-2]).
100+ The weights are 1 at the equator and decrease towards the poles.
101+
102+ Parameters:
103+ lat_size (int): How many latitude data points should be considered
104+ Returns:
105+ torch.Tensor: Latitude weights
106+ """
107+ # we are not using -90, 90, but a small offset to make sure none of the datapoints get weights of 0
108+ lats = torch .linspace (- 89.75 , 89.75 , lat_size )
109+ weights = torch .cos ((torch .pi * lats ) / 180 )
110+ weights = weights .unsqueeze (- 1 )
111+ return weights
112+
113+ def weighted_global_mean (self , x : torch .Tensor , weights : torch .Tensor ) -> torch .Tensor :
114+ """ Get a weighted global mean of a tensor.
115+ Parameters:
116+ x (torch.Tensor): Values that are to be averaged
117+ weights (torch.Tensor): Latitude weights for x
118+ Returns:
119+ torch.Tensor: Single number representing the global mean
120+ """
121+ # sum_lat(sum_lon(x * weights)) / N_lat * N_lon
122+ # i.e.: sum(sum(x * weights)) / (96 * 144)
123+ return torch .mean (x * weights , dim = (- 2 , - 1 )) # dims order does not matter
124+
125+ def check_lat_lon (self , pred : torch .Tensor , y : torch .Tensor ):
126+ """ Functions that checks if latitude and longitude is behaving as expected.
127+ Parameters:
128+ pred (torch.Tensor): Predictions
129+ y (torch.Tensor): Targets
130+ """
131+ # Expected shape: [4, 12, 96, 144] -> [batch, time, latitude, longitude]
132+ if (pred .shape [- 1 ] == 1 ) or (y [- 1 ].shape == 1 ):
133+ raise ValueError ("Loss function: Last dimension (values/channels) must be squeezed away" )
134+
135+ if (pred .shape [- 1 ] < pred .shape [- 2 ]) or (y .shape [- 1 ] < y .shape [- 2 ]):
136+ raise ValueError ("There are more latitude than longitude grid cells. Check if you swapped longitude and latitude." )
137+
138+
139+ class NRMSELoss_s_ClimateBench (ClimateSetLoss ):
97140 """
98141 Spatial normalized weighted RMSE taken from Climate Bench.
99142 Weighting to account for decreasing grid size towards the poles.
100143 """
101-
102144 def __init__ (self ):
103145 super ().__init__ ()
104146 self .mse = nn .MSELoss (reduction = "none" )
105147
106- def forward (self , pred , y ):
107- # weighting to account for decreasing grid-cell area towards poles
108- # latitude weights
109- lat_size = y .shape [- 2 ]
110- lats = torch .linspace (- 89.75 , 89.75 , lat_size )
111- weights = torch .cos ((torch .pi * lats ) / 180 )
112-
113- weights = weights .unsqueeze (- 1 )
148+ def forward (self , pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
149+ self .check_lat_lon (pred , y )
150+ weights = self .get_latitude_weights (y .shape [- 2 ])
114151 weights = weights .to (device )
115152
116153 # nrmses = sqrt((weights * (pred_mean_t - y_mean_n_t)**2)_mean_s) / ((weights*y)_mean_s)_mean_t_n
@@ -127,28 +164,18 @@ def forward(self, pred, y):
127164
128165 return nrmse_s
129166
130- def weighted_global_mean (self , x , weights ):
131- # sum_lat(sum_lon(x * weights)) / N_lat * N_lon
132- # i.e.: sum(sum(x * weights)) / (96 * 144)
133- return torch .mean (x * weights , dim = (- 2 , - 1 )) # dims order does not matter
134-
135- class NRMSELoss_g_ClimateBench (nn .Module ):
167+ class NRMSELoss_g_ClimateBench (ClimateSetLoss ):
136168 """
137169 Spatial normalized weighted RMSE taken from Climate Bench.
138170 Weighting to account for decreasing grid size towards the pole.
139171 """
140-
141172 def __init__ (self ):
142173 super ().__init__ ()
143174 self .mse = nn .MSELoss (reduction = "none" )
144175
145- def forward (self , pred , y ):
146- #latitude weighting to account for decreasing grid-cell area towards pole
147- lat_size = y .shape [- 2 ]
148- lats = torch .linspace (- 89.75 , 89.75 , lat_size )
149- # same like np.cos(np.deg2rad(lats))
150- weights = torch .cos ((torch .pi * lats ) / 180 )
151- weights = weights .unsqueeze (- 1 )
176+ def forward (self , pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
177+ self .check_lat_lon (pred , y )
178+ weights = self .get_latitude_weights (y .shape [- 2 ])
152179 weights = weights .to (device )
153180
154181 # nrmseg = sqrt(
@@ -164,83 +191,67 @@ def forward(self, pred, y):
164191 )
165192 / self .weighted_global_mean (y , weights ).mean (dim = (0 , 1 ))
166193 )
167- # TODO understand: the values are in the same range like nrmse_s - why do we need to adapt them?
168194 return nrmse_g
169195
170- def weighted_global_mean (self , x , weights ):
171- # sum_lat(sum_lon(x * weights)) / N_lat * N_lon
172- # i.e.: sum(sum(x * weights)) / (144 * 96)
173- return torch .mean (x * weights , dim = (- 2 , - 1 ))
174-
175- class NRMSELoss_ClimateBench (nn .Module ):
196+ class NRMSELoss_ClimateBench (ClimateSetLoss ):
176197 """
177198 Combination of global weighted and spatially weighted nrmse.
178-
179199 """
180-
181200 def __init__ (self , alpha : int = 5 ):
182201 super ().__init__ ()
183202
184203 self .nrmse_g = NRMSELoss_g_ClimateBench ()
185204 self .nrmse_s = NRMSELoss_s_ClimateBench ()
186205 self .alpha = alpha
187206
188- def forward (self , pred , y ):
207+ def forward (self , pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
208+ self .check_lat_lon (pred , y )
189209 nrmseg = self .nrmse_g (pred , y )
190210 nrmses = self .nrmse_s (pred , y )
191211 nrmse = nrmses + self .alpha * nrmseg
192212 return nrmse
193213
194- class LLWeighted_RMSELoss_WeatherBench (nn .Module ):
195-
214+ class LLWeighted_RMSELoss_WeatherBench (ClimateSetLoss ):
196215 """
197216 Weigthed RMSE taken from Weather Bench.
198217 Weighting to account for decreasing grid sizes towards the pole.
199218
200219 rmse = mean over forecasts and time of torch.sqrt( mean over lon lat L(lat_j)*)MSE(pred, y)
201220 weights = cos(latitude)/cos(latitude).mean()
202221 """
203-
204222 def __init__ (self ):
205223 super ().__init__ ()
206224
207225 self .mse = nn .MSELoss (reduction = "none" )
208226
209- def forward (self , pred , y ):
210-
211- lat_size = y .shape [- 2 ]
212- lats = torch .linspace (- 89.75 , 89.75 , lat_size )
213- weights = torch .cos ((torch .pi * lats ) / 180 )
214- weights = weights .unsqueeze (- 1 )
227+ def forward (self , pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
228+ self .check_lat_lon (pred , y )
229+ weights = self .get_latitude_weights (y .shape [- 2 ])
230+ weights = weights .to (device )
215231 weights = weights .to (device )
216232
217233 #rmse_before = torch.sqrt(torch.mean(weights * self.mse(pred, y), dim=(-2, -1))).mean()
218234 rmse = torch .mean (torch .sqrt (torch .mean (weights * ((pred - y )** 2 ), dim = ([- 2 , - 1 ]))))
219235
220236 return rmse
221237
222- class LLweighted_MSELoss_Climax (nn . Module ):
238+ class LLweighted_MSELoss_Climax (ClimateSetLoss ):
223239 """
224240 Latitude weighted mean squared error taken from ClimaX.
225241 Allows to weight the loss by the cosine of the latitude to account for gridding differences at equator vs. poles.
226242 Applied per variable.
227243 If given a mask, normalized by sum of that.
228-
229244 """
230-
231245 def __init__ (self , mask = None ):
232246 super ().__init__ ()
233247
234248 self .mse = nn .MSELoss (reduction = "none" )
235249 self .mask = mask
236250
237- def forward (self , pred , y ):
251+ def forward (self , pred : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
252+ self .check_lat_lon (pred , y )
238253 mse = self .mse (pred , y )
239-
240- lat_size = y .shape [- 2 ]
241- lats = torch .linspace (- 89.75 , 89.75 , lat_size )
242- weights = torch .cos ((torch .pi * lats ) / 180 )
243- weights = weights .unsqueeze (- 1 )
254+ weights = self .get_latitude_weights (y .shape [- 2 ])
244255 weights = weights .to (device )
245256
246257 # how they create the weights (does not work for us, results make no sense)
@@ -261,46 +272,32 @@ def forward(self, pred, y):
261272 return error
262273
263274
264- class LLweighted_RMSELoss_Climax (nn . Module ):
275+ class LLweighted_RMSELoss_Climax (ClimateSetLoss ):
265276 """
266277 Latitude weighted root mean squared error taken from ClimaX.
267278 Allows to weight the loss by the cosine of the latitude to account for gridding differences at equator vs. poles.
268279 Applied per variable.
269280 If given a mask, normalized by sum of that.
270281 """
271-
272282 def __init__ (self , mask = None ):
273283 super ().__init__ ()
274284
275285 self .mse = nn .MSELoss (reduction = "none" )
276286 self .mask = mask
277287
278- def forward (self , pred , y ) :
288+ def forward (self , pred : torch . Tensor , y : torch . Tensor ) -> torch . Tensor :
279289 """ Latitude is expected to be on position -2
280290 """
281- lat_num_grid_cells = y .shape [- 2 ]
282-
283- # Expected shape: [4, 12, 96, 144] -> [batch, time, latitude, longitude]
284- if (pred .shape [- 1 ] == 1 ) or (y [- 1 ].shape == 1 ):
285- raise ValueError ("Loss function: Last dimension (values/channels) must be squeezed away" )
286-
287- if (pred .shape [- 1 ] < pred .shape [- 2 ]):
288- raise ValueError ("There are more latitude than longitude grid cells. Check if you swapped longitude and latitude." )
289-
291+ self .check_lat_lon (pred , y )
290292 mse = self .mse (pred , y ) # [batch, time, lat, lon]
291293
292- latitudes = torch .linspace (- 89.75 , 89.75 , lat_num_grid_cells )
293- # torch.abs: -90 and + 90 get -0.000X as weight -> make all weights positive
294- weights = torch .abs (torch .cos (torch .deg2rad (latitudes )))
295- weights = weights .unsqueeze (- 1 )
294+ weights = self .get_latitude_weights (y .shape [- 2 ])
295+ weights = weights .to (device )
296296
297297 # ClimaX creates weird weights by dividing them by the mean
298298 # this leads to the climax rmse and mse to be the exact same like the unweighted mse / rmse
299299 #weights = weights / weights.mean() # ignored in this code
300300
301- # move weights to device
302- weights = weights .to (device )
303-
304301 if self .mask is not None :
305302 raise NotImplementedError ("Masking is not supported in the loss functions anymore." )
306303
@@ -335,8 +332,7 @@ def forward(self, pred, y):
335332
336333 llmse_cx = LLweighted_MSELoss_Climax ()
337334 llrmse_cx = LLweighted_RMSELoss_Climax ()
338-
339- # MSE: CHECKED
335+
340336 loss = mse (dummy , targets )
341337 print ("MSE loss" , loss , loss .size ())
342338 # np_dummy = dummy.cpu().detach().numpy()
@@ -363,17 +359,9 @@ def forward(self, pred, y):
363359 loss = llrmse_cx (dummy , targets )
364360 print ("CX rmse loss" , loss , loss .size ())
365361
366- # TESTS for losses:
367- # - with specific tensor of 1s + offset
368- # - with specific random tensors
369- # - make sure output size is only one number (except if several channels?)
370-
371- # - compare losses for ones: rmse == nrmse_g
372- # - with channels for different variables (make sure it's not breaking) / doing whatever is needed
373- # - make sure WB and CX rmse losses are the same
374362
375363# REFACTOR
376- # - weight function should be one function (utils)
364+ # - weight function should be one function
377365
378366# Same tests needed for metrics
379367
0 commit comments