-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathmodel_zoo.py
More file actions
59 lines (51 loc) · 1.76 KB
/
model_zoo.py
File metadata and controls
59 lines (51 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
# @Organization : insightface.ai
# @Author : Jia Guo
# @Time : 2021-05-04
# @Function :
import os
import os.path as osp
import glob
import onnxruntime
from .arcface_onnx import *
from .scrfd import *
#__all__ = ['get_model', 'get_model_list', 'get_arcface_onnx', 'get_scrfd']
__all__ = ['get_model']
class ModelRouter:
def __init__(self, onnx_file):
self.onnx_file = onnx_file
def get_model(self):
session = onnxruntime.InferenceSession(self.onnx_file, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
input_cfg = session.get_inputs()[0]
input_shape = input_cfg.shape
outputs = session.get_outputs()
#print(input_shape)
if len(outputs)>=5:
return SCRFD(model_file=self.onnx_file, session=session)
elif input_shape[2]==112 and input_shape[3]==112:
return ArcFaceONNX(model_file=self.onnx_file, session=session)
else:
raise RuntimeError('error on model routing')
def find_onnx_file(dir_path):
if not os.path.exists(dir_path):
return None
paths = glob.glob("%s/*.onnx" % dir_path)
if len(paths) == 0:
return None
paths = sorted(paths)
return paths[-1]
def get_model(name, **kwargs):
root = kwargs.get('root', '~/.insightface/models')
root = os.path.expanduser(root)
if not name.endswith('.onnx'):
model_dir = os.path.join(root, name)
model_file = find_onnx_file(model_dir)
if model_file is None:
return None
else:
model_file = name
assert osp.isfile(model_file), 'model should be file'
router = ModelRouter(name)
model = router.get_model()
#print('get-model for ', name,' : ', model.taskname)
return model