Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/maxtext/configs/post_train/rl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ max_num_checkpoints_to_keep: 10

# ====== Reward ======

reward_exact_answer: 5.0
reward_exact_format_match: 3.0
reward_white_space_format_match: 1.5
reward_partial_format_match: 0.5
Expand Down
1 change: 1 addition & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,7 @@ class RLEvaluation(BaseModel):
class Reward(BaseModel):
"""Configuration for the reward/penalty model in RL."""

reward_exact_answer: float = Field(5.0, description="Reward for an exact answer match.")
reward_exact_format_match: float = Field(3.0, description="Reward for an exact format match.")
reward_white_space_format_match: float = Field(1.5, description="Reward for a format match ignoring whitespace.")
reward_partial_format_match: float = Field(0.5, description="Reward for a partial format match.")
Expand Down
14 changes: 8 additions & 6 deletions src/maxtext/trainers/post_train/rl/train_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,16 +564,18 @@ def create_rl_components(
**rl_cluster_kwargs,
)

reward_fns = [
utils_rl.make_reward_fn(trainer_config, utils_rl.match_format_exactly),
utils_rl.make_reward_fn(trainer_config, utils_rl.match_format_approximately),
# utils_rl.make_reward_fn(utils_rl.check_answer), # atwigg: commenting out since it overlaps with check_numbers
utils_rl.make_reward_fn(trainer_config, utils_rl.check_numbers),
]

# Create RL trainer
max_logging.log("Setting up RL trainer...")
rl_trainer = GrpoLearner(
rl_cluster=rl_cluster,
reward_fns=[ # type: ignore
lambda **kwargs: utils_rl.match_format_exactly(tmvp_config=trainer_config, **kwargs),
lambda **kwargs: utils_rl.match_format_approximately(tmvp_config=trainer_config, **kwargs),
lambda **kwargs: utils_rl.check_answer(tmvp_config=trainer_config, **kwargs),
lambda **kwargs: utils_rl.check_numbers(tmvp_config=trainer_config, **kwargs),
],
reward_fns=reward_fns, # type: ignore
algo_config=grpo_config,
)

Expand Down
23 changes: 17 additions & 6 deletions src/maxtext/trainers/post_train/rl/utils_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# pylint: disable=bare-except, consider-using-generator, chained-comparison, broad-exception-caught
"""RL Utils Module."""
import re
from functools import wraps
import optax
from maxtext.utils import max_logging

Expand All @@ -34,6 +35,16 @@
)


def make_reward_fn(trainer_config, reward_fn):
"""Wrap reward_fn with trainer config while preserving function metadata."""

@wraps(reward_fn)
def _reward_fn(**kwargs):
return reward_fn(tmvp_config=trainer_config, **kwargs)

return _reward_fn


def boxed(x):
"""Wraps the input string in a LaTeX boxed command if it's not already wrapped."""
return "\\boxed{" + x + "}" if not x.startswith("\\boxed{") else x
Expand Down Expand Up @@ -257,15 +268,15 @@ def check_answer(prompts, completions, answer, tmvp_config, **kargs):
except (TimeoutException, Exception):
pass

# Correct answer gets tmvp_config.reward_exact_format_match points!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this becomes the primary reward function now and we are skipping check_numbers() by default, could you please bring the debug logic from check_numbers() from here. This really helps in debugging extraction logic

# Correct answer gets tmvp_config.reward_exact_answer points!
if guess == true_answer:
score += tmvp_config.reward_exact_format_match
score += tmvp_config.reward_exact_answer
# Give credit if spaces are seen but otherwise the answers match (useful for simple datasets like gsm8k)
elif guess.strip() == true_answer.strip():
score += tmvp_config.reward_white_space_format_match
# Answers match upon robust comparison with math_verify
elif verified_correct:
score += tmvp_config.reward_exact_format_match
score += tmvp_config.reward_exact_answer
else:
# We also reward it if the answer is close via ratios!
# Ie if the answer is within some range, reward it!
Expand Down Expand Up @@ -437,14 +448,14 @@ def check_numbers(prompts, completions, answer, tmvp_config, **kargs):

