Skip to content

Commit b46c228

Browse files
committed
Added support for metal performance shaders (MPS by apple)
1 parent 60162db commit b46c228

2 files changed

Lines changed: 23 additions & 7 deletions

File tree

src/icatcher/cli.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,19 @@ def load_models(opt, download_only=False):
192192
face_detector_model_file = file_paths[
193193
file_names.index("Resnet50_Final.pth")
194194
]
195-
face_detector_model = RetinaFace(
195+
if opt.device.startswith("mps"):
196+
face_detector_model = RetinaFace(
197+
gpu_id=opt.gpu_id,
198+
model_path=face_detector_model_file,
199+
network="resnet50",
200+
device="mps",
201+
)
202+
else:
203+
face_detector_model = RetinaFace(
196204
gpu_id=opt.gpu_id,
197205
model_path=face_detector_model_file,
198206
network="resnet50",
199-
)
207+
)
200208
elif opt.fd_model == "opencv_dnn":
201209
face_detector_model_file = file_paths[
202210
file_names.index("face_model.caffemodel")
@@ -215,6 +223,10 @@ def load_models(opt, download_only=False):
215223
state_dict = torch.load(
216224
str(path_to_gaze_model), map_location=torch.device(opt.device)
217225
)
226+
elif opt.device.startswith("mps"):
227+
state_dict = torch.load(
228+
str(path_to_gaze_model), map_location=torch.device("mps")
229+
)
218230
else:
219231
state_dict = torch.load(str(path_to_gaze_model))
220232
try:

src/icatcher/options.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -272,13 +272,17 @@ def parse_arguments(my_string=None):
272272
args.device = "cpu"
273273
else:
274274
import os
275-
276-
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
277-
args.device = "cuda:{}".format(0)
278275
import torch
279276

280-
if not torch.cuda.is_available():
281-
raise ValueError("GPU is not available. Was torch compiled with CUDA?")
277+
if torch.cuda.is_available():
278+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
279+
args.device = f"cuda:{args.gpu_id}"
280+
else:
281+
if torch.backends.mps.is_available():
282+
args.device = f"mps:{args.gpu_id}"
283+
else:
284+
raise ValueError("GPU is not available. Was torch compiled with CUDA or MPS?")
285+
282286
# figure out how many cpus can be used
283287
use_cpu = True if args.gpu_id == -1 else False
284288
if use_cpu:

0 commit comments

Comments
 (0)