Skip to content

Commit fa1a4e0

Browse files
committed
Use an enum for callbacks
1 parent 41e0c2e commit fa1a4e0

7 files changed

Lines changed: 134 additions & 59 deletions

File tree

discord/ext/test/_types.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,11 @@
88
import discord
99
import typing
1010

11-
if typing.TYPE_CHECKING:
12-
from discord.types import (
13-
role, gateway, appinfo, user, guild, emoji, channel, message, sticker, # noqa: F401
14-
scheduled_event, member # noqa: F401
15-
)
16-
17-
AnyChannelJson = channel.VoiceChannel | channel.TextChannel | channel.DMChannel | channel.CategoryChannel
18-
else:
19-
class OpenNamespace:
20-
def __getattr__(self, item: str) -> Self:
21-
return self
22-
23-
def __subclasscheck__(self, subclass: type) -> Literal[True]:
24-
return True
25-
26-
def __or__(self, other: T) -> T:
27-
return other
28-
29-
def __getattr__(name: str) -> OpenNamespace:
30-
return OpenNamespace()
31-
3211
T = TypeVar('T')
3312
P = ParamSpec('P')
3413

35-
Callback = Callable[P, Coroutine[None, None, None]]
36-
AnyChannel = (discord.abc.GuildChannel | discord.TextChannel | discord.VoiceChannel | discord.StageChannel | discord.DMChannel | discord.Thread | discord.GroupChannel)
14+
AnyChannel = (discord.abc.GuildChannel | discord.TextChannel | discord.VoiceChannel | discord.StageChannel
15+
| discord.DMChannel | discord.Thread | discord.GroupChannel)
3716

3817

3918
class Wrapper(Protocol[P, T]):
@@ -55,3 +34,25 @@ class Undef(Enum):
5534

5635

5736
undefined: Literal[Undef.undefined] = Undef.undefined
37+
38+
39+
if typing.TYPE_CHECKING:
40+
from discord.types import (
41+
role, gateway, appinfo, user, guild, emoji, channel, message, sticker, snowflake, # noqa: F401
42+
scheduled_event, member # noqa: F401
43+
)
44+
45+
AnyChannelJson = channel.VoiceChannel | channel.TextChannel | channel.DMChannel | channel.CategoryChannel
46+
else:
47+
class OpenNamespace:
48+
def __getattr__(self, item: str) -> Self:
49+
return self
50+
51+
def __subclasscheck__(self, subclass: type) -> Literal[True]:
52+
return True
53+
54+
def __or__(self, other: T) -> T:
55+
return other
56+
57+
def __getattr__(name: str) -> OpenNamespace:
58+
return OpenNamespace()

discord/ext/test/backend.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from ._types import Undef, undefined
2727
from discord.types.snowflake import Snowflake
2828

29+
from .callbacks import CallbackEvent
30+
2931

