-
Notifications
You must be signed in to change notification settings - Fork 173
Expand file tree
/
Copy pathdataset.py
More file actions
117 lines (97 loc) · 3.84 KB
/
dataset.py
File metadata and controls
117 lines (97 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import tensorflow as tf
import os
from tensorflow import keras
from keras import layers
from keras.utils import load_img
from keras.utils import array_to_img
from keras.utils import img_to_array
from keras.preprocessing import image_dataset_from_directory
from IPython.display import display
# Reference
""" Title: Image Super-Resolution using an Efficient Sub-Pixel CNN
Author: Xingyu Long
Date: 28/07/2020
Availability: https://keras.io/examples/vision/super_resolution_sub_pixel/"""
#Set parameters for cropping
crop_width_size = 256
crop_height_size = 248
upscale_factor = 4 # ratio that dowansample orginal image for training and upscale images to predict at
input_height_size = crop_height_size // upscale_factor
input_width_size = crop_width_size // upscale_factor
batch_size = 8
#Specify directory containing training dataset
training_dir = "D:/temporary_workspace/comp3710_project/PatternAnalysis_2023_Shan_Jiang/recognition/SuperResolutionShanJiang/train_dataset"
#Create traning dataset
train_ds = image_dataset_from_directory(
training_dir,
batch_size=batch_size,
image_size=(crop_height_size, crop_width_size),
validation_split=0.2,
subset="training",
seed=1337,
label_mode=None,
)
#Create validation dataset
valid_ds = image_dataset_from_directory(
training_dir,
batch_size=batch_size,
image_size=(crop_height_size, crop_width_size),
validation_split=0.2,
subset="validation",
seed=1337,
label_mode=None,
)
# resacla training and validation images to take values in the range [0, 1].
def scaling(input_image):
input_image = input_image / 255.0
return input_image
train_ds = train_ds.map(scaling)
valid_ds = valid_ds.map(scaling)
# A fucntion that turns given image to grey scale and crop it
def process_input(input,input_height_size,input_width_size):
input = tf.image.rgb_to_yuv(input)
last_dimension_axis = len(input.shape) - 1
y, u, v = tf.split(input, 3, axis=last_dimension_axis)
return tf.image.resize(y, [input_height_size, input_width_size], method="area")
# A fucntion that turn given image to grey scale
def process_target(input):
input = tf.image.rgb_to_yuv(input)
last_dimension_axis = len(input.shape) - 1
y, u, v = tf.split(input, 3, axis=last_dimension_axis)
return y
# Process train dataset:create low resolution images and corresponding high resolution images, and put the pair into a tuple
train_ds = train_ds.map(
lambda x: (process_input(x, input_height_size, input_width_size), process_target(x))
)
train_ds = train_ds.prefetch(buffer_size=32)
# Process validation dataset:create low resolution images and corresponding high resolution images, and put the pair into a tuple
valid_ds = valid_ds.map(
lambda x: (process_input(x, input_height_size, input_width_size), process_target(x))
)
valid_ds = valid_ds.prefetch(buffer_size=32)
#Specify directory containing testing dataset
test_path = 'D:/temporary_workspace/comp3710_project/PatternAnalysis_2023_Shan_Jiang/recognition/SuperResolutionShanJiang/test_dataset'
#Put path of each testing image into a sorted list
test_img_paths = sorted(
[
os.path.join(test_path, fname)
for fname in os.listdir(test_path)
if fname.endswith(".jpeg")
]
)
#return a list containing path of each image for testing
def get_test_img_paths():
return test_img_paths
#Specify directory containing prediction dataset
prediction_path = "D:/temporary_workspace/comp3710_project/PatternAnalysis_2023_Shan_Jiang/recognition/SuperResolutionShanJiang/prediction_dataset"
#Put path of each prediction image into a sorted list
prediction_path = sorted(
[
os.path.join(prediction_path, fname)
for fname in os.listdir(prediction_path)
if fname.endswith(".jpeg")
]
)
# return a list containing path of each image to be predicted
def get_prediction_img_paths():
return prediction_path