Skip to content

Commit 5ad4889

Browse files
committed
code
1 parent ab3a1f0 commit 5ad4889

3 files changed

Lines changed: 90 additions & 42 deletions

File tree

rss_glue/feeds/feed.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,3 +403,36 @@ def posts(
403403
self.post(post.id) for post in self.source.posts(limit=limit, start=start, end=end)
404404
]
405405
return [post for post in posts if post is not None]
406+
407+
408+
class AugmentFeed(BaseFeed, ABC):
409+
"""
410+
A feed that augments another feed with additional data.
411+
"""
412+
413+
source: BaseFeed
414+
post_cls: type[ReferenceFeedItem] = ReferenceFeedItem
415+
name: str = "augment"
416+
limit: int
417+
418+
def __init__(self, source: BaseFeed, limit: int = 10):
419+
self.source = source
420+
self.limit = limit
421+
self.author = source.author
422+
self.origin_url = source.origin_url
423+
self.title = source.title
424+
super().__init__()
425+
426+
@property
427+
def namespace(self):
428+
return f"{self.name}_{self.source.namespace}"
429+
430+
def sources(self) -> Iterable[BaseFeed]:
431+
yield self.source
432+
433+
def next_update(self, force):
434+
source_next_update, source_needs_update = self.source.next_update(force)
435+
if self.source.last_updated >= self.last_updated:
436+
return self.source.last_updated, True
437+
return source_next_update, source_needs_update
438+

rss_glue/feeds/smart_filter.py

Lines changed: 17 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from rss_glue.feeds import ai_client, feed
77
from rss_glue.utils import from_subpost
88

9-
base_prompt = """
9+
_base_prompt = """
1010
Please take a look at the following bit of a content from an RSS feed:
1111
1212
Title: {title}
@@ -43,18 +43,17 @@ class AiFilterPost(feed.ReferenceFeedItem):
4343
include_post: bool
4444

4545

46-
class AiFilterFeed(feed.BaseFeed):
46+
class AiFilterFeed(feed.AugmentFeed):
4747
"""
4848
AiFilterFeed is a feed that filters out posts based
4949
on a given prompt.
5050
"""
5151

52-
source: feed.BaseFeed
53-
limit: int
5452
client: ai_client.AiClient
5553
name: str = "smart_filter"
5654
post_cls: type[AiFilterPost] = AiFilterPost
5755

56+
5857
def __init__(
5958
self,
6059
source: feed.BaseFeed,
@@ -64,29 +63,29 @@ def __init__(
6463
limit: int = -1,
6564
title: Optional[str] = None,
6665
):
67-
self.source = source
68-
self.limit = limit
6966
self.prompt = prompt
7067
self.client = client
71-
72-
self.author = source.author
73-
self.origin_url = source.origin_url
7468
self.content_limit = content_limit
69+
70+
super().__init__(source=source, limit=limit)
71+
7572
if title:
7673
self.title = title
7774
else:
7875
self.title = f"Filter {source.title}"
79-
super().__init__()
8076

81-
@property
82-
def namespace(self):
83-
return f"{self.name}_{self.source.namespace}"
77+
def posts(
78+
self, limit: int = 50, start: Optional[datetime] = None, end: Optional[datetime] = None
79+
) -> list[feed.FeedItem]:
80+
source_posts = self.source.posts(limit=limit, start=start, end=end)
81+
posts = [cast(AiFilterPost, self.post(post.id)) for post in source_posts]
82+
return [post for post in posts if post and post.include_post]
8483

85-
def format_prompt(self, post: feed.FeedItem) -> str:
84+
def _format_prompt(self, post: feed.FeedItem) -> str:
8685
f = HTMLFilter()
8786
f.feed(post.render())
8887

89-
return base_prompt.format(
88+
return _base_prompt.format(
9089
title=post.title,
9190
author=post.author,
9291
content=f.text[: self.content_limit],
@@ -95,38 +94,14 @@ def format_prompt(self, post: feed.FeedItem) -> str:
9594
prompt=self.prompt,
9695
)
9796

98-
def posts(
99-
self, limit: int = 50, start: Optional[datetime] = None, end: Optional[datetime] = None
100-
) -> list[feed.FeedItem]:
101-
source_posts = self.source.posts(limit=limit, start=start, end=end)
102-
posts = [cast(AiFilterPost, self.post(post.id)) for post in source_posts]
103-
return [post for post in posts if post and post.include_post]
104-
105-
def sources(self) -> Iterable[feed.BaseFeed]:
106-
yield self.source
107-
108-
def next_update(self, force):
109-
return self.source.next_update(force)
110-
11197
def update(self) -> None:
112-
"""
113-
This feed only updates when the source feed updates
114-
"""
115-
source_posts = self.source.posts()
116-
# Sort by posted_time
117-
source_posts.sort(key=lambda post: post.posted_time, reverse=True)
118-
# Limit to the user specified limit
119-
if self.limit != -1:
120-
source_posts = source_posts[: self.limit]
121-
# Figure out which ones we haven't tested yet
122-
123-
for source_post in source_posts:
98+
for source_post in self.source.posts(limit=self.limit):
12499
post = self.post(source_post.id)
125100
if post:
126101
self.logger.debug(f" skipping filter check for {source_post.id}")
127102
continue
128103

129-
msg = self.client.get_response(self.format_prompt(source_post))
104+
msg = self.client.get_response(self._format_prompt(source_post))
130105

131106
include_post = False
132107
if "yes" in msg.response.lower():
@@ -144,4 +119,4 @@ def update(self) -> None:
144119
include_post=include_post,
145120
)
146121
self.logger.info(f"post={source_post.id} include_post={include_post}")
147-
self.cache_set(value.id, value.to_dict())
122+
self.cache_set(value.id, value.to_dict())

tests/test_smart_filter.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,43 @@ def test_smart_filter_basic_filtering(self, fs_config, mock_http_requests):
7070
assert filtered_post.include_post is True
7171
assert filtered_post.token_cost == 100
7272
assert filtered_post.namespace == f"smart_filter_{source_feed.namespace}"
73+
74+
def test_smart_filter_updates_when_source_updated_but_not_due(
75+
self, fs_config, mock_http_requests
76+
):
77+
"""Test that smart filter needs update when source was updated more recently, even if source doesn't need update."""
78+
mock_rss_feed()
79+
80+
# Create a source feed
81+
source_feed = RssFeed(
82+
id="test_feed",
83+
url="https://example.com/feed.xml",
84+
limit=10,
85+
interval=timedelta(hours=1),
86+
)
87+
fs_config([source_feed])
88+
source_feed.update()
89+
90+
# Create a mock AI client
91+
mock_client = Mock(spec=AiClient)
92+
mock_client.get_response.return_value = AiClientResponse(response="yes", tokens_used=100)
93+
94+
# Create the smart filter but don't update it yet
95+
filter_feed = AiFilterFeed(
96+
source=source_feed,
97+
client=mock_client,
98+
prompt="Only include posts about technology",
99+
limit=10,
100+
)
101+
fs_config([source_feed, filter_feed])
102+
103+
# At this point, source has been updated but filter hasn't
104+
# So source.last_updated > filter.last_updated (or filter.last_updated is None)
105+
assert source_feed.last_updated is not None
106+
107+
# Check next_update - should return True even though source doesn't need update
108+
next_update_time, needs_update = filter_feed.next_update(force=False)
109+
110+
# Should indicate it needs an update because source was updated more recently
111+
assert needs_update is True
112+
assert next_update_time == source_feed.last_updated

0 commit comments

Comments
 (0)