|
5 | 5 | import platform |
6 | 6 | import sys |
7 | 7 | import tempfile |
| 8 | +import time |
8 | 9 | import urllib.request |
9 | 10 | from pathlib import Path |
10 | 11 | from typing import Any, Optional |
@@ -144,47 +145,57 @@ def _extracted_dir_trtllm(platform_system: str, platform_machine: str) -> Path: |
144 | 145 |
|
145 | 146 |
|
146 | 147 | def extract_wheel_file(wheel_path: Path, extract_dir: Path) -> None: |
147 | | - # this will not be encountered in case of platforms not supporting torch distributed/nccl/TRT-LLM |
148 | | - from torch.distributed import barrier, get_rank, is_initialized |
149 | | - |
150 | | - if not is_initialized(): |
151 | | - # Single process case, just unzip |
152 | | - is_master = True |
153 | | - else: |
154 | | - is_master = get_rank() == 0 # only rank 0 does the unzip |
155 | | - |
156 | | - if is_master: |
| 148 | + """ |
| 149 | + Safely extract a wheel file to a directory with a lock to prevent concurrent extraction. |
| 150 | + """ |
| 151 | + rank = int(os.environ.get("OMPI_COMM_WORLD_RANK", 0)) # MPI rank from OpenMPI |
| 152 | + torch.cuda.set_device(rank) |
| 153 | + lock_file = extract_dir / ".extracting" |
| 154 | + |
| 155 | + # Rank 0 performs extraction |
| 156 | + if rank == 0: |
| 157 | + logger.debug( |
| 158 | + f"[Rank {rank}] Starting extraction of {wheel_path} to {extract_dir}" |
| 159 | + ) |
157 | 160 | try: |
158 | 161 | import zipfile |
159 | 162 | except ImportError as e: |
160 | 163 | raise ImportError( |
161 | 164 | "zipfile module is required but not found. Please install zipfile" |
162 | 165 | ) |
| 166 | + # Create lock file to signal extraction in progress |
| 167 | + extract_dir.mkdir(parents=True, exist_ok=False) |
| 168 | + lock_file.touch(exist_ok=False) |
163 | 169 | try: |
164 | 170 | with zipfile.ZipFile(wheel_path) as zip_ref: |
165 | 171 | zip_ref.extractall(extract_dir) |
166 | | - logger.debug(f"Extracted wheel to {extract_dir}") |
167 | | - |
| 172 | + logger.debug(f"[Rank {rank}] Extraction complete: {extract_dir}") |
168 | 173 | except FileNotFoundError as e: |
169 | | - # This should capture the errors in the download failure above |
170 | | - logger.error(f"Wheel file not found at {wheel_path}: {e}") |
| 174 | + logger.error(f"[Rank {rank}] Wheel file not found at {wheel_path}: {e}") |
171 | 175 | raise RuntimeError( |
172 | 176 | f"Failed to find downloaded wheel file at {wheel_path}" |
173 | 177 | ) from e |
174 | 178 | except zipfile.BadZipFile as e: |
175 | | - logger.error(f"Invalid or corrupted wheel file: {e}") |
| 179 | + logger.error(f"[Rank {rank}] Invalid or corrupted wheel file: {e}") |
176 | 180 | raise RuntimeError( |
177 | 181 | "Downloaded wheel file is corrupted or not a valid zip archive" |
178 | 182 | ) from e |
179 | 183 | except Exception as e: |
180 | | - logger.error(f"Unexpected error while extracting wheel: {e}") |
| 184 | + logger.error(f"[Rank {rank}] Unexpected error while extracting wheel: {e}") |
181 | 185 | raise RuntimeError( |
182 | 186 | "Unexpected error during extraction of TensorRT-LLM wheel" |
183 | 187 | ) from e |
| 188 | + finally: |
| 189 | + # Remove lock file to signal completion |
| 190 | + lock_file.unlink(missing_ok=True) |
184 | 191 |
|
185 | | - # Make sure others wait until unzip is done |
186 | | - if is_initialized(): |
187 | | - barrier() |
| 192 | + else: |
| 193 | + # Other ranks wait for extraction to complete |
| 194 | + while lock_file.exists(): |
| 195 | + logger.debug( |
| 196 | + f"[Rank {rank}] Waiting for extraction to finish at {extract_dir}..." |
| 197 | + ) |
| 198 | + time.sleep(0.5) |
188 | 199 |
|
189 | 200 |
|
190 | 201 | def download_and_get_plugin_lib_path() -> Optional[str]: |
@@ -216,7 +227,6 @@ def download_and_get_plugin_lib_path() -> Optional[str]: |
216 | 227 | return str(plugin_lib_path) |
217 | 228 |
|
218 | 229 | wheel_path.parent.mkdir(parents=True, exist_ok=True) |
219 | | - extract_dir.mkdir(parents=True, exist_ok=True) |
220 | 230 |
|
221 | 231 | if not wheel_path.exists(): |
222 | 232 | base_url = "https://pypi.nvidia.com/tensorrt-llm/" |
|
0 commit comments