Skip to content

Commit 8c989b3

Browse files
committed
added generator fine tune config
1 parent 3e12110 commit 8c989b3

3 files changed

Lines changed: 89 additions & 1 deletion

File tree

configs/generator/sst2_hf.yaml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
model:
2+
model_name: "gpt2" # you can swap in "gpt2-medium" or "EleutherAI/pythia-70m" etc.
3+
use_fast_tokenizer: true
4+
block_size: 128 # Maximum sequence length after tokenisation
5+
6+
data:
7+
dataset_path: "./data/clean/" # Use HF dataset identifier
8+
max_len: 32
9+
train_split: "train"
10+
validation_split: "val"
11+
test_split: "test"
12+
13+
training:
14+
# ── bookkeeping ────────────────────────────────────────────────────────────
15+
output_dir: "runs/generator/gpt2_sst2"
16+
overwrite_output_dir: true
17+
run_name: "generator_sst2_gpt2"
18+
19+
report_to: "wandb"
20+
wandb_project: "senti_synth_generator"
21+
22+
# ── batch size & epochs ────────────────────────────────────────────────────
23+
per_device_train_batch_size: 32 # fits comfortably on 24 GB VRAM
24+
per_device_eval_batch_size: 64
25+
gradient_accumulation_steps: 1
26+
num_train_epochs: 3 # SST‑2 is tiny; 2–3 epochs suffice
27+
28+
# ── precision & speed ──────────────────────────────────────────────────────
29+
fp16: true # enable mixed precision
30+
bf16: false # turn off to avoid dual precision modes
31+
# torch_dtype: "auto" # (optional) lets HF pick fastest dtype
32+
33+
# ── optimiser & scheduler ─────────────────────────────────────────────────
34+
learning_rate: 5e-5 # good starting LR for GPT‑2 on small corpora
35+
warmup_ratio: 0.1
36+
37+
# ── misc performance knobs ────────────────────────────────────────────────
38+
dataloader_num_workers: 4
39+
gradient_checkpointing: true # big memory win on GPT‑style decoders
40+
max_grad_norm: 1.0
41+
42+
# ── logging, saving, early stop ───────────────────────────────────────────
43+
logging_steps: 100
44+
eval_steps: 500
45+
save_steps: 500
46+
save_total_limit: 3
47+
load_best_model_at_end: true
48+
metric_for_best_model: "eval_loss"
49+
greater_is_better: false
50+
51+
use_early_stopping: true
52+
early_stopping_patience: 2
53+
early_stopping_threshold: 0.0005
54+
55+
do_test_eval: true

src/cli/02_fine_tune_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def main(config_path: Path = type.Argument(..., help="Path to YAML config")):
9898
trainer.save_metrics("train", metrics)
9999

100100
# Evaluate on test set if available
101-
test_dataset = data_module.get_test_dataset()
101+
test_dataset = data_module.get_sanity_dataset()
102102
if test_dataset and cfg['training'].get("do_test_eval", True):
103103
logger.info("Evaluating on test set...")
104104
test_metrics = trainer.evaluate(eval_dataset=test_dataset, metric_key_prefix="test")

src/nb.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/usr/bin/env python
2+
"""
3+
Quick demo: load a fine‑tuned DeBERTa‑v3 model and run one inference.
4+
Install deps first:
5+
pip install "transformers>=4.40" torch --upgrade
6+
"""
7+
import torch
8+
from transformers import (
9+
AutoTokenizer,
10+
AutoModelForSequenceClassification,
11+
pipeline,
12+
)
13+
14+
# -------------------------------------------------------------------
15+
# 1️⃣ where did Trainer save your model?
16+
# -------------------------------------------------------------------
17+
MODEL_DIR = "runs/teacher/deberta_v3_base/" # <= change me!
18+
19+
# -------------------------------------------------------------------
20+
# 2️⃣ easiest: high‑level pipeline
21+
# -------------------------------------------------------------------
22+
device = "mps"
23+
clf = pipeline(
24+
task="text-classification",
25+
model=MODEL_DIR,
26+
tokenizer=MODEL_DIR,
27+
device=device,
28+
)
29+
30+
example = "This movie was absolutely trash!"
31+
print("\nPipeline result:")
32+
print(clf(example)) # [{'label': 'POSITIVE', 'score': 0.97}]
33+

0 commit comments

Comments
 (0)