From 24bc8d15a6d6b28634a1a7a938df08d331fa1f82 Mon Sep 17 00:00:00 2001 From: hanhainebula <2512674094@qq.com> Date: Tue, 23 Sep 2025 23:02:26 +0800 Subject: [PATCH] fix bug: safe dist.get_rank() --- FlagEmbedding/abc/finetune/embedder/AbsDataset.py | 6 ++++-- FlagEmbedding/abc/finetune/embedder/AbsModeling.py | 4 ++-- FlagEmbedding/abc/finetune/reranker/AbsDataset.py | 3 ++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/FlagEmbedding/abc/finetune/embedder/AbsDataset.py b/FlagEmbedding/abc/finetune/embedder/AbsDataset.py index 26430330..4da7dbee 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsDataset.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsDataset.py @@ -63,7 +63,8 @@ def _load_dataset(self, file_path: str): Returns: datasets.Dataset: Loaded HF dataset. """ - if dist.get_rank() == 0: + safe_rank = dist.get_rank() if dist.is_initialized() else 0 + if safe_rank == 0: logger.info(f'loading data from {file_path} ...') temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path) @@ -342,7 +343,8 @@ def _load_dataset(self, file_path: str): Returns: datasets.Dataset: The loaded dataset. """ - if dist.get_rank() == 0: + safe_rank = dist.get_rank() if dist.is_initialized() else 0 + if safe_rank == 0: logger.info(f'loading data from {file_path} ...') temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path) diff --git a/FlagEmbedding/abc/finetune/embedder/AbsModeling.py b/FlagEmbedding/abc/finetune/embedder/AbsModeling.py index 9ba9e829..de8c80b0 100644 --- a/FlagEmbedding/abc/finetune/embedder/AbsModeling.py +++ b/FlagEmbedding/abc/finetune/embedder/AbsModeling.py @@ -54,8 +54,8 @@ def __init__( if self.negatives_cross_device: if not dist.is_initialized(): raise ValueError('Distributed training has not been initialized for representation all gather.') - self.process_rank = dist.get_rank() - self.world_size = dist.get_world_size() + self.process_rank = dist.get_rank() if dist.is_initialized() else 0 + self.world_size = dist.get_world_size() if dist.is_initialized() else 1 self.sub_batch_size = sub_batch_size self.kd_loss_type = kd_loss_type diff --git a/FlagEmbedding/abc/finetune/reranker/AbsDataset.py b/FlagEmbedding/abc/finetune/reranker/AbsDataset.py index 22389770..73830bbb 100644 --- a/FlagEmbedding/abc/finetune/reranker/AbsDataset.py +++ b/FlagEmbedding/abc/finetune/reranker/AbsDataset.py @@ -64,7 +64,8 @@ def _load_dataset(self, file_path: str): Returns: datasets.Dataset: Loaded HF dataset. """ - if dist.get_rank() == 0: + safe_rank = dist.get_rank() if dist.is_initialized() else 0 + if safe_rank == 0: logger.info(f'loading data from {file_path} ...') temp_dataset = datasets.load_dataset('json', data_files=file_path, split='train', cache_dir=self.args.cache_path)