diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index da455a13e2..bf2fc75563 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -164,6 +164,7 @@ max_num_checkpoints_to_keep: 10 # ====== Reward ====== +reward_exact_answer: 5.0 reward_exact_format_match: 3.0 reward_white_space_format_match: 1.5 reward_partial_format_match: 0.5 diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 388247b5a0..b91344f126 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1664,6 +1664,7 @@ class RLEvaluation(BaseModel): class Reward(BaseModel): """Configuration for the reward/penalty model in RL.""" + reward_exact_answer: float = Field(5.0, description="Reward for an exact answer match.") reward_exact_format_match: float = Field(3.0, description="Reward for an exact format match.") reward_white_space_format_match: float = Field(1.5, description="Reward for a format match ignoring whitespace.") reward_partial_format_match: float = Field(0.5, description="Reward for a partial format match.") diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 7f5a33eed6..ffdf5d244f 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -564,16 +564,18 @@ def create_rl_components( **rl_cluster_kwargs, ) + reward_fns = [ + utils_rl.make_reward_fn(trainer_config, utils_rl.match_format_exactly), + utils_rl.make_reward_fn(trainer_config, utils_rl.match_format_approximately), + # utils_rl.make_reward_fn(utils_rl.check_answer), # atwigg: commenting out since it overlaps with check_numbers + utils_rl.make_reward_fn(trainer_config, utils_rl.check_numbers), + ] + # Create RL trainer max_logging.log("Setting up RL trainer...") rl_trainer = GrpoLearner( rl_cluster=rl_cluster, - reward_fns=[ # type: ignore - lambda **kwargs: utils_rl.match_format_exactly(tmvp_config=trainer_config, **kwargs), - lambda **kwargs: utils_rl.match_format_approximately(tmvp_config=trainer_config, **kwargs), - lambda **kwargs: utils_rl.check_answer(tmvp_config=trainer_config, **kwargs), - lambda **kwargs: utils_rl.check_numbers(tmvp_config=trainer_config, **kwargs), - ], + reward_fns=reward_fns, # type: ignore algo_config=grpo_config, ) diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py index c37db48c0a..6d748164b7 100644 --- a/src/maxtext/trainers/post_train/rl/utils_rl.py +++ b/src/maxtext/trainers/post_train/rl/utils_rl.py @@ -15,6 +15,7 @@ # pylint: disable=bare-except, consider-using-generator, chained-comparison, broad-exception-caught """RL Utils Module.""" import re +from functools import wraps import optax from maxtext.utils import max_logging @@ -34,6 +35,16 @@ ) +def make_reward_fn(trainer_config, reward_fn): + """Wrap reward_fn with trainer config while preserving function metadata.""" + + @wraps(reward_fn) + def _reward_fn(**kwargs): + return reward_fn(tmvp_config=trainer_config, **kwargs) + + return _reward_fn + + def boxed(x): """Wraps the input string in a LaTeX boxed command if it's not already wrapped.""" return "\\boxed{" + x + "}" if not x.startswith("\\boxed{") else x @@ -257,15 +268,15 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs): except (TimeoutException, Exception): pass - # Correct answer gets tmvp_config.reward_exact_format_match points! + # Correct answer gets tmvp_config.reward_exact_answer points! if guess == true_answer: - score += tmvp_config.reward_exact_format_match + score += tmvp_config.reward_exact_answer # Give credit if spaces are seen but otherwise the answers match (useful for simple datasets like gsm8k) elif guess.strip() == true_answer.strip(): score += tmvp_config.reward_white_space_format_match # Answers match upon robust comparison with math_verify elif verified_correct: - score += tmvp_config.reward_exact_format_match + score += tmvp_config.reward_exact_answer else: # We also reward it if the answer is close via ratios! # Ie if the answer is within some range, reward it! @@ -437,14 +448,14 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs): # Use math_verify to compare answers (handles both numeric and expression comparison) score, _ = math_verify_func([boxed(true_answer_fixed)], [boxed(guess_fixed)]) - # Return scaled score: 1.5 for exact/correct, 0 otherwise - scores.append(1.5 if score > 0.1 else 0.0) + # Return binary score for exact/correct, 0 otherwise + scores.append(tmvp_config.reward_exact_answer if score > 0.1 else 0.0) except (TimeoutException, Exception): # Fallback to simple numeric comparison if math_verify fails try: guess_val = float(normalize_final_answer(guess).strip()) true_val = float(normalize_final_answer(true_answer).strip()) - scores.append(1.5 if guess_val == true_val else 0.0) + scores.append(tmvp_config.reward_exact_answer if guess_val == true_val else 0.0) except: scores.append(0) diff --git a/tests/post_training/unit/rl_utils_test.py b/tests/post_training/unit/rl_utils_test.py index c8aaa7dd83..5e7f08244f 100644 --- a/tests/post_training/unit/rl_utils_test.py +++ b/tests/post_training/unit/rl_utils_test.py @@ -38,6 +38,7 @@ def _make_config(): reasoning_end_token="", solution_start_token="", solution_end_token="", + reward_exact_answer=3.0, reward_exact_format_match=2.0, reward_partial_format_match=0.5, reward_white_space_format_match=1.5, @@ -223,7 +224,7 @@ def test_extraction_succeeds_full_format(self): completions=["40 + 2 = 4242"], answer=["42"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) @pytest.mark.cpu_only def test_extraction_fails_no_tags(self): @@ -262,7 +263,7 @@ def test_extraction_batch_mixed(self): ], answer=["7", "7"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) self.assertEqual(scores[1], 0) # --------------------------------------------------------------- @@ -271,12 +272,12 @@ def test_extraction_batch_mixed(self): @pytest.mark.cpu_only def test_extracted_matches_integer_answer(self): - """Extracted integer equal to reference answer earns 1.5.""" + """Extracted integer equal to reference answer earns reward_exact_answer.""" scores = self._check( completions=["simple100"], answer=["100"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) @pytest.mark.cpu_only def test_extracted_does_not_match_answer(self): @@ -294,7 +295,7 @@ def test_extracted_matches_comma_formatted_number(self): completions=["cost calculation1,000"], answer=["1000"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) @pytest.mark.cpu_only def test_extracted_matches_with_currency_prefix(self): @@ -303,7 +304,7 @@ def test_extracted_matches_with_currency_prefix(self): completions=["price is $16$16"], answer=["16"], ) - self.assertEqual(scores[0], 1.5) + self.assertEqual(scores[0], self.config.reward_exact_answer) @pytest.mark.cpu_only def test_extracted_non_numeric_no_match(self): diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index af66d52a98..5f54ee30ef 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -44,6 +44,25 @@ def _get_mock_devices(devices_per_slice, num_slices=1): class TrainRLTest(unittest.TestCase): """Tests for train_rl.py.""" + @pytest.mark.cpu_only + def test_make_reward_fn_preserves_name_and_injects_config(self): + """_make_reward_fn should preserve __name__ and pass through tmvp_config.""" + trainer_config = SimpleNamespace(name="cfg") + + def sample_reward_fn(tmvp_config, prompts, completions, answer, **kwargs): + self.assertIs(tmvp_config, trainer_config) + self.assertEqual(prompts, ["p"]) + self.assertEqual(completions, ["c"]) + self.assertEqual(answer, ["a"]) + self.assertEqual(kwargs.get("question"), ["q"]) + return [1.5] + + wrapped_reward_fn = train_rl.utils_rl.make_reward_fn(trainer_config, sample_reward_fn) + + self.assertEqual(wrapped_reward_fn.__name__, "sample_reward_fn") + result = wrapped_reward_fn(prompts=["p"], completions=["c"], answer=["a"], question=["q"]) + self.assertEqual(result, [1.5]) + @pytest.mark.cpu_only def test_setup_configs_and_devices_pathways_split(self): """Test setup_configs_and_devices with multiple VMs and Pathways."""