Skip to content

Commit 079fb46

Browse files
committed
fix training for relation extraction model
1 parent d5ff35d commit 079fb46

1 file changed

Lines changed: 7 additions & 10 deletions

File tree

renard/pipeline/relation_extraction.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
PreTrainedModel,
1414
EvalPrediction,
1515
pipeline as hg_pipeline,
16+
BatchEncoding,
1617
)
1718
from more_itertools import flatten
1819
from transformers.pipelines.pt_utils import KeyDataset
@@ -26,23 +27,19 @@
2627
Relation = tuple[Character, str, Character]
2728

2829

29-
def _load_ARF_line(example: dict, tokenizer: PreTrainedTokenizerFast) -> dict:
30-
example["relations"] = ast.literal_eval(example["relations"] or "[]")
30+
def _load_ARF_line(example: dict, tokenizer: PreTrainedTokenizerFast) -> BatchEncoding:
31+
relations = ast.literal_eval(example["relations"] or "[]")
3132

3233
def format_rel(rel: dict) -> str:
3334
return "({}, {}, {})".format(rel["entity1"], rel["relation"], rel["entity2"])
3435

35-
labels = " ".join(map(format_rel, example["relations"]))
36-
with tokenizer.as_target_tokenizer():
37-
labels_batch = tokenizer(labels)
38-
example["labels"] = labels_batch["input_ids"]
36+
labels = "[" + ",".join(map(format_rel, relations)) + "]"
3937

4038
text = example["chunk"] or ""
41-
example["input_ids"] = tokenizer(GenerativeRelationExtractor.task_prompt(text))[
42-
"input_ids"
43-
]
39+
batch = tokenizer(GenerativeRelationExtractor.task_prompt(text), text_target=labels)
40+
batch["relations"] = relations
4441

45-
return example
42+
return batch
4643

4744

4845
def load_ARF_dataset(tokenizer: PreTrainedTokenizerFast) -> HGDataset:

0 commit comments

Comments
 (0)