1+ import typer
2+ import os
3+ import yaml
4+ import logging
5+ from pathlib import Path
6+
7+ import numpy as np
8+ from sklearn .metrics import precision_recall_fscore_support , accuracy_score
9+ import torch
10+ from transformers import DataCollatorWithPadding , IntervalStrategy , TrainingArguments , Trainer
11+
12+ from src .models import build_teacher
13+ from src .data import ClassificationDataModule
14+
15+
16+ logging .basicConfig (level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" )
17+ logger = logging .getLogger (__name__ )
18+
19+ app = typer .Typer ()
20+
21+
22+ def compute_metrics (p ):
23+ """Computes metrics for HF Trainer."""
24+ preds = np .argmax (p .predictions , axis = 1 )
25+ labels = p .label_ids
26+ precision , recall , f1 , _ = precision_recall_fscore_support (labels , preds , average = 'binary' ) # Assuming binary
27+ acc = accuracy_score (labels , preds )
28+ return {
29+ 'accuracy' : acc ,
30+ 'f1' : f1 ,
31+ 'precision' : precision ,
32+ 'recall' : recall
33+ }
34+
35+
36+ @app .command ()
37+ def main (config_path : Path = type .Argument (..., help = "Path to YAML config" )):
38+ cfg = yaml .safe_load (config_path .read_text ())
39+
40+ # --- SETUP W&B ---
41+ run_name = cfg ['training' ].get ("run_name" , f"teacher_train_{ Path (cfg ['training' ]['output_dir' ]).name } " )
42+ report_to = cfg ['training' ].get ("report_to" , "none" ) # Default to no reporting
43+ if report_to == "wandb" :
44+ project_name = cfg ['training' ].get ("wandb_project" , "senti_synth_teacher" )
45+ os .environ .pop ("WANDB_DISABLED" , None ) # Ensure it's enabled if requested
46+ os .environ ["WANDB_PROJECT" ] = project_name
47+ logger .info (f"Reporting to W&B project: { project_name } " )
48+ else :
49+ os .environ ["WANDB_DISABLED" ] = "true" # Explicitly disable
50+ logger .info ("W&B reporting disabled." )
51+
52+ # --- BUILD MODEL ---
53+ model , tokenizer = build_teacher (cfg ['model' ])
54+
55+ # --- SETUP DATA ---
56+ data_module = ClassificationDataModule (cfg ['data' ], tokenizer )
57+ data_module .setup ()
58+ train_dataset = data_module .get_train_dataset ()
59+ eval_dataset = data_module .get_eval_dataset ()
60+
61+ # --- SETUP TRAINER ---
62+ data_collator = DataCollatorWithPadding (tokenizer = tokenizer )
63+
64+ training_args_dict = {
65+ "output_dir" : cfg ['training' ]['output_dir' ],
66+ "overwrite_output_dir" : cfg ['training' ].get ("overwrite_output_dir" , True ),
67+ "do_train" : True ,
68+ "do_eval" : eval_dataset is not None , # Only do eval if eval_dataset exists
69+ "per_device_train_batch_size" : cfg ['training' ].get ("per_device_train_batch_size" , 8 ),
70+ "per_device_eval_batch_size" : cfg ['training' ].get ("per_device_eval_batch_size" , 16 ),
71+ "gradient_accumulation_steps" : cfg ['training' ].get ("gradient_accumulation_steps" , 1 ),
72+ "num_train_epochs" : cfg ['training' ].get ("num_train_epochs" , 3 ),
73+ "learning_rate" : cfg ['training' ].get ("learning_rate" , 5e-5 ),
74+ "warmup_ratio" : cfg ['training' ].get ("warmup_ratio" , 0.1 ),
75+ "fp16" : cfg ['training' ].get ("fp16" , torch .cuda .is_available ()), # Enable FP16 if available by default
76+ "logging_dir" : cfg ['training' ].get ("logging_dir" , f"{ cfg ['training' ]['output_dir' ]} /logs" ),
77+ "logging_steps" : cfg ['training' ].get ("logging_steps" , 100 ),
78+ "eval_strategy" : IntervalStrategy .STEPS if eval_dataset is not None else IntervalStrategy .NO ,
79+ "eval_steps" : cfg ['training' ].get ("eval_steps" , 500 ),
80+ "save_strategy" : IntervalStrategy .STEPS ,
81+ "save_steps" : cfg ['training' ].get ("save_steps" , 500 ),
82+ "save_total_limit" : cfg ['training' ].get ("save_total_limit" , 2 ),
83+ "load_best_model_at_end" : cfg ['training' ].get ("load_best_model_at_end" , eval_dataset is not None ), # Only if eval is done
84+ "metric_for_best_model" : cfg ['training' ].get ("metric_for_best_model" , "eval_f1" if eval_dataset else None ),
85+ "greater_is_better" : cfg ['training' ].get ("greater_is_better" , True ),
86+ "report_to" : [report_to ] if report_to != "none" else [],
87+ "run_name" : run_name ,
88+ "label_names" : ["labels" ], # Standard practice
89+ "remove_unused_columns" : False , # We already removed them in data module
90+ "ddp_find_unused_parameters" : cfg ['training' ].get ("ddp_find_unused_parameters" , False ),
91+ }
92+
93+ training_args = TrainingArguments (** training_args_dict )
94+ logger .info (f"Training arguments: { training_args } . FP16 Enabled: { training_args .fp16 } " )
95+
96+ trainer = Trainer (
97+ model = model ,
98+ args = training_args ,
99+ train_dataset = train_dataset ,
100+ eval_dataset = eval_dataset ,
101+ tokenizer = tokenizer ,
102+ data_collator = data_collator ,
103+ compute_metrics = compute_metrics if eval_dataset is not None else None ,
104+ )
105+
106+ # --- TRAIN ---
107+ logger .info ("Training model..." )
108+ train_result = trainer .train ()
109+ logger .info (f"Training results: { train_result } " )
110+
111+ # Save final model & metrics
112+ logger .info (f"Saving best model to { training_args .output_dir } " )
113+ trainer .save_model () # Saves the best model due to load_best_model_at_end=True
114+ trainer .save_state ()
115+
116+ # Log final metrics
117+ metrics = train_result .metrics
118+ trainer .log_metrics ("train" , metrics )
119+ trainer .save_metrics ("train" , metrics )
120+
121+ # Evaluate on test set if available
122+ test_dataset = data_module .get_test_dataset ()
123+ if test_dataset and cfg ['training' ].get ("do_test_eval" , True ):
124+ logger .info ("Evaluating on test set..." )
125+ test_metrics = trainer .evaluate (eval_dataset = test_dataset , metric_key_prefix = "test" )
126+ trainer .log_metrics ("test" , test_metrics )
127+ trainer .save_metrics ("test" , test_metrics )
128+ logger .info (f"Test set evaluation complete: { test_metrics } " )
129+
130+
131+ logger .info ("Script finished successfully." )
132+
133+
134+ if __name__ == "__main__" :
135+ app ()
0 commit comments