Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@

"""Interceptor for collecting Cloud Spanner metrics."""

import inspect
import logging
import re
from typing import Dict
from typing import Any, Dict

import grpc
from grpc_interceptor import ClientInterceptor

from .constants import GOOGLE_CLOUD_RESOURCE_KEY, SPANNER_METHOD_PREFIX
from .spanner_metrics_tracer_factory import SpannerMetricsTracerFactory

logger = logging.getLogger(__name__)


class MetricsInterceptor(ClientInterceptor):
"""Interceptor that collects metrics for Cloud Spanner operations."""
Expand Down Expand Up @@ -88,6 +93,8 @@ def _set_metrics_tracer_attributes(self, resources: Dict[str, str]) -> None:
if "database" in resources:
tracer.set_database(resources["database"])



def intercept(self, invoked_method, request_or_iterator, call_details):
"""Intercept gRPC calls to collect metrics.

Expand Down Expand Up @@ -122,10 +129,265 @@ def intercept(self, invoked_method, request_or_iterator, call_details):
tracer.set_method(method_name)
tracer.record_attempt_start()
response = invoked_method(request_or_iterator, call_details)
tracer.record_attempt_completion()

# Process and send GFE metrics if enabled
if tracer.gfe_enabled:
metadata = response.initial_metadata()
return _wrap_response(response, tracer)


def _wrap_response(response: Any, tracer: Any) -> Any:
"""Wraps the response if it is streaming, or records metrics immediately if unary."""
if hasattr(response, "__next__"):
return _StreamingResponseWrapper(response, tracer)
else:
# Unary call: execute completion and record metrics immediately
try:
tracer.record_attempt_completion()
metadata = []
if hasattr(response, "initial_metadata"):
try:
metadata.extend(response.initial_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(response, "trailing_metadata"):
try:
metadata.extend(response.trailing_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")
return response
Comment on lines +140 to 158

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The metrics recording block for unary calls is not wrapped in a try-except block. If tracer.record_attempt_completion() or tracer.record_gfe_metrics(metadata) raises an exception (e.g., due to OpenTelemetry configuration issues or unexpected metadata formats), it will crash the entire unary RPC call and prevent the response from being returned. Telemetry and metrics collection should be non-blocking and fail-safe, meaning they should never disrupt the main application flow. Avoid broad except Exception: blocks that silently pass; instead, log the exception to aid in debugging.

Suggested change
else:
# Unary call: execute completion and record metrics immediately
tracer.record_attempt_completion()
metadata = []
if hasattr(response, "initial_metadata"):
try:
metadata.extend(response.initial_metadata() or [])
except Exception:
pass
if hasattr(response, "trailing_metadata"):
try:
metadata.extend(response.trailing_metadata() or [])
except Exception:
pass
tracer.record_gfe_metrics(metadata)
return response
else:
# Unary call: execute completion and record metrics immediately
try:
tracer.record_attempt_completion()
metadata = []
if hasattr(response, "initial_metadata"):
try:
metadata.extend(response.initial_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(response, "trailing_metadata"):
try:
metadata.extend(response.trailing_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")
return response
References
  1. Avoid broad except Exception: blocks that silently return None. Instead, log the exception (e.g., using logger.warning) to aid in debugging and prevent masking underlying issues.



class AsyncMetricsInterceptor(
grpc.aio.UnaryUnaryClientInterceptor,
grpc.aio.UnaryStreamClientInterceptor,
grpc.aio.StreamUnaryClientInterceptor,
grpc.aio.StreamStreamClientInterceptor,
):
"""Async Interceptor that collects metrics for Cloud Spanner operations."""

async def intercept_unary_unary(self, continuation, client_call_details, request):
return await self._async_intercept(continuation, client_call_details, request)

async def intercept_unary_stream(self, continuation, client_call_details, request):
return await self._async_intercept(continuation, client_call_details, request)

async def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
return await self._async_intercept(continuation, client_call_details, request_iterator)

async def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
return await self._async_intercept(continuation, client_call_details, request_iterator)

async def _async_intercept(
self,
continuation: Any,
call_details: grpc.ClientCallDetails,
request_or_iterator: Any,
) -> Any:
# Implementation for async interceptor
factory = SpannerMetricsTracerFactory()
tracer = SpannerMetricsTracerFactory.get_current_tracer()
if tracer is None or not factory.enabled:
return await continuation(call_details, request_or_iterator)

if not (
tracer.client_attributes.get("project_id")
and tracer.client_attributes.get("instance_id")
and tracer.client_attributes.get("database")
):
resources = MetricsInterceptor._extract_resource_from_path(call_details.metadata)
MetricsInterceptor._set_metrics_tracer_attributes(resources)

method_name = call_details.method.removeprefix(SPANNER_METHOD_PREFIX).replace(
"/", "."
)

tracer.set_method(method_name)
tracer.record_attempt_start()
response = await continuation(call_details, request_or_iterator)

if hasattr(response, "__anext__"):
return _AsyncStreamingResponseWrapper(response, tracer)
else:
return _AsyncUnaryResponseWrapper(response, tracer)


class _StreamingResponseWrapper:
"""Wrapper for streaming RPC response iterators to defer metrics recording."""

def __init__(self, response, tracer):
self._response = response
self._tracer = tracer
self._metrics_recorded = False
self._iterator = None

def __iter__(self):
self._iterator = iter(self._response)
return self

def __next__(self):
if self._iterator is None:
self._iterator = iter(self._response)
try:
return next(self._iterator)
except StopIteration:
self._record_metrics()
raise
except Exception:
self._record_metrics()
raise

def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
metadata.extend(self._response.initial_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(self._response, "trailing_metadata"):
try:
metadata.extend(self._response.trailing_metadata() or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
self._tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")

def __del__(self):
try:
self._record_metrics()
except Exception:
pass

def __getattr__(self, name):
return getattr(self._response, name)


class _AsyncUnaryResponseWrapper:
"""Wrapper for async unary RPC response to defer metrics recording until awaited."""

def __init__(self, response, tracer):
self._response = response
self._tracer = tracer
self._metrics_recorded = False

def __await__(self):
async def _wait():
try:
return await self._response
finally:
await self._record_metrics()
return _wait().__await__()

async def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
res = self._response.initial_metadata()
if inspect.isawaitable(res):
res = await res
metadata.extend(res or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(self._response, "trailing_metadata"):
try:
res = self._response.trailing_metadata()
if inspect.isawaitable(res):
res = await res
metadata.extend(res or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
self._tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")

def __del__(self):
if not self._metrics_recorded:
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
except Exception:
pass

def __getattr__(self, name):
return getattr(self._response, name)


class _AsyncStreamingResponseWrapper:
"""Wrapper for async streaming RPC response iterators to defer metrics recording."""

def __init__(self, response, tracer):
self._response = response
self._tracer = tracer
self._metrics_recorded = False
self._iterator = None

def __aiter__(self):
if hasattr(self._response, "__aiter__"):
self._iterator = self._response.__aiter__()
else:
self._iterator = self._response
return self

async def __anext__(self):
if self._iterator is None:
if hasattr(self._response, "__aiter__"):
self._iterator = self._response.__aiter__()
else:
self._iterator = self._response
try:
return await self._iterator.__anext__()
except StopAsyncIteration:
await self._record_metrics()
raise
except Exception:
await self._record_metrics()
raise

async def _record_metrics(self):
if self._metrics_recorded:
return
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
metadata = []
if hasattr(self._response, "initial_metadata"):
try:
res = self._response.initial_metadata()
if inspect.isawaitable(res):
res = await res
metadata.extend(res or [])
except Exception as e:
logger.warning(f"Failed to retrieve initial metadata: {e}")
if hasattr(self._response, "trailing_metadata"):
try:
res = self._response.trailing_metadata()
if inspect.isawaitable(res):
res = await res
metadata.extend(res or [])
except Exception as e:
logger.warning(f"Failed to retrieve trailing metadata: {e}")
self._tracer.record_gfe_metrics(metadata)
except Exception as e:
logger.warning(f"Failed to record metrics: {e}")

def __del__(self):
if not self._metrics_recorded:
self._metrics_recorded = True
try:
self._tracer.record_attempt_completion()
except Exception:
pass

def __getattr__(self, name):
return getattr(self._response, name)
Loading
Loading