3032
class BackendState(NamedTuple):
3133
"""
@@ -140,7 +142,7 @@ async def delete_channel(self, channel_id: Snowflake, *, reason: str | None = No
140142
delete_channel(channel)
141143

142144
async def get_channel(self, channel_id: Snowflake) -> _types.channel.GuildChannel:
143-
await callbacks.dispatch_event("get_channel", channel_id)
145+
await callbacks.dispatch_event(CallbackEvent.get_channel, channel_id)
144146

145147
find = None
146148
for guild in get_state().guilds:
@@ -155,7 +157,7 @@ async def start_private_message(self, user_id: Snowflake) -> _types.channel.DMCh
155157
locs = _get_higher_locs(1)
156158
user = locs["self"]
157159

158-
await callbacks.dispatch_event("start_private_message", user)
160+
await callbacks.dispatch_event(CallbackEvent.start_private_message, user)
159161

160162
return facts.make_dm_channel_dict(user)
161163

@@ -212,22 +214,22 @@ async def send_message(
212214
attachments=attachments,
213215
nonce=nonce
214216
)
215-
await callbacks.dispatch_event("send_message", message)
217+
await callbacks.dispatch_event(CallbackEvent.send_message, message)
216218

217219
return facts.dict_from_object(message)
218220

219221
async def send_typing(self, channel_id: Snowflake) -> None:
220222
locs = _get_higher_locs(1)
221223
channel = locs.get("channel", None)
222224

223-
await callbacks.dispatch_event("send_typing", channel)
225+
await callbacks.dispatch_event(CallbackEvent.send_typing, channel)
224226

225227
async def delete_message(self, channel_id: Snowflake, message_id: Snowflake, *,
226228
reason: str | None = None) -> None:
227229
locs = _get_higher_locs(1)
228230
message = locs["self"]
229231

230-
await callbacks.dispatch_event("delete_message", message.channel, message, reason=reason)
232+
await callbacks.dispatch_event(CallbackEvent.delete_message, message.channel, message, reason=reason)
231233

232234
delete_message(message)
233235

@@ -236,7 +238,7 @@ async def edit_message(self, channel_id: Snowflake, message_id: Snowflake,
236238
locs = _get_higher_locs(1)
237239
message = locs["self"]
238240

239-
await callbacks.dispatch_event("edit_message", message.channel, message, fields)
241+
await callbacks.dispatch_event(CallbackEvent.edit_message, message.channel, message, fields)
240242

241243
return edit_message(message, **fields)
242244

@@ -250,7 +252,7 @@ async def add_reaction(self, channel_id: Snowflake, message_id: Snowflake,
250252

251253
emoji = emoji # TODO: Turn this back into class?
252254

253-
await callbacks.dispatch_event("add_reaction", message, emoji)
255+
await callbacks.dispatch_event(CallbackEvent.add_reaction, message, emoji)
254256

255257
add_reaction(message, user, emoji)
256258

@@ -261,7 +263,7 @@ async def remove_reaction(self, channel_id: Snowflake, message_id: Snowflake,
261263
message = locs["self"]
262264
member = locs["member"]
263265

264-
await callbacks.dispatch_event("remove_reaction", message, emoji, member)
266+
await callbacks.dispatch_event(CallbackEvent.remove_reaction, message, emoji, member)
265267

266268
remove_reaction(message, member, emoji)
267269

@@ -271,7 +273,7 @@ async def remove_own_reaction(self, channel_id: Snowflake, message_id: Snowflake
271273
message = locs["self"]
272274
member = locs["member"]
273275

274-
await callbacks.dispatch_event("remove_own_reaction", message, emoji, member)
276+
await callbacks.dispatch_event(CallbackEvent.remove_own_reaction, message, emoji, member)
275277

276278
remove_reaction(message, self.state.user, emoji)
277279

@@ -285,7 +287,7 @@ async def get_message(self, channel_id: Snowflake,
285287
locs = _get_higher_locs(1)
286288
channel = locs["self"]
287289

288-
await callbacks.dispatch_event("get_message", channel, message_id)
290+
await callbacks.dispatch_event(CallbackEvent.get_message, channel, message_id)
289291

290292
messages = get_config().messages[int(channel_id)]
291293
find = next(filter(lambda m: m["id"] == message_id, messages), None)
@@ -304,7 +306,7 @@ async def logs_from(
304306
locs = _get_higher_locs(1)
305307
channel = locs["self"]
306308

307-
await callbacks.dispatch_event("logs_from", channel, limit, before=None, after=None, around=None)
309+
await callbacks.dispatch_event(CallbackEvent.logs_from, channel, limit, before=None, after=None, around=None)
308310

309311
messages = get_config().messages[int(channel_id)]
310312
if after is not None:
@@ -326,7 +328,7 @@ async def kick(self, user_id: Snowflake, guild_id: Snowflake,
326328
guild = locs["self"]
327329
member = locs["user"]
328330

329-
await callbacks.dispatch_event("kick", guild, member, reason=reason)
331+
await callbacks.dispatch_event(CallbackEvent.kick, guild, member, reason=reason)
330332

331333
delete_member(member)
332334

@@ -337,7 +339,7 @@ async def ban(self, user_id: Snowflake, guild_id: Snowflake,
337339
guild = locs["self"]
338340
member = locs["user"]
339341

340-
await callbacks.dispatch_event("ban", guild, member, delete_message_days, reason=reason)
342+
await callbacks.dispatch_event(CallbackEvent.ban, guild, member, delete_message_days, reason=reason)
341343

342344
delete_member(member)
343345

@@ -346,7 +348,7 @@ async def unban(self, user_id: Snowflake, guild_id: Snowflake, *,
346348
locs = _get_higher_locs(1)
347349
guild = locs["self"]
348350
member = locs["user"]
349-
await callbacks.dispatch_event("unban", guild, member, reason=reason)
351+
await callbacks.dispatch_event(CallbackEvent.unban, guild, member, reason=reason)
350352

351353
async def change_my_nickname(self, guild_id: Snowflake, nickname: str, *,
352354
reason: str | None = None) -> _types.member.Nickname:
@@ -355,7 +357,7 @@ async def change_my_nickname(self, guild_id: Snowflake, nickname: str, *,
355357

356358
me.nick = nickname
357359

358-
await callbacks.dispatch_event("change_nickname", nickname, me, reason=reason)
360+
await callbacks.dispatch_event(CallbackEvent.change_nickname, nickname, me, reason=reason)
359361

360362
return {"nick": nickname}
361363

@@ -365,7 +367,7 @@ async def edit_member(self, guild_id: Snowflake, user_id: Snowflake, *,
365367
locs = _get_higher_locs(1)
366368
member = locs["self"]
367369

368-
await callbacks.dispatch_event("edit_member", fields, member, reason=reason)
370+
await callbacks.dispatch_event(CallbackEvent.edit_member, fields, member, reason=reason)
369371
member = update_member(member, nick=fields.get('nick'), roles=fields.get('roles'))
370372
return facts.dict_from_object(member)
371373

@@ -384,7 +386,7 @@ async def edit_role(self, guild_id: Snowflake, role_id: Snowflake, *,
384386
role = locs["self"]
385387
guild = role.guild
386388

387-
await callbacks.dispatch_event("edit_role", guild, role, fields, reason=reason)
389+
await callbacks.dispatch_event(CallbackEvent.edit_role, guild, role, fields, reason=reason)
388390

389391
update_role(role, **fields)
390392
return facts.dict_from_object(role)
@@ -395,7 +397,7 @@ async def delete_role(self, guild_id: Snowflake, role_id: Snowflake, *,
395397
role = locs["self"]
396398
guild = role.guild
397399

398-
await callbacks.dispatch_event("delete_role", guild, role, reason=reason)
400+
await callbacks.dispatch_event(CallbackEvent.delete_role, guild, role, reason=reason)
399401

400402
delete_role(role)
401403

@@ -405,7 +407,7 @@ async def create_role(self, guild_id: Snowflake, *, reason: str | None = None,
405407
guild = locs["self"]
406408
role = make_role(guild=guild, **fields)
407409

408-
await callbacks.dispatch_event("create_role", guild, role, reason=reason)
410+
await callbacks.dispatch_event(CallbackEvent.create_role, guild, role, reason=reason)
409411

410412
return facts.dict_from_object(role)
411413

@@ -416,7 +418,7 @@ async def move_role_position(self, guild_id: Snowflake,
416418
role = locs["self"]
417419
guild = role.guild
418420

419-
await callbacks.dispatch_event("move_role", guild, role, positions, reason=reason)
421+
await callbacks.dispatch_event(CallbackEvent.move_role, guild, role, positions, reason=reason)
420422

421423
for pair in positions:
422424
guild._roles[pair["id"]].position = pair["position"]
@@ -428,7 +430,7 @@ async def add_role(self, guild_id: Snowflake, user_id: Snowflake,
428430
member = locs["self"]
429431
role = locs["role"]
430432

431-
await callbacks.dispatch_event("add_role", member, role, reason=reason)
433+
await callbacks.dispatch_event(CallbackEvent.add_role, member, role, reason=reason)
432434

433435
roles = [role] + [x for x in member.roles if x.id != member.guild.id]
434436
update_member(member, roles=roles)
@@ -440,7 +442,7 @@ async def remove_role(self, guild_id: Snowflake, user_id: Snowflake,
440442
member = locs["self"]
441443
role = locs["role"]
442444

443-
await callbacks.dispatch_event("remove_role", member, role, reason=reason)
445+
await callbacks.dispatch_event(CallbackEvent.remove_role, member, role, reason=reason)
444446

445447
roles = [x for x in member.roles if x != role and x.id != member.guild.id]
446448
update_member(member, roles=roles)
@@ -463,7 +465,7 @@ async def application_info(self) -> _types.appinfo.AppInfo:
463465
}
464466

465467
appinfo = discord.AppInfo(self.state, data)
466-
await callbacks.dispatch_event("app_info", appinfo)
468+
await callbacks.dispatch_event(CallbackEvent.app_info, appinfo)
467469

468470
return data
469471

@@ -540,7 +542,13 @@ async def get_guilds(self, limit: int, before: Snowflake | None = None,
540542
after: Snowflake | None = None,
541543
with_counts: bool = True) -> list[_types.guild.Guild]:
542544
# self.request(Route('GET', '/users/@me/guilds')
543-
await callbacks.dispatch_event("get_guilds", limit, before=before, after=after, with_counts=with_counts)
545+
await callbacks.dispatch_event(
546+
CallbackEvent.get_guilds,
547+
limit,
548+
before=before,
549+
after=after,
550+
with_counts=with_counts,
551+
)
544552
guilds = get_state().guilds # List[]
545553

546554
guilds_new = [facts.dict_from_object(guild) for guild in guilds]

discord/ext/test/callbacks.py

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,52 @@
66

77
import logging
88
import typing
9+
import discord
10+
from enum import Enum
11+
from typing import Callable, overload, Literal, Any, Awaitable
12+
913
from . import _types
1014

15+
GetChannelCallback = Callable[[_types.snowflake.Snowflake], Awaitable[None]]
16+
SendMessageCallback = Callable[[discord.Message], Awaitable[None]]
17+
EditMemberCallback = Callable[[dict[str, Any], discord.Member, str | None], Awaitable[None]]
18+
Callback = GetChannelCallback | SendMessageCallback | EditMemberCallback | Callable[..., Awaitable[None]]
19+
1120
log = logging.getLogger("discord.ext.tests")
1221

13-
_callbacks: dict[str, _types.Callback[...]] = {}
22+
23+
class CallbackEvent(Enum):
24+
get_channel = "get_channel"
25+
presence = "presence"
26+
start_private_message = "start_private_message"
27+
send_message = "send_message"
28+
send_typing = "send_typing"
29+
delete_message = "delete_message"
30+
edit_message = "edit_message"
31+
add_reaction = "add_reaction"
32+
remove_reaction = "remove_reaction"
33+
remove_own_reaction = "remove_own_reaction"
34+
get_message = "get_message"
35+
logs_from = "logs_from"
36+
kick = "kick"
37+
ban = "ban"
38+
unban = "unban"
39+
change_nickname = "change_nickname"
40+
edit_member = "edit_member"
41+
create_role = "create_role"
42+
edit_role = "edit_role"
43+
delete_role = "delete_role"
44+
move_role = "move_role"
45+
add_role = "add_role"
46+
remove_role = "remove_role"
47+
app_info = "app_info"
48+
get_guilds = "get_guilds"
1449

1550

16-
async def dispatch_event(event: str, *args: typing.Any, **kwargs: typing.Any) -> None:
51+
_callbacks: dict[CallbackEvent, Callback] = {}
52+
53+
54+
async def dispatch_event(event: CallbackEvent, *args: typing.Any, **kwargs: typing.Any) -> None:
1755
"""
1856
Dispatch an event to a set handler, if one exists. Will ignore handler errors,
1957
just print a log
@@ -30,7 +68,19 @@ async def dispatch_event(event: str, *args: typing.Any, **kwargs: typing.Any) ->
3068
log.error(f"Error in handler for event {event}: {e}")
3169

