From ee59b2fd5ee1148ff978cf673c4d460ad2b6b020 Mon Sep 17 00:00:00 2001 From: Andy Twigg Date: Tue, 24 Mar 2026 21:05:04 +0000 Subject: [PATCH] simplify reward structure; add exact answer reward --- src/maxtext/configs/post_train/rl.yml | 15 ++++++++------- src/maxtext/configs/types.py | 1 + src/maxtext/trainers/post_train/rl/train_rl.py | 18 ++++++++++++++---- src/maxtext/trainers/post_train/rl/utils_rl.py | 10 +++++----- tests/post_training/unit/rl_utils_test.py | 13 +++++++------ 5 files changed, 35 insertions(+), 22 deletions(-) diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index 9cc305d090..3787146d90 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -164,13 +164,14 @@ max_num_checkpoints_to_keep: 10 # ====== Reward ====== -reward_exact_format_match: 3.0 -reward_white_space_format_match: 1.5 -reward_partial_format_match: 0.5 -reward_ratio_guess_to_answer_high: 0.5 -reward_ratio_guess_to_answer_low: 0.25 -penalty_incorrect_format: -0.5 -penalty_incorrect_answer: -1.0 +reward_exact_answer: 1.0 +reward_white_space_format_match: 1.0 +reward_exact_format_match: 0.1 +reward_partial_format_match: 0.0 +reward_ratio_guess_to_answer_high: 0.0 +reward_ratio_guess_to_answer_low: 0.0 +penalty_incorrect_format: 0.0 +penalty_incorrect_answer: 0.0 # ====== Special tokens/templates for GSM8K reasoning ====== reasoning_start_token: '' diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 9ee1a7bb59..cdb61a25fc 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1686,6 +1686,7 @@ class RLEvaluation(BaseModel): class Reward(BaseModel): """Configuration for the reward/penalty model in RL.""" + reward_exact_answer: float = Field(5.0, description="Reward for an exact answer match.") reward_exact_format_match: float = Field(3.0, description="Reward for an exact format match.") reward_white_space_format_match: float = Field(1.5, description="Reward for a format match ignoring whitespace.") reward_partial_format_match: float = Field(0.5, description="Reward for a partial format match.") diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 7f5a33eed6..d77270f62e 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -44,6 +44,7 @@ """ from __future__ import annotations +from functools import wraps from typing import Sequence import collections @@ -564,15 +565,24 @@ def create_rl_components( **rl_cluster_kwargs, ) + def make_reward_fn(fn): + # pragma: no cover + @wraps(fn) + def _reward_fn(**kwargs): + return fn(tmvp_config=trainer_config, **kwargs) + + return _reward_fn + # Create RL trainer max_logging.log("Setting up RL trainer...") rl_trainer = GrpoLearner( rl_cluster=rl_cluster, reward_fns=[ # type: ignore - lambda **kwargs: utils_rl.match_format_exactly(tmvp_config=trainer_config, **kwargs), - lambda **kwargs: utils_rl.match_format_approximately(tmvp_config=trainer_config, **kwargs), - lambda **kwargs: utils_rl.check_answer(tmvp_config=trainer_config, **kwargs), - lambda **kwargs: utils_rl.check_numbers(tmvp_config=trainer_config, **kwargs), + make_reward_fn(utils_rl.match_format_exactly), + make_reward_fn(utils_rl.match_format_approximately), + # TODO(atwigg): comment out to simplify reward and overlap with check_numbers + make_reward_fn(utils_rl.check_answer), + make_reward_fn(utils_rl.check_numbers), ], algo_config=grpo_config, ) diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py index 782983cefa..bca0ae7bbd 100644 --- a/src/maxtext/trainers/post_train/rl/utils_rl.py +++ b/src/maxtext/trainers/post_train/rl/utils_rl.py @@ -266,15 +266,15 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs): except (TimeoutException, Exception): pass - # Correct answer gets tmvp_config.reward_exact_format_match points! + # Correct answer gets tmvp_config.reward_exact_answer points! if guess == true_answer: - score += tmvp_config.reward_exact_format_match + score += tmvp_config.reward_exact_answer # Give credit if spaces are seen but otherwise the answers match (useful for simple datasets like gsm8k) elif guess.strip() == true_answer.strip(): score += tmvp_config.reward_white_space_format_match # Answers match upon robust comparison with math_verify elif verified_correct: - score += tmvp_config.reward_exact_format_match + score += tmvp_config.reward_exact_answer else: # We also reward it if the answer is close via ratios! # Ie if the answer is within some range, reward it! @@ -456,13 +456,13 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): # Use math_verify to compare answers (handles both numeric and expression comparison) score, _ = math_verify_func([boxed(true_answer_fixed)], [boxed(guess_fixed)]) # Return scaled score: 1.5 for exact/correct, 0 otherwise - scores.append(1.5 if score > 0.1 else 0.0) + scores.append(tmvp_config.reward_exact_answer if score > 0.1 else 0.0) except (TimeoutException, Exception): # Fallback to simple numeric comparison if math_verify fails try: guess_val = float(normalize_final_answer(guess).strip()) true_val = float(normalize_final_answer(true_answer).strip()) - scores.append(1.5 if guess_val == true_val else 0.0) + scores.append(tmvp_config.reward_exact_answer if guess_val == true_val else 0.0) except: scores.append(0) diff --git a/tests/post_training/unit/rl_utils_test.py b/tests/post_training/unit/rl_utils_test.py index 60bb01a51d..0d4c6285c8 100644 --- a/tests/post_training/unit/rl_utils_test.py +++ b/tests/post_training/unit/rl_utils_test.py @@ -38,6 +38,7 @@ def _make_config(): reasoning_end_token="", solution_start_token="", solution_end_token="", + reward_exact_answer=3.0, reward_exact_format_match=2.0, reward_partial_format_match=0.5, reward_white_space_format_match=1.5, @@ -223,7 +224,7 @@ def test_extraction_succeeds_full_format(self): completions=["40 + 2 = 4242"], answer=["42"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) @pytest.mark.cpu_only def test_extraction_fails_no_tags(self): @@ -241,7 +242,7 @@ def test_extraction_fails_answer_tags_only(self): completions=["42"], answer=["42"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) @pytest.mark.cpu_only def test_extraction_fails_reasoning_tags_only(self): @@ -262,7 +263,7 @@ def test_extraction_batch_mixed(self): ], answer=["7", "7"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) self.assertEqual(scores[1], 0) # --------------------------------------------------------------- @@ -276,7 +277,7 @@ def test_extracted_matches_integer_answer(self): completions=["simple100"], answer=["100"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) @pytest.mark.cpu_only def test_extracted_does_not_match_answer(self): @@ -294,7 +295,7 @@ def test_extracted_matches_comma_formatted_number(self): completions=["cost calculation1,000"], answer=["1000"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) @pytest.mark.cpu_only def test_extracted_matches_with_currency_prefix(self): @@ -303,7 +304,7 @@ def test_extracted_matches_with_currency_prefix(self): completions=["price is $16$16"], answer=["16"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) @pytest.mark.cpu_only def test_extracted_non_numeric_no_match(self):