-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathbuild_data_tfrecord.py
More file actions
40 lines (34 loc) · 1.62 KB
/
build_data_tfrecord.py
File metadata and controls
40 lines (34 loc) · 1.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#encoding:utf-8
import os
import argparse
from pytfeager.utils.data_utils import train_val_split,convert_image_record
from pytfeager.config import densenet_config as config
def main():
# 首先对数据进行分割
trainPaths = os.path.join(config.PATH,config.DATA_PATH)
(trainPaths,trainLabels),(valPaths,valLabels) = train_val_split(trainPaths,test_size=args['test_size'],random_state=0)
class_to_int = config.CLASS_DICT
# 对标签进行转换
trainLabels = [class_to_int[tag] for tag in trainLabels]
valLabels = [class_to_int[tag] for tag in valLabels]
# 数据保存路径
feature_path = os.path.join(config.PATH,config.FEATURES_PATH)
if not os.path.exists(feature_path):
os.makedirs(feature_path)
train_tfrecord = os.path.join(feature_path,'train.tfrecords')
test_tfrecord = os.path.join(feature_path,'test.tfrecords')
# 将训练数据集保存tfrecord
convert_image_record(img_paths = trainPaths,
img_labels = trainLabels,
tfrecord_file_name = train_tfrecord,
is_train=True)
# 将验证数据集保存tfrecord
convert_image_record(img_paths = valPaths,
img_labels = valLabels,
tfrecord_file_name = test_tfrecord,
is_train=True)
if __name__ == '__main__':
ap = argparse.ArgumentParser()
ap.add_argument('-s', '--test_size', type=int, default=0.2, help='test data size')
args = vars(ap.parse_args())
main()