Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions tests/unit/vertexai/genai/test_agent_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2866,6 +2866,151 @@ def test_query_agent_engine(self):
None,
)

@mock.patch("google.cloud.storage.Client")
@mock.patch.object(agent_engines.AgentEngines, "_get")
@mock.patch("uuid.uuid4")
def test_run_query_job_agent_engine(self, mock_uuid, get_mock, mock_storage_client):
with mock.patch.object(
self.client.agent_engines._api_client, "request"
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body='{"name": "projects/123/locations/us-central1/reasoningEngines/456/operations/789"}'
)

# Mock the GCS bucket and blob so we don't actually try to use GCS
mock_bucket = mock.Mock()
mock_bucket.exists.return_value = False
mock_blob = mock.Mock()
mock_blob.exists.return_value = False
mock_bucket.blob.return_value = mock_blob
mock_storage_client.return_value.bucket.return_value = mock_bucket

# mock uuid
mock_uuid.return_value.hex = "b92b9b89-4585-4146-8ee5-22fe99802a8e"

# Mock _get to return a dummy resource
get_mock.return_value = _genai_types.ReasoningEngine(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
spec=_genai_types.ReasoningEngineSpec(
deployment_spec=_genai_types.ReasoningEngineSpecDeploymentSpec(
env=[_genai_types.EnvVar(name="input_gcs_uri", value="")]
)
),
)

result = self.client.agent_engines.run_query_job(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
config={
"query": _TEST_QUERY_PROMPT,
"gcs_bucket": "gs://my-input-bucket/",
},
)

# Verify bucket creation
assert mock_bucket.create.call_count == 1
# Verify file upload
mock_blob.upload_from_string.assert_called_once_with(_TEST_QUERY_PROMPT)

assert result == _genai_types.RunQueryJobResult(
job_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
input_gcs_uri="gs://my-input-bucket/input_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
output_gcs_uri="gs://my-input-bucket/output_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
)

request_mock.assert_called_with(
"post",
f"{_TEST_AGENT_ENGINE_RESOURCE_NAME}:asyncQuery",
{
"_url": {"name": _TEST_AGENT_ENGINE_RESOURCE_NAME},
"inputGcsUri": "gs://my-input-bucket/input_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
"outputGcsUri": "gs://my-input-bucket/output_b92b9b89-4585-4146-8ee5-22fe99802a8e.json",
},
None,
)

def test_run_query_job_agent_engine_missing_query(self):
with pytest.raises(
ValueError, match="`query` is required in the config object."
):
self.client.agent_engines.run_query_job(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
config={"gcs_bucket": "gs://my-input-bucket/"},
)

def test_run_query_job_agent_engine_missing_bucket(self):
with pytest.raises(
ValueError, match="`gcs_bucket` is required in the config object."
):
self.client.agent_engines.run_query_job(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
config={"query": _TEST_QUERY_PROMPT},
)

@mock.patch.object(agent_engines.AgentEngines, "_get")
def test_run_query_job_agent_engine_missing_cloud_run_job(self, get_mock):
get_mock.return_value = _genai_types.ReasoningEngine(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
spec=_genai_types.ReasoningEngineSpec(
deployment_spec=_genai_types.ReasoningEngineSpecDeploymentSpec(env=[])
),
)
with pytest.raises(
ValueError,
match="Your ReasoningEngine does not support long running queries, please update your ReasoningEngine and try again.",
):
self.client.agent_engines.run_query_job(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
config={
"query": _TEST_QUERY_PROMPT,
"gcs_bucket": "gs://my-input-bucket/",
},
)

@mock.patch("google.cloud.storage.Client")
@mock.patch.object(agent_engines.AgentEngines, "_get")
@mock.patch("uuid.uuid4")
def test_run_query_job_agent_engine_bucket_creation_forbidden(
self, mock_uuid, get_mock, mock_storage_client
):
with mock.patch.object(
self.client.agent_engines._api_client, "request"
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body='{"name": "projects/123/locations/us-central1/reasoningEngines/456/operations/789"}'
)

from google.api_core import exceptions as api_core_exceptions

mock_bucket = mock.Mock()
mock_bucket.exists.side_effect = api_core_exceptions.Forbidden(
"403 GET Bucket"
)
mock_blob = mock.Mock()
mock_bucket.blob.return_value = mock_blob
mock_storage_client.return_value.bucket.return_value = mock_bucket

mock_uuid.return_value.hex = "b92b9b89-4585-4146-8ee5-22fe99802a8e"

