Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions slack_bolt/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore

from slack_bolt.context.assistant.assistant_utilities import AssistantUtilities
from slack_bolt.error import BoltError, BoltUnhandledRequestError
from slack_bolt.lazy_listener.thread_runner import ThreadLazyListenerRunner
from slack_bolt.listener.builtins import TokenRevocationListeners
Expand Down Expand Up @@ -83,10 +82,6 @@
from slack_bolt.oauth.internals import select_consistent_installation_store
from slack_bolt.oauth.oauth_settings import OAuthSettings
from slack_bolt.request import BoltRequest
from slack_bolt.request.payload_utils import (
is_assistant_event,
to_event,
)
from slack_bolt.response import BoltResponse
from slack_bolt.util.utils import (
create_web_client,
Expand Down Expand Up @@ -1398,20 +1393,6 @@ def _init_context(self, req: BoltRequest):
# It is intended for apps that start lazy listeners from their custom global middleware.
req.context["listener_runner"] = self.listener_runner

# For AI Agents & Assistants
if is_assistant_event(req.body):
assistant = AssistantUtilities(
payload=to_event(req.body), # type:ignore[arg-type]
context=req.context,
thread_context_store=self._assistant_thread_context_store,
)
req.context["say"] = assistant.say
req.context["set_status"] = assistant.set_status
req.context["set_title"] = assistant.set_title
req.context["set_suggested_prompts"] = assistant.set_suggested_prompts
req.context["get_thread_context"] = assistant.get_thread_context
req.context["save_thread_context"] = assistant.save_thread_context

@staticmethod
def _to_listener_functions(
kwargs: dict,
Expand Down
16 changes: 0 additions & 16 deletions slack_bolt/app/async_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from aiohttp import web

from slack_bolt.app.async_server import AsyncSlackAppServer
from slack_bolt.context.assistant.async_assistant_utilities import AsyncAssistantUtilities
from slack_bolt.context.assistant.thread_context_store.async_store import (
AsyncAssistantThreadContextStore,
)
Expand All @@ -30,7 +29,6 @@
AsyncMessageListenerMatches,
)
from slack_bolt.oauth.async_internals import select_consistent_installation_store
from slack_bolt.request.payload_utils import is_assistant_event, to_event
from slack_bolt.util.utils import get_name_for_callable, is_callable_coroutine
from slack_bolt.workflows.step.async_step import (
AsyncWorkflowStep,
Expand Down Expand Up @@ -1431,20 +1429,6 @@ def _init_context(self, req: AsyncBoltRequest):
# It is intended for apps that start lazy listeners from their custom global middleware.
req.context["listener_runner"] = self.listener_runner

# For AI Agents & Assistants
if is_assistant_event(req.body):
assistant = AsyncAssistantUtilities(
payload=to_event(req.body), # type:ignore[arg-type]
context=req.context,
thread_context_store=self._assistant_thread_context_store,
)
req.context["say"] = assistant.say
req.context["set_status"] = assistant.set_status
req.context["set_title"] = assistant.set_title
req.context["set_suggested_prompts"] = assistant.set_suggested_prompts
req.context["get_thread_context"] = assistant.get_thread_context
req.context["save_thread_context"] = assistant.save_thread_context

@staticmethod
def _to_listener_functions(
kwargs: dict,
Expand Down
2 changes: 1 addition & 1 deletion slack_bolt/context/async_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def handle_button_clicks(ack, say):
Callable `say()` function
"""
if "say" not in self:
self["say"] = AsyncSay(client=self.client, channel=self.channel_id, thread_ts=self.thread_ts)
self["say"] = AsyncSay(client=self.client, channel=self.channel_id)
return self["say"]

@property
Expand Down
2 changes: 1 addition & 1 deletion slack_bolt/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def handle_button_clicks(ack, say):
Callable `say()` function
"""
if "say" not in self:
self["say"] = Say(client=self.client, channel=self.channel_id, thread_ts=self.thread_ts)
self["say"] = Say(client=self.client, channel=self.channel_id)
return self["say"]

@property
Expand Down
11 changes: 11 additions & 0 deletions slack_bolt/middleware/assistant/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore
from slack_bolt.listener_matcher.builtins import build_listener_matcher

from slack_bolt.middleware.assistant.attaching_assistant_kwargs import AttachingAssistantKwargs
from slack_bolt.request.request import BoltRequest
from slack_bolt.response.response import BoltResponse
from slack_bolt.listener_matcher import CustomListenerMatcher
Expand Down Expand Up @@ -236,6 +237,15 @@ def process( # type:ignore[return]
if listeners is not None:
for listener in listeners:
if listener.matches(req=req, resp=resp):
middleware_resp, next_was_not_called = listener.run_middleware(req=req, resp=resp)
if next_was_not_called:
if middleware_resp is not None:
return middleware_resp
# The listener middleware didn't call next().
# This means the listener is not for this incoming request.
continue
if middleware_resp is not None:
resp = middleware_resp
return listener_runner.run(
request=req,
response=resp,
Expand All @@ -262,6 +272,7 @@ def build_listener(
return listener_or_functions
elif isinstance(listener_or_functions, list):
middleware = middleware if middleware else []
middleware.insert(0, AttachingAssistantKwargs(self.thread_context_store))
functions = listener_or_functions
ack_function = functions.pop(0)

Expand Down
11 changes: 11 additions & 0 deletions slack_bolt/middleware/assistant/async_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from slack_bolt.listener.asyncio_runner import AsyncioListenerRunner
from slack_bolt.listener_matcher.builtins import build_listener_matcher
from slack_bolt.middleware.assistant.async_attaching_assistant_kwargs import AsyncAttachingAssistantKwargs
from slack_bolt.request.async_request import AsyncBoltRequest
from slack_bolt.response import BoltResponse
from slack_bolt.error import BoltError
Expand Down Expand Up @@ -265,6 +266,15 @@ async def async_process( # type:ignore[return]
if listeners is not None:
for listener in listeners:
if listener is not None and await listener.async_matches(req=req, resp=resp):
middleware_resp, next_was_not_called = await listener.run_async_middleware(req=req, resp=resp)
if next_was_not_called:
if middleware_resp is not None:
return middleware_resp
# The listener middleware didn't call next().
# This means the listener is not for this incoming request.
continue
if middleware_resp is not None:
resp = middleware_resp
return await listener_runner.run(
request=req,
response=resp,
Expand All @@ -291,6 +301,7 @@ def build_listener(
return listener_or_functions
elif isinstance(listener_or_functions, list):
middleware = middleware if middleware else []
middleware.insert(0, AsyncAttachingAssistantKwargs(self.thread_context_store))
functions = listener_or_functions
ack_function = functions.pop(0)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Optional, Callable, Awaitable

from slack_bolt.context.assistant.async_assistant_utilities import AsyncAssistantUtilities
from slack_bolt.context.assistant.thread_context_store.async_store import AsyncAssistantThreadContextStore
from slack_bolt.middleware.async_middleware import AsyncMiddleware
from slack_bolt.request.async_request import AsyncBoltRequest
from slack_bolt.request.payload_utils import to_event
from slack_bolt.response import BoltResponse


class AsyncAttachingAssistantKwargs(AsyncMiddleware):

thread_context_store: Optional[AsyncAssistantThreadContextStore]

def __init__(self, thread_context_store: Optional[AsyncAssistantThreadContextStore]):
self.thread_context_store = thread_context_store

async def async_process(
self,
*,
req: AsyncBoltRequest,
resp: BoltResponse,
next: Callable[[], Awaitable[BoltResponse]],
) -> Optional[BoltResponse]:
event = to_event(req.body)
if event is not None:
assistant = AsyncAssistantUtilities(
payload=event,
context=req.context,
thread_context_store=self.thread_context_store,
)
req.context["say"] = assistant.say
req.context["set_status"] = assistant.set_status
req.context["set_title"] = assistant.set_title
req.context["set_suggested_prompts"] = assistant.set_suggested_prompts
req.context["get_thread_context"] = assistant.get_thread_context
req.context["save_thread_context"] = assistant.save_thread_context
return await next()
32 changes: 32 additions & 0 deletions slack_bolt/middleware/assistant/attaching_assistant_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional, Callable

from slack_bolt.context.assistant.assistant_utilities import AssistantUtilities
from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore
from slack_bolt.middleware import Middleware
from slack_bolt.request.payload_utils import to_event
from slack_bolt.request.request import BoltRequest
from slack_bolt.response.response import BoltResponse


class AttachingAssistantKwargs(Middleware):

thread_context_store: Optional[AssistantThreadContextStore]

def __init__(self, thread_context_store: Optional[AssistantThreadContextStore]):
self.thread_context_store = thread_context_store

def process(self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], BoltResponse]) -> Optional[BoltResponse]:
event = to_event(req.body)
if event is not None:
assistant = AssistantUtilities(
payload=event,
context=req.context,
thread_context_store=self.thread_context_store,
)
req.context["say"] = assistant.say
req.context["set_status"] = assistant.set_status
req.context["set_title"] = assistant.set_title
req.context["set_suggested_prompts"] = assistant.set_suggested_prompts
req.context["get_thread_context"] = assistant.get_thread_context
req.context["save_thread_context"] = assistant.save_thread_context
return next()
39 changes: 11 additions & 28 deletions slack_bolt/request/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from urllib.parse import parse_qsl, parse_qs

from slack_bolt.context import BoltContext
from slack_bolt.request.payload_utils import is_assistant_event


def parse_query(query: Optional[Union[str, Dict[str, str], Dict[str, Sequence[str]]]]) -> Dict[str, Sequence[str]]:
Expand Down Expand Up @@ -215,33 +214,17 @@ def extract_channel_id(payload: Dict[str, Any]) -> Optional[str]:


def extract_thread_ts(payload: Dict[str, Any]) -> Optional[str]:
# This utility initially supports only the use cases for AI assistants, but it may be fine to add more patterns.
# That said, note that thread_ts is always required for assistant threads, but it's not for channels.
# Thus, blindly setting this thread_ts to say utility can break existing apps' behaviors.
#
# The BoltAgent class handles non-assistant thread_ts separately by reading from the event directly,
# allowing it to work correctly without affecting say() behavior.
if is_assistant_event(payload):
event = payload["event"]
if (
event.get("assistant_thread") is not None
and event["assistant_thread"].get("channel_id") is not None
and event["assistant_thread"].get("thread_ts") is not None
):
# assistant_thread_started, assistant_thread_context_changed
# "assistant_thread" property can exist for message event without channel_id and thread_ts
# Thus, the above if check verifies these properties exist
return event["assistant_thread"]["thread_ts"]
elif event.get("channel") is not None:
if event.get("thread_ts") is not None:
# message in an assistant thread
return event["thread_ts"]
elif event.get("message", {}).get("thread_ts") is not None:
# message_changed
return event["message"]["thread_ts"]
elif event.get("previous_message", {}).get("thread_ts") is not None:
# message_deleted
return event["previous_message"]["thread_ts"]
thread_ts = payload.get("thread_ts")
if thread_ts is not None:
return thread_ts
if payload.get("event") is not None:
return extract_thread_ts(payload["event"])
if isinstance(payload.get("assistant_thread"), dict):
return extract_thread_ts(payload["assistant_thread"])
if isinstance(payload.get("message"), dict):
return extract_thread_ts(payload["message"])
if isinstance(payload.get("previous_message"), dict):
return extract_thread_ts(payload["previous_message"])
return None


Expand Down
86 changes: 86 additions & 0 deletions tests/scenario_tests/test_events_assistant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from time import sleep
from typing import Callable

from slack_sdk.web import WebClient

from slack_bolt import App, BoltRequest, Assistant, Say, SetSuggestedPrompts, SetStatus, BoltContext
from slack_bolt.middleware import Middleware
from slack_bolt.request import BoltRequest as BoltRequestType
from slack_bolt.response import BoltResponse
from tests.mock_web_api_server import (
setup_mock_web_api_server,
cleanup_mock_web_api_server,
Expand Down Expand Up @@ -44,6 +48,7 @@ def assert_target_called():
def start_thread(say: Say, set_suggested_prompts: SetSuggestedPrompts, context: BoltContext):
assert context.channel_id == "D111"
assert context.thread_ts == "1726133698.626339"
assert say.thread_ts == context.thread_ts
say("Hi, how can I help you today?")
set_suggested_prompts(prompts=[{"title": "What does SLACK stand for?", "message": "What does SLACK stand for?"}])
set_suggested_prompts(
Expand All @@ -61,6 +66,7 @@ def handle_thread_context_changed(context: BoltContext):
def handle_user_message(say: Say, set_status: SetStatus, context: BoltContext):
assert context.channel_id == "D111"
assert context.thread_ts == "1726133698.626339"
assert say.thread_ts == context.thread_ts
try:
set_status("is typing...")
say("Here you are!")
Expand Down Expand Up @@ -102,6 +108,86 @@ def handle_user_message(say: Say, set_status: SetStatus, context: BoltContext):
response = app.dispatch(request)
assert response.status == 404

def test_assistant_threads_with_custom_listener_middleware(self):
app = App(client=self.web_client)
assistant = Assistant()

state = {"called": False, "middleware_called": False}

def assert_target_called():
count = 0
while state["called"] is False and count < 20:
sleep(0.1)
count += 1
assert state["called"] is True
state["called"] = False

class TestMiddleware(Middleware):
def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]):
state["middleware_called"] = True
# Verify assistant utilities are available
assert req.context.get("set_status") is not None
assert req.context.get("set_title") is not None
assert req.context.get("set_suggested_prompts") is not None
assert req.context.get("get_thread_context") is not None
assert req.context.get("save_thread_context") is not None
return next()

@assistant.thread_started(middleware=[TestMiddleware()])
def start_thread(say: Say, set_suggested_prompts: SetSuggestedPrompts, context: BoltContext):
assert context.channel_id == "D111"
assert context.thread_ts == "1726133698.626339"
assert say.thread_ts == context.thread_ts
say("Hi, how can I help you today?")
set_suggested_prompts(prompts=[{"title": "What does SLACK stand for?", "message": "What does SLACK stand for?"}])
state["called"] = True

@assistant.user_message(middleware=[TestMiddleware()])
def handle_user_message(say: Say, set_status: SetStatus, context: BoltContext):
assert context.channel_id == "D111"
assert context.thread_ts == "1726133698.626339"
assert say.thread_ts == context.thread_ts
set_status("is typing...")
say("Here you are!")
state["called"] = True

app.assistant(assistant)

request = BoltRequest(body=thread_started_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 200
assert_target_called()
assert state["middleware_called"] is True
state["middleware_called"] = False

request = BoltRequest(body=user_message_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 200
assert_target_called()
assert state["middleware_called"] is True

def test_assistant_threads_custom_middleware_can_short_circuit(self):
app = App(client=self.web_client)
assistant = Assistant()

state = {"handler_called": False}

class BlockingMiddleware(Middleware):
def process(self, *, req: BoltRequestType, resp: BoltResponse, next: Callable[[], BoltResponse]):
# Intentionally not calling next() to short-circuit
return BoltResponse(status=200)

@assistant.thread_started(middleware=[BlockingMiddleware()])
def start_thread(say: Say, context: BoltContext):
state["handler_called"] = True

app.assistant(assistant)

request = BoltRequest(body=thread_started_event_body, mode="socket_mode")
response = app.dispatch(request)
assert response.status == 200
assert state["handler_called"] is False


def build_payload(event: dict) -> dict:
return {
Expand Down
Loading
Loading