Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,14 @@ max_num_checkpoints_to_keep: 10

# ====== Reward ======

reward_exact_format_match: 3.0
reward_white_space_format_match: 1.5
reward_partial_format_match: 0.5
reward_ratio_guess_to_answer_high: 0.5
reward_ratio_guess_to_answer_low: 0.25
penalty_incorrect_format: -0.5
penalty_incorrect_answer: -1.0
reward_exact_answer: 1.0
reward_white_space_format_match: 1.0
reward_exact_format_match: 0.1
reward_partial_format_match: 0.0
reward_ratio_guess_to_answer_high: 0.0
reward_ratio_guess_to_answer_low: 0.0
penalty_incorrect_format: 0.0
penalty_incorrect_answer: 0.0

# ====== Special tokens/templates for GSM8K reasoning ======
reasoning_start_token: '<reasoning>'
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1686,6 +1686,7 @@ class RLEvaluation(BaseModel):
class Reward(BaseModel):
"""Configuration for the reward/penalty model in RL."""

reward_exact_answer: float = Field(5.0, description="Reward for an exact answer match.")
reward_exact_format_match: float = Field(3.0, description="Reward for an exact format match.")
reward_white_space_format_match: float = Field(1.5, description="Reward for a format match ignoring whitespace.")
reward_partial_format_match: float = Field(0.5, description="Reward for a partial format match.")
Expand Down
18 changes: 14 additions & 4 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"""

from __future__ import annotations
from functools import wraps
from typing import Sequence

import collections
Expand Down Expand Up @@ -564,15 +565,24 @@ def create_rl_components(
**rl_cluster_kwargs,
)

def make_reward_fn(fn):
# pragma: no cover
@wraps(fn)
def _reward_fn(**kwargs):
return fn(tmvp_config=trainer_config, **kwargs)

return _reward_fn

# Create RL trainer
max_logging.log("Setting up RL trainer...")
rl_trainer = GrpoLearner(
rl_cluster=rl_cluster,
reward_fns=[ # type: ignore
lambda **kwargs: utils_rl.match_format_exactly(tmvp_config=trainer_config, **kwargs),
lambda **kwargs: utils_rl.match_format_approximately(tmvp_config=trainer_config, **kwargs),
lambda **kwargs: utils_rl.check_answer(tmvp_config=trainer_config, **kwargs),
lambda **kwargs: utils_rl.check_numbers(tmvp_config=trainer_config, **kwargs),
make_reward_fn(utils_rl.match_format_exactly),
make_reward_fn(utils_rl.match_format_approximately),
# TODO(atwigg): comment out to simplify reward and overlap with check_numbers
make_reward_fn(utils_rl.check_answer),
make_reward_fn(utils_rl.check_numbers),
],
algo_config=grpo_config,
)
Expand Down
10 changes: 5 additions & 5 deletions src/maxtext/trainers/post_train/rl/utils_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,15 +266,15 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs):
except (TimeoutException, Exception):
pass

# Correct answer gets tmvp_config.reward_exact_format_match points!
# Correct answer gets tmvp_config.reward_exact_answer points!
if guess == true_answer:
score += tmvp_config.reward_exact_format_match
score += tmvp_config.reward_exact_answer
# Give credit if spaces are seen but otherwise the answers match (useful for simple datasets like gsm8k)
elif guess.strip() == true_answer.strip():
score += tmvp_config.reward_white_space_format_match
# Answers match upon robust comparison with math_verify
elif verified_correct:
score += tmvp_config.reward_exact_format_match
score += tmvp_config.reward_exact_answer
else:
# We also reward it if the answer is close via ratios!
# Ie if the answer is within some range, reward it!
Expand Down Expand Up @@ -456,13 +456,13 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):
# Use math_verify to compare answers (handles both numeric and expression comparison)
score, _ = math_verify_func([boxed(true_answer_fixed)], [boxed(guess_fixed)])
# Return scaled score: 1.5 for exact/correct, 0 otherwise
scores.append(1.5 if score > 0.1 else 0.0)
scores.append(tmvp_config.reward_exact_answer if score > 0.1 else 0.0)
except (TimeoutException, Exception):
# Fallback to simple numeric comparison if math_verify fails
try:
guess_val = float(normalize_final_answer(guess).strip())
true_val = float(normalize_final_answer(true_answer).strip())
scores.append(1.5 if guess_val == true_val else 0.0)
scores.append(tmvp_config.reward_exact_answer if guess_val == true_val else 0.0)
except:
scores.append(0)

Expand Down
13 changes: 7 additions & 6 deletions tests/post_training/unit/rl_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _make_config():
reasoning_end_token="</reasoning>",
solution_start_token="<answer>",
solution_end_token="</answer>",
reward_exact_answer=3.0,
reward_exact_format_match=2.0,
reward_partial_format_match=0.5,
reward_white_space_format_match=1.5,
Expand Down Expand Up @@ -223,7 +224,7 @@ def test_extraction_succeeds_full_format(self):
completions=["<reasoning>40 + 2 = 42</reasoning><answer>42</answer>"],
answer=["42"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)

@pytest.mark.cpu_only
def test_extraction_fails_no_tags(self):
Expand All @@ -241,7 +242,7 @@ def test_extraction_fails_answer_tags_only(self):
completions=["<answer>42</answer>"],
answer=["42"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)

@pytest.mark.cpu_only
def test_extraction_fails_reasoning_tags_only(self):
Expand All @@ -262,7 +263,7 @@ def test_extraction_batch_mixed(self):
],
answer=["7", "7"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)
self.assertEqual(scores[1], 0)

# ---------------------------------------------------------------
Expand All @@ -276,7 +277,7 @@ def test_extracted_matches_integer_answer(self):
completions=["<reasoning>simple</reasoning><answer>100</answer>"],
answer=["100"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)

@pytest.mark.cpu_only
def test_extracted_does_not_match_answer(self):
Expand All @@ -294,7 +295,7 @@ def test_extracted_matches_comma_formatted_number(self):
completions=["<reasoning>cost calculation</reasoning><answer>1,000</answer>"],
answer=["1000"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)

@pytest.mark.cpu_only
def test_extracted_matches_with_currency_prefix(self):
Expand All @@ -303,7 +304,7 @@ def test_extracted_matches_with_currency_prefix(self):
completions=["<reasoning>price is $16</reasoning><answer>$16</answer>"],
answer=["16"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)

@pytest.mark.cpu_only
def test_extracted_non_numeric_no_match(self):
Expand Down
Loading