-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvacation_dataset.py
More file actions
154 lines (130 loc) · 4.45 KB
/
vacation_dataset.py
File metadata and controls
154 lines (130 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
VacationDataset class for loading and transforming vacation images.
This was not used since the ImageFolder class was used instead.
Attributes
----------
annotations : pandas.DataFrame
DataFrame containing image file names and corresponding labels.
root_dir : str
Directory with all the images.
transform : callable, optional
Optional transform to be applied on a sample.
target_transform : callable, optional
Optional transform to be applied on the target.
Methods
-------
__len__()
Returns the number of images in the dataset.
__getitem__(idx)
Returns the image and label at the specified index.
Parameters
----------
annotations_file : str
Path to the CSV file with annotations.
root_dir : str
Directory with all the images.
transform : callable, optional
Optional transform to be applied on a sample.
target_transform : callable, optional
Optional transform to be applied on the target.
Initialize the VacationDataset.
Args:
annotations_file (str): Path to the CSV file with annotations.
root_dir (str): Directory with all the images.
transform (callable, optional): Optional transform to be applied on a sample.
target_transform (callable, optional): Optional transform to be applied on the target.
Return the number of images in the dataset.
Returns:
int: Number of images in the dataset.
Get the image and label at the specified index.
Args:
idx (int): Index of the image and label to retrieve.
Returns:
tuple: (image, label) where image is the loaded image tensor and label is the corresponding label.
"""
from pathlib import Path
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image
class VacationDataset(Dataset):
"""
Dataset class for vacation images.
Parameters
----------
annotations_file : str
Path to the CSV file with annotations.
root_dir : str
Directory with all the images.
transform : callable, optional
Optional transform to be applied on a sample.
target_transform : callable, optional
Optional transform to be applied on the target.
Attributes
----------
annotations : pandas.DataFrame
DataFrame containing image file names and labels.
root_dir : str
Directory with all the images.
transform : callable
Transform to be applied on a sample.
target_transform : callable
Transform to be applied on the target.
Methods
-------
__len__()
Returns the number of images in the dataset.
__getitem__(idx)
Loads and returns the image and label at the specified index.
"""
def __init__(self, annotations_file, root_dir, transform=None, target_transform=None):
"""
Initialize the dataset with annotations file, root directory, and optional transforms.
Parameters
----------
annotations_file : str
Path to the CSV file containing annotations.
root_dir : str
Directory with all the images.
transform : callable, optional
Optional transform to be applied on a sample.
target_transform : callable, optional
Optional transform to be applied on the target.
"""
self.annotations = pd.read_csv(annotations_file)
self.root_dir = root_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
"""
Returns the number of images in the dataset.
Returns
-------
int
The number of images in the dataset.
"""
return len(self.annotations)
def __getitem__(self, idx):
"""
Load image and label at index idx.
Parameters
----------
idx : int
Index of the image and label to be loaded.
Returns
-------
tuple
A tuple containing the transformed image and label.
"""
# Read image
img_name = Path(self.annotations.iloc[idx, 0])
class_name = self.annotations.iloc[idx, 1]
img_path = str(Path(self.root_dir, class_name, img_name))
image = read_image(img_path)
# Label
label = self.annotations.iloc[idx, 2]
# Transform image and label
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label