Skip to content

Commit 768d52c

Browse files
simonguoziruiWilly-ChanSimon GuoEthanethanboneh
authored
Static Kernel Code Checker (#110)
* start code checker to avoid reward hacking * prelim thunderkittens code check * optional producer-consumer semantics, setup for compilation check * proposed rejection behavior * an initial modularized version, needs to simplify and cleaning * clearer organization for static reward hack checking * Added more static checks ethan wrote (from deep reinforce blog) and dsl tests Co-authored-by: Simon Guo <simoguo@stanford.edu> Co-authored-by: Ethan <ethanbc@stanford.edu> * add integration paths and simplify test cases, (could add more adverserial ones) * gate behind a flag for now as a start for the checker system * precision downgrades * precision downgrade check and test --------- Co-authored-by: Willy-Chan <willychan2022@gmail.com> Co-authored-by: Simon Guo <simoguo@stanford.edu> Co-authored-by: Ethan <ethanbc@stanford.edu> Co-authored-by: ethanboneh <ethanboneh@gmail.com>
1 parent f393682 commit 768d52c

10 files changed

Lines changed: 1709 additions & 16 deletions

EVAL.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ We have (and continue to) implement various approaches to conduct kernel timing
3636

3737
Check out `timing.py` to see available timing methods and `src/unit_tests/test_eval_timing.py` to test out various timing methods (including leveraging `cuda_event` marker, Triton `do_bench`, `host_time` E2E time). @palic and team is working on a blogpost explaining the different tradeoffs soon.
3838

39+
### Checkers
40+
There are potentially many ways model might reward hack and we would like to catch the known ways through checkers [experimental and WIP]. We start with `kernel_static_checker.py`, which is a regex-based checker on the genenrated code against set of rules. We plan to add AST-based, LM-as-a-judge, and more runtime checks in the future. We welcome suggestions and contributions here.
41+
3942
### Unit Tests with Adversarial Examples
4043
We've included some unit tests for the eval script in `src/unit_tests/test_eval_adversarial.py`. These tests run adversarial kernels (see `src/unit_tests/test_kernels/`) that contain examples of reward hacking that we've seen from LLMs and ensures that the eval script catches them, either by failing their correctness checks or flagging them for excessive speedups. Examples include:
4144
- Reusing computations cached during the PyTorch reference

README.md

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -117,24 +117,10 @@ uv run python scripts/generate_and_eval_single_sample.py dataset_src=huggingface
117117
* **`precision`** - You can specify the precision of tensor by `precision=fp32`. Currently all of our reported results are `fp32` but we added support for `fp16` & `bf16`.
118118
* **`backend`** - We are also supporting other GPU programming languages beyond `cuda`. Simply specify `backend=triton`. For now we support DSLs: `cuda`, `triton`, `cute`, `tilelang`, `thunderkittens`.
119119

120-
Check the config fields for comprehensive set of options. Note we provide the model with a one-shot example by default along with the minimum set of info; you can check out other prompt settings or construct your own in `src/prompt_constructor_toml.py`.
121-
122-
### Running Thunderkittens Locally
123-
If you plan on using `scripts/generate_and_eval_single_sample.py` using `backend=thunderkittens`, make sure to git clone the ThunderKittens repo and you set the following environment variable to point to your local ThunderKittens directory:
124-
125-
```bash
126-
export THUNDERKITTENS_ROOT=/Users/willychan/Desktop/projects/KernelBench/ThunderKittens
127-
```
128120

129-
As seen in `src/kernelbench/prompts/model_new_ex_add_thunderkittens.py`, the generated kernels should have the following line:
121+
Note on setting up ThunderKittens (TK) locally: to use `backend=thunderkittens`, you need to git clone the ThunderKittens repo and set the following environment variable to point to your local ThunderKittens directory, `export THUNDERKITTENS_ROOT=<PATH to ThunderKittens folder>`, and all ThunderKitten programs as shown in the [example](src/kernelbench/prompts/model_new_ex_add_thunderkittens.py), should contain `tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens")`, which enable the kernel to include the right TK primitives. In addition, we only support BF16 for TK right now.
130122

131-
```bash
132-
tk_root = os.environ.get("THUNDERKITTENS_ROOT", "/root/ThunderKittens")
133-
```
134-
135-
This allows the kernel to include the right TK primitives.
136-
137-
*NOTE*: Right now, all generated ThunderKittens kernels are required to be in datatype format BF16. FP16 support is TBD.
123+
Check the config fields for comprehensive set of options. Note we provide the model with a one-shot example by default along with the minimum set of info; you can check out other prompt settings or construct your own in `src/prompt_constructor_toml.py`.
138124

139125
### Run on all problems
140126

scripts/generate_and_eval_single_sample.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(self):
8383
self.hardware_gpu_name = None
8484
self.custom_prompt_key = None
8585

86+
self.check_kernel = True # [experimental] optional static checker catching potential hacking patterns
87+
8688
def verbose_logging(self):
8789
self.log = True
8890
self.log_prompt = True
@@ -260,6 +262,19 @@ def main(config: EvalConfig):
260262
custom_kernel is not None
261263
), f"Custom {config.backend} kernel code generation failed"
262264

