-
Notifications
You must be signed in to change notification settings - Fork 4.7k
Expand file tree
/
Copy pathMiniCLIP-ViT
More file actions
169 lines (108 loc) · 4.87 KB
/
MiniCLIP-ViT
File metadata and controls
169 lines (108 loc) · 4.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#در این فایل قراره یک مینی clip با هدف درک معماری clip
# واقعی که هسته AGI است
import torch
import torch.nn as nn
import torch.nn.functional as F
#---creat paths---#
class PathEmbedding(nn.Module):
def __init__(self,img_size=32,paths_size=8,in_channels=3,embd_dim=128):
super().__init__()
self.num_paths = (img_size // paths_size) ** 2
self.project = nn.Conv2d(
in_channels,embd_dim,
kernel_size=paths_size,stride=paths_size
)
def forward(self,x):
x = self.project(x)
x = x.flatten(2)
x = x.transpose(1,2)
return x
#---ViTBlock---#
class VitBlock(nn.Module):
def __init__(self,embd_dim,num_heads):
super().__init__()
self.att_head = nn.MultiheadAttention(embd_dim,num_heads,batch_first=True)
self.mlp = nn.Sequential(
nn.Linear(embd_dim*4,embd_dim),
nn.GELU(),
nn.Linear(embd_dim,embd_dim*4)
)
self.norm1 = nn.LayerNorm(embd_dim)
self.norm2 = nn.LayerNorm(embd_dim)
def forward(self,x):
att, _ = self.att_head(self.norm1(x),self.norm1(x),self.norm1(x))
x = x + att
x = x + self.mlp(self.norm2(x))
return x
#---creat MiniViT----#
class VitImageEncoder(nn.Module):
def __init__(self,embd_dim=128,path_size=8,num_heads=3,depth=4,img_size=32):
super().__init__()
self.path_embedding = PathEmbedding(img_size,path_size,3,embd_dim)
self.cls = nn.Parameter(torch.randn(1,1,embd_dim))
# Corrected: use num_paths from the instantiated path_embedding
self.pos_path = nn.Parameter(torch.randn(1,1+self.path_embedding.num_paths,embd_dim))
self.blocks = nn.Sequential(
*[VitBlock(embd_dim,num_heads)
for _ in range(depth)]
)
self.norm = nn.LayerNorm(embd_dim)
def forward(self,x):
B = x.size(0)
x = self.path_embedding(x)
cls = self.cls.expand(B,-1,-1)
# Corrected: concatenate cls token at the beginning of the sequence
x = torch.cat((cls, x),dim=1)
x = x + self.pos_path
# Apply blocks
for block in self.blocks:
x = block(x)
out = self.norm(x[:,0]) # Take the class token representation
return F.normalize(out,dim=-1)
#---creat Text_encoder---#
class TextEncoder(nn.Module):
def __init__(self,embd_dim,vocab_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size,embd_dim)
def forward(self,x): # x is expected to be a tensor of token indices (batch_size, sequence_length)
embd = self.embedding(x) # Output (batch_size, sequence_length, embd_dim)
# Assuming we want to average embeddings across the sequence length for a single text representation
embd = embd.mean(dim=1) # Output (batch_size, embd_dim)
x = F.normalize(embd,dim=-1)
return x
#--creat mini clip---#
class MiniCLIP(nn.Module):
def __init__(self,image_encoder_instance,text_encoder_instance,embd_dim):
super().__init__()
self.ViTimage_encoder = image_encoder_instance
self.text_encoder = text_encoder_instance
# Corrected: logits_scale initialization and requires_grad
self.logits_scale = nn.Parameter(torch.ones([]) * torch.log(torch.tensor(1/0.07)))
# The original paper uses a learnable temperature parameter initialized to log(1/0.07)
# It's usually a scalar, not a tensor of shape []. Simplified for common use.
# Original paper initializes it as learnable logit_scale parameter
def forward(self,image_input,text_input):
image_features = self.ViTimage_encoder(image_input)
text_features = self.text_encoder(text_input)
image_features = F.normalize(image_features,dim=-1)
text_features = F.normalize(text_features,dim=-1)
scale = self.logits_scale.exp()
logits = scale * image_features @ text_features.T
return logits
def lossClip(logits):
# Corrected: Missing parenthesis for size(0)
labels = torch.arange(logits.size(0)).to(logits.device)
i_loss = F.cross_entropy(logits,labels)
# Corrected: apply cross_entropy on transposed logits for text-to-image similarity
t_loss = F.cross_entropy(logits.T,labels)
out = (i_loss + t_loss) / 2
return out
#------Creat Model------#
embd_dim = 128
vocab_size= 512
# Corrected: Instantiate encoders before passing them to MiniCLIP
# Changed num_heads from default 3 to 8, as 128 is not divisible by 3.
image_encoder_instance = VitImageEncoder(embd_dim=embd_dim, num_heads=8) # Assuming default img_size, path_size, depth
text_encoder_instance = TextEncoder(embd_dim,vocab_size)
# Corrected: Pass instances to MiniCLIP constructor
model = MiniCLIP(image_encoder_instance,text_encoder_instance,embd_dim)