-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_us_single.py
More file actions
26 lines (21 loc) · 1014 Bytes
/
train_us_single.py
File metadata and controls
26 lines (21 loc) · 1014 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from training.FetalAbdominalTrainer import FetalAbdominalTrainer
import torch
import os
from utils.root_path import results_single_path
from dataset.loader import get_data_loader
if __name__ == '__main__':
data_type = "HAM" #"fetalAbdominal", "XRay"
model_type = "transunet_b16" #"unet", "mednca", "transunet_b16"
num_clients = 4
data_split_seed = 42
num_test = 0.5
client_name = 'client-4'
batch_size = 4
ldr = get_data_loader(data_type, client_name, num_clients, num_test, data_split_seed, batch_size=batch_size, shuffle=True)
trainer = FetalAbdominalTrainer(ldr, data_type, model_type, client_name)
local_epochs = 200
for epoch in range(local_epochs):
trainer.train_epoch(loading_bar=True)
setup_name = f"{data_type}_{num_clients}_{data_split_seed}_{num_test}"
exp_name = f"exp_{model_type}"
trainer.save_model(os.path.join(results_single_path, setup_name, exp_name, f"{client_name}.pt"))