-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathast_utils.py
More file actions
201 lines (169 loc) · 6.96 KB
/
ast_utils.py
File metadata and controls
201 lines (169 loc) · 6.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# ast_utils.py
import ast
import astunparse
import logging
import os
import re
class TestFunctionTransformer(ast.NodeTransformer):
"""AST transformer for test function conversion"""
def visit_FunctionDef(self, node):
# First, process main function (remove it)
if node.name == "main":
return None
# Process TestInput/TestOneInput functions
if node.name in ["TestInput", "TestOneInput"]:
# a. Record parameter name (assume only one parameter)
param_name = None
if node.args.args:
param_name = node.args.args[0].arg
# b. Rename function to test_
node.name = "test_"
# c. Remove parameters (set argument list to empty)
node.args = ast.arguments(
posonlyargs=[],
args=[],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[],
)
# d. Insert param_name = b"" at the beginning of the function body
if param_name:
self.add_param_assignment(node, param_name)
# Ensure traversing child nodes continues
self.generic_visit(node)
return node
def add_param_assignment(self, node, param_name):
"""Add param_name = b"..." at the beginning of the function body with an inline comment"""
# Create a compound value containing assignment and comment
value_with_comment = ast.JoinedStr(
values=[
ast.FormattedValue(value=ast.Constant(value=b""), conversion=-1),
ast.Constant(value=" # This is a test template"),
]
)
# Create an assignment node
assign_node = ast.Assign(
targets=[ast.Name(id=param_name, ctx=ast.Store())], value=value_with_comment
)
# If there is a docstring, insert after the docstring
if (
node.body
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Constant)
and isinstance(node.body[0].value.value, str)
):
# Insert right after the docstring
node.body.insert(1, assign_node)
else:
# Insert at the beginning of the function
node.body.insert(0, assign_node)
def remove_print_param(self, node, param_name):
"""Remove print statements for the specific parameter"""
new_body = []
for stmt in node.body:
# Skip print(param_name) calls
if (
isinstance(stmt, ast.Expr)
and isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id == "print"
and any(
isinstance(arg, ast.Name) and arg.id == param_name
for arg in stmt.value.args
)
):
continue
new_body.append(stmt)
node.body = new_body
def visit_If(self, node):
"""Remove if __name__ == '__main__' blocks"""
# Check if this is the main function guard
if (
isinstance(node.test, ast.Compare)
and isinstance(node.test.left, ast.Name)
and node.test.left.id == "__name__"
and isinstance(node.test.ops[0], ast.Eq)
and isinstance(node.test.comparators[0], ast.Constant)
and node.test.comparators[0].value == "__main__"
):
# Remove the entire if block
return None
# Ensure traversing child nodes continues
self.generic_visit(node)
return node
class TestGenTransformer(ast.NodeTransformer):
def __init__(self, idx, fuzz_input):
self.idx = idx
self.fuzz_input = fuzz_input
self.found_test_function = False
def visit_FunctionDef(self, node):
if node.name == "test_":
self.found_test_function = True
# 1. Modify function name
node.name = f"test_{self.idx}"
# 2. Find and replace assignment statements with the special comment
for i, stmt in enumerate(node.body):
# Check if it's an assignment statement
if isinstance(stmt, ast.Assign):
# Check if the value is a compound value with a comment
if (
isinstance(stmt.value, ast.JoinedStr)
and len(stmt.value.values) >= 2
and isinstance(stmt.value.values[1], ast.Constant)
and stmt.value.values[1].value == " # This is a test template"
):
# Replace with new fuzz input
stmt.value = ast.Constant(value=self.fuzz_input)
break
return node
def generate_test_template(target_name: str, repo_path: str):
"""
Generate Python test template using AST for more precise code transformations
"""
src_file = os.path.join(repo_path, target_name)
logging.info(f"Generating test template for {src_file}")
if not src_file.endswith(".py"):
src_file += ".py"
if not os.path.exists(src_file):
logging.error(f"Source target file not found: {src_file}")
return None
with open(src_file, "r", encoding="utf-8") as f:
original_code = f.read()
# --- 1. Keep shebang but remove license comments ---
shebang = ""
if original_code.startswith("#!"):
shebang, original_code = original_code.split("\n", 1)
shebang += "\n"
license_pattern = re.compile(
r"^(?:\s*#.*\n)*\s*#.*limitations\s+under\s+the\s+license.*\n",
re.IGNORECASE | re.MULTILINE,
)
code_no_license = re.sub(license_pattern, "", original_code, count=1)
# --- 2. Parse code to AST ---
try:
tree = ast.parse(code_no_license)
except SyntaxError as e:
logging.error(f"Syntax error in {src_file}: {e}")
return None
# --- 3. AST transformation ---
transformer = TestFunctionTransformer()
new_tree = transformer.visit(tree)
ast.fix_missing_locations(new_tree)
# --- 4. Generate cleaned code ---
cleaned_code = astunparse.unparse(new_tree)
# --- 5. Output to tests-gen directory ---
template_dir = os.path.join(repo_path, "tests-gen")
os.makedirs(template_dir, exist_ok=True)
init_path = os.path.join(template_dir, "__init__.py")
if not os.path.exists(init_path):
with open(init_path, "w", encoding="utf-8") as f:
f.write("")
# Use the base part of target_name (remove extension) as the output file name
base_target_name = os.path.splitext(target_name)[0]
template_path = os.path.join(template_dir, f"{base_target_name}.py")
with open(template_path, "w", encoding="utf-8") as f:
f.write(shebang + cleaned_code.strip() + "\n")
logging.info(f"Generated cleaned template: {template_path}")
return template_path