@@ -201,7 +201,7 @@ def partition_metrics(
201201 outfile = None ,
202202 val_sites = None ,
203203 test_sites = None ,
204-
204+ train_sites = None ,
205205):
206206 """
207207 calculate metrics for a certain group (or no group at all) for a given
@@ -222,8 +222,9 @@ def partition_metrics(
222222 names and dict values are the id values. These are added as columns to the
223223 metrics information
224224 :param outfile: [str] file where the metrics should be written
225- :param val_sites: [list] sites to exclude from training metrics
225+ :param val_sites: [list] sites to exclude from training and test metrics
226226 :param test_sites: [list] sites to exclude from validation and training metrics
227+ :param train_sites: [list] sites to exclude from test metrics
227228 :return: [pd dataframe] the condensed metrics
228229 """
229230 var_data = fmt_preds_obs (preds , obs_file , spatial_idx_name ,
@@ -240,6 +241,10 @@ def partition_metrics(
240241 # mask out test sites from val partition
241242 if test_sites and partition == 'val' :
242243 data = data [~ data [spatial_idx_name ].isin (test_sites )]
244+ if train_sites and partition == 'tst' :
245+ data = data [~ data [spatial_idx_name ].isin (train_sites )]
246+ if val_sites and partition == 'tst' :
247+ data = data [~ data [spatial_idx_name ].isin (val_sites )]
243248
244249 if not group :
245250 metrics = calc_metrics (data )
@@ -286,6 +291,7 @@ def combined_metrics(
286291 pred_tst = None ,
287292 val_sites = None ,
288293 test_sites = None ,
294+ train_sites = None ,
289295 spatial_idx_name = "seg_id_nat" ,
290296 time_idx_name = "date" ,
291297 group = None ,
@@ -349,7 +355,8 @@ def combined_metrics(
349355 id_dict = id_dict ,
350356 group = group ,
351357 val_sites = val_sites ,
352- test_sites = test_sites )
358+ test_sites = test_sites ,
359+ train_sites = train_sites )
353360 df_all .extend ([metrics ])
354361
355362 df_all = pd .concat (df_all , axis = 0 )
0 commit comments