Skip to content

Commit 0a0bab0

Browse files
committed
improve format for relation extraction
1 parent 1bb2ac8 commit 0a0bab0

1 file changed

Lines changed: 21 additions & 5 deletions

File tree

renard/pipeline/relation_extraction.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,24 @@
2626
#: (subject, relation, object)
2727
Relation = tuple[Character, str, Character]
2828

29+
seq2seq_special_tokens = ["<triplet>", "<subj>", "<rel>", "<obj>", "</triplet>"]
30+
2931

3032
def _load_ARF_line(example: dict, tokenizer: PreTrainedTokenizerFast) -> BatchEncoding:
3133
relations = ast.literal_eval(example["relations"] or "[]")
3234

3335
def format_rel(rel: dict) -> str:
34-
return "({}, {}, {})".format(rel["entity1"], rel["relation"], rel["entity2"])
36+
return "<triplet> <subj> {} <rel> {} <obj> {} </triplet>".format(
37+
rel["entity1"], rel["relation"], rel["entity2"]
38+
)
3539

36-
labels = "[" + ",".join(map(format_rel, relations)) + "]"
40+
labels = " ".join(map(format_rel, relations))
3741

3842
text = example["chunk"] or ""
39-
batch = tokenizer(GenerativeRelationExtractor.task_prompt(text), text_target=labels)
43+
batch = tokenizer(
44+
tokenizer.bos_token + GenerativeRelationExtractor.task_prompt(text),
45+
text_target=labels + tokenizer.eos_token,
46+
)
4047
batch["relations"] = relations
4148

4249
return batch
@@ -101,6 +108,7 @@ def train_model_on_ARF(
101108
assert not tokenizer is None
102109
tokenizer.pad_token = tokenizer.eos_token
103110
pad_token_i = tokenizer.encode(tokenizer.pad_token)[0]
111+
tokenizer.add_special_tokens({"additional_special_tokens": seq2seq_special_tokens})
104112

105113
dataset = load_ARF_dataset(tokenizer)
106114

@@ -123,7 +131,7 @@ def compute_metrics(eval_preds: EvalPrediction) -> dict[str, float]:
123131
targs,
124132
train_dataset=dataset["train"],
125133
eval_dataset=dataset["test"],
126-
data_collator=DataCollatorForSeq2Seq(tokenizer),
134+
data_collator=DataCollatorForSeq2Seq(tokenizer, model),
127135
compute_metrics=compute_metrics,
128136
)
129137
trainer.train()
@@ -203,7 +211,15 @@ def task_prompt(text: str) -> str:
203211

204212
@staticmethod
205213
def parse_text_relations(text_relations: str) -> list[tuple[str, str, str]]:
206-
return re.findall(r"\(([^,]+), ([^,]+), ([^,]+)\)", text_relations)
214+
triplets = re.findall(
215+
r"<triplet> ?<subj>([^<]+)<rel>([^<]+)<obj>([^<]+)</triplet>",
216+
text_relations,
217+
)
218+
triplets = [
219+
(subj.strip(" "), rel.strip(" "), obj.strip(" "))
220+
for subj, rel, obj in triplets
221+
]
222+
return triplets
207223

208224
@staticmethod
209225
def identify_character(

0 commit comments

Comments
 (0)