This project implements and compares different learning rate schedulers for deep learning models, with a focus on Reinforcement Learning (RL)-based schedulers that dynamically adapt learning rates during training.
Traditional learning rate schedulers follow predetermined patterns (e.g., step decay, cosine annealing), while RL-based schedulers learn to adjust the learning rate based on the current training state. This project provides:
- RL-based schedulers: PPO and A2C agents that learn optimal learning rate policies
- Traditional schedulers: Step LR and Cosine Annealing for baseline comparison
- Comprehensive evaluation: On CIFAR-10 dataset with various CNN architectures
- Hyperparameter tuning: Optuna-based optimization for finding best scheduler parameters
- Experiment tracking: Neptune.ai integration for monitoring and reproducibility
- Python 3.8+
- CUDA-compatible GPU (optional but recommended)
- Neptune.ai account (optional, for experiment tracking)
- Clone the repository:
git clone https://github.com/KananMahammadli/rl-learning-rate-scheduler.git
cd rl-learning-rate-scheduler- Create a virtual environment:
python -m venv env
source env/bin/activate # On Windows: env\Scripts\activate- Install dependencies:
pip install -r requirements.txt- (Optional) Set up Neptune.ai for experiment tracking:
neptune login
# Follow the prompts to enter your API tokenTo train a model using the best found parameters for each scheduler:
# PPO scheduler on CIFAR-10 with ResNet18
python train_pl.py --dataset cifar10 --model resnet18 --scheduler_type ppo --log_to_neptune
# Traditional schedulers for comparison
python train_pl.py --dataset cifar10 --model resnet18 --scheduler_type step --log_to_neptune
python train_pl.py --dataset cifar10 --model resnet18 --scheduler_type cosine --log_to_neptuneAvailable options:
--dataset:cifar10--model:resnet18,resnet34,resnet50,vgg16,simple_cnn_cifar10--scheduler_type:ppo,a2c,step,cosine--log_to_neptune: Enable Neptune.ai logging--save_rl_policy: Save the trained RL policy for later use
To find optimal parameters for a scheduler:
# Tune PPO scheduler for ResNet18 on CIFAR-10
python tune.py --dataset cifar10 --model resnet18 --scheduler_type ppo --n_trials 100 --log_to_neptune
# Tune with pruning for faster optimization
python tune.py --dataset cifar10 --model resnet18 --scheduler_type ppo --n_trials 100 --pruning --log_to_neptune
# Tune traditional scheduler
python tune.py --dataset cifar10 --model resnet18 --scheduler_type step --n_trials 50 --log_to_neptuneTuning options:
--n_trials: Number of Optuna trials (default: 100)--study_name: Custom name for the Optuna study--log_to_neptune: Log each trial to Neptune.ai--pruning: Enable early stopping of unpromising trials
-
RL Environment (
rl_env.py):- State: current LR, loss trends, training progress
- Action: discrete LR choices (configurable granularity)
- Reward: based on validation loss improvement
-
RL Scheduler Adapter (
rl_scheduler_adapter.py):- Wraps RL agents (PPO/A2C) as PyTorch schedulers
- Handles state updates and action execution
-
Lightning Module (
model_pl.py):- Integrates schedulers with PyTorch Lightning
- Manages training loop and metrics logging
-
Data Module (
data_module.py):- Standardized data loading for MNIST and CIFAR-10
- Configurable batch sizes and augmentations
All hyperparameters are centralized in config.py:
# Training configuration
TrainingConfig(
num_epochs=100,
batch_size=64,
initial_lr=1e-2,
log_every_n_steps=50,
use_gpu=True,
gpu_ids=[0] # Specify GPU devices
)
# Best found parameters for each scheduler
SCHEDULER_BEST_PARAMS = {
'ppo': {
'min_lr': 1e-6,
'max_lr': 1e-2,
'n_actions': 20,
'total_timesteps': 48985,
'rl_model_params': {
'learning_rate': 0.00065,
'n_steps': 1392,
'batch_size': 116,
# ... more PPO hyperparameters
}
},
# ... parameters for other schedulers
}- Training metrics: loss, accuracy, precision, recall, F1-score
- Validation metrics: same as training, computed each epoch
- Test metrics: final evaluation on held-out test set
- RL-specific metrics: reward, current learning rate
The RL-based schedulers learn to:
- Start with higher learning rates for faster initial convergence
- Reduce learning rate when validation loss plateaus
- Make fine adjustments based on training dynamics
Traditional schedulers follow fixed patterns regardless of actual training progress.
With Neptune.ai logging enabled, you can:
- Compare learning rate schedules across experiments
- Analyze convergence patterns
- Track RL agent behavior (rewards, state evolution)
Add a new model in model_cnn.py:
def get_custom_model(num_classes):
# Define your model architecture
return CustomModel(num_classes)
# Register in CNN_MODEL_CONFIGS in config.pyModify the state representation or reward function in rl_env.py:
class CustomLRSchedulerEnv(LRSchedulerEnv):
def _get_state(self):
# Add custom state features
pass
def _calculate_reward(self):
# Implement custom reward logic
passConfigure multi-GPU training in config.py:
TrainingConfig(
use_gpu=True,
gpu_ids=[0, 1, 2, 3] # Use 4 GPUs
)The training script automatically uses DDP (Distributed Data Parallel) strategy for multi-GPU setups.
- Initial Learning Rate: Set
initial_lrinTrainingConfigbased on your model/dataset - RL Action Space: Adjust
n_actionsfor finer/coarser LR control - RL Training: Increase
total_timestepsfor better RL policy learning - Early Stopping: Configure patience in
EarlyStoppingConfigto prevent overfitting
This project is licensed under the MIT License - see the LICENSE file for details.
- Stable Baselines3 for RL implementations
- PyTorch Lightning for training framework
- Neptune.ai for experiment tracking
- Optuna for hyperparameter optimization
If you use this code in your research, please cite:
@software{rl_lr_scheduler,
title = {RL-Based Learning Rate Scheduler},
author = {Kanan Mahammadli},
year = {2025},
url = {https://github.com/KananMahammadli/rl-learning-rate-scheduler.git}
}