Skip to content

Commit 2671d58

Browse files
Fix
1 parent 03f6fcb commit 2671d58

1 file changed

Lines changed: 13 additions & 0 deletions

File tree

predai/rootfs/predai.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,19 @@ def _patched_torch_load(*args, **kwargs):
5353
kwargs['weights_only'] = False
5454
return _original_torch_load(*args, **kwargs)
5555
torch.load = _patched_torch_load
56+
57+
# Also patch torch.serialization._load if it exists (for older PyTorch)
58+
if hasattr(torch.serialization, '_load'):
59+
_original_serialization_load = torch.serialization._load
60+
def _patched_serialization_load(*args, **kwargs):
61+
if 'weights_only' not in kwargs:
62+
kwargs['weights_only'] = False
63+
return _original_serialization_load(*args, **kwargs)
64+
torch.serialization._load = _patched_serialization_load
65+
66+
# Set environment variable as additional fallback
67+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
68+
5669
except (ImportError, AttributeError):
5770
# If torch.serialization or classes are not available, continue without the fix
5871
# This allows backward compatibility with older PyTorch versions

0 commit comments

Comments
 (0)