Skip to content

Commit 0e036f6

Browse files
committed
feat: add PyTorch MNIST training example with Kubeflow Trainer integration
1 parent 1fe3bd3 commit 0e036f6

1 file changed

Lines changed: 95 additions & 0 deletions

File tree

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import argparse
2+
import os
3+
import sys
4+
from kubeflow.trainer import CustomTrainer, TrainerClient
5+
6+
def train_fn():
7+
import torch
8+
import torch.distributed as dist
9+
import torch.nn.functional as F
10+
from torch import nn
11+
from torch.utils.data import DataLoader, DistributedSampler
12+
from torchvision import datasets, transforms
13+
14+
class Net(nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.conv1 = nn.Conv2d(1, 20, 5, 1)
18+
self.conv2 = nn.Conv2d(20, 50, 5, 1)
19+
self.fc1 = nn.Linear(4 * 4 * 50, 500)
20+
self.fc2 = nn.Linear(500, 10)
21+
22+
def forward(self, x):
23+
x = F.relu(self.conv1(x))
24+
x = F.max_pool2d(x, 2, 2)
25+
x = F.relu(self.conv2(x))
26+
x = F.max_pool2d(x, 2, 2)
27+
x = x.view(-1, 4 * 4 * 50)
28+
x = F.relu(self.fc1(x))
29+
return F.log_softmax(self.fc2(x), dim=1)
30+
31+
# Distributed setup
32+
# If we're just testing locally, we'll bypass the DDP init if WORLD_SIZE isn't set
33+
world_size = int(os.environ.get("WORLD_SIZE", 1))
34+
if world_size > 1:
35+
backend = "nccl" if torch.cuda.is_available() else "gloo"
36+
dist.init_process_group(backend=backend)
37+
local_rank = int(os.getenv("LOCAL_RANK", 0))
38+
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
39+
model = nn.parallel.DistributedDataParallel(Net().to(device))
40+
sampler = DistributedSampler(datasets.FashionMNIST("./data", train=True, download=True, transform=transforms.ToTensor()))
41+
else:
42+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43+
model = Net().to(device)
44+
sampler = None
45+
46+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
47+
48+
# Load data
49+
dataset = datasets.FashionMNIST("./data", train=True, download=True, transform=transforms.ToTensor())
50+
loader = DataLoader(dataset, batch_size=128, sampler=sampler, shuffle=(sampler is None))
51+
52+
print(f"Starting training on {device}...")
53+
for epoch in range(1):
54+
model.train()
55+
for batch_idx, (data, target) in enumerate(loader):
56+
data, target = data.to(device), target.to(device)
57+
optimizer.zero_grad()
58+
output = model(data)
59+
loss = F.nll_loss(output, target)
60+
loss.backward()
61+
optimizer.step()
62+
63+
if batch_idx % 20 == 0:
64+
print(f"Epoch: {epoch} | Batch: {batch_idx} | Loss: {loss.item():.4f}")
65+
66+
# Short circuit for test mode if we want, but let's run a full epoch
67+
if os.environ.get("KUBEFLOW_TRAINER_TEST"):
68+
if batch_idx >= 10: break
69+
70+
if world_size > 1:
71+
dist.destroy_process_group()
72+
73+
if __name__ == "__main__":
74+
parser = argparse.ArgumentParser()
75+
parser.add_argument("--nodes", type=int, default=1, help="Number of nodes")
76+
parser.add_argument("--test", action="store_true", help="Run a quick local test without Kubeflow")
77+
args = parser.parse_args()
78+
79+
if args.test:
80+
print("Running quick local test...")
81+
os.environ["KUBEFLOW_TRAINER_TEST"] = "1"
82+
train_fn()
83+
sys.exit(0)
84+
85+
client = TrainerClient()
86+
87+
job_name = client.train(
88+
trainer=CustomTrainer(
89+
func=train_fn,
90+
num_nodes=args.nodes,
91+
packages_to_install=["torch", "torchvision"]
92+
)
93+
)
94+
95+
print(f"Submitted TrainJob: {job_name}")

0 commit comments

Comments
 (0)