2626#: (subject, relation, object)
2727Relation = tuple [Character , str , Character ]
2828
29+ seq2seq_special_tokens = ["<triplet>" , "<subj>" , "<rel>" , "<obj>" , "</triplet>" ]
30+
2931
3032def _load_ARF_line (example : dict , tokenizer : PreTrainedTokenizerFast ) -> BatchEncoding :
3133 relations = ast .literal_eval (example ["relations" ] or "[]" )
3234
3335 def format_rel (rel : dict ) -> str :
34- return "({}, {}, {})" .format (rel ["entity1" ], rel ["relation" ], rel ["entity2" ])
36+ return "<triplet> <subj> {} <rel> {} <obj> {} </triplet>" .format (
37+ rel ["entity1" ], rel ["relation" ], rel ["entity2" ]
38+ )
3539
36- labels = "[" + "," .join (map (format_rel , relations )) + "]"
40+ labels = " " .join (map (format_rel , relations ))
3741
3842 text = example ["chunk" ] or ""
39- batch = tokenizer (GenerativeRelationExtractor .task_prompt (text ), text_target = labels )
43+ batch = tokenizer (
44+ tokenizer .bos_token + GenerativeRelationExtractor .task_prompt (text ),
45+ text_target = labels + tokenizer .eos_token ,
46+ )
4047 batch ["relations" ] = relations
4148
4249 return batch
@@ -101,6 +108,7 @@ def train_model_on_ARF(
101108 assert not tokenizer is None
102109 tokenizer .pad_token = tokenizer .eos_token
103110 pad_token_i = tokenizer .encode (tokenizer .pad_token )[0 ]
111+ tokenizer .add_special_tokens ({"additional_special_tokens" : seq2seq_special_tokens })
104112
105113 dataset = load_ARF_dataset (tokenizer )
106114
@@ -123,7 +131,7 @@ def compute_metrics(eval_preds: EvalPrediction) -> dict[str, float]:
123131 targs ,
124132 train_dataset = dataset ["train" ],
125133 eval_dataset = dataset ["test" ],
126- data_collator = DataCollatorForSeq2Seq (tokenizer ),
134+ data_collator = DataCollatorForSeq2Seq (tokenizer , model ),
127135 compute_metrics = compute_metrics ,
128136 )
129137 trainer .train ()
@@ -203,7 +211,15 @@ def task_prompt(text: str) -> str:
203211
204212 @staticmethod
205213 def parse_text_relations (text_relations : str ) -> list [tuple [str , str , str ]]:
206- return re .findall (r"\(([^,]+), ([^,]+), ([^,]+)\)" , text_relations )
214+ triplets = re .findall (
215+ r"<triplet> ?<subj>([^<]+)<rel>([^<]+)<obj>([^<]+)</triplet>" ,
216+ text_relations ,
217+ )
218+ triplets = [
219+ (subj .strip (" " ), rel .strip (" " ), obj .strip (" " ))
220+ for subj , rel , obj in triplets
221+ ]
222+ return triplets
207223
208224 @staticmethod
209225 def identify_character (
0 commit comments