File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -53,6 +53,19 @@ def _patched_torch_load(*args, **kwargs):
5353 kwargs ['weights_only' ] = False
5454 return _original_torch_load (* args , ** kwargs )
5555 torch .load = _patched_torch_load
56+
57+ # Also patch torch.serialization._load if it exists (for older PyTorch)
58+ if hasattr (torch .serialization , '_load' ):
59+ _original_serialization_load = torch .serialization ._load
60+ def _patched_serialization_load (* args , ** kwargs ):
61+ if 'weights_only' not in kwargs :
62+ kwargs ['weights_only' ] = False
63+ return _original_serialization_load (* args , ** kwargs )
64+ torch .serialization ._load = _patched_serialization_load
65+
66+ # Set environment variable as additional fallback
67+ os .environ ['PYTORCH_ENABLE_MPS_FALLBACK' ] = '1'
68+
5669except (ImportError , AttributeError ):
5770 # If torch.serialization or classes are not available, continue without the fix
5871 # This allows backward compatibility with older PyTorch versions
You can’t perform that action at this time.
0 commit comments