|
1 | 1 | import argparse |
2 | | -import pandas as pd |
| 2 | + |
3 | 3 | import matplotlib.pyplot as plt |
| 4 | +import pandas as pd |
4 | 5 | import seaborn as sns |
5 | | -from typing import Dict, Tuple |
6 | | - |
7 | 6 |
|
8 | | -def plot_benchmark_results(results: Dict[str, Tuple[float, float]]): |
9 | | - """ |
10 | | - Plot the benchmark results using Seaborn. |
11 | | -
|
12 | | - :param results: Dictionary where the key is the model type and the value is a tuple (average inference time, throughput). |
13 | | - """ |
14 | | - plot_path = "./inference/plot.png" |
15 | | - |
16 | | - # Extract data from the results |
| 7 | +PLOT_OUTPUT_PATH = "./inference/plot.png" |
| 8 | +DEFAULT_IMAGE_PATH = "./inference/cat3.jpg" |
| 9 | +DEFAULT_ONNX_PATH = "./models/model.onnx" |
| 10 | +DEFAULT_OV_PATH = "./models/model.ov" |
| 11 | +DEFAULT_TOPK = 5 |
| 12 | +INFERENCE_MODES = ["onnx", "ov", "cpu", "cuda", "tensorrt", "all"] |
| 13 | + |
| 14 | + |
| 15 | +def _create_sorted_dataframe( |
| 16 | + data: dict[str, float], column_name: str, ascending: bool |
| 17 | +) -> pd.DataFrame: |
| 18 | + df = pd.DataFrame(list(data.items()), columns=["Model", column_name]) |
| 19 | + return df.sort_values(column_name, ascending=ascending) |
| 20 | + |
| 21 | + |
| 22 | +def _plot_bar_chart( |
| 23 | + ax, |
| 24 | + data: pd.DataFrame, |
| 25 | + x_col: str, |
| 26 | + y_col: str, |
| 27 | + xlabel: str, |
| 28 | + ylabel: str, |
| 29 | + title: str, |
| 30 | + palette: str, |
| 31 | + value_format: str, |
| 32 | +): |
| 33 | + sns.barplot(x=data[x_col], y=data[y_col], hue=data[y_col], palette=palette, ax=ax, legend=False) |
| 34 | + ax.set_xlabel(xlabel) |
| 35 | + ax.set_ylabel(ylabel) |
| 36 | + ax.set_title(title) |
| 37 | + |
| 38 | + for index, value in enumerate(data[x_col]): |
| 39 | + ax.text(value, index, value_format.format(value), color="black", ha="left", va="center") |
| 40 | + |
| 41 | + |
| 42 | +def plot_benchmark_results(results: dict[str, tuple[float, float]]): |
17 | 43 | models = list(results.keys()) |
18 | | - times = [value[0] for value in results.values()] |
19 | | - throughputs = [value[1] for value in results.values()] |
| 44 | + times = {model: results[model][0] for model in models} |
| 45 | + throughputs = {model: results[model][1] for model in models} |
20 | 46 |
|
21 | | - # Create DataFrames for plotting |
22 | | - time_data = pd.DataFrame({"Model": models, "Time": times}) |
23 | | - throughput_data = pd.DataFrame({"Model": models, "Throughput": throughputs}) |
| 47 | + time_data = _create_sorted_dataframe(times, "Time", ascending=True) |
| 48 | + throughput_data = _create_sorted_dataframe(throughputs, "Throughput", ascending=False) |
24 | 49 |
|
25 | | - # Sort the DataFrames |
26 | | - time_data = time_data.sort_values("Time", ascending=True) |
27 | | - throughput_data = throughput_data.sort_values("Throughput", ascending=False) |
28 | | - |
29 | | - # Create subplots |
30 | 50 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 6)) |
31 | 51 |
|
32 | | - # Plot inference times |
33 | | - sns.barplot( |
34 | | - x=time_data["Time"], |
35 | | - y=time_data["Model"], |
36 | | - hue=time_data["Model"], |
37 | | - palette="rocket", |
38 | | - ax=ax1, |
39 | | - legend=False, |
| 52 | + _plot_bar_chart( |
| 53 | + ax1, |
| 54 | + time_data, |
| 55 | + "Time", |
| 56 | + "Model", |
| 57 | + "Average Inference Time (ms)", |
| 58 | + "Model Type", |
| 59 | + "ResNet50 - Inference Benchmark Results", |
| 60 | + "rocket", |
| 61 | + "{:.2f} ms", |
40 | 62 | ) |
41 | | - ax1.set_xlabel("Average Inference Time (ms)") |
42 | | - ax1.set_ylabel("Model Type") |
43 | | - ax1.set_title("ResNet50 - Inference Benchmark Results") |
44 | | - for index, value in enumerate(time_data["Time"]): |
45 | | - ax1.text(value, index, f"{value:.2f} ms", color="black", ha="left", va="center") |
46 | | - |
47 | | - # Plot throughputs |
48 | | - sns.barplot( |
49 | | - x=throughput_data["Throughput"], |
50 | | - y=throughput_data["Model"], |
51 | | - hue=throughput_data["Model"], |
52 | | - palette="viridis", |
53 | | - ax=ax2, |
54 | | - legend=False, |
| 63 | + |
| 64 | + _plot_bar_chart( |
| 65 | + ax2, |
| 66 | + throughput_data, |
| 67 | + "Throughput", |
| 68 | + "Model", |
| 69 | + "Throughput (samples/sec)", |
| 70 | + "", |
| 71 | + "ResNet50 - Throughput Benchmark Results", |
| 72 | + "viridis", |
| 73 | + "{:.2f}", |
55 | 74 | ) |
56 | | - ax2.set_xlabel("Throughput (samples/sec)") |
57 | | - ax2.set_ylabel("") |
58 | | - ax2.set_title("ResNet50 - Throughput Benchmark Results") |
59 | | - for index, value in enumerate(throughput_data["Throughput"]): |
60 | | - ax2.text(value, index, f"{value:.2f}", color="black", ha="left", va="center") |
61 | 75 |
|
62 | | - # Save the plot to a file |
63 | 76 | plt.tight_layout() |
64 | | - plt.savefig(plot_path, bbox_inches="tight") |
| 77 | + plt.savefig(PLOT_OUTPUT_PATH, bbox_inches="tight") |
65 | 78 | plt.show() |
66 | 79 |
|
67 | | - print(f"Plot saved to {plot_path}") |
| 80 | + print(f"Plot saved to {PLOT_OUTPUT_PATH}") |
68 | 81 |
|
69 | 82 |
|
70 | 83 | def parse_arguments(): |
71 | | - # Initialize ArgumentParser with description |
72 | 84 | parser = argparse.ArgumentParser(description="PyTorch Inference") |
73 | 85 |
|
74 | 86 | parser.add_argument( |
75 | 87 | "--image_path", |
76 | 88 | type=str, |
77 | | - default="./inference/cat3.jpg", |
| 89 | + default=DEFAULT_IMAGE_PATH, |
78 | 90 | help="Path to the image to predict", |
79 | 91 | ) |
80 | 92 |
|
81 | 93 | parser.add_argument( |
82 | | - "--topk", type=int, default=5, help="Number of top predictions to show" |
| 94 | + "--topk", type=int, default=DEFAULT_TOPK, help="Number of top predictions to show" |
83 | 95 | ) |
84 | 96 |
|
85 | 97 | parser.add_argument( |
86 | 98 | "--onnx_path", |
87 | 99 | type=str, |
88 | | - default="./models/model.onnx", |
| 100 | + default=DEFAULT_ONNX_PATH, |
89 | 101 | help="Path where model in ONNX format will be exported", |
90 | 102 | ) |
91 | 103 |
|
92 | 104 | parser.add_argument( |
93 | 105 | "--ov_path", |
94 | 106 | type=str, |
95 | | - default="./models/model.ov", |
| 107 | + default=DEFAULT_OV_PATH, |
96 | 108 | help="Path where model in OpenVINO format will be exported", |
97 | 109 | ) |
98 | 110 |
|
99 | 111 | parser.add_argument( |
100 | 112 | "--mode", |
101 | | - choices=["onnx", "ov", "cpu", "cuda", "tensorrt", "all"], |
| 113 | + choices=INFERENCE_MODES, |
102 | 114 | default="all", |
103 | | - help="Mode for exporting and running the model. Choices are: onnx, ov, cuda, tensorrt or all.", |
| 115 | + help="Mode for exporting and running the model", |
104 | 116 | ) |
105 | 117 |
|
106 | 118 | parser.add_argument( |
107 | 119 | "-D", |
108 | 120 | "--DEBUG", |
109 | 121 | action="store_true", |
110 | | - help="Enable or disable debug capabilities.", |
| 122 | + help="Enable debug mode", |
111 | 123 | ) |
112 | 124 |
|
113 | 125 | return parser.parse_args() |
0 commit comments