diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py index c37db48c0a..5d55a85c49 100644 --- a/src/maxtext/trainers/post_train/rl/utils_rl.py +++ b/src/maxtext/trainers/post_train/rl/utils_rl.py @@ -460,32 +460,40 @@ def extract_hash_answer(text: str) -> str | None: def get_optimizer(tmvp_config, max_train_steps): """Function to obtain an optax optimizer, currently we use adamw.""" - optimizer = optax.adamw( - learning_rate=optax.schedules.warmup_cosine_decay_schedule( - init_value=0.0, - peak_value=tmvp_config.learning_rate, - # Linearly increase learning rate from 0. to learning_rate in the first - # warmup_steps_fraction training steps, and then gradually decrease the - # learning rate to 0 using cosine scheduler. - warmup_steps=int(tmvp_config.warmup_steps_fraction * max_train_steps), - decay_steps=max_train_steps, - end_value=0.0, - ), - b1=tmvp_config.adam_b1, - b2=tmvp_config.adam_b2, - weight_decay=tmvp_config.adam_weight_decay, + schedule = optax.schedules.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=tmvp_config.learning_rate, + # Linearly increase learning rate from 0. to learning_rate in the first + # warmup_steps_fraction training steps, and then gradually decrease the + # learning rate to 0 using cosine scheduler. + warmup_steps=int(tmvp_config.warmup_steps_fraction * max_train_steps), + decay_steps=max_train_steps, + end_value=0.0, ) # TODO: @mazumdera: try optimizer offloading with adamw # Add gradient clipping if specified # Grad clipping to prevent large gradients. We find this # important to keep KL divergence in check. - if tmvp_config.gradient_clipping_threshold > 0: - optimizer = optax.chain( - optax.clip_by_global_norm(max_norm=tmvp_config.gradient_clipping_threshold), - optimizer, + def make_optimizer(learning_rate): + transforms = [] + if tmvp_config.gradient_clipping_threshold > 0: + transforms.append(optax.clip_by_global_norm(max_norm=tmvp_config.gradient_clipping_threshold)) + transforms.append( + optax.adamw( + learning_rate=learning_rate, + b1=tmvp_config.adam_b1, + b2=tmvp_config.adam_b2, + weight_decay=tmvp_config.adam_weight_decay, + ) ) - return optimizer + return optax.chain(*transforms) + + # Wrap the entire optimizer (including gradient clipping) with + # inject_hyperparams so opt_state.hyperparams['learning_rate'] is at the + # top level of the state tree. This is required for tunix's peft_trainer to + # automatically read and log the per-step learning rate. + return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule) def process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x): diff --git a/tests/post_training/unit/rl_utils_test.py b/tests/post_training/unit/rl_utils_test.py index c8aaa7dd83..51b6f074e0 100644 --- a/tests/post_training/unit/rl_utils_test.py +++ b/tests/post_training/unit/rl_utils_test.py @@ -332,5 +332,42 @@ def test_without_hash(self): self.assertIsNone(utils_rl.extract_hash_answer("")) +class TestGetOptimizer(unittest.TestCase): + """Tests for utils_rl.get_optimizer.""" + + def _make_optimizer_config(self, gradient_clipping_threshold=0.0): + return SimpleNamespace( + learning_rate=1e-4, + warmup_steps_fraction=0.1, + gradient_clipping_threshold=gradient_clipping_threshold, + adam_b1=0.9, + adam_b2=0.999, + adam_weight_decay=0.01, + ) + + @pytest.mark.cpu_only + def test_returns_optimizer_without_clipping(self): + """get_optimizer returns an optax optimizer when gradient clipping is disabled.""" + import jax.numpy as jnp # pylint: disable=import-outside-toplevel + + config = self._make_optimizer_config(gradient_clipping_threshold=0.0) + opt = utils_rl.get_optimizer(config, max_train_steps=100) + # Should be usable: init on a simple param tree + params = {"w": jnp.ones(3)} + state = opt.init(params) + self.assertIn("learning_rate", state.hyperparams) + + @pytest.mark.cpu_only + def test_returns_optimizer_with_clipping(self): + """get_optimizer includes gradient clipping when threshold > 0.""" + import jax.numpy as jnp # pylint: disable=import-outside-toplevel + + config = self._make_optimizer_config(gradient_clipping_threshold=1.0) + opt = utils_rl.get_optimizer(config, max_train_steps=100) + params = {"w": jnp.ones(3)} + state = opt.init(params) + self.assertIn("learning_rate", state.hyperparams) + + if __name__ == "__main__": unittest.main()