diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index 86840722f0..5c6b5991c3 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -368,6 +368,106 @@ def test_register_operations(self): for operation in operations: assert operation in dir(app) + @mock.patch("os.rename") + @mock.patch("tempfile.NamedTemporaryFile") + @mock.patch("time.time", return_value=1000.0) + @mock.patch("os.getpid", return_value=12345) + @mock.patch("os.path.exists", return_value=True) + @mock.patch("os.path.isdir", return_value=True) + def test__update_keep_alive_timestamp( + self, + isdir_mock, + exists_mock, + getpid_mock, + time_mock, + tempfile_mock, + rename_mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + + mock_tf_ret = mock.MagicMock() # this is result of NamedTemporaryFile call + mock_tf_ret.name = "/dev/shm/tmp_xyz123" + mock_tf_ret.__enter__.return_value = mock_tf_ret + mock_tf_ret.__exit__.return_value = (None, None, None) + tempfile_mock.return_value = mock_tf_ret + + app._update_keep_alive_timestamp() + + tempfile_mock.assert_called_once_with( + "w", + dir="/dev/shm", + delete=False, + prefix="tmp_keep_alive_timestamp_12345_", + ) + pid = 12345 + lease = 60 * 60 + expected_timestamp = str(1000.0 + lease) + mock_tf_ret.write.assert_called_once_with(expected_timestamp) + rename_mock.assert_called_once_with( + "/dev/shm/tmp_xyz123", f"/dev/shm/keep_alive_timestamp_{pid}" + ) + + @mock.patch("glob.glob") + @mock.patch("os.kill") + @mock.patch("os.remove") + @mock.patch("time.time") + def test_keep_alive_no_files(self, time_mock, remove_mock, kill_mock, glob_mock): + glob_mock.return_value = [] + time_mock.return_value = 1000.0 + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert not app.keep_alive() + + @mock.patch("glob.glob") + @mock.patch("os.kill") + @mock.patch("os.remove") + @mock.patch("time.time") + def test_keep_alive_one_file_busy( + self, time_mock, remove_mock, kill_mock, glob_mock + ): + pid = 12345 + glob_mock.return_value = [f"/dev/shm/keep_alive_timestamp_{pid}"] + time_mock.return_value = 1500.0 + mock_read_data = str(2000.0) # Timestamp in file is in future + with mock.patch("builtins.open", mock.mock_open(read_data=mock_read_data)): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app.keep_alive() + kill_mock.assert_called_once_with(pid, 0) + remove_mock.assert_not_called() + + @mock.patch("glob.glob") + @mock.patch("os.kill") + @mock.patch("os.remove") + @mock.patch("time.time") + def test_keep_alive_one_file_not_busy( + self, time_mock, remove_mock, kill_mock, glob_mock + ): + pid = 12345 + glob_mock.return_value = [f"/dev/shm/keep_alive_timestamp_{pid}"] + time_mock.return_value = 2500.0 + mock_read_data = str(2000.0) # Timestamp in file is in past + with mock.patch("builtins.open", mock.mock_open(read_data=mock_read_data)): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert not app.keep_alive() + kill_mock.assert_called_once_with(pid, 0) + remove_mock.assert_not_called() + + @mock.patch("glob.glob") + @mock.patch("os.kill") + @mock.patch("os.remove") + @mock.patch("time.time") + def test_keep_alive_stale_file_process_dead( + self, time_mock, remove_mock, kill_mock, glob_mock + ): + pid = 12345 + stale_file = f"/dev/shm/keep_alive_timestamp_{pid}" + glob_mock.return_value = [stale_file] + kill_mock.side_effect = ProcessLookupError() + time_mock.return_value = 1000.0 + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert not app.keep_alive() + kill_mock.assert_called_once_with(pid, 0) + remove_mock.assert_called_once_with(stale_file) + def test_stream_query( self, default_instrumentor_builder_mock: mock.Mock, diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index f52b46f1ea..635f25838f 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -95,6 +95,10 @@ _DEFAULT_APP_NAME = "default-app-name" _DEFAULT_USER_ID = "default-user-id" +_KEEP_ALIVE_DIR = "/dev/shm" +_KEEP_ALIVE_FILENAME_PREFIX = "keep_alive_timestamp" +_KEEP_ALIVE_TEMP_FILENAME_PREFIX = "tmp_keep_alive_timestamp" +_KEEP_ALIVE_LEASE_SECONDS = 60 #TODO: Change to 60 * 60 for 1 hour. _TELEMETRY_API_DISABLED_WARNING = ( "Tracing integration for Agent Engine has migrated to a new API.\n" "The 'telemetry.googleapis.com' has not been enabled in project %s. \n" @@ -992,6 +996,77 @@ def set_up(self): memory_service=self._tmpl_attrs.get("in_memory_memory_service"), ) + def _update_keep_alive_timestamp(self): + """Updates the keep-alive timestamp. + + It writes the current timestamp to a file + /dev/shm/keep_alive_timestamp_{pid} where pid is the process id. + This is done atomically by writing to a temporary file and then renaming it. + This file can be checked by other processes to see if this agent process + is still alive and processing requests. + """ + import os + import tempfile + import time + + try: + pid = os.getpid() + timestamp = str(time.time() + _KEEP_ALIVE_LEASE_SECONDS) + filename = f"{_KEEP_ALIVE_DIR}/{_KEEP_ALIVE_FILENAME_PREFIX}_{pid}" + tmp_dir = _KEEP_ALIVE_DIR + if not os.path.exists(tmp_dir) or not os.path.isdir(tmp_dir): + return + with tempfile.NamedTemporaryFile( + "w", + dir=tmp_dir, + delete=False, + prefix=f"{_KEEP_ALIVE_TEMP_FILENAME_PREFIX}_{pid}_", + ) as fp: + fp.write(timestamp) + tmp_path = fp.name + os.rename(tmp_path, filename) + except Exception as e: + # If there's any issue writing the timestamp, we log a warning + # and ignore it. + _warn(f"Failed to update keep-alive timestamp: {e}") + + def keep_alive(self) -> bool: + """Checks if the agent is busy.""" + import glob + import os + import time + + max_timestamp = -1.0 + try: + timestamp_files = glob.glob( + f"{_KEEP_ALIVE_DIR}/{_KEEP_ALIVE_FILENAME_PREFIX}_*" + ) + for timestamp_file in timestamp_files: + try: + # Extract PID from filename (e.g., keep_alive_timestamp_1234) + basename = os.path.basename(timestamp_file) + pid_str = basename[len(_KEEP_ALIVE_FILENAME_PREFIX) + 1 :] + pid = int(pid_str) + + # Check if the process that created the file is still running + os.kill(pid, 0) + + with open(timestamp_file, "r") as f: + timestamp = float(f.read()) + if timestamp > max_timestamp: + max_timestamp = timestamp + except (ProcessLookupError, ValueError, FileNotFoundError): + # If process is dead or file is missing/corrupt, remove the stale file + try: + os.remove(timestamp_file) + except FileNotFoundError: + pass + continue + except Exception as e: + _warn(f"Failed to read timestamp files: {e}") + + return time.time() <= max_timestamp + async def async_stream_query( self, *, @@ -1035,6 +1110,7 @@ async def async_stream_query( from vertexai.agent_engines import _utils from google.genai import types + self._update_keep_alive_timestamp() if isinstance(message, Dict): content = types.Content.model_validate(message) elif isinstance(message, str): @@ -1140,6 +1216,7 @@ def stream_query( from vertexai.agent_engines import _utils from google.genai import types + self._update_keep_alive_timestamp() if isinstance(message, Dict): content = types.Content.model_validate(message) elif isinstance(message, str): @@ -1191,6 +1268,7 @@ async def streaming_agent_run_with_events(self, request_json: str): from google.genai import types from google.genai.errors import ClientError + self._update_keep_alive_timestamp() request = _StreamRunRequest(**json.loads(request_json)) if not any( self._tmpl_attrs.get(service) @@ -1659,6 +1737,7 @@ def register_operations(self) -> Dict[str, List[str]]: "async_stream_query", "streaming_agent_run_with_events", ], + "keep_alive": ["keep_alive"], } def _telemetry_enabled(self) -> Optional[bool]: