Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cloud_pipelines_backend/instrumentation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class MetricUnit(str, enum.Enum):

SECONDS = "s"
ERRORS = "{error}"
EXECUTIONS = "{execution}"


# ---------------------------------------------------------------------------
Expand All @@ -51,6 +52,13 @@ class MetricUnit(str, enum.Enum):
unit=MetricUnit.SECONDS,
)

execution_status_count = orchestrator_meter.create_observable_gauge(
name="execution.status.count",
callbacks=[],
description="Number of execution nodes in each active (non-terminal) status",
unit=MetricUnit.EXECUTIONS,
)


def record_status_transition(
from_status: str,
Expand Down
85 changes: 85 additions & 0 deletions cloud_pipelines_backend/instrumentation/metrics_poller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Metrics poller.

Periodically queries the DB and updates ObservableGauges. Currently emits
execution status counts; add new DB-backed metrics here as needed.

Only fluctuating (non-terminal) statuses are emitted as status count gauges —
terminal statuses like SUCCEEDED and FAILED only ever climb and are not useful
as gauges.
"""

import logging
import threading
import time
import typing

import sqlalchemy as sql
from opentelemetry import metrics as otel_metrics
from sqlalchemy import orm

from .. import backend_types_sql as bts
from . import metrics as app_metrics

_logger = logging.getLogger(__name__)


# All statuses minus terminal (ended) ones — these fluctuate up and down
_ACTIVE_STATUSES: frozenset[bts.ContainerExecutionStatus] = (
frozenset(bts.ContainerExecutionStatus) - bts.CONTAINER_STATUSES_ENDED
)


class PollingService:
"""Polls the DB periodically and emits execution status count gauges."""

def __init__(
self,
*,
session_factory: typing.Callable[[], orm.Session],
poll_interval_seconds: float = 30.0,
) -> None:
self._session_factory = session_factory
self._poll_interval_seconds = poll_interval_seconds
self._lock = threading.Lock()
# Initialize all active statuses to 0
self._counts: dict[str, int] = {s.value: 0 for s in _ACTIVE_STATUSES}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: used in two places, might make sense to put into a function.

# Register our observe method as the gauge callback.
# The OTel SDK stores callbacks in _callbacks; we append after creation
# since create_observable_gauge is called at module load time in metrics.py.
app_metrics.execution_status_count._callbacks.append(self._observe)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're calling a private variable, is it possible to make a function set sets up execution_status_count with the callbacks in metric.py?

Reason for this is:

  • The private variable could change since this is not meant to be called, which will have issues if it does in the future.
  • Will there ever be a case you're appending more than one callback? If so, can you elaborate that use case. My thought is you would only ever need to "set" one callback.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see what we can do!


def run_loop(self) -> None:
while True:
try:
self._poll()
except Exception:
_logger.exception("Metrics PollingService: error polling DB")
time.sleep(self._poll_interval_seconds)

def _poll(self) -> None:
with self._session_factory() as session:
rows = session.execute(
sql.select(
bts.ExecutionNode.container_execution_status,
sql.func.count().label("count"),
)
.where(
bts.ExecutionNode.container_execution_status.in_(_ACTIVE_STATUSES)
)
.group_by(bts.ExecutionNode.container_execution_status)
).all()
new_counts = {s.value: 0 for s in _ACTIVE_STATUSES}
for status, count in rows:
if status is not None:
new_counts[status.value] = count
with self._lock:
self._counts = new_counts
_logger.debug(f"Metrics PollingService: polled status counts: {new_counts}")

def _observe(
self, _options: otel_metrics.CallbackOptions
) -> typing.Iterable[otel_metrics.Observation]:
with self._lock:
counts = self._counts.copy()
for status_value, count in counts.items():
yield otel_metrics.Observation(count, {"execution.status": status_value})
Comment on lines +82 to +85
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, when will observe be called? Will there ever be a race condition where observe is called before the poll happened/completed? Would it make sense to put a flag here that at last 1 poll completed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lock will prevent it from copying counts that are currently being modified.

As for there being at least 1 iteration, i'll add a check.

35 changes: 31 additions & 4 deletions start_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,35 @@ def run_orchestrator(
# endregion


# region: Metrics poller initialization

from cloud_pipelines_backend.instrumentation import metrics_poller
from cloud_pipelines_backend.instrumentation.opentelemetry._internal import (
configuration as otel_configuration,
)


def run_metrics_poller(*, db_engine: sqlalchemy.Engine) -> None:
otel_config = otel_configuration.resolve()
if otel_config is None or otel_config.metrics is None:
logger.info(
f"No OTel metrics endpoint configured"
f" (set {otel_configuration.EnvVar.METRIC_EXPORTER_ENDPOINT})"
f" — metrics poller not starting"
)
return
session_factory = orm.sessionmaker(
autocommit=False, autoflush=False, bind=db_engine
)
metrics_poller.PollingService(session_factory=session_factory).run_loop()


run_configured_metrics_poller = lambda: run_metrics_poller(
db_engine=db_engine,
)
Comment on lines +237 to +239
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to do

def run_configured_metrics_poller() -> None:
    run_metrics_poller(db_engine=db_engine)

Instead of lambda for variables. Reason why is 1) troubleshooting, if there was a stack trace it'll just say "lambda" instead of the actual function name and 2) cannot have type hinting/docstrings to have a strongly type function for type checking.

Also, if we had a linter, there are some that would complain about this.

Use cases for when using lambda are inline/throwaway like:

sorted(items, key=lambda x: x.created_at)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now this is intentional to match existing code standards set by the orchestrator. It matches the orchestrator code exactly. I'm not sure I want to introduce an alternative format right now unless @Ark-kun would like make the switch.

I do like all your points for not using a lambda in this case.

# endregion


# region: API Server initialization
import contextlib
import threading
Expand All @@ -228,10 +257,8 @@ def run_orchestrator(
@contextlib.asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
database_ops.initialize_and_migrate_db(db_engine=db_engine)
threading.Thread(
target=run_configured_orchestrator,
daemon=True,
).start()
threading.Thread(target=run_configured_orchestrator, daemon=True).start()
threading.Thread(target=run_configured_metrics_poller, daemon=True).start()
if os.environ.get("GOOGLE_CLOUD_SHELL") == "true":
# TODO: Find a way to get fastapi/starlette/uvicorn port
port = 8000
Expand Down