265+
# Optional: static code checker for kernel code using regex matching
266+
# NOTE: by no means is this checker complete, but it could help catch some potential hacks
267+
if config.check_kernel:
268+
from kernelbench.kernel_static_checker import validate_kernel_static
269+
static_check_status, errors, warnings = validate_kernel_static(
270+
custom_kernel,
271+
backend=config.backend,
272+
precision=config.precision,
273+
)
274+
assert static_check_status, f"Static check failed for level {config.level} problem {config.problem_id}. Errors: {errors}. Warnings: {warnings}"
275+
if warnings:
276+
print(f"Static check warnings for level {config.level} problem {config.problem_id}: {warnings}")
277+
263278
# this should be optional
264279
if config.log:
265280
with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f:

scripts/generate_and_eval_single_sample_modal.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def __init__(self):
8080
self.hardware_gpu_name = None
8181
self.custom_prompt_key = None
8282

83+
self.check_kernel = True # [experimental] optional static checker catching potential hacking patterns
84+
8385
def verbose_logging(self):
8486
self.log = True
8587
self.log_prompt = True
@@ -283,6 +285,19 @@ def main(config: EvalConfig):
283285
# check LLM is able to generate custom kernel code
284286
assert custom_kernel is not None, f"Custom {config.backend} kernel code generation failed"
285287

288+
# Optional: static code checker for kernel code using regex matching
289+
# NOTE: by no means is this checker complete, but it could help catch some potential hacks
290+
if config.check_kernel:
291+
from kernelbench.kernel_static_checker import validate_kernel_static
292+
static_check_status, errors, warnings = validate_kernel_static(
293+
custom_kernel,
294+
backend=config.backend,
295+
precision=config.precision,
296+
)
297+
assert static_check_status, f"Static check failed for level {config.level} problem {config.problem_id}. Errors: {errors}. Warnings: {warnings}"
298+
if warnings:
299+
print(f"Static check warnings for level {config.level} problem {config.problem_id}: {warnings}")
300+
286301
# this should be optional
287302
if config.log:
288303
with open(os.path.join(config.logdir, f"generated_kernel_level_{config.level}_problem_{config.problem_id}.py"), "w") as f:

scripts/generate_samples.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
read_file,
1919
set_gpu_arch,
2020
)
21+
from kernelbench.kernel_static_checker import validate_kernel_static
2122

2223
"""
2324
Batch Generate Samples for Particular Level
@@ -84,6 +85,8 @@ def __init__(self):
8485
self.hardware_gpu_name = None
8586
self.custom_prompt_key = None
8687

88+
self.check_kernel = True # [experimental] optional static checker catching potential hacking patterns
89+
8790
def greedy(self):
8891
# For greedy decoding, epsecially baseline eval
8992
self.greedy_sample = True
@@ -162,6 +165,19 @@ def generate_sample_single(
162165
# check LLM is able to generate custom CUDA code
163166
assert custom_kernel is not None, "Custom CUDA code generation failed"
164167

168+
# Optional: we provide a static code checker for kernel code using regex matching
169+
# NOTE: by no means, is this checker complete, but it might could help catch some potential hacks and issues
170+
if config.check_kernel:
171+
static_check_status, error, warnings = validate_kernel_static(custom_kernel,
172+
backend=config.backend,
173+
precision=config.precision,
174+
# uses the default set of forbidden and warning patterns,
175+
# you could adapt the patterns to your own setting (degree of banning cuda stream, allowing some torch ops)
176+
)
177+
assert static_check_status, f"Static check failed for sample {work.sample_id} for problem {problem_number}: {problem_name}. Error: {error}. Warnings: {warnings}"
178+
if warnings:
179+
print(f"Static check warnings for sample {work.sample_id} for problem {problem_number}: {problem_name}. Warnings: {warnings}")
180+
165181
if config.verbose:
166182
print(
167183
f"Generated sample {work.sample_id} for problem {problem_number}: {problem_name}"

scripts/run_and_check.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from kernelbench import utils as kernel_utils
1111
from scripts.generate_baseline_time import measure_program_time
1212
from kernelbench.utils import read_file
13+
from kernelbench.kernel_static_checker import validate_kernel_static
1314

1415
# Modal setup
1516
app = modal.App("run_and_check")
@@ -120,6 +121,8 @@ def __init__(self):
120121
self.precision = "fp32"
121122
self.backend = "cuda"
122123

124+
self.check_kernel = True # [experimental] optional static checker catching potential hacking patterns
125+
123126
def __repr__(self):
124127
return f"ScriptConfig({self.to_dict()})"
125128

@@ -279,6 +282,18 @@ def main(config: ScriptConfig):
279282

280283
kernel_src = read_file(config.kernel_src_path)
281284

285+
# Optional: static code checker for kernel code using regex matching
286+
# NOTE: by no means is this checker complete, but it could help catch some potential hacks
287+
if config.check_kernel:
288+
static_check_status, errors, warnings = validate_kernel_static(
289+
kernel_src,
290+
backend=config.backend,
291+
precision=config.precision,
292+
)
293+
assert static_check_status, f"Static check failed. Errors: {errors}. Warnings: {warnings}"
294+
if warnings:
295+
print(f"[WARN] Static check warnings: {warnings}")
296+
282297
# Start Evaluation
283298
assert config.eval_mode in ["local", "modal"], "eval_mode must be either 'local' or 'modal'"
284299

0 commit comments

Comments
 (0)