-
Notifications
You must be signed in to change notification settings - Fork 160
Expand file tree
/
Copy patheval_from_generations.py
More file actions
973 lines (815 loc) · 35.7 KB
/
eval_from_generations.py
File metadata and controls
973 lines (815 loc) · 35.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
import json
import multiprocessing as mp
import os
import shutil
import time
from dataclasses import dataclass
from collections import defaultdict
from dataclasses import dataclass
import numpy as np
import pydra
import torch
from pydra import Config, REQUIRED
# Import only what we need
from kernelbench import compile, eval, utils
from kernelbench.dataset import construct_kernelbench_dataset
from kernelbench.eval import (
build_compile_cache,
get_error_name,
check_metadata_serializable_all_types,
eval_kernel_against_ref,
KernelExecResult,
)
from kernelbench.utils import read_file, set_gpu_arch
from tqdm import tqdm
# Modal support
import modal
"""
Batch Evaluation from Existing Generations
This expects you have generated the kernels and stored them in the runs/{run_name} directory
This eval script will evaluate the kernels against the reference architecture, and store the results in the runs/{run_name}/eval_results.json file
Usually with eval, we check
- correctness (n_correct): 5 randomized input trials
- performance (n_trials): 100 randomized input trials
You can increase the number of trials for correctness and performance
"""
REPO_TOP_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
torch.set_printoptions(precision=4, threshold=10)
# Modal Infrastructure Setup
app = modal.App("eval_from_generations_modal")
gpu_arch_mapping = {"L40S": ["Ada"], "H100": ["Hopper"], "A100": ["Ampere"], "L4": ["Ada"], "T4": ["Turing"], "A10G": ["Ampere"]}
cuda_version = "13.0.0" # should be no greater than host CUDA version
flavor = "devel" # includes full CUDA toolkit
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"
SRC_DIR = os.path.join(REPO_TOP_DIR, "src")
KERNELBENCH_DIR = os.path.join(REPO_TOP_DIR, "KernelBench")
image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git",
"gcc-10",
"g++-10",
"clang"
)
.uv_sync(uv_project_dir=REPO_TOP_DIR)
.run_commands("git clone -b main https://github.com/HazyResearch/ThunderKittens.git /root/ThunderKittens")
.env({
"THUNDERKITTENS_ROOT": "/root/ThunderKittens",
"PYTHONPATH": "/root/src:/root"
})
.add_local_dir(SRC_DIR, remote_path="/root/src")
.add_local_dir(KERNELBENCH_DIR, remote_path="/root/KernelBench") # must be last
)
class EvalConfig(Config):
def __init__(self):
self.run_name = REQUIRED # name of the run to evaluate
self.dataset_src = REQUIRED # either huggingface or local
# name of dataset name on Hugging Face
self.dataset_name = "ScalingIntelligence/KernelBench"
# Problem Specification
self.level = REQUIRED
# subset of problems to evaluate
self.subset = (None, None) # (start_id, end_id), these are the logical index
# specific problem IDs to evaluate (overrides subset if provided)
self.problem_ids = None # e.g., [71, 86, 95] for specific problems
# Evaluation Mode: local (requires GPU), modal (cloud GPU)
self.eval_mode = "local"
# For Modal: GPU type to use (L40S, H100, A100, L4, T4, A10G)
self.gpu = "A10G"
# Construct this from mapping from architecture name to torch cuda arch list in the future
# you can either specify SM version or just use the name
self.gpu_arch = ["Ada"]
# Logging
# Top Directory to Store Runs
self.runs_dir = os.path.join(REPO_TOP_DIR, "runs")
self.verbose = False
# Eval settings
self.num_correct_trials = 5
self.num_perf_trials = 100
self.timeout = 180 # in seconds
self.measure_performance = True
self.timing_method = "cuda_event"
# Eval Flow setting
# To speedup evaluation, you can start building the kernel on CPU on disk as cache
self.build_cache = False
self.num_cpu_workers = (
20 # number of parallel process to to parallelize the build on CPUs
)
# Directory to build kernels for evaluation
self.kernel_eval_build_dir = os.path.join(REPO_TOP_DIR, "cache")
# number of GPUs to do batch evaluation
self.num_gpu_devices = 1
# Backend to use for kernel implementation (cuda or triton)
self.backend = "cuda"
# Precision for computation: "fp32", "fp16", "bf16"
self.precision = "fp32"
# Number of samples per problem to evaluate for pass@k analysis
self.num_samples_per_problem = 1 # Default to 1 sample per problem
# List of k values for pass@k calculation (e.g., [1, 5, 10])
self.pass_at_k_values = [1] # Default to only pass@1
def __repr__(self):
return f"EvalConfig({self.to_dict()})"
@dataclass
class WorkArgs:
problem_id: int
sample_id: int
device: torch.device
# Modal Evaluation Class
# GPU must be specified here for all instances
# Retries are configured at the class level to handle GPU attachment failures
@app.cls(
image=image,
gpu="A10G",
retries=modal.Retries(
max_retries=3,
backoff_coefficient=2.0,
initial_delay=1.0,
)
)
class ModalEvaluator:
@modal.method()
def evaluate_single_sample_modal(
self,
ref_arch_src: str,
kernel_src: str,
gpu_arch: list[str],
num_correct_trials: int = 5,
num_perf_trials: int = 100,
measure_performance: bool = True,
timing_method: str = "cuda_event",
verbose: bool = False,
backend: str = "cuda",
precision: str = "fp32",
):
"""
Evaluate a single sample on Modal GPU with automatic retries for GPU attachment failures
and proper GPU corruption handling via stop_fetching_inputs()
"""
from kernelbench.eval import eval_kernel_against_ref, get_torch_dtype_from_string
from kernelbench.utils import set_gpu_arch
import torch
import time
import modal.experimental
max_wait_time = 30
start_time = time.time()
gpu_available = False
while time.time() - start_time < max_wait_time:
if torch.cuda.is_available():
gpu_available = True
break
# Progressive backoff: 0.5s, 1s, 2s, 4s, 8s...
wait_time = min(0.5 * (2 ** int((time.time() - start_time) / 2)), 8.0)
time.sleep(wait_time)
if not gpu_available:
raise RuntimeError(
f"GPU not attached to container after {max_wait_time}s - Modal will retry with new container"
)
set_gpu_arch(gpu_arch)
gpu_corrupted = False
try:
result = eval_kernel_against_ref(
original_model_src=ref_arch_src,
custom_model_src=kernel_src,
measure_performance=measure_performance,
timing_method=timing_method,
verbose=verbose,
num_correct_trials=num_correct_trials,
num_perf_trials=num_perf_trials,
build_dir=None,
device=torch.device("cuda:0"),
backend=backend,
precision=get_torch_dtype_from_string(precision),
)
except (torch.cuda.CudaError, torch.AcceleratorError) as e:
# GPU error detected - retire this container to prevent contamination
gpu_corrupted = True
# TODO: Replace with more stable API in the future, thanks modal team for temp workaround.
modal.experimental.stop_fetching_inputs()
result = KernelExecResult(
compiled=False,
correctness=False,
metadata={
"gpu_error": type(e).__name__,
"error_message": str(e)[:500],
},
runtime=-1.0,
runtime_stats={},
)
if not gpu_corrupted:
torch.cuda.empty_cache()
return result
def fetch_ref_arch_from_problem_id(
dataset, problem_id: int, dataset_src: str = None
) -> str | None:
"""
Fetch reference architecture from problem directory.
Uses the unified dataset interface.
Note: dataset_src parameter is kept for backward compatibility but ignored
since the dataset object already handles both sources.
"""
problem = dataset.get_problem_by_id(problem_id)
return problem.code
def fetch_kernel_from_disk(
run_dir: str, level: int, problem_id: int, sample_id: int
) -> str | None:
"""
Fetch kernel file from disk (stored in runs/{run_name})
"""
kernel_path = os.path.join(
run_dir, f"level_{level}_problem_{problem_id}_sample_{sample_id}_kernel.py"
)
if os.path.exists(kernel_path):
return read_file(kernel_path)
else:
return None
def evaluate_single_sample(
work_args: WorkArgs, configs: EvalConfig, dataset, run_dir: str
) -> KernelExecResult | None:
"""
Evaluate a single sample on a single GPU
"""
problem_id, sample_id, device = (
work_args.problem_id,
work_args.sample_id,
work_args.device,
)
# fetch reference architecture from problem directory
ref_arch_src = fetch_ref_arch_from_problem_id(
dataset, problem_id, configs.dataset_src
)
# fetch kernel from disk
# Add database support in the future
kernel_src = fetch_kernel_from_disk(run_dir, configs.level, problem_id, sample_id)
assert (
kernel_src is not None
), f"Kernel not found for problem {problem_id} sample {sample_id}"
build_dir = os.path.join(
configs.kernel_eval_build_dir, configs.run_name, f"{problem_id}", f"{sample_id}"
)
try:
eval_result = eval_kernel_against_ref(
original_model_src=ref_arch_src,
custom_model_src=kernel_src,
measure_performance=configs.measure_performance,
timing_method=configs.timing_method,
verbose=configs.verbose,
num_correct_trials=configs.num_correct_trials,
num_perf_trials=configs.num_perf_trials,
build_dir=build_dir,
device=device,
backend=configs.backend,
precision=eval.get_torch_dtype_from_string(configs.precision),
)
return eval_result
except Exception as e:
# INNER CATCH: Handles errors during kernel execution
# - CUDA errors (illegal memory access, kernel launch failures)
# - Runtime errors from the custom kernel
# - Any exception from eval_kernel_against_ref()
print(
f"[WARNING] Last level catch on {sample_id}: Some issue evaluating for kernel: {e} "
)
if "CUDA error" in str(e):
# NOTE: count this as compilation failure as it is not runnable code
metadata = {
"cuda_error": f"CUDA Error: {str(e)}",
"cuda_error_name": get_error_name(e),
"hardware": torch.cuda.get_device_name(device=device),
"device": str(device),
} # log this for debugging as this usually signifies illegal memory access
eval_result = KernelExecResult(
compiled=False, correctness=False, metadata=metadata
)
return eval_result
else:
metadata = {
"other_error": f"error: {str(e)}",
"other_error_name": get_error_name(e),
"hardware": torch.cuda.get_device_name(device=device),
"device": str(device),
} # for debugging
eval_result = KernelExecResult(
compiled=False, correctness=False, metadata=metadata
)
return eval_result
def evaluate_single_sample_modal_direct(
problem_id: int,
sample_id: int,
ref_arch_src: str,
kernel_src: str,
gpu: str,
configs: EvalConfig,
):
"""
Evaluate a single sample using Modal
"""
gpu_arch = gpu_arch_mapping.get(gpu, ["Ada"])
try:
evaluator = ModalEvaluator()
eval_result = evaluator.evaluate_single_sample_modal.remote(
ref_arch_src=ref_arch_src,
kernel_src=kernel_src,
gpu_arch=gpu_arch,
num_correct_trials=configs.num_correct_trials,
num_perf_trials=configs.num_perf_trials,
measure_performance=configs.measure_performance,
timing_method=configs.timing_method,
verbose=configs.verbose,
)
return eval_result
except Exception as e:
print(f"[ERROR] Modal evaluation failed for problem {problem_id} sample {sample_id}: {e}")
return None
def cuda_single_eval_wrapper(curr_work: WorkArgs, configs: dict, dataset, run_dir: str):
"""
Wrapper to handle timeout and keyboard interrupt
"""
with mp.Pool(1) as pool:
try:
result = pool.apply_async(
evaluate_single_sample,
args=(curr_work, configs, dataset, run_dir),
).get(timeout=configs.timeout)
except KeyboardInterrupt:
print("\n [Terminate] Caught KeyboardInterrupt, terminating workers...")
pool.terminate()
pool.join()
raise
except mp.TimeoutError as e:
print(
f"[WARNING] Evaluation TIMED OUT for Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}\nException: {e}"
)
print(
f"[Eval Result] Problem ID: {curr_work.problem_id}, Sample ID: {curr_work.sample_id}: {result}"
)
return result
def remove_cache_dir(cache_dir: str, run_name: str, problem_id, sample_id):
"""
Remove the cached folder for sample compilation so it can start a clean build next time
useful for time out, failed build, etc.
"""
problem_cache_dir = os.path.join(
cache_dir, run_name, f"{problem_id}", f"{sample_id}"
)
print(f"cache_dir to remove: {problem_cache_dir}")
if os.path.exists(cache_dir):
try:
shutil.rmtree(cache_dir, ignore_errors=True)
print(
f"\n[INFO] Removed cached folder for Problem ID: {problem_id}, Sample ID: {sample_id}"
)
except Exception as e:
print(f"\n[WARNING] Failed to remove cache directory {cache_dir}: {str(e)}")
def batch_eval_modal(
total_work: list[tuple[int, int]],
config: EvalConfig,
curr_level_dataset,
run_dir: str,
eval_file_path: str,
):
print(f"[Modal] Starting batch evaluation on {config.gpu} GPUs")
print(f"[Modal] Processing {len(total_work)} samples in parallel batches of {config.num_gpu_devices}")
with app.run():
with tqdm(total=len(total_work), desc="Modal Evaluation Progress") as pbar:
batch_size = config.num_gpu_devices
while len(total_work) > 0:
curr_work_batch = total_work[:batch_size]
total_work = total_work[batch_size:]
print(f"\n[Modal Batch] Processing {len(curr_work_batch)} samples; {len(total_work)} remaining")
start_time = time.time()
# Prepare work items - fetch all data first
work_items = []
for problem_id, sample_id in curr_work_batch:
ref_arch_src = fetch_ref_arch_from_problem_id(
curr_level_dataset, problem_id, config.dataset_src
)
kernel_src = fetch_kernel_from_disk(run_dir, config.level, problem_id, sample_id)
if kernel_src is None:
print(f"[WARNING] Kernel not found for problem {problem_id} sample {sample_id}")
work_items.append(None)
else:
work_items.append({
'problem_id': problem_id,
'sample_id': sample_id,
'ref_arch_src': ref_arch_src,
'kernel_src': kernel_src,
})
# Submit all evaluations in parallel using Modal
gpu_arch = gpu_arch_mapping.get(config.gpu, ["Ada"])
# Override GPU if different from default in decorator
# .with_options() overrides the decorator's parameters
evaluator_cls = ModalEvaluator.with_options(gpu=config.gpu) if config.gpu != "A10G" else ModalEvaluator
# Spawn all tasks in parallel
# Modal assigns these to available containers
# sometimes GPU mem state is corrupted so we will drain this container and find a new one with clean mem state.
# GPU corruption is handled via stop_fetching_inputs() in evaluate_single_sample_modal
futures = []
for item in work_items:
if item is None:
futures.append(None)
else:
future = evaluator_cls().evaluate_single_sample_modal.spawn(
ref_arch_src=item['ref_arch_src'],
kernel_src=item['kernel_src'],
gpu_arch=gpu_arch,
num_correct_trials=config.num_correct_trials,
num_perf_trials=config.num_perf_trials,
measure_performance=config.measure_performance,
timing_method=config.timing_method,
verbose=config.verbose,
backend=config.backend,
precision=config.precision,
)
futures.append(future)
# Collect results from all futures
results = []
for i, future in enumerate(futures):
problem_id, sample_id = curr_work_batch[i]
if future is None:
# Create a failure result for None futures
fail_result = KernelExecResult(
compiled=False,
correctness=False,
metadata={"error": "Future was None - evaluation did not complete"},
runtime=-1.0,
runtime_stats={},
)
results.append((problem_id, sample_id, fail_result))
else:
try:
result = future.get()
results.append((problem_id, sample_id, result))
except Exception as e:
# OUTER CATCH: Modal infrastructure or remote execution failures
# - GPU attachment failures after retries
# - Network/container issues
# - Import errors in the kernel (can't even start evaluation)
# - Any exception from future.get()
error_msg = str(e)
# Check if it's a GPU attachment failure that exhausted retries
if "GPU not attached" in error_msg or "CUDA is not available" in error_msg:
print(f"[ERROR] Modal GPU attachment FAILED after retries for Problem ID: {problem_id}, Sample ID: {sample_id}")
print(f" This is a Modal infrastructure issue. Sample will be recorded as failed.")
else:
print(f"[ERROR] Modal evaluation FAILED for Problem ID: {problem_id}, Sample ID: {sample_id}: {error_msg}")
# Create a failure result instead of None
fail_result = KernelExecResult(
compiled=False,
correctness=False,
metadata={"error": error_msg},
runtime=-1.0,
runtime_stats={},
)
results.append((problem_id, sample_id, fail_result))
end_time = time.time()
# Save results
for problem_id, sample_id, result in results:
print("-" * 128)
print(f"[Eval Result] Problem ID: {problem_id}, Sample ID: {sample_id}")
print(result)
print(f"Adding Eval Result to file for problem {problem_id} sample {sample_id}")
add_to_eval_results_file(
problem_id, sample_id, result, eval_file_path
)
print("-" * 128)
print(f"[Modal Batch] Evaluation took {end_time - start_time:.2f} seconds")
pbar.update(len(curr_work_batch))
def batch_eval(
total_work: list[tuple[int, int]],
config: EvalConfig,
curr_level_dataset,
run_dir: str,
eval_file_path: str,
):
"""
Batch evaluation across multiple GPUs (local or Modal)
We put in time out for each batch, consider trying again with larger time out if it didn't finish building.
Cache directory is removed if evaluation times out or fails
"""
# Use Modal-based evaluation if eval_mode is "modal"
if config.eval_mode == "modal":
return batch_eval_modal(total_work, config, curr_level_dataset, run_dir, eval_file_path)
# Original local GPU evaluation
# construct a list of work args
batch_size = config.num_gpu_devices
with tqdm(total=len(total_work), desc="Processing batches") as pbar:
while len(total_work) > 0:
curr_work_batch = total_work[:batch_size]
total_work = total_work[batch_size:] # pop the first batch_size elements
print(
f"[Curr Batch] {len(curr_work_batch)} tasks over {config.num_gpu_devices} GPUs; [Total Work left] {len(total_work)}"
)
assert (
len(curr_work_batch) <= batch_size
), f"Current batch size {len(curr_work_batch)} is greater than the number of GPUs {batch_size}"
with mp.Pool(batch_size) as pool:
work_args = [
(
WorkArgs(
problem_id=p_id,
sample_id=s_idx,
device=torch.device(f"cuda:{i%batch_size}"),
),
config,
curr_level_dataset,
run_dir,
)
for i, (p_id, s_idx) in enumerate(curr_work_batch)
]
start_time = time.time()
async_results = []
for work_arg in work_args:
async_results.append(
pool.apply_async(evaluate_single_sample, work_arg)
)
# Collect results with a batch timeout
results = []
batch_timeout = config.timeout
for i, async_result in enumerate(async_results):
problem_id, sample_id = curr_work_batch[i]
try:
elapsed_time = time.time() - start_time
remaining_time = max(0, batch_timeout - elapsed_time)
result = async_result.get(timeout=remaining_time)
results.append((problem_id, sample_id, result))
except mp.TimeoutError:
# OUTER CATCH: Evaluation exceeded timeout (config.timeout seconds)
# - Kernel hangs, infinite loops, very slow compilation
print(
f"[WARNING] Evaluation TIMED OUT for Problem ID: {problem_id}, Sample ID: {sample_id}"
)
fail_result = KernelExecResult(
compiled=False,
correctness=False,
metadata={"error": "Evaluation timed out"},
runtime=-1.0,
runtime_stats={},
)
results.append((problem_id, sample_id, fail_result))
remove_cache_dir(
config.kernel_eval_build_dir,
config.run_name,
problem_id,
sample_id,
)
except Exception as e:
# OUTER CATCH: Multiprocessing-level failures
# - Process crashes, pickling errors
# - Errors that escape the inner handler
error_msg = str(e)
print(
f"[ERROR] Evaluation FAILED for Problem ID: {problem_id}, Sample ID: {sample_id}: {error_msg}"
)
fail_result = KernelExecResult(
compiled=False,
correctness=False,
metadata={"error": error_msg},
runtime=-1.0,
runtime_stats={},
)
results.append((problem_id, sample_id, fail_result))
remove_cache_dir(
config.kernel_eval_build_dir,
config.run_name,
problem_id,
sample_id,
)
end_time = time.time()
# current batch summary
for problem_id, sample_id, result in results:
print("-" * 128)
print(
f"[Eval Result] Problem ID: {problem_id}, Sample ID: {sample_id}"
)
print(result)
# add all the batch results here to avoid file race condition
print(
f"Adding Eval Result to file for problem {problem_id} sample {sample_id}"
)
add_to_eval_results_file(
problem_id, sample_id, result, eval_file_path
)
print("-" * 128)
print(
f"[Curr batch] Evaluation took {end_time - start_time:.2f} seconds"
)
pbar.update(len(curr_work_batch))
def check_if_eval_exists_local(
problem_id: int, sample_id: int, eval_file_path: str
) -> bool:
"""
Check if evaluation result already exists in eval results file
"""
if os.path.exists(eval_file_path):
with open(eval_file_path, "r") as f:
eval_results = json.load(f)
return str(problem_id) in eval_results
return False
def add_to_eval_results_file(
problem_id: int, sample_id: int, eval_result: KernelExecResult, eval_file_path: str
):
"""
Add evaluation result to eval results file
TODO: migrate database support
"""
# Load existing results if file exists
if os.path.exists(eval_file_path):
with open(eval_file_path, "r") as f:
eval_results = json.load(f)
eval_results = defaultdict(lambda: [], eval_results)
else:
eval_results = defaultdict(lambda: [])
# Add new result
eval_results[str(problem_id)].append(
{
"sample_id": sample_id,
"compiled": eval_result.compiled,
"correctness": eval_result.correctness,
"metadata": check_metadata_serializable_all_types(eval_result.metadata),
"runtime": eval_result.runtime,
"runtime_stats": eval_result.runtime_stats,
}
)
# Write updated results back to file (sorted by numeric key)
if not os.path.exists(eval_file_path):
os.makedirs(os.path.dirname(eval_file_path), exist_ok=True)
sorted_results = dict(sorted(eval_results.items(), key=lambda x: int(x[0])))
with open(eval_file_path, "w") as f:
json.dump(sorted_results, f, indent=4)
def single_eval_example(
config: EvalConfig, curr_level_dataset: list[str], run_dir: str, eval_file_path
):
device = torch.device("cuda:0")
example_work = WorkArgs(problem_id=1, sample_id=0, device=device)
# example_eval_result = evaluate_single_sample(example_work, config, curr_level_dataset, run_dir)
example_eval_result = cuda_single_eval_wrapper(
example_work, config, curr_level_dataset, run_dir
)
print(example_eval_result)
if not check_if_eval_exists_local(1, 0, eval_file_path):
add_to_eval_results_file(1, 0, example_eval_result, eval_file_path)
@pydra.main(base=EvalConfig)
def main(config: EvalConfig):
"""
Batch Eval Samples from Particular Run
Store Eval Results in specified eval results file
"""
print(f"Starting Batch Eval with config: {config}")
# Handle backend-specific settings
backend = config.backend.lower()
# thunderkittens requires bf16 and H100 GPU
if backend == "thunderkittens":
config.precision = "bf16"
config.gpu = "H100"
print(f"[ThunderKittens] Auto-configured: precision=bf16, gpu=H100")
# Check if CUDA is available (only for local mode)
if config.eval_mode == "local":
if not torch.cuda.is_available():
raise RuntimeError("CUDA device not available. Local evaluation requires GPU.")
# set GPU arch to configure what target to build for
set_gpu_arch(config.gpu_arch)
assert (
config.num_gpu_devices <= torch.cuda.device_count()
), f"Number of GPUs requested ({config.num_gpu_devices}) is greater than the number of available GPUs ({torch.cuda.device_count()})"
else:
print(f"[Modal] Using Modal for evaluation with GPU: {config.gpu}")
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method("spawn")
# Dataset Configurations - Unified loading
dataset = construct_kernelbench_dataset(
level=config.level,
source=config.dataset_src,
dataset_name=config.dataset_name,
)
all_problem_ids = dataset.get_problem_ids()
if config.subset == (None, None):
problem_ids_to_run = all_problem_ids
else:
start, end = config.subset
problem_ids_to_run = [pid for pid in all_problem_ids if start <= pid <= end]
if not problem_ids_to_run:
print(f"Warning: No problems found in subset range {config.subset}")
print(
f"Evaluating {config.num_samples_per_problem} sample(s) each for level {config.level} problems: {problem_ids_to_run}"
)
run_dir = os.path.join(config.runs_dir, config.run_name)
eval_file_path = os.path.join(run_dir, f"eval_results.json")
# To Debug
# single_eval_example(config, dataset, run_dir, eval_file_path)
total_work = []
for problem_id in problem_ids_to_run:
for sample_id in range(config.num_samples_per_problem):
if not check_if_eval_exists_local(problem_id, sample_id, eval_file_path):
total_work.append((problem_id, sample_id))
print(
f"Start evaluation on {len(total_work)} unevaluated samples"
f" in range: {problem_ids_to_run}"
)
# Build Cache on CPU as that is faster (only for local mode)
if config.build_cache and config.eval_mode == "local":
compile.batch_compile(total_work, config.to_dict())
batch_eval(total_work, config, dataset, run_dir, eval_file_path)
# Calculate pass@k metrics
calculate_pass_at_k(eval_file_path, config.pass_at_k_values)
def calc_pass_at_k(n, c, k):
"""
:param n: total number of samples
:param c: number of correct samples
:param k: k in pass@$k$
"""
if n - c < k:
return 1.0
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
def calculate_pass_at_k(eval_file_path: str, k_values: list[int]) -> dict:
"""
Calculate pass@k metrics from evaluation results.
pass@k is the probability that at least one of k samples passes (is correct).
Formula: 1 - (1 - c/n)^k, where c is number of correct samples and n is total samples evaluated.
Args:
eval_file_path: Path to evaluation results file
k_values: List of k values to calculate pass@k for
Returns:
Dictionary mapping problem_id to pass@k metrics for each k value
"""
if not os.path.exists(eval_file_path):
print(
f"[WARNING] Evaluation file {eval_file_path} does not exist. Cannot calculate pass@k."
)
return {}
with open(eval_file_path, "r") as f:
eval_results = json.load(f)
# Group results by problem_id
results_by_problem = {}
for problem_id, result in eval_results.items():
results_by_problem[problem_id] = result
# Calculate pass@k for each problem
pass_at_k_results = {}
for problem_id, results in results_by_problem.items():
# Count correct samples
total_samples = len(results)
correct_samples = sum(1 for r in results if r["correctness"] and r["compiled"])
# Calculate pass@k for each k value
pass_at_k_metrics = {}
for k in k_values:
if k > total_samples:
print(
f"[WARNING] k={k} is greater than total samples {total_samples} for problem {problem_id}. Using k={total_samples}."
)
k = total_samples
pass_at_k = calc_pass_at_k(total_samples, correct_samples, k)
pass_at_k_metrics[f"pass@{k}"] = pass_at_k
pass_at_k_results[problem_id] = {
"total_samples": total_samples,
"correct_samples": correct_samples,
**pass_at_k_metrics,
}
# Calculate average pass@k metrics across all problems
avg_pass_at_k = {}
total_problems = len(pass_at_k_results)
if total_problems > 0:
for k in k_values:
filtered_results = {
p: r for p, r in pass_at_k_results.items() if f"pass@{k}" in r
}
avg_pass_at_k[f"avg_pass@{k}"] = float(
sum(result[f"pass@{k}"] for result in filtered_results.values())
/ total_problems
)
# Add metadata about the evaluation
metadata = {
"total_problems": total_problems,
"problems_with_samples": len(
[p for p, r in pass_at_k_results.items() if r["total_samples"] > 0]
),
"total_evaluated_samples": sum(
r["total_samples"] for r in pass_at_k_results.values()
),
"total_correct_samples": sum(
r["correct_samples"] for r in pass_at_k_results.values()
),
}
# Add pass@k metadata
for k in k_values:
filtered_results = {
p: r for p, r in pass_at_k_results.items() if f"pass@{k}" in r
}
metadata[f"pass@{k}_count"] = len(filtered_results)
# Construct the final result with averages, individual problem results, and metadata
final_results = {
"averages": avg_pass_at_k,
"metadata": metadata,
"problems": pass_at_k_results,
}
# Write pass@k results to file
pass_at_k_file_path = os.path.join(
os.path.dirname(eval_file_path), "pass_at_k_results.json"
)
with open(pass_at_k_file_path, "w") as f:
json.dump(final_results, f, indent=2)
# Print the average pass@k metrics
print(f"Pass@k Correctness metrics calculated and saved to {pass_at_k_file_path}")
print(f"Evaluation metadata: {metadata}")
print(f"Average pass@k metrics: {avg_pass_at_k}")
return final_results
if __name__ == "__main__":
main()