1616import importlib
1717import json
1818import os
19+ import cloudpickle
20+ import sys
1921from unittest import mock
2022from typing import Optional
21- import dataclasses
2223
2324from google import auth
25+ from google .auth import credentials as auth_credentials
26+ from google .cloud import storage
2427import vertexai
28+ from google .cloud import aiplatform
29+ from google .cloud .aiplatform_v1 import types as aip_types
30+ from google .cloud .aiplatform_v1 .services import reasoning_engine_service
31+ from google .cloud .aiplatform import base
2532from google .cloud .aiplatform import initializer
2633from vertexai .agent_engines import _utils
2734from vertexai import agent_engines
35+ from vertexai .agent_engines .templates import adk as adk_template
36+ from vertexai .agent_engines import _agent_engines
37+ from google .api_core import operation as ga_operation
2838from google .genai import types
2939import pytest
3040import uuid
@@ -76,6 +86,52 @@ def __init__(self, name: str, model: str):
7686 "streaming_mode" : "sse" ,
7787 "max_llm_calls" : 500 ,
7888}
89+ _TEST_STAGING_BUCKET = "gs://test-bucket"
90+ _TEST_CREDENTIALS = mock .Mock (spec = auth_credentials .AnonymousCredentials ())
91+ _TEST_PARENT = f"projects/{ _TEST_PROJECT } /locations/{ _TEST_LOCATION } "
92+ _TEST_RESOURCE_ID = "1028944691210842416"
93+ _TEST_AGENT_ENGINE_RESOURCE_NAME = (
94+ f"{ _TEST_PARENT } /reasoningEngines/{ _TEST_RESOURCE_ID } "
95+ )
96+ _TEST_AGENT_ENGINE_DISPLAY_NAME = "Agent Engine Display Name"
97+ _TEST_GCS_DIR_NAME = _agent_engines ._DEFAULT_GCS_DIR_NAME
98+ _TEST_BLOB_FILENAME = _agent_engines ._BLOB_FILENAME
99+ _TEST_REQUIREMENTS_FILE = _agent_engines ._REQUIREMENTS_FILE
100+ _TEST_EXTRA_PACKAGES_FILE = _agent_engines ._EXTRA_PACKAGES_FILE
101+ _TEST_AGENT_ENGINE_GCS_URI = "{}/{}/{}" .format (
102+ _TEST_STAGING_BUCKET ,
103+ _TEST_GCS_DIR_NAME ,
104+ _TEST_BLOB_FILENAME ,
105+ )
106+ _TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI = "{}/{}/{}" .format (
107+ _TEST_STAGING_BUCKET ,
108+ _TEST_GCS_DIR_NAME ,
109+ _TEST_EXTRA_PACKAGES_FILE ,
110+ )
111+ _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI = "{}/{}/{}" .format (
112+ _TEST_STAGING_BUCKET ,
113+ _TEST_GCS_DIR_NAME ,
114+ _TEST_REQUIREMENTS_FILE ,
115+ )
116+ _TEST_AGENT_ENGINE_PACKAGE_SPEC = aip_types .ReasoningEngineSpec .PackageSpec (
117+ python_version = f"{ sys .version_info .major } .{ sys .version_info .minor } " ,
118+ pickle_object_gcs_uri = _TEST_AGENT_ENGINE_GCS_URI ,
119+ dependency_files_gcs_uri = _TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI ,
120+ requirements_gcs_uri = _TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI ,
121+ )
122+ _ADK_AGENT_FRAMEWORK = adk_template .AdkApp .agent_framework
123+ _TEST_AGENT_ENGINE_OBJ = aip_types .ReasoningEngine (
124+ name = _TEST_AGENT_ENGINE_RESOURCE_NAME ,
125+ display_name = _TEST_AGENT_ENGINE_DISPLAY_NAME ,
126+ spec = aip_types .ReasoningEngineSpec (
127+ package_spec = _TEST_AGENT_ENGINE_PACKAGE_SPEC ,
128+ agent_framework = _ADK_AGENT_FRAMEWORK ,
129+ ),
130+ )
131+
132+ GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = (
133+ "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"
134+ )
79135
80136
81137@pytest .fixture (scope = "module" )
@@ -97,27 +153,11 @@ def vertexai_init_mock():
97153
98154
99155@pytest .fixture
100- def cloud_trace_exporter_mock ():
101- import sys
102- import opentelemetry
103-
104- mock_cloud_trace_exporter = mock .Mock ()
105-
106- opentelemetry .exporter = type (sys )("exporter" )
107- opentelemetry .exporter .cloud_trace = type (sys )("cloud_trace" )
108- opentelemetry .exporter .cloud_trace .CloudTraceSpanExporter = (
109- mock_cloud_trace_exporter
110- )
111-
112- sys .modules ["opentelemetry.exporter" ] = opentelemetry .exporter
113- sys .modules ["opentelemetry.exporter.cloud_trace" ] = (
114- opentelemetry .exporter .cloud_trace
115- )
116-
117- yield mock_cloud_trace_exporter
118-
119- del sys .modules ["opentelemetry.exporter.cloud_trace" ]
120- del sys .modules ["opentelemetry.exporter" ]
156+ def otlp_span_exporter_mock ():
157+ with mock .patch (
158+ "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter"
159+ ) as otlp_span_exporter_mock :
160+ yield otlp_span_exporter_mock
121161
122162
123163@pytest .fixture
@@ -609,9 +649,9 @@ def test_custom_instrumentor_enablement(
609649 )
610650 def test_tracing_setup (
611651 self ,
612- trace_provider_mock : mock .Mock ,
613- cloud_trace_exporter_mock : mock .Mock ,
614652 monkeypatch ,
653+ trace_provider_mock : mock .Mock ,
654+ otlp_span_exporter_mock : mock .Mock ,
615655 ):
616656 monkeypatch .setattr (
617657 "uuid.uuid4" , lambda : uuid .UUID ("12345678123456781234567812345678" )
@@ -633,17 +673,9 @@ def test_tracing_setup(
633673 "some-attribute" : "some-value" ,
634674 }
635675
636- @dataclasses .dataclass
637- class RegexMatchingAll :
638- keys : set [str ]
639-
640- def __eq__ (self , regex : object ) -> bool :
641- return isinstance (regex , str ) and set (regex .split ("|" )) == self .keys
642-
643- cloud_trace_exporter_mock .assert_called_once_with (
644- project_id = _TEST_PROJECT ,
645- client = mock .ANY ,
646- resource_regex = RegexMatchingAll (keys = set (expected_attributes .keys ())),
676+ otlp_span_exporter_mock .assert_called_once_with (
677+ session = mock .ANY ,
678+ endpoint = "https://telemetry.googleapis.com/v1/traces" ,
647679 )
648680
649681 assert (
@@ -655,7 +687,6 @@ def __eq__(self, regex: object) -> bool:
655687 def test_enable_tracing (
656688 self ,
657689 caplog ,
658- cloud_trace_exporter_mock ,
659690 tracer_provider_mock ,
660691 simple_span_processor_mock ,
661692 ):
@@ -752,3 +783,174 @@ async def test_async_stream_query_invalid_message_type(self):
752783 ):
753784 async for _ in app .async_stream_query (user_id = _TEST_USER_ID , message = 123 ):
754785 pass
786+
787+
788+ @pytest .fixture (scope = "module" )
789+ def create_agent_engine_mock ():
790+ with mock .patch .object (
791+ reasoning_engine_service .ReasoningEngineServiceClient ,
792+ "create_reasoning_engine" ,
793+ ) as create_agent_engine_mock :
794+ create_agent_engine_lro_mock = mock .Mock (ga_operation .Operation )
795+ create_agent_engine_lro_mock .result .return_value = _TEST_AGENT_ENGINE_OBJ
796+ create_agent_engine_mock .return_value = create_agent_engine_lro_mock
797+ yield create_agent_engine_mock
798+
799+
800+ @pytest .fixture (scope = "module" )
801+ def get_agent_engine_mock ():
802+ with mock .patch .object (
803+ reasoning_engine_service .ReasoningEngineServiceClient ,
804+ "get_reasoning_engine" ,
805+ ) as get_agent_engine_mock :
806+ api_client_mock = mock .Mock ()
807+ api_client_mock .get_reasoning_engine .return_value = _TEST_AGENT_ENGINE_OBJ
808+ get_agent_engine_mock .return_value = api_client_mock
809+ yield get_agent_engine_mock
810+
811+
812+ @pytest .fixture (scope = "module" )
813+ def cloud_storage_create_bucket_mock ():
814+ with mock .patch .object (storage , "Client" ) as cloud_storage_mock :
815+ bucket_mock = mock .Mock (spec = storage .Bucket )
816+ bucket_mock .blob .return_value .open .return_value = "blob_file"
817+ bucket_mock .blob .return_value .upload_from_filename .return_value = None
818+ bucket_mock .blob .return_value .upload_from_string .return_value = None
819+
820+ cloud_storage_mock .get_bucket = mock .Mock (
821+ side_effect = ValueError ("bucket not found" )
822+ )
823+ cloud_storage_mock .bucket .return_value = bucket_mock
824+ cloud_storage_mock .create_bucket .return_value = bucket_mock
825+
826+ yield cloud_storage_mock
827+
828+
829+ @pytest .fixture (scope = "module" )
830+ def cloudpickle_dump_mock ():
831+ with mock .patch .object (cloudpickle , "dump" ) as cloudpickle_dump_mock :
832+ yield cloudpickle_dump_mock
833+
834+
835+ @pytest .fixture (scope = "module" )
836+ def cloudpickle_load_mock ():
837+ with mock .patch .object (cloudpickle , "load" ) as cloudpickle_load_mock :
838+ yield cloudpickle_load_mock
839+
840+
841+ @pytest .fixture (scope = "function" )
842+ def get_gca_resource_mock ():
843+ with mock .patch .object (
844+ base .VertexAiResourceNoun ,
845+ "_get_gca_resource" ,
846+ ) as get_gca_resource_mock :
847+ get_gca_resource_mock .return_value = _TEST_AGENT_ENGINE_OBJ
848+ yield get_gca_resource_mock
849+
850+
851+ # Function scope is required for the pytest parameterized tests.
852+ @pytest .fixture (scope = "function" )
853+ def update_agent_engine_mock ():
854+ with mock .patch .object (
855+ reasoning_engine_service .ReasoningEngineServiceClient ,
856+ "update_reasoning_engine" ,
857+ ) as update_agent_engine_mock :
858+ yield update_agent_engine_mock
859+
860+
861+ @pytest .mark .usefixtures ("google_auth_mock" )
862+ class TestAgentEngines :
863+ def setup_method (self ):
864+ importlib .reload (initializer )
865+ importlib .reload (aiplatform )
866+ aiplatform .init (
867+ project = _TEST_PROJECT ,
868+ location = _TEST_LOCATION ,
869+ credentials = _TEST_CREDENTIALS ,
870+ staging_bucket = _TEST_STAGING_BUCKET ,
871+ )
872+
873+ def teardown_method (self ):
874+ initializer .global_pool .shutdown (wait = True )
875+
876+ @pytest .mark .parametrize (
877+ "env_vars,expected_env_vars" ,
878+ [
879+ ({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY : "true" }),
880+ (None , {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY : "true" }),
881+ (
882+ {"some_env" : "some_val" },
883+ {
884+ "some_env" : "some_val" ,
885+ GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY : "true" ,
886+ },
887+ ),
888+ (
889+ {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY : "false" },
890+ {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY : "false" },
891+ ),
892+ ],
893+ )
894+ def test_create_default_telemetry_enablement (
895+ self ,
896+ create_agent_engine_mock : mock .Mock ,
897+ cloud_storage_create_bucket_mock : mock .Mock ,
898+ cloudpickle_dump_mock : mock .Mock ,
899+ cloudpickle_load_mock : mock .Mock ,
900+ get_gca_resource_mock : mock .Mock ,
901+ env_vars : dict [str , str ],
902+ expected_env_vars : dict [str , str ],
903+ ):
904+ agent_engines .create (
905+ agent_engine = agent_engines .AdkApp (agent = _TEST_AGENT ),
906+ env_vars = env_vars ,
907+ )
908+ create_agent_engine_mock .assert_called_once ()
909+ deployment_spec = create_agent_engine_mock .call_args .kwargs [
910+ "reasoning_engine"
911+ ].spec .deployment_spec
912+ assert _utils .to_dict (deployment_spec )["env" ] == [
913+ {"name" : key , "value" : value } for key , value in expected_env_vars .items ()
914+ ]
915+
916+ @pytest .mark .parametrize (
917+ "env_vars,expected_env_vars" ,
918+ [
919+ ({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY : "true" }),
920+ (None , {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY : "true" }),
921+ (
922+ {"some_env" : "some_val" },
923+ {
924+ "some_env" : "some_val" ,
925+ GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY : "true" ,
926+ },
927+ ),
928+ (
929+ {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY : "false" },
930+ {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY : "false" },
931+ ),
932+ ],
933+ )
934+ def test_update_default_telemetry_enablement (
935+ self ,
936+ update_agent_engine_mock : mock .Mock ,
937+ cloud_storage_create_bucket_mock : mock .Mock ,
938+ cloudpickle_dump_mock : mock .Mock ,
939+ cloudpickle_load_mock : mock .Mock ,
940+ get_gca_resource_mock : mock .Mock ,
941+ get_agent_engine_mock : mock .Mock ,
942+ env_vars : dict [str , str ],
943+ expected_env_vars : dict [str , str ],
944+ ):
945+ agent_engines .update (
946+ resource_name = _TEST_AGENT_ENGINE_RESOURCE_NAME ,
947+ description = "foobar" , # avoid "At least one of ... must be specified" errors.
948+ env_vars = env_vars ,
949+ )
950+ update_agent_engine_mock .assert_called_once ()
951+ deployment_spec = update_agent_engine_mock .call_args .kwargs [
952+ "request"
953+ ].reasoning_engine .spec .deployment_spec
954+ assert _utils .to_dict (deployment_spec )["env" ] == [
955+ {"name" : key , "value" : value } for key , value in expected_env_vars .items ()
956+ ]
0 commit comments