|
18 | 18 | read_file, |
19 | 19 | set_gpu_arch, |
20 | 20 | ) |
| 21 | +from kernelbench.kernel_static_checker import validate_kernel_static |
21 | 22 |
|
22 | 23 | """ |
23 | 24 | Batch Generate Samples for Particular Level |
@@ -84,6 +85,8 @@ def __init__(self): |
84 | 85 | self.hardware_gpu_name = None |
85 | 86 | self.custom_prompt_key = None |
86 | 87 |
|
| 88 | + self.check_kernel = True # [experimental] optional static checker catching potential hacking patterns |
| 89 | + |
87 | 90 | def greedy(self): |
88 | 91 | # For greedy decoding, epsecially baseline eval |
89 | 92 | self.greedy_sample = True |
@@ -162,6 +165,19 @@ def generate_sample_single( |
162 | 165 | # check LLM is able to generate custom CUDA code |
163 | 166 | assert custom_kernel is not None, "Custom CUDA code generation failed" |
164 | 167 |
|
| 168 | + # Optional: we provide a static code checker for kernel code using regex matching |
| 169 | + # NOTE: by no means, is this checker complete, but it might could help catch some potential hacks and issues |
| 170 | + if config.check_kernel: |
| 171 | + static_check_status, error, warnings = validate_kernel_static(custom_kernel, |
| 172 | + backend=config.backend, |
| 173 | + precision=config.precision, |
| 174 | + # uses the default set of forbidden and warning patterns, |
| 175 | + # you could adapt the patterns to your own setting (degree of banning cuda stream, allowing some torch ops) |
| 176 | + ) |
| 177 | + assert static_check_status, f"Static check failed for sample {work.sample_id} for problem {problem_number}: {problem_name}. Error: {error}. Warnings: {warnings}" |
| 178 | + if warnings: |
| 179 | + print(f"Static check warnings for sample {work.sample_id} for problem {problem_number}: {problem_name}. Warnings: {warnings}") |
| 180 | + |
165 | 181 | if config.verbose: |
166 | 182 | print( |
167 | 183 | f"Generated sample {work.sample_id} for problem {problem_number}: {problem_name}" |
|
0 commit comments