-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgrpo_trainer.py
More file actions
155 lines (127 loc) · 3.97 KB
/
grpo_trainer.py
File metadata and controls
155 lines (127 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import re
import signal
import torch
from datasets import load_dataset
from trl import GRPOTrainer, GRPOConfig
from unsloth import FastLanguageModel
MODEL_NAME = "unsloth/Qwen2.5-Coder-0.5B-bnb-4bit"
DATASET_NAME = "newfacade/LeetCodeDataset"
OUTPUT_DIR = "grpo-leetcode-coder"
dataset = load_dataset(DATASET_NAME, split="train")
class TimeoutException(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutException("Execution timed out")
def extract_python_code(text):
"""
Extracts content strictly between ```python and ``` tags.
Returns None if no valid block is found.
"""
pattern = r"```python\s+(.*?)\s+```"
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1)
pattern_generic = r"```\s+(.*?)\s+```"
match_generic = re.search(pattern_generic, text, re.DOTALL)
if match_generic:
return match_generic.group(1)
return None
def run_test_case(generated_code, test_code, timeout=2):
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
try:
exec_globals = {}
exec(generated_code, exec_globals)
candidate_func = None
for name, obj in exec_globals.items():
if callable(obj) and name != '__builtins__':
candidate_func = obj
if not candidate_func:
return 0.0
test_globals = {}
exec(test_code, test_globals)
if 'check' not in test_globals:
return 0.0
check_fn = test_globals['check']
check_fn(candidate_func)
return 1.0
except AssertionError:
return 0.0
except TimeoutException:
return 0.0
except Exception as e:
return 0.0
finally:
signal.alarm(0)
def correctness_reward_func(prompts, completions, **kwargs):
tests = kwargs["test"]
rewards = []
for completion, test_case in zip(completions, tests):
code = extract_python_code(completion)
print(code)
if not code:
rewards.append(0.0)
continue
reward = run_test_case(code, test_case)
rewards.append(reward)
return rewards
def format_reward_func(completions, **kwargs):
"""
Soft reward to encourage correct formatting {Explanation ... Code}.
Checks if ```python exists.
"""
rewards = []
for completion in completions:
if "```python" in completion and "```" in completion.split("```python")[1]:
rewards.append(0.1)
else:
rewards.append(0.0)
return rewards
def prompt_func(data):
return [
f"Problem Description:\n{q}\n\nPlease provide a solution in Python. Output an explanation followed by the code block wrapped in ```python ... ```."
for q in data['problem_description']
]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = MODEL_NAME,
# max_seq_length = max_seq_length,
load_in_4bit = True,
full_finetuning = False
)
model = FastLanguageModel.get_peft_model(
model,
r = 32,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 64,
lora_dropout = 0,
bias = "lora_only",
use_gradient_checkpointing = "unsloth",
)
training_args = GRPOConfig(
output_dir=OUTPUT_DIR,
learning_rate=1e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
logging_steps=1,
bf16=True,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
num_generations=4,
max_prompt_length=512,
max_completion_length=512,
max_steps=500,
save_steps=100,
)
trainer = GRPOTrainer(
model=model,
reward_funcs=[correctness_reward_func, format_reward_func],
args=training_args,
train_dataset=dataset,
)
if __name__ == "__main__":
trainer.train()
trainer.save_model(OUTPUT_DIR)