diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 7fcd6cd0db..cc3a02f79d 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -184,6 +184,9 @@ def _create_engine(self, llm_kwargs: dict[str, Any]) -> None: self.server_thread, self.base_url, self.http_server = None, None, None if self.cfg["vllm_cfg"].get("expose_http_server"): + # Must run after AsyncLLM.from_engine_args and before + # _setup_vllm_server spawns the uvicorn thread. + self._install_engine_input_socket_lock() self.server_thread, self.base_url, self.http_server = ( self._setup_vllm_server() ) @@ -194,6 +197,24 @@ def _create_engine(self, llm_kwargs: dict[str, Any]) -> None: if self.cfg["vllm_cfg"].get("enable_vllm_metrics_logger", False): self._start_vllm_metrics_logger() + def _install_engine_input_socket_lock(self) -> None: + """Serialise sends on AsyncMPClient.input_socket across OS threads + to prevent race conditions that block the vLLM engine (e.g. during + in flight weight updates in async grpo). + """ + shadow_sock = self.llm.engine_core.input_socket._shadow_sock + + lock = threading.Lock() + original_send_multipart = shadow_sock.send_multipart + + def locked_send_multipart(*args: Any, **kwargs: Any) -> Any: + with lock: + return original_send_multipart(*args, **kwargs) + + # Replace the bound method on this socket instance only; other zmq + # sockets in the process are unaffected. + shadow_sock.send_multipart = locked_send_multipart # type: ignore[assignment] + def _start_vllm_metrics_logger(self) -> None: """Start a background thread that periodically collects vLLM logger metrics.