Skip to content

Commit 2cafa03

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add SDK method async_query to support >8 hour processing of queries.
PiperOrigin-RevId: 879636016
1 parent b3bae32 commit 2cafa03

4 files changed

Lines changed: 1191 additions & 829 deletions

File tree

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2866,6 +2866,56 @@ def test_query_agent_engine(self):
28662866
None,
28672867
)
28682868

2869+
def test_async_query_agent_engine(self):
2870+
with mock.patch.object(
2871+
self.client.agent_engines._api_client, "request"
2872+
) as request_mock:
2873+
request_mock.return_value = genai_types.HttpResponse(body="")
2874+
with mock.patch(
2875+
"google.cloud.storage.Client"
2876+
) as mock_storage_client, mock.patch.object(
2877+
self.client.agent_engines, "_get"
2878+
) as get_mock:
2879+
# Mock the GCS bucket and blob so we don't actually try to use GCS
2880+
mock_bucket = mock.Mock()
2881+
mock_bucket.exists.return_value = False
2882+
mock_blob = mock.Mock()
2883+
mock_blob.exists.return_value = False
2884+
mock_bucket.blob.return_value = mock_blob
2885+
mock_storage_client.return_value.bucket.return_value = mock_bucket
2886+
2887+
# Mock _get to return a dummy resource
2888+
get_mock.return_value = _genai_types.ReasoningEngine(
2889+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2890+
spec=_genai_types.ReasoningEngineSpec(),
2891+
)
2892+
2893+
self.client.agent_engines.async_query(
2894+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2895+
config={
2896+
"query": _TEST_QUERY_PROMPT,
2897+
"input_gcs_uri": "gs://my-input-bucket/input.json",
2898+
"output_gcs_uri": "gs://my-output-bucket/output.json",
2899+
},
2900+
)
2901+
2902+
# Verify bucket creation
2903+
assert mock_bucket.create.call_count == 2
2904+
# Verify file upload
2905+
mock_blob.upload_from_string.assert_called_once_with(_TEST_QUERY_PROMPT)
2906+
2907+
request_mock.assert_called_with(
2908+
"post",
2909+
f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:asyncQuery",
2910+
{
2911+
"_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME},
2912+
"query": _TEST_QUERY_PROMPT,
2913+
"inputGcsUri": "gs://my-input-bucket/input.json",
2914+
"outputGcsUri": "gs://my-output-bucket/output.json",
2915+
},
2916+
None,
2917+
)
2918+
28692919
def test_query_agent_engine_async(self):
28702920
agent = self.client.agent_engines._register_api_methods(
28712921
agent_engine=_genai_types.AgentEngine(

vertexai/_genai/agent_engines.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,44 @@
4949
logger.setLevel(logging.INFO)
5050

5151

52+
def _AsyncQueryAgentEngineConfig_to_vertex(
53+
from_object: Union[dict[str, Any], object],
54+
parent_object: Optional[dict[str, Any]] = None,
55+
) -> dict[str, Any]:
56+
to_object: dict[str, Any] = {}
57+
58+
if getv(from_object, ["query"]) is not None:
59+
setv(parent_object, ["query"], getv(from_object, ["query"]))
60+
61+
if getv(from_object, ["input_gcs_uri"]) is not None:
62+
setv(parent_object, ["inputGcsUri"], getv(from_object, ["input_gcs_uri"]))
63+
64+
if getv(from_object, ["output_gcs_uri"]) is not None:
65+
setv(parent_object, ["outputGcsUri"], getv(from_object, ["output_gcs_uri"]))
66+
67+
return to_object
68+
69+
70+
def _AsyncQueryAgentEngineRequestParameters_to_vertex(
71+
from_object: Union[dict[str, Any], object],
72+
parent_object: Optional[dict[str, Any]] = None,
73+
) -> dict[str, Any]:
74+
to_object: dict[str, Any] = {}
75+
if getv(from_object, ["name"]) is not None:
76+
setv(to_object, ["_url", "name"], getv(from_object, ["name"]))
77+
78+
if getv(from_object, ["config"]) is not None:
79+
setv(
80+
to_object,
81+
["config"],
82+
_AsyncQueryAgentEngineConfig_to_vertex(
83+
getv(from_object, ["config"]), to_object
84+
),
85+
)
86+
87+
return to_object
88+
89+
5290
def _CreateAgentEngineConfig_to_vertex(
5391
from_object: Union[dict[str, Any], object],
5492
parent_object: Optional[dict[str, Any]] = None,
@@ -337,6 +375,61 @@ def _UpdateAgentEngineRequestParameters_to_vertex(
337375

338376
class AgentEngines(_api_module.BaseModule):
339377

378+
def _async_query(
379+
self,
380+
*,
381+
name: str,
382+
config: Optional[types.AsyncQueryAgentEngineConfigOrDict] = None,
383+
) -> types.AgentEngineOperation:
384+
"""
385+
Query an Agent Engine asynchronously.
386+
"""
387+
388+
parameter_model = types._AsyncQueryAgentEngineRequestParameters(
389+
name=name,
390+
config=config,
391+
)
392+
393+
request_url_dict: Optional[dict[str, str]]
394+
if not self._api_client.vertexai:
395+
raise ValueError("This method is only supported in the Vertex AI client.")
396+
else:
397+
request_dict = _AsyncQueryAgentEngineRequestParameters_to_vertex(
398+
parameter_model
399+
)
400+
request_url_dict = request_dict.get("_url")
401+
if request_url_dict:
402+
path = "{name}:asyncQuery".format_map(request_url_dict)
403+
else:
404+
path = "{name}:asyncQuery"
405+
406+
query_params = request_dict.get("_query")
407+
if query_params:
408+
path = f"{path}?{urlencode(query_params)}"
409+
# TODO: remove the hack that pops config.
410+
request_dict.pop("config", None)
411+
412+
http_options: Optional[types.HttpOptions] = None
413+
if (
414+
parameter_model.config is not None
415+
and parameter_model.config.http_options is not None
416+
):
417+
http_options = parameter_model.config.http_options
418+
419+
request_dict = _common.convert_to_dict(request_dict)
420+
request_dict = _common.encode_unserializable_types(request_dict)
421+
422+
response = self._api_client.request("post", path, request_dict, http_options)
423+
424+
response_dict = {} if not response.body else json.loads(response.body)
425+
426+
return_value = types.AgentEngineOperation._from_response(
427+
response=response_dict, kwargs=parameter_model.model_dump()
428+
)
429+
430+
self._api_client._verify_response(return_value)
431+
return return_value
432+
340433
def _create(
341434
self, *, config: Optional[types.CreateAgentEngineConfigOrDict] = None
342435
) -> types.AgentEngineOperation:
@@ -795,6 +888,96 @@ def _is_lightweight_creation(
795888
return False
796889
return True
797890

891+
def async_query(
892+
self,
893+
*,
894+
name: str,
895+
config: Optional[types.AsyncQueryAgentEngineConfigOrDict] = None,
896+
) -> types.AgentEngineOperation:
897+
"""Queries an agent engine asynchronously.
898+
899+
Args:
900+
name (str):
901+
Required. A fully-qualified resource name or ID.
902+
config (AsyncQueryAgentEngineConfigOrDict):
903+
Optional. The configuration for the async query. If not provided,
904+
the default configuration will be used. This can be used to specify
905+
the following fields:
906+
- query: The query to send to the agent engine.
907+
- input_gcs_uri: The GCS URI of the input file to use for the query.
908+
- output_gcs_uri: The GCS URI of the output file to store the results of the query.
909+
"""
910+
from google.cloud import storage # type: ignore[attr-defined]
911+
912+
if config is None:
913+
config = types.AsyncQueryAgentEngineConfig()
914+
elif isinstance(config, dict):
915+
config = types.AsyncQueryAgentEngineConfig(**config)
916+
917+
api_resource = self._get(name=name)
918+
919+
# Extract default GCS URIs from ReasoningEngine deployment spec env if needed
920+
default_input_gcs_uri = None
921+
default_output_gcs_uri = None
922+
923+
if (
924+
api_resource.spec
925+
and api_resource.spec.deployment_spec
926+
and api_resource.spec.deployment_spec.env
927+
):
928+
for env_var in api_resource.spec.deployment_spec.env:
929+
if env_var.name == "INPUT_GCS_URI":
930+
default_input_gcs_uri = env_var.value
931+
elif env_var.name == "OUTPUT_GCS_URI":
932+
default_output_gcs_uri = env_var.value
933+
934+
storage_client = storage.Client()
935+
936+
# Set up input_gcs_uri
937+
input_gcs_uri = config.input_gcs_uri
938+
if not input_gcs_uri:
939+
if not default_input_gcs_uri:
940+
raise ValueError(
941+
"Could not determine a default GCS bucket for `input_gcs_uri` from the agent engine configuration. Please specify `input_gcs_uri`."
942+
)
943+
input_gcs_uri = default_input_gcs_uri
944+
config.input_gcs_uri = input_gcs_uri
945+
946+
# Handle creating the bucket if it does not exist
947+
bucket_name = input_gcs_uri.replace("gs://", "").split("/")[0]
948+
bucket = storage_client.bucket(bucket_name)
949+
950+
if not bucket.exists():
951+
bucket.create()
952+
953+
if config.query:
954+
blob_name = input_gcs_uri.replace(f"gs://{bucket_name}/", "")
955+
blob = bucket.blob(blob_name)
956+
if blob.exists():
957+
logger.warning(f"Overwriting existing file at {input_gcs_uri}")
958+
blob.upload_from_string(config.query)
959+
960+
# Set up output_gcs_uri
961+
output_gcs_uri = config.output_gcs_uri
962+
if not output_gcs_uri:
963+
if not default_output_gcs_uri:
964+
raise ValueError(
965+
"Could not determine a default GCS bucket for `output_gcs_uri` from the agent engine configuration. Please specify `output_gcs_uri`."
966+
)
967+
output_gcs_uri = default_output_gcs_uri
968+
config.output_gcs_uri = output_gcs_uri
969+
970+
output_bucket_name = output_gcs_uri.replace("gs://", "").split("/")[0]
971+
output_bucket = storage_client.bucket(output_bucket_name)
972+
if not output_bucket.exists():
973+
output_bucket.create()
974+
975+
# Set query to None before it goes back to the server
976+
config.query = None
977+
978+
# Proceed with sending the async query via the auto-generated method
979+
return self._async_query(name=name, config=config)
980+
798981
def get(
799982
self,
800983
*,
@@ -2055,6 +2238,63 @@ def list_session_events(
20552238

20562239
class AsyncAgentEngines(_api_module.BaseModule):
20572240

2241+
async def _async_query(
2242+
self,
2243+
*,
2244+
name: str,
2245+
config: Optional[types.AsyncQueryAgentEngineConfigOrDict] = None,
2246+
) -> types.AgentEngineOperation:
2247+
"""
2248+
Query an Agent Engine asynchronously.
2249+
"""
2250+
2251+
parameter_model = types._AsyncQueryAgentEngineRequestParameters(
2252+
name=name,
2253+
config=config,
2254+
)
2255+
2256+
request_url_dict: Optional[dict[str, str]]
2257+
if not self._api_client.vertexai:
2258+
raise ValueError("This method is only supported in the Vertex AI client.")
2259+
else:
2260+
request_dict = _AsyncQueryAgentEngineRequestParameters_to_vertex(
2261+
parameter_model
2262+
)
2263+
request_url_dict = request_dict.get("_url")
2264+
if request_url_dict:
2265+
path = "{name}:asyncQuery".format_map(request_url_dict)
2266+
else:
2267+
path = "{name}:asyncQuery"
2268+
2269+
query_params = request_dict.get("_query")
2270+
if query_params:
2271+
path = f"{path}?{urlencode(query_params)}"
2272+
# TODO: remove the hack that pops config.
2273+
request_dict.pop("config", None)
2274+
2275+
http_options: Optional[types.HttpOptions] = None
2276+
if (
2277+
parameter_model.config is not None
2278+
and parameter_model.config.http_options is not None
2279+
):
2280+
http_options = parameter_model.config.http_options
2281+
2282+
request_dict = _common.convert_to_dict(request_dict)
2283+
request_dict = _common.encode_unserializable_types(request_dict)
2284+
2285+
response = await self._api_client.async_request(
2286+
"post", path, request_dict, http_options
2287+
)
2288+
2289+
response_dict = {} if not response.body else json.loads(response.body)
2290+
2291+
return_value = types.AgentEngineOperation._from_response(
2292+
response=response_dict, kwargs=parameter_model.model_dump()
2293+
)
2294+
2295+
self._api_client._verify_response(return_value)
2296+
return return_value
2297+
20582298
async def _create(
20592299
self, *, config: Optional[types.CreateAgentEngineConfigOrDict] = None
20602300
) -> types.AgentEngineOperation:

0 commit comments

Comments
 (0)