-
Notifications
You must be signed in to change notification settings - Fork 73
Expand file tree
/
Copy pathutils.py
More file actions
149 lines (129 loc) · 5.12 KB
/
utils.py
File metadata and controls
149 lines (129 loc) · 5.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import json
import re
def prepare_dataset_for_eval(dataset_name, output_file):
if dataset_name == 'cwq':
with open('../data/cwq.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'question'
elif dataset_name == 'webqsp':
with open('../data/WebQSP.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'RawQuestion'
elif dataset_name == 'grailqa':
with open('../data/grailqa.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'question'
elif dataset_name == 'simpleqa':
with open('../data/SimpleQA.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'question'
elif dataset_name == 'qald':
with open('../data/qald_10-en.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'question'
elif dataset_name == 'webquestions':
with open('../data/WebQuestions.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'question'
elif dataset_name == 'trex':
with open('../data/T-REX.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'input'
elif dataset_name == 'zeroshotre':
with open('../data/Zero_Shot_RE.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'input'
elif dataset_name == 'creak':
with open('../data/creak.json',encoding='utf-8') as f:
datas = json.load(f)
question_string = 'sentence'
else:
print("dataset not found, you should pick from {cwq, webqsp, grailqa, simpleqa, qald, webquestions, trex, zeroshotre, creak}.")
exit(-1)
with open(output_file, encoding='utf-8') as f:
output_datas = json.load(f)
return datas, question_string, output_datas
def align(dataset_name, question_string, data, ground_truth_datas):
answer_list= []
origin_data = [j for j in ground_truth_datas if j[question_string] == data[question_string]][0]
if dataset_name == 'cwq':
if 'answers' in origin_data:
answers = origin_data["answers"]
else:
answers = origin_data["answer"]
for answer in answers:
alias = answer['aliases']
ans = answer['answer']
alias.append(ans)
answer_list.extend(alias)
elif dataset_name == 'webqsp':
answers = origin_data["Parses"]
for answer in answers:
for name in answer['Answers']:
if name['EntityName'] == None:
answer_list.append(name['AnswerArgument'])
else:
answer_list.append(name['EntityName'])
elif dataset_name == 'grailqa':
answers = origin_data["answer"]
for answer in answers:
if "entity_name" in answer:
answer_list.append(answer['entity_name'])
else:
answer_list.append(answer['answer_argument'])
elif dataset_name == 'simpleqa':
answers = origin_data["answer"]
answer_list.append(answers)
elif dataset_name == 'qald':
answers = origin_data["answer"]
for answer in answers:
answer_list.append(answers[answer])
elif dataset_name == 'webquestions':
answer_list = origin_data["answers"]
elif dataset_name == 'trex' or dataset_name == 'zeroshotre':
answers = origin_data["answer"]
answer_list.append(answers)
elif dataset_name == 'creak':
answer = origin_data['label']
answer_list.append(answer)
return list(set(answer_list))
def check_string(string):
return "{" in string
def clean_results(string):
if string.lower().startswith('{yes}'):
string = string[5:]
if "{" in string:
start = string.find("{") + 1
end = string.find("}")
content = string[start:end]
return content
else:
return "NULL"
def check_refuse(string):
refuse_words = ["however", "sorry"]
return any(word in string.lower() for word in refuse_words)
def exact_match(response, answers):
clean_result = response.strip().replace(" ","").lower()
for answer in answers:
clean_answer = answer.strip().replace(" ","").lower()
if clean_result == clean_answer or clean_result in clean_answer or clean_answer in clean_result:
return True
return False
def save_result2json(dataset_name, num_right, num_error, total_nums, method):
results_data = {
'dataset': dataset_name,
'method': method,
'Exact Match': float(num_right/total_nums),
'Right Samples': num_right,
'Error Sampels': num_error
}
with open('ToG_{}_results.json'.format(dataset_name), 'w', encoding='utf-8') as f:
json.dump(results_data, f, ensure_ascii=False, indent=4)
def extract_content(s):
matches = re.findall(r'\{(.*?)\}', s)
if len(matches) >= 2 and matches[0].lower() == 'yes':
return matches[1]
elif len(matches) >= 1:
return matches[0]
else:
return 'NULL'