-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
22 lines (19 loc) · 795 Bytes
/
model.py
File metadata and controls
22 lines (19 loc) · 795 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch.nn as nn
class Mlp(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(1024, 2048)
self.fc2 = nn.Linear(2048, 2048)
self.fc3 = nn.Linear(2048, 1280)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def caption_generation(image_feature, model, tokenizer, device):
text = "prefix prefix prefix prefix prefix:"
inputs = tokenizer(text, return_tensors="pt")
output = model.generate(inputs["input_ids"].to(device), 40, prefix = image_feature, do_sample = False, num_beams=5)[0]
output = tokenizer.decode(output)
return output.split(':')[1].split('.')[0].lower()