From ce33e48e805d75470d1fd99355967db363345fa8 Mon Sep 17 00:00:00 2001 From: Shy4n7 Date: Thu, 7 May 2026 01:12:44 +0530 Subject: [PATCH] Preserve explicit ImageNet resume learning rate Signed-off-by: Shy4n7 --- imagenet/main.py | 15 +++++++++++++++ run_python_examples.sh | 8 ++++++++ 2 files changed, 23 insertions(+) diff --git a/imagenet/main.py b/imagenet/main.py index 42e5e9fa64..7826be7269 100644 --- a/imagenet/main.py +++ b/imagenet/main.py @@ -2,6 +2,7 @@ import os import random import shutil +import sys import time import warnings from enum import Enum @@ -85,6 +86,7 @@ def main(): args = parser.parse_args() + args.lr_supplied = learning_rate_arg_supplied(sys.argv[1:]) if args.seed is not None: random.seed(args.seed) @@ -223,6 +225,9 @@ def main_worker(gpu, ngpus_per_node, args): best_acc1 = best_acc1.to(args.gpu) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) + if args.lr_supplied: + for param_group in optimizer.param_groups: + param_group['lr'] = args.lr scheduler.load_state_dict(checkpoint['scheduler']) print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) @@ -431,6 +436,16 @@ def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): if is_best: shutil.copyfile(filename, 'model_best.pth.tar') + +def learning_rate_arg_supplied(argv): + return any( + arg in ('--lr', '--learning-rate') + or arg.startswith('--lr=') + or arg.startswith('--learning-rate=') + for arg in argv + ) + + class Summary(Enum): NONE = 0 AVERAGE = 1 diff --git a/run_python_examples.sh b/run_python_examples.sh index caa58fc3a3..5c08c74aaa 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -81,6 +81,14 @@ function imagenet() { cp sample/train/n/* sample/val/n/ fi uv run main.py --epochs 1 sample/ || error "imagenet example failed" + uv run main.py --epochs 2 --resume checkpoint.pth.tar --lr 0.05 sample/ || error "imagenet resume example failed" + uv run python - <<'PY' || error "imagenet resume learning rate check failed" +import torch + +checkpoint = torch.load("checkpoint.pth.tar", map_location="cpu") +for param_group in checkpoint["optimizer"]["param_groups"]: + assert param_group["lr"] == 0.05 +PY uv run main.py --epochs 1 --gpu 0 sample/ || error "imagenet example failed" }