-
Notifications
You must be signed in to change notification settings - Fork 183
Expand file tree
/
Copy pathloader.py
More file actions
280 lines (235 loc) · 10.6 KB
/
loader.py
File metadata and controls
280 lines (235 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
"""
FFCV loader
"""
import enum
from os import environ
import ast
import logging
from multiprocessing import cpu_count
from re import sub
from typing import Any, Callable, Mapping, Sequence, Type, Union, Literal
from collections import defaultdict
from collections.abc import Collection
from copy import deepcopy
from enum import Enum, unique, auto
from ffcv.fields.base import Field
import torch as ch
import numpy as np
from .epoch_iterator import EpochIterator
from ..reader import Reader
from ..traversal_order.base import TraversalOrder
from ..traversal_order import Random, Sequential, QuasiRandom
from ..pipeline import Pipeline, PipelineSpec, Compiler
from ..pipeline.operation import Operation
from ..pipeline.graph import Graph
from ..memory_managers import (
ProcessCacheManager, OSCacheManager, MemoryManager
)
@unique
class OrderOption(Enum):
SEQUENTIAL = auto()
RANDOM = auto()
QUASI_RANDOM = auto()
ORDER_TYPE = Union[
TraversalOrder,
Literal[OrderOption.SEQUENTIAL,
OrderOption.RANDOM]
]
ORDER_MAP: Mapping[ORDER_TYPE, Type[TraversalOrder]] = {
OrderOption.RANDOM: Random,
OrderOption.SEQUENTIAL: Sequential,
OrderOption.QUASI_RANDOM: QuasiRandom
}
DEFAULT_PROCESS_CACHE = int(environ.get('FFCV_DEFAULT_CACHE_PROCESS', "0"))
DEFAULT_OS_CACHE = not DEFAULT_PROCESS_CACHE
class Loader:
"""FFCV loader class that can be used as a drop-in replacement
for standard (e.g. PyTorch) data loaders.
Parameters
----------
fname: str
Full path to the location of the dataset (.beton file format).
batch_size : int
Batch size.
num_workers : int
Number of workers used for data loading. Consider using the actual number of cores instead of the number of threads if you only use JITed augmentations as they usually don't benefit from hyper-threading.
os_cache : bool
Leverages the operating for caching purposes. This is beneficial when there is enough memory to cache the dataset and/or when multiple processes on the same machine training using the same dataset. See https://docs.ffcv.io/performance_guide.html for more information.
order : Union[OrderOption, TraversalOrder]
Traversal order, one of: SEQUENTIAL, RANDOM, QUASI_RANDOM, or a custom TraversalOrder
QUASI_RANDOM is a random order that tries to be as uniform as possible while minimizing the amount of data read from the disk. Note that it is mostly useful when `os_cache=False`.
distributed : bool
For distributed training (multiple GPUs). Emulates the behavior of DistributedSampler from PyTorch.
seed : int
Random seed for batch ordering.
indices : Sequence[int]
Subset of dataset by filtering only some indices.
pipelines : Mapping[str, Sequence[Union[Operation, torch.nn.Module]]
Dictionary defining for each field the sequence of Decoders and transforms to apply.
Fileds with missing entries will use the default pipeline, which consists of the default decoder and `ToTensor()`,
but a field can also be disabled by explicitly by passing `None` as its pipeline.
custom_fields : Mapping[str, Field]
Dictonary informing the loader of the types associated to fields that are using a custom type.
drop_last : bool
Drop non-full batch in each iteration.
batches_ahead : int
Number of batches prepared in advance; balances latency and memory.
recompile : bool
Recompile every iteration. This is necessary if the implementation of some augmentations are expected to change during training.
"""
def __init__(self,
fname: str,
batch_size: int,
num_workers: int = -1,
os_cache: bool = DEFAULT_OS_CACHE,
order: Union[ORDER_TYPE, TraversalOrder] = OrderOption.SEQUENTIAL,
distributed: bool = False,
seed: int = None, # For ordering of samples
indices: Sequence[int] = None, # For subset selection
pipelines: Mapping[str,
Sequence[Union[Operation, ch.nn.Module]]] = {},
custom_fields: Mapping[str, Type[Field]] = {},
drop_last: bool = True,
batches_ahead: int = 3,
recompile: bool = False, # Recompile at every epoch
):
if distributed and order != OrderOption.SEQUENTIAL and (seed is None):
logging.warn('No ordering seed was specified with distributed=True. '
'Setting seed to 0 to match PyTorch distributed sampler.')
seed = 0
elif seed is None:
tinfo = np.iinfo('int32')
seed = np.random.randint(0, tinfo.max)
# We store the original user arguments to be able to pass it to the
# filtered version of the datasets
self._args = {
'fname': fname,
'batch_size': batch_size,
'num_workers': num_workers,
'os_cache': os_cache,
'order': order,
'distributed': distributed,
'seed': seed,
'indices': deepcopy(indices),
'pipelines': deepcopy(pipelines),
'drop_last': drop_last,
'batches_ahead': batches_ahead,
'recompile': recompile
}
self.fname: str = fname
self.batch_size: int = batch_size
self.batches_ahead = batches_ahead
self.seed: int = seed
self.reader: Reader = Reader(self.fname, custom_fields)
self.num_workers: int = num_workers
self.drop_last: bool = drop_last
self.distributed: bool = distributed
self.code = None
self.recompile = recompile
if self.num_workers < 1:
self.num_workers = cpu_count()
Compiler.set_num_threads(self.num_workers)
if indices is None:
self.indices = np.arange(self.reader.num_samples, dtype='uint64')
else:
self.indices = np.array(indices)
if os_cache:
self.memory_manager: MemoryManager = OSCacheManager(self.reader)
else:
self.memory_manager: MemoryManager = ProcessCacheManager(
self.reader)
if order in ORDER_MAP:
self.traversal_order: TraversalOrder = ORDER_MAP[order](self)
elif issubclass(order, TraversalOrder):
self.traversal_order: TraversalOrder = order(self)
else:
raise ValueError(f"Order {order} is not a supported order type or a subclass of TraversalOrder")
memory_read = self.memory_manager.compile_reader()
self.next_epoch: int = 0
self.pipelines = {}
self.pipeline_specs = {}
self.field_name_to_f_ix = {}
custom_pipeline_specs = {}
# Creating PipelineSpec objects from the pipeline dict passed
# by the user
for output_name, spec in pipelines.items():
if isinstance(spec, PipelineSpec):
pass
elif isinstance(spec, Sequence):
spec = PipelineSpec(output_name, decoder=None, transforms=spec)
elif spec is None:
continue # This is a disabled field
else:
msg = f"The pipeline for {output_name} has to be "
msg += f"either a PipelineSpec or a sequence of operations"
raise ValueError(msg)
custom_pipeline_specs[output_name] = spec
# Adding the default pipelines
for f_ix, (field_name, field) in enumerate(self.reader.handlers.items()):
self.field_name_to_f_ix[field_name] = f_ix
if field_name not in custom_pipeline_specs:
# We add the default pipeline
if field_name not in pipelines:
self.pipeline_specs[field_name] = PipelineSpec(field_name)
else:
self.pipeline_specs[field_name] = custom_pipeline_specs[field_name]
# We add the custom fields after the default ones
# This is to preserve backwards compatibility and make sure the order
# is intuitive
for field_name, spec in custom_pipeline_specs.items():
if field_name not in self.pipeline_specs:
self.pipeline_specs[field_name] = spec
self.graph = Graph(self.pipeline_specs, self.reader.handlers,
self.field_name_to_f_ix, self.reader.metadata,
memory_read)
self.generate_code()
self.first_traversal_order = self.next_traversal_order()
def next_traversal_order(self):
return self.traversal_order.sample_order(self.next_epoch)
def __iter__(self):
Compiler.set_num_threads(self.num_workers)
order = self.next_traversal_order()
selected_order = order[:len(self) * self.batch_size]
self.next_epoch += 1
# Compile at the first epoch
if self.code is None or self.recompile:
self.generate_code()
return EpochIterator(self, selected_order)
def filter(self, field_name:str, condition: Callable[[Any], bool]) -> 'Loader':
new_args = {**self._args}
pipelines = {}
# Disabling all the other fields
for other_field_name in self.reader.handlers.keys():
pipelines[other_field_name] = None
# We reuse the original pipeline for the field we care about
try:
pipelines[field_name] = new_args['pipelines'][field_name]
except KeyError:
# We keep the default one if the user didn't setup a custom one
del pipelines[field_name]
pass
new_args['pipelines'] = pipelines
# We use sequential order for speed and to know which index we are
# filtering
new_args['order'] = OrderOption.SEQUENTIAL
new_args['drop_last'] = False
sub_loader = Loader(**new_args)
selected_indices = []
# Iterate through the loader and test the user defined condition
for i, (batch,) in enumerate(sub_loader):
for j, sample in enumerate(batch):
sample_id = i * self.batch_size + j
if condition(sample):
selected_indices.append(sample_id)
final_args = {**self._args}
final_args['indices'] = np.array(selected_indices)
return Loader(**final_args)
def __len__(self):
next_order = self.first_traversal_order
if self.drop_last:
return len(next_order) // self.batch_size
else:
return int(np.ceil(len(next_order) / self.batch_size))
def generate_code(self):
queries, code = self.graph.collect_requirements()
self.code = self.graph.codegen_all(code)