Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/unit/vertex_adk/test_agent_engine_templates_adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,106 @@ def test_register_operations(self):
for operation in operations:
assert operation in dir(app)

@mock.patch("os.rename")
@mock.patch("tempfile.NamedTemporaryFile")
@mock.patch("time.time", return_value=1000.0)
@mock.patch("os.getpid", return_value=12345)
@mock.patch("os.path.exists", return_value=True)
@mock.patch("os.path.isdir", return_value=True)
def test__update_keep_alive_timestamp(
self,
isdir_mock,
exists_mock,
getpid_mock,
time_mock,
tempfile_mock,
rename_mock,
):
app = agent_engines.AdkApp(agent=_TEST_AGENT)

mock_tf_ret = mock.MagicMock() # this is result of NamedTemporaryFile call
mock_tf_ret.name = "/dev/shm/tmp_xyz123"
mock_tf_ret.__enter__.return_value = mock_tf_ret
mock_tf_ret.__exit__.return_value = (None, None, None)
tempfile_mock.return_value = mock_tf_ret

app._update_keep_alive_timestamp()

tempfile_mock.assert_called_once_with(
"w",
dir="/dev/shm",
delete=False,
prefix="tmp_keep_alive_timestamp_12345_",
)
pid = 12345
lease = 60 * 60
expected_timestamp = str(1000.0 + lease)
mock_tf_ret.write.assert_called_once_with(expected_timestamp)
rename_mock.assert_called_once_with(
"/dev/shm/tmp_xyz123", f"/dev/shm/keep_alive_timestamp_{pid}"
)

@mock.patch("glob.glob")
@mock.patch("os.kill")
@mock.patch("os.remove")
@mock.patch("time.time")
def test_keep_alive_no_files(self, time_mock, remove_mock, kill_mock, glob_mock):
glob_mock.return_value = []
time_mock.return_value = 1000.0
app = agent_engines.AdkApp(agent=_TEST_AGENT)
assert not app.keep_alive()

@mock.patch("glob.glob")
@mock.patch("os.kill")
@mock.patch("os.remove")
@mock.patch("time.time")
def test_keep_alive_one_file_busy(
self, time_mock, remove_mock, kill_mock, glob_mock
):
pid = 12345
glob_mock.return_value = [f"/dev/shm/keep_alive_timestamp_{pid}"]
time_mock.return_value = 1500.0
mock_read_data = str(2000.0) # Timestamp in file is in future
with mock.patch("builtins.open", mock.mock_open(read_data=mock_read_data)):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
assert app.keep_alive()
kill_mock.assert_called_once_with(pid, 0)
remove_mock.assert_not_called()

@mock.patch("glob.glob")
@mock.patch("os.kill")
@mock.patch("os.remove")
@mock.patch("time.time")
def test_keep_alive_one_file_not_busy(
self, time_mock, remove_mock, kill_mock, glob_mock
):
pid = 12345
glob_mock.return_value = [f"/dev/shm/keep_alive_timestamp_{pid}"]
time_mock.return_value = 2500.0
mock_read_data = str(2000.0) # Timestamp in file is in past
with mock.patch("builtins.open", mock.mock_open(read_data=mock_read_data)):
app = agent_engines.AdkApp(agent=_TEST_AGENT)
assert not app.keep_alive()
kill_mock.assert_called_once_with(pid, 0)
remove_mock.assert_not_called()

@mock.patch("glob.glob")
@mock.patch("os.kill")
@mock.patch("os.remove")
@mock.patch("time.time")
def test_keep_alive_stale_file_process_dead(
self, time_mock, remove_mock, kill_mock, glob_mock
):
pid = 12345
stale_file = f"/dev/shm/keep_alive_timestamp_{pid}"
glob_mock.return_value = [stale_file]
kill_mock.side_effect = ProcessLookupError()
time_mock.return_value = 1000.0
app = agent_engines.AdkApp(agent=_TEST_AGENT)
assert not app.keep_alive()
kill_mock.assert_called_once_with(pid, 0)
remove_mock.assert_called_once_with(stale_file)

