Skip to content

Commit ee59b2f

Browse files
committed
simplify reward structure; add exact answer reward
1 parent 61fa4f3 commit ee59b2f

5 files changed

Lines changed: 35 additions & 22 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,13 +164,14 @@ max_num_checkpoints_to_keep: 10
164164

165165
# ====== Reward ======
166166

167-
reward_exact_format_match: 3.0
168-
reward_white_space_format_match: 1.5
169-
reward_partial_format_match: 0.5
170-
reward_ratio_guess_to_answer_high: 0.5
171-
reward_ratio_guess_to_answer_low: 0.25
172-
penalty_incorrect_format: -0.5
173-
penalty_incorrect_answer: -1.0
167+
reward_exact_answer: 1.0
168+
reward_white_space_format_match: 1.0
169+
reward_exact_format_match: 0.1
170+
reward_partial_format_match: 0.0
171+
reward_ratio_guess_to_answer_high: 0.0
172+
reward_ratio_guess_to_answer_low: 0.0
173+
penalty_incorrect_format: 0.0
174+
penalty_incorrect_answer: 0.0
174175

175176
# ====== Special tokens/templates for GSM8K reasoning ======
176177
reasoning_start_token: '<reasoning>'

src/maxtext/configs/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1686,6 +1686,7 @@ class RLEvaluation(BaseModel):
16861686
class Reward(BaseModel):
16871687
"""Configuration for the reward/penalty model in RL."""
16881688

