66from rss_glue .feeds import ai_client , feed
77from rss_glue .utils import from_subpost
88
9- base_prompt = """
9+ _base_prompt = """
1010Please take a look at the following bit of a content from an RSS feed:
1111
1212Title: {title}
@@ -43,18 +43,17 @@ class AiFilterPost(feed.ReferenceFeedItem):
4343 include_post : bool
4444
4545
46- class AiFilterFeed (feed .BaseFeed ):
46+ class AiFilterFeed (feed .AugmentFeed ):
4747 """
4848 AiFilterFeed is a feed that filters out posts based
4949 on a given prompt.
5050 """
5151
52- source : feed .BaseFeed
53- limit : int
5452 client : ai_client .AiClient
5553 name : str = "smart_filter"
5654 post_cls : type [AiFilterPost ] = AiFilterPost
5755
56+
5857 def __init__ (
5958 self ,
6059 source : feed .BaseFeed ,
@@ -64,29 +63,29 @@ def __init__(
6463 limit : int = - 1 ,
6564 title : Optional [str ] = None ,
6665 ):
67- self .source = source
68- self .limit = limit
6966 self .prompt = prompt
7067 self .client = client
71-
72- self .author = source .author
73- self .origin_url = source .origin_url
7468 self .content_limit = content_limit
69+
70+ super ().__init__ (source = source , limit = limit )
71+
7572 if title :
7673 self .title = title
7774 else :
7875 self .title = f"Filter { source .title } "
79- super ().__init__ ()
8076
81- @property
82- def namespace (self ):
83- return f"{ self .name } _{ self .source .namespace } "
77+ def posts (
78+ self , limit : int = 50 , start : Optional [datetime ] = None , end : Optional [datetime ] = None
79+ ) -> list [feed .FeedItem ]:
80+ source_posts = self .source .posts (limit = limit , start = start , end = end )
81+ posts = [cast (AiFilterPost , self .post (post .id )) for post in source_posts ]
82+ return [post for post in posts if post and post .include_post ]
8483
85- def format_prompt (self , post : feed .FeedItem ) -> str :
84+ def _format_prompt (self , post : feed .FeedItem ) -> str :
8685 f = HTMLFilter ()
8786 f .feed (post .render ())
8887
89- return base_prompt .format (
88+ return _base_prompt .format (
9089 title = post .title ,
9190 author = post .author ,
9291 content = f .text [: self .content_limit ],
@@ -95,38 +94,14 @@ def format_prompt(self, post: feed.FeedItem) -> str:
9594 prompt = self .prompt ,
9695 )
9796
98- def posts (
99- self , limit : int = 50 , start : Optional [datetime ] = None , end : Optional [datetime ] = None
100- ) -> list [feed .FeedItem ]:
101- source_posts = self .source .posts (limit = limit , start = start , end = end )
102- posts = [cast (AiFilterPost , self .post (post .id )) for post in source_posts ]
103- return [post for post in posts if post and post .include_post ]
104-
105- def sources (self ) -> Iterable [feed .BaseFeed ]:
106- yield self .source
107-
108- def next_update (self , force ):
109- return self .source .next_update (force )
110-
11197 def update (self ) -> None :
112- """
113- This feed only updates when the source feed updates
114- """
115- source_posts = self .source .posts ()
116- # Sort by posted_time
117- source_posts .sort (key = lambda post : post .posted_time , reverse = True )
118- # Limit to the user specified limit
119- if self .limit != - 1 :
120- source_posts = source_posts [: self .limit ]
121- # Figure out which ones we haven't tested yet
122-
123- for source_post in source_posts :
98+ for source_post in self .source .posts (limit = self .limit ):
12499 post = self .post (source_post .id )
125100 if post :
126101 self .logger .debug (f" skipping filter check for { source_post .id } " )
127102 continue
128103
129- msg = self .client .get_response (self .format_prompt (source_post ))
104+ msg = self .client .get_response (self ._format_prompt (source_post ))
130105
131106 include_post = False
132107 if "yes" in msg .response .lower ():
@@ -144,4 +119,4 @@ def update(self) -> None:
144119 include_post = include_post ,
145120 )
146121 self .logger .info (f"post={ source_post .id } include_post={ include_post } " )
147- self .cache_set (value .id , value .to_dict ())
122+ self .cache_set (value .id , value .to_dict ())
0 commit comments