|
| 1 | +"""Amazon SQS transport. Requires the ``sqs`` extra: |
| 2 | +
|
| 3 | + pip install "babelqueue[sqs]" |
| 4 | +
|
| 5 | +Producing sends the canonical envelope as the message body and projects the |
| 6 | +contract envelope fields onto native SQS ``MessageAttributes`` (``bq-job`` = URN, |
| 7 | +``bq-trace-id`` = trace_id, ``bq-message-id`` = meta.id, plus ``bq-schema-version`` / |
| 8 | +``bq-source-lang`` / ``bq-created-at``) — so a Go/PHP/... peer can route on ``bq-job`` |
| 9 | +and correlate on ``bq-trace-id`` without parsing the body. Consuming uses the |
| 10 | +visibility-timeout reservation model (``receive_message`` -> process -> |
| 11 | +``delete_message``); the authoritative attempt count is the broker's |
| 12 | +``ApproximateReceiveCount``, reconciled onto the envelope as ``attempts = count - 1``. |
| 13 | +
|
| 14 | +This implements §3 of the broker-bindings contract. The envelope is unchanged |
| 15 | +(``schema_version`` stays 1); SQS is purely additive. |
| 16 | +
|
| 17 | +URL form: ``sqs://[region][?endpoint=...&prefix=...&fifo=1&group_id=...&wait_time=20]`` |
| 18 | +(e.g. ``sqs://us-east-1?endpoint=http://localhost:4566`` for LocalStack). Credentials |
| 19 | +come from the standard AWS default provider chain. For richer setups, build the |
| 20 | +transport directly and pass it via ``BabelQueue(transport=...)``. |
| 21 | +""" |
| 22 | + |
| 23 | +from __future__ import annotations |
| 24 | + |
| 25 | +from typing import Any, Dict, Optional |
| 26 | +from urllib.parse import parse_qs, urlsplit |
| 27 | + |
| 28 | +from .codec import EnvelopeCodec |
| 29 | +from .transport import ReceivedMessage, Transport |
| 30 | + |
| 31 | + |
| 32 | +class SqsTransport(Transport): |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + url: str = "sqs://", |
| 36 | + *, |
| 37 | + client: Any = None, |
| 38 | + region: Optional[str] = None, |
| 39 | + endpoint: Optional[str] = None, |
| 40 | + queue_url_prefix: Optional[str] = None, |
| 41 | + wait_time: Optional[int] = None, |
| 42 | + visibility_timeout: Optional[int] = None, |
| 43 | + fifo: bool = False, |
| 44 | + message_group_id: Optional[str] = None, |
| 45 | + content_dedup: bool = False, |
| 46 | + ) -> None: |
| 47 | + parts = urlsplit(url) if url else urlsplit("sqs://") |
| 48 | + q = parse_qs(parts.query) |
| 49 | + |
| 50 | + self._region = region or (parts.hostname or None) |
| 51 | + self._endpoint = endpoint or _q1(q, "endpoint") |
| 52 | + self._queue_url_prefix = queue_url_prefix or _q1(q, "prefix") |
| 53 | + self._wait_time = wait_time if wait_time is not None else _qint(q, "wait_time") |
| 54 | + self._visibility_timeout = ( |
| 55 | + visibility_timeout if visibility_timeout is not None else _qint(q, "visibility_timeout") |
| 56 | + ) |
| 57 | + self._fifo = fifo or _qbool(q, "fifo") |
| 58 | + self._message_group_id = message_group_id or _q1(q, "group_id") |
| 59 | + self._content_dedup = content_dedup or _qbool(q, "content_dedup") |
| 60 | + self._urls: Dict[str, str] = {} |
| 61 | + |
| 62 | + if client is not None: |
| 63 | + self._sqs = client |
| 64 | + return |
| 65 | + try: |
| 66 | + import boto3 |
| 67 | + except ImportError as exc: # pragma: no cover - import guard |
| 68 | + raise ImportError( |
| 69 | + "SqsTransport requires the 'boto3' package. Install with " |
| 70 | + 'pip install "babelqueue[sqs]".' |
| 71 | + ) from exc |
| 72 | + kwargs: Dict[str, Any] = {} |
| 73 | + if self._region: |
| 74 | + kwargs["region_name"] = self._region |
| 75 | + if self._endpoint: |
| 76 | + kwargs["endpoint_url"] = self._endpoint |
| 77 | + self._sqs = boto3.client("sqs", **kwargs) # pragma: no cover - needs AWS/LocalStack |
| 78 | + |
| 79 | + # -- helpers ------------------------------------------------------------ |
| 80 | + |
| 81 | + def _resolve_url(self, name: str) -> str: |
| 82 | + cached = self._urls.get(name) |
| 83 | + if cached is not None: |
| 84 | + return cached |
| 85 | + if self._queue_url_prefix: |
| 86 | + url = self._queue_url_prefix.rstrip("/") + "/" + name |
| 87 | + else: |
| 88 | + url = self._sqs.get_queue_url(QueueName=name)["QueueUrl"] |
| 89 | + self._urls[name] = url |
| 90 | + return url |
| 91 | + |
| 92 | + @staticmethod |
| 93 | + def _attributes(body: str) -> Dict[str, Dict[str, str]]: |
| 94 | + """Project the envelope's contract fields onto SQS MessageAttributes — a |
| 95 | + redundant, routable view of the body (the body stays authoritative).""" |
| 96 | + try: |
| 97 | + env: Dict[str, Any] = EnvelopeCodec.decode(body) |
| 98 | + except (ValueError, TypeError): # pragma: no cover - decode is defensive |
| 99 | + return {} |
| 100 | + meta = env.get("meta") or {} |
| 101 | + |
| 102 | + def s(v: Any) -> Dict[str, str]: |
| 103 | + return {"DataType": "String", "StringValue": str(v)} |
| 104 | + |
| 105 | + def n(v: Any) -> Dict[str, str]: |
| 106 | + return {"DataType": "Number", "StringValue": str(v)} |
| 107 | + |
| 108 | + attrs: Dict[str, Dict[str, str]] = {} |
| 109 | + if env.get("job"): |
| 110 | + attrs["bq-job"] = s(env["job"]) |
| 111 | + if env.get("trace_id"): |
| 112 | + attrs["bq-trace-id"] = s(env["trace_id"]) |
| 113 | + if meta.get("id"): |
| 114 | + attrs["bq-message-id"] = s(meta["id"]) |
| 115 | + if meta.get("schema_version") is not None: |
| 116 | + attrs["bq-schema-version"] = n(meta["schema_version"]) |
| 117 | + if meta.get("lang"): |
| 118 | + attrs["bq-source-lang"] = s(meta["lang"]) |
| 119 | + if meta.get("created_at") is not None: |
| 120 | + attrs["bq-created-at"] = n(meta["created_at"]) |
| 121 | + return attrs |
| 122 | + |
| 123 | + @staticmethod |
| 124 | + def _reconcile(body: str, receive_count: Any) -> str: |
| 125 | + """Set attempts to max(current, ApproximateReceiveCount - 1): a first delivery |
| 126 | + reads 0, a natively-redelivered message reflects its true count, and a |
| 127 | + runtime-incremented counter is never lowered.""" |
| 128 | + try: |
| 129 | + rc = int(receive_count) |
| 130 | + except (ValueError, TypeError): |
| 131 | + return body |
| 132 | + if rc <= 1: |
| 133 | + return body |
| 134 | + env = EnvelopeCodec.decode(body) |
| 135 | + if not env: |
| 136 | + return body |
| 137 | + native = rc - 1 |
| 138 | + if native <= int(env.get("attempts", 0)): |
| 139 | + return body |
| 140 | + env["attempts"] = native |
| 141 | + return EnvelopeCodec.encode(env) |
| 142 | + |
| 143 | + # -- Transport ---------------------------------------------------------- |
| 144 | + |
| 145 | + def publish(self, queue: str, body: str) -> None: |
| 146 | + params: Dict[str, Any] = {"QueueUrl": self._resolve_url(queue), "MessageBody": body} |
| 147 | + attrs = self._attributes(body) |
| 148 | + if attrs: |
| 149 | + params["MessageAttributes"] = attrs |
| 150 | + if self._fifo: |
| 151 | + params["MessageGroupId"] = self._message_group_id or queue |
| 152 | + if not self._content_dedup: |
| 153 | + msg_id = (EnvelopeCodec.decode(body).get("meta") or {}).get("id") |
| 154 | + if msg_id: |
| 155 | + params["MessageDeduplicationId"] = msg_id |
| 156 | + self._sqs.send_message(**params) |
| 157 | + |
| 158 | + def pop(self, queue: str, timeout: float = 1.0) -> Optional[ReceivedMessage]: |
| 159 | + wait = int(timeout) if timeout and timeout > 0 else 0 |
| 160 | + if wait > 20: |
| 161 | + wait = 20 |
| 162 | + if self._wait_time is not None and self._wait_time < wait: |
| 163 | + wait = self._wait_time |
| 164 | + params: Dict[str, Any] = { |
| 165 | + "QueueUrl": self._resolve_url(queue), |
| 166 | + "MaxNumberOfMessages": 1, |
| 167 | + "WaitTimeSeconds": wait, |
| 168 | + "MessageAttributeNames": ["All"], |
| 169 | + "AttributeNames": ["ApproximateReceiveCount"], |
| 170 | + } |
| 171 | + if self._visibility_timeout is not None: |
| 172 | + params["VisibilityTimeout"] = self._visibility_timeout |
| 173 | + resp = self._sqs.receive_message(**params) |
| 174 | + messages = resp.get("Messages") or [] |
| 175 | + if not messages: |
| 176 | + return None |
| 177 | + msg = messages[0] |
| 178 | + body = msg.get("Body", "") |
| 179 | + receive_count = (msg.get("Attributes") or {}).get("ApproximateReceiveCount") |
| 180 | + if receive_count is not None: |
| 181 | + body = self._reconcile(body, receive_count) |
| 182 | + return ReceivedMessage(body=body, queue=queue, handle=msg.get("ReceiptHandle")) |
| 183 | + |
| 184 | + def ack(self, message: ReceivedMessage) -> None: |
| 185 | + if not message.handle: |
| 186 | + return |
| 187 | + self._sqs.delete_message( |
| 188 | + QueueUrl=self._resolve_url(message.queue), ReceiptHandle=message.handle |
| 189 | + ) |
| 190 | + |
| 191 | + |
| 192 | +def _q1(q: Dict[str, list], key: str) -> Optional[str]: |
| 193 | + values = q.get(key) |
| 194 | + return values[0] if values else None |
| 195 | + |
| 196 | + |
| 197 | +def _qint(q: Dict[str, list], key: str) -> Optional[int]: |
| 198 | + v = _q1(q, key) |
| 199 | + if v is None: |
| 200 | + return None |
| 201 | + try: |
| 202 | + return int(v) |
| 203 | + except ValueError: # pragma: no cover - defensive |
| 204 | + return None |
| 205 | + |
| 206 | + |
| 207 | +def _qbool(q: Dict[str, list], key: str) -> bool: |
| 208 | + v = _q1(q, key) |
| 209 | + return v is not None and v.lower() in ("1", "true", "yes", "on") |
0 commit comments