forked from SamsungLabs/time-aware-awb
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtesting.py
More file actions
112 lines (97 loc) · 4.69 KB
/
testing.py
File metadata and controls
112 lines (97 loc) · 4.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Copyright (c) 2025 Samsung Electronics Co., Ltd.
Author(s):
Mahmoud Afifi (m.afifi1@samsung.com, m.3afifi@gmail.com)
Licensed under the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) License, (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at https://creativecommons.org/licenses/by-nc/4.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and limitations under the License.
For conditions of distribution and use, see the accompanying LICENSE.md file.
Testing script.
"""
from torch.utils.data import DataLoader
import argparse
import utils
from model import IllumEstimator
from dataloader import IlluminantEstimationDataLoader
from utils import *
def get_args():
parser = argparse.ArgumentParser(
description='Script for testing a trained model.'
)
parser.add_argument('--dataset_path', type=str, required=True,
help='Path to the dataset directory. The directory must contain subdirectories:'
' "test" and "val".')
parser.add_argument('--set_name', type=str, required=True, choices=TESTING_SET_NAMES,
help='Set name. Options: "test" and "val".')
parser.add_argument('--model_path', type=str, default='./models',
help='Path to the directory of trained models.')
parser.add_argument('--output_path', type=str, default='./results',
help='Path to the output directory for saving the results.')
parser.add_argument('--device', type=str, choices=DEVICES, default='gpu',
help='Device to use for training. Options: "cpu" or "gpu".')
parser.add_argument('--without_mask', action='store_true',
help='To use testing/validation images without mask.')
parser.add_argument('--exp_name', type=str, default='awb',
help='Experiment name.')
return parser.parse_args()
def testing(model: IllumEstimator, data: DataLoader):
"""Testing trained model."""
model.eval()
filenames = []
estimated_illums = []
for i, batch in enumerate(data):
capture_data = batch['capture_data'].to(device=device)
hist = batch['hist'].to(device=device)
filenames.append(batch['filename'])
est_illum = model(hist, capture_data).detach().cpu().numpy().tolist()
estimated_illums.append(est_illum)
return estimated_illums, filenames
if __name__ == '__main__':
args = get_args()
if args.set_name == 'val':
valid = True
test = False
elif args.set_name == 'test':
valid = False
test = True
else:
raise ValueError('Invalid set_name.')
if args.device == 'gpu':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
device = torch.device('cpu')
configs = utils.read_json_file(os.path.join(args.model_path, f'config-{args.exp_name}.json'))
capture_data = configs['capture_data']
user_preference = configs['user_preference']
if user_preference:
sub_folder = 'pref_illum'
else:
sub_folder = 'gt_illum'
norm_values = configs['norm_values']
hist_boundaries = configs['hist_boundaries']
hist_bins = configs['hist_bins']
target_size = configs['target_size']
test_dataset = IlluminantEstimationDataLoader(
data_dir=args.dataset_path, capture_data=capture_data, target_size=target_size,
valid=valid, test=test, normalization_data=norm_values, hist_bins=hist_bins,
hist_boundaries=hist_boundaries, without_mask=args.without_mask)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=NUM_WORKERS)
capture_data_size = 0
for i, data in enumerate(test_loader):
capture_data_size = data['capture_data'][0, ...].numel()
if i == 0:
break
model = IllumEstimator(in_channels=capture_data_size, hist_channels=4).to(device=device)
folder = 'without_mask' if args.without_mask else 'with_mask'
model_path = os.path.join(args.model_path, f'model-{args.exp_name}.pt')
model.load_state_dict(torch.load(model_path, weights_only=True))
model.print_num_of_params()
estimated_illums, filenames = testing(model, test_loader)
output_path = os.path.join(args.output_path, folder, sub_folder, f'{args.exp_name}-{args.set_name}')
os.makedirs(output_path, exist_ok=True)
for f, est_illum in zip(filenames, estimated_illums):
utils.write_json_file({'est_illum': est_illum[0]}, os.path.join(output_path, f[0]))
print('Done!')