Skip to content

KananMahammadli/rl-learning-rate-scheduler

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

45 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

RL-Based Learning Rate Scheduler

This project implements and compares different learning rate schedulers for deep learning models, with a focus on Reinforcement Learning (RL)-based schedulers that dynamically adapt learning rates during training.

Overview

Traditional learning rate schedulers follow predetermined patterns (e.g., step decay, cosine annealing), while RL-based schedulers learn to adjust the learning rate based on the current training state. This project provides:

  • RL-based schedulers: PPO and A2C agents that learn optimal learning rate policies
  • Traditional schedulers: Step LR and Cosine Annealing for baseline comparison
  • Comprehensive evaluation: On CIFAR-10 dataset with various CNN architectures
  • Hyperparameter tuning: Optuna-based optimization for finding best scheduler parameters
  • Experiment tracking: Neptune.ai integration for monitoring and reproducibility

Quick Start

Prerequisites

  • Python 3.8+
  • CUDA-compatible GPU (optional but recommended)
  • Neptune.ai account (optional, for experiment tracking)

Installation

  1. Clone the repository:
git clone https://github.com/KananMahammadli/rl-learning-rate-scheduler.git
cd rl-learning-rate-scheduler
  1. Create a virtual environment:
python -m venv env
source env/bin/activate  # On Windows: env\Scripts\activate
  1. Install dependencies:
pip install -r requirements.txt
  1. (Optional) Set up Neptune.ai for experiment tracking:
neptune login
# Follow the prompts to enter your API token

Reproducing Results

1. Training with Best Parameters

To train a model using the best found parameters for each scheduler:

# PPO scheduler on CIFAR-10 with ResNet18
python train_pl.py --dataset cifar10 --model resnet18 --scheduler_type ppo  --log_to_neptune

# Traditional schedulers for comparison
python train_pl.py --dataset cifar10 --model resnet18 --scheduler_type step --log_to_neptune
python train_pl.py --dataset cifar10 --model resnet18 --scheduler_type cosine --log_to_neptune

Available options:

  • --dataset: cifar10
  • --model: resnet18, resnet34, resnet50, vgg16, simple_cnn_cifar10
  • --scheduler_type: ppo, a2c, step, cosine
  • --log_to_neptune: Enable Neptune.ai logging
  • --save_rl_policy: Save the trained RL policy for later use

2. Hyperparameter Tuning

To find optimal parameters for a scheduler:

# Tune PPO scheduler for ResNet18 on CIFAR-10
python tune.py --dataset cifar10 --model resnet18 --scheduler_type ppo --n_trials 100 --log_to_neptune

# Tune with pruning for faster optimization
python tune.py --dataset cifar10 --model resnet18 --scheduler_type ppo --n_trials 100 --pruning --log_to_neptune

# Tune traditional scheduler
python tune.py --dataset cifar10 --model resnet18 --scheduler_type step --n_trials 50 --log_to_neptune

Tuning options:

  • --n_trials: Number of Optuna trials (default: 100)
  • --study_name: Custom name for the Optuna study
  • --log_to_neptune: Log each trial to Neptune.ai
  • --pruning: Enable early stopping of unpromising trials

Project Architecture

Core Components

  1. RL Environment (rl_env.py):

    • State: current LR, loss trends, training progress
    • Action: discrete LR choices (configurable granularity)
    • Reward: based on validation loss improvement
  2. RL Scheduler Adapter (rl_scheduler_adapter.py):

    • Wraps RL agents (PPO/A2C) as PyTorch schedulers
    • Handles state updates and action execution
  3. Lightning Module (model_pl.py):

    • Integrates schedulers with PyTorch Lightning
    • Manages training loop and metrics logging
  4. Data Module (data_module.py):

    • Standardized data loading for MNIST and CIFAR-10
    • Configurable batch sizes and augmentations

Configuration System

All hyperparameters are centralized in config.py:

# Training configuration
TrainingConfig(
    num_epochs=100,
    batch_size=64,
    initial_lr=1e-2,
    log_every_n_steps=50,
    use_gpu=True,
    gpu_ids=[0]  # Specify GPU devices
)

# Best found parameters for each scheduler
SCHEDULER_BEST_PARAMS = {
    'ppo': {
        'min_lr': 1e-6,
        'max_lr': 1e-2,
        'n_actions': 20,
        'total_timesteps': 48985,
        'rl_model_params': {
            'learning_rate': 0.00065,
            'n_steps': 1392,
            'batch_size': 116,
            # ... more PPO hyperparameters
        }
    },
    # ... parameters for other schedulers
}

Understanding the Results

Metrics Tracked

  • Training metrics: loss, accuracy, precision, recall, F1-score
  • Validation metrics: same as training, computed each epoch
  • Test metrics: final evaluation on held-out test set
  • RL-specific metrics: reward, current learning rate

Comparing Schedulers

The RL-based schedulers learn to:

  1. Start with higher learning rates for faster initial convergence
  2. Reduce learning rate when validation loss plateaus
  3. Make fine adjustments based on training dynamics

Traditional schedulers follow fixed patterns regardless of actual training progress.

Visualization

With Neptune.ai logging enabled, you can:

  • Compare learning rate schedules across experiments
  • Analyze convergence patterns
  • Track RL agent behavior (rewards, state evolution)

Advanced Usage

Custom Models

Add a new model in model_cnn.py:

def get_custom_model(num_classes):
    # Define your model architecture
    return CustomModel(num_classes)

# Register in CNN_MODEL_CONFIGS in config.py

Custom RL Environments

Modify the state representation or reward function in rl_env.py:

class CustomLRSchedulerEnv(LRSchedulerEnv):
    def _get_state(self):
        # Add custom state features
        pass
    
    def _calculate_reward(self):
        # Implement custom reward logic
        pass

Multi-GPU Training

Configure multi-GPU training in config.py:

TrainingConfig(
    use_gpu=True,
    gpu_ids=[0, 1, 2, 3]  # Use 4 GPUs
)

The training script automatically uses DDP (Distributed Data Parallel) strategy for multi-GPU setups.

Performance Tips

  1. Initial Learning Rate: Set initial_lr in TrainingConfig based on your model/dataset
  2. RL Action Space: Adjust n_actions for finer/coarser LR control
  3. RL Training: Increase total_timesteps for better RL policy learning
  4. Early Stopping: Configure patience in EarlyStoppingConfig to prevent overfitting

License

This project is licensed under the MIT License - see the LICENSE file for details.

Acknowledgments

Citation

If you use this code in your research, please cite:

@software{rl_lr_scheduler,
  title = {RL-Based Learning Rate Scheduler},
  author = {Kanan Mahammadli},
  year = {2025},
  url = {https://github.com/KananMahammadli/rl-learning-rate-scheduler.git}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors