Skip to content

Commit c896919

Browse files
imranq2ebetica
andauthored
Enable running ESM on Mac silicon using MPS (#99)
Signed-off-by: Zeming Lin <ebetica0@gmail.com> Co-authored-by: Zeming Lin <ebetica0@gmail.com>
1 parent 97eb26b commit c896919

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

esm/utils/misc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from collections import defaultdict
3+
from contextlib import nullcontext
34
from io import BytesIO
45
from typing import Any, ContextManager, Sequence, TypeVar
56
from warnings import warn
@@ -224,6 +225,9 @@ def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast
224225
"""
225226
if device_type == "cpu":
226227
return torch.amp.autocast(device_type, enabled=False) # type: ignore
228+
elif device_type == "mps":
229+
# For MPS, just return a no-op context manager (nullcontext) since MPS does not support autocast.
230+
return nullcontext()
227231
elif device_type == "cuda":
228232
return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore
229233
else:

0 commit comments

Comments
 (0)