diff --git a/packages/sunagent-ext/pyproject.toml b/packages/sunagent-ext/pyproject.toml index e7fb11b..b3f9945 100644 --- a/packages/sunagent-ext/pyproject.toml +++ b/packages/sunagent-ext/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "sunagent-ext" -version = "0.0.6b10" +version = "0.0.7b1" license = {file = "LICENSE-CODE"} description = "AutoGen extensions library" readme = "README.md" diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/__init__.py b/packages/sunagent-ext/src/sunagent_ext/tweet/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_client_pool.py b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_client_pool.py new file mode 100644 index 0000000..6421b2d --- /dev/null +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_client_pool.py @@ -0,0 +1,110 @@ +import asyncio +import logging +import time +from dataclasses import dataclass +from typing import Any, Coroutine, List, Optional + +import tweepy +from tweepy import Client + +logger = logging.getLogger(__name__) +RETRY_AFTER_SEC = 15 * 60 # 15 分钟 + + +@dataclass +class _PoolItem: # type: ignore[no-any-unimported] + client: tweepy.Client # type: ignore[no-any-unimported] + client_key: str # 用 consumer_key 当唯一标识 + dead_at: Optional[float] = None # None 表示 alive + + +class TwitterClientPool: + """ + Twitter 客户端专用池:轮询获取、异常熔断、15 min 复活、支持永久摘除。 + 所有操作在锁内完成,保证并发安全。 + """ + + def __init__(self, clients: list[tweepy.Client], retry_after: float = RETRY_AFTER_SEC) -> None: # type: ignore[no-any-unimported] + self._retry_after = retry_after + self._pool: list[_PoolItem] = [_PoolItem(c, c.consumer_key) for c in clients] + self._lock = asyncio.Lock() + self._not_empty = asyncio.Event() + # 轮询指针:指向下一次应该开始检查的索引 + self._rr_idx = 0 + if any(item.dead_at is None for item in self._pool): + self._not_empty.set() + + async def acquire(self) -> tuple[Client, str]: # type: ignore[no-any-unimported] + """ + 以轮询方式获取一个可用的客户端。 + 如果当前没有可用的客户端,将异步等待直到有客户端复活或被添加。 + """ + while True: + need_wake = True + async with self._lock: + # 0. 如果池子已空(所有客户端被永久移除),直接挂起等待 + if not self._pool: + self._not_empty.clear() + # 跳出 with-block 以释放锁,然后等待 + raise RuntimeError("TwitterClientPool: 所有客户端已被永久摘除,请重建池子") + # 1. 检查并复活到期的客户端 + now = time.time() + revived = False + for it in self._pool: + if it.dead_at and now - it.dead_at >= self._retry_after: + it.dead_at = None + revived = True + logger.info("client %s revived", it.client_key) + if revived: + need_wake = True + else: + need_wake = False + # 2. 健壮的轮询逻辑 + # 从上一个位置开始,遍历整个池子寻找可用的客户端 + for i in range(len(self._pool)): + idx = (self._rr_idx + i) % len(self._pool) + chosen = self._pool[idx] + if chosen.dead_at is None: + # 找到了,更新下一次轮询的起始点 + self._rr_idx = (idx + 1) % len(self._pool) + return chosen.client, chosen.client_key + # 3. 如果没有找到可用的客户端,清空事件,准备等待 + self._not_empty.clear() + if need_wake: + self._not_empty.set() + # 4. 在锁外等待,避免阻塞其他协程 + await self._not_empty.wait() + + # -------------------- 加锁摘除 -------------------- + async def remove(self, client: tweepy.Client) -> None: # type: ignore[no-any-unimported] + """永久摘除某个 client(不再放回池子)。""" + async with self._lock: + # 使用列表推导式过滤掉要移除的客户端,比 pop 更安全 + original_len = len(self._pool) + client_key_to_remove = client.consumer_key + self._pool = [it for it in self._pool if it.client is not client] + if len(self._pool) < original_len: + logger.info("client %s removed permanently", client_key_to_remove) + # 检查移除后是否还有存活的客户端 + if not any(item.dead_at is None for item in self._pool): + self._not_empty.clear() + + # -------------------- 归还 -------------------- + async def report_failure(self, client: tweepy.Client) -> None: # type: ignore[no-any-unimported] + """ + 报告一个客户端操作失败,将其置于熔断状态。 + 这不会将客户端从池中移除,它将在指定时间后自动复活。 + """ + async with self._lock: + for it in self._pool: + if it.client is client: + # 只有当它还活着时才标记为死亡,避免重复记录 + if it.dead_at is None: + it.dead_at = time.time() + logger.warning( + "client %s dead, will retry after %s min", it.client_key, self._retry_after // 60 + ) + # 检查此操作是否导致所有客户端都死亡 + if not any(item.dead_at is None for item in self._pool): + self._not_empty.clear() + return # 找到后即可退出 diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py new file mode 100644 index 0000000..8958366 --- /dev/null +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py @@ -0,0 +1,446 @@ +""" +Twitter 时间线 & Mention 增量抓取 + 对话链拼合 +网络请求通过 TwitterClientPool,Prometheus 埋点带 client_key +""" + +import asyncio +import logging +from datetime import datetime, timedelta, timezone +from typing import Any, Callable, Dict, List, Optional, cast + +from prometheus_client import Counter, Gauge +from tweepy import Media, NotFound, TwitterServerError, User # 保持原类型注解 +from tweepy import Response as TwitterResponse + +from sunagent_ext.tweet.twitter_client_pool import TwitterClientPool + +logger = logging.getLogger("tweet_get_context") + +# ---------- Prometheus 指标 ---------- +read_tweet_success_count = Counter( + "ext_read_tweet_success_count", "Number of successful read tweets", labelnames=["client_key"] +) +read_tweet_failure_count = Counter( + "ext_read_tweet_failure_count", "Number of failed read tweets", labelnames=["client_key"] +) +tweet_monthly_cap = Gauge("ext_tweet_monthly_cap", "0=触顶 1=正常", labelnames=["client_key"]) + +# ---------- 字段 ---------- +TWEET_FIELDS = [ + "id", + "created_at", + "author_id", + "text", + "public_metrics", + "referenced_tweets", + "conversation_id", + "entities", + "display_text_range", + "attachments", + "withheld", + "note_tweet", + "edit_controls", + "edit_history_tweet_ids", + "possibly_sensitive", + "reply_settings", + "source", + "lang", + "geo", + "context_annotations", + "card_uri", + "community_id", + "in_reply_to_user_id", + "media_metadata", +] +EXPANSIONS = [ + "author_id", + "referenced_tweets.id", + "referenced_tweets.id.author_id", + "attachments.media_keys", + "attachments.poll_ids", + "geo.place_id", +] +USER_FIELDS = [ + "id", + "username", + "name", + "public_metrics", + "created_at", + "description", + "entities", + "location", + "pinned_tweet_id", + "profile_image_url", + "protected", + "verified", + "verified_type", + "is_identity_verified", + "affiliation", + "connection_status", + "most_recent_tweet_id", + "parody", + "receives_your_dm", + "subscription", + "subscription_type", + "profile_banner_url", + "withheld", +] +MEDIA_FIELDS = [ + "alt_text", + "duration_ms", + "height", + "media_key", + "preview_image_url", + "public_metrics", + "type", + "url", + "variants", + "width", +] +POLL_FIELDS = ["duration_minutes", "end_datetime", "id", "options", "voting_status"] +PLACE_FIELDS = ["contained_within", "country", "country_code", "full_name", "geo", "id", "name", "place_type"] +MAX_RESULTS = 100 +PROCESS_KEY_PREFIX = "P:" +FREQ_KEY_PREFIX = "F:" +HOME_TIMELINE_ID = "last_home_timeline" +MENTIONS_TIMELINE_ID = "last_mentions_timeline" +MONTHLY_CAP_INFO = "Monthly product cap" + + +# ---------- 主类 ---------- +class TweetGetContext: + def __init__( # type: ignore[no-untyped-def] + self, + pool: TwitterClientPool, # 外部池子 + cache=None, # 可选缓存 + max_results: int = MAX_RESULTS, + block_user_ids: Optional[list[str]] = None, + white_user_ids: Optional[list[str]] = None, + reply_freq_limit: int = 5, + max_depth: int = 5, + ) -> None: + self.pool = pool + self.cache = cache + self.max_depth = max_depth + self.max_results = max_results + self.block_uids = set(block_user_ids or []) + self.white_uids = set(white_user_ids or []) + self.freq_limit = reply_freq_limit + # 用于 mentions_me 判断(可外部注入 me_id) + self.me_id: Optional[str] = None + + # ===================== 对外 API ===================== + async def get_home_timeline_with_context( + self, + me_id: str, + agent_id: str, + hours: int = 24, + since_id: Optional[str] = None, + filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None, + ) -> list[Dict[str, Any]]: + return await self._fetch_timeline( + endpoint="home", + me_id=me_id, + hours=hours, + since_id=since_id, + agent_id=agent_id, + filter_func=filter_func or (lambda _: True), + ) + + async def get_mentions_with_context( + self, + me_id: str, + agent_id: str, + hours: int = 24, + since_id: Optional[str] = None, + filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None, + ) -> list[Dict[str, Any]]: + self.me_id = me_id + return await self._fetch_timeline( + endpoint="mentions", + me_id=me_id, + hours=hours, + agent_id=agent_id, + since_id=since_id, + filter_func=filter_func or (lambda _: True), + ) + + # ===================== 统一抓取 ===================== + async def _fetch_timeline( + self, + endpoint: str, + me_id: str, + hours: int, + since_id: Optional[str], + filter_func: Callable[[Dict[str, Any]], bool], + agent_id: str, + ) -> list[Dict[str, Any]]: + since = datetime.now(timezone.utc) - timedelta(hours=hours) + start_time = since.isoformat(timespec="seconds") + next_token = None + all_raw: list[Dict[str, Any]] = [] + cache_key = f"{agent_id}:{MENTIONS_TIMELINE_ID}" + if endpoint == "home": + cache_key = f"{agent_id}:{HOME_TIMELINE_ID}" + if not since_id and self.cache: + since_id = self.cache.get(cache_key) + + while True: + cli, client_key = await self.pool.acquire() + try: + if endpoint == "home": + resp = cli.get_home_timeline( + tweet_fields=TWEET_FIELDS, + expansions=EXPANSIONS, + media_fields=MEDIA_FIELDS, + poll_fields=POLL_FIELDS, + user_fields=USER_FIELDS, + place_fields=PLACE_FIELDS, + exclude=["replies", "retweets"], + start_time=start_time, + since_id=since_id, + max_results=self.max_results, + pagination_token=next_token, + user_auth=True, + ) + else: # mentions + resp = cli.get_users_mentions( + id=me_id, + tweet_fields=TWEET_FIELDS, + expansions=EXPANSIONS, + media_fields=MEDIA_FIELDS, + poll_fields=POLL_FIELDS, + user_fields=USER_FIELDS, + place_fields=PLACE_FIELDS, + start_time=start_time, + since_id=since_id, + max_results=self.max_results, + pagination_token=next_token, + user_auth=True, + ) + + # ③ 成功埋点 + read_tweet_success_count.labels(client_key=client_key).inc(len(resp.data or [])) + + # 交给中间层处理 + tweet_list, next_token = await self.on_twitter_response(agent_id, resp, filter_func) + all_raw.extend(tweet_list) + if not next_token: + break + except (NotFound, TwitterServerError): + break + except Exception as e: + logger.warning("timeline %s error: %s", endpoint, e) + # ④ 失败埋点 + read_tweet_failure_count.labels(client_key=client_key).inc() + # ⑤ 月额度检测 + if MONTHLY_CAP_INFO in str(e): + tweet_monthly_cap.labels(client_key=client_key).set(0) + await self.pool.remove(cli) # 永久踢出 + logger.error("client %s removed due to monthly cap", client_key) + break + else: + tweet_monthly_cap.labels(client_key=client_key).set(1) + await self.pool.report_failure(cli) + break + + all_raw.sort(key=lambda t: t["id"]) + if all_raw and self.cache: + newest_id = all_raw[-1]["id"] + self.cache.set(cache_key, str(newest_id)) + return all_raw + + # ===================== 中间处理钩子(保留) ===================== + async def on_twitter_response( # type: ignore[no-any-unimported] + self, + agent_id: str, + response: TwitterResponse, + filter_func: Callable[[Dict[str, Any]], bool], + ) -> tuple[list[Dict[str, Any]], Optional[str]]: + next_token = response.meta.get("next_token") + if response.meta.get("result_count", 0) == 0 or response.data is None: + return [], next_token + + users = self._build_users(response.includes) + medias = self._build_medias(response.includes) + all_tweets = self._get_all_tweets(response, users, medias) + out: list[Dict[str, Any]] = [] + + for tweet in all_tweets: + if not await self._should_keep(agent_id, tweet, filter_func): + continue + norm = await self._normalize_tweet(tweet) + out.append(norm) + return out, next_token + + async def _should_keep( + self, agent_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool] + ) -> bool: + is_processed = await self._check_tweet_process(tweet["id"], agent_id) + if is_processed: + logger.info("already processed %s", tweet["id"]) + return False + author_id = str(tweet["author_id"]) + if author_id in self.block_uids: + logger.info("blocked user %s", author_id) + return False + freq = await self._get_freq(agent_id, tweet) + if freq >= self.freq_limit and author_id not in self.white_uids: + logger.info(f"skip tweet {tweet['id']} freq {freq}") + return False + await self._increase_freq(agent_id, tweet) + return filter_func(tweet) + + async def _check_tweet_process(self, tweet_id: str, agent_id: str) -> bool: + if self.cache is None: + return False + try: + return self.cache.get(f"{agent_id}:{PROCESS_KEY_PREFIX}{tweet_id}") is not None + except Exception: + # regard it as processed if cache not available + return True + + async def _mark_tweet_process(self, tweet_id: str, agent_id: str) -> None: + if self.cache is None: + return + try: + self.cache.set(f"{agent_id}:{PROCESS_KEY_PREFIX}{tweet_id}", "") + except Exception: + pass + + async def _get_freq(self, agent_id: str, tweet: Dict[str, Any]) -> int: + if self.cache is None: + return -1 + try: + freq = self.cache.get(f"{agent_id}:{FREQ_KEY_PREFIX}{tweet['conversation_id']}") + return int(freq) if freq else 0 + except Exception: + return 0 + + async def _increase_freq(self, agent_id: str, tweet: Dict[str, Any]) -> None: + if self.cache is None: + return + freq = await self._get_freq(agent_id, tweet) + try: + self.cache.set(f"{agent_id}:{FREQ_KEY_PREFIX}{tweet['conversation_id']}", str(freq + 1)) + except Exception: + pass + + async def _normalize_tweet(self, tweet: Dict[str, Any]) -> Dict[str, Any]: + wanted = [ + "id", + "created_at", + "author_id", + "author", + "text", + "public_metrics", + "conversation_id", + "entities", + ] + out = {k: tweet[k] for k in wanted if k in tweet} + out["history"] = await self._build_context(tweet) + out["sampling_quote"] = not tweet.get("referenced_tweets") + return out + + async def _build_context(self, tweet: Dict[str, Any]) -> str: + chain: list[Dict[str, Any]] = [] + await self._recursive_fetch(tweet, chain, depth=0) + lines = [""] + for t in chain: + lines.append(f"{t.get('text', '')}") + lines.append("") + return "\n".join(lines) + + async def _recursive_fetch(self, tweet: Dict[str, Any], chain: list[Dict[str, Any]], depth: int) -> None: + if depth > 5: + chain.append(tweet) + return + parent_id = None + if tweet.get("referenced_tweets"): + ref = tweet["referenced_tweets"][0] + if ref["type"] == "replied_to": + parent_id = ref["id"] + if parent_id: + parent = await self._get_tweet_with_retry(parent_id) + if parent: + await self._recursive_fetch(parent, chain, depth + 1) + chain.append(tweet) + + async def _get_tweet_with_retry(self, tweet_id: str) -> Optional[Dict[str, Any]]: + for attempt in range(3): + cli, client_key = await self.pool.acquire() + try: + resp = cli.get_tweet( + tweet_id, tweet_fields=TWEET_FIELDS, expansions=EXPANSIONS, user_fields=USER_FIELDS, user_auth=True + ) + if not resp.data: + return None + tw: Dict[str, Any] = resp.data.data + users = self._build_users(resp.includes) + self._format_tweet_data(tw, users, self._build_medias(resp.includes)) + return tw + except (NotFound, TwitterServerError): + return None + except Exception as e: + logger.warning("get_tweet retry %s: %s", attempt + 1, e) + await self.pool.report_failure(cli) + if attempt == 2: + return None + await asyncio.sleep(2**attempt) + return None + + # ===================== 原方法签名保持不变 ===================== + def _format_tweet_data(self, tweet: Dict[str, Any], users: Dict[str, User], medias: Dict[str, Media]) -> None: # type: ignore[no-any-unimported] + """标准化推文内容""" + author_id = tweet["author_id"] + user = users[author_id] if author_id in users else None + author = str(user.username) if user and "username" in user else author_id + tweet["author"] = author + tweet["is_robot"] = ( + "Automated" in user["affiliation"]["description"] + if user and "affiliation" in user and "description" in user["affiliation"] + else False + ) + tweet["mentions_me"] = ( + "entities" in tweet + and "mentions" in tweet["entities"] + and self.me_id in list(str(i["id"]) for i in tweet["entities"]["mentions"]) + ) + text = tweet["text"] + if "display_text_range" in tweet: + display_text_range: list[int] = tweet["display_text_range"] + text = text[display_text_range[0] : display_text_range[1]] + tweet["text"] = f"{author}:\n{text}\n\n" + + if ( + "attachments" in tweet + and "media_keys" in tweet["attachments"] + and len(tweet["attachments"]["media_keys"]) > 0 + ): + key = tweet["attachments"]["media_keys"][0] + if key in medias and medias[key].type == "photo": + tweet["image_url"] = medias[key].url + + def _build_users(self, includes: Dict[str, Any]) -> Dict[str, User]: # type: ignore[no-any-unimported] + users: Dict[str, User] = {} # type: ignore[no-any-unimported] + if "users" in includes: + for user in includes["users"]: + users[str(user.id)] = user + return users + + def _build_medias(self, includes: Dict[str, Any]) -> Dict[str, Media]: # type: ignore[no-any-unimported] + medias: Dict[str, Media] = {} # type: ignore[no-any-unimported] + if "media" in includes: + for media in includes["media"]: + medias[str(media.media_key)] = media + return medias + + def _get_all_tweets( # type: ignore[no-any-unimported] + self, response: TwitterResponse, users: Dict[str, User], medias: Dict[str, Media] + ) -> list[Dict[str, Any]]: + all_tweets: list[Dict[str, Any]] = [] + for tweet in response.data: + t = tweet.data + self._format_tweet_data(t, users, medias) + all_tweets.append(t) + return all_tweets