We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 97eb26b commit c896919Copy full SHA for c896919
1 file changed
esm/utils/misc.py
@@ -1,5 +1,6 @@
1
import os
2
from collections import defaultdict
3
+from contextlib import nullcontext
4
from io import BytesIO
5
from typing import Any, ContextManager, Sequence, TypeVar
6
from warnings import warn
@@ -224,6 +225,9 @@ def fp32_autocast_context(device_type: str) -> ContextManager[torch.amp.autocast
224
225
"""
226
if device_type == "cpu":
227
return torch.amp.autocast(device_type, enabled=False) # type: ignore
228
+ elif device_type == "mps":
229
+ # For MPS, just return a no-op context manager (nullcontext) since MPS does not support autocast.
230
+ return nullcontext()
231
elif device_type == "cuda":
232
return torch.amp.autocast(device_type, dtype=torch.float32) # type: ignore
233
else:
0 commit comments