From 9fa948710b36de263fe509112193f3c906cbf4b7 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Wed, 14 Jan 2026 23:18:20 +0000 Subject: [PATCH 01/22] optimize gapic --- gapic/schema/api.py | 19 ++++++++-- gapic/schema/metadata.py | 21 +++++++--- gapic/schema/wrappers.py | 82 +++++++++++++++++++++++++++++++--------- gapic/utils/__init__.py | 2 + gapic/utils/cache.py | 31 +++++++++++++++ 5 files changed, 128 insertions(+), 27 deletions(-) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index b05f6b6dec..e06c237960 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -115,6 +115,7 @@ def build( prior_protos: Optional[Mapping[str, "Proto"]] = None, load_services: bool = True, all_resources: Optional[Mapping[str, wrappers.MessageType]] = None, + context_cache: Optional[Dict[tuple, "wrappers.MessageType"]] = None, ) -> "Proto": """Build and return a Proto instance. @@ -139,6 +140,7 @@ def build( prior_protos=prior_protos or {}, load_services=load_services, all_resources=all_resources or {}, + context_cache=context_cache, ).proto @cached_property @@ -455,6 +457,7 @@ def disambiguate_keyword_sanitize_fname( # We just load all the APIs types first and then # load the services and methods with the full scope of types. pre_protos: Dict[str, Proto] = dict(prior_protos or {}) + context_cache = {} for fd in file_descriptors: fd.name = disambiguate_keyword_sanitize_fname(fd.name, pre_protos) pre_protos[fd.name] = Proto.build( @@ -465,6 +468,7 @@ def disambiguate_keyword_sanitize_fname( prior_protos=pre_protos, # Ugly, ugly hack. load_services=False, + context_cache=context_cache, ) # A file descriptor's file-level resources are NOT visible to any importers. @@ -485,6 +489,7 @@ def disambiguate_keyword_sanitize_fname( opts=opts, prior_protos=pre_protos, all_resources=MappingProxyType(all_file_resources), + context_cache=context_cache, ) for name, proto in pre_protos.items() } @@ -1103,6 +1108,8 @@ def __init__( prior_protos: Optional[Mapping[str, Proto]] = None, load_services: bool = True, all_resources: Optional[Mapping[str, wrappers.MessageType]] = None, + context_cache: Optional[Dict[tuple, "wrappers.MessageType"]] = None, + ): self.proto_messages: Dict[str, wrappers.MessageType] = {} self.proto_enums: Dict[str, wrappers.EnumType] = {} @@ -1111,6 +1118,7 @@ def __init__( self.file_to_generate = file_to_generate self.prior_protos = prior_protos or {} self.opts = opts + self.context_cache = context_cache # Iterate over the documentation and place it into a dictionary. # @@ -1220,20 +1228,22 @@ def proto(self) -> Proto: if not self.file_to_generate: return naive + global_collisions = frozenset(naive.names) visited_messages: Set[wrappers.MessageType] = set() # Return a context-aware proto object. return dataclasses.replace( naive, all_enums=collections.OrderedDict( - (k, v.with_context(collisions=naive.names)) + (k, v.with_context(collisions=global_collisions, context_cache=self.context_cache)) for k, v in naive.all_enums.items() ), all_messages=collections.OrderedDict( ( k, v.with_context( - collisions=naive.names, + collisions=global_collisions, visited_messages=visited_messages, + context_cache=self.context_cache, ), ) for k, v in naive.all_messages.items() @@ -1244,8 +1254,9 @@ def proto(self) -> Proto: ( k, v.with_context( - collisions=v.names, - visited_messages=visited_messages, + collisions=global_collisions, + visited_messages=frozenset(v.names), + context_cache=self.context_cache, ), ) for k, v in naive.services.items() diff --git a/gapic/schema/metadata.py b/gapic/schema/metadata.py index 7df8d0291f..806b9c9bce 100644 --- a/gapic/schema/metadata.py +++ b/gapic/schema/metadata.py @@ -28,13 +28,14 @@ import dataclasses import re -from typing import FrozenSet, Set, Tuple, Optional +from typing import FrozenSet, Set, Tuple, Optional, Dict from google.protobuf import descriptor_pb2 from gapic.schema import imp from gapic.schema import naming from gapic.utils import cached_property +from gapic.utils import cached_proto_context from gapic.utils import RESERVED_NAMES # This class is a minor hack to optimize Address's __eq__ method. @@ -359,19 +360,23 @@ def resolve(self, selector: str) -> str: return f'{".".join(self.package)}.{selector}' return selector - def with_context(self, *, collisions: Set[str]) -> "Address": + @cached_proto_context + def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = None) -> "Address": """Return a derivative of this address with the provided context. This method is used to address naming collisions. The returned ``Address`` object aliases module names to avoid naming collisions in the file being written. """ - return ( + + updated_msg = ( dataclasses.replace(self, collisions=collisions) if collisions and collisions != self.collisions else self ) + return updated_msg + @dataclasses.dataclass(frozen=True) class Metadata: @@ -398,22 +403,26 @@ def doc(self): return "\n\n".join(self.documentation.leading_detached_comments) return "" - def with_context(self, *, collisions: Set[str]) -> "Metadata": + @cached_proto_context + def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = None) -> "Metadata": """Return a derivative of this metadata with the provided context. This method is used to address naming collisions. The returned ``Address`` object aliases module names to avoid naming collisions in the file being written. """ - return ( + + updated_msg = ( dataclasses.replace( self, - address=self.address.with_context(collisions=collisions), + address=self.address.with_context(collisions=collisions, context_cache=context_cache), ) if collisions and collisions != self.address.collisions else self ) + return updated_msg + @dataclasses.dataclass(frozen=True) class FieldIdentifier: diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 17a7832756..ffc4e99426 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -69,6 +69,7 @@ from gapic.schema import metadata from gapic.utils import uri_sample from gapic.utils import make_private +from gapic.utils import cached_proto_context @dataclasses.dataclass(frozen=True) @@ -410,11 +411,13 @@ def type(self) -> Union["MessageType", "EnumType", "PrimitiveType"]: "This code should not be reachable; please file a bug." ) + @cached_proto_context def with_context( self, *, collisions: Set[str], visited_messages: Optional[Set["MessageType"]] = None, + context_cache: Optional[Dict] = None, ) -> "Field": """Return a derivative of this field with the provided context. @@ -422,7 +425,8 @@ def with_context( ``Field`` object aliases module names to avoid naming collisions in the file being written. """ - return dataclasses.replace( + + updated_msg = dataclasses.replace( self, message=( self.message.with_context( @@ -431,14 +435,17 @@ def with_context( self.message in visited_messages if visited_messages else False ), visited_messages=visited_messages, + context_cache=context_cache ) if self.message else None ), - enum=self.enum.with_context(collisions=collisions) if self.enum else None, - meta=self.meta.with_context(collisions=collisions), + enum=self.enum.with_context(collisions=collisions, context_cache=context_cache) if self.enum else None, + meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), ) + return updated_msg + def add_to_address_allowlist( self, *, @@ -737,7 +744,7 @@ def path_regex_str(self) -> str: return parsing_regex_str def get_field( - self, *field_path: str, collisions: Optional[Set[str]] = None + self, *field_path: str, context_cache: Optional[Dict] = None, collisions: Optional[Set[str]] = None ) -> Field: """Return a field arbitrarily deep in this message's structure. @@ -781,6 +788,7 @@ def get_field( return cursor.with_context( collisions=collisions, visited_messages=set({self}), + context_cache=context_cache, ) # Quick check: If cursor is a repeated field, then raise an exception. @@ -805,12 +813,14 @@ def get_field( # message. return cursor.message.get_field(*field_path[1:], collisions=collisions) + @cached_proto_context def with_context( self, *, collisions: Set[str], skip_fields: bool = False, visited_messages: Optional[Set["MessageType"]] = None, + context_cache: Optional[Dict] = None, ) -> "MessageType": """Return a derivative of this message with the provided context. @@ -824,12 +834,13 @@ def with_context( """ visited_messages = visited_messages or set() visited_messages = visited_messages | {self} - return dataclasses.replace( + + updated_msg = dataclasses.replace( self, fields=( { k: v.with_context( - collisions=collisions, visited_messages=visited_messages + collisions=collisions, visited_messages=visited_messages, context_cache=context_cache ) for k, v in self.fields.items() } @@ -837,7 +848,7 @@ def with_context( else self.fields ), nested_enums={ - k: v.with_context(collisions=collisions) + k: v.with_context(collisions=collisions, context_cache=context_cache) for k, v in self.nested_enums.items() }, nested_messages={ @@ -845,12 +856,15 @@ def with_context( collisions=collisions, skip_fields=skip_fields, visited_messages=visited_messages, + context_cache=context_cache ) for k, v in self.nested_messages.items() }, - meta=self.meta.with_context(collisions=collisions), + meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), ) + return updated_msg + def add_to_address_allowlist( self, *, @@ -937,22 +951,26 @@ def ident(self) -> metadata.Address: """Return the identifier data to be used in templates.""" return self.meta.address - def with_context(self, *, collisions: Set[str]) -> "EnumType": + @cached_proto_context + def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = None,) -> "EnumType": """Return a derivative of this enum with the provided context. This method is used to address naming collisions. The returned ``EnumType`` object aliases module names to avoid naming collisions in the file being written. """ - return ( + + updated_msg = ( dataclasses.replace( self, - meta=self.meta.with_context(collisions=collisions), + meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), ) if collisions else self ) + return updated_msg + def add_to_address_allowlist( self, *, address_allowlist: Set["metadata.Address"] ) -> None: @@ -1058,11 +1076,13 @@ class ExtendedOperationInfo: request_type: MessageType operation_type: MessageType + @cached_proto_context def with_context( self, *, collisions: Set[str], visited_messages: Optional[Set["MessageType"]] = None, + context_cache: Optional[Dict] = None, ) -> "ExtendedOperationInfo": """Return a derivative of this OperationInfo with the provided context. @@ -1070,7 +1090,8 @@ def with_context( ``OperationInfo`` object aliases module names to avoid naming collisions in the file being written. """ - return ( + + updated_msg = ( self if not collisions else dataclasses.replace( @@ -1078,14 +1099,18 @@ def with_context( request_type=self.request_type.with_context( collisions=collisions, visited_messages=visited_messages, + context_cache=context_cache ), operation_type=self.operation_type.with_context( collisions=collisions, visited_messages=visited_messages, + context_cache=context_cache ), ) ) + return updated_msg + def add_to_address_allowlist( self, *, @@ -1127,11 +1152,13 @@ class OperationInfo: response_type: MessageType metadata_type: MessageType + @cached_proto_context def with_context( self, *, collisions: Set[str], visited_messages: Optional[Set["MessageType"]] = None, + context_cache: Optional[Dict] = None, ) -> "OperationInfo": """Return a derivative of this OperationInfo with the provided context. @@ -1139,18 +1166,24 @@ def with_context( ``OperationInfo`` object aliases module names to avoid naming collisions in the file being written. """ - return dataclasses.replace( + + updated_msg = dataclasses.replace( self, response_type=self.response_type.with_context( collisions=collisions, visited_messages=visited_messages, + context_cache=context_cache ), metadata_type=self.metadata_type.with_context( collisions=collisions, visited_messages=visited_messages, + context_cache=context_cache ), ) + return updated_msg + + def add_to_address_allowlist( self, *, @@ -1937,11 +1970,13 @@ def void(self) -> bool: """Return True if this method has no return value, False otherwise.""" return self.output.ident.proto == "google.protobuf.Empty" + @cached_proto_context def with_context( self, *, collisions: Set[str], visited_messages: Optional[Set["MessageType"]] = None, + context_cache: Optional[Dict] = None, ) -> "Method": """Return a derivative of this method with the provided context. @@ -1949,12 +1984,14 @@ def with_context( ``Method`` object aliases module names to avoid naming collisions in the file being written. """ + maybe_lro = None if self.lro: maybe_lro = ( self.lro.with_context( collisions=collisions, visited_messages=visited_messages, + context_cache=context_cache ) if collisions else self.lro @@ -1964,26 +2001,31 @@ def with_context( self.extended_lro.with_context( collisions=collisions, visited_messages=visited_messages, + context_cache=context_cache ) if self.extended_lro else None ) - return dataclasses.replace( + updated_msg = dataclasses.replace( self, lro=maybe_lro, extended_lro=maybe_extended_lro, input=self.input.with_context( collisions=collisions, visited_messages=visited_messages, + context_cache=context_cache ), output=self.output.with_context( collisions=collisions, visited_messages=visited_messages, + context_cache=context_cache ), - meta=self.meta.with_context(collisions=collisions), + meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), ) + return updated_msg + def add_to_address_allowlist( self, *, @@ -2357,11 +2399,13 @@ def operation_polling_method(self) -> Optional[Method]: def is_internal(self) -> bool: return any(m.is_internal for m in self.methods.values()) + @cached_proto_context def with_context( self, *, collisions: Set[str], visited_messages: Optional[Set["MessageType"]] = None, + context_cache: Optional[Dict] = None, ) -> "Service": """Return a derivative of this service with the provided context. @@ -2369,7 +2413,8 @@ def with_context( ``Service`` object aliases module names to avoid naming collisions in the file being written. """ - return dataclasses.replace( + + updated_msg = dataclasses.replace( self, methods={ k: v.with_context( @@ -2377,11 +2422,14 @@ def with_context( # that may conflict with module imports. collisions=collisions | set(v.flattened_fields.keys()), visited_messages=visited_messages, + context_cache=context_cache ) for k, v in self.methods.items() }, - meta=self.meta.with_context(collisions=collisions), + meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), ) + + return updated_msg def add_to_address_allowlist( self, diff --git a/gapic/utils/__init__.py b/gapic/utils/__init__.py index 8b48801730..61acd16d35 100644 --- a/gapic/utils/__init__.py +++ b/gapic/utils/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from gapic.utils.cache import cached_property +from gapic.utils.cache import cached_proto_context from gapic.utils.case import to_snake_case from gapic.utils.case import to_camel_case from gapic.utils.checks import is_msg_field_pb @@ -34,6 +35,7 @@ __all__ = ( "cached_property", + "cached_proto_context" "convert_uri_fieldnames", "doc", "empty", diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index f9c4d703f5..8d401a9eaf 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +from typing import Dict, Optional def cached_property(fx): @@ -43,3 +44,33 @@ def inner(self): return self._cached_values[fx.__name__] return property(inner) + + +def cached_proto_context(func): + """Decorator to memoize with_context calls based on self and collisions.""" + + @functools.wraps(func) + def wrapper(self, *, collisions, context_cache: Optional[Dict] = None, **kwargs): + # 1. Initialize cache if not provided (handles the root call case) + if context_cache is None: + context_cache = {} + + # 2. Create the cache key + collisions_key = frozenset(collisions) if collisions else None + key = (id(self), collisions_key) + + # 3. Check Cache + if key in context_cache: + return context_cache[key] + + # 4. Execute the actual function + # We ensure context_cache is passed down to the recursive calls + result = func( + self, collisions=collisions, context_cache=context_cache, **kwargs + ) + + # 5. Update Cache + context_cache[key] = result + return result + + return wrapper From afaa05089d31cf96f2482f1231a9f9c0cf75611b Mon Sep 17 00:00:00 2001 From: ohmayr Date: Thu, 22 Jan 2026 21:24:56 +0000 Subject: [PATCH 02/22] update type to Dict --- gapic/schema/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index e06c237960..a7deed86ba 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -115,7 +115,7 @@ def build( prior_protos: Optional[Mapping[str, "Proto"]] = None, load_services: bool = True, all_resources: Optional[Mapping[str, wrappers.MessageType]] = None, - context_cache: Optional[Dict[tuple, "wrappers.MessageType"]] = None, + context_cache: Optional[Dict] = None, ) -> "Proto": """Build and return a Proto instance. @@ -1108,7 +1108,7 @@ def __init__( prior_protos: Optional[Mapping[str, Proto]] = None, load_services: bool = True, all_resources: Optional[Mapping[str, wrappers.MessageType]] = None, - context_cache: Optional[Dict[tuple, "wrappers.MessageType"]] = None, + context_cache: Optional[Dict] = None, ): self.proto_messages: Dict[str, wrappers.MessageType] = {} From edad384db0a8cf87d004030a949089b212a7dc63 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Thu, 22 Jan 2026 21:41:14 +0000 Subject: [PATCH 03/22] reduce diff --- gapic/schema/api.py | 2 +- gapic/schema/metadata.py | 10 ++-------- gapic/schema/wrappers.py | 36 +++++++----------------------------- 3 files changed, 10 insertions(+), 38 deletions(-) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index a7deed86ba..25ed24e88b 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -1255,7 +1255,7 @@ def proto(self) -> Proto: k, v.with_context( collisions=global_collisions, - visited_messages=frozenset(v.names), + visited_messages=visited_messages, context_cache=self.context_cache, ), ) diff --git a/gapic/schema/metadata.py b/gapic/schema/metadata.py index 806b9c9bce..b33e416269 100644 --- a/gapic/schema/metadata.py +++ b/gapic/schema/metadata.py @@ -368,15 +368,12 @@ def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = ``Address`` object aliases module names to avoid naming collisions in the file being written. """ - - updated_msg = ( + return ( dataclasses.replace(self, collisions=collisions) if collisions and collisions != self.collisions else self ) - return updated_msg - @dataclasses.dataclass(frozen=True) class Metadata: @@ -411,8 +408,7 @@ def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = ``Address`` object aliases module names to avoid naming collisions in the file being written. """ - - updated_msg = ( + return ( dataclasses.replace( self, address=self.address.with_context(collisions=collisions, context_cache=context_cache), @@ -421,8 +417,6 @@ def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = else self ) - return updated_msg - @dataclasses.dataclass(frozen=True) class FieldIdentifier: diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index ffc4e99426..e7164f40eb 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -425,8 +425,7 @@ def with_context( ``Field`` object aliases module names to avoid naming collisions in the file being written. """ - - updated_msg = dataclasses.replace( + return dataclasses.replace( self, message=( self.message.with_context( @@ -444,8 +443,6 @@ def with_context( meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), ) - return updated_msg - def add_to_address_allowlist( self, *, @@ -834,8 +831,7 @@ def with_context( """ visited_messages = visited_messages or set() visited_messages = visited_messages | {self} - - updated_msg = dataclasses.replace( + return dataclasses.replace( self, fields=( { @@ -863,8 +859,6 @@ def with_context( meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), ) - return updated_msg - def add_to_address_allowlist( self, *, @@ -959,8 +953,7 @@ def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = ``EnumType`` object aliases module names to avoid naming collisions in the file being written. """ - - updated_msg = ( + return ( dataclasses.replace( self, meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), @@ -969,8 +962,6 @@ def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = else self ) - return updated_msg - def add_to_address_allowlist( self, *, address_allowlist: Set["metadata.Address"] ) -> None: @@ -1090,8 +1081,7 @@ def with_context( ``OperationInfo`` object aliases module names to avoid naming collisions in the file being written. """ - - updated_msg = ( + return ( self if not collisions else dataclasses.replace( @@ -1109,8 +1099,6 @@ def with_context( ) ) - return updated_msg - def add_to_address_allowlist( self, *, @@ -1166,8 +1154,7 @@ def with_context( ``OperationInfo`` object aliases module names to avoid naming collisions in the file being written. """ - - updated_msg = dataclasses.replace( + return dataclasses.replace( self, response_type=self.response_type.with_context( collisions=collisions, @@ -1181,9 +1168,6 @@ def with_context( ), ) - return updated_msg - - def add_to_address_allowlist( self, *, @@ -1984,7 +1968,6 @@ def with_context( ``Method`` object aliases module names to avoid naming collisions in the file being written. """ - maybe_lro = None if self.lro: maybe_lro = ( @@ -2007,7 +1990,7 @@ def with_context( else None ) - updated_msg = dataclasses.replace( + return dataclasses.replace( self, lro=maybe_lro, extended_lro=maybe_extended_lro, @@ -2024,8 +2007,6 @@ def with_context( meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), ) - return updated_msg - def add_to_address_allowlist( self, *, @@ -2413,8 +2394,7 @@ def with_context( ``Service`` object aliases module names to avoid naming collisions in the file being written. """ - - updated_msg = dataclasses.replace( + return dataclasses.replace( self, methods={ k: v.with_context( @@ -2428,8 +2408,6 @@ def with_context( }, meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), ) - - return updated_msg def add_to_address_allowlist( self, From 2103a8a09aa5cce98f791a3269c9122139829a71 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Thu, 22 Jan 2026 21:48:56 +0000 Subject: [PATCH 04/22] use module level cache --- gapic/schema/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index 25ed24e88b..fe1acc95ff 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -1230,6 +1230,7 @@ def proto(self) -> Proto: global_collisions = frozenset(naive.names) visited_messages: Set[wrappers.MessageType] = set() + self.context_cache = {} # Return a context-aware proto object. return dataclasses.replace( naive, From fa92e027144b6ecb9ecd5f1cb2385c563eac9e10 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Thu, 22 Jan 2026 21:52:05 +0000 Subject: [PATCH 05/22] pass cache to meta --- gapic/schema/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index fe1acc95ff..cb46510a4b 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -1262,7 +1262,7 @@ def proto(self) -> Proto: ) for k, v in naive.services.items() ), - meta=naive.meta.with_context(collisions=naive.names), + meta=naive.meta.with_context(collisions=naive.names, self.context_cache), ) @cached_property From 58f3049e1cdbc7ce52894438cb9e3cb576a2207c Mon Sep 17 00:00:00 2001 From: ohmayr Date: Thu, 22 Jan 2026 23:35:18 +0000 Subject: [PATCH 06/22] fix positional argument --- gapic/schema/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index cb46510a4b..584dcb7c81 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -1262,7 +1262,7 @@ def proto(self) -> Proto: ) for k, v in naive.services.items() ), - meta=naive.meta.with_context(collisions=naive.names, self.context_cache), + meta=naive.meta.with_context(collisions=naive.names, context_cache=self.context_cache), ) @cached_property From 49650118c77a676dca53f1687596eb695ffb298c Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Jan 2026 09:13:55 +0000 Subject: [PATCH 07/22] pass down service names --- gapic/schema/api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index 584dcb7c81..3a446ceadb 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -1255,7 +1255,7 @@ def proto(self) -> Proto: ( k, v.with_context( - collisions=global_collisions, + collisions=v.names, visited_messages=visited_messages, context_cache=self.context_cache, ), From e03a978f64ee4635593ea71d44cabc339f6e8b96 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Jan 2026 09:28:32 +0000 Subject: [PATCH 08/22] fix lint --- gapic/schema/api.py | 12 ++++++-- gapic/schema/metadata.py | 12 ++++++-- gapic/schema/wrappers.py | 66 ++++++++++++++++++++++++++++------------ gapic/utils/__init__.py | 3 +- 4 files changed, 65 insertions(+), 28 deletions(-) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index 3a446ceadb..eaac31f23d 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -1109,7 +1109,6 @@ def __init__( load_services: bool = True, all_resources: Optional[Mapping[str, wrappers.MessageType]] = None, context_cache: Optional[Dict] = None, - ): self.proto_messages: Dict[str, wrappers.MessageType] = {} self.proto_enums: Dict[str, wrappers.EnumType] = {} @@ -1235,7 +1234,12 @@ def proto(self) -> Proto: return dataclasses.replace( naive, all_enums=collections.OrderedDict( - (k, v.with_context(collisions=global_collisions, context_cache=self.context_cache)) + ( + k, + v.with_context( + collisions=global_collisions, context_cache=self.context_cache + ), + ) for k, v in naive.all_enums.items() ), all_messages=collections.OrderedDict( @@ -1262,7 +1266,9 @@ def proto(self) -> Proto: ) for k, v in naive.services.items() ), - meta=naive.meta.with_context(collisions=naive.names, context_cache=self.context_cache), + meta=naive.meta.with_context( + collisions=naive.names, context_cache=self.context_cache + ), ) @cached_property diff --git a/gapic/schema/metadata.py b/gapic/schema/metadata.py index b33e416269..c2b3c7b13b 100644 --- a/gapic/schema/metadata.py +++ b/gapic/schema/metadata.py @@ -361,7 +361,9 @@ def resolve(self, selector: str) -> str: return selector @cached_proto_context - def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = None) -> "Address": + def with_context( + self, *, collisions: Set[str], context_cache: Optional[Dict] = None + ) -> "Address": """Return a derivative of this address with the provided context. This method is used to address naming collisions. The returned @@ -401,7 +403,9 @@ def doc(self): return "" @cached_proto_context - def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = None) -> "Metadata": + def with_context( + self, *, collisions: Set[str], context_cache: Optional[Dict] = None + ) -> "Metadata": """Return a derivative of this metadata with the provided context. This method is used to address naming collisions. The returned @@ -411,7 +415,9 @@ def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = return ( dataclasses.replace( self, - address=self.address.with_context(collisions=collisions, context_cache=context_cache), + address=self.address.with_context( + collisions=collisions, context_cache=context_cache + ), ) if collisions and collisions != self.address.collisions else self diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index e7164f40eb..292d832585 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -434,13 +434,21 @@ def with_context( self.message in visited_messages if visited_messages else False ), visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, ) if self.message else None ), - enum=self.enum.with_context(collisions=collisions, context_cache=context_cache) if self.enum else None, - meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), + enum=( + self.enum.with_context( + collisions=collisions, context_cache=context_cache + ) + if self.enum + else None + ), + meta=self.meta.with_context( + collisions=collisions, context_cache=context_cache + ), ) def add_to_address_allowlist( @@ -741,7 +749,10 @@ def path_regex_str(self) -> str: return parsing_regex_str def get_field( - self, *field_path: str, context_cache: Optional[Dict] = None, collisions: Optional[Set[str]] = None + self, + *field_path: str, + context_cache: Optional[Dict] = None, + collisions: Optional[Set[str]] = None, ) -> Field: """Return a field arbitrarily deep in this message's structure. @@ -836,7 +847,9 @@ def with_context( fields=( { k: v.with_context( - collisions=collisions, visited_messages=visited_messages, context_cache=context_cache + collisions=collisions, + visited_messages=visited_messages, + context_cache=context_cache, ) for k, v in self.fields.items() } @@ -852,11 +865,13 @@ def with_context( collisions=collisions, skip_fields=skip_fields, visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, ) for k, v in self.nested_messages.items() }, - meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), + meta=self.meta.with_context( + collisions=collisions, context_cache=context_cache + ), ) def add_to_address_allowlist( @@ -946,7 +961,12 @@ def ident(self) -> metadata.Address: return self.meta.address @cached_proto_context - def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = None,) -> "EnumType": + def with_context( + self, + *, + collisions: Set[str], + context_cache: Optional[Dict] = None, + ) -> "EnumType": """Return a derivative of this enum with the provided context. This method is used to address naming collisions. The returned @@ -956,7 +976,9 @@ def with_context(self, *, collisions: Set[str], context_cache: Optional[Dict] = return ( dataclasses.replace( self, - meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), + meta=self.meta.with_context( + collisions=collisions, context_cache=context_cache + ), ) if collisions else self @@ -1089,12 +1111,12 @@ def with_context( request_type=self.request_type.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, ), operation_type=self.operation_type.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, ), ) ) @@ -1159,12 +1181,12 @@ def with_context( response_type=self.response_type.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, ), metadata_type=self.metadata_type.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, ), ) @@ -1974,7 +1996,7 @@ def with_context( self.lro.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, ) if collisions else self.lro @@ -1984,7 +2006,7 @@ def with_context( self.extended_lro.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, ) if self.extended_lro else None @@ -1997,14 +2019,16 @@ def with_context( input=self.input.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, ), output=self.output.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, + ), + meta=self.meta.with_context( + collisions=collisions, context_cache=context_cache ), - meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), ) def add_to_address_allowlist( @@ -2402,11 +2426,13 @@ def with_context( # that may conflict with module imports. collisions=collisions | set(v.flattened_fields.keys()), visited_messages=visited_messages, - context_cache=context_cache + context_cache=context_cache, ) for k, v in self.methods.items() }, - meta=self.meta.with_context(collisions=collisions, context_cache=context_cache), + meta=self.meta.with_context( + collisions=collisions, context_cache=context_cache + ), ) def add_to_address_allowlist( diff --git a/gapic/utils/__init__.py b/gapic/utils/__init__.py index 61acd16d35..1b8b866559 100644 --- a/gapic/utils/__init__.py +++ b/gapic/utils/__init__.py @@ -35,8 +35,7 @@ __all__ = ( "cached_property", - "cached_proto_context" - "convert_uri_fieldnames", + "cached_proto_context" "convert_uri_fieldnames", "doc", "empty", "is_msg_field_pb", From fd827c6e766fd4bc84bbe9af12597a3c77acd996 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Jan 2026 22:29:59 +0000 Subject: [PATCH 09/22] avoid resetting the context cache --- gapic/schema/api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index eaac31f23d..69c2706141 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -457,7 +457,6 @@ def disambiguate_keyword_sanitize_fname( # We just load all the APIs types first and then # load the services and methods with the full scope of types. pre_protos: Dict[str, Proto] = dict(prior_protos or {}) - context_cache = {} for fd in file_descriptors: fd.name = disambiguate_keyword_sanitize_fname(fd.name, pre_protos) pre_protos[fd.name] = Proto.build( From fd5c07ba0adf022c6e1bae4990cdb82254fcf124 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Jan 2026 22:40:25 +0000 Subject: [PATCH 10/22] define cache in API layer --- gapic/schema/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index 69c2706141..ba7a5d59ca 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -456,6 +456,7 @@ def disambiguate_keyword_sanitize_fname( # type into the proto file that defines an LRO. # We just load all the APIs types first and then # load the services and methods with the full scope of types. + context_cache = {} pre_protos: Dict[str, Proto] = dict(prior_protos or {}) for fd in file_descriptors: fd.name = disambiguate_keyword_sanitize_fname(fd.name, pre_protos) From 2537d980e02abbce206b53aa3a3953d9c1b10401 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Fri, 23 Jan 2026 22:44:16 +0000 Subject: [PATCH 11/22] fix init file --- gapic/utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/gapic/utils/__init__.py b/gapic/utils/__init__.py index 1b8b866559..23c5739156 100644 --- a/gapic/utils/__init__.py +++ b/gapic/utils/__init__.py @@ -35,7 +35,8 @@ __all__ = ( "cached_property", - "cached_proto_context" "convert_uri_fieldnames", + "cached_proto_context", + "convert_uri_fieldnames", "doc", "empty", "is_msg_field_pb", From fbad3e14708a03c620ccaf145f2816f80df2c2c4 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Mon, 26 Jan 2026 05:35:49 +0000 Subject: [PATCH 12/22] use a context manager instead of passing args --- gapic/cli/generate.py | 19 ++++++++++--------- gapic/schema/api.py | 24 +++--------------------- gapic/schema/metadata.py | 6 +++--- gapic/schema/wrappers.py | 35 +++++++---------------------------- gapic/utils/cache.py | 40 ++++++++++++++++++++++++++++++++++++---- 5 files changed, 59 insertions(+), 65 deletions(-) diff --git a/gapic/cli/generate.py b/gapic/cli/generate.py index e8eee1f034..32f3847fde 100644 --- a/gapic/cli/generate.py +++ b/gapic/cli/generate.py @@ -23,7 +23,7 @@ from gapic import generator from gapic.schema import api from gapic.utils import Options - +from gapic.utils.cache import generation_cache_context @click.command() @click.option( @@ -56,15 +56,16 @@ def generate(request: typing.BinaryIO, output: typing.BinaryIO) -> None: [p.package for p in req.proto_file if p.name in req.file_to_generate] ).rstrip(".") - # Build the API model object. - # This object is a frozen representation of the whole API, and is sent - # to each template in the rendering step. - api_schema = api.API.build(req.proto_file, opts=opts, package=package) + with generation_cache_context(): + # Build the API model object. + # This object is a frozen representation of the whole API, and is sent + # to each template in the rendering step. + api_schema = api.API.build(req.proto_file, opts=opts, package=package) - # Translate into a protobuf CodeGeneratorResponse; this reads the - # individual templates and renders them. - # If there are issues, error out appropriately. - res = generator.Generator(opts).get_response(api_schema, opts) + # Translate into a protobuf CodeGeneratorResponse; this reads the + # individual templates and renders them. + # If there are issues, error out appropriately. + res = generator.Generator(opts).get_response(api_schema, opts) # Output the serialized response. output.write(res.SerializeToString()) diff --git a/gapic/schema/api.py b/gapic/schema/api.py index ba7a5d59ca..b05f6b6dec 100644 --- a/gapic/schema/api.py +++ b/gapic/schema/api.py @@ -115,7 +115,6 @@ def build( prior_protos: Optional[Mapping[str, "Proto"]] = None, load_services: bool = True, all_resources: Optional[Mapping[str, wrappers.MessageType]] = None, - context_cache: Optional[Dict] = None, ) -> "Proto": """Build and return a Proto instance. @@ -140,7 +139,6 @@ def build( prior_protos=prior_protos or {}, load_services=load_services, all_resources=all_resources or {}, - context_cache=context_cache, ).proto @cached_property @@ -456,7 +454,6 @@ def disambiguate_keyword_sanitize_fname( # type into the proto file that defines an LRO. # We just load all the APIs types first and then # load the services and methods with the full scope of types. - context_cache = {} pre_protos: Dict[str, Proto] = dict(prior_protos or {}) for fd in file_descriptors: fd.name = disambiguate_keyword_sanitize_fname(fd.name, pre_protos) @@ -468,7 +465,6 @@ def disambiguate_keyword_sanitize_fname( prior_protos=pre_protos, # Ugly, ugly hack. load_services=False, - context_cache=context_cache, ) # A file descriptor's file-level resources are NOT visible to any importers. @@ -489,7 +485,6 @@ def disambiguate_keyword_sanitize_fname( opts=opts, prior_protos=pre_protos, all_resources=MappingProxyType(all_file_resources), - context_cache=context_cache, ) for name, proto in pre_protos.items() } @@ -1108,7 +1103,6 @@ def __init__( prior_protos: Optional[Mapping[str, Proto]] = None, load_services: bool = True, all_resources: Optional[Mapping[str, wrappers.MessageType]] = None, - context_cache: Optional[Dict] = None, ): self.proto_messages: Dict[str, wrappers.MessageType] = {} self.proto_enums: Dict[str, wrappers.EnumType] = {} @@ -1117,7 +1111,6 @@ def __init__( self.file_to_generate = file_to_generate self.prior_protos = prior_protos or {} self.opts = opts - self.context_cache = context_cache # Iterate over the documentation and place it into a dictionary. # @@ -1227,28 +1220,20 @@ def proto(self) -> Proto: if not self.file_to_generate: return naive - global_collisions = frozenset(naive.names) visited_messages: Set[wrappers.MessageType] = set() - self.context_cache = {} # Return a context-aware proto object. return dataclasses.replace( naive, all_enums=collections.OrderedDict( - ( - k, - v.with_context( - collisions=global_collisions, context_cache=self.context_cache - ), - ) + (k, v.with_context(collisions=naive.names)) for k, v in naive.all_enums.items() ), all_messages=collections.OrderedDict( ( k, v.with_context( - collisions=global_collisions, + collisions=naive.names, visited_messages=visited_messages, - context_cache=self.context_cache, ), ) for k, v in naive.all_messages.items() @@ -1261,14 +1246,11 @@ def proto(self) -> Proto: v.with_context( collisions=v.names, visited_messages=visited_messages, - context_cache=self.context_cache, ), ) for k, v in naive.services.items() ), - meta=naive.meta.with_context( - collisions=naive.names, context_cache=self.context_cache - ), + meta=naive.meta.with_context(collisions=naive.names), ) @cached_property diff --git a/gapic/schema/metadata.py b/gapic/schema/metadata.py index c2b3c7b13b..10460fd369 100644 --- a/gapic/schema/metadata.py +++ b/gapic/schema/metadata.py @@ -362,7 +362,7 @@ def resolve(self, selector: str) -> str: @cached_proto_context def with_context( - self, *, collisions: Set[str], context_cache: Optional[Dict] = None + self, *, collisions: Set[str] ) -> "Address": """Return a derivative of this address with the provided context. @@ -404,7 +404,7 @@ def doc(self): @cached_proto_context def with_context( - self, *, collisions: Set[str], context_cache: Optional[Dict] = None + self, *, collisions: Set[str] ) -> "Metadata": """Return a derivative of this metadata with the provided context. @@ -416,7 +416,7 @@ def with_context( dataclasses.replace( self, address=self.address.with_context( - collisions=collisions, context_cache=context_cache + collisions=collisions ), ) if collisions and collisions != self.address.collisions diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index 292d832585..aae4e59939 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -417,7 +417,6 @@ def with_context( *, collisions: Set[str], visited_messages: Optional[Set["MessageType"]] = None, - context_cache: Optional[Dict] = None, ) -> "Field": """Return a derivative of this field with the provided context. @@ -434,20 +433,19 @@ def with_context( self.message in visited_messages if visited_messages else False ), visited_messages=visited_messages, - context_cache=context_cache, ) if self.message else None ), enum=( self.enum.with_context( - collisions=collisions, context_cache=context_cache + collisions=collisions ) if self.enum else None ), meta=self.meta.with_context( - collisions=collisions, context_cache=context_cache + collisions=collisions, ), ) @@ -751,7 +749,6 @@ def path_regex_str(self) -> str: def get_field( self, *field_path: str, - context_cache: Optional[Dict] = None, collisions: Optional[Set[str]] = None, ) -> Field: """Return a field arbitrarily deep in this message's structure. @@ -796,7 +793,6 @@ def get_field( return cursor.with_context( collisions=collisions, visited_messages=set({self}), - context_cache=context_cache, ) # Quick check: If cursor is a repeated field, then raise an exception. @@ -828,7 +824,6 @@ def with_context( collisions: Set[str], skip_fields: bool = False, visited_messages: Optional[Set["MessageType"]] = None, - context_cache: Optional[Dict] = None, ) -> "MessageType": """Return a derivative of this message with the provided context. @@ -849,7 +844,6 @@ def with_context( k: v.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache, ) for k, v in self.fields.items() } @@ -857,7 +851,7 @@ def with_context( else self.fields ), nested_enums={ - k: v.with_context(collisions=collisions, context_cache=context_cache) + k: v.with_context(collisions=collisions) for k, v in self.nested_enums.items() }, nested_messages={ @@ -865,12 +859,11 @@ def with_context( collisions=collisions, skip_fields=skip_fields, visited_messages=visited_messages, - context_cache=context_cache, ) for k, v in self.nested_messages.items() }, meta=self.meta.with_context( - collisions=collisions, context_cache=context_cache + collisions=collisions ), ) @@ -965,7 +958,6 @@ def with_context( self, *, collisions: Set[str], - context_cache: Optional[Dict] = None, ) -> "EnumType": """Return a derivative of this enum with the provided context. @@ -977,7 +969,7 @@ def with_context( dataclasses.replace( self, meta=self.meta.with_context( - collisions=collisions, context_cache=context_cache + collisions=collisions ), ) if collisions @@ -1095,7 +1087,6 @@ def with_context( *, collisions: Set[str], visited_messages: Optional[Set["MessageType"]] = None, - context_cache: Optional[Dict] = None, ) -> "ExtendedOperationInfo": """Return a derivative of this OperationInfo with the provided context. @@ -1111,12 +1102,10 @@ def with_context( request_type=self.request_type.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache, ), operation_type=self.operation_type.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache, ), ) ) @@ -1168,7 +1157,6 @@ def with_context( *, collisions: Set[str], visited_messages: Optional[Set["MessageType"]] = None, - context_cache: Optional[Dict] = None, ) -> "OperationInfo": """Return a derivative of this OperationInfo with the provided context. @@ -1181,12 +1169,10 @@ def with_context( response_type=self.response_type.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache, ), metadata_type=self.metadata_type.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache, ), ) @@ -1982,7 +1968,6 @@ def with_context( *, collisions: Set[str], visited_messages: Optional[Set["MessageType"]] = None, - context_cache: Optional[Dict] = None, ) -> "Method": """Return a derivative of this method with the provided context. @@ -1996,7 +1981,6 @@ def with_context( self.lro.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache, ) if collisions else self.lro @@ -2006,7 +1990,6 @@ def with_context( self.extended_lro.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache, ) if self.extended_lro else None @@ -2019,15 +2002,13 @@ def with_context( input=self.input.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache, ), output=self.output.with_context( collisions=collisions, visited_messages=visited_messages, - context_cache=context_cache, ), meta=self.meta.with_context( - collisions=collisions, context_cache=context_cache + collisions=collisions, ), ) @@ -2410,7 +2391,6 @@ def with_context( *, collisions: Set[str], visited_messages: Optional[Set["MessageType"]] = None, - context_cache: Optional[Dict] = None, ) -> "Service": """Return a derivative of this service with the provided context. @@ -2426,12 +2406,11 @@ def with_context( # that may conflict with module imports. collisions=collisions | set(v.flattened_fields.keys()), visited_messages=visited_messages, - context_cache=context_cache, ) for k, v in self.methods.items() }, meta=self.meta.with_context( - collisions=collisions, context_cache=context_cache + collisions=collisions, ), ) diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index 8d401a9eaf..d65c3e21ba 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -13,7 +13,9 @@ # limitations under the License. import functools -from typing import Dict, Optional +import contextlib +from contextvars import ContextVar +from typing import Dict, Optional, Any def cached_property(fx): @@ -45,15 +47,45 @@ def inner(self): return property(inner) +# 1. The ContextVar (Default is None) +# This replaces threading.local with the async-safe standard. +generation_cache: ContextVar[Optional[Dict[Any, Any]]] = ContextVar( + "generation_cache", default=None +) + +# Optimization: Bind .get for speed +_get_cache = generation_cache.get + +@contextlib.contextmanager +def generation_cache_context(): + """Context manager to explicitly manage the cache lifecycle. + + Usage: + with generation_cache_context(): + # Cache is active (fast) + ... + """ + # Initialize: Create a new dictionary and set it in the ContextVar + token = generation_cache.set({}) + try: + yield + finally: + # Cleanup: Reset the ContextVar to its previous state (None) + # This allows the dictionary to be garbage collected. + generation_cache.reset(token) def cached_proto_context(func): """Decorator to memoize with_context calls based on self and collisions.""" @functools.wraps(func) - def wrapper(self, *, collisions, context_cache: Optional[Dict] = None, **kwargs): + def wrapper(self, *, collisions, **kwargs): # 1. Initialize cache if not provided (handles the root call case) + + context_cache = _get_cache() if context_cache is None: - context_cache = {} + raise RuntimeError( + f"Cache MISSING! {func.__name__} called on {type(self).__name__} outside of context manager." + ) # 2. Create the cache key collisions_key = frozenset(collisions) if collisions else None @@ -66,7 +98,7 @@ def wrapper(self, *, collisions, context_cache: Optional[Dict] = None, **kwargs) # 4. Execute the actual function # We ensure context_cache is passed down to the recursive calls result = func( - self, collisions=collisions, context_cache=context_cache, **kwargs + self, collisions=collisions, **kwargs ) # 5. Update Cache From db204f68141adabf954689ae9467126f5bfccfcf Mon Sep 17 00:00:00 2001 From: ohmayr Date: Mon, 26 Jan 2026 05:37:29 +0000 Subject: [PATCH 13/22] fix formatting --- gapic/cli/generate.py | 1 + gapic/schema/metadata.py | 12 +++--------- gapic/schema/wrappers.py | 16 +++------------- gapic/utils/cache.py | 11 ++++++----- 4 files changed, 13 insertions(+), 27 deletions(-) diff --git a/gapic/cli/generate.py b/gapic/cli/generate.py index 32f3847fde..ec13df8ebf 100644 --- a/gapic/cli/generate.py +++ b/gapic/cli/generate.py @@ -25,6 +25,7 @@ from gapic.utils import Options from gapic.utils.cache import generation_cache_context + @click.command() @click.option( "--request", diff --git a/gapic/schema/metadata.py b/gapic/schema/metadata.py index 10460fd369..d973a4faf9 100644 --- a/gapic/schema/metadata.py +++ b/gapic/schema/metadata.py @@ -361,9 +361,7 @@ def resolve(self, selector: str) -> str: return selector @cached_proto_context - def with_context( - self, *, collisions: Set[str] - ) -> "Address": + def with_context(self, *, collisions: Set[str]) -> "Address": """Return a derivative of this address with the provided context. This method is used to address naming collisions. The returned @@ -403,9 +401,7 @@ def doc(self): return "" @cached_proto_context - def with_context( - self, *, collisions: Set[str] - ) -> "Metadata": + def with_context(self, *, collisions: Set[str]) -> "Metadata": """Return a derivative of this metadata with the provided context. This method is used to address naming collisions. The returned @@ -415,9 +411,7 @@ def with_context( return ( dataclasses.replace( self, - address=self.address.with_context( - collisions=collisions - ), + address=self.address.with_context(collisions=collisions), ) if collisions and collisions != self.address.collisions else self diff --git a/gapic/schema/wrappers.py b/gapic/schema/wrappers.py index aae4e59939..26118f9c1b 100644 --- a/gapic/schema/wrappers.py +++ b/gapic/schema/wrappers.py @@ -437,13 +437,7 @@ def with_context( if self.message else None ), - enum=( - self.enum.with_context( - collisions=collisions - ) - if self.enum - else None - ), + enum=(self.enum.with_context(collisions=collisions) if self.enum else None), meta=self.meta.with_context( collisions=collisions, ), @@ -862,9 +856,7 @@ def with_context( ) for k, v in self.nested_messages.items() }, - meta=self.meta.with_context( - collisions=collisions - ), + meta=self.meta.with_context(collisions=collisions), ) def add_to_address_allowlist( @@ -968,9 +960,7 @@ def with_context( return ( dataclasses.replace( self, - meta=self.meta.with_context( - collisions=collisions - ), + meta=self.meta.with_context(collisions=collisions), ) if collisions else self diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index d65c3e21ba..87ac0b8253 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -47,6 +47,7 @@ def inner(self): return property(inner) + # 1. The ContextVar (Default is None) # This replaces threading.local with the async-safe standard. generation_cache: ContextVar[Optional[Dict[Any, Any]]] = ContextVar( @@ -56,10 +57,11 @@ def inner(self): # Optimization: Bind .get for speed _get_cache = generation_cache.get + @contextlib.contextmanager def generation_cache_context(): """Context manager to explicitly manage the cache lifecycle. - + Usage: with generation_cache_context(): # Cache is active (fast) @@ -74,6 +76,7 @@ def generation_cache_context(): # This allows the dictionary to be garbage collected. generation_cache.reset(token) + def cached_proto_context(func): """Decorator to memoize with_context calls based on self and collisions.""" @@ -83,7 +86,7 @@ def wrapper(self, *, collisions, **kwargs): context_cache = _get_cache() if context_cache is None: - raise RuntimeError( + raise RuntimeError( f"Cache MISSING! {func.__name__} called on {type(self).__name__} outside of context manager." ) @@ -97,9 +100,7 @@ def wrapper(self, *, collisions, **kwargs): # 4. Execute the actual function # We ensure context_cache is passed down to the recursive calls - result = func( - self, collisions=collisions, **kwargs - ) + result = func(self, collisions=collisions, **kwargs) # 5. Update Cache context_cache[key] = result From 5d9183a29c48fbe6eb7987ed4fc87bceb38df665 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Mon, 26 Jan 2026 05:41:11 +0000 Subject: [PATCH 14/22] return default func if cache not set --- gapic/utils/cache.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index 87ac0b8253..e87990d794 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -86,9 +86,7 @@ def wrapper(self, *, collisions, **kwargs): context_cache = _get_cache() if context_cache is None: - raise RuntimeError( - f"Cache MISSING! {func.__name__} called on {type(self).__name__} outside of context manager." - ) + return func(self, collisions=collisions, **kwargs) # 2. Create the cache key collisions_key = frozenset(collisions) if collisions else None From b4a2815ad304744d49954f10eb5bb511159af505 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Mon, 26 Jan 2026 07:40:32 +0000 Subject: [PATCH 15/22] add test cases for cache --- tests/unit/utils/test_cache.py | 45 ++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py index eee6865e6d..591dd6c929 100644 --- a/tests/unit/utils/test_cache.py +++ b/tests/unit/utils/test_cache.py @@ -31,3 +31,48 @@ def bar(self): assert foo.call_count == 1 assert foo.bar == 42 assert foo.call_count == 1 + + +def test_generation_cache_context(): + assert cache.generation_cache.get() is None + with cache.generation_cache_context(): + assert isinstance(cache.generation_cache.get(), dict) + cache.generation_cache.get()["foo"] = "bar" + assert cache.generation_cache.get()["foo"] == "bar" + assert cache.generation_cache.get() is None + + +def test_cached_proto_context(): + class Foo: + def __init__(self): + self.call_count = 0 + + @cache.cached_proto_context + def bar(self, *, collisions, other=None): + self.call_count += 1 + return f"baz-{self.call_count}" + + foo = Foo() + + # Without context, no caching + assert foo.bar(collisions={"a", "b"}) == "baz-1" + assert foo.bar(collisions={"a", "b"}) == "baz-2" + + # With context, caching works + with cache.generation_cache_context(): + assert foo.bar(collisions={"a", "b"}) == "baz-3" + assert foo.bar(collisions={"a", "b"}) == "baz-3" + assert foo.call_count == 3 + + # Different collisions, different cache entry + assert foo.bar(collisions={"c"}) == "baz-4" + assert foo.bar(collisions={"c"}) == "baz-4" + assert foo.call_count == 4 + + # Same collisions again, still cached + assert foo.bar(collisions={"a", "b"}) == "baz-3" + assert foo.call_count == 4 + + # Context cleared + assert cache.generation_cache.get() is None + assert foo.bar(collisions={"a", "b"}) == "baz-5" From 565664dae0e6a16e81f0bb6195b00e4d2c8c904b Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Jan 2026 00:28:25 +0000 Subject: [PATCH 16/22] update key --- gapic/utils/cache.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index e87990d794..f1c1f597f6 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -88,9 +88,17 @@ def wrapper(self, *, collisions, **kwargs): if context_cache is None: return func(self, collisions=collisions, **kwargs) - # 2. Create the cache key - collisions_key = frozenset(collisions) if collisions else None - key = (id(self), collisions_key) + col_key = frozenset(collisions) if collisions else None + + # CRITICAL FIX: Include these flags in the key. + # This prevents the "Summary" (skip_fields=True) from being returned + # when the template asks for the "Full" (skip_fields=False) object. + skip_key = kwargs.get('skip_fields', False) + + visited = kwargs.get('visited_messages') + vis_key = tuple(sorted(id(m) for m in visited)) if visited else None + + key = (id(self), col_key, skip_key, vis_key) # 3. Check Cache if key in context_cache: From 5c13472553b7707efbc6a670b02d34b7d15840e2 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Jan 2026 00:32:26 +0000 Subject: [PATCH 17/22] update tests --- tests/unit/utils/test_cache.py | 73 ++++++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 17 deletions(-) diff --git a/tests/unit/utils/test_cache.py b/tests/unit/utils/test_cache.py index 591dd6c929..bc46920e60 100644 --- a/tests/unit/utils/test_cache.py +++ b/tests/unit/utils/test_cache.py @@ -43,36 +43,75 @@ def test_generation_cache_context(): def test_cached_proto_context(): + # A dummy object to act as a visited message + class MockMessage: + pass + class Foo: def __init__(self): self.call_count = 0 + # We define a signature that matches the real Proto.with_context + # to ensure arguments are propagated correctly. @cache.cached_proto_context - def bar(self, *, collisions, other=None): + def with_context(self, collisions, *, skip_fields=False, visited_messages=None): self.call_count += 1 - return f"baz-{self.call_count}" + return f"val-{self.call_count}" foo = Foo() + msg_a = MockMessage() + msg_b = MockMessage() - # Without context, no caching - assert foo.bar(collisions={"a", "b"}) == "baz-1" - assert foo.bar(collisions={"a", "b"}) == "baz-2" + # 1. Test Bypass (No Context) + # The cache is not active, so every call increments the counter. + assert foo.with_context(collisions={"a"}) == "val-1" + assert foo.with_context(collisions={"a"}) == "val-2" - # With context, caching works + # 2. Test Context Activation with cache.generation_cache_context(): - assert foo.bar(collisions={"a", "b"}) == "baz-3" - assert foo.bar(collisions={"a", "b"}) == "baz-3" + # Reset counter to make tracking easier + foo.call_count = 0 + + # A. Basic Cache Hit + assert foo.with_context({"a"}) == "val-1" + assert foo.with_context({"a"}) == "val-1" # Hit + assert foo.call_count == 1 + + # B. Collision Difference + # Changing collisions creates a new key + assert foo.with_context({"b"}) == "val-2" + assert foo.call_count == 2 + + # C. skip_fields Difference (The Critical Fix) + # Verify that skip_fields=True is cached separately from the default (False). + # This prevents the "Husk" object from poisoning the "Full" object cache. + assert foo.with_context({"a"}, skip_fields=True) == "val-3" + assert foo.with_context({"a"}, skip_fields=True) == "val-3" # Hit assert foo.call_count == 3 - - # Different collisions, different cache entry - assert foo.bar(collisions={"c"}) == "baz-4" - assert foo.bar(collisions={"c"}) == "baz-4" + + # Verify the original (skip_fields=False) is still accessible + assert foo.with_context({"a"}) == "val-1" # Hit (Still 1) + + # D. visited_messages Difference + # Verify that sets of objects are correctly distinguished. + + # Call with msg_a + assert foo.with_context({"a"}, visited_messages={msg_a}) == "val-4" + assert foo.with_context({"a"}, visited_messages={msg_a}) == "val-4" # Hit assert foo.call_count == 4 - # Same collisions again, still cached - assert foo.bar(collisions={"a", "b"}) == "baz-3" - assert foo.call_count == 4 + # Call with msg_b (Should be a Miss) + assert foo.with_context({"a"}, visited_messages={msg_b}) == "val-5" + assert foo.call_count == 5 + + # E. Set Stability + # Verify that the order of items in the set doesn't matter (sets are unordered). + # {a, b} should cache-hit with {b, a} + assert foo.with_context({"a"}, visited_messages={msg_a, msg_b}) == "val-6" + assert foo.with_context({"a"}, visited_messages={msg_b, msg_a}) == "val-6" # Hit + assert foo.call_count == 6 - # Context cleared + # 3. Context Cleared + # Everything should be forgotten now. assert cache.generation_cache.get() is None - assert foo.bar(collisions={"a", "b"}) == "baz-5" + assert foo.with_context({"a"}) == "val-7" \ No newline at end of file From 6012919c1af8c61faabb02b78a5bca674f2e0548 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Jan 2026 00:40:30 +0000 Subject: [PATCH 18/22] update key --- gapic/utils/cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index f1c1f597f6..f08ae8e48a 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -98,7 +98,7 @@ def wrapper(self, *, collisions, **kwargs): visited = kwargs.get('visited_messages') vis_key = tuple(sorted(id(m) for m in visited)) if visited else None - key = (id(self), col_key, skip_key, vis_key) + key = (self, col_key, skip_key, vis_key) # 3. Check Cache if key in context_cache: From 16160ad0eb63fdf6e12a768d16812c51df63e85c Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Jan 2026 00:48:25 +0000 Subject: [PATCH 19/22] add keep alive --- gapic/utils/cache.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index f08ae8e48a..3e2d0c227d 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -98,12 +98,14 @@ def wrapper(self, *, collisions, **kwargs): visited = kwargs.get('visited_messages') vis_key = tuple(sorted(id(m) for m in visited)) if visited else None - key = (self, col_key, skip_key, vis_key) + key = (id(self), col_key, skip_key, vis_key) # 3. Check Cache if key in context_cache: return context_cache[key] + keep_alive.append(self) + # 4. Execute the actual function # We ensure context_cache is passed down to the recursive calls result = func(self, collisions=collisions, **kwargs) From e67c36ac66232f11f6ea33b271400aef1e4f6f92 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Jan 2026 01:03:26 +0000 Subject: [PATCH 20/22] separate out context --- gapic/cli/generate.py | 1 + gapic/utils/cache.py | 16 +++------------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/gapic/cli/generate.py b/gapic/cli/generate.py index ec13df8ebf..e069b9939c 100644 --- a/gapic/cli/generate.py +++ b/gapic/cli/generate.py @@ -63,6 +63,7 @@ def generate(request: typing.BinaryIO, output: typing.BinaryIO) -> None: # to each template in the rendering step. api_schema = api.API.build(req.proto_file, opts=opts, package=package) + with generation_cache_context(): # Translate into a protobuf CodeGeneratorResponse; this reads the # individual templates and renders them. # If there are issues, error out appropriately. diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index 3e2d0c227d..e87990d794 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -88,24 +88,14 @@ def wrapper(self, *, collisions, **kwargs): if context_cache is None: return func(self, collisions=collisions, **kwargs) - col_key = frozenset(collisions) if collisions else None - - # CRITICAL FIX: Include these flags in the key. - # This prevents the "Summary" (skip_fields=True) from being returned - # when the template asks for the "Full" (skip_fields=False) object. - skip_key = kwargs.get('skip_fields', False) - - visited = kwargs.get('visited_messages') - vis_key = tuple(sorted(id(m) for m in visited)) if visited else None - - key = (id(self), col_key, skip_key, vis_key) + # 2. Create the cache key + collisions_key = frozenset(collisions) if collisions else None + key = (id(self), collisions_key) # 3. Check Cache if key in context_cache: return context_cache[key] - keep_alive.append(self) - # 4. Execute the actual function # We ensure context_cache is passed down to the recursive calls result = func(self, collisions=collisions, **kwargs) From 9ebb566fd4a6080c15cc32416b4b7b3c712001d0 Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Jan 2026 01:07:07 +0000 Subject: [PATCH 21/22] avoid using context for get response --- gapic/cli/generate.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/gapic/cli/generate.py b/gapic/cli/generate.py index e069b9939c..674979f7d9 100644 --- a/gapic/cli/generate.py +++ b/gapic/cli/generate.py @@ -63,11 +63,10 @@ def generate(request: typing.BinaryIO, output: typing.BinaryIO) -> None: # to each template in the rendering step. api_schema = api.API.build(req.proto_file, opts=opts, package=package) - with generation_cache_context(): - # Translate into a protobuf CodeGeneratorResponse; this reads the - # individual templates and renders them. - # If there are issues, error out appropriately. - res = generator.Generator(opts).get_response(api_schema, opts) + # Translate into a protobuf CodeGeneratorResponse; this reads the + # individual templates and renders them. + # If there are issues, error out appropriately. + res = generator.Generator(opts).get_response(api_schema, opts) # Output the serialized response. output.write(res.SerializeToString()) From c25fb2b50db3f993c2750887be026782e3e6761b Mon Sep 17 00:00:00 2001 From: ohmayr Date: Tue, 27 Jan 2026 01:41:55 +0000 Subject: [PATCH 22/22] change logic of cache --- gapic/utils/cache.py | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/gapic/utils/cache.py b/gapic/utils/cache.py index e87990d794..2991af3d85 100644 --- a/gapic/utils/cache.py +++ b/gapic/utils/cache.py @@ -14,7 +14,7 @@ import functools import contextlib -from contextvars import ContextVar +import threading from typing import Dict, Optional, Any @@ -48,33 +48,20 @@ def inner(self): return property(inner) -# 1. The ContextVar (Default is None) -# This replaces threading.local with the async-safe standard. -generation_cache: ContextVar[Optional[Dict[Any, Any]]] = ContextVar( - "generation_cache", default=None -) - -# Optimization: Bind .get for speed -_get_cache = generation_cache.get +# Thread-local storage for the simple cache dictionary +_thread_local = threading.local() @contextlib.contextmanager def generation_cache_context(): - """Context manager to explicitly manage the cache lifecycle. - - Usage: - with generation_cache_context(): - # Cache is active (fast) - ... - """ - # Initialize: Create a new dictionary and set it in the ContextVar - token = generation_cache.set({}) + """Context manager to explicitly manage the cache lifecycle.""" + # Initialize the cache as a standard dictionary + _thread_local.cache = {} try: yield finally: - # Cleanup: Reset the ContextVar to its previous state (None) - # This allows the dictionary to be garbage collected. - generation_cache.reset(token) + # Delete the dictionary to free all memory and pinned objects + del _thread_local.cache def cached_proto_context(func): @@ -84,7 +71,7 @@ def cached_proto_context(func): def wrapper(self, *, collisions, **kwargs): # 1. Initialize cache if not provided (handles the root call case) - context_cache = _get_cache() + context_cache = getattr(_thread_local, "cache", None) if context_cache is None: return func(self, collisions=collisions, **kwargs)