-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathutils.py
More file actions
77 lines (68 loc) · 2.57 KB
/
utils.py
File metadata and controls
77 lines (68 loc) · 2.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def format_result(result, text, tag):
entities = []
for i in result:
begin, end = i
entities.append({
"start": begin,
"stop": end + 1,
"word": text[begin:end + 1],
"type": tag
})
return entities
def get_tags(path, tag, tag_map):
begin_tag = tag_map.get("B-" + tag)
mid_tag = tag_map.get("I-" + tag)
end_tag = tag_map.get("E-" + tag)
single_tag = tag_map.get("S")
o_tag = tag_map.get("O")
begin = -1
end = 0
tags = []
last_tag = 0
for index, tag in enumerate(path):
if tag == begin_tag and index == 0:
begin = 0
elif tag == begin_tag:
begin = index
elif tag == end_tag and last_tag in [mid_tag, begin_tag] and begin > -1:
end = index
tags.append([begin, end])
elif tag == o_tag or tag == single_tag:
begin = -1
last_tag = tag
return tags
def f1_score(tar_path, pre_path, tag, tag_map):
origin = 0.
found = 0.
right = 0.
for fetch in zip(tar_path, pre_path):
tar, pre = fetch
tar_tags = get_tags(tar, tag, tag_map)
pre_tags = get_tags(pre, tag, tag_map)
origin += len(tar_tags)
found += len(pre_tags)
for p_tag in pre_tags:
if p_tag in tar_tags:
right += 1
recall = 0. if origin == 0 else (right / origin)
precision = 0. if found == 0 else (right / found)
f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
print("\t{}\trecall {:.2f}\tprecision {:.2f}\tf1 {:.2f}".format(tag, recall, precision, f1))
return recall, precision, f1
def path_to_entity(seq_of_word, seq_of_tag, ix_to_word, ix_to_tag, res=[]):
entity = []
for ix in range(len(seq_of_word)):
if ix_to_tag[seq_of_tag[ix]][0] == 'B':
entity = [str(ix), ix_to_word[seq_of_word[ix]] + '/' + ix_to_tag[seq_of_tag[ix]]] # 起始下标
elif ix_to_tag[seq_of_tag[ix]][0] == 'M' and len(entity) != 0 \
and entity[-1].split('/')[1][1:] == ix_to_tag[seq_of_tag[ix]][1:]:
entity.append(ix_to_word[seq_of_word[ix]] + '/' + ix_to_tag[seq_of_tag[ix]])
elif ix_to_tag[seq_of_tag[ix]][0] == 'E' and len(entity) != 0 \
and entity[-1].split('/')[1][1:] == ix_to_tag[seq_of_tag[ix]][1:]:
entity.append(ix_to_word[seq_of_word[ix]] + '/' + ix_to_tag[seq_of_tag[ix]])
entity.append(str(ix))
res.append(entity)
entity = []
else:
entity = []
return res