Skip to content

Commit c79f971

Browse files
committed
refactored bad training code for teacher model
1 parent 5f3b970 commit c79f971

16 files changed

Lines changed: 345 additions & 51 deletions

README.md

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,27 @@ More details coming soon!
3131
- `evaluation/`: Evaluation metrics and analysis
3232
- `tests/`: Unit tests
3333
- `notebooks/`: Jupyter notebooks for exploration
34-
- `scripts/`: Utility scripts
34+
- `scripts/`: Utility scripts
35+
36+
37+
## Training
38+
To run on `weftdrive`:
39+
```bash
40+
nohup /srv/gpurun.pl python src/senti_synth/cli/01_train_teacher.py configs/teacher/stt2_hf.yaml > ~/scratch/senti_synth/logs/$(date +%Y%m%d_%H%M).log 2>&1 &
41+
```
42+
43+
### Setting up on weftdrive
44+
1. SSH into weftdrive: `ssh paramkapur@weftdrive.private.reed.edu`
45+
2. Git clone the repository: `git clone https://github.com/paramkpr/senti_synth.git`
46+
3. Setup the conda environment `/srv/conda/bin/conda init` and `source ~/.bashrc`
47+
4. Enter the conda environment `conda activate deep-learning`
48+
1. Check what packages are installed `conda list`
49+
2. Install the packages for the project `pip install -r requirements.txt`
50+
3. Install the project `pip install -e .`
51+
5. SCP `data/clean` to `weftdrive:~/scratch/data/clean`: `scp -r data/clean weftdrive:~/scratch/data/`
52+
1. Ensure that the config file points to the correct path: `dataset_path: "~/scratch/data/clean"`
53+
6. Setup W&B:
54+
1. `export WANDB_API_KEY="..."`
55+
2. `python -m wandb login`
56+
7. Run the training script: `nohup /srv/gpurun.pl python src/senti_synth/cli/01_train_teacher.py configs/teacher/stt2_hf.yaml > ~/scratch/senti_synth/logs/$(date +%Y%m%d_%H%M).log 2>&1 &`
57+

configs/deberta_large_sst2.yaml

Lines changed: 0 additions & 20 deletions
This file was deleted.
File renamed without changes.

configs/teacher/sst2_hf.yaml

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
model:
2+
model_name: "microsoft/deberta-v3-base"
3+
num_labels: 2
4+
use_fast_tokenizer: true
5+
6+
data:
7+
dataset_path: "~/scratch/data/clean" # Use HF dataset identifier
8+
max_len: 32
9+
train_split: "train"
10+
validation_split: "val"
11+
test_split: "test"
12+
13+
training:
14+
output_dir: "runs/teacher/deberta_v3_base" # Specific output for this run
15+
overwrite_output_dir: true
16+
run_name: "teacher_sst2_deberta_v3_base_run" # Optional W&B/TensorBoard run name
17+
18+
# Reporting
19+
report_to: "wandb"
20+
wandb_project: "senti_synth_teacher"
21+
22+
# Batching & Epochs
23+
per_device_train_batch_size: 16
24+
per_device_eval_batch_size: 32
25+
gradient_accumulation_steps: 1
26+
num_train_epochs: 3
27+
28+
# Optimizer & Scheduler
29+
learning_rate: 3e-5
30+
warmup_ratio: 0.1
31+
32+
# Logging, Saving, Evaluation
33+
logging_steps: 50
34+
eval_steps: 200 # Evaluate every N steps
35+
save_steps: 200 # Save checkpoint every N steps
36+
save_total_limit: 2 # Keep only the best and the latest checkpoints
37+
load_best_model_at_end: true # Load the best model found during training
38+
metric_for_best_model: "eval_f1" # Metric to determine the 'best' model
39+
greater_is_better: true
40+
41+
# Hardware & Performance
42+
fp16: true # Set to false if GPU doesn't support FP16 or causes issues
43+
44+
# Callbacks
45+
use_early_stopping: true
46+
early_stopping_patience: 3
47+
early_stopping_threshold: 0.001 # Small improvement needed to reset patience
48+
49+
# Optional: Evaluate on test set after training
50+
do_test_eval: true

