-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy pathrun_benchmark.py
More file actions
570 lines (501 loc) · 22.1 KB
/
run_benchmark.py
File metadata and controls
570 lines (501 loc) · 22.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
"""This script runs microbenchmarks and collects metrics.
Sample usage (on TPU vm):
$ python run_benchmark.py --config=configs/benchmark_collectives.yaml
"""
import argparse
import datetime
import importlib
import inspect
import itertools
import random
import string
from typing import Any, Callable, Dict, List, Tuple
from benchmark_utils import maybe_write_metrics_file, rename_xla_dump, MetricsStatistics
import jax
import yaml
import ray
from concurrent.futures import ThreadPoolExecutor
import os
import copy
import pandas as pd
import ast
import json
COLLECTIVE_BENCHMARK_MAP = {
"all_gather": "benchmark_collectives.all_gather_benchmark",
"psum": "benchmark_collectives.psum_benchmark",
"psum_scatter": "benchmark_collectives.psum_scatter_benchmark",
"all_to_all": "benchmark_collectives.all_to_all_benchmark",
"ppermute": "benchmark_collectives.ppermute_benchmark",
"send_recv": "benchmark_send_recv.send_recv_benchmark",
}
MATMUL_BENCHMARK_MAP = {
"naive_matmul": "benchmark_matmul.naive_matmul",
"single_host_naive_matmul": "benchmark_matmul.single_host_naive_matmul",
"multilayer_collective_matmul": ("benchmark_matmul.multilayer_collective_matmul"),
"collective_matmul_one_direction": (
"benchmark_matmul.collective_matmul_one_direction"
),
"collective_matmul_two_directions": (
"benchmark_matmul.collective_matmul_two_directions"
),
}
CONVOLUTION_BENCHMARK_MAP = {
"numpy_convolve": "benchmark_convolution.numpy_convolve",
"scipy_signal_convolve": "benchmark_convolution.scipy_signal_convolve",
"scipy_signal_convolve2d": "benchmark_convolution.scipy_signal_convolve2d",
"lax_conv_general_dilated": ("benchmark_convolution.lax_conv_general_dilated"),
}
ATTENTION_BENCHMARK_MAP = {
"tokamax_splash_attention": "benchmark_attention.tokamax_splash_attention_benchmark",
}
HBM_BENCHMARK_MAP = {
"single_device_hbm_copy": "benchmark_hbm.single_device_hbm_copy",
"multiple_device_hbm_copy": "benchmark_hbm.multiple_devices_hbm_copy",
}
COMPUTE_BENCHMARK_MAP = {
"gemm_simple": "benchmark_gemm.gemm_simple",
"gemm_simple_with_dtype": "benchmark_gemm.gemm_simple_with_dtype",
"gemm_multiple_run": "benchmark_gemm.gemm_multiple_run",
"gemm_throttling": "benchmark_gemm_throttling.gemm_throttling",
"gemm": "benchmark_gemm.gemm",
"gemm_accum": "benchmark_gemm.gemm_accum",
"gemm_multiple_devices": "benchmark_gemm.gemm_multiple_devices",
"quantization": "benchmark_compute.quantization",
"transpose_quantization": "benchmark_compute.transpose_quantization",
"quantization_static_scaling": (
"benchmark_compute.quantization_static_scaling"
),
"transpose_quantization_static_scaling": (
"benchmark_compute.transpose_quantization_static_scaling"
),
"swiglu_fwd": "benchmark_compute.swiglu_fwd",
"swiglu_bwd": "benchmark_compute.swiglu_bwd",
"rmsnorm_fwd": "benchmark_compute.rmsnorm_fwd",
"rmsnorm_bwd": "benchmark_compute.rmsnorm_bwd",
"add": "benchmark_compute.add",
"gemm_fp8_rowwise": "benchmark_gemm_numerics.gemm_fp8_rowwise",
"gemm_fp8_b128_fp32": "benchmark_gemm_numerics.gemm_fp8_b128_fp32",
"gemm_fp8_rowwise_static_scaling": (
"benchmark_gemm_numerics.gemm_fp8_rowwise_static_scaling"
),
"gemm_fp8_b128_fp32_static_scaling": (
"benchmark_gemm_numerics.gemm_fp8_b128_fp32_static_scaling"
),
"gemm_mxfp8_b32": "benchmark_gemm_numerics.gemm_mxfp8_b32",
"gemm_mxfp8_b32_static_scaling": (
"benchmark_gemm_numerics.gemm_mxfp8_b32_static_scaling"
),
"gemm_fp8_rowwise_w_dequantize": (
"benchmark_gemm_numerics.gemm_fp8_rowwise_w_dequantize"
),
"inference_add": "benchmark_inference_compute.add",
"inference_rmsnorm": "benchmark_inference_compute.rmsnorm",
"inference_silu_mul": "benchmark_inference_compute.silu_mul",
"inference_sigmoid": "benchmark_inference_compute.sigmoid",
}
HOST_DEVICE_BENCHMARK_MAP = {
"host_device": "benchmark_host_device.benchmark_host_device",
}
BENCHMARK_MAP = {}
BENCHMARK_MAP.update(COLLECTIVE_BENCHMARK_MAP)
BENCHMARK_MAP.update(MATMUL_BENCHMARK_MAP)
BENCHMARK_MAP.update(CONVOLUTION_BENCHMARK_MAP)
BENCHMARK_MAP.update(ATTENTION_BENCHMARK_MAP)
BENCHMARK_MAP.update(HBM_BENCHMARK_MAP)
BENCHMARK_MAP.update(COMPUTE_BENCHMARK_MAP)
BENCHMARK_MAP.update(HOST_DEVICE_BENCHMARK_MAP)
# Mapping from dtype string to actual dtype object
dtype_mapping = {
"bfloat16": jax.numpy.bfloat16,
"float32": jax.numpy.float32,
"int32": jax.numpy.int32,
"float8": jax.numpy.float8_e4m3fn,
# Add other dtypes as needed
}
# Always dump HLOs
TMP_XLA_DUMP_DIR = "/tmp/microbenchmarks/hlo_graphs"
os.environ["XLA_FLAGS"] = f"--xla_dump_to={TMP_XLA_DUMP_DIR}"
def get_benchmark_config(config_path: str) -> Dict[str, Any]:
"""Load benchmark configuration from a YAML file."""
with open(config_path, "r") as file:
return yaml.safe_load(file)
# Dynamically load the benchmark functions.
def get_benchmark_functions(
benchmark_name: str,
) -> Tuple[Callable[..., Any], Callable[..., Any]]:
"""Dynamically load the benchmark function and its calculate_metrics function from the predefined map."""
if benchmark_name not in BENCHMARK_MAP:
raise ValueError(f"Benchmark {benchmark_name} is not defined in the map.")
module_path, func_name = BENCHMARK_MAP[benchmark_name].rsplit(".", 1)
# Get the benchmark function
try:
module = importlib.import_module(f"{module_path}")
benchmark_func = getattr(module, func_name)
except ModuleNotFoundError as e:
raise ValueError(
f"Unable to import {module_path}.{func_name}. ModuleNotFoundError {e}."
) from e
except AttributeError as e:
raise ValueError(
f"Unable to import {module_path}.{func_name}. AttributeError {e}."
) from e
# Get the calculate_metrics function
try:
calculate_metrics_func = getattr(module, f"{func_name}_calculate_metrics")
except AttributeError:
raise ValueError(
f"Calculate metrics function for {benchmark_name} not found."
) from None
return benchmark_func, calculate_metrics_func
def preprocess_benchmark_param(
benchmark_param: Dict[str, Any], trace_dir: string = None
) -> Dict[str, Any]:
"""Preprocess the benchmark parameter before running the benchmark."""
if "dtype" in benchmark_param:
dtype_str = benchmark_param["dtype"]
if dtype_str in dtype_mapping:
benchmark_param["dtype"] = dtype_mapping[dtype_str]
else:
raise ValueError(f"Unsupported dtype: {dtype_str}")
# Handle "SAME_AS_" parameters.
# For example, if "n" is "SAME_AS_m", then "n" will
# be set to the same value as "m".
for key, value in benchmark_param.items():
if isinstance(value, str) and value.startswith("SAME_AS_"):
same_as_key = value.split("SAME_AS_")[1]
if same_as_key not in benchmark_param:
raise ValueError(
f"Parameter {same_as_key} not found in the benchmark_param."
)
benchmark_param[key] = benchmark_param[same_as_key]
benchmark_param["trace_dir"] = trace_dir
return benchmark_param
def generate_benchmark_params_sweeping(
benchmark_sweep_params: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""Generate benchmark parameters by sweeping through the specified ranges."""
generated_params = []
for sweep_params in benchmark_sweep_params:
param_sets = {}
for key, value in sweep_params.items():
if key.endswith("_range"):
key = key[:-6] # Remove the last 6 characters (i.e., '_range')
if key.endswith("_list"):
key = key[:-5] # Remove the last 6 characters (i.e., '_list')
if isinstance(value, list):
param_sets[key] = value
elif isinstance(value, dict):
# Extract the range and multiplier
start = value.get("start")
end = value.get("end")
multiplier = value.get("multiplier", None)
increase_by = value.get("increase_by", None)
# Generate values in the range
param_values = []
current_value = start
while current_value <= end:
param_values.append(current_value)
if multiplier:
current_value *= multiplier
elif increase_by:
current_value += increase_by
else:
raise ValueError(
"In sweep mode, user must provide either multiplier or"
" increase_by value."
)
# Add the generated values to the param set
param_sets[key] = param_values
else:
# If it's not a range, just add it as a list with one element
param_sets[key] = [value]
# Get parameter names in a fixed order
param_names = list(param_sets.keys())
# Generate all combinations using itertools.product
combinations = [
dict(zip(param_names, values))
for values in itertools.product(*(param_sets[name] for name in param_names))
]
generated_params += combinations
return generated_params
def write_to_csv(csv_path: str, calculate_metrics_results: List[Dict[str, Any]]):
"""Writes benchmark metrics to a CSV file.
This function takes a list of dictionaries, where each dictionary contains
the 'metadata' and 'metrics' from a benchmark run. It processes each
dictionary by flattening it, calculating additional statistics for specific
fields (like 'ici_average_time_ms_list'), and then converting it into a
pandas DataFrame. All resulting DataFrames are concatenated and written to
the specified CSV file.
Args:
csv_path: The path to the output CSV file.
calculate_metrics_results: A list of dictionaries with benchmark results.
"""
if not calculate_metrics_results:
raise ValueError("0 metrics results are collected.")
if not isinstance(calculate_metrics_results[0], dict):
raise ValueError("metrics result is not a dict.")
def flatten_dict(current_dict: Dict) -> Dict:
"""Recursively flattens a nested dictionary."""
output_dict = {}
for key, val in current_dict.items():
if isinstance(val, Dict):
output_dict.update(flatten_dict(val))
else:
# Try to evaluate string-formatted literals (e.g., "[1, 2, 3]")
try:
output_dict[key] = ast.literal_eval(val)
except (ValueError, SyntaxError, TypeError):
# If it's not a valid literal, keep it as a string.
output_dict[key] = val
return output_dict
def convert_dict_to_df(target_dict: Dict) -> pd.DataFrame:
"""Converts a single benchmark result dictionary to a pandas DataFrame."""
flattened_dict = flatten_dict(target_dict)
# This section is specific to collective benchmarks that produce
# 'ici_average_time_ms_list'.
if "ici_average_time_ms_list" in flattened_dict:
# Calculate statistics for the timing list.
ici_average_time_ms_statistics = MetricsStatistics(
metrics_list=flattened_dict["ici_average_time_ms_list"],
metrics_name="ici_average_time_ms",
).statistics
for key, val in ici_average_time_ms_statistics.items():
flattened_dict["ici_average_time_ms_" + key] = val
# Convert list to JSON string for CSV storage.
flattened_dict["ici_average_time_ms_list"] = json.dumps(
flattened_dict["ici_average_time_ms_list"]
)
df = pd.DataFrame(flattened_dict, index=[0])
return df
# TODO(hylin2002@)
# This is a temporary workaround to generate a properly formatted CSV file for the output metrics.
# We should revert this PR and refactor the code such that metrics object is a flatten dict that can be easily exported as a CSV.
# For other information that requires nested structures, we should serialize it into a json file."
df_list = [convert_dict_to_df(each) for each in calculate_metrics_results]
df = pd.concat(df_list, ignore_index=True)
df.to_csv(csv_path, index=False, sep="\t")
print(f"Metrics written to CSV at {csv_path}.")
def run_single_benchmark(benchmark_config: Dict[str, Any], output_path: str):
"""Run a single benchmark with one or more configurations."""
# Extract benchmark details
benchmark_name = benchmark_config.get("benchmark_name")
benchmark_params = benchmark_config.get("benchmark_params", [])
benchmark_sweep_params = benchmark_config.get("benchmark_sweep_params", {})
if benchmark_sweep_params:
benchmark_params += generate_benchmark_params_sweeping(benchmark_sweep_params)
csv_path = benchmark_config.get("csv_path")
trace_dir = benchmark_config.get("trace_dir")
xlml_metrics_dir = benchmark_config.get("xlml_metrics_dir")
xla_dump_dir = benchmark_config.get("xla_dump_dir")
if output_path != "":
# csv_path = os.path.join(output_path, benchmark_name)
trace_dir = os.path.join(output_path, benchmark_name, "trace")
xla_dump_dir = os.path.join(output_path, benchmark_name, "hlo_graphs")
# Inject num_runs from config if not present in params
global_num_runs = benchmark_config.get("num_runs")
if global_num_runs is not None:
for param in benchmark_params:
if "num_runs" not in param:
param["num_runs"] = global_num_runs
if not benchmark_name:
raise ValueError("Each benchmark must have a 'benchmark_name'.")
# Get the benchmark function
benchmark_func, calculate_metrics_func = get_benchmark_functions(benchmark_name)
print(f"\n{'=' * 30}Starting benchmark '{benchmark_name}'{'=' * 30}\n")
# Run the benchmark
calculate_metrics_results = []
for id, benchmark_param in enumerate(benchmark_params):
original_benchmark_param = copy.deepcopy(benchmark_param)
benchmark_param = preprocess_benchmark_param(
benchmark_param, trace_dir=os.path.join(trace_dir, f"benchmark_{id}")
)
print(f"Running benchmark: {benchmark_name} with params: {benchmark_param}")
test_start_time = (
datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z"
) # "Z" indicates UTC
benchmark_func_params = inspect.signature(benchmark_func).parameters
try:
benchmark_results = benchmark_func(**benchmark_param)
except Exception as e: # pylint: disable=broad-except
print(f"Benchmark func failed: {e}")
continue
test_end_time = (
datetime.datetime.now(tz=datetime.timezone.utc).isoformat() + "Z"
)
xla_output = None
if xla_dump_dir:
xla_output = rename_xla_dump(
tmp_xla_dump_dir=TMP_XLA_DUMP_DIR,
dest_xla_dump_dir=xla_dump_dir,
benchmark_name=benchmark_name,
benchmark_param=original_benchmark_param,
)
benchmark_results["xla_output"] = xla_output
# Filter benchmark_results to include only keys present in
# calculate_metrics_func
calculate_metrics_params = inspect.signature(
calculate_metrics_func
).parameters
filtered_benchmark_results = {
key: value
for key, value in benchmark_results.items()
if key in calculate_metrics_params
}
# Filter out certain parameters from benchmark_param, eg. "num_runs".
benchmark_params_to_filter = ["num_runs", "trace_dir"]
filtered_benchmark_param = {
key: value
for key, value in benchmark_param.items()
if key not in benchmark_params_to_filter
}
metadata, metrics = calculate_metrics_func(
**filtered_benchmark_param, **filtered_benchmark_results
)
if xlml_metrics_dir:
maybe_write_metrics_file(
xlml_metrics_dir,
metrics,
metadata,
benchmark_name,
test_start_time,
test_end_time,
)
# Post process the xla dump
calculate_metrics_results.append({
"metadata": metadata,
"metrics": metrics
})
# Dump metrics to file.
if csv_path:
os.makedirs(csv_path, exist_ok=True)
test_name = f"t_{benchmark_name}_" + "".join(
random.choices(string.ascii_uppercase + string.digits, k=10)
)
write_to_csv(f"{csv_path}/{test_name}.tsv", calculate_metrics_results)
def main(args):
"""Main function."""
# Load configuration
config_path = args.config
multithreaded = args.multithreaded
output_path = args.output_path
config = get_benchmark_config(config_path)
benchmarks = config.get("benchmarks")
if not benchmarks or not isinstance(benchmarks, list):
raise ValueError("Configuration must contain a 'benchmarks' list.")
# Clear the tmp dirs.
if os.path.exists(TMP_XLA_DUMP_DIR):
for filename in os.listdir(TMP_XLA_DUMP_DIR):
file_path = os.path.join(TMP_XLA_DUMP_DIR, filename)
if os.path.isfile(file_path):
os.remove(file_path)
if multithreaded:
ray.init(
runtime_env=ray.runtime_env.RuntimeEnv(
address="ray://tpu-ray-cluster-head-svc:10001",
env_vars={
"XLA_IR_DEBUG": "1",
"XLA_HLO_DEBUG": "1",
"PJRT_DEVICE": "TPU",
# "LIBTPU_INIT_ARGS": "--xla_tpu_scoped_vmem_limit_kib=25602",
},
)
)
# Calculate the number of TPU hosts within our Ray cluster...
# num_hosts = int(ray.available_resources()["TPU"]) // 4
print(ray.available_resources())
# print("Num hosts detected: %d", num_hosts)
for benchmark_config in benchmarks:
run_benchmark_multithreaded(benchmark_config, output_path)
else:
for benchmark_config in benchmarks:
run_single_benchmark(benchmark_config, output_path)
def run_benchmark_multithreaded(benchmark_config, output_path):
# Extract benchmark details
benchmark_name = benchmark_config.get("benchmark_name")
benchmark_params = benchmark_config.get("benchmark_params", [])
benchmark_sweep_params = benchmark_config.get("benchmark_sweep_params", {})
if benchmark_sweep_params:
benchmark_params += generate_benchmark_params_sweeping(benchmark_sweep_params)
csv_path = benchmark_config.get("csv_path")
if not benchmark_name:
raise ValueError("Each benchmark must have a 'benchmark_name'.")
if output_path != "":
csv_path = os.path.join(output_path, benchmark_name)
os.makedirs(csv_path, exist_ok=True)
# Inject num_runs from config if not present in params
global_num_runs = benchmark_config.get("num_runs")
if global_num_runs is not None:
for param in benchmark_params:
if "num_runs" not in param:
param["num_runs"] = global_num_runs
# Get the benchmark function
benchmark_func, calculate_metrics_func = get_benchmark_functions(benchmark_name)
print(f"\n{'=' * 30}Starting benchmark '{benchmark_name}'{'=' * 30}\n")
# Start a trace if requested
test_name = f"t_{benchmark_name}_" + "".join(
random.choices(string.ascii_uppercase + string.digits, k=10)
)
# Preprocess benchmark parameters
preprocessed_benchmark_params = [
preprocess_benchmark_param(benchmark_param, trace_dir=None)
for benchmark_param in benchmark_params
]
calculate_metrics_results = []
# Calculate the number of TPU hosts within our Ray cluster...
num_hosts = int(ray.available_resources()["TPU"]) // 4
# print(ray.available_resources())
print(f"Num hosts detected: {num_hosts}")
# Run benchmark_func in multiple threads
with ThreadPoolExecutor(max_workers=num_hosts) as executor:
# Create a mapping of futures to their corresponding parameters
future_to_param = {
executor.submit(benchmark_func, **benchmark_param): benchmark_param
for benchmark_param in preprocessed_benchmark_params
}
# Process each future as it completes
for future in future_to_param:
benchmark_param = future_to_param[
future
] # Retrieve the corresponding benchmark_param
benchmark_results = future.result() # Get the result from the future
# Filter benchmark_results to include only keys present in calculate_metrics_func
calculate_metrics_params = inspect.signature(
calculate_metrics_func
).parameters
filtered_benchmark_results = {
key: value
for key, value in benchmark_results.items()
if key in calculate_metrics_params
}
# Call calculate_metrics_func with the filtered results and benchmark_param
metadata, metrics = calculate_metrics_func(
**benchmark_param, **filtered_benchmark_results
)
calculate_metrics_results.append({"metadata": metadata, "metrics": metrics})
if csv_path:
os.makedirs(csv_path, exist_ok=True)
write_to_csv(f"{csv_path}/{test_name}.tsv", calculate_metrics_results)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run microbenchmarks and collect metrics."
)
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to the YAML configuration file.",
)
parser.add_argument(
"--output_path",
type=str,
default="",
help="Path to output.",
)
parser.add_argument(
"--multithreaded",
type=bool,
default=False,
help="Path to the YAML configuration file.",
)
args = parser.parse_args()
main(args)