forked from SamsungLabs/time-aware-awb
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdataset_preprocessing.py
More file actions
84 lines (76 loc) · 4.5 KB
/
dataset_preprocessing.py
File metadata and controls
84 lines (76 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Copyright (c) 2025 Samsung Electronics Co., Ltd.
Author(s):
Mahmoud Afifi (m.afifi1@samsung.com, m.3afifi@gmail.com)
Licensed under the Creative Commons Attribution-NonCommercial 4.0 International (CC BY-NC 4.0) License, (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at https://creativecommons.org/licenses/by-nc/4.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and limitations under the License.
For conditions of distribution and use, see the accompanying LICENSE.md file.
This script is used to prepare the dataset for training and testing. It creates a downsampled version of the raw images
and applies masks to the training images. For validation and testing sets, it generates two versions: one with masks
applied and one without.
"""
import constants
import utils
import argparse
import os
import shutil
from numpy import expand_dims
def get_args():
parser = argparse.ArgumentParser(
description='Pre-process S24 dataset before training and testing.')
parser.add_argument(
'-dp', '--dataset_path', type=str, required=True,
help='Path to the dataset directory containing the train, val, and test folders of the S24 Raw-sRGB dataset.')
parser.add_argument('-op', '--output_path', type=str, required=True,
help='Path to the output directory where processed version of the dataset will be saved.')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
sub_dirs = ['train', 'val', 'test']
data_dirs = ['raw_images', 'masks', 'data']
output_data_dirs = [str(constants.WITHOUT_MASK_SUB_FOLDER), str(constants.WITH_MASK_SUB_FOLDER), 'data']
for sub_dir in sub_dirs:
sub_dir_path = os.path.join(args.dataset_path, sub_dir)
if not os.path.isdir(sub_dir_path):
raise FileNotFoundError(
f'Invalid S24 dataset path: missing "{sub_dir}" directory in {args.dataset_path}. '
'The dataset directory must include "train", "val", and "test" subdirectories.'
)
for data_dir in data_dirs:
data_dir_path = os.path.join(args.dataset_path, sub_dir, data_dir)
if not os.path.isdir(data_dir_path):
raise FileNotFoundError(
f'Invalid S24 dataset path: missing "{data_dir}" directory in {os.path.join(args.dataset_path, sub_dir)}. '
'Each set must include "raw_images", "masks", and "data" subdirectories.'
)
os.makedirs(args.output_path, exist_ok=True)
for sub_dir in sub_dirs:
os.makedirs(os.path.join(args.output_path, sub_dir), exist_ok=True)
for data_dir in output_data_dirs:
if sub_dir == 'train' and data_dir == str(constants.WITHOUT_MASK_SUB_FOLDER):
continue
os.makedirs(os.path.join(args.output_path, sub_dir, data_dir), exist_ok=True)
for sub_dir in sub_dirs:
print(f'Processing {sub_dir}...')
shutil.copytree(os.path.join(args.dataset_path, sub_dir, 'data'), os.path.join(args.output_path, sub_dir, 'data'),
dirs_exist_ok=True)
img_files = [f for f in os.listdir(os.path.join(args.dataset_path, sub_dir, 'raw_images')
) if f.endswith('.png') or f.endswith('.PNG')]
for i, img_file in enumerate(img_files):
print(f'Processing {sub_dir}: {i + 1}/{len(img_files)}...')
img = utils.imread(os.path.join(args.dataset_path, sub_dir, 'raw_images', img_file))
resized_img = utils.imresize(img, height=constants.TARGET_SIZE[0], width=constants.TARGET_SIZE[1],
interpolation_method='bicubic')
mask = utils.imread(os.path.join(args.dataset_path, sub_dir, 'masks', img_file))
resized_mask = utils.imresize(mask, height=constants.TARGET_SIZE[0], width=constants.TARGET_SIZE[1],
interpolation_method='nearest')
masked_img = resized_img * (1 - expand_dims(resized_mask, axis=-1))
utils.imwrite(masked_img, os.path.join(args.output_path, sub_dir, str(constants.WITH_MASK_SUB_FOLDER), img_file),
format='PNG-16')
if sub_dir in ['val', 'test']:
utils.imwrite(resized_img, os.path.join(args.output_path, sub_dir, str(constants.WITHOUT_MASK_SUB_FOLDER),
img_file), format='PNG-16')