# Use math_verify to compare answers (handles both numeric and expression comparison)
score, _ = math_verify_func([boxed(true_answer_fixed)], [boxed(guess_fixed)])
# Return scaled score: 1.5 for exact/correct, 0 otherwise
scores.append(1.5 if score > 0.1 else 0.0)
# Return binary score for exact/correct, 0 otherwise
scores.append(tmvp_config.reward_exact_answer if score > 0.1 else 0.0)
except (TimeoutException, Exception):
# Fallback to simple numeric comparison if math_verify fails
try:
guess_val = float(normalize_final_answer(guess).strip())
true_val = float(normalize_final_answer(true_answer).strip())
scores.append(1.5 if guess_val == true_val else 0.0)
scores.append(tmvp_config.reward_exact_answer if guess_val == true_val else 0.0)
except:
scores.append(0)

Expand Down
13 changes: 7 additions & 6 deletions tests/post_training/unit/rl_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _make_config():
reasoning_end_token="</reasoning>",
solution_start_token="<answer>",
solution_end_token="</answer>",
reward_exact_answer=3.0,
reward_exact_format_match=2.0,
reward_partial_format_match=0.5,
reward_white_space_format_match=1.5,
Expand Down Expand Up @@ -223,7 +224,7 @@ def test_extraction_succeeds_full_format(self):
completions=["<reasoning>40 + 2 = 42</reasoning><answer>42</answer>"],
answer=["42"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)

@pytest.mark.cpu_only
def test_extraction_fails_no_tags(self):
Expand Down Expand Up @@ -262,7 +263,7 @@ def test_extraction_batch_mixed(self):
],
answer=["7", "7"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)
self.assertEqual(scores[1], 0)

# ---------------------------------------------------------------
Expand All @@ -271,12 +272,12 @@ def test_extraction_batch_mixed(self):

@pytest.mark.cpu_only
def test_extracted_matches_integer_answer(self):
"""Extracted integer equal to reference answer earns 1.5."""
"""Extracted integer equal to reference answer earns reward_exact_answer."""
scores = self._check(
completions=["<reasoning>simple</reasoning><answer>100</answer>"],
answer=["100"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)

@pytest.mark.cpu_only
def test_extracted_does_not_match_answer(self):
Expand All @@ -294,7 +295,7 @@ def test_extracted_matches_comma_formatted_number(self):
completions=["<reasoning>cost calculation</reasoning><answer>1,000</answer>"],
answer=["1000"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)

@pytest.mark.cpu_only
def test_extracted_matches_with_currency_prefix(self):
Expand All @@ -303,7 +304,7 @@ def test_extracted_matches_with_currency_prefix(self):
completions=["<reasoning>price is $16</reasoning><answer>$16</answer>"],
answer=["16"],
)
self.assertEqual(scores[0], 1.5)
self.assertEqual(scores[0], self.config.reward_exact_answer)

@pytest.mark.cpu_only
def test_extracted_non_numeric_no_match(self):
Expand Down
19 changes: 19 additions & 0 deletions tests/post_training/unit/train_rl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,25 @@ def _get_mock_devices(devices_per_slice, num_slices=1):
class TrainRLTest(unittest.TestCase):
"""Tests for train_rl.py."""

@pytest.mark.cpu_only
def test_make_reward_fn_preserves_name_and_injects_config(self):
"""_make_reward_fn should preserve __name__ and pass through tmvp_config."""
trainer_config = SimpleNamespace(name="cfg")

def sample_reward_fn(tmvp_config, prompts, completions, answer, **kwargs):
self.assertIs(tmvp_config, trainer_config)
self.assertEqual(prompts, ["p"])
self.assertEqual(completions, ["c"])
self.assertEqual(answer, ["a"])
self.assertEqual(kwargs.get("question"), ["q"])
return [1.5]

wrapped_reward_fn = train_rl.utils_rl.make_reward_fn(trainer_config, sample_reward_fn)

self.assertEqual(wrapped_reward_fn.__name__, "sample_reward_fn")
result = wrapped_reward_fn(prompts=["p"], completions=["c"], answer=["a"], question=["q"])
self.assertEqual(result, [1.5])

@pytest.mark.cpu_only
def test_setup_configs_and_devices_pathways_split(self):
"""Test setup_configs_and_devices with multiple VMs and Pathways."""
Expand Down
Loading