-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutil.py
More file actions
86 lines (71 loc) · 2.78 KB
/
util.py
File metadata and controls
86 lines (71 loc) · 2.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
'''
@file: util.py
This file contains all the utility functions required for the BTP.
@author: Rukmangadh Sai Myana
@mail: rukman.sai@gmail.com
'''
import os
import torch
def set_defaults(given_args, default_args):
'''
Set default values to the arguments that are None.
This function is necessary because same named arguments (from command
prompt flags) can have different defaults in different contexts.
@param given_args: The parameters given by the flags.
@param default_args: The default values for the args.
'''
for key, value in given_args.items():
if value is None:
given_args[key] = default_args[key]
return given_args
def get_summary_dir(args):
'''
Return the directory for the summary writer to store the tensorboard
summary.
@param args: The arguments passed as flags from the command prompt.
@returns: The path to the directory where the summary is gonna be stored
for the current experiment.
'''
summary_dir = os.path.join(args.logdir, args.dataset,
args.task + str(args.task_number))
# directory doesn't exist
if not os.path.isdir(os.path.join(summary_dir, 'experiment_1')):
summary_dir = os.path.join(summary_dir, 'experiment_1')
os.makedirs(summary_dir)
# directory exists
else:
prev_experiment_dirname = max(os.listdir(summary_dir))
if args.resume:
# resuming training means no new logging directory
new_experiment_dirname = prev_experiment_dirname
else:
new_experiment_dirname = prev_experiment_dirname[:-1] + \
str(int(prev_experiment_dirname[-1]) + 1)
summary_dir = os.path.join(summary_dir, new_experiment_dirname)
return summary_dir
def save_best_model(save_dir, model, metrics, use_metric='accuracy'):
'''
Save the model in the given directory if it's better than the older model
in the directory.
@param save_dir: The directory to save the model in.
@param model: The model to save
@param metrics: The metrics for the model.
@param use_metric: The metric to be used for comparing the models.
Default is 'accuracy'
'''
filepath = os.path.join(save_dir, 'best_model.pt')
# older best model file exists
if os.path.exists(filepath):
best_model = torch.load(filepath)
best_metric = best_model['metrics'][use_metric]
if best_metric < metrics[use_metric]:
# save file contains the state dict and metrics also
torch.save({
'model_state_dict': model.state_dict(),
'metrics': metrics,
}, filepath)
else:
torch.save({
'model_state_dict': model.state_dict(),
'metrics': metrics,
}, filepath)