notebooks/eda_sst_2.ipynb

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -475,18 +475,6 @@
475475
"display_name": "Python (sentisynth)",
476476
"language": "python",
477477
"name": "auctionn"
478-
},
479-
"language_info": {
480-
"codemirror_mode": {
481-
"name": "ipython",
482-
"version": 3
483-
},
484-
"file_extension": ".py",
485-
"mimetype": "text/x-python",
486-
"name": "python",
487-
"nbconvert_exporter": "python",
488-
"pygments_lexer": "ipython3",
489-
"version": "3.13.3"
490478
}
491479
},
492480
"nbformat": 4,

notebooks/teacher.ipynb

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,18 +165,6 @@
165165
"display_name": "Python3.11 (sentisynth)",
166166
"language": "python",
167167
"name": "auctionn"
168-
},
169-
"language_info": {
170-
"codemirror_mode": {
171-
"name": "ipython",
172-
"version": 3
173-
},
174-
"file_extension": ".py",
175-
"mimetype": "text/x-python",
176-
"name": "python",
177-
"nbconvert_exporter": "python",
178-
"pygments_lexer": "ipython3",
179-
"version": "3.11.2"
180168
}
181169
},
182170
"nbformat": 4,
File renamed without changes.

src/cli/01_train_teacher.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import typer
2+
import os
3+
import yaml
4+
import logging
5+
from pathlib import Path
6+
7+
import numpy as np
8+
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
9+
import torch
10+
from transformers import DataCollatorWithPadding, IntervalStrategy, TrainingArguments, Trainer
11+
12+
from src.models import build_teacher
13+
from src.data import ClassificationDataModule
14+
15+
16+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
17+
logger = logging.getLogger(__name__)
18+
19+
app = typer.Typer()
20+
21+
22+
def compute_metrics(p):
23+
"""Computes metrics for HF Trainer."""
24+
preds = np.argmax(p.predictions, axis=1)
25+
labels = p.label_ids
26+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary') # Assuming binary
27+
acc = accuracy_score(labels, preds)
28+
return {
29+
'accuracy': acc,
30+
'f1': f1,
31+
'precision': precision,
32+
'recall': recall
33+
}
34+
35+
36+
@app.command()
37+
def main(config_path: Path = type.Argument(..., help="Path to YAML config")):
38+
cfg = yaml.safe_load(config_path.read_text())
39+
40+
# --- SETUP W&B ---
41+
run_name = cfg['training'].get("run_name", f"teacher_train_{Path(cfg['training']['output_dir']).name}")
42+
report_to = cfg['training'].get("report_to", "none") # Default to no reporting
43+
if report_to == "wandb":
44+
project_name = cfg['training'].get("wandb_project", "senti_synth_teacher")
45+
os.environ.pop("WANDB_DISABLED", None) # Ensure it's enabled if requested
46+
os.environ["WANDB_PROJECT"] = project_name
47+
logger.info(f"Reporting to W&B project: {project_name}")
48+
else:
49+
os.environ["WANDB_DISABLED"] = "true" # Explicitly disable
50+
logger.info("W&B reporting disabled.")
51+
52+
# --- BUILD MODEL ---
53+
model, tokenizer = build_teacher(cfg['model'])
54+
55+
# --- SETUP DATA ---
56+
data_module = ClassificationDataModule(cfg['data'], tokenizer)
57+
data_module.setup()
58+
train_dataset = data_module.get_train_dataset()
59+
eval_dataset = data_module.get_eval_dataset()
60+
61+
# --- SETUP TRAINER ---
62+
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
63+
64+
training_args_dict = {
65+
"output_dir": cfg['training']['output_dir'],
66+
"overwrite_output_dir": cfg['training'].get("overwrite_output_dir", True),
67+
"do_train": True,
68+
"do_eval": eval_dataset is not None, # Only do eval if eval_dataset exists
69+
"per_device_train_batch_size": cfg['training'].get("per_device_train_batch_size", 8),
70+
"per_device_eval_batch_size": cfg['training'].get("per_device_eval_batch_size", 16),
71+
"gradient_accumulation_steps": cfg['training'].get("gradient_accumulation_steps", 1),
72+
"num_train_epochs": cfg['training'].get("num_train_epochs", 3),
73+
"learning_rate": cfg['training'].get("learning_rate", 5e-5),
74+
"warmup_ratio": cfg['training'].get("warmup_ratio", 0.1),
75+
"fp16": cfg['training'].get("fp16", torch.cuda.is_available()), # Enable FP16 if available by default
76+
"logging_dir": cfg['training'].get("logging_dir", f"{cfg['training']['output_dir']}/logs"),
77+
"logging_steps": cfg['training'].get("logging_steps", 100),
78+
"eval_strategy": IntervalStrategy.STEPS if eval_dataset is not None else IntervalStrategy.NO,
79+
"eval_steps": cfg['training'].get("eval_steps", 500),
80+
"save_strategy": IntervalStrategy.STEPS,
81+
"save_steps": cfg['training'].get("save_steps", 500),
82+
"save_total_limit": cfg['training'].get("save_total_limit", 2),
83+
"load_best_model_at_end": cfg['training'].get("load_best_model_at_end", eval_dataset is not None), # Only if eval is done
84+
"metric_for_best_model": cfg['training'].get("metric_for_best_model", "eval_f1" if eval_dataset else None),
85+
"greater_is_better": cfg['training'].get("greater_is_better", True),
86+
"report_to": [report_to] if report_to != "none" else [],
87+
"run_name": run_name,
88+
"label_names": ["labels"], # Standard practice
89+
"remove_unused_columns": False, # We already removed them in data module
90+
"ddp_find_unused_parameters": cfg['training'].get("ddp_find_unused_parameters", False),
91+
}
92+
93+
training_args = TrainingArguments(**training_args_dict)
94+
logger.info(f"Training arguments: {training_args}. FP16 Enabled: {training_args.fp16}")
95+
96+
trainer = Trainer(
97+
model=model,
98+
args=training_args,
99+
train_dataset=train_dataset,
100+
eval_dataset=eval_dataset,
101+
tokenizer=tokenizer,
102+
data_collator=data_collator,
103+
compute_metrics=compute_metrics if eval_dataset is not None else None,
104+
)
105+
106+
# --- TRAIN ---
107+
logger.info("Training model...")
108+
train_result = trainer.train()
109+
logger.info(f"Training results: {train_result}")
110+
111+
# Save final model & metrics
112+
logger.info(f"Saving best model to {training_args.output_dir}")
113+
trainer.save_model() # Saves the best model due to load_best_model_at_end=True
114+
trainer.save_state()
115+
116+
# Log final metrics
117+
metrics = train_result.metrics
118+
trainer.log_metrics("train", metrics)
119+
trainer.save_metrics("train", metrics)
120+
121+
# Evaluate on test set if available
122+
test_dataset = data_module.get_test_dataset()
123+
if test_dataset and cfg['training'].get("do_test_eval", True):
124+
logger.info("Evaluating on test set...")
125+
test_metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")
126+
trainer.log_metrics("test", test_metrics)
127+
trainer.save_metrics("test", test_metrics)
128+
logger.info(f"Test set evaluation complete: {test_metrics}")
129+
130+
131+
logger.info("Script finished successfully.")
132+
133+
134+
if __name__ == "__main__":
135+
app()