def test_stream_query(
self,
default_instrumentor_builder_mock: mock.Mock,
Expand Down
79 changes: 79 additions & 0 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@

_DEFAULT_APP_NAME = "default-app-name"
_DEFAULT_USER_ID = "default-user-id"
_KEEP_ALIVE_DIR = "/dev/shm"
_KEEP_ALIVE_FILENAME_PREFIX = "keep_alive_timestamp"
_KEEP_ALIVE_TEMP_FILENAME_PREFIX = "tmp_keep_alive_timestamp"
_KEEP_ALIVE_LEASE_SECONDS = 60 #TODO: Change to 60 * 60 for 1 hour.
_TELEMETRY_API_DISABLED_WARNING = (
"Tracing integration for Agent Engine has migrated to a new API.\n"
"The 'telemetry.googleapis.com' has not been enabled in project %s. \n"
Expand Down Expand Up @@ -992,6 +996,77 @@ def set_up(self):
memory_service=self._tmpl_attrs.get("in_memory_memory_service"),
)

def _update_keep_alive_timestamp(self):
"""Updates the keep-alive timestamp.

It writes the current timestamp to a file
/dev/shm/keep_alive_timestamp_{pid} where pid is the process id.
This is done atomically by writing to a temporary file and then renaming it.
This file can be checked by other processes to see if this agent process
is still alive and processing requests.
"""
import os
import tempfile
import time

try:
pid = os.getpid()
timestamp = str(time.time() + _KEEP_ALIVE_LEASE_SECONDS)
filename = f"{_KEEP_ALIVE_DIR}/{_KEEP_ALIVE_FILENAME_PREFIX}_{pid}"
tmp_dir = _KEEP_ALIVE_DIR
if not os.path.exists(tmp_dir) or not os.path.isdir(tmp_dir):
return
with tempfile.NamedTemporaryFile(
"w",
dir=tmp_dir,
delete=False,
prefix=f"{_KEEP_ALIVE_TEMP_FILENAME_PREFIX}_{pid}_",
) as fp:
fp.write(timestamp)
tmp_path = fp.name
os.rename(tmp_path, filename)
except Exception as e:
# If there's any issue writing the timestamp, we log a warning
# and ignore it.
_warn(f"Failed to update keep-alive timestamp: {e}")

def keep_alive(self) -> bool:
"""Checks if the agent is busy."""
import glob
import os
import time

max_timestamp = -1.0
try:
timestamp_files = glob.glob(
f"{_KEEP_ALIVE_DIR}/{_KEEP_ALIVE_FILENAME_PREFIX}_*"
)
for timestamp_file in timestamp_files:
try:
# Extract PID from filename (e.g., keep_alive_timestamp_1234)
basename = os.path.basename(timestamp_file)
pid_str = basename[len(_KEEP_ALIVE_FILENAME_PREFIX) + 1 :]
pid = int(pid_str)

# Check if the process that created the file is still running
os.kill(pid, 0)

with open(timestamp_file, "r") as f:
timestamp = float(f.read())
if timestamp > max_timestamp:
max_timestamp = timestamp
except (ProcessLookupError, ValueError, FileNotFoundError):
# If process is dead or file is missing/corrupt, remove the stale file
try:
os.remove(timestamp_file)
except FileNotFoundError:
pass
continue
except Exception as e:
_warn(f"Failed to read timestamp files: {e}")

return time.time() <= max_timestamp

async def async_stream_query(
self,
*,
Expand Down Expand Up @@ -1035,6 +1110,7 @@ async def async_stream_query(
from vertexai.agent_engines import _utils
from google.genai import types

self._update_keep_alive_timestamp()
if isinstance(message, Dict):
content = types.Content.model_validate(message)
elif isinstance(message, str):
Expand Down Expand Up @@ -1140,6 +1216,7 @@ def stream_query(
from vertexai.agent_engines import _utils
from google.genai import types

self._update_keep_alive_timestamp()
if isinstance(message, Dict):
content = types.Content.model_validate(message)
elif isinstance(message, str):
Expand Down Expand Up @@ -1191,6 +1268,7 @@ async def streaming_agent_run_with_events(self, request_json: str):
from google.genai import types
from google.genai.errors import ClientError

self._update_keep_alive_timestamp()
request = _StreamRunRequest(**json.loads(request_json))
if not any(
self._tmpl_attrs.get(service)
Expand Down Expand Up @@ -1659,6 +1737,7 @@ def register_operations(self) -> Dict[str, List[str]]:
"async_stream_query",
"streaming_agent_run_with_events",
],
"keep_alive": ["keep_alive"],
}

def _telemetry_enabled(self) -> Optional[bool]:
Expand Down
Loading