1111from xtuner .v1 .utils .misc import monkey_patch_hf_modules_cache
1212from transformers import AutoTokenizer
1313from math import ceil
14+ import copy
15+ import random
16+ from loguru import logger
17+ import json
18+ import shutil
1419
1520
21+ CACHE_META = ".xpuyu-cache-meta.json"
22+
1623app = App ()
1724
1825
2835def start_ray (nnodes : int , cpus_per_task : int , memory_per_task : int ) -> JobSchema :
2936 assert ClusterParams is not None
3037 cmd_path = (Path (__file__ ).parent / "launch_ray.sh" ).absolute ()
38+
3139 clusterx_params = ClusterParams (
3240 cmd = str (cmd_path ),
3341 job_name = "offline_cache" ,
3442 cpus_per_task = cpus_per_task ,
3543 memory_per_task = memory_per_task ,
36- nnodes = nnodes ,
44+ num_nodes = nnodes ,
3745 no_env = True ,
3846 image = "registry.h.pjlab.org.cn/ailab-llmrazor/xtuner:pt28_latest" ,
3947 )
4048 job = cluster .run (clusterx_params )
49+
4150 while job .status != JobStatus .RUNNING :
4251 job = cluster .get_job_info (job .job_id )
52+
4353 return job
4454
4555
46- def cache_worker (dataset_config_list : DatasetConfigList , tokenizer_path ):
56+ def merge_worker_caches (base_cache_dir : Path , num_workers : int ) -> None :
57+ """Merge all worker cache directories into the base directory.
58+
59+ Args:
60+ base_cache_dir (Path): Base directory containing worker-{i} subdirectories
61+ num_workers (int): Number of worker directories to merge
62+ """
63+ merged_meta = {}
64+
65+ logger .info (f"Starting cache merge into { base_cache_dir } " )
66+
67+ for worker_id in range (num_workers ):
68+ worker_dir = base_cache_dir / f"worker-{ worker_id } "
69+ if not worker_dir .exists ():
70+ logger .warning (f"Worker directory { worker_dir } does not exist, skipping" )
71+ continue
72+
73+ meta_file = worker_dir / CACHE_META
74+ if not meta_file .exists ():
75+ logger .warning (f"Meta file { meta_file } does not exist, skipping worker-{ worker_id } " )
76+ continue
77+
78+ # Load worker's meta.json
79+ with open (meta_file ) as f :
80+ worker_meta = json .load (f )
81+
82+ logger .info (f"Processing worker-{ worker_id } : { len (worker_meta )} hash entries" )
83+
84+ # Copy each hash directory and merge metadata
85+ for hash_key , hash_meta in worker_meta .items ():
86+ src_hash_dir = worker_dir / hash_key
87+ dst_hash_dir = base_cache_dir / hash_key
88+
89+ if not src_hash_dir .exists ():
90+ logger .warning (f"Hash directory { src_hash_dir } does not exist, skipping" )
91+ continue
92+
93+ if dst_hash_dir .exists ():
94+ logger .warning (f"Hash directory { dst_hash_dir } already exists, skipping copy" )
95+ else :
96+ shutil .move (src_hash_dir , dst_hash_dir )
97+ logger .debug (f"Moved { src_hash_dir } -> { dst_hash_dir } " )
98+
99+ # Merge metadata (assuming no conflicts, otherwise we'd need merge logic)
100+ if hash_key in merged_meta :
101+ logger .warning (f"Hash key { hash_key } already exists in merged metadata, skipping" )
102+ else :
103+ merged_meta [hash_key ] = hash_meta
104+
105+ # Clean up worker directory
106+ try :
107+ shutil .rmtree (worker_dir )
108+ logger .info (f"Removed worker directory { worker_dir } " )
109+ except Exception as e :
110+ logger .error (f"Failed to remove worker directory { worker_dir } : { e } " )
111+
112+ # Write merged metadata
113+ merged_meta_file = base_cache_dir / CACHE_META
114+ with open (merged_meta_file , "w" ) as f :
115+ json .dump (merged_meta , f , indent = 4 )
116+
117+ logger .info (f"Cache merge completed: { len (merged_meta )} hash entries in { merged_meta_file } " )
118+
119+
120+ def cache_worker (dataset_config_list : DatasetConfigList , tokenizer_path , worker_id : int ):
121+ import socket
122+ import os
123+ from loguru import logger
124+
125+ hostname = socket .gethostname ()
126+ pid = os .getpid ()
127+ logger .info (
128+ f"[Worker { worker_id } ] Running on { hostname } , PID={ pid } , processing { len (dataset_config_list )} datasets"
129+ )
47130 monkey_patch_hf_modules_cache ()
48131 tokenizer = AutoTokenizer .from_pretrained (tokenizer_path )
49132 build_datasets (dataset_config_list , tokenizer )
133+ logger .info (f"[Worker { worker_id } ] Finished on { hostname } , PID={ pid } " )
50134
51135
52136@app .default
@@ -55,27 +139,67 @@ def main(
55139 nnodes : int ,
56140 cpus_per_task : int ,
57141 memory_per_task : int ,
142+ tasks_per_node : int = 1 ,
58143):
59144 job = start_ray (nnodes = nnodes , cpus_per_task = cpus_per_task , memory_per_task = memory_per_task )
60145 try :
61- print (f"ray clsuter job id { job .job_id } " )
146+ logger . info (f"ray clsuter job id { job .job_id } " )
62147 assert job .nodes_ip is not None
63148 ray_ip = f"ray://{ job .nodes_ip [0 ]} :10001"
64149
65- print (f"Connect to ray cluster: { ray_ip } " )
150+ logger . info (f"Connect to ray cluster: { ray_ip } " )
66151 ray .init (ray_ip )
67152
153+ resources = ray .cluster_resources ()
154+ logger .info (f"Ray cluster resources: CPU={ resources .get ('CPU' , 0 )} , nodes={ len (ray .nodes ())} " )
155+ logger .info (f"Ray nodes: { [node ['NodeManagerAddress' ] for node in ray .nodes ()]} " )
156+
68157 config = Config .fromfile (config_path )
69158 dataset_config = config .dataset_config
70159 tokenizer_path = config .trainer .tokenizer_path
160+ base_cache_dir = Path (dataset_config [0 ]["dataset" ].cache_dir )
71161
72162 global cache_worker
73- worker = ray .remote (num_cpus = cpus_per_task )(cache_worker )
74- batch_size = ceil (len (dataset_config ) / nnodes )
163+ worker = ray .remote (
164+ num_cpus = cpus_per_task // tasks_per_node ,
165+ runtime_env = {
166+ "env_vars" : {
167+ "XTUNER_TOKENIZE_WORKERS" : str (cpus_per_task // tasks_per_node ),
168+ "PYTHONPATH" : str ((Path (__file__ ).parent .parent ).absolute ()),
169+ }
170+ },
171+ )(cache_worker )
172+ batch_size = ceil (len (dataset_config ) / (nnodes * tasks_per_node ))
75173 batch_config_list = list (batched (dataset_config , batch_size ))
174+ random .shuffle (batch_config_list )
175+
176+ res = []
177+ logger .info (f"Total tasks: { len (batch_config_list )} (nnodes={ nnodes } , tasks_per_node={ tasks_per_node } )" )
178+ logger .info (f"Batch size: { batch_size } , Total datasets: { len (dataset_config )} " )
179+ for i , batch in enumerate (batch_config_list ):
180+ # each worker should cache tokenized meta data to different paths to avoid CACHE_META file IO conflicts
181+ batch_copy = copy .deepcopy (list (batch ))
182+
183+ for dataset_cfg in batch_copy :
184+ if hasattr (dataset_cfg ["dataset" ], "cache_dir" ) and dataset_cfg ["dataset" ].cache_dir is not None :
185+ original_cache_dir = Path (dataset_cfg ["dataset" ].cache_dir )
186+ worker_cache_dir = original_cache_dir / f"worker-{ i } "
187+ dataset_cfg ["dataset" ].cache_dir = str (worker_cache_dir )
188+
189+ res .append (worker .remote (batch_copy , tokenizer_path , i ))
190+
191+ logger .info (f"Submitted task { i + 1 } /{ len (batch_config_list )} " )
192+
193+ logger .info (f"\n Waiting for { len (res )} tasks to complete..." )
194+ for i , obj in enumerate (res ):
195+ ray .get (obj )
196+ logger .info (f"Task { i + 1 } /{ len (res )} completed" )
197+
198+ # Merge worker caches after all tasks complete
199+ logger .info ("All tasks completed, starting cache merge..." )
200+
201+ merge_worker_caches (base_cache_dir , num_workers = len (batch_config_list ))
76202
77- for batch in batch_config_list :
78- worker .remote (batch , tokenizer_path )
79203 finally :
80204 cluster .stop (job_id = job .job_id )
81205
0 commit comments