Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 177 additions & 16 deletions maxapi/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .webhook.aiohttp import AiohttpMaxWebhook

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Iterator

from magic_filter import MagicFilter

Expand Down Expand Up @@ -236,7 +236,9 @@ def _prepare_handlers(self, bot: Bot) -> None:

handlers_count = 0

for router in self.routers:
for router, *_ in self._iter_unique_routers(
self.routers, warn_duplicates=True
):
router.bot = bot

for handler in router.event_handlers:
Expand Down Expand Up @@ -339,26 +341,177 @@ async def process_base_filters(

return data

def _iter_routers(
self,
routers: list[Router | Dispatcher],
parent_middlewares: list[BaseMiddleware] | None = None,
parent_filters: list[MagicFilter] | None = None,
parent_base_filters: list[BaseFilter] | None = None,
path: set[int] | None = None,
) -> Iterator[
tuple[
Router | Dispatcher,
list[BaseMiddleware],
list[MagicFilter],
list[BaseFilter],
]
]:
"""
Рекурсивно обходит роутеры, накапливая middleware и фильтры родителей.

Args:
routers: Список роутеров для обхода.
parent_middlewares: Накопленные middleware от родительских
роутеров.
parent_filters: Накопленные MagicFilter от родительских
роутеров.
parent_base_filters: Накопленные BaseFilter от родительских
роутеров.
path: Идентификаторы роутеров в текущей ветви обхода; используется,
чтобы не уходить в бесконечную рекурсию при циклических
включениях между роутерами.

Yields:
Кортеж (роутер, middleware, MagicFilter, BaseFilter) с накопленными
значениями от всех родителей.
"""
parent_middlewares = parent_middlewares or []
parent_filters = parent_filters or []
parent_base_filters = parent_base_filters or []
path = path if path is not None else set()

for router in routers:
router_key = id(router)
if router_key in path:
continue

if router is self:
accumulated_middlewares = parent_middlewares
else:
accumulated_middlewares = (
parent_middlewares + router.middlewares
)

accumulated_filters = parent_filters + router.filters
accumulated_base_filters = (
parent_base_filters + router.base_filters
)

yield (
router,
accumulated_middlewares,
accumulated_filters,
accumulated_base_filters,
)

sub_routers = (
[]
if router is self
else [r for r in router.routers if r is not self]
)
if sub_routers:
path.add(router_key)
try:
yield from self._iter_routers(
routers=sub_routers,
parent_middlewares=accumulated_middlewares,
parent_filters=accumulated_filters,
parent_base_filters=accumulated_base_filters,
path=path,
)
finally:
path.discard(router_key)

def _iter_unique_routers(
self,
routers: list[Router | Dispatcher],
parent_middlewares: list[BaseMiddleware] | None = None,
parent_filters: list[MagicFilter] | None = None,
parent_base_filters: list[BaseFilter] | None = None,
*,
warn_duplicates: bool = False,
) -> Iterator[
tuple[
Router | Dispatcher,
list[BaseMiddleware],
list[MagicFilter],
list[BaseFilter],
]
]:
"""
Обходит дерево роутеров и исключает повторные экземпляры роутеров.

При повторном включении одного и того же объекта роутера используется
контекст первого вхождения (накопленные middleware и фильтры).

Args:
routers: Список роутеров для обхода.
parent_middlewares: Накопленные middleware от родительских
роутеров.
parent_filters: Накопленные MagicFilter от родительских
роутеров.
parent_base_filters: Накопленные BaseFilter от родительских
роутеров.
warn_duplicates: Если True, выводит предупреждение при обнаружении
повторных включений одного и того же экземпляра роутера.
"""
seen: set[int] = set()
duplicate_keys: set[int] = set()
duplicate_titles: list[str] = []
try:
for item in self._iter_routers(
routers=routers,
parent_middlewares=parent_middlewares,
parent_filters=parent_filters,
parent_base_filters=parent_base_filters,
):
router = item[0]
router_key = id(router)
if router_key in seen:
if warn_duplicates and router_key not in duplicate_keys:
duplicate_keys.add(router_key)
rid = getattr(router, "router_id", None)
router_title = (
str(rid)
if rid is not None
else router.__class__.__name__
)
duplicate_titles.append(router_title)
continue
seen.add(router_key)
yield item
finally:
if warn_duplicates and duplicate_titles:
logger_dp.warning(
"Обнаружены повторные включения роутеров: %s. "
"Повторные вхождения будут дедуплицированы.",
", ".join(duplicate_titles),
)

async def _check_router_filters(
self, event: UpdateUnion, router: Router | Dispatcher
) -> dict[str, Any] | None | Literal[False]:
self,
event: UpdateUnion,
filters: list[MagicFilter],
base_filters: list[BaseFilter],
) -> dict[str, Any] | Literal[False]:
"""
Проверяет фильтры роутера для события.
Проверяет накопленные фильтры роутера для события.

Args:
event (UpdateUnion): Событие.
router (Router | Dispatcher): Роутер для проверки.
filters: Накопленные MagicFilter.
base_filters: Накопленные BaseFilter.

Returns:
Optional[Dict[str, Any]] | Literal[False]: Словарь с данными
или False, если фильтры не прошли.
Dict[str, Any] | Literal[False]: Словарь с данными или False,
если фильтры не прошли.
"""
if router.filters and not filter_attrs(event, *router.filters):
if filters and not filter_attrs(event, *filters):
return False

if router.base_filters:
if base_filters:
result = await self.process_base_filters(
event=event, filters=router.base_filters
event=event, filters=base_filters
)
if isinstance(result, dict):
return result
Expand Down Expand Up @@ -482,7 +635,7 @@ async def handle_raw_response(
"""
Специальный метод для обработки сырых ответов API.
"""
for router in self.routers:
for router, *_ in self._iter_unique_routers(self.routers):
matching_handlers = self._find_matching_handlers(
router, event_type
)
Expand Down Expand Up @@ -525,14 +678,19 @@ async def _process_event(

data["context"] = memory_context

for index, router in enumerate(self.routers):
for index, (
router,
router_middlewares,
router_filters,
router_base_filters,
) in enumerate(self._iter_unique_routers(self.routers)):
if is_handled:
break

router_id = router.router_id or index

router_filter_result = await self._check_router_filters(
event_object, router
event_object, router_filters, router_base_filters
)

if router_filter_result is False:
Expand All @@ -545,6 +703,9 @@ async def _process_event(
router, event_object.update_type
)

if not matching_handlers:
continue

async def _process_handlers(
event: UpdateUnion, handler_data: dict[str, Any]
) -> None:
Expand Down Expand Up @@ -582,9 +743,9 @@ async def _process_handlers(
is_handled = True
break

if isinstance(router, Router) and router.middlewares:
if router_middlewares:
router_chain = self.build_middleware_chain(
router.middlewares, _process_handlers
router_middlewares, _process_handlers
)
await router_chain(event_object, data)
else:
Expand Down
Loading