diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml
index 9cc305d090..3787146d90 100644
--- a/src/maxtext/configs/post_train/rl.yml
+++ b/src/maxtext/configs/post_train/rl.yml
@@ -164,13 +164,14 @@ max_num_checkpoints_to_keep: 10
# ====== Reward ======
-reward_exact_format_match: 3.0
-reward_white_space_format_match: 1.5
-reward_partial_format_match: 0.5
-reward_ratio_guess_to_answer_high: 0.5
-reward_ratio_guess_to_answer_low: 0.25
-penalty_incorrect_format: -0.5
-penalty_incorrect_answer: -1.0
+reward_exact_answer: 1.0
+reward_white_space_format_match: 1.0
+reward_exact_format_match: 0.1
+reward_partial_format_match: 0.0
+reward_ratio_guess_to_answer_high: 0.0
+reward_ratio_guess_to_answer_low: 0.0
+penalty_incorrect_format: 0.0
+penalty_incorrect_answer: 0.0
# ====== Special tokens/templates for GSM8K reasoning ======
reasoning_start_token: ''
diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py
index 9ee1a7bb59..cdb61a25fc 100644
--- a/src/maxtext/configs/types.py
+++ b/src/maxtext/configs/types.py
@@ -1686,6 +1686,7 @@ class RLEvaluation(BaseModel):
class Reward(BaseModel):
"""Configuration for the reward/penalty model in RL."""
+ reward_exact_answer: float = Field(5.0, description="Reward for an exact answer match.")
reward_exact_format_match: float = Field(3.0, description="Reward for an exact format match.")
reward_white_space_format_match: float = Field(1.5, description="Reward for a format match ignoring whitespace.")
reward_partial_format_match: float = Field(0.5, description="Reward for a partial format match.")
diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py
index 7f5a33eed6..d77270f62e 100644
--- a/src/maxtext/trainers/post_train/rl/train_rl.py
+++ b/src/maxtext/trainers/post_train/rl/train_rl.py
@@ -44,6 +44,7 @@
"""
from __future__ import annotations
+from functools import wraps
from typing import Sequence
import collections
@@ -564,15 +565,24 @@ def create_rl_components(
**rl_cluster_kwargs,
)
+ def make_reward_fn(fn):
+ # pragma: no cover
+ @wraps(fn)
+ def _reward_fn(**kwargs):
+ return fn(tmvp_config=trainer_config, **kwargs)
+
+ return _reward_fn
+
# Create RL trainer
max_logging.log("Setting up RL trainer...")
rl_trainer = GrpoLearner(
rl_cluster=rl_cluster,
reward_fns=[ # type: ignore
- lambda **kwargs: utils_rl.match_format_exactly(tmvp_config=trainer_config, **kwargs),
- lambda **kwargs: utils_rl.match_format_approximately(tmvp_config=trainer_config, **kwargs),
- lambda **kwargs: utils_rl.check_answer(tmvp_config=trainer_config, **kwargs),
- lambda **kwargs: utils_rl.check_numbers(tmvp_config=trainer_config, **kwargs),
+ make_reward_fn(utils_rl.match_format_exactly),
+ make_reward_fn(utils_rl.match_format_approximately),
+ # TODO(atwigg): comment out to simplify reward and overlap with check_numbers
+ make_reward_fn(utils_rl.check_answer),
+ make_reward_fn(utils_rl.check_numbers),
],
algo_config=grpo_config,
)
diff --git a/src/maxtext/trainers/post_train/rl/utils_rl.py b/src/maxtext/trainers/post_train/rl/utils_rl.py
index 782983cefa..bca0ae7bbd 100644
--- a/src/maxtext/trainers/post_train/rl/utils_rl.py
+++ b/src/maxtext/trainers/post_train/rl/utils_rl.py
@@ -266,15 +266,15 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs):
except (TimeoutException, Exception):
pass
- # Correct answer gets tmvp_config.reward_exact_format_match points!
+ # Correct answer gets tmvp_config.reward_exact_answer points!
if guess == true_answer:
- score += tmvp_config.reward_exact_format_match
+ score += tmvp_config.reward_exact_answer
# Give credit if spaces are seen but otherwise the answers match (useful for simple datasets like gsm8k)
elif guess.strip() == true_answer.strip():
score += tmvp_config.reward_white_space_format_match
# Answers match upon robust comparison with math_verify
elif verified_correct:
- score += tmvp_config.reward_exact_format_match
+ score += tmvp_config.reward_exact_answer
else:
# We also reward it if the answer is close via ratios!
# Ie if the answer is within some range, reward it!
@@ -456,13 +456,13 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
# Use math_verify to compare answers (handles both numeric and expression comparison)
score, _ = math_verify_func([boxed(true_answer_fixed)], [boxed(guess_fixed)])
# Return scaled score: 1.5 for exact/correct, 0 otherwise
- scores.append(1.5 if score > 0.1 else 0.0)
+ scores.append(tmvp_config.reward_exact_answer if score > 0.1 else 0.0)
except (TimeoutException, Exception):
# Fallback to simple numeric comparison if math_verify fails
try:
guess_val = float(normalize_final_answer(guess).strip())
true_val = float(normalize_final_answer(true_answer).strip())
- scores.append(1.5 if guess_val == true_val else 0.0)
+ scores.append(tmvp_config.reward_exact_answer if guess_val == true_val else 0.0)
except:
scores.append(0)
diff --git a/tests/post_training/unit/rl_utils_test.py b/tests/post_training/unit/rl_utils_test.py
index 60bb01a51d..0d4c6285c8 100644
--- a/tests/post_training/unit/rl_utils_test.py
+++ b/tests/post_training/unit/rl_utils_test.py
@@ -38,6 +38,7 @@ def _make_config():
reasoning_end_token="",
solution_start_token="",
solution_end_token="",
+ reward_exact_answer=3.0,
reward_exact_format_match=2.0,
reward_partial_format_match=0.5,
reward_white_space_format_match=1.5,
@@ -223,7 +224,7 @@ def test_extraction_succeeds_full_format(self):
completions=["40 + 2 = 4242"],
answer=["42"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
@pytest.mark.cpu_only
def test_extraction_fails_no_tags(self):
@@ -241,7 +242,7 @@ def test_extraction_fails_answer_tags_only(self):
completions=["42"],
answer=["42"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
@pytest.mark.cpu_only
def test_extraction_fails_reasoning_tags_only(self):
@@ -262,7 +263,7 @@ def test_extraction_batch_mixed(self):
],
answer=["7", "7"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
self.assertEqual(scores[1], 0)
# ---------------------------------------------------------------
@@ -276,7 +277,7 @@ def test_extracted_matches_integer_answer(self):
completions=["simple100"],
answer=["100"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
@pytest.mark.cpu_only
def test_extracted_does_not_match_answer(self):
@@ -294,7 +295,7 @@ def test_extracted_matches_comma_formatted_number(self):
completions=["cost calculation1,000"],
answer=["1000"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
@pytest.mark.cpu_only
def test_extracted_matches_with_currency_prefix(self):
@@ -303,7 +304,7 @@ def test_extracted_matches_with_currency_prefix(self):
completions=["price is $16$16"],
answer=["16"],
)
- self.assertEqual(scores[0], 1.5)
+ self.assertEqual(scores[0], self.config.reward_exact_answer)
@pytest.mark.cpu_only
def test_extracted_non_numeric_no_match(self):