-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathget_task_embeddings.py
More file actions
110 lines (94 loc) · 4.68 KB
/
get_task_embeddings.py
File metadata and controls
110 lines (94 loc) · 4.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from transformers import AutoModel, AutoTokenizer
from hydra.utils import to_absolute_path
from libero.libero import benchmark
from omegaconf import DictConfig, OmegaConf
import pickle
def get_task_embs(cfg, descriptions):
if cfg.task_embedding_format == "one-hot":
# offset defaults to 1, if we have pretrained another model, this offset
# starts from the pretrained number of tasks + 1
offset = cfg.task_embedding_one_hot_offset
descriptions = [f"Task {i+offset}" for i in range(len(descriptions))]
if cfg.task_embedding_format == "bert" or cfg.task_embedding_format == "one-hot":
tz = AutoTokenizer.from_pretrained(
"bert-base-cased", cache_dir=to_absolute_path("./bert")
)
model = AutoModel.from_pretrained(
"bert-base-cased", cache_dir=to_absolute_path("./bert")
)
tokens = tz(
text=descriptions, # the sentence to be encoded
add_special_tokens=True, # Add [CLS] and [SEP]
max_length=cfg.data.max_word_len, # maximum length of a sentence
padding="max_length",
return_attention_mask=True, # Generate the attention mask
return_tensors="pt", # ask the function to return PyTorch tensors
)
masks = tokens["attention_mask"]
input_ids = tokens["input_ids"]
task_embs = model(tokens["input_ids"], tokens["attention_mask"])[
"pooler_output"
].detach()
elif cfg.task_embedding_format == "gpt2":
tz = AutoTokenizer.from_pretrained("gpt2")
tz.pad_token = tz.eos_token
model = AutoModel.from_pretrained("gpt2")
tokens = tz(
text=descriptions, # the sentence to be encoded
add_special_tokens=True, # Add [CLS] and [SEP]
max_length=cfg.data.max_word_len, # maximum length of a sentence
padding="max_length",
return_attention_mask=True, # Generate the attention mask
return_tensors="pt", # ask the function to return PyTorch tensors
)
task_embs = model(**tokens)["last_hidden_state"].detach()[:, -1]
elif cfg.task_embedding_format == "clip":
tz = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
model = AutoModel.from_pretrained("openai/clip-vit-base-patch32")
tokens = tz(
text=descriptions, # the sentence to be encoded
add_special_tokens=True, # Add [CLS] and [SEP]
max_length=cfg.data.max_word_len, # maximum length of a sentence
padding="max_length",
return_attention_mask=True, # Generate the attention mask
return_tensors="pt", # ask the function to return PyTorch tensors
)
task_embs = model.get_text_features(**tokens).detach()
elif cfg.task_embedding_format == "roberta":
tz = AutoTokenizer.from_pretrained("roberta-base")
tz.pad_token = tz.eos_token
model = AutoModel.from_pretrained("roberta-base")
tokens = tz(
text=descriptions, # the sentence to be encoded
add_special_tokens=True, # Add [CLS] and [SEP]
max_length=cfg.data.max_word_len, # maximum length of a sentence
padding="max_length",
return_attention_mask=True, # Generate the attention mask
return_tensors="pt", # ask the function to return PyTorch tensors
)
task_embs = model(**tokens)["pooler_output"].detach()
cfg.policy.language_encoder.network_kwargs.input_size = task_embs.shape[-1]
return task_embs
def create_cfg_for_libero(task_embedding_format):
cfg = DictConfig({'task_embedding_format': task_embedding_format,
'data': {'max_word_len': 25}})
cfg.policy = OmegaConf.create()
cfg.policy.language_encoder = OmegaConf.create()
cfg.policy.language_encoder.network_kwargs = OmegaConf.create()
return cfg
for task in ["libero_object", "libero_spatial", "libero_10", "libero_goal", "libero_90"]:
# get task embedding
task_suite = benchmark.get_benchmark_dict()[task]()
cfg = create_cfg_for_libero("clip")
tasks = {}
if task != "libero_90":
task_names = [task_suite.get_task(i).name for i in range(10)]
descriptions = [task_suite.get_task(i).language for i in range(10)]
else:
task_names = [task_suite.get_task(i).name for i in range(90)]
descriptions = [task_suite.get_task(i).language for i in range(90)]
task_embs = get_task_embs(cfg, descriptions)
for num, name in enumerate(task_names):
tasks[name] = task_embs[num:num+1]
with open("task_embeddings/" + task + ".pkl", 'wb') as f: # open a text file
pickle.dump(tasks, f) # serialize the list