diff --git a/gapic/cli/generate.py b/gapic/cli/generate.py index e8eee1f034..674979f7d9 100644 --- a/gapic/cli/generate.py +++ b/gapic/cli/generate.py @@ -23,6 +23,7 @@ from gapic import generator from gapic.schema import api from gapic.utils import Options +from gapic.utils.cache import generation_cache_context @click.command() @@ -56,10 +57,11 @@ def generate(request: typing.BinaryIO, output: typing.BinaryIO) -> None: [p.package for p in req.proto_file if p.name in req.file_to_generate] ).rstrip(".") - # Build the API model object. - # This object is a frozen representation of the whole API, and is sent - # to each template in the rendering step. - api_schema = api.API.build(req.proto_file, opts=opts, package=package) + with generation_cache_context(): + # Build the API model object. + # This object is a frozen representation of the whole API, and is sent + # to each template in the rendering step. + api_schema = api.API.build(req.proto_file, opts=opts, package=package) # Translate into a protobuf CodeGeneratorResponse; this reads the # individual templates and renders them. diff --git a/gapic/schema/metadata.py b/gapic/schema/metadata.py index 7df8d0291f..d973a4faf9 100644 --- a/gapic/schema/metadata.py +++ b/gapic/schema/metadata.py @@ -28,13 +28,14 @@ import dataclasses import re -from typing import FrozenSet, Set, Tuple, Optional +from typing import FrozenSet, Set, Tuple, Optional, Dict from google.protobuf import descriptor_pb2 from gapic.schema import imp from gapic.schema import naming from gapic.utils import cached_property +from gapic.utils import cached_proto_context from gapic.utils import RESERVED_NAMES # This class is a minor hack to optimize Address's __eq__ method. @@ -359,6 +360,7 @@ def resolve(self, selector: str) -> str: return f'{".".join(self.package)}.{selector}' return selector + @cached_proto_context def with_context(self, *, collisions: Set[str]) -> "Address": """Return a derivative of this address with the provided context. @@ -398,6 +400,7 @@ def doc(self): return "\n\n".join(self.documentation.leading_detached_comments) return "" + @cached_proto_context def with_context(self, *, collisions: Set[str]) -> "Metadata": """Return a derivative of this metadata with the provided context. diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 17a7832756..26118f9c1b 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -69,6 +69,7 @@ from gapic.schema import metadata from gapic.utils import uri_sample from gapic.utils import make_private +from gapic.utils import cached_proto_context @dataclasses.dataclass(frozen=True) @@ -410,6 +411,7 @@ def type(self) -> Union["MessageType", "EnumType", "PrimitiveType"]: "This code should not be reachable; please file a bug." ) + @cached_proto_context def with_context( self, *, @@ -435,8 +437,10 @@ def with_context( if self.message else None ), - enum=self.enum.with_context(collisions=collisions) if self.enum else None, - meta=self.meta.with_context(collisions=collisions), + enum=(self.enum.with_context(collisions=collisions) if self.enum else None), + meta=self.meta.with_context( + collisions=collisions, + ), ) def add_to_address_allowlist( @@ -737,7 +741,9 @@ def path_regex_str(self) -> str: return parsing_regex_str def get_field( - self, *field_path: str, collisions: Optional[Set[str]] = None + self, + *field_path: str, + collisions: Optional[Set[str]] = None, ) -> Field: """Return a field arbitrarily deep in this message's structure. @@ -805,6 +811,7 @@ def get_field( # message. return cursor.message.get_field(*field_path[1:], collisions=collisions) + @cached_proto_context def with_context( self, *, @@ -829,7 +836,8 @@ def with_context( fields=( { k: v.with_context( - collisions=collisions, visited_messages=visited_messages + collisions=collisions, + visited_messages=visited_messages, ) for k, v in self.fields.items() } @@ -937,7 +945,12 @@ def ident(self) -> metadata.Address: """Return the identifier data to be used in templates.""" return self.meta.address - def with_context(self, *, collisions: Set[str]) -> "EnumType": + @cached_proto_context + def with_context( + self, + *, + collisions: Set[str], + ) -> "EnumType": """Return a derivative of this enum with the provided context. This method is used to address naming collisions. The returned @@ -1058,6 +1071,7 @@ class ExtendedOperationInfo: request_type: MessageType operation_type: MessageType + @cached_proto_context def with_context( self, *, @@ -1127,6 +1141,7 @@ class OperationInfo: response_type: MessageType metadata_type: MessageType + @cached_proto_context def with_context( self, *, @@ -1937,6 +1952,7 @@ def void(self) -> bool: """Return True if this method has no return value, False otherwise.""" return self.output.ident.proto == "google.protobuf.Empty" + @cached_proto_context def with_context( self, *, @@ -1981,7 +1997,9 @@ def with_context( collisions=collisions, visited_messages=visited_messages, ), - meta=self.meta.with_context(collisions=collisions), + meta=self.meta.with_context( + collisions=collisions, + ), ) def add_to_address_allowlist( @@ -2357,6 +2375,7 @@ def operation_polling_method(self) -> Optional[Method]: def is_internal(self) -> bool: return any(m.is_internal for m in self.methods.values()) + @cached_proto_context def with_context( self, *, @@ -2380,7 +2399,9 @@ def with_context( ) for k, v in self.methods.items() }, - meta=self.meta.with_context(collisions=collisions), + meta=self.meta.with_context( + collisions=collisions, + ), ) def add_to_address_allowlist( diff --git a/gapic/utils/__init__.py b/gapic/utils/__init__.py index 8b48801730..23c5739156 100644 --- a/gapic/utils/__init__.py +++ b/gapic/utils/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from gapic.utils.cache import cached_property +from gapic.utils.cache import cached_proto_context from gapic.utils.case import to_snake_case from gapic.utils.case import to_camel_case from gapic.utils.checks import is_msg_field_pb @@ -34,6 +35,7 @@ __all__ = ( "cached_property", + "cached_proto_context", "convert_uri_fieldnames", "doc", "empty", diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index f9c4d703f5..2991af3d85 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -13,6 +13,9 @@ # limitations under the License. import functools +import contextlib +import threading +from typing import Dict, Optional, Any def cached_property(fx): @@ -43,3 +46,49 @@ def inner(self): return self._cached_values[fx.__name__] return property(inner) + + +# Thread-local storage for the simple cache dictionary +_thread_local = threading.local() + + +@contextlib.contextmanager +def generation_cache_context(): + """Context manager to explicitly manage the cache lifecycle.""" + # Initialize the cache as a standard dictionary + _thread_local.cache = {} + try: + yield + finally: + # Delete the dictionary to free all memory and pinned objects + del _thread_local.cache + + +def cached_proto_context(func): + """Decorator to memoize with_context calls based on self and collisions.""" + + @functools.wraps(func) + def wrapper(self, *, collisions, **kwargs): + # 1. Initialize cache if not provided (handles the root call case) + + context_cache = getattr(_thread_local, "cache", None) + if context_cache is None: + return func(self, collisions=collisions, **kwargs) + + # 2. Create the cache key + collisions_key = frozenset(collisions) if collisions else None + key = (id(self), collisions_key) + + # 3. Check Cache + if key in context_cache: + return context_cache[key] + + # 4. Execute the actual function + # We ensure context_cache is passed down to the recursive calls + result = func(self, collisions=collisions, **kwargs) + + # 5. Update Cache + context_cache[key] = result + return result + + return wrapper diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py index eee6865e6d..bc46920e60 100644 --- a/tests/unit/utils/test_cache.py +++ b/tests/unit/utils/test_cache.py @@ -31,3 +31,87 @@ def bar(self): assert foo.call_count == 1 assert foo.bar == 42 assert foo.call_count == 1 + + +def test_generation_cache_context(): + assert cache.generation_cache.get() is None + with cache.generation_cache_context(): + assert isinstance(cache.generation_cache.get(), dict) + cache.generation_cache.get()["foo"] = "bar" + assert cache.generation_cache.get()["foo"] == "bar" + assert cache.generation_cache.get() is None + + +def test_cached_proto_context(): + # A dummy object to act as a visited message + class MockMessage: + pass + + class Foo: + def __init__(self): + self.call_count = 0 + + # We define a signature that matches the real Proto.with_context + # to ensure arguments are propagated correctly. + @cache.cached_proto_context + def with_context(self, collisions, *, skip_fields=False, visited_messages=None): + self.call_count += 1 + return f"val-{self.call_count}" + + foo = Foo() + msg_a = MockMessage() + msg_b = MockMessage() + + # 1. Test Bypass (No Context) + # The cache is not active, so every call increments the counter. + assert foo.with_context(collisions={"a"}) == "val-1" + assert foo.with_context(collisions={"a"}) == "val-2" + + # 2. Test Context Activation + with cache.generation_cache_context(): + # Reset counter to make tracking easier + foo.call_count = 0 + + # A. Basic Cache Hit + assert foo.with_context({"a"}) == "val-1" + assert foo.with_context({"a"}) == "val-1" # Hit + assert foo.call_count == 1 + + # B. Collision Difference + # Changing collisions creates a new key + assert foo.with_context({"b"}) == "val-2" + assert foo.call_count == 2 + + # C. skip_fields Difference (The Critical Fix) + # Verify that skip_fields=True is cached separately from the default (False). + # This prevents the "Husk" object from poisoning the "Full" object cache. + assert foo.with_context({"a"}, skip_fields=True) == "val-3" + assert foo.with_context({"a"}, skip_fields=True) == "val-3" # Hit + assert foo.call_count == 3 + + # Verify the original (skip_fields=False) is still accessible + assert foo.with_context({"a"}) == "val-1" # Hit (Still 1) + + # D. visited_messages Difference + # Verify that sets of objects are correctly distinguished. + + # Call with msg_a + assert foo.with_context({"a"}, visited_messages={msg_a}) == "val-4" + assert foo.with_context({"a"}, visited_messages={msg_a}) == "val-4" # Hit + assert foo.call_count == 4 + + # Call with msg_b (Should be a Miss) + assert foo.with_context({"a"}, visited_messages={msg_b}) == "val-5" + assert foo.call_count == 5 + + # E. Set Stability + # Verify that the order of items in the set doesn't matter (sets are unordered). + # {a, b} should cache-hit with {b, a} + assert foo.with_context({"a"}, visited_messages={msg_a, msg_b}) == "val-6" + assert foo.with_context({"a"}, visited_messages={msg_b, msg_a}) == "val-6" # Hit + assert foo.call_count == 6 + + # 3. Context Cleared + # Everything should be forgotten now. + assert cache.generation_cache.get() is None + assert foo.with_context({"a"}) == "val-7" \ No newline at end of file