diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index c0e92cd9d6..78f5ad469d 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -26,6 +26,7 @@ from google.protobuf import json_format import json import math import pytest +from collections.abc import Sequence, Mapping from google.api_core import api_core_version from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 index f15326d670..2ab58f5cb8 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_macros.j2 @@ -752,8 +752,18 @@ def test_{{ method_name }}_pager(transport_name: str = "grpc"): results = list(pager) assert len(results) == 6 + {% if method.paged_result_field.type.ident|string == 'struct_pb2.ListValue' %} + assert all(isinstance(i, Sequence) + for i in results) + {% elif method.paged_result_field.type.ident|string == 'struct_pb2.Struct' %} + assert all(isinstance(i, Mapping) + for i in results) + {% elif method.paged_result_field.type.ident|string == 'struct_pb2.Value' %} + assert all(True for i in results) + {% else %} assert all(isinstance(i, {{ method.paged_result_field.type.ident }}) for i in results) + {% endif %} {% endif %} def test_{{ method_name }}_pages(transport_name: str = "grpc"): client = {{ service.client_name }}( @@ -913,9 +923,19 @@ async def test_{{ method_name }}_async_pager(): assert async_pager.get('a') is None assert isinstance(async_pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }}) {% else %} + {% if method.paged_result_field.type.ident|string == 'struct_pb2.ListValue' %} + assert all(isinstance(i, Sequence) + for i in responses) + {% elif method.paged_result_field.type.ident|string == 'struct_pb2.Struct' %} + assert all(isinstance(i, Mapping) + for i in responses) + {% elif method.paged_result_field.type.ident|string == 'struct_pb2.Value' %} + assert all(True for i in responses) + {% else %} assert all(isinstance(i, {{ method.paged_result_field.type.ident }}) for i in responses) {% endif %} + {% endif %} @pytest.mark.asyncio @@ -1412,9 +1432,19 @@ def test_{{ method_name }}_rest_pager(transport: str = 'rest'): assert pager.get('a') is None assert isinstance(pager.get('h'), {{ method.paged_result_field.type.fields.get('value').ident }}) {% else %} + {% if method.paged_result_field.type.ident|string == 'struct_pb2.ListValue' %} + assert all(isinstance(i, Sequence) + for i in results) + {% elif method.paged_result_field.type.ident|string == 'struct_pb2.Struct' %} + assert all(isinstance(i, Mapping) + for i in results) + {% elif method.paged_result_field.type.ident|string == 'struct_pb2.Value' %} + assert all(True for i in results) + {% else %} assert all(isinstance(i, {{ method.paged_result_field.type.ident }}) for i in results) {% endif %} + {% endif %} pages = list(client.{{ method_name }}(request=sample_request).pages) for page_, token in zip(pages, ['abc','def','ghi', '']):