-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathconsumer.py
More file actions
222 lines (193 loc) · 8.17 KB
/
consumer.py
File metadata and controls
222 lines (193 loc) · 8.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""Lambda handler for SQS Consumer.
Processes messages from SQS queue, calls Agent Server, sends response to Telegram.
"""
import asyncio
import json
from typing import Any
import httpx
from telegram import Bot, Update
from telegram.constants import ParseMode, ChatAction
from telegram.helpers import escape_markdown
from telegram.error import BadRequest
from config import Config
def _get_reply_to_id(message_id: int, thread_id: int | None, message_thread_id: int | None) -> int | None:
"""Determine if we should reply to the original message.
Only reply to the original message if we're in the same thread.
This prevents Telegram API errors when sending to a different thread (e.g., /newchat).
Args:
message_id: The original message ID
thread_id: The target thread ID (may be overridden by handler)
message_thread_id: The original message's thread ID
Returns:
message_id if in same thread, None otherwise
"""
return message_id if thread_id == message_thread_id else None
def lambda_handler(event: dict, context: Any) -> dict:
"""SQS Consumer Lambda entry point."""
for record in event['Records']:
try:
message_data = json.loads(record['body'])
except json.JSONDecodeError as e:
# Invalid message format - log and skip
import logging
logger = logging.getLogger()
logger.error(f"Failed to parse SQS message: {e}")
continue
try:
asyncio.run(process_message(message_data))
except Exception as e:
# Log and let SQS retry on failure
import logging
logger = logging.getLogger()
logger.exception(f"Failed to process message: {e}")
raise # Re-raise to fail the batch item
return {'statusCode': 200}
async def process_message(message_data: dict) -> None:
"""Process single message from SQS queue."""
import logging
logger = logging.getLogger()
# Enable INFO logging as suggested in issue for better debugging
logger.setLevel(logging.INFO)
config = Config.from_env()
bot = Bot(config.telegram_token)
# Reconstruct Update object from stored data
update = Update.de_json(message_data['telegram_update'], bot)
message = update.message or update.edited_message
if not message:
logger.warning("Received update with no message or edited_message")
return
# Extract thread_id and user_message early - needed for all message processing
# (allows handler to override text/thread_id via SQS message_data)
user_message = message_data.get('text') or message.text
thread_id = message_data.get('thread_id') or message.message_thread_id
cmd = config.get_command(message.text)
if cmd:
if config.is_local_command(cmd):
logger.info(
"Handling local command in consumer (fallback path)",
extra={'chat_id': message.chat_id, 'message_id': message.message_id},
)
reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id)
try:
await bot.send_message(
chat_id=message.chat_id,
text=config.local_response(cmd),
message_thread_id=thread_id,
reply_to_message_id=reply_to_id,
)
except Exception:
logger.warning("Failed to send local command response", exc_info=True)
return
if not config.is_agent_command(cmd):
# Defensive guard: producer should already block non-agent commands.
logger.info(
"Skipping non-agent command (consumer fallback)",
extra={
'chat_id': message.chat_id,
'message_id': message.message_id,
},
)
reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id)
try:
await bot.send_message(
chat_id=message.chat_id,
text=config.unknown_command_message(),
message_thread_id=thread_id,
reply_to_message_id=reply_to_id,
)
except Exception:
logger.warning("Failed to send local command response", exc_info=True)
return
# Send typing indicator
await bot.send_chat_action(
chat_id=message.chat_id,
action=ChatAction.TYPING,
message_thread_id=thread_id,
)
# Initialize result with default error response
# This ensures result is always defined, even if Agent Server call fails
result = {
'response': '',
'is_error': True,
'error_message': 'Failed to get response from Agent Server'
}
# Call Agent Server
try:
async with httpx.AsyncClient(timeout=600.0) as client:
response = await client.post(
config.agent_server_url,
headers={
'Authorization': f'Bearer {config.auth_token}',
'Content-Type': 'application/json',
},
json={
'user_message': user_message,
'chat_id': str(message.chat_id),
'thread_id': str(thread_id) if thread_id else None,
},
)
response.raise_for_status()
result = response.json()
except httpx.TimeoutException:
logger.warning(f"Agent Server timeout for chat_id={message.chat_id}")
reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id)
await bot.send_message(
chat_id=message.chat_id,
text="Request timed out.",
message_thread_id=thread_id,
reply_to_message_id=reply_to_id,
)
raise # Re-raise to trigger SQS retry for transient errors
except Exception as e:
logger.exception(f"Agent Server error for chat_id={message.chat_id}")
error_text = f"Error: {str(e)[:200]}"
reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id)
try:
await bot.send_message(
chat_id=message.chat_id,
text=error_text,
message_thread_id=thread_id,
reply_to_message_id=reply_to_id,
)
except Exception as send_error:
logger.error(f"Failed to send error message to Telegram: {send_error}")
# Don't re-raise - error message already sent to user, retrying would cause duplicate messages
# Format response (result is guaranteed to be defined now)
if result.get('is_error'):
text = f"Agent error: {result.get('error_message', 'Unknown')}"
else:
text = result.get('response') or 'No response'
if len(text) > 4000:
text = text[:4000] + "\n\n... (truncated)"
# Send response to Telegram
reply_to_id = _get_reply_to_id(message.message_id, thread_id, message.message_thread_id)
try:
await bot.send_message(
chat_id=message.chat_id,
text=text,
parse_mode=ParseMode.MARKDOWN_V2,
message_thread_id=thread_id,
reply_to_message_id=reply_to_id,
)
logger.info(
f"Message sent successfully to chat_id={message.chat_id}, "
f"thread_id={thread_id}, reply_to={reply_to_id}"
)
except BadRequest as e:
if "parse entities" in str(e).lower():
logger.warning(f"Markdown parse error, retrying with escaped text: {e}")
safe_text = escape_markdown(text, version=2)
await bot.send_message(
chat_id=message.chat_id,
text=safe_text,
parse_mode=ParseMode.MARKDOWN_V2,
message_thread_id=thread_id,
reply_to_message_id=reply_to_id,
)
logger.info(
f"Message sent successfully (escaped) to chat_id={message.chat_id}, "
f"thread_id={thread_id}, reply_to={reply_to_id}"
)
else:
logger.error(f"Failed to send message: {e}")
raise