|
49 | 49 | logger.setLevel(logging.INFO) |
50 | 50 |
|
51 | 51 |
|
| 52 | +def _AsyncQueryAgentEngineConfig_to_vertex( |
| 53 | + from_object: Union[dict[str, Any], object], |
| 54 | + parent_object: Optional[dict[str, Any]] = None, |
| 55 | +) -> dict[str, Any]: |
| 56 | + to_object: dict[str, Any] = {} |
| 57 | + |
| 58 | + if getv(from_object, ["query"]) is not None: |
| 59 | + setv(parent_object, ["query"], getv(from_object, ["query"])) |
| 60 | + |
| 61 | + if getv(from_object, ["input_gcs_uri"]) is not None: |
| 62 | + setv(parent_object, ["inputGcsUri"], getv(from_object, ["input_gcs_uri"])) |
| 63 | + |
| 64 | + if getv(from_object, ["output_gcs_uri"]) is not None: |
| 65 | + setv(parent_object, ["outputGcsUri"], getv(from_object, ["output_gcs_uri"])) |
| 66 | + |
| 67 | + return to_object |
| 68 | + |
| 69 | + |
| 70 | +def _AsyncQueryAgentEngineRequestParameters_to_vertex( |
| 71 | + from_object: Union[dict[str, Any], object], |
| 72 | + parent_object: Optional[dict[str, Any]] = None, |
| 73 | +) -> dict[str, Any]: |
| 74 | + to_object: dict[str, Any] = {} |
| 75 | + if getv(from_object, ["name"]) is not None: |
| 76 | + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) |
| 77 | + |
| 78 | + if getv(from_object, ["config"]) is not None: |
| 79 | + setv( |
| 80 | + to_object, |
| 81 | + ["config"], |
| 82 | + _AsyncQueryAgentEngineConfig_to_vertex( |
| 83 | + getv(from_object, ["config"]), to_object |
| 84 | + ), |
| 85 | + ) |
| 86 | + |
| 87 | + return to_object |
| 88 | + |
| 89 | + |
52 | 90 | def _CreateAgentEngineConfig_to_vertex( |
53 | 91 | from_object: Union[dict[str, Any], object], |
54 | 92 | parent_object: Optional[dict[str, Any]] = None, |
@@ -337,6 +375,61 @@ def _UpdateAgentEngineRequestParameters_to_vertex( |
337 | 375 |
|
338 | 376 | class AgentEngines(_api_module.BaseModule): |
339 | 377 |
|
| 378 | + def _async_query( |
| 379 | + self, |
| 380 | + *, |
| 381 | + name: str, |
| 382 | + config: Optional[types.AsyncQueryAgentEngineConfigOrDict] = None, |
| 383 | + ) -> types.AgentEngineOperation: |
| 384 | + """ |
| 385 | + Query an Agent Engine asynchronously. |
| 386 | + """ |
| 387 | + |
| 388 | + parameter_model = types._AsyncQueryAgentEngineRequestParameters( |
| 389 | + name=name, |
| 390 | + config=config, |
| 391 | + ) |
| 392 | + |
| 393 | + request_url_dict: Optional[dict[str, str]] |
| 394 | + if not self._api_client.vertexai: |
| 395 | + raise ValueError("This method is only supported in the Vertex AI client.") |
| 396 | + else: |
| 397 | + request_dict = _AsyncQueryAgentEngineRequestParameters_to_vertex( |
| 398 | + parameter_model |
| 399 | + ) |
| 400 | + request_url_dict = request_dict.get("_url") |
| 401 | + if request_url_dict: |
| 402 | + path = "{name}:asyncQuery".format_map(request_url_dict) |
| 403 | + else: |
| 404 | + path = "{name}:asyncQuery" |
| 405 | + |
| 406 | + query_params = request_dict.get("_query") |
| 407 | + if query_params: |
| 408 | + path = f"{path}?{urlencode(query_params)}" |
| 409 | + # TODO: remove the hack that pops config. |
| 410 | + request_dict.pop("config", None) |
| 411 | + |
| 412 | + http_options: Optional[types.HttpOptions] = None |
| 413 | + if ( |
| 414 | + parameter_model.config is not None |
| 415 | + and parameter_model.config.http_options is not None |
| 416 | + ): |
| 417 | + http_options = parameter_model.config.http_options |
| 418 | + |
| 419 | + request_dict = _common.convert_to_dict(request_dict) |
| 420 | + request_dict = _common.encode_unserializable_types(request_dict) |
| 421 | + |
| 422 | + response = self._api_client.request("post", path, request_dict, http_options) |
| 423 | + |
| 424 | + response_dict = {} if not response.body else json.loads(response.body) |
| 425 | + |
| 426 | + return_value = types.AgentEngineOperation._from_response( |
| 427 | + response=response_dict, kwargs=parameter_model.model_dump() |
| 428 | + ) |
| 429 | + |
| 430 | + self._api_client._verify_response(return_value) |
| 431 | + return return_value |
| 432 | + |
340 | 433 | def _create( |
341 | 434 | self, *, config: Optional[types.CreateAgentEngineConfigOrDict] = None |
342 | 435 | ) -> types.AgentEngineOperation: |
@@ -795,6 +888,96 @@ def _is_lightweight_creation( |
795 | 888 | return False |
796 | 889 | return True |
797 | 890 |
|
| 891 | + def async_query( |
| 892 | + self, |
| 893 | + *, |
| 894 | + name: str, |
| 895 | + config: Optional[types.AsyncQueryAgentEngineConfigOrDict] = None, |
| 896 | + ) -> types.AgentEngineOperation: |
| 897 | + """Queries an agent engine asynchronously. |
| 898 | +
|
| 899 | + Args: |
| 900 | + name (str): |
| 901 | + Required. A fully-qualified resource name or ID. |
| 902 | + config (AsyncQueryAgentEngineConfigOrDict): |
| 903 | + Optional. The configuration for the async query. If not provided, |
| 904 | + the default configuration will be used. This can be used to specify |
| 905 | + the following fields: |
| 906 | + - query: The query to send to the agent engine. |
| 907 | + - input_gcs_uri: The GCS URI of the input file to use for the query. |
| 908 | + - output_gcs_uri: The GCS URI of the output file to store the results of the query. |
| 909 | + """ |
| 910 | + from google.cloud import storage # type: ignore[attr-defined] |
| 911 | + |
| 912 | + if config is None: |
| 913 | + config = types.AsyncQueryAgentEngineConfig() |
| 914 | + elif isinstance(config, dict): |
| 915 | + config = types.AsyncQueryAgentEngineConfig(**config) |
| 916 | + |
| 917 | + api_resource = self._get(name=name) |
| 918 | + |
| 919 | + # Extract default GCS URIs from ReasoningEngine deployment spec env if needed |
| 920 | + default_input_gcs_uri = None |
| 921 | + default_output_gcs_uri = None |
| 922 | + |
| 923 | + if ( |
| 924 | + api_resource.spec |
| 925 | + and api_resource.spec.deployment_spec |
| 926 | + and api_resource.spec.deployment_spec.env |
| 927 | + ): |
| 928 | + for env_var in api_resource.spec.deployment_spec.env: |
| 929 | + if env_var.name == "INPUT_GCS_URI": |
| 930 | + default_input_gcs_uri = env_var.value |
| 931 | + elif env_var.name == "OUTPUT_GCS_URI": |
| 932 | + default_output_gcs_uri = env_var.value |
| 933 | + |
| 934 | + storage_client = storage.Client() |
| 935 | + |
| 936 | + # Set up input_gcs_uri |
| 937 | + input_gcs_uri = config.input_gcs_uri |
| 938 | + if not input_gcs_uri: |
| 939 | + if not default_input_gcs_uri: |
| 940 | + raise ValueError( |
| 941 | + "Could not determine a default GCS bucket for `input_gcs_uri` from the agent engine configuration. Please specify `input_gcs_uri`." |
| 942 | + ) |
| 943 | + input_gcs_uri = default_input_gcs_uri |
| 944 | + config.input_gcs_uri = input_gcs_uri |
| 945 | + |
| 946 | + # Handle creating the bucket if it does not exist |
| 947 | + bucket_name = input_gcs_uri.replace("gs://", "").split("/")[0] |
| 948 | + bucket = storage_client.bucket(bucket_name) |
| 949 | + |
| 950 | + if not bucket.exists(): |
| 951 | + bucket.create() |
| 952 | + |
| 953 | + if config.query: |
| 954 | + blob_name = input_gcs_uri.replace(f"gs://{bucket_name}/", "") |
| 955 | + blob = bucket.blob(blob_name) |
| 956 | + if blob.exists(): |
| 957 | + logger.warning(f"Overwriting existing file at {input_gcs_uri}") |
| 958 | + blob.upload_from_string(config.query) |
| 959 | + |
| 960 | + # Set up output_gcs_uri |
| 961 | + output_gcs_uri = config.output_gcs_uri |
| 962 | + if not output_gcs_uri: |
| 963 | + if not default_output_gcs_uri: |
| 964 | + raise ValueError( |
| 965 | + "Could not determine a default GCS bucket for `output_gcs_uri` from the agent engine configuration. Please specify `output_gcs_uri`." |
| 966 | + ) |
| 967 | + output_gcs_uri = default_output_gcs_uri |
| 968 | + config.output_gcs_uri = output_gcs_uri |
| 969 | + |
| 970 | + output_bucket_name = output_gcs_uri.replace("gs://", "").split("/")[0] |
| 971 | + output_bucket = storage_client.bucket(output_bucket_name) |
| 972 | + if not output_bucket.exists(): |
| 973 | + output_bucket.create() |
| 974 | + |
| 975 | + # Set query to None before it goes back to the server |
| 976 | + config.query = None |
| 977 | + |
| 978 | + # Proceed with sending the async query via the auto-generated method |
| 979 | + return self._async_query(name=name, config=config) |
| 980 | + |
798 | 981 | def get( |
799 | 982 | self, |
800 | 983 | *, |
@@ -2055,6 +2238,63 @@ def list_session_events( |
2055 | 2238 |
|
2056 | 2239 | class AsyncAgentEngines(_api_module.BaseModule): |
2057 | 2240 |
|
| 2241 | + async def _async_query( |
| 2242 | + self, |
| 2243 | + *, |
| 2244 | + name: str, |
| 2245 | + config: Optional[types.AsyncQueryAgentEngineConfigOrDict] = None, |
| 2246 | + ) -> types.AgentEngineOperation: |
| 2247 | + """ |
| 2248 | + Query an Agent Engine asynchronously. |
| 2249 | + """ |
| 2250 | + |
| 2251 | + parameter_model = types._AsyncQueryAgentEngineRequestParameters( |
| 2252 | + name=name, |
| 2253 | + config=config, |
| 2254 | + ) |
| 2255 | + |
| 2256 | + request_url_dict: Optional[dict[str, str]] |
| 2257 | + if not self._api_client.vertexai: |
| 2258 | + raise ValueError("This method is only supported in the Vertex AI client.") |
| 2259 | + else: |
| 2260 | + request_dict = _AsyncQueryAgentEngineRequestParameters_to_vertex( |
| 2261 | + parameter_model |
| 2262 | + ) |
| 2263 | + request_url_dict = request_dict.get("_url") |
| 2264 | + if request_url_dict: |
| 2265 | + path = "{name}:asyncQuery".format_map(request_url_dict) |
| 2266 | + else: |
| 2267 | + path = "{name}:asyncQuery" |
| 2268 | + |
| 2269 | + query_params = request_dict.get("_query") |
| 2270 | + if query_params: |
| 2271 | + path = f"{path}?{urlencode(query_params)}" |
| 2272 | + # TODO: remove the hack that pops config. |
| 2273 | + request_dict.pop("config", None) |
| 2274 | + |
| 2275 | + http_options: Optional[types.HttpOptions] = None |
| 2276 | + if ( |
| 2277 | + parameter_model.config is not None |
| 2278 | + and parameter_model.config.http_options is not None |
| 2279 | + ): |
| 2280 | + http_options = parameter_model.config.http_options |
| 2281 | + |
| 2282 | + request_dict = _common.convert_to_dict(request_dict) |
| 2283 | + request_dict = _common.encode_unserializable_types(request_dict) |
| 2284 | + |
| 2285 | + response = await self._api_client.async_request( |
| 2286 | + "post", path, request_dict, http_options |
| 2287 | + ) |
| 2288 | + |
| 2289 | + response_dict = {} if not response.body else json.loads(response.body) |
| 2290 | + |
| 2291 | + return_value = types.AgentEngineOperation._from_response( |
| 2292 | + response=response_dict, kwargs=parameter_model.model_dump() |
| 2293 | + ) |
| 2294 | + |
| 2295 | + self._api_client._verify_response(return_value) |
| 2296 | + return return_value |
| 2297 | + |
2058 | 2298 | async def _create( |
2059 | 2299 | self, *, config: Optional[types.CreateAgentEngineConfigOrDict] = None |
2060 | 2300 | ) -> types.AgentEngineOperation: |
|
0 commit comments