-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patharcface.py
More file actions
153 lines (134 loc) · 5.44 KB
/
arcface.py
File metadata and controls
153 lines (134 loc) · 5.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import math
import timm
import torch
import torch.nn as nn
from torch.nn import functional as F
from config import CFG
import albumentations
from albumentations.pytorch.transforms import ToTensorV2
class ArcFaceModule(nn.Module):
def __init__(self, in_features, out_features, scale, margin, easy_margin=False, ls_eps=0.0):
super(ArcFaceModule, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.scale = scale
self.margin = margin
self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
nn.init.xavier_uniform_(self.weight)
self.easy_margin = easy_margin
self.ls_eps = ls_eps
self.cos_m = math.cos(margin)
self.sin_m = math.sin(margin)
self.th = math.cos(math.pi - margin)
self.mm = math.sin(math.pi - margin) * margin
def forward(self, input, label):
# cosine = X.W = ||X|| .||W|| . cos(theta)
# if X and W are normalize then dot product X, W = will be cos theta
cosine = F.linear(F.normalize(input), F.normalize(self.weight))
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
# phi = cos(theta + margin) = cos theta . cos(margin) - sine theta . sin(margin)
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
one_hot = torch.zeros(cosine.size(), device=CFG.device)
# one hot encoded
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
if self.ls_eps > 0:
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
# output = label == True ? phi : cosine
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
# scale the output
output *= self.scale
# return cross entropy loss on scalled output
return output, nn.CrossEntropyLoss()(output, label)
class ShopeeEncoderBackBone(nn.Module):
def __init__(self,
model_name='tf_efficientnet_b3',
loss_fn='ArcFace',
classes=CFG.classes,
fc_dim=CFG.fc_dim,
pretrained=False,
use_fc=True,
isTraining=True
):
super(ShopeeEncoderBackBone, self).__init__()
# create bottlenack backbone network from pretrained model
self.backbone = timm.create_model(model_name, pretrained=pretrained)
in_features = self.backbone.classifier.in_features
self.backbone.classifier = nn.Identity()
self.backbone.global_pool = nn.Identity()
self.pooling = nn.AdaptiveAvgPool2d(1)
self.use_fc = use_fc
self.loss_fn = loss_fn
self.isTraining = isTraining
# build top fc layers (Embedding that we are looking at testing time to represent the entire image)
# this will work as regularizer
if self.use_fc:
self.dropout = nn.Dropout(0.2)
self.fc = nn.Linear(in_features, fc_dim)
self.bn = nn.BatchNorm1d(fc_dim)
self.init_params()
in_features = fc_dim
self.loss_fn = loss_fn
if self.loss_fn == 'softmax':
self.final = nn.Linear(in_features, CFG.classes)
elif self.loss_fn == 'ArcFace':
self.final = ArcFaceModule(in_features,
CFG.classes,
scale=30,
margin=0.5,
easy_margin=False,
ls_eps=0.0)
def forward(self, image, label):
features = self.get_features(image)
if self.isTraining:
logits = self.final(features, label)
return logits
else:
return features
def init_params(self):
nn.init.xavier_normal_(self.fc.weight)
nn.init.constant_(self.fc.bias, 0)
nn.init.constant_(self.bn.weight, 1)
nn.init.constant_(self.bn.bias, 0)
def get_features(self, inp):
batch_dim = inp.shape[0]
inp = self.backbone(inp)
inp = self.pooling(inp).view(batch_dim, -1)
if self.use_fc and self.isTraining:
inp = self.dropout(inp)
inp = self.fc(inp)
inp = self.bn(inp)
return inp
def get_test_transforms():
return albumentations.Compose(
[
albumentations.Resize(CFG.img_size, CFG.img_size, always_apply=True),
albumentations.Normalize(),
ToTensorV2(p=1.0)
]
)
def getAugmentation(IMG_SIZE=CFG.img_size, isTraining=CFG.isTraining):
if isTraining:
return albumentations.Compose([
albumentations.Resize(IMG_SIZE, IMG_SIZE, always_apply=True),
albumentations.HorizontalFlip(p=0.5),
albumentations.VerticalFlip(p=0.5),
albumentations.Rotate(limit=120, p=0.75),
albumentations.RandomBrightness(limit=(0.09, 0.6), p=0.5),
albumentations.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
ToTensorV2(p=1.0)
])
else:
return albumentations.Compose(
[
albumentations.Resize(IMG_SIZE, IMG_SIZE, always_apply=True),
albumentations.Normalize(),
ToTensorV2(p=1.0)
]
)