Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions src/maxtext/trainers/post_train/rl/utils_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,32 +460,40 @@ def extract_hash_answer(text: str) -> str | None:

def get_optimizer(tmvp_config, max_train_steps):
"""Function to obtain an optax optimizer, currently we use adamw."""
optimizer = optax.adamw(
learning_rate=optax.schedules.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=tmvp_config.learning_rate,
# Linearly increase learning rate from 0. to learning_rate in the first
# warmup_steps_fraction training steps, and then gradually decrease the
# learning rate to 0 using cosine scheduler.
warmup_steps=int(tmvp_config.warmup_steps_fraction * max_train_steps),
decay_steps=max_train_steps,
end_value=0.0,
),
b1=tmvp_config.adam_b1,
b2=tmvp_config.adam_b2,
weight_decay=tmvp_config.adam_weight_decay,
schedule = optax.schedules.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=tmvp_config.learning_rate,
# Linearly increase learning rate from 0. to learning_rate in the first
# warmup_steps_fraction training steps, and then gradually decrease the
# learning rate to 0 using cosine scheduler.
warmup_steps=int(tmvp_config.warmup_steps_fraction * max_train_steps),
decay_steps=max_train_steps,
end_value=0.0,
)

# TODO: @mazumdera: try optimizer offloading with adamw
# Add gradient clipping if specified
# Grad clipping to prevent large gradients. We find this
# important to keep KL divergence in check.
if tmvp_config.gradient_clipping_threshold > 0:
optimizer = optax.chain(
optax.clip_by_global_norm(max_norm=tmvp_config.gradient_clipping_threshold),
optimizer,
def make_optimizer(learning_rate):
transforms = []
if tmvp_config.gradient_clipping_threshold > 0:
transforms.append(optax.clip_by_global_norm(max_norm=tmvp_config.gradient_clipping_threshold))
transforms.append(
optax.adamw(
learning_rate=learning_rate,
b1=tmvp_config.adam_b1,
b2=tmvp_config.adam_b2,
weight_decay=tmvp_config.adam_weight_decay,
)
)
return optimizer
return optax.chain(*transforms)

# Wrap the entire optimizer (including gradient clipping) with
# inject_hyperparams so opt_state.hyperparams['learning_rate'] is at the
# top level of the state tree. This is required for tunix's peft_trainer to
# automatically read and log the per-step learning rate.
return optax.inject_hyperparams(make_optimizer)(learning_rate=schedule)


def process_data(dataset_name, model_tokenizer, template_config, tmvp_config, x):
Expand Down
37 changes: 37 additions & 0 deletions tests/post_training/unit/rl_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,5 +332,42 @@ def test_without_hash(self):
self.assertIsNone(utils_rl.extract_hash_answer(""))


class TestGetOptimizer(unittest.TestCase):
"""Tests for utils_rl.get_optimizer."""

def _make_optimizer_config(self, gradient_clipping_threshold=0.0):
return SimpleNamespace(
learning_rate=1e-4,
warmup_steps_fraction=0.1,
gradient_clipping_threshold=gradient_clipping_threshold,
adam_b1=0.9,
adam_b2=0.999,
adam_weight_decay=0.01,
)

@pytest.mark.cpu_only
def test_returns_optimizer_without_clipping(self):
"""get_optimizer returns an optax optimizer when gradient clipping is disabled."""
import jax.numpy as jnp # pylint: disable=import-outside-toplevel

config = self._make_optimizer_config(gradient_clipping_threshold=0.0)
opt = utils_rl.get_optimizer(config, max_train_steps=100)
# Should be usable: init on a simple param tree
params = {"w": jnp.ones(3)}
state = opt.init(params)
self.assertIn("learning_rate", state.hyperparams)

@pytest.mark.cpu_only
def test_returns_optimizer_with_clipping(self):
"""get_optimizer includes gradient clipping when threshold > 0."""
import jax.numpy as jnp # pylint: disable=import-outside-toplevel

config = self._make_optimizer_config(gradient_clipping_threshold=1.0)
opt = utils_rl.get_optimizer(config, max_train_steps=100)
params = {"w": jnp.ones(3)}
state = opt.init(params)
self.assertIn("learning_rate", state.hyperparams)


if __name__ == "__main__":
unittest.main()
Loading