Skip to content

Commit 0751f5e

Browse files
committed
ner_utils.train_ner_model now allows a custom trainer class
1 parent 33626a7 commit 0751f5e

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

renard/ner_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ def train_ner_model(
337337
targs: TrainingArguments,
338338
train_split: str = "train",
339339
valid_split: str = "valid",
340+
trainer_class: type[Trainer] = Trainer,
340341
) -> PreTrainedModel:
341342
"""Train a NER model on the given dataset.
342343
@@ -347,6 +348,8 @@ def train_ner_model(
347348
trainer.
348349
:param train_split: split of the dataset used for train.
349350
:param valid_split: split of the dataset used for validation.
351+
:param trainer_class: trainer class to use. Can be used to
352+
override the default huggingface trainer.
350353
"""
351354
from transformers import DataCollatorForTokenClassification
352355

@@ -366,12 +369,11 @@ def train_ner_model(
366369
label2id={label: i for i, label in enumerate(label_lst)},
367370
)
368371

369-
trainer = Trainer(
372+
trainer = trainer_class(
370373
model,
371374
targs,
372375
train_dataset=dataset[train_split],
373376
eval_dataset=dataset[valid_split],
374-
# data_collator=DataCollatorForTokenClassificationWithBatchEncoding(tokenizer),
375377
data_collator=DataCollatorForTokenClassification(tokenizer),
376378
tokenizer=tokenizer,
377379
)

0 commit comments

Comments
 (0)