-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathutils.py
More file actions
65 lines (54 loc) · 1.94 KB
/
utils.py
File metadata and controls
65 lines (54 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import argparse
import openvino as ov
import torch
from src.model import ModelLoader
from src.onnx_exporter import ONNXExporter
from src.ov_exporter import OVExporter
import onnxruntime as ort
def export_onnx_model(
onnx_path: str, model_loader: ModelLoader, device: torch.device
) -> None:
onnx_exporter = ONNXExporter(model_loader.model, device, onnx_path)
onnx_exporter.export_model()
def init_onnx_model(
onnx_path: str, model_loader: ModelLoader, device: torch.device
) -> ort.InferenceSession:
export_onnx_model(onnx_path=onnx_path, model_loader=model_loader, device=device)
return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
def init_ov_model(onnx_path: str) -> ov.CompiledModel:
ov_exporter = OVExporter(onnx_path)
return ov_exporter.export_model()
def init_cuda_model(
model_loader: ModelLoader, device: torch.device, dtype: torch.dtype
) -> torch.nn.Module:
cuda_model = model_loader.model.to(device)
if device == "cuda":
cuda_model = torch.jit.trace(
cuda_model, [torch.randn((1, 3, 224, 224)).to(device)]
)
return cuda_model
def parse_arguments():
# Initialize ArgumentParser with description
parser = argparse.ArgumentParser(description="PyTorch Inference")
parser.add_argument(
"--image_path",
type=str,
default="./inference/cat3.jpg",
help="Path to the image to predict",
)
parser.add_argument(
"--topk", type=int, default=5, help="Number of top predictions to show"
)
parser.add_argument(
"--onnx_path",
type=str,
default="./inference/model.onnx",
help="Path where model in ONNX format will be exported",
)
parser.add_argument(
"--mode",
choices=["onnx", "ov", "cuda", "all"],
default="all",
help="Mode for exporting and running the model. Choices are: onnx, ov, cuda or all.",
)
return parser.parse_args()