diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml
index da455a13e2..bf2fc75563 100644
--- a/src/maxtext/configs/post_train/rl.yml
+++ b/src/maxtext/configs/post_train/rl.yml
@@ -164,6 +164,7 @@ max_num_checkpoints_to_keep: 10
# ====== Reward ======
+reward_exact_answer: 5.0
reward_exact_format_match: 3.0
reward_white_space_format_match: 1.5
reward_partial_format_match: 0.5
diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py
index 388247b5a0..b91344f126 100644
--- a/src/maxtext/configs/types.py
+++ b/src/maxtext/configs/types.py
@@ -1664,6 +1664,7 @@ class RLEvaluation(BaseModel):
class Reward(BaseModel):
"""Configuration for the reward/penalty model in RL."""
+ reward_exact_answer: float = Field(5.0, description="Reward for an exact answer match.")
reward_exact_format_match: float = Field(3.0, description="Reward for an exact format match.")
reward_white_space_format_match: float = Field(1.5, description="Reward for a format match ignoring whitespace.")
reward_partial_format_match: float = Field(0.5, description="Reward for a partial format match.")
diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py
index 7f5a33eed6..ffdf5d244f 100644
--- a/src/maxtext/trainers/post_train/rl/train_rl.py
+++ b/src/maxtext/trainers/post_train/rl/train_rl.py
@@ -564,16 +564,18 @@ def create_rl_components(
**rl_cluster_kwargs,
)
+ reward_fns = [
+ utils_rl.make_reward_fn(trainer_config, utils_rl.match_format_exactly),
+ utils_rl.make_reward_fn(trainer_config, utils_rl.match_format_approximately),
+ # utils_rl.make_reward_fn(utils_rl.check_answer), # atwigg: commenting out since it overlaps with check_numbers
+ utils_rl.make_reward_fn(trainer_config, utils_rl.check_numbers),
+ ]
+
# Create RL trainer
max_logging.log("Setting up RL trainer...")
rl_trainer = GrpoLearner(
rl_cluster=rl_cluster,
- reward_fns=[ # type: ignore
- lambda **kwargs: utils_rl.match_format_exactly(tmvp_config=trainer_config, **kwargs),
- lambda **kwargs: utils_rl.match_format_approximately(tmvp_config=trainer_config, **kwargs),
- lambda **kwargs: utils_rl.check_answer(tmvp_config=trainer_config, **kwargs),
- lambda **kwargs: utils_rl.check_numbers(tmvp_config=trainer_config, **kwargs),
- ],
+ reward_fns=reward_fns, # type: ignore
algo_config=grpo_config,
)
diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py
index c37db48c0a..6d748164b7 100644
--- a/src/maxtext/trainers/post_train/rl/utils_rl.py
+++ b/src/maxtext/trainers/post_train/rl/utils_rl.py
@@ -15,6 +15,7 @@
# pylint: disable=bare-except, consider-using-generator, chained-comparison, broad-exception-caught
"""RL Utils Module."""
import re
+from functools import wraps
import optax
from maxtext.utils import max_logging
@@ -34,6 +35,16 @@
)
+def make_reward_fn(trainer_config, reward_fn):
+ """Wrap reward_fn with trainer config while preserving function metadata."""
+
+ @wraps(reward_fn)
+ def _reward_fn(**kwargs):
+ return reward_fn(tmvp_config=trainer_config, **kwargs)
+
+ return _reward_fn
+
+
def boxed(x):
"""Wraps the input string in a LaTeX boxed command if it's not already wrapped."""
return "\\boxed{" + x + "}" if not x.startswith("\\boxed{") else x
@@ -257,15 +268,15 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs):
except (TimeoutException, Exception):
pass
- # Correct answer gets tmvp_config.reward_exact_format_match points!
+ # Correct answer gets tmvp_config.reward_exact_answer points!
if guess == true_answer:
- score += tmvp_config.reward_exact_format_match
+ score += tmvp_config.reward_exact_answer
# Give credit if spaces are seen but otherwise the answers match (useful for simple datasets like gsm8k)
elif guess.strip() == true_answer.strip():
score += tmvp_config.reward_white_space_format_match
# Answers match upon robust comparison with math_verify
elif verified_correct:
- score += tmvp_config.reward_exact_format_match
+ score += tmvp_config.reward_exact_answer
else:
# We also reward it if the answer is close via ratios!
# Ie if the answer is within some range, reward it!
@@ -437,14 +448,14 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
# Use math_verify to compare answers (handles both numeric and expression comparison)
score, _ = math_verify_func([boxed(true_answer_fixed)], [boxed(guess_fixed)])
- # Return scaled score: 1.5 for exact/correct, 0 otherwise
- scores.append(1.5 if score > 0.1 else 0.0)
+ # Return binary score for exact/correct, 0 otherwise
+ scores.append(tmvp_config.reward_exact_answer if score > 0.1 else 0.0)
except (TimeoutException, Exception):
# Fallback to simple numeric comparison if math_verify fails
try:
guess_val = float(normalize_final_answer(guess).strip())
true_val = float(normalize_final_answer(true_answer).strip())
- scores.append(1.5 if guess_val == true_val else 0.0)
+ scores.append(tmvp_config.reward_exact_answer if guess_val == true_val else 0.0)
except:
scores.append(0)
diff --git a/tests/post_training/unit/rl_utils_test.py b/tests/post_training/unit/rl_utils_test.py
index c8aaa7dd83..5e7f08244f 100644
--- a/tests/post_training/unit/rl_utils_test.py
+++ b/tests/post_training/unit/rl_utils_test.py
@@ -38,6 +38,7 @@ def _make_config():
reasoning_end_token="",
solution_start_token="",
solution_end_token="",
+ reward_exact_answer=3.0,
reward_exact_format_match=2.0,
reward_partial_format_match=0.5,
reward_white_space_format_match=1.5,
@@ -223,7 +224,7 @@ def test_extraction_succeeds_full_format(self):
completions=["40 + 2 = 4242"],
answer=["42"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
@pytest.mark.cpu_only
def test_extraction_fails_no_tags(self):
@@ -262,7 +263,7 @@ def test_extraction_batch_mixed(self):
],
answer=["7", "7"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
self.assertEqual(scores[1], 0)
# ---------------------------------------------------------------
@@ -271,12 +272,12 @@ def test_extraction_batch_mixed(self):
@pytest.mark.cpu_only
def test_extracted_matches_integer_answer(self):
- """Extracted integer equal to reference answer earns 1.5."""
+ """Extracted integer equal to reference answer earns reward_exact_answer."""
scores = self._check(
completions=["simple100"],
answer=["100"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
@pytest.mark.cpu_only
def test_extracted_does_not_match_answer(self):
@@ -294,7 +295,7 @@ def test_extracted_matches_comma_formatted_number(self):
completions=["cost calculation1,000"],
answer=["1000"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
@pytest.mark.cpu_only
def test_extracted_matches_with_currency_prefix(self):
@@ -303,7 +304,7 @@ def test_extracted_matches_with_currency_prefix(self):
completions=["price is $16$16"],
answer=["16"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
@pytest.mark.cpu_only
def test_extracted_non_numeric_no_match(self):
diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py
index af66d52a98..5f54ee30ef 100644
--- a/tests/post_training/unit/train_rl_test.py
+++ b/tests/post_training/unit/train_rl_test.py
@@ -44,6 +44,25 @@ def _get_mock_devices(devices_per_slice, num_slices=1):
class TrainRLTest(unittest.TestCase):
"""Tests for train_rl.py."""
+ @pytest.mark.cpu_only
+ def test_make_reward_fn_preserves_name_and_injects_config(self):
+ """_make_reward_fn should preserve __name__ and pass through tmvp_config."""
+ trainer_config = SimpleNamespace(name="cfg")
+
+ def sample_reward_fn(tmvp_config, prompts, completions, answer, **kwargs):
+ self.assertIs(tmvp_config, trainer_config)
+ self.assertEqual(prompts, ["p"])
+ self.assertEqual(completions, ["c"])
+ self.assertEqual(answer, ["a"])
+ self.assertEqual(kwargs.get("question"), ["q"])
+ return [1.5]
+
+ wrapped_reward_fn = train_rl.utils_rl.make_reward_fn(trainer_config, sample_reward_fn)
+
+ self.assertEqual(wrapped_reward_fn.__name__, "sample_reward_fn")
+ result = wrapped_reward_fn(prompts=["p"], completions=["c"], answer=["a"], question=["q"])
+ self.assertEqual(result, [1.5])
+
@pytest.mark.cpu_only
def test_setup_configs_and_devices_pathways_split(self):
"""Test setup_configs_and_devices with multiple VMs and Pathways."""