3270

33-
def set_callback(cb: _types.Callback[...], event: str) -> None:
71+
@overload
72+
def set_callback(cb: GetChannelCallback, event: Literal[CallbackEvent.get_channel]) -> None: ...
73+
74+
75+
@overload
76+
def set_callback(cb: SendMessageCallback, event: Literal[CallbackEvent.send_message]) -> None: ...
77+
78+
79+
@overload
80+
def set_callback(cb: EditMemberCallback, event: Literal[CallbackEvent.edit_member]) -> None: ...
81+
82+
83+
def set_callback(cb: Callback, event: CallbackEvent) -> None:
3484
"""
3585
Set the callback to use for a specific event
3686
@@ -40,7 +90,7 @@ def set_callback(cb: _types.Callback[...], event: str) -> None:
4090
_callbacks[event] = cb
4191

4292

43-
def get_callback(event: str) -> _types.Callback[...]:
93+
def get_callback(event: CallbackEvent) -> Callback:
4494
"""
4595
Get the current callback for an event, or raise an exception if one isn't set
4696
@@ -52,7 +102,7 @@ def get_callback(event: str) -> _types.Callback[...]:
52102
return _callbacks[event]
53103

54104

55-
def remove_callback(event: str) -> _types.Callback[...] | None:
105+
def remove_callback(event: CallbackEvent) -> Callback | None:
56106
"""
57107
Remove the callback set for an event, returning it, or None if one isn't set
58108

0 commit comments

Comments
 (0)