File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -192,11 +192,19 @@ def load_models(opt, download_only=False):
192192 face_detector_model_file = file_paths [
193193 file_names .index ("Resnet50_Final.pth" )
194194 ]
195- face_detector_model = RetinaFace (
195+ if opt .device .startswith ("mps" ):
196+ face_detector_model = RetinaFace (
197+ gpu_id = opt .gpu_id ,
198+ model_path = face_detector_model_file ,
199+ network = "resnet50" ,
200+ device = "mps" ,
201+ )
202+ else :
203+ face_detector_model = RetinaFace (
196204 gpu_id = opt .gpu_id ,
197205 model_path = face_detector_model_file ,
198206 network = "resnet50" ,
199- )
207+ )
200208 elif opt .fd_model == "opencv_dnn" :
201209 face_detector_model_file = file_paths [
202210 file_names .index ("face_model.caffemodel" )
@@ -215,6 +223,10 @@ def load_models(opt, download_only=False):
215223 state_dict = torch .load (
216224 str (path_to_gaze_model ), map_location = torch .device (opt .device )
217225 )
226+ elif opt .device .startswith ("mps" ):
227+ state_dict = torch .load (
228+ str (path_to_gaze_model ), map_location = torch .device ("mps" )
229+ )
218230 else :
219231 state_dict = torch .load (str (path_to_gaze_model ))
220232 try :
Original file line number Diff line number Diff line change @@ -272,13 +272,17 @@ def parse_arguments(my_string=None):
272272 args .device = "cpu"
273273 else :
274274 import os
275-
276- os .environ ["CUDA_VISIBLE_DEVICES" ] = str (args .gpu_id )
277- args .device = "cuda:{}" .format (0 )
278275 import torch
279276
280- if not torch .cuda .is_available ():
281- raise ValueError ("GPU is not available. Was torch compiled with CUDA?" )
277+ if torch .cuda .is_available ():
278+ os .environ ["CUDA_VISIBLE_DEVICES" ] = str (args .gpu_id )
279+ args .device = f"cuda:{ args .gpu_id } "
280+ else :
281+ if torch .backends .mps .is_available ():
282+ args .device = f"mps:{ args .gpu_id } "
283+ else :
284+ raise ValueError ("GPU is not available. Was torch compiled with CUDA or MPS?" )
285+
282286 # figure out how many cpus can be used
283287 use_cpu = True if args .gpu_id == - 1 else False
284288 if use_cpu :
You can’t perform that action at this time.
0 commit comments