-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
288 lines (224 loc) · 8.71 KB
/
main.py
File metadata and controls
288 lines (224 loc) · 8.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import os
from openai import OpenAI
from dotenv import load_dotenv
from halo import Halo
import argparse
from pathlib import Path
from collections import namedtuple
from diff_wrapper import diff_asm
from compile import try_compile
import template
OPENAI_MODEL = "gpt-4o-2024-08-06"
ASM_FILENAME = "inputs/input.s"
M2C_OUTPUT_FILENAME = "outputs/m2c-output.c"
# Load environment variables from .env file
load_dotenv()
openai_client = OpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
)
system_prompt = Path("system.txt").read_text()
asm = Path(ASM_FILENAME).read_text()
# We need this for the asm differ, as it doesn't just diff the whole file.
def get_first_function_asm(asm):
first_fn_index = asm.find(".fn")
slc = asm[first_fn_index + len(".fn ") :]
fn_name_end_index = slc.find(",")
return slc[:fn_name_end_index]
first_function_name = get_first_function_asm(asm)
def extract_c_from_openai_response(response):
# Extract the output message
output_message = response.choices[0].message.content
# Extract only the C code from the output (assuming that the output is clean and directly usable)
# This part can be adjusted based on the specific format of the output.
code_start = output_message.find("```c")
code_end = output_message.rfind("```")
if code_start != -1 and code_end != -1:
c_code = output_message[code_start + len("```c") : code_end].strip()
else:
c_code = output_message
return c_code
def query_chatgpt(system_message, user_message, console_message, filename_pass: int):
with open(f"outputs/tmp-message-{filename_pass}.md", "w") as msg_file:
msg_file.write(user_message)
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message},
]
with Halo(text=console_message, spinner="dots"):
response = openai_client.chat.completions.create(
model=OPENAI_MODEL, messages=messages
)
return response
def initial_pass():
# TODO(sjayakar): not sure if this is the best place to open the file
with open(M2C_OUTPUT_FILENAME, "r") as m2c_file:
m2c_output = m2c_file.read()
user_message = template.initial_pass_message(asm, m2c_output)
response = query_chatgpt(
system_prompt, user_message, "Querying ChatGPT for the first .c file...", 0
)
c_code = extract_c_from_openai_response(response)
write_c_file(0, c_code)
def write_c_file(pass_number, c_code):
# Write the C code to the output file
with open(f"outputs/output-{pass_number}.c", "w") as file:
file.write(c_code)
print(f"The C code has been generated and saved to outputs/output-{pass_number}.c.")
def clean():
os.system("rm outputs/*")
print("Removed all files in output directory")
# TODO(sjayakar): this is silly on multiple levels. i read more and
# more files based on # of passes, and I could have just held it in
# memory 😭.
def fix_compiler_errors(state):
# TODO: read all sources & error messages
c_and_errs = []
for attempt in state.attempts:
c_and_errs.append((attempt.c_code, attempt.errors))
user_message = template.error_message(asm, c_and_errs)
with open(f"outputs/tmp-message-{state.filename_counter}.md", "w") as msg_file:
msg_file.write(user_message)
response = query_chatgpt(
system_prompt,
user_message,
"Querying ChatGPT to fix compiler errors",
state.filename_counter,
)
c_code = extract_c_from_openai_response(response)
write_c_file(state.filename_counter, c_code)
# Attempt to improve the score
def successful_chain(state):
candidate_messages = "\n".join(
[
template.successful_chain_message(c.c_code, c.score, c.diff)
for c in state.candidates
]
)
initial_message = template.initial_pass_message(asm, state.m2c)
message = f"{initial_message}\n{candidate_messages}"
response = query_chatgpt(
system_prompt,
message,
"Querying ChatGPT to improve the ASM score",
state.filename_counter,
)
c_code = extract_c_from_openai_response(response)
write_c_file(state.filename_counter, c_code)
# Convert .s to .o for eventually comparing output.o files
def assemble_base():
print("Assembling the .s file to get expected.o")
os.system(f"bin/powerpc-eabi-as -I inputs {ASM_FILENAME}")
os.system("mv a.out outputs/expected.o")
if not os.path.exists("outputs/expected.o"):
raise FileNotFoundError("failed to assemble")
# Get the default output of m2c to help ground the ChatGPT response
def m2c():
print("Running m2c...")
os.system(
f"python3 ../m2c/m2c.py --target ppc-mwcc-c {ASM_FILENAME} > {M2C_OUTPUT_FILENAME}"
)
if not os.path.exists(M2C_OUTPUT_FILENAME):
raise FileNotFoundError("m2c failed")
def compile_and_log_error(base_name):
with Halo(
text=f"Compiling",
spinner="dots",
):
compiled_successfully, message = try_compile(f"outputs/{base_name}.c")
if not compiled_successfully:
with open(f"outputs/{base_name}.error", "w") as error_file:
error_file.write(message)
return compiled_successfully
def parse_args():
parser = argparse.ArgumentParser(description="AI decompilation script")
parser.add_argument(
"--diff-only",
action="store_true",
help="Skip generation and just diff with expected",
)
args = parser.parse_args()
return args
# Candidate means "code that compiles and has an ASM diff"
Candidate = namedtuple("NamedTuple", ["c_code", "diff", "score"])
# An attempt is C code & compiler errors
Attempt = namedtuple("Attempt", ["c_code", "errors"])
# State machine with metadata. The program is either improving on the
# ASM diff result or it's attempting to fix compiler errors.
class State:
def __init__(self):
# Stuff that's always set
self.state = STATE_INITIAL
self.filename_counter = 0
# Read in all the other stuff
with open(M2C_OUTPUT_FILENAME, "r") as m2c_file:
self.m2c = m2c_file.read()
with open("outputs/output-0.c", "r") as initial_pass_file:
self.initial_pass = initial_pass_file.read()
self.candidates = []
def current_filename_prefix(self):
return f"output-{self.filename_counter}"
def _to_err(self):
self.state = STATE_ERRORS
self.attempts = []
def add_err(self, filename_prefix):
if self.state != STATE_ERRORS:
self._to_err()
with open(f"outputs/{filename_prefix}.c", "r") as c_file:
c_code = c_file.read()
with open(f"outputs/{filename_prefix}.error", "r") as err_file:
errs = err_file.read()
self.attempts.append(Attempt(c_code, errs))
def _to_candidate(self):
self.state = STATE_CANDIDATE
def add_candidate(self, filename_prefix):
if self.state != STATE_CANDIDATE:
self._to_candidate()
# TODO: assumes that the candidate was the last compiled code. Yikes!
with open(f"outputs/{filename_prefix}.c", "r") as c_file:
c_code = c_file.read()
diff_output, score = diff_asm(first_function_name)
self.candidates.append(Candidate(c_code, diff_output, score))
STATE_INITIAL = "INITIAL"
STATE_ERRORS = "ERRORS"
STATE_CANDIDATE = "CANDIDATE"
def main():
args = parse_args()
if args.diff_only:
diff_output, score = diff_asm(first_function_name)
print(f"ASM Score: {score}")
print(diff_output)
return
clean()
assemble_base()
m2c()
initial_pass()
prefix = "output-0"
compiled_successfully = compile_and_log_error(prefix)
state = State()
if compiled_successfully:
state.add_candidate(prefix)
else:
state.add_err(prefix)
while True:
state.filename_counter += 1
filename_prefix = state.current_filename_prefix()
if state.state == STATE_ERRORS:
print(f"❌ Did not compile, starting compile pass {state.filename_counter}")
fix_compiler_errors(state)
elif state.state == STATE_CANDIDATE:
last_candidate = state.candidates[-1]
print(
f"✅ Code compiled. Current ASM score: {last_candidate.score}. Attempting to improve"
)
successful_chain(state)
else:
raise Exception(f"invalid state {state}")
# By now, the new C file has been written to disk. Attempt to
# compile it and then possibly state transition.
compiled_successfully = compile_and_log_error(filename_prefix)
if compiled_successfully:
state.add_candidate(filename_prefix)
else:
state.add_err(filename_prefix)
if __name__ == "__main__":
main()