From 593b3cc28263791dcbf8167bad54308623512e2a Mon Sep 17 00:00:00 2001 From: yuan Date: Sat, 10 Jan 2026 12:12:27 +0800 Subject: [PATCH 1/3] fix multi-gpu model loading bug --- diffsynth/core/vram/disk_map.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/diffsynth/core/vram/disk_map.py b/diffsynth/core/vram/disk_map.py index a666590fa..5aa286e05 100644 --- a/diffsynth/core/vram/disk_map.py +++ b/diffsynth/core/vram/disk_map.py @@ -44,6 +44,9 @@ def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, bu self.rename_dict = self.fetch_rename_dict(state_dict_converter) def flush_files(self): + import torch.distributed as dist + if dist.is_available() and dist.is_initialized() and str(self.device).strip() == "cuda": + self.device = f"cuda:{torch.cuda.current_device()}" if len(self.files) == 0: for path in self.path: if path.endswith(".safetensors"): From c6aab9e2d43483f09075ba5dd5d1b8a5c63b0686 Mon Sep 17 00:00:00 2001 From: yuan Date: Sat, 10 Jan 2026 15:39:15 +0800 Subject: [PATCH 2/3] optimize code structure --- diffsynth/core/vram/disk_map.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/diffsynth/core/vram/disk_map.py b/diffsynth/core/vram/disk_map.py index 5aa286e05..f95ee4179 100644 --- a/diffsynth/core/vram/disk_map.py +++ b/diffsynth/core/vram/disk_map.py @@ -30,6 +30,9 @@ class DiskMap: def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9): self.path = path if isinstance(path, list) else [path] self.device = device + import torch.distributed as dist + if dist.is_available() and dist.is_initialized() and str(self.device).strip() == "cuda": + self.device = f"cuda:{torch.cuda.current_device()}" self.torch_dtype = torch_dtype if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None: self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE')) @@ -44,9 +47,6 @@ def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, bu self.rename_dict = self.fetch_rename_dict(state_dict_converter) def flush_files(self): - import torch.distributed as dist - if dist.is_available() and dist.is_initialized() and str(self.device).strip() == "cuda": - self.device = f"cuda:{torch.cuda.current_device()}" if len(self.files) == 0: for path in self.path: if path.endswith(".safetensors"): From 0bbf574b7f16323b8a683860b20d080358d881dd Mon Sep 17 00:00:00 2001 From: yuan Date: Mon, 12 Jan 2026 13:10:03 +0800 Subject: [PATCH 3/3] add npu support; update load_state_dict_from_safetensors --- diffsynth/core/loader/file.py | 4 ++++ diffsynth/core/vram/disk_map.py | 7 ++++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/diffsynth/core/loader/file.py b/diffsynth/core/loader/file.py index 8f66961f2..bf4ec1141 100644 --- a/diffsynth/core/loader/file.py +++ b/diffsynth/core/loader/file.py @@ -15,6 +15,10 @@ def load_state_dict(file_path, torch_dtype=None, device="cpu"): def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): + import torch.distributed as dist + from diffsynth.core.device.npu_compatible_device import get_device_name, IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE + if dist.is_available() and dist.is_initialized() and (IS_CUDA_AVAILABLE or IS_NPU_AVAILABLE): + device = get_device_name() state_dict = {} with safe_open(file_path, framework="pt", device=str(device)) as f: for k in f.keys(): diff --git a/diffsynth/core/vram/disk_map.py b/diffsynth/core/vram/disk_map.py index f95ee4179..0fe15f761 100644 --- a/diffsynth/core/vram/disk_map.py +++ b/diffsynth/core/vram/disk_map.py @@ -28,11 +28,12 @@ def get_slice(self, name): class DiskMap: def __init__(self, path, device, torch_dtype=None, state_dict_converter=None, buffer_size=10**9): + import torch.distributed as dist + from diffsynth.core.device.npu_compatible_device import get_device_name, IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE + if dist.is_available() and dist.is_initialized() and (IS_CUDA_AVAILABLE or IS_NPU_AVAILABLE): + device = get_device_name() self.path = path if isinstance(path, list) else [path] self.device = device - import torch.distributed as dist - if dist.is_available() and dist.is_initialized() and str(self.device).strip() == "cuda": - self.device = f"cuda:{torch.cuda.current_device()}" self.torch_dtype = torch_dtype if os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE') is not None: self.buffer_size = int(os.environ.get('DIFFSYNTH_DISK_MAP_BUFFER_SIZE'))