forked from sthalles/SimCLR
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
53 lines (40 loc) · 1.38 KB
/
utils.py
File metadata and controls
53 lines (40 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import torch
np.random.seed(0)
cos1d = torch.nn.CosineSimilarity(dim=1)
cos2d = torch.nn.CosineSimilarity(dim=2)
def get_negative_mask(batch_size):
# return a mask that removes the similarity score of equal/similar images.
# this function ensures that only distinct pair of images get their similarity scores
# passed as negative examples
negative_mask = torch.ones((batch_size, 2 * batch_size), dtype=bool)
for i in range(batch_size):
negative_mask[i, i] = 0
negative_mask[i, i + batch_size] = 0
return negative_mask
def _dot_simililarity_dim1(x, y):
# x shape: (N, 1, C)
# y shape: (N, C, 1)
# v shape: (N, 1, 1)
v = torch.bmm(x.unsqueeze(1), y.unsqueeze(2)) #
return v
def _dot_simililarity_dim2(x, y):
v = torch.tensordot(x.unsqueeze(1), y.T.unsqueeze(0), dims=2)
# x shape: (N, 1, C)
# y shape: (1, C, 2N)
# v shape: (N, 2N)
return v
def _cosine_simililarity_dim1(x, y):
v = cos1d(x, y)
return v
def _cosine_simililarity_dim2(x, y):
# x shape: (N, 1, C)
# y shape: (1, 2N, C)
# v shape: (N, 2N)
v = cos2d(x.unsqueeze(1), y.unsqueeze(0))
return v
def get_similarity_function(use_cosine_similarity):
if use_cosine_similarity:
return _cosine_simililarity_dim1, _cosine_simililarity_dim2
else:
return _dot_simililarity_dim1, _dot_simililarity_dim2