Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions nemo_rl/models/generation/vllm/vllm_worker_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def _create_engine(self, llm_kwargs: dict[str, Any]) -> None:

self.server_thread, self.base_url, self.http_server = None, None, None
if self.cfg["vllm_cfg"].get("expose_http_server"):
# Must run after AsyncLLM.from_engine_args and before
# _setup_vllm_server spawns the uvicorn thread.
self._install_engine_input_socket_lock()
self.server_thread, self.base_url, self.http_server = (
self._setup_vllm_server()
)
Expand All @@ -194,6 +197,24 @@ def _create_engine(self, llm_kwargs: dict[str, Any]) -> None:
if self.cfg["vllm_cfg"].get("enable_vllm_metrics_logger", False):
self._start_vllm_metrics_logger()

def _install_engine_input_socket_lock(self) -> None:
"""Serialise sends on AsyncMPClient.input_socket across OS threads
to prevent race conditions that block the vLLM engine (e.g. during
in flight weight updates in async grpo).
"""
shadow_sock = self.llm.engine_core.input_socket._shadow_sock

lock = threading.Lock()
original_send_multipart = shadow_sock.send_multipart

def locked_send_multipart(*args: Any, **kwargs: Any) -> Any:
with lock:
return original_send_multipart(*args, **kwargs)

# Replace the bound method on this socket instance only; other zmq
# sockets in the process are unaffected.
shadow_sock.send_multipart = locked_send_multipart # type: ignore[assignment]

def _start_vllm_metrics_logger(self) -> None:
"""Start a background thread that periodically collects vLLM logger metrics.

Expand Down
Loading