Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions gapic/cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from gapic import generator
from gapic.schema import api
from gapic.utils import Options
from gapic.utils.cache import generation_cache_context


@click.command()
Expand Down Expand Up @@ -56,15 +57,16 @@ def generate(request: typing.BinaryIO, output: typing.BinaryIO) -> None:
[p.package for p in req.proto_file if p.name in req.file_to_generate]
).rstrip(".")

# Build the API model object.
# This object is a frozen representation of the whole API, and is sent
# to each template in the rendering step.
api_schema = api.API.build(req.proto_file, opts=opts, package=package)
with generation_cache_context():
# Build the API model object.
# This object is a frozen representation of the whole API, and is sent
# to each template in the rendering step.
api_schema = api.API.build(req.proto_file, opts=opts, package=package)

# Translate into a protobuf CodeGeneratorResponse; this reads the
# individual templates and renders them.
# If there are issues, error out appropriately.
res = generator.Generator(opts).get_response(api_schema, opts)
# Translate into a protobuf CodeGeneratorResponse; this reads the
# individual templates and renders them.
# If there are issues, error out appropriately.
res = generator.Generator(opts).get_response(api_schema, opts)

# Output the serialized response.
output.write(res.SerializeToString())
Expand Down
5 changes: 4 additions & 1 deletion gapic/schema/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@

import dataclasses
import re
from typing import FrozenSet, Set, Tuple, Optional
from typing import FrozenSet, Set, Tuple, Optional, Dict

from google.protobuf import descriptor_pb2

from gapic.schema import imp
from gapic.schema import naming
from gapic.utils import cached_property
from gapic.utils import cached_proto_context
from gapic.utils import RESERVED_NAMES

# This class is a minor hack to optimize Address's __eq__ method.
Expand Down Expand Up @@ -359,6 +360,7 @@ def resolve(self, selector: str) -> str:
return f'{".".join(self.package)}.{selector}'
return selector

@cached_proto_context
def with_context(self, *, collisions: Set[str]) -> "Address":
"""Return a derivative of this address with the provided context.

Expand Down Expand Up @@ -398,6 +400,7 @@ def doc(self):
return "\n\n".join(self.documentation.leading_detached_comments)
return ""

@cached_proto_context
def with_context(self, *, collisions: Set[str]) -> "Metadata":
"""Return a derivative of this metadata with the provided context.

Expand Down
35 changes: 28 additions & 7 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from gapic.schema import metadata
from gapic.utils import uri_sample
from gapic.utils import make_private
from gapic.utils import cached_proto_context


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -410,6 +411,7 @@ def type(self) -> Union["MessageType", "EnumType", "PrimitiveType"]:
"This code should not be reachable; please file a bug."
)

@cached_proto_context
def with_context(
self,
*,
Expand All @@ -435,8 +437,10 @@ def with_context(
if self.message
else None
),
enum=self.enum.with_context(collisions=collisions) if self.enum else None,
meta=self.meta.with_context(collisions=collisions),
enum=(self.enum.with_context(collisions=collisions) if self.enum else None),
meta=self.meta.with_context(
collisions=collisions,
),
)

