diff --git a/diffsynth/core/loader/file.py b/diffsynth/core/loader/file.py index 8f66961f..bf4ec114 100644 --- a/diffsynth/core/loader/file.py +++ b/diffsynth/core/loader/file.py @@ -15,6 +15,10 @@ def load_state_dict(file_path, torch_dtype=None, device="cpu"): def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): + import torch.distributed as dist + from diffsynth.core.device.npu_compatible_device import get_device_name, IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE + if dist.is_available() and dist.is_initialized() and (IS_CUDA_AVAILABLE or IS_NPU_AVAILABLE): + device = get_device_name() state_dict = {} with safe_open(file_path, framework="pt", device=str(device)) as f: for k in f.keys(): diff --git a/diffsynth/core/vram/disk_map.py b/diffsynth/core/vram/disk_map.py index a666590f..0fe15f76 100644 --- a/diffsynth/core/vram/disk_map.py +++ b/diffsynth/core/vram/disk_map.py @@ -28,6 +28,10 @@ def get_slice(self, name): class DiskMap: def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9): + import torch.distributed as dist + from diffsynth.core.device.npu_compatible_device import get_device_name, IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE + if dist.is_available() and dist.is_initialized() and (IS_CUDA_AVAILABLE or IS_NPU_AVAILABLE): + device = get_device_name() self.path = path if isinstance(path, list) else [path] self.device = device self.torch_dtype = torch_dtype