src/data.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import logging
2+
from datasets import load_from_disk, DatasetDict
3+
from transformers import AutoTokenizer
4+
logger = logging.getLogger(__name__)
5+
6+
7+
class ClassificationDataModule:
8+
"""
9+
Data module for classification tasks. Handles standard text classification setup.
10+
Used by Teacher (training/eval) and Student (eval).
11+
"""
12+
def __init__(self, cfg: dict, tokenizer: AutoTokenizer):
13+
self.cfg = cfg
14+
self.tokenizer = tokenizer
15+
self.dataset_path = cfg.get("dataset_path", None)
16+
self.max_len = cfg.get("max_len", 128)
17+
18+
self.tokenized_datasets = None
19+
20+
self.required_splits = ["train", "val", "sanity", "test"]
21+
self.required_columns = ["text", "labels"]
22+
23+
def _load_clean_dataset(self) -> DatasetDict:
24+
logger.info(f"Loading dataset from: {self.dataset_path}")
25+
dataset = load_from_disk(self.dataset_path)
26+
27+
missing_splits = [s for s in self.required_splits if s not in dataset]
28+
missing_cols = [c for c in self.required_columns if c not in dataset["train"].column_names]
29+
30+
if missing_splits:
31+
raise ValueError(f"Dataset missing splits: {missing_splits}")
32+
if missing_cols:
33+
raise ValueError(f"Dataset missing columns: {missing_cols}")
34+
35+
return dataset
36+
37+
def _tokenize_function(self, examples):
38+
"""Tokenization function for map."""
39+
# Ensure correct text column is used
40+
return self.tokenizer(
41+
examples["text"],
42+
truncation=True,
43+
padding=False, # Trainer handles padding with data collator
44+
max_length=self.max_len
45+
)
46+
47+
def setup(self):
48+
"""Loads and tokenizes the dataset."""
49+
if self.tokenized_datasets:
50+
return
51+
52+
raw_datasets = self._load_clean_dataset()
53+
self.tokenized_datasets = raw_datasets.map(
54+
self._tokenize_function,
55+
batched=True,
56+
remove_columns=[c for c in raw_datasets["train"].column_names if c not in
57+
["input_ids", "attention_mask", "labels"]]
58+
)
59+
60+
logger.info(f"Loaded and tokenized datasets with max length: {self.max_len}")
61+
logger.info(f"Columns in tokenized datasets: {self.tokenized_datasets['train'].column_names}")
62+
63+
def get_train_dataset(self):
64+
if not self.tokenized_datasets: self.setup() # noqa: E701
65+
return self.tokenized_datasets["train"]
66+
67+
def get_eval_dataset(self):
68+
if not self.tokenized_datasets: self.setup() # noqa: E701
69+
return self.tokenized_datasets["val"]
70+
71+
def get_sanity_dataset(self):
72+
if not self.tokenized_datasets: self.setup() # noqa: E701
73+
return self.tokenized_datasets["sanity"]
74+
75+
def get_test_dataset(self):
76+
if not self.tokenized_datasets: self.setup() # noqa: E701
77+
return self.tokenized_datasets["test"]
78+

src/models.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
Model factory for the teacher model.
3+
"""
4+
import logging
5+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
6+
7+
logger = logging.getLogger(__name__)
8+
9+
def build_teacher(cfg: dict):
10+
"""
11+
Builds and returns the teacher model and tokenizer using Hugging Face.
12+
13+
Args:
14+
cfg (dict): Configuration dictionary for the model, expecting keys like:
15+
- model_name (str): Hugging Face model identifier.
16+
- num_labels (int): Number of classification labels.
17+
18+
Returns:
19+
tuple: (model, tokenizer)
20+
"""
21+
model_name = cfg.get("model_name", "microsoft/deberta-v3-base")
22+
num_labels = cfg.get("num_labels", 2)
23+
use_fast_tokenizer = cfg.get("use_fast_tokenizer", True)
24+
25+
logger.info(f"Loading teacher model: {model_name} with {num_labels} labels.")
26+
model = AutoModelForSequenceClassification.from_pretrained(
27+
model_name,
28+
num_labels=num_labels
29+
)
30+
31+
logger.info(f"Loading tokenizer for: {model_name}")
32+
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=use_fast_tokenizer)
33+
34+
return model, tokenizer

0 commit comments

Comments
 (0)