def add_to_address_allowlist(
Expand Down Expand Up @@ -737,7 +741,9 @@ def path_regex_str(self) -> str:
return parsing_regex_str

def get_field(
self, *field_path: str, collisions: Optional[Set[str]] = None
self,
*field_path: str,
collisions: Optional[Set[str]] = None,
) -> Field:
"""Return a field arbitrarily deep in this message's structure.

Expand Down Expand Up @@ -805,6 +811,7 @@ def get_field(
# message.
return cursor.message.get_field(*field_path[1:], collisions=collisions)

@cached_proto_context
def with_context(
self,
*,
Expand All @@ -829,7 +836,8 @@ def with_context(
fields=(
{
k: v.with_context(
collisions=collisions, visited_messages=visited_messages
collisions=collisions,
visited_messages=visited_messages,
)
for k, v in self.fields.items()
}
Expand Down Expand Up @@ -937,7 +945,12 @@ def ident(self) -> metadata.Address:
"""Return the identifier data to be used in templates."""
return self.meta.address

def with_context(self, *, collisions: Set[str]) -> "EnumType":
@cached_proto_context
def with_context(
self,
*,
collisions: Set[str],
) -> "EnumType":
"""Return a derivative of this enum with the provided context.

This method is used to address naming collisions. The returned
Expand Down Expand Up @@ -1058,6 +1071,7 @@ class ExtendedOperationInfo:
request_type: MessageType
operation_type: MessageType

@cached_proto_context
def with_context(
self,
*,
Expand Down Expand Up @@ -1127,6 +1141,7 @@ class OperationInfo:
response_type: MessageType
metadata_type: MessageType

@cached_proto_context
def with_context(
self,
*,
Expand Down Expand Up @@ -1937,6 +1952,7 @@ def void(self) -> bool:
"""Return True if this method has no return value, False otherwise."""
return self.output.ident.proto == "google.protobuf.Empty"

@cached_proto_context
def with_context(
self,
*,
Expand Down Expand Up @@ -1981,7 +1997,9 @@ def with_context(
collisions=collisions,
visited_messages=visited_messages,
),
meta=self.meta.with_context(collisions=collisions),
meta=self.meta.with_context(
collisions=collisions,
),
)

def add_to_address_allowlist(
Expand Down Expand Up @@ -2357,6 +2375,7 @@ def operation_polling_method(self) -> Optional[Method]:
def is_internal(self) -> bool:
return any(m.is_internal for m in self.methods.values())

@cached_proto_context
def with_context(
self,
*,
Expand All @@ -2380,7 +2399,9 @@ def with_context(
)
for k, v in self.methods.items()
},
meta=self.meta.with_context(collisions=collisions),
meta=self.meta.with_context(
collisions=collisions,
),
)

def add_to_address_allowlist(
Expand Down
2 changes: 2 additions & 0 deletions gapic/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from gapic.utils.cache import cached_property
from gapic.utils.cache import cached_proto_context
from gapic.utils.case import to_snake_case
from gapic.utils.case import to_camel_case
from gapic.utils.checks import is_msg_field_pb
Expand All @@ -34,6 +35,7 @@

__all__ = (
"cached_property",
"cached_proto_context",
"convert_uri_fieldnames",
"doc",
"empty",
Expand Down
62 changes: 62 additions & 0 deletions gapic/utils/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

import functools
import contextlib
from contextvars import ContextVar
from typing import Dict, Optional, Any


def cached_property(fx):
Expand Down Expand Up @@ -43,3 +46,62 @@ def inner(self):
return self._cached_values[fx.__name__]

return property(inner)


# 1. The ContextVar (Default is None)
# This replaces threading.local with the async-safe standard.
generation_cache: ContextVar[Optional[Dict[Any, Any]]] = ContextVar(
"generation_cache", default=None
)

# Optimization: Bind .get for speed
_get_cache = generation_cache.get


@contextlib.contextmanager
def generation_cache_context():
"""Context manager to explicitly manage the cache lifecycle.

Usage:
with generation_cache_context():
# Cache is active (fast)
...
"""
# Initialize: Create a new dictionary and set it in the ContextVar
token = generation_cache.set({})
try:
yield
finally:
# Cleanup: Reset the ContextVar to its previous state (None)
# This allows the dictionary to be garbage collected.
generation_cache.reset(token)


def cached_proto_context(func):
"""Decorator to memoize with_context calls based on self and collisions."""

@functools.wraps(func)
def wrapper(self, collisions, **kwargs):
# 1. Initialize cache if not provided (handles the root call case)

context_cache = _get_cache()
if context_cache is None:
return func(self, collisions=collisions, **kwargs)

# 2. Create the cache key
collisions_key = frozenset(collisions) if collisions else None
key = (id(self), collisions_key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this key, the cache will only be triggered if both messages have the same memory address, even if the content is identical. Is that the intention? Typically, a hash would be used for this kind of thing. Unless each message is a singleton

a = ("hello",)
b = ("hello",)

>>> id(a), id(b)
(140612776488128, 140612776488000)
>>> hash(a), hash(b)
(2145482566216562249, 2145482566216562249)


# 3. Check Cache
if key in context_cache:
return context_cache[key]

# 4. Execute the actual function
# We ensure context_cache is passed down to the recursive calls
result = func(self, collisions=collisions, **kwargs)

# 5. Update Cache
context_cache[key] = result
return result

return wrapper
45 changes: 45 additions & 0 deletions tests/unit/utils/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,48 @@ def bar(self):
assert foo.call_count == 1
assert foo.bar == 42
assert foo.call_count == 1


def test_generation_cache_context():
assert cache.generation_cache.get() is None
with cache.generation_cache_context():
assert isinstance(cache.generation_cache.get(), dict)
cache.generation_cache.get()["foo"] = "bar"
assert cache.generation_cache.get()["foo"] == "bar"
assert cache.generation_cache.get() is None


def test_cached_proto_context():
class Foo:
def __init__(self):
self.call_count = 0

@cache.cached_proto_context
def bar(self, *, collisions, other=None):
self.call_count += 1
return f"baz-{self.call_count}"

foo = Foo()

# Without context, no caching
assert foo.bar(collisions={"a", "b"}) == "baz-1"
assert foo.bar(collisions={"a", "b"}) == "baz-2"

# With context, caching works
with cache.generation_cache_context():
assert foo.bar(collisions={"a", "b"}) == "baz-3"
assert foo.bar(collisions={"a", "b"}) == "baz-3"
assert foo.call_count == 3

# Different collisions, different cache entry
assert foo.bar(collisions={"c"}) == "baz-4"
assert foo.bar(collisions={"c"}) == "baz-4"
assert foo.call_count == 4

# Same collisions again, still cached
assert foo.bar(collisions={"a", "b"}) == "baz-3"
assert foo.call_count == 4

# Context cleared
assert cache.generation_cache.get() is None
assert foo.bar(collisions={"a", "b"}) == "baz-5"
Loading