Skip to content

Commit 2baddd1

Browse files
authored
Merge pull request ContinualAI#1370 from lrzpellegrini/porting_distributed_training_pt2
Add base elements to support distributed comms. Add supports_distributed plugin flag.
2 parents bea1cdb + 0f37679 commit 2baddd1

46 files changed

Lines changed: 2289 additions & 233 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/environment-update.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,11 @@ jobs:
6161
run: |
6262
python -m unittest discover tests &&
6363
echo "Running checkpointing tests..." &&
64-
bash ./tests/checkpointing/test_checkpointing.sh
64+
bash ./tests/checkpointing/test_checkpointing.sh &&
65+
echo "Running distributed training tests..." &&
66+
cd tests &&
67+
PYTHONPATH=.. python run_dist_tests.py &&
68+
cd ..
6569
- name: checkout avalanche-docker repo
6670
if: always()
6771
uses: actions/checkout@v3

.github/workflows/unit-test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,5 +58,9 @@ jobs:
5858
PYTHONPATH=. python examples/eval_plugin.py &&
5959
echo "Running checkpointing tests..." &&
6060
bash ./tests/checkpointing/test_checkpointing.sh &&
61+
echo "Running distributed training tests..." &&
62+
cd tests &&
63+
PYTHONPATH=.. python run_dist_tests.py &&
64+
cd .. &&
6165
echo "While running unit tests, the following datasets were downloaded:" &&
6266
ls ~/.avalanche/data

avalanche/benchmarks/classic/cmnist.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@
2828
)
2929
from avalanche.benchmarks.datasets.external_datasets.mnist import \
3030
get_mnist_dataset
31-
from ..utils import make_classification_dataset, DefaultTransformGroups
32-
from ..utils.data import make_avalanche_dataset
31+
from avalanche.benchmarks.utils import (
32+
make_classification_dataset,
33+
DefaultTransformGroups,
34+
)
35+
from avalanche.benchmarks.utils.data import make_avalanche_dataset
3336

3437
_default_mnist_train_transform = Compose(
3538
[Normalize((0.1307,), (0.3081,))]

avalanche/benchmarks/datasets/lvis_dataset/lvis_dataset.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,20 @@ def __getitem__(self, index):
159159
"""
160160
img_id = self.img_ids[index]
161161
img_dict: LVISImgEntry = self.lvis_api.load_imgs(ids=[img_id])[0]
162-
annotation_dicts = self.targets[index]
162+
annotation_dicts: LVISImgTargets = self.targets[index]
163163

164164
# Transform from LVIS dictionary to torchvision-style target
165-
num_objs = len(annotation_dicts)
165+
num_objs = annotation_dicts["bbox"].shape[0]
166166

167167
boxes = []
168168
labels = []
169169
for i in range(num_objs):
170-
xmin = annotation_dicts[i]["bbox"][0]
171-
ymin = annotation_dicts[i]["bbox"][1]
172-
xmax = xmin + annotation_dicts[i]["bbox"][2]
173-
ymax = ymin + annotation_dicts[i]["bbox"][3]
170+
xmin = annotation_dicts["bbox"][i][0]
171+
ymin = annotation_dicts["bbox"][i][1]
172+
xmax = xmin + annotation_dicts["bbox"][i][2]
173+
ymax = ymin + annotation_dicts["bbox"][i][3]
174174
boxes.append([xmin, ymin, xmax, ymax])
175-
labels.append(annotation_dicts[i]["category_id"])
175+
labels.append(annotation_dicts["category_id"][i])
176176

177177
if len(boxes) > 0:
178178
boxes = torch.as_tensor(boxes, dtype=torch.float32)
@@ -183,7 +183,7 @@ def __getitem__(self, index):
183183
image_id = torch.tensor([img_id])
184184
areas = []
185185
for i in range(num_objs):
186-
areas.append(annotation_dicts[i]["area"])
186+
areas.append(annotation_dicts["area"][i])
187187
areas = torch.as_tensor(areas, dtype=torch.float32)
188188
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
189189

@@ -233,7 +233,17 @@ class LVISAnnotationEntry(TypedDict):
233233
category_id: int
234234

235235

236-
class LVISDetectionTargets(Sequence[List[LVISAnnotationEntry]]):
236+
class LVISImgTargets(TypedDict):
237+
id: torch.Tensor
238+
area: torch.Tensor
239+
segmentation: List[List[List[float]]]
240+
image_id: torch.Tensor
241+
bbox: torch.Tensor
242+
category_id: torch.Tensor
243+
labels: torch.Tensor
244+
245+
246+
class LVISDetectionTargets(Sequence[List[LVISImgTargets]]):
237247
def __init__(
238248
self,
239249
lvis_api: LVIS,
@@ -254,7 +264,28 @@ def __getitem__(self, index):
254264
annotation_dicts: List[LVISAnnotationEntry] = self.lvis_api.load_anns(
255265
annotation_ids
256266
)
257-
return annotation_dicts
267+
268+
n_annotations = len(annotation_dicts)
269+
270+
category_tensor = torch.empty((n_annotations,), dtype=torch.long)
271+
target_dict: LVISImgTargets = {
272+
'bbox': torch.empty((n_annotations, 4), dtype=torch.float32),
273+
'category_id': category_tensor,
274+
'id': torch.empty((n_annotations,), dtype=torch.long),
275+
'area': torch.empty((n_annotations,), dtype=torch.float32),
276+
'image_id': torch.full((1,), img_id, dtype=torch.long),
277+
'segmentation': [],
278+
'labels': category_tensor # Alias of category_id
279+
}
280+
281+
for ann_idx, annotation in enumerate(annotation_dicts):
282+
target_dict['bbox'][ann_idx] = torch.as_tensor(annotation['bbox'])
283+
target_dict['category_id'][ann_idx] = annotation['category_id']
284+
target_dict['id'][ann_idx] = annotation['id']
285+
target_dict['area'][ann_idx] = annotation['area']
286+
target_dict['segmentation'].append(annotation['segmentation'])
287+
288+
return target_dict
258289

259290

260291
def _test_to_tensor(a, b):
@@ -316,5 +347,6 @@ def _plot_detection_sample(img: Image.Image, target):
316347
"LvisDataset",
317348
"LVISImgEntry",
318349
"LVISAnnotationEntry",
350+
"LVISImgTargets",
319351
"LVISDetectionTargets",
320352
]

avalanche/benchmarks/utils/data_loader.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
1818

1919
import torch
20-
from torch.utils.data import RandomSampler, DistributedSampler
20+
from torch.utils.data import RandomSampler, DistributedSampler, Dataset
2121
from torch.utils.data.dataloader import DataLoader
2222

2323
from avalanche.benchmarks.utils.collate_functions import (
@@ -31,6 +31,7 @@
3131
)
3232
from avalanche.benchmarks.utils.data import AvalancheDataset
3333
from avalanche.benchmarks.utils.data_attribute import DataAttribute
34+
from avalanche.distributed.distributed_helper import DistributedHelper
3435

3536
_default_collate_mbatches_fn = classification_collate_mbatches_fn
3637

@@ -284,14 +285,14 @@ def __init__(
284285
self.collate_mbatches = collate_mbatches
285286

286287
for data in self.datasets:
287-
if _DistributedHelper.is_distributed and distributed_sampling:
288+
if DistributedHelper.is_distributed and distributed_sampling:
288289
seed = torch.randint(
289290
0,
290-
2 ** 32 - 1 - _DistributedHelper.world_size,
291+
2 ** 32 - 1 - DistributedHelper.world_size,
291292
(1,),
292293
dtype=torch.int64,
293294
)
294-
seed += _DistributedHelper.rank
295+
seed += DistributedHelper.rank
295296
generator = torch.Generator()
296297
generator.manual_seed(int(seed))
297298
else:
@@ -584,11 +585,11 @@ def _get_batch_sizes(
584585

585586

586587
def _make_data_loader(
587-
dataset,
588-
distributed_sampling,
589-
data_loader_args,
590-
batch_size,
591-
force_no_workers=False,
588+
dataset: Dataset,
589+
distributed_sampling: bool,
590+
data_loader_args: Dict[str, Any],
591+
batch_size: int,
592+
force_no_workers: bool = False,
592593
):
593594
data_loader_args = data_loader_args.copy()
594595

@@ -601,14 +602,22 @@ def _make_data_loader(
601602
if 'prefetch_factor' in data_loader_args:
602603
data_loader_args['prefetch_factor'] = 2
603604

604-
if _DistributedHelper.is_distributed and distributed_sampling:
605+
if DistributedHelper.is_distributed and distributed_sampling:
606+
# Note: shuffle only goes in the sampler, while
607+
# drop_last must be passed to both the sampler
608+
# and the DataLoader
609+
drop_last = data_loader_args.pop("drop_last", False)
605610
sampler = DistributedSampler(
606611
dataset,
607-
shuffle=data_loader_args.pop("shuffle", False),
608-
drop_last=data_loader_args.pop("drop_last", False),
612+
shuffle=data_loader_args.pop("shuffle", True),
613+
drop_last=drop_last,
609614
)
610615
data_loader = DataLoader(
611-
dataset, sampler=sampler, batch_size=batch_size, **data_loader_args
616+
dataset,
617+
sampler=sampler,
618+
batch_size=batch_size,
619+
drop_last=drop_last,
620+
**data_loader_args
612621
)
613622
else:
614623
sampler = None
@@ -619,15 +628,6 @@ def _make_data_loader(
619628
return data_loader, sampler
620629

621630

622-
class __DistributedHelperPlaceholder:
623-
is_distributed = False
624-
world_size = 1
625-
rank = 0
626-
627-
628-
_DistributedHelper = __DistributedHelperPlaceholder()
629-
630-
631631
__all__ = [
632632
"detection_collate_fn",
633633
"detection_collate_mbatches_fn",

avalanche/core.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC
2-
from typing import TypeVar, Generic
2+
from typing import Optional, Type, TypeVar, Generic
33
from typing import TYPE_CHECKING
44

55
if TYPE_CHECKING:
@@ -27,8 +27,16 @@ class BasePlugin(Generic[Template], ABC):
2727
and loggers.
2828
"""
2929

30+
supports_distributed: bool = False
31+
"""
32+
A flag describing whether this plugin supports distributed training.
33+
"""
34+
3035
def __init__(self):
31-
pass
36+
"""
37+
Inizializes an instance of a supervised plugin.
38+
"""
39+
super().__init__()
3240

3341
def before_training(self, strategy: Template, *args, **kwargs):
3442
"""Called before `train` by the `BaseTemplate`."""
@@ -68,13 +76,26 @@ def after_eval(self, strategy: Template, *args, **kwargs) -> CallbackResult:
6876
"""Called after `eval` by the `BaseTemplate`."""
6977
pass
7078

79+
def __init_subclass__(
80+
cls,
81+
supports_distributed: bool = False,
82+
**kwargs) -> None:
83+
cls.supports_distributed = supports_distributed
84+
return super().__init_subclass__(**kwargs)
85+
7186

7287
class BaseSGDPlugin(BasePlugin[Template], ABC):
7388
"""ABC for BaseSGDTemplate plugins.
7489
7590
See `BaseSGDTemplate` for complete description of the train/eval loop.
7691
"""
7792

93+
def __init__(self):
94+
"""
95+
Inizializes an instance of a base SGD plugin.
96+
"""
97+
super().__init__()
98+
7899
def before_training_epoch(
79100
self, strategy: Template, *args, **kwargs
80101
) -> CallbackResult:
@@ -193,7 +214,11 @@ class SupervisedPlugin(BaseSGDPlugin[Template], ABC):
193214
194215
See `BaseTemplate` for complete description of the train/eval loop.
195216
"""
196-
pass
217+
def __init__(self):
218+
"""
219+
Inizializes an instance of a supervised plugin.
220+
"""
221+
super().__init__()
197222

198223

199224
class SupervisedMetaLearningPlugin(SupervisedPlugin[Template], ABC):

avalanche/distributed/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .distributed_helper import *

0 commit comments

Comments
 (0)