get_mock.return_value = _genai_types.ReasoningEngine(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
spec=_genai_types.ReasoningEngineSpec(
deployment_spec=_genai_types.ReasoningEngineSpecDeploymentSpec(
env=[_genai_types.EnvVar(name="input_gcs_uri", value="")]
)
),
)

with pytest.raises(
ValueError, match="Permission denied to check existence of bucket"
):
self.client.agent_engines.run_query_job(
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
config={
"query": _TEST_QUERY_PROMPT,
"gcs_bucket": "gs://my-input-bucket/",
},
)

def test_query_agent_engine_async(self):
agent = self.client.agent_engines._register_api_methods(
agent_engine=_genai_types.AgentEngine(
Expand Down Expand Up @@ -2898,6 +3043,120 @@ def test_query_agent_engine_async(self):
None,
)

def test_check_query_job_agent_engine(self):
with mock.patch.object(
self.client.agent_engines._api_client, "request"
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body='{"done": true, "metadata": {"output_gcs_uri": "gs://my-output-bucket/output.json", "input_gcs_uri": "gs://my-input-bucket/input.json"}}'
)
with mock.patch("google.cloud.storage.Client") as mock_storage_client:
mock_bucket = mock.Mock()
mock_blob = mock.Mock()
mock_blob.exists.return_value = True
mock_blob.download_as_string.return_value = b'{"success": true}'
mock_bucket.blob.return_value = mock_blob
mock_storage_client.return_value.bucket.return_value = mock_bucket

result = self.client.agent_engines.check_query_job(
name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
config={"retrieve_result": True},
)

assert result == _genai_types.CheckQueryJobResult(
operation_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
status="SUCCESS",
input_gcs_uri="gs://my-input-bucket/input.json",
output_gcs_uri="gs://my-output-bucket/output.json",
result='{"success": true}',
)

def test_check_query_job_agent_engine_running(self):
with mock.patch.object(
self.client.agent_engines._api_client, "request"
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body='{"done": false, "metadata": {"output_gcs_uri": "gs://my-output-bucket/output.json", "input_gcs_uri": "gs://my-input-bucket/input.json"}}'
)

result = self.client.agent_engines.check_query_job(
name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
config={"retrieve_result": True},
)

assert result == _genai_types.CheckQueryJobResult(
operation_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
status="RUNNING",
input_gcs_uri="gs://my-input-bucket/input.json",
output_gcs_uri="gs://my-output-bucket/output.json",
result=None,
)

def test_check_query_job_agent_engine_failed(self):
with mock.patch.object(
self.client.agent_engines._api_client, "request"
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body='{"done": true, "error": {"message": "Job failed with errors."}}'
)

result = self.client.agent_engines.check_query_job(
name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
config={"retrieve_result": True},
)

assert result == _genai_types.CheckQueryJobResult(
operation_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
status="FAILED",
input_gcs_uri=None,
output_gcs_uri=None,
result="{'message': 'Job failed with errors.'}",
)

def test_check_query_job_agent_engine_no_retrieve(self):
with mock.patch.object(
self.client.agent_engines._api_client, "request"
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body='{"done": true, "metadata": {"output_gcs_uri": "gs://my-output-bucket/output.json", "input_gcs_uri": "gs://my-input-bucket/input.json"}}'
)

result = self.client.agent_engines.check_query_job(
name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
config={"retrieve_result": False},
)

assert result == _genai_types.CheckQueryJobResult(
operation_name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
status="SUCCESS",
input_gcs_uri="gs://my-input-bucket/input.json",
output_gcs_uri="gs://my-output-bucket/output.json",
result=None,
)

def test_check_query_job_agent_engine_blob_not_exists(self):
with mock.patch.object(
self.client.agent_engines._api_client, "request"
) as request_mock:
request_mock.return_value = genai_types.HttpResponse(
body='{"done": true, "metadata": {"output_gcs_uri": "gs://my-output-bucket/output.json", "input_gcs_uri": "gs://my-input-bucket/input.json"}}'
)
with mock.patch("google.cloud.storage.Client") as mock_storage_client:
mock_bucket = mock.Mock()
mock_blob = mock.Mock()
mock_blob.exists.return_value = False
mock_bucket.blob.return_value = mock_blob
mock_storage_client.return_value.bucket.return_value = mock_bucket

with pytest.raises(
ValueError,
match="Failed to retrieve blob results for gs://my-output-bucket/output.json",
):
self.client.agent_engines.check_query_job(
name="projects/123/locations/us-central1/reasoningEngines/456/operations/789",
config={"retrieve_result": True},
)

def test_query_agent_engine_stream(self):
with mock.patch.object(
self.client.agent_engines._api_client, "request_streamed"
Expand Down
Loading
Loading