Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import random
import shutil
import sys
import time
import warnings
from enum import Enum
Expand Down Expand Up @@ -85,6 +86,7 @@

def main():
args = parser.parse_args()
args.lr_supplied = learning_rate_arg_supplied(sys.argv[1:])

if args.seed is not None:
random.seed(args.seed)
Expand Down Expand Up @@ -223,6 +225,9 @@ def main_worker(gpu, ngpus_per_node, args):
best_acc1 = best_acc1.to(args.gpu)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
if args.lr_supplied:
for param_group in optimizer.param_groups:
param_group['lr'] = args.lr
scheduler.load_state_dict(checkpoint['scheduler'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
Expand Down Expand Up @@ -431,6 +436,16 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
if is_best:
shutil.copyfile(filename, 'model_best.pth.tar')


def learning_rate_arg_supplied(argv):
return any(
arg in ('--lr', '--learning-rate')
or arg.startswith('--lr=')
or arg.startswith('--learning-rate=')
for arg in argv
)


class Summary(Enum):
NONE = 0
AVERAGE = 1
Expand Down
8 changes: 8 additions & 0 deletions run_python_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ function imagenet() {
cp sample/train/n/* sample/val/n/
fi
uv run main.py --epochs 1 sample/ || error "imagenet example failed"
uv run main.py --epochs 2 --resume checkpoint.pth.tar --lr 0.05 sample/ || error "imagenet resume example failed"
uv run python - <<'PY' || error "imagenet resume learning rate check failed"
import torch

checkpoint = torch.load("checkpoint.pth.tar", map_location="cpu")
for param_group in checkpoint["optimizer"]["param_groups"]:
assert param_group["lr"] == 0.05
PY
uv run main.py --epochs 1 --gpu 0 sample/ || error "imagenet example failed"
}

Expand Down