1689+
reward_exact_answer: float = Field(5.0, description="Reward for an exact answer match.")
16891690
reward_exact_format_match: float = Field(3.0, description="Reward for an exact format match.")
16901691
reward_white_space_format_match: float = Field(1.5, description="Reward for a format match ignoring whitespace.")
16911692
reward_partial_format_match: float = Field(0.5, description="Reward for a partial format match.")

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
"""
4545

4646
from __future__ import annotations
47+
from functools import wraps
4748
from typing import Sequence
4849

4950
import collections
@@ -564,15 +565,24 @@ def create_rl_components(
564565
**rl_cluster_kwargs,
565566
)
566567

568+
def make_reward_fn(fn):
569+
# pragma: no cover
570+
@wraps(fn)
571+
def _reward_fn(**kwargs):
572+
return fn(tmvp_config=trainer_config, **kwargs)
573+
574+
return _reward_fn
575+
567576
# Create RL trainer
568577
max_logging.log("Setting up RL trainer...")
569578
rl_trainer = GrpoLearner(
570579
rl_cluster=rl_cluster,
571580
reward_fns=[ # type: ignore
572-
lambda **kwargs: utils_rl.match_format_exactly(tmvp_config=trainer_config, **kwargs),
573-
lambda **kwargs: utils_rl.match_format_approximately(tmvp_config=trainer_config, **kwargs),
574-
lambda **kwargs: utils_rl.check_answer(tmvp_config=trainer_config, **kwargs),
575-
lambda **kwargs: utils_rl.check_numbers(tmvp_config=trainer_config, **kwargs),
581+
make_reward_fn(utils_rl.match_format_exactly),
582+
make_reward_fn(utils_rl.match_format_approximately),
583+
# TODO(atwigg): comment out to simplify reward and overlap with check_numbers
584+
make_reward_fn(utils_rl.check_answer),
585+
make_reward_fn(utils_rl.check_numbers),
576586
],
577587
algo_config=grpo_config,
578588
)

src/maxtext/trainers/post_train/rl/utils_rl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,15 +266,15 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs):
266266
except (TimeoutException, Exception):
267267
pass
268268

269-
# Correct answer gets tmvp_config.reward_exact_format_match points!
269+
# Correct answer gets tmvp_config.reward_exact_answer points!
270270
if guess == true_answer:
271-
score += tmvp_config.reward_exact_format_match
271+
score += tmvp_config.reward_exact_answer
272272
# Give credit if spaces are seen but otherwise the answers match (useful for simple datasets like gsm8k)
273273
elif guess.strip() == true_answer.strip():
274274
score += tmvp_config.reward_white_space_format_match
275275
# Answers match upon robust comparison with math_verify
276276
elif verified_correct:
277-
score += tmvp_config.reward_exact_format_match
277+
score += tmvp_config.reward_exact_answer
278278
else:
279279
# We also reward it if the answer is close via ratios!
280280
# Ie if the answer is within some range, reward it!
@@ -456,13 +456,13 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
456456
# Use math_verify to compare answers (handles both numeric and expression comparison)
457457
score, _ = math_verify_func([boxed(true_answer_fixed)], [boxed(guess_fixed)])
458458
# Return scaled score: 1.5 for exact/correct, 0 otherwise
459-
scores.append(1.5 if score > 0.1 else 0.0)
459+
scores.append(tmvp_config.reward_exact_answer if score > 0.1 else 0.0)
460460
except (TimeoutException, Exception):
461461
# Fallback to simple numeric comparison if math_verify fails
462462
try:
463463
guess_val = float(normalize_final_answer(guess).strip())
464464
true_val = float(normalize_final_answer(true_answer).strip())
465-
scores.append(1.5 if guess_val == true_val else 0.0)
465+
scores.append(tmvp_config.reward_exact_answer if guess_val == true_val else 0.0)
466466
except:
467467
scores.append(0)
468468

tests/post_training/unit/rl_utils_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def _make_config():
3838
reasoning_end_token="</reasoning>",
3939
solution_start_token="<answer>",
4040
solution_end_token="</answer>",
41+
reward_exact_answer=3.0,
4142
reward_exact_format_match=2.0,
4243
reward_partial_format_match=0.5,
4344
reward_white_space_format_match=1.5,
@@ -223,7 +224,7 @@ def test_extraction_succeeds_full_format(self):
223224
completions=["<reasoning>40 + 2 = 42</reasoning><answer>42</answer>"],
224225
answer=["42"],
225226
)
226-
self.assertEqual(scores[0], 1.5)
227+
self.assertEqual(scores[0], self.config.reward_exact_answer)
227228

228229
@pytest.mark.cpu_only
229230
def test_extraction_fails_no_tags(self):
@@ -241,7 +242,7 @@ def test_extraction_fails_answer_tags_only(self):
241242
completions=["<answer>42</answer>"],
242243
answer=["42"],
243244
)
244-
self.assertEqual(scores[0], 1.5)
245+
self.assertEqual(scores[0], self.config.reward_exact_answer)
245246

246247
@pytest.mark.cpu_only
247248
def test_extraction_fails_reasoning_tags_only(self):
@@ -262,7 +263,7 @@ def test_extraction_batch_mixed(self):
262263
],
263264
answer=["7", "7"],
264265
)
265-
self.assertEqual(scores[0], 1.5)
266+
self.assertEqual(scores[0], self.config.reward_exact_answer)
266267
self.assertEqual(scores[1], 0)
267268

268269
# ---------------------------------------------------------------
@@ -276,7 +277,7 @@ def test_extracted_matches_integer_answer(self):
276277
completions=["<reasoning>simple</reasoning><answer>100</answer>"],
277278
answer=["100"],
278279
)
279-
self.assertEqual(scores[0], 1.5)
280+
self.assertEqual(scores[0], self.config.reward_exact_answer)
280281

281282
@pytest.mark.cpu_only
282283
def test_extracted_does_not_match_answer(self):
@@ -294,7 +295,7 @@ def test_extracted_matches_comma_formatted_number(self):
294295
completions=["<reasoning>cost calculation</reasoning><answer>1,000</answer>"],
295296
answer=["1000"],
296297
)
297-
self.assertEqual(scores[0], 1.5)
298+
self.assertEqual(scores[0], self.config.reward_exact_answer)
298299

299300
@pytest.mark.cpu_only
300301
def test_extracted_matches_with_currency_prefix(self):
@@ -303,7 +304,7 @@ def test_extracted_matches_with_currency_prefix(self):
303304
completions=["<reasoning>price is $16</reasoning><answer>$16</answer>"],
304305
answer=["16"],
305306
)
306-
self.assertEqual(scores[0], 1.5)
307+
self.assertEqual(scores[0], self.config.reward_exact_answer)
307308

308309
@pytest.mark.cpu_only
309310
def test_extracted_non_numeric_no_match(self):

0 commit comments

Comments
 (0)