-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Expand file tree
/
Copy pathtrain.py
More file actions
36 lines (32 loc) · 1.1 KB
/
train.py
File metadata and controls
36 lines (32 loc) · 1.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import torch
import torch.nn.functional as F
from torch.optim import Adam
from data_loader import get_loader
from model import GPT
from config import Config
# Prepare data
loader, vocab = get_loader(vocab_path=Config.VOCAB_PATH)
# Initialize model
model = GPT().to(Config.device)
optimizer = Adam(model.parameters(), lr=Config.lr)
# Training loop
for epoch in range(1, Config.epochs + 1):
model.train()
total_loss = 0.0
for x, y in loader:
x = x.to(Config.device)
y = y.to(Config.device)
logits = model(x)
loss = F.cross_entropy(logits.view(-1, Config.vocab_size), y.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(loader)
print(f"Epoch {epoch}/{Config.epochs}, Loss: {avg_loss:.4f}")
# Save checkpoint
os.makedirs(Config.OUTPUT_DIR, exist_ok=True)
ckpt_path = os.path.join(Config.OUTPUT_DIR, f"model_epoch{epoch}.pt")
torch.save(model.state_dict(), ckpt_path)
print(f"Saved checkpoint: {ckpt_path}")