@@ -337,6 +337,7 @@ def train_ner_model(
337337 targs : TrainingArguments ,
338338 train_split : str = "train" ,
339339 valid_split : str = "valid" ,
340+ trainer_class : type [Trainer ] = Trainer ,
340341) -> PreTrainedModel :
341342 """Train a NER model on the given dataset.
342343
@@ -347,6 +348,8 @@ def train_ner_model(
347348 trainer.
348349 :param train_split: split of the dataset used for train.
349350 :param valid_split: split of the dataset used for validation.
351+ :param trainer_class: trainer class to use. Can be used to
352+ override the default huggingface trainer.
350353 """
351354 from transformers import DataCollatorForTokenClassification
352355
@@ -366,12 +369,11 @@ def train_ner_model(
366369 label2id = {label : i for i , label in enumerate (label_lst )},
367370 )
368371
369- trainer = Trainer (
372+ trainer = trainer_class (
370373 model ,
371374 targs ,
372375 train_dataset = dataset [train_split ],
373376 eval_dataset = dataset [valid_split ],
374- # data_collator=DataCollatorForTokenClassificationWithBatchEncoding(tokenizer),
375377 data_collator = DataCollatorForTokenClassification (tokenizer ),
376378 tokenizer = tokenizer ,
377379 )
0 commit comments