First, install robopianist:
bash <(curl -s https://raw.githubusercontent.com/google-research/robopianist/main/scripts/install_deps.sh) --no-soundfonts
conda create -n pianist python=3.10
conda activate pianist
pip install --upgrade robopianistFrom the same conda environment:
- Install JAX
- Run
pip install -r requirements.txt
To train an SAC policy to play Crossing Field's first 10s with the task parameters used in the robopianist paper:
WANDB_DIR=/tmp/robopianist/ MUJOCO_GL=glfw XLA_PYTHON_CLIENT_PREALLOCATE=false python scripts/train.py \
--root-dir models/CrossingField10s \
--warmstart-steps 5000 \
--max-steps 5000000 \
--discount 0.8 \
--agent-config.critic-dropout-rate 0.01 \
--agent-config.critic-layer-norm \
--agent-config.hidden-dims 256 256 256 \
--trim-silence \
--gravity-compensation \
--reduced-action-space \
--control-timestep 0.05 \
--n-steps-lookahead 10 \
--midi-file "example_midis/Crossing Field 10s.mid" \
--action-reward-observation \
--primitive-fingertip-collisions \
--eval-episodes 1 \
--tqdm-bar \
--eval-interval 30000 \
--disable_wandbTo evaluate a trained policy, run:
python scripts/eval.py \
--load_checkpoint <YOUR_MODEL_PATH> \
--midi-file "/Users/almondgod/Repositories/robopianist/midi_files_cut/Cruel Angel's Thesis Cut middle 15s.mid"To generate preference data, run:
python rlhf/generate_preference_data.py \
--checkpoints \
<YOUR_PATH_TO_CHECKPOINT1.pkl> \
<YOUR_PATH_TO_CHECKPOINT2.pkl> \
<YOUR_PATH_TO_CHECKPOINT3.pkl> \
<YOUR_PATH_TO_CHECKPOINT4.pkl> \
<YOUR_PATH_TO_CHECKPOINT5.pkl> \
--midi-file "example_midis/Crossing Field 10s.mid"To train a policy with CPL, run:
python scripts/cpl_train.py \
--sac_checkpoint <YOUR_SAC_CHECKPOINT.pkl>
--preference_data <YOUR_PREFERENCE_DATA.pkl>
--output_dir "cpl_trained_models"
--midi-file "example_midis/Crossing Field 10s.mid"See scripts/readme.md for information on other utility scripts clip_midi.py, play_song.py, and interactive_piano.py
Use human feedback of recorded episodes to train a policy to act aligned with human preferences (in our case, play the piano better)
Contrastive Preference Learning (CPL) is a method for learning from human preferences without requiring explicit reward engineering. Generically, we:
- Collect pairs of model outputs (segments)
- Get human preferences between these pairs
- Train a policy to maximize the likelihood of preferred behaviors while staying close to the original policy
This approach is perfect for piano playing, in which humans have subtle technique and feel preferences that are difficult to capture in reward functions. However, we want to preserve existing 'good behaviors' while making human-guided improvements.
I worked from an example implementation in the original CPL repository: https://github.com/jhejna/cpl/blob/main/research/algs/cpl.py
For the CPL implementation in cpl_train_sac.py, we have:
Base Policy (scripts/train.py, architecture/sac.py)
- A Soft Actor-Critic policy to act as "pretraining" on the piano playing task
- When finetuning with CPL, we abandon the Critic and only use the Policy. In the future, could try using the Critic as auxiliary loss to augment BC keeping the policy close to the original.
Preference Collection (rlhf/generate_preference_dataset.py)
- Generate segments using different checkpoints and noise levels to increase the diversity of the dataset
- Record videos of the policy playing for human evaluation, and rate the quality of each performance from 1-100
- Create pairwise preferences based on the ratings to yield
$n^2$ -size preference dataset - The data is organized by timestamp and includes full trajectory (state + action) information
CPL Training: (rlhf/cpl_train.py, architecture/cpl_sac.py)
- Uses CPL-SAC architecture to wrap SAC with CPL loss update
- CPL loss computes log probs of preferred and non-preferred actions, computes equivalent advantages (alpha * log probs), and compute cpl loss according to:
Essentially, we maximize the expected ratio of the advantage of preferred actions over both preferred and non-preferred actions.
- Then we computes loss as weighted combination of CPL loss and conservative loss (MSE to original policy outputs)
- And we finally clip gradient norms to prevent exploding gradients
- No need to learn a separate reward function, only need pretrained policy
- CPL Directly incorporates human preferences about playing style
- Conservative Loss prevents forgetting of good behaviors learned in pretraining
- Could learn from a relatively small number of human preferences
-
Hyperparameters that worked for me:
- Learning rate: 1e-4
- Batch size: 32
- Temperature (alpha): 0.1
- Conservative weight: 0.01
- Preference weight (lambda): 0.5
-
Monitoring:
- Track preference loss
- Monitor conservative regularization and ensure no NaN gradients
- Track the videos in the
evalfolder to ensure consistent performance
Train a policy in high-dimensional continuous actio spaces that does nto get stuck at local minima and is not as sensitive to hyperparameters as past RL algorithms
Soft Actor-Critic (SAC) is off-policy, meaning it learns from data not necessarily generated by the current policy. In practice, we often use SAC online (interacting with an environement) and store data in a replay buffer. SAC combines techniques from:
- DDPG: offline actor-critic architecture (but has a deterministic actor)
- Soft Q-Learning: maximizes entropy to encourage exploration
We use typical actor-critic architecture, but now we add to our critic:
- Using the minimum of 2 Q functions to reduce Q overestimation bias
- use target networks for more stable TD learning
- Include an entropy term in the values which discourages deterministic actions and thus encourages exploration of the state space.
The robopianist-rl implementation (in architecture/sac.py) uses:
-
Actor Network: MLP with 3x256 hidden layers, predicts normal dist for each action dimension,
TanhMultivariateNormalDiagdistribution for bounded actions -
Critic Network: Two Q-networks with 3x256 hidden layers, Layer normalization and dropout (0.01) for regularization (DroQ)
- Off-policy learning allows reuse of past experience (sample efficient)
- Double Q-learning reduces bellman-induced overestimation bias, DroQ improves generalization
- Maximum entropy so gets stuck at local minima less
-
Hyperparameters that worked for me with no fingering annotations:
- Learning rates: ~3e-4 for all networks
- Batch size: 256
- Target network update rate (tau): 0.01
- Initial temperature: 1.2
-
Monitoring:
- Track Q-values to detect overestimation
- Ensure consistent relatively high entropy and gradually increasing Qs after beginning stabilization in training
The current JAX is configured to run on Apple Silicon. Please adjust the JAX METAL lines according to your system.
I encourage you not to use custom fingering, which can introduce errors and is not necessary when SAC is combined with RLHF finetuning.
@article{zakka2023robopianist,
author = {Zakka, Kevin and Smith, Laura and Gileadi, Nimrod and Howell, Taylor and Peng, Xue Bin and Singh, Sumeet and Tassa, Yuval and Florence, Pete and Zeng, Andy and Abbeel, Pieter},
title = {{RoboPianist: A Benchmark for High-Dimensional Robot Control}},
journal = {arXiv preprint arXiv:2304.04150},
year = {2023},
}@InProceedings{hejna23contrastive,
title = {Contrastive Preference Learning: Learning From Human Feedback without RL},
author = {Hejna, Joey and Rafailov, Rafael and Sikchi, Harshit and Finn, Chelsea and Niekum, Scott and Knox, W. Bradley and Sadigh, Dorsa},
booktitle = {ArXiv preprint},
year = {2023},
url = {https://arxiv.org/abs/2310.13639}
}