forked from rigley007/OpenPrivML
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
22 lines (16 loc) · 742 Bytes
/
main.py
File metadata and controls
22 lines (16 loc) · 742 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import config as cfg
from imagenet10_dataloader import get_data_loaders
from adv_image import Adv_Gen
from regular_generator import conv_generator, Generator
from pre_model_extractor import model_extractor
if __name__ == '__main__':
print("CUDA Available: ", torch.cuda.is_available())
device = torch.device("cuda:0" if (cfg.use_cuda and torch.cuda.is_available()) else "cpu")
train_loader, val_loader = get_data_loaders()
feature_ext = model_extractor('resnet18', 5, True)
generator = conv_generator()
# Two different auto-encoders are provided here
#generator = Generator(3,3)
advGen = Adv_Gen(device, feature_ext, generator)
advGen.train(train_loader, cfg.epochs)