Skip to content

Commit 3dc24aa

Browse files
committed
Work on improving types, dict construction
1 parent a0b8b2a commit 3dc24aa

3 files changed

Lines changed: 90 additions & 31 deletions

File tree

discord/ext/test/backend.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ async def create_channel(
130130
channel = make_category_channel(name, guild, permission_overwrites=perms)
131131
elif channel_type == discord.ChannelType.voice.value:
132132
channel = make_voice_channel(name, guild, permission_overwrites=perms)
133-
134133
else:
135134
raise NotImplementedError(
136135
"Operation occurred that isn't captured by the tests framework. This is dpytest's fault, please report"
@@ -892,22 +891,22 @@ def make_member(user: discord.user.BaseUser, guild: discord.Guild,
892891

893892
def update_member(member: discord.Member, nick: str | None = None,
894893
roles: list[discord.Role] | None = None) -> discord.Member:
895-
data = facts.dict_from_object(member)
894+
data = facts.dict_from_object(member, guild=True)
896895
if nick is not None:
897896
data["nick"] = nick
898897
if roles is not None:
899898
data["roles"] = list(map(lambda x: x.id, roles))
900899

901900
state = get_state()
902-
state.parse_guild_member_update(data) # type: ignore[arg-type]
901+
state.parse_guild_member_update(data)
903902

904903
return member
905904

906905

907906
def delete_member(member: discord.Member) -> None:
908-
out = facts.dict_from_object(member)
907+
out = facts.dict_from_object(member, guild=True)
909908
state = get_state()
910-
state.parse_guild_member_remove(out) # type: ignore[arg-type]
909+
state.parse_guild_member_remove(out)
911910

912911

913912
def make_message(

discord/ext/test/factories.py

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ def _fill_optional(
4949
) -> None: ...
5050

5151

52+
@overload
53+
def _fill_optional(
54+
data: _types.gateway.GuildMemberUpdateEvent,
55+
obj: discord.Member | dict[str, object],
56+
items: Iterable[str]
57+
) -> None: ...
58+
59+
5260
@overload
5361
def _fill_optional(
5462
data: _types.guild.Guild,
@@ -95,17 +103,36 @@ def _fill_optional( # type: ignore[misc]
95103
items: Iterable[str]
96104
) -> None:
97105
if isinstance(obj, dict):
98-
for item in items:
99-
result = obj.pop(item, None)
100-
if result is None:
101-
continue
102-
data[item] = result
103-
if len(obj) > 0:
104-
print("Warning: Invalid attributes passed")
106+
_fill_optional_dict(data, obj, items)
105107
else:
106-
for item in items:
107-
if hasattr(obj, item):
108-
data[item] = getattr(obj, item)
108+
_fill_optional_value(data, obj, items)
109+
110+
111+
def _fill_optional_dict(
112+
data: dict[str, object],
113+
obj: dict[str, object],
114+
items: Iterable[str],
115+
) -> None:
116+
for item in items:
117+
result = obj.pop(item, None)
118+
if result is None:
119+
continue
120+
data[item] = result
121+
if len(obj) > 0:
122+
print("Warning: Invalid attributes passed")
123+
124+
125+
def _fill_optional_value(
126+
data: dict[str, object],
127+
obj: object,
128+
items: Iterable[str],
129+
) -> None:
130+
for item in items:
131+
if item == "permissions":
132+
print()
133+
if (val := getattr(obj, item, None)) is None and (val := getattr(obj, f"_{item}", None)) is None:
134+
continue
135+
data[item] = val
109136

110137

111138
def make_user_dict(username: str, discrim: str | int, avatar: str | None, id_num: int = -1, flags: int = 0,
@@ -125,7 +152,7 @@ def make_user_dict(username: str, discrim: str | int, avatar: str | None, id_num
125152
'avatar': avatar,
126153
'flags': flags,
127154
}
128-
items = ("bot", "mfa_enabled", "locale", "verified", "email", "premium_type")
155+
items = ("bot", "system", "mfa_enabled", "locale", "verified", "email", "premium_type", "public_flags")
129156
_fill_optional(out, kwargs, items)
130157
return out
131158

@@ -147,7 +174,7 @@ def make_member_dict(
147174
'mute': mute,
148175
'flags': flags,
149176
}
150-
items = ("nick",)
177+
items = ("avatar", "nick", "premium_since", "pending", "permissions", "communication_disabled_until", "avatar_decoration_data")
151178
_fill_optional(out, kwargs, items)
152179
return out
153180

@@ -168,7 +195,11 @@ class DictFromObject(Protocol):
168195
@overload
169196
def __call__(self, obj: discord.user.BaseUser) -> _types.member.UserWithMember: ...
170197
@overload
171-
def __call__(self, obj: discord.Member) -> _types.member.MemberWithUser: ...
198+
def __call__(self, obj: discord.Member, *, guild: Literal[False] = ...) -> _types.member.MemberWithUser: ...
199+
@overload
200+
def __call__(self, obj: discord.Member, *, guild: Literal[True] = ...) -> _types.gateway.GuildMemberUpdateEvent: ...
201+
@overload
202+
def __call__(self, obj: discord.Member, *, guild: bool = ...) -> _types.member.MemberWithUser | _types.gateway.GuildMemberUpdateEvent: ...
172203
@overload
173204
def __call__(self, obj: discord.Role) -> _types.role.Role: ...
174205

@@ -238,21 +269,36 @@ def _from_base_user(user: discord.user.BaseUser) -> _types.member.UserWithMember
238269

239270

240271
@dict_from_object.register(discord.Member)
241-
def _from_member(member: discord.Member) -> _types.member.MemberWithUser:
272+
def _from_member(member: discord.Member, *, guild: bool = False) -> _types.member.MemberWithUser | _types.gateway.GuildMemberUpdateEvent:
242273
# discord code adds default role to every member later on in Member constructor
243274
roles_no_default = list(filter(lambda r: not r == member.guild.default_role, member.roles))
244-
out: _types.member.MemberWithUser = {
245-
'guild_id': member.guild.id, # type: ignore[typeddict-unknown-key]
246-
'user': dict_from_object(member._user),
247-
'roles': list(map(lambda role: int(role.id), roles_no_default)),
248-
'joined_at': str(int(member.joined_at.timestamp())) if member.joined_at else None,
249-
'flags': member.flags.value,
250-
'deaf': member.voice.deaf if member.voice else False,
251-
'mute': member.voice.mute if member.voice else False,
252-
}
253-
items = ("nick",)
254-
_fill_optional(out, member, items)
255-
return out
275+
items: tuple[str, ...]
276+
if guild:
277+
out: _types.gateway.GuildMemberUpdateEvent = {
278+
'guild_id': member.guild.id,
279+
'user': dict_from_object(member._user),
280+
'avatar': member.avatar.url if member.avatar else "",
281+
'roles': list(map(lambda role: int(role.id), roles_no_default)),
282+
'joined_at': str(int(member.joined_at.timestamp())) if member.joined_at else None,
283+
'flags': member.flags.value,
284+
'deaf': member.voice.deaf if member.voice else False,
285+
'mute': member.voice.mute if member.voice else False,
286+
}
287+
items = ("nick", "premium_since", "pending", "permissions", "communication_disabled_until", "avatar_decoration_data")
288+
_fill_optional(out, member, items)
289+
return out
290+
else:
291+
mem_user: _types.member.MemberWithUser = {
292+
'user': dict_from_object(member._user),
293+
'roles': list(map(lambda role: int(role.id), roles_no_default)),
294+
'joined_at': str(int(member.joined_at.timestamp())) if member.joined_at else None,
295+
'flags': member.flags.value,
296+
'deaf': member.voice.deaf if member.voice else False,
297+
'mute': member.voice.mute if member.voice else False,
298+
}
299+
items = ("avatar", "nick", "premium_since", "pending", "permissions", "communication_disabled_until", "avatar_decoration_data")
300+
_fill_optional(mem_user, member, items)
301+
return mem_user
256302

257303

258304
@dict_from_object.register(discord.Role)

tests/test_role.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,17 @@ async def test_remove_role2(bot: discord.Client) -> None:
5555
# then remove_role
5656
await dpytest.remove_role(0, staff_role)
5757
assert staff_role not in guild.members[0].roles
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_member_add_roles(bot: discord.Client) -> None:
62+
guild = bot.guilds[0]
63+
member = guild.members[0]
64+
65+
staff_role = await guild.create_role(name="Staff")
66+
user_role = await guild.create_role(name="User")
67+
68+
await member.add_roles(*[staff_role, user_role])
69+
70+
assert staff_role in member.roles
71+
assert user_role in member.roles

0 commit comments

Comments
 (0)