Skip to content
This repository was archived by the owner on Jun 2, 2023. It is now read-only.

Commit 4f1500a

Browse files
authored
Merge pull request #201 from SimonTopp/main
Specify train sites to remove from test metrics
2 parents b25a45e + 8102332 commit 4f1500a

2 files changed

Lines changed: 14 additions & 6 deletions

File tree

river_dl/evaluate.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def partition_metrics(
201201
outfile=None,
202202
val_sites=None,
203203
test_sites=None,
204-
204+
train_sites=None,
205205
):
206206
"""
207207
calculate metrics for a certain group (or no group at all) for a given
@@ -222,8 +222,9 @@ def partition_metrics(
222222
names and dict values are the id values. These are added as columns to the
223223
metrics information
224224
:param outfile: [str] file where the metrics should be written
225-
:param val_sites: [list] sites to exclude from training metrics
225+
:param val_sites: [list] sites to exclude from training and test metrics
226226
:param test_sites: [list] sites to exclude from validation and training metrics
227+
:param train_sites: [list] sites to exclude from test metrics
227228
:return: [pd dataframe] the condensed metrics
228229
"""
229230
var_data = fmt_preds_obs(preds, obs_file, spatial_idx_name,
@@ -240,6 +241,10 @@ def partition_metrics(
240241
# mask out test sites from val partition
241242
if test_sites and partition=='val':
242243
data = data[~data[spatial_idx_name].isin(test_sites)]
244+
if train_sites and partition=='tst':
245+
data = data[~data[spatial_idx_name].isin(train_sites)]
246+
if val_sites and partition=='tst':
247+
data = data[~data[spatial_idx_name].isin(val_sites)]
243248

244249
if not group:
245250
metrics = calc_metrics(data)
@@ -286,6 +291,7 @@ def combined_metrics(
286291
pred_tst=None,
287292
val_sites=None,
288293
test_sites=None,
294+
train_sites=None,
289295
spatial_idx_name="seg_id_nat",
290296
time_idx_name="date",
291297
group=None,
@@ -349,7 +355,8 @@ def combined_metrics(
349355
id_dict=id_dict,
350356
group=group,
351357
val_sites = val_sites,
352-
test_sites = test_sites)
358+
test_sites = test_sites,
359+
train_sites=train_sites)
353360
df_all.extend([metrics])
354361

355362
df_all = pd.concat(df_all, axis=0)

river_dl/torch_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,19 +238,20 @@ def predict_torch(x_data, model, batch_size):
238238
@param device: [str] cuda or cpu
239239
@return: [tensor] predicted values
240240
"""
241-
device = next(model.parameters()).device
241+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
242+
243+
model.to(device)
242244
data = []
243245
for i in range(len(x_data)):
244246
data.append(torch.from_numpy(x_data[i]).float())
245247

246248
dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False, pin_memory=True)
247-
model.to(device)
248249
model.eval()
249250
predicted = []
250251
for iter, x in enumerate(dataloader):
251252
trainx = x.to(device)
252253
with torch.no_grad():
253-
output = model(trainx.to(device)).cpu()
254+
output = model(trainx).detach().cpu()
254255
predicted.append(output)
255256
predicted = torch.cat(predicted, dim=0)
256257
return predicted

0 commit comments

Comments
 (0)