-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_augmentation.py
More file actions
106 lines (84 loc) · 3.79 KB
/
data_augmentation.py
File metadata and controls
106 lines (84 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
CIFAR-10 data normalization reference:
https://github.com/Armour/pytorch-nn-practice/blob/master/utils/meanstd.py
"""
import random
import os
import numpy as np
from PIL import Image
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.sampler import SubsetRandomSampler
def fetch_dataloader(types, params):
"""
Fetch and return train/dev dataloader with hyperparameters (params.subset_percent = 1.)
"""
# using random crops and horizontal flip for train set
if params.augmentation == "yes":
train_transformer = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(), # randomly flip image horizontally
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
# data augmentation can be turned off
else:
train_transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
# transformer for dev set
dev_transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
trainset = torchvision.datasets.CIFAR10(root='./data-cifar10', train=True,
download=True, transform=train_transformer)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=params.batch_size,
shuffle=True, num_workers=params.num_workers, pin_memory=params.cuda)
devset = torchvision.datasets.CIFAR10(root='./data-cifar10', train=False,
download=True, transform=dev_transformer)
devloader = torch.utils.data.DataLoader(devset, batch_size=params.batch_size,
shuffle=False, num_workers=params.num_workers, pin_memory=params.cuda)
if types == 'train':
dl = trainloader
else:
dl = devloader
return dl
def fetch_subset_dataloader(types, params):
"""
Use only a subset of dataset for KD training, depending on params.subset_percent
"""
# using random crops and horizontal flip for train set
if params.augmentation == "yes":
train_transformer = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(), # randomly flip image horizontally
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
# data augmentation can be turned off
else:
train_transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
# transformer for dev set
dev_transformer = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
trainset = torchvision.datasets.CIFAR10(root='./data-cifar10', train=True,
download=True, transform=train_transformer)
devset = torchvision.datasets.CIFAR10(root='./data-cifar10', train=False,
download=True, transform=dev_transformer)
trainset_size = len(trainset)
indices = list(range(trainset_size))
split = int(np.floor(params.subset_percent * trainset_size))
np.random.seed(230)
np.random.shuffle(indices)
train_sampler = SubsetRandomSampler(indices[:split])
trainloader = torch.utils.data.DataLoader(trainset, batch_size=params.batch_size,
sampler=train_sampler, num_workers=params.num_workers, pin_memory=params.cuda)
devloader = torch.utils.data.DataLoader(devset, batch_size=params.batch_size,
shuffle=False, num_workers=params.num_workers, pin_memory=params.cuda)
if types == 'train':
dl = trainloader
else:
dl = devloader
return dl