Skip to content

Commit 3ef3010

Browse files
authored
Add files via upload
1 parent 5048894 commit 3ef3010

6 files changed

Lines changed: 711 additions & 0 deletions

File tree

CFG.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as np
2+
3+
class CFG:
4+
target_cols = ['camp', 'corylus', 'dust', 'grim', 'qrob', 'qsub', 'cont']
5+
target_size = len(target_cols)
6+
prob_cols = ['p_' + i for i in target_cols]
7+
cols_mva = ['Area (ABD)', 'Area (Filled)', 'Aspect Ratio', 'Biovolume (Cylinder)',
8+
'Biovolume (P. Spheroid)', 'Circle Fit',
9+
'Circularity', 'Circularity (Hu)', 'Compactness', 'Convex Perimeter',
10+
'Convexity', 'Diameter (ABD)', 'Diameter (ESD)', 'Edge Gradient',
11+
'Elongation', 'Feret Angle Max', 'Feret Angle Min', 'Fiber Curl',
12+
'Fiber Straightness', 'Geodesic Aspect Ratio', 'Geodesic Length',
13+
'Geodesic Thickness', 'Intensity', 'Length', 'Particles Per Chain',
14+
'Perimeter', 'Roughness', 'Sigma Intensity', 'Sum Intensity',
15+
'Symmetry', 'Transparency', 'Volume (ABD)', 'Volume (ESD)', 'Width']
16+
17+
size = 128
18+
n_fold = 1
19+
num_workers = 8
20+
batch_size = 512
21+
model_name = 'resnet18'
22+
if_pretrained = True
23+
24+
lr = 1e-4
25+
epochs = 40
26+
27+
run_umap_test = True
28+
29+
save_model = True
30+
load_model = False
31+
model_name_saved = 'ICELEARNING_net'
32+
OUTPUT_DIR = 'saved_model/'
33+
save_conf_matrix = 'confusion_matrix_test_dataset.pdf'
34+
35+
save_inference_csv_files = True

data.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from CFG import CFG
2+
import torch
3+
from torch.utils.data import Dataset
4+
import albumentations as A
5+
import cv2
6+
from albumentations.pytorch import ToTensorV2
7+
8+
# ====================================================
9+
# Dataset
10+
# ====================================================
11+
class ParticleDataset(Dataset):
12+
13+
def __init__(self, df, transform=None):
14+
15+
self.df = df
16+
self.imgpaths = df['imgpaths'].to_numpy()
17+
self.labels = df[CFG.target_cols].to_numpy()
18+
self.transform = transform
19+
self.X_features = df[CFG.cols_mva].to_numpy()
20+
21+
def __len__(self):
22+
return len(self.df)
23+
24+
def __getitem__(self, idx):
25+
26+
imgpath = self.imgpaths[idx]
27+
28+
image = cv2.imread(imgpath, cv2.IMREAD_GRAYSCALE)
29+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
30+
31+
if self.transform:
32+
augmented = self.transform(image=image)
33+
image = augmented['image']
34+
35+
label = torch.tensor(self.labels[idx]).float()
36+
xfeatures = torch.from_numpy(self.X_features[idx]).float()
37+
38+
# print(type(image), type(label), type(xfeatures))
39+
return image, label, imgpath, xfeatures
40+
41+
# ====================================================
42+
# Transformations
43+
# ====================================================
44+
def get_transforms(*, data):
45+
if data == 'train':
46+
return A.Compose([
47+
A.Flip(p=0.5),
48+
A.Resize(CFG.size, CFG.size),
49+
A.Normalize(mean=[94., 94., 94.], std=[12., 12., 12.], max_pixel_value=1.0),
50+
ToTensorV2()
51+
])
52+
elif data == 'valid':
53+
return A.Compose([
54+
A.Resize(CFG.size, CFG.size),
55+
A.Normalize(mean=[94., 94., 94.], std=[12., 12., 12.], max_pixel_value=1.0),
56+
ToTensorV2()
57+
])

functs.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
3+
# ====================================================
4+
# Helper functions
5+
# ====================================================
6+
class AverageMeter(object):
7+
"""Computes and stores the average and current value"""
8+
9+
def __init__(self):
10+
self.reset()
11+
12+
def reset(self):
13+
self.val = 0
14+
self.avg = 0
15+
self.sum = 0
16+
self.count = 0
17+
self.history = []
18+
19+
def update(self, val, n=1):
20+
self.val = val
21+
self.sum += val * n
22+
self.count += n
23+
self.avg = self.sum / self.count
24+
self.history.append(self.sum / self.count)
25+
26+
def update_simplesum(self, val, n=1):
27+
self.val = val
28+
self.sum += val
29+
self.count += n
30+
self.avg = self.sum / self.count
31+
self.history.append(self.sum / self.count)
32+
33+
34+
def train_fn(train_loader, model, criterion, optimizer, scheduler, device):
35+
losses = AverageMeter()
36+
accuracies = AverageMeter()
37+
38+
model.train()
39+
40+
for step, (images, labels, paths, xfeatures) in enumerate(train_loader):
41+
images = images.to(device)
42+
labels = labels.to(device)
43+
xfeatures = xfeatures.to(device)
44+
45+
# with torch.set_grad_enabled(True):
46+
y_preds = model(images, xfeatures)
47+
loss = criterion(y_preds, labels)
48+
preds = (y_preds == y_preds.max(dim=1, keepdim=True)[0]).to(dtype=torch.int32)
49+
50+
# statistics
51+
losses.update(loss.item(), images.size(0))
52+
how_many_correct = torch.sum(torch.all(torch.eq(preds, labels), dim=1))
53+
accuracies.update_simplesum(how_many_correct.item(), images.size(0))
54+
55+
# compute gradient and do SGD step
56+
optimizer.zero_grad()
57+
loss.backward()
58+
optimizer.step()
59+
60+
scheduler.step()
61+
62+
print(f'Train Loss: {losses.avg:.4f} Acc: {accuracies.avg:.4f}')
63+
64+
return losses.history, accuracies.history
65+
66+
67+
def valid_fn(valid_loader, model, criterion, device):
68+
losses = AverageMeter()
69+
accuracies = AverageMeter()
70+
71+
model.eval()
72+
73+
for step, (images, labels, paths, xfeatures) in enumerate(valid_loader):
74+
images = images.to(device)
75+
labels = labels.to(device)
76+
xfeatures = xfeatures.to(device)
77+
78+
# compute loss
79+
with torch.no_grad():
80+
y_preds = model(images, xfeatures)
81+
loss = criterion(y_preds, labels)
82+
preds = (y_preds == y_preds.max(dim=1, keepdim=True)[0]).to(dtype=torch.int32)
83+
84+
# statistics
85+
losses.update(loss.item(), images.size(0))
86+
how_many_correct = torch.sum(torch.all(torch.eq(preds, labels), dim=1))
87+
accuracies.update_simplesum(how_many_correct.item(), images.size(0))
88+
89+
print(f'Val Loss: {losses.avg:.4f} Acc: {accuracies.avg:.4f}')
90+
91+
return losses.history, accuracies.history
92+

model.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from CFG import CFG
2+
import torch
3+
import torch.nn as nn
4+
import torchvision.models as models
5+
6+
# ====================================================
7+
# MODEL
8+
# ====================================================
9+
class CustomModel(nn.Module):
10+
def __init__(self, model_name=CFG.model_name, pretrained=CFG.if_pretrained):
11+
12+
super().__init__()
13+
14+
if model_name == 'resnet18':
15+
self.base = models.resnet18(weights='ResNet18_Weights.DEFAULT')
16+
if model_name == 'resnet34':
17+
self.base = models.resnet34(weights='ResNet34_Weights.DEFAULT')
18+
if model_name == 'resnet152':
19+
self.base = models.resnet152(weights='ResNet152_Weights.DEFAULT')
20+
if model_name == 'resnet101':
21+
self.base = timm.create_model(model_name)
22+
elif model_name == 'resnext50_32x4d':
23+
self.base = timm.create_model(model_name)
24+
25+
26+
n_features = self.base.fc.in_features # 512
27+
28+
self.base.fc = nn.Linear(n_features, 64)
29+
30+
self.bn1 = nn.BatchNorm1d(64)
31+
self.relu = nn.ReLU()
32+
self.dropout = nn.Dropout(p=0.2)
33+
34+
self.meta_net = nn.Sequential(nn.Linear(34, 128),
35+
nn.BatchNorm1d(128),
36+
nn.ReLU(),
37+
nn.Dropout(p=0.5),
38+
nn.Linear(128, 64),
39+
nn.BatchNorm1d(64),
40+
nn.ReLU(),
41+
nn.Dropout(p=0.5),
42+
nn.Linear(64, 32)
43+
)
44+
45+
self.fc3 = nn.Linear(96, 40)
46+
self.bn3 = nn.BatchNorm1d(40)
47+
48+
self.layer_out = nn.Linear(40, CFG.target_size)
49+
50+
def forward(self, imgs, metas):
51+
cnn1 = self.base(imgs)
52+
x = self.bn1(cnn1)
53+
x = self.relu(x)
54+
x = self.dropout(x)
55+
56+
meta_ = self.meta_net(metas)
57+
58+
x = torch.cat((x, meta_), 1)
59+
60+
x = self.fc3(x)
61+
x = self.bn3(x)
62+
x = self.relu(x)
63+
x = self.dropout(x)
64+
65+
x = self.layer_out(x)
66+
return x
67+
68+
# ====================================================
69+
# Hooks
70+
# ====================================================
71+
class Hook():
72+
def __init__(self, name, module, backward=False):
73+
74+
self.name = name
75+
76+
if backward == False:
77+
self.hook = module.register_forward_hook(self.hook_fn)
78+
else:
79+
self.hook = module.register_backward_hook(self.hook_fn)
80+
81+
def hook_fn(self, module, input, output):
82+
self.input = input
83+
self.output = output
84+
85+
def close(self):
86+
self.hook.remove()

0 commit comments

Comments
 (0)