Skip to content

Commit de28cfb

Browse files
committed
[Feat] fix and enhance offline caching script
1 parent 441aeaf commit de28cfb

2 files changed

Lines changed: 137 additions & 12 deletions

File tree

.dev_scripts/offline_cache.py

Lines changed: 132 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,15 @@
1111
from xtuner.v1.utils.misc import monkey_patch_hf_modules_cache
1212
from transformers import AutoTokenizer
1313
from math import ceil
14+
import copy
15+
import random
16+
from loguru import logger
17+
import json
18+
import shutil
1419

1520

21+
CACHE_META = ".xpuyu-cache-meta.json"
22+
1623
app = App()
1724

1825

@@ -28,25 +35,102 @@
2835
def start_ray(nnodes: int, cpus_per_task: int, memory_per_task: int) -> JobSchema:
2936
assert ClusterParams is not None
3037
cmd_path = (Path(__file__).parent / "launch_ray.sh").absolute()
38+
3139
clusterx_params = ClusterParams(
3240
cmd=str(cmd_path),
3341
job_name="offline_cache",
3442
cpus_per_task=cpus_per_task,
3543
memory_per_task=memory_per_task,
36-
nnodes=nnodes,
44+
num_nodes=nnodes,
3745
no_env=True,
3846
image="registry.h.pjlab.org.cn/ailab-llmrazor/xtuner:pt28_latest",
3947
)
4048
job = cluster.run(clusterx_params)
49+
4150
while job.status != JobStatus.RUNNING:
4251
job = cluster.get_job_info(job.job_id)
52+
4353
return job
4454

4555

46-
def cache_worker(dataset_config_list: DatasetConfigList, tokenizer_path):
56+
def merge_worker_caches(base_cache_dir: Path, num_workers: int) -> None:
57+
"""Merge all worker cache directories into the base directory.
58+
59+
Args:
60+
base_cache_dir (Path): Base directory containing worker-{i} subdirectories
61+
num_workers (int): Number of worker directories to merge
62+
"""
63+
merged_meta = {}
64+
65+
logger.info(f"Starting cache merge into {base_cache_dir}")
66+
67+
for worker_id in range(num_workers):
68+
worker_dir = base_cache_dir / f"worker-{worker_id}"
69+
if not worker_dir.exists():
70+
logger.warning(f"Worker directory {worker_dir} does not exist, skipping")
71+
continue
72+
73+
meta_file = worker_dir / CACHE_META
74+
if not meta_file.exists():
75+
logger.warning(f"Meta file {meta_file} does not exist, skipping worker-{worker_id}")
76+
continue
77+
78+
# Load worker's meta.json
79+
with open(meta_file) as f:
80+
worker_meta = json.load(f)
81+
82+
logger.info(f"Processing worker-{worker_id}: {len(worker_meta)} hash entries")
83+
84+
# Copy each hash directory and merge metadata
85+
for hash_key, hash_meta in worker_meta.items():
86+
src_hash_dir = worker_dir / hash_key
87+
dst_hash_dir = base_cache_dir / hash_key
88+
89+
if not src_hash_dir.exists():
90+
logger.warning(f"Hash directory {src_hash_dir} does not exist, skipping")
91+
continue
92+
93+
if dst_hash_dir.exists():
94+
logger.warning(f"Hash directory {dst_hash_dir} already exists, skipping copy")
95+
else:
96+
shutil.move(src_hash_dir, dst_hash_dir)
97+
logger.debug(f"Moved {src_hash_dir} -> {dst_hash_dir}")
98+
99+
# Merge metadata (assuming no conflicts, otherwise we'd need merge logic)
100+
if hash_key in merged_meta:
101+
logger.warning(f"Hash key {hash_key} already exists in merged metadata, skipping")
102+
else:
103+
merged_meta[hash_key] = hash_meta
104+
105+
# Clean up worker directory
106+
try:
107+
shutil.rmtree(worker_dir)
108+
logger.info(f"Removed worker directory {worker_dir}")
109+
except Exception as e:
110+
logger.error(f"Failed to remove worker directory {worker_dir}: {e}")
111+
112+
# Write merged metadata
113+
merged_meta_file = base_cache_dir / CACHE_META
114+
with open(merged_meta_file, "w") as f:
115+
json.dump(merged_meta, f, indent=4)
116+
117+
logger.info(f"Cache merge completed: {len(merged_meta)} hash entries in {merged_meta_file}")
118+
119+
120+
def cache_worker(dataset_config_list: DatasetConfigList, tokenizer_path, worker_id: int):
121+
import socket
122+
import os
123+
from loguru import logger
124+
125+
hostname = socket.gethostname()
126+
pid = os.getpid()
127+
logger.info(
128+
f"[Worker {worker_id}] Running on {hostname}, PID={pid}, processing {len(dataset_config_list)} datasets"
129+
)
47130
monkey_patch_hf_modules_cache()
48131
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
49132
build_datasets(dataset_config_list, tokenizer)
133+
logger.info(f"[Worker {worker_id}] Finished on {hostname}, PID={pid}")
50134

51135

52136
@app.default
@@ -55,27 +139,67 @@ def main(
55139
nnodes: int,
56140
cpus_per_task: int,
57141
memory_per_task: int,
142+
tasks_per_node: int = 1,
58143
):
59144
job = start_ray(nnodes=nnodes, cpus_per_task=cpus_per_task, memory_per_task=memory_per_task)
60145
try:
61-
print(f"ray clsuter job id {job.job_id}")
146+
logger.info(f"ray clsuter job id {job.job_id}")
62147
assert job.nodes_ip is not None
63148
ray_ip = f"ray://{job.nodes_ip[0]}:10001"
64149

65-
print(f"Connect to ray cluster: {ray_ip}")
150+
logger.info(f"Connect to ray cluster: {ray_ip}")
66151
ray.init(ray_ip)
67152

153+
resources = ray.cluster_resources()
154+
logger.info(f"Ray cluster resources: CPU={resources.get('CPU', 0)}, nodes={len(ray.nodes())}")
155+
logger.info(f"Ray nodes: {[node['NodeManagerAddress'] for node in ray.nodes()]}")
156+
68157
config = Config.fromfile(config_path)
69158
dataset_config = config.dataset_config
70159
tokenizer_path = config.trainer.tokenizer_path
160+
base_cache_dir = Path(dataset_config[0]["dataset"].cache_dir)
71161

72162
global cache_worker
73-
worker = ray.remote(num_cpus=cpus_per_task)(cache_worker)
74-
batch_size = ceil(len(dataset_config) / nnodes)
163+
worker = ray.remote(
164+
num_cpus=cpus_per_task // tasks_per_node,
165+
runtime_env={
166+
"env_vars": {
167+
"XTUNER_TOKENIZE_WORKERS": str(cpus_per_task // tasks_per_node),
168+
"PYTHONPATH": str((Path(__file__).parent.parent).absolute()),
169+
}
170+
},
171+
)(cache_worker)
172+
batch_size = ceil(len(dataset_config) / (nnodes * tasks_per_node))
75173
batch_config_list = list(batched(dataset_config, batch_size))
174+
random.shuffle(batch_config_list)
175+
176+
res = []
177+
logger.info(f"Total tasks: {len(batch_config_list)} (nnodes={nnodes}, tasks_per_node={tasks_per_node})")
178+
logger.info(f"Batch size: {batch_size}, Total datasets: {len(dataset_config)}")
179+
for i, batch in enumerate(batch_config_list):
180+
# each worker should cache tokenized meta data to different paths to avoid CACHE_META file IO conflicts
181+
batch_copy = copy.deepcopy(list(batch))
182+
183+
for dataset_cfg in batch_copy:
184+
if hasattr(dataset_cfg["dataset"], "cache_dir") and dataset_cfg["dataset"].cache_dir is not None:
185+
original_cache_dir = Path(dataset_cfg["dataset"].cache_dir)
186+
worker_cache_dir = original_cache_dir / f"worker-{i}"
187+
dataset_cfg["dataset"].cache_dir = str(worker_cache_dir)
188+
189+
res.append(worker.remote(batch_copy, tokenizer_path, i))
190+
191+
logger.info(f"Submitted task {i + 1}/{len(batch_config_list)}")
192+
193+
logger.info(f"\nWaiting for {len(res)} tasks to complete...")
194+
for i, obj in enumerate(res):
195+
ray.get(obj)
196+
logger.info(f"Task {i + 1}/{len(res)} completed")
197+
198+
# Merge worker caches after all tasks complete
199+
logger.info("All tasks completed, starting cache merge...")
200+
201+
merge_worker_caches(base_cache_dir, num_workers=len(batch_config_list))
76202

77-
for batch in batch_config_list:
78-
worker.remote(batch, tokenizer_path)
79203
finally:
80204
cluster.stop(job_id=job.job_id)
81205

xtuner/v1/datasets/jsonl.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ def tokenize_worker(
7070
out_queue: Queue,
7171
cpu_ids: list[int],
7272
):
73-
try:
74-
os.sched_setaffinity(os.getpid(), cpu_ids)
75-
except OSError as e:
76-
logger.debug(f"Failed to set CPU affinity: {e}")
73+
# For offline caching script to work, cpu affinity has to be turned off.
74+
# try:
75+
# os.sched_setaffinity(os.getpid(), cpu_ids)
76+
# except OSError as e:
77+
# logger.debug(f"Failed to set CPU affinity: {e}")
7778

7879
shared_memory = SharedMemory(name=shm_name, create=False)
7980
while True:

0 commit comments

Comments
 (0)