diff --git a/CHANGELOG.md b/CHANGELOG.md index 4209b0dbb0..f1394c8aaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,10 @@ These changes are available on the `master` branch, but have not yet been releas ### Changed +- Added support for Discord DAVE (Audio & Video E2EE) for voice-receive related features + and refactored the voice-reception system. + ([#3159](https://github.com/Pycord-Development/pycord/pull/3159)) + ### Fixed - Fixed a `TypeError` when using `Label.set_select` and not providing `default_values`. diff --git a/discord/opus.py b/discord/opus.py index f6c42fdd04..fa27a5a112 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -725,9 +725,4 @@ def _decode_packet(self, packet: Packet) -> tuple[Packet, bytes]: else: pcm = self._decoder.decode(None, fec=False) - if HAS_DAVEY: - if user_id is not None and in_dave and dave.can_passthrough(user_id): - _log.debug("User ID %s can passthrough, decrypting with DAVE", user_id) - pcm = dave.decrypt(user_id, davey.MediaType.audio, pcm) - return packet, pcm diff --git a/discord/sinks/core.py b/discord/sinks/core.py index d22969a675..1525ea16d8 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -37,7 +37,10 @@ from .errors import SinkException if TYPE_CHECKING: + from ..member import Member + from ..user import User from ..voice import VoiceClient + from ..voice.packets import VoiceData __all__ = ( "Filters", @@ -210,6 +213,8 @@ class Sink(Filters): Audio may only be formatted after recording is finished. """ + __sink_listeners__: list[tuple[str, str]] = [] + def __init__(self, *, filters=None): if filters is None: filters = default_filters @@ -222,18 +227,38 @@ def __init__(self, *, filters=None): def client(self) -> VoiceClient | None: return self.vc + @property + def recording(self) -> bool: + """Whether the voice client is currently recording.""" + return self.vc is not None and self.vc.is_recording() + + def is_opus(self) -> bool: + """Whether this sink accepts raw opus packets instead of decoded PCM.""" + return False + + def walk_children(self): + """Yields child sinks. Base implementation yields nothing.""" + return + yield # make it a generator + def init(self, vc: VoiceClient): # called under listen self.vc = vc super().init() @Filters.container - def write(self, data, user): + def write(self, data: VoiceData | bytes, user: User | Member | None) -> None: + from ..voice.packets import VoiceData + + if isinstance(data, VoiceData): + pcm_data = data.pcm + else: + pcm_data = data + if user not in self.audio_data: file = io.BytesIO() self.audio_data.update({user: AudioData(file)}) - file = self.audio_data[user] - file.write(data) + self.audio_data[user].write(pcm_data) def cleanup(self): self.finished = True diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index 1cff9da538..a857ffffc8 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -57,7 +57,7 @@ def format_audio(self, audio): M4ASinkError Formatting the audio failed. """ - if self.vc.recording: + if self.recording: raise M4ASinkError( "Audio may only be formatted after recording is finished." ) diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index c2bbefb923..cb8c2dcd74 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -55,7 +55,7 @@ def format_audio(self, audio): MKASinkError Formatting the audio failed. """ - if self.vc.recording: + if self.recording: raise MKASinkError( "Audio may only be formatted after recording is finished." ) diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index 93f4cc7444..0405758a35 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -55,7 +55,7 @@ def format_audio(self, audio): MKVSinkError Formatting the audio failed. """ - if self.vc.recording: + if self.recording: raise MKVSinkError( "Audio may only be formatted after recording is finished." ) diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 74386a2738..a356009a1c 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -55,7 +55,7 @@ def format_audio(self, audio): MP3SinkError Formatting the audio failed. """ - if self.vc.recording: + if self.recording: raise MP3SinkError( "Audio may only be formatted after recording is finished." ) diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index c4d0ed2b63..3158dc5572 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -57,7 +57,7 @@ def format_audio(self, audio): MP4SinkError Formatting the audio failed. """ - if self.vc.recording: + if self.recording: raise MP4SinkError( "Audio may only be formatted after recording is finished." ) diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index 7b531464bd..57232cb5c0 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -55,7 +55,7 @@ def format_audio(self, audio): OGGSinkError Formatting the audio failed. """ - if self.vc.recording: + if self.recording: raise OGGSinkError( "Audio may only be formatted after recording is finished." ) diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index 37f5aac933..b9b53bb34b 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -23,7 +23,9 @@ """ import wave +from io import BytesIO +from ..opus import Decoder as OpusDecoder from .core import Filters, Sink, default_filters from .errors import WaveSinkError @@ -54,16 +56,20 @@ def format_audio(self, audio): WaveSinkError Formatting the audio failed. """ - if self.vc.recording: + if self.recording: raise WaveSinkError( "Audio may only be formatted after recording is finished." ) - data = audio.file + audio.file.seek(0) + pcm_data = audio.file.read() + + data = BytesIO() with wave.open(data, "wb") as f: f.setnchannels(self.vc.decoder.CHANNELS) f.setsampwidth(self.vc.decoder.SAMPLE_SIZE // self.vc.decoder.CHANNELS) f.setframerate(self.vc.decoder.SAMPLING_RATE) - + f.writeframes(pcm_data) data.seek(0) + audio.file = data audio.on_format(self.encoding) diff --git a/discord/utils.py b/discord/utils.py index cc6d9d3b19..2535c9f91d 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -58,6 +58,7 @@ Iterator, Literal, Mapping, + ParamSpec, Protocol, Sequence, TypeVar, @@ -176,6 +177,8 @@ class _RequestLike(Protocol): T = TypeVar("T") T_co = TypeVar("T_co", covariant=True) +_MC_P = ParamSpec("_MC_P") +_MC_T = TypeVar("_MC_T") _Iter = Union[Iterator[T], AsyncIterator[T]] @@ -880,7 +883,11 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: return (reset - now).total_seconds() -async def maybe_coroutine(f, *args, **kwargs): +async def maybe_coroutine( + f: Callable[_MC_P, _MC_T | Awaitable[_MC_T]], + *args: _MC_P.args, + **kwargs: _MC_P.kwargs, +) -> _MC_T: value = f(*args, **kwargs) if _isawaitable(value): return await value diff --git a/discord/voice/__init__.py b/discord/voice/__init__.py index 97fd59cf62..cf89e35aa5 100644 --- a/discord/voice/__init__.py +++ b/discord/voice/__init__.py @@ -11,8 +11,8 @@ from ..errors import MissingVoiceDependenciesError from ..utils import get_missing_voice_dependencies -if missing := get_missing_voice_dependencies(): - raise MissingVoiceDependenciesError(missing=missing) +if _missing := get_missing_voice_dependencies(): + raise MissingVoiceDependenciesError(missing=_missing) from ._types import * from .client import * diff --git a/discord/voice/client.py b/discord/voice/client.py index cea26eb1ad..7ff9f6298d 100644 --- a/discord/voice/client.py +++ b/discord/voice/client.py @@ -30,7 +30,11 @@ import logging import struct import warnings -from typing import TYPE_CHECKING, Any, Literal, overload +from collections.abc import Callable, Coroutine +from typing import TYPE_CHECKING, Any, Literal, cast, overload + +import nacl.secret +import nacl.utils from discord import opus from discord.enums import SpeakingState, try_enum @@ -45,13 +49,15 @@ from .enums import OpCodes from .receive import AudioReader from .state import VoiceConnectionState -from .utils.dependencies import HAS_DAVEY, HAS_NACL +from .utils.dependencies import has_nacl -if HAS_NACL: +if has_nacl: import nacl.secret import nacl.utils if TYPE_CHECKING: + from typing import TypeVar + from typing_extensions import ParamSpec from discord import abc @@ -71,13 +77,14 @@ from .receive.reader import AfterCallback P = ParamSpec("P") + T = TypeVar("T") _log = logging.getLogger(__name__) __all__ = ("VoiceClient",) -class VoiceClient(VoiceProtocol): +class VoiceClient(VoiceProtocol["Client"]): """Represents a Discord voice connection. You do not create these, you typically get them from e.g. @@ -107,15 +114,6 @@ def __init__( client: Client, channel: abc.Connectable, ) -> None: - missing = get_missing_voice_dependencies() - if missing: - deps = ", ".join(missing) - raise RuntimeError( - f"{deps} {'library is' if len(missing) == 1 else 'libraries are'} needed " - "in order to use voice related features, " - 'you can run "pip install py-cord[voice]" to install all voice-related ' - "dependencies." - ) super().__init__(client, channel) state = client._connection @@ -136,7 +134,9 @@ def __init__( self._ssrc_to_id: dict[int, int] = {} self._id_to_ssrc: dict[int, int] = {} - self._event_listeners: dict[str, list] = {} + self._event_listeners: dict[ + str, list[Callable[..., Coroutine[Any, Any, Any]]] + ] = {} self._reader: AudioReader = MISSING @staticmethod @@ -156,7 +156,7 @@ def _set_future_result_if_pending( @property def guild(self) -> Guild: """Returns the guild the channel we're connecting to is bound to.""" - channel: VocalGuildChannel = self.channel + channel = cast("VocalGuildChannel", self.channel) return channel.guild @property @@ -259,7 +259,11 @@ async def _recv_hook(self, ws: VoiceWebSocket, msg: dict[str, Any]) -> None: # maybe handle video and such things? async def _run_event( - self, coro, event_name: str, *args: Any, **kwargs: Any + self, + coro: Callable[..., Coroutine[P, None]], + event_name: str, + *args: P.args, + **kwargs: P.kwargs, ) -> None: try: await coro(*args, **kwargs) @@ -269,8 +273,12 @@ async def _run_event( _log.exception("Error calling %s", event_name) def _schedule_event( - self, coro, event_name: str, *args: Any, **kwargs: Any - ) -> asyncio.Task: + self, + coro: Callable[..., Coroutine[Any, Any, T]], + event_name: str, + *args: Any, + **kwargs: Any, + ) -> asyncio.Task[T]: wrapped = self._run_event(coro, event_name, *args, **kwargs) return self.client.loop.create_task( wrapped, name=f"voice-receiver-event-dispatch: {event_name}" @@ -433,6 +441,7 @@ def _get_voice_packet(self, data: Any) -> bytes: return encrypt_packet(header, packet) # encryption methods + # nacl is guaranteed to be available here because __init__ raises if missing def _encrypt_xsalsa20_poly1305(self, header: bytes, data: Any) -> bytes: # deprecated @@ -569,8 +578,10 @@ def play( raise ClientException("Not connected to voice") if self.is_playing(): raise ClientException("Already playing audio") - if not isinstance(source, AudioSource): - raise TypeError( + if not isinstance( + source, AudioSource + ): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( # pyright: ignore[reportUnreachable] f"Source must be an AudioSource, not {source.__class__.__name__}", ) if not self.encoder and not source.is_opus(): @@ -636,8 +647,12 @@ def source(self) -> AudioSource | None: @source.setter def source(self, value: AudioSource) -> None: - if not isinstance(value, AudioSource): - raise TypeError(f"expected AudioSource, not {value.__class__.__name__}") + if not isinstance( + value, AudioSource + ): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( + f"expected AudioSource, not {value.__class__.__name__}" + ) # pyright: ignore[reportUnreachable] if self._player is None: raise ValueError("the client is not playing anything") @@ -699,10 +714,6 @@ def start_recording( .. versionadded:: 2.0 - .. warning:: - - Recording may not work as expected due to the new DAVE (End-to-End Encryption) for voice calls. - Parameters ---------- sink: :class:`~.Sink` @@ -732,17 +743,12 @@ def start_recording( TypeError You did not provide a Sink object. """ - warnings.warn( - "Voice reception is currently broken due to Discord's DAVE (End-to-End Encryption) protocol. " - + "Follow development progress at https://github.com/Pycord-Development/pycord/issues/3139", - RuntimeWarning, - stacklevel=2, - ) - # TODO: remove warning in voice-recv fix PR if not self.is_connected(): raise RecordingException("not connected to a voice channel") - if not isinstance(sink, Sink): - raise TypeError(f"expected a Sink object, got {sink.__class__.__name__}") + if not isinstance(sink, Sink): # pyright: ignore[reportUnnecessaryIsInstance] + raise TypeError( + f"expected a Sink object, got {sink.__class__.__name__}" + ) # pyright: ignore[reportUnreachable] if self.is_recording(): raise ClientException("Already recording audio") @@ -770,12 +776,6 @@ def stop_recording(self) -> None: RecordingException You are not recording. """ - warnings.warn( - "Voice reception is currently broken due to Discord's DAVE (End-to-End Encryption) protocol. " - + "Follow development progress at https://github.com/Pycord-Development/pycord/issues/3139", - RuntimeWarning, - stacklevel=2, - ) if self._reader is not MISSING: self._reader.stop() self._reader = MISSING @@ -797,12 +797,6 @@ def is_speaking(self, member: Member | User) -> bool | None: .. versionadded:: 2.7 """ - warnings.warn( - "Voice reception is currently broken due to Discord's DAVE (End-to-End Encryption) protocol. " - + "Follow development progress at https://github.com/Pycord-Development/pycord/issues/3139", - RuntimeWarning, - stacklevel=2, - ) ssrc = self._id_to_ssrc.get(member.id) if ssrc is None: return None diff --git a/discord/voice/gateway.py b/discord/voice/gateway.py index d277917d2f..8ebb8b88c4 100644 --- a/discord/voice/gateway.py +++ b/discord/voice/gateway.py @@ -35,11 +35,7 @@ from typing import TYPE_CHECKING, Any import aiohttp - -from .utils.dependencies import HAS_DAVEY - -if HAS_DAVEY: - import davey +import davey from discord import utils from discord.enums import SpeakingState @@ -60,7 +56,7 @@ class KeepAliveHandler(KeepAliveHandlerBase): if TYPE_CHECKING: - ws: VoiceWebSocket + ws: VoiceWebSocket # pyright: ignore[reportIncompatibleVariableOverride] def __init__( self, @@ -129,24 +125,34 @@ def __init__( self._hook = hook or state.ws_hook # type: ignore @property - def token(self) -> str | None: + def token( + self, + ) -> str | None: # pyright: ignore[reportIncompatibleVariableOverride] return self.state.token @token.setter - def token(self, value: str | None) -> None: + def token( + self, value: str | None + ) -> None: # pyright: ignore[reportIncompatibleVariableOverride] self.state.token = value @property - def session_id(self) -> str | None: + def session_id( + self, + ) -> str | None: # pyright: ignore[reportIncompatibleVariableOverride] return self.state.session_id @session_id.setter - def session_id(self, value: str | None) -> None: + def session_id( + self, value: str | None + ) -> None: # pyright: ignore[reportIncompatibleVariableOverride] self.state.session_id = value @property def self_id(self) -> int: - return self._connection.self_id + self_id = self._connection.self_id + assert self_id is not None + return self_id async def _hook(self, *args: Any) -> Any: pass @@ -178,7 +184,7 @@ async def resume(self) -> None: } await self.send_as_json(payload) - async def received_message(self, msg: Any, /): + async def received_message(self, msg: Any, /) -> None: _log.debug("Voice websocket frame received: %s", msg) op = msg["op"] data = msg.get("d", {}) # this key should ALWAYS be given, but guard anyways @@ -206,9 +212,11 @@ async def received_message(self, msg: Any, /): await state.reinit_dave_session() elif op == OpCodes.hello: interval = data["heartbeat_interval"] / 1000.0 - self._keep_alive = KeepAliveHandler( - ws=self, - interval=min(interval, 5), + self._keep_alive = ( + KeepAliveHandler( # pyright: ignore[reportIncompatibleVariableOverride] + ws=self, + interval=min(interval, 5), + ) ) self._keep_alive.start() elif state.dave_session: diff --git a/discord/voice/packets/core.py b/discord/voice/packets/core.py index c38fc9face..9b45fea2ad 100644 --- a/discord/voice/packets/core.py +++ b/discord/voice/packets/core.py @@ -46,8 +46,8 @@ class Packet: ssrc: int sequence: int timestamp: int - type: int - decrypted_data: bytes + type: int | None + decrypted_data: bytes | None def __init__(self, data: bytes) -> None: self.data: bytes = data diff --git a/discord/voice/packets/rtp.py b/discord/voice/packets/rtp.py index dd9520cf4e..c7417ae255 100644 --- a/discord/voice/packets/rtp.py +++ b/discord/voice/packets/rtp.py @@ -26,14 +26,10 @@ from __future__ import annotations import struct -from collections import namedtuple -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, NamedTuple from .core import OPUS_SILENCE, Packet -if TYPE_CHECKING: - from typing_extensions import Final - MAX_UINT_32 = 0xFFFFFFFF MAX_UINT_16 = 0xFFFF @@ -48,8 +44,8 @@ def decode(data: bytes) -> Packet: class FakePacket(Packet): data = b"" - decrypted_data: bytes = b"" - extension_data: dict = {} + decrypted_data: bytes | None = b"" + extension_data: dict[int, bytes] = {} def __init__( self, @@ -66,8 +62,8 @@ def __bool__(self) -> Literal[False]: class SilencePacket(Packet): - decrypted_data: Final = OPUS_SILENCE - extension_data: Final[dict[int, Any]] = {} + decrypted_data: bytes | None = OPUS_SILENCE + extension_data: dict[int, Any] = {} sequence: int = -1 def __init__(self, ssrc: int, timestamp: int) -> None: @@ -90,7 +86,12 @@ class RTPPacket(Packet): """ _hstruct = struct.Struct(">xxHII") - _ext_header = namedtuple("Extension", "profile length values") + + class _ext_header(NamedTuple): + profile: bytes + length: int + values: tuple[int, ...] | list[int] + _ext_magic = b"\xbe\xde" def __init__(self, data: bytes) -> None: @@ -210,7 +211,7 @@ def __repr__(self) -> str: class RTCPPacket(Packet): _header = struct.Struct(">BBH") _ssrc_fmt = struct.Struct(">I") - type = None + type: int | None = None def __init__(self, data: bytes) -> None: super().__init__(data) @@ -234,18 +235,26 @@ def _parse_low(x: int, bitlen: int = 32) -> float: return x / 2.0**bitlen -def _to_low(x: float, bitlen: int = 32) -> int: - return int(x * 2.0**bitlen) - - class SenderReportPacket(RTCPPacket): _info_fmt = struct.Struct(">5I") _report_fmt = struct.Struct(">IB3x4I") _24bit_int_fmt = struct.Struct(">4xI") - _info = namedtuple("RRSenderInfo", "ntp_ts rtp_ts packet_count octet_count") - _report = namedtuple( - "RReport", "ssrc perc_loss total_lost last_seq jitter lsr dlsr" - ) + + class _info(NamedTuple): + ntp_ts: float + rtp_ts: int + packet_count: int + octet_count: int + + class _report(NamedTuple): + ssrc: int + perc_loss: int + total_lost: int + last_seq: int + jitter: int + lsr: int + dlsr: int + type = 200 if TYPE_CHECKING: @@ -257,13 +266,12 @@ def __init__(self, data: bytes) -> None: self.ssrc = self._ssrc_fmt.unpack_from(data, 4)[0] self.info = self._read_sender_info(data, 8) - _report = self._report - reports: list[_report] = [] + reports: list[SenderReportPacket._report] = [] for x in range(self.report_count): offset = 28 + 24 * x reports.append(self._read_report(data, offset)) - self.reports: tuple[_report, ...] = tuple(reports) + self.reports: tuple[SenderReportPacket._report, ...] = tuple(reports) self.extension = None if len(data) > 28 + 24 * self.report_count: self.extension = data[28 + 24 * self.report_count :] @@ -282,12 +290,17 @@ def _read_report(self, data: bytes, offset: int) -> _report: class ReceiverReportPacket(RTCPPacket): _report_fmt = struct.Struct(">IB3x4I") _24bit_int_fmt = struct.Struct(">4xI") - _report = namedtuple( - "RReport", "ssrc perc_loss total_loss last_seq jitter lsr dlsr" - ) - type = 201 - reports: tuple[_report, ...] + class _report(NamedTuple): + ssrc: int + perc_loss: int + total_loss: int + last_seq: int + jitter: int + lsr: int + dlsr: int + + type = 201 if TYPE_CHECKING: report_count: int @@ -296,13 +309,12 @@ def __init__(self, data: bytes) -> None: super().__init__(data) self.ssrc: int = self._ssrc_fmt.unpack_from(data, 4)[0] - _report = self._report - reports: list[_report] = [] + reports: list[ReceiverReportPacket._report] = [] for x in range(self.report_count): offset = 8 + 24 * x reports.append(self._read_report(data, offset)) - self.reports = tuple(reports) + self.reports: tuple[ReceiverReportPacket._report, ...] = tuple(reports) self.extension: bytes | None = None if len(data) > 8 + 24 * self.report_count: diff --git a/discord/voice/receive/reader.py b/discord/voice/receive/reader.py index 7ec0300c66..afc38ee9e8 100644 --- a/discord/voice/receive/reader.py +++ b/discord/voice/receive/reader.py @@ -32,19 +32,14 @@ from operator import itemgetter from typing import TYPE_CHECKING, Any, Literal +import davey +import nacl.secret +from nacl.exceptions import CryptoError + from ..packets.core import OPUS_SILENCE from ..packets.rtp import ReceiverReportPacket, RTCPPacket, decode -from ..utils.dependencies import HAS_DAVEY, HAS_NACL from .router import PacketRouter, SinkEventRouter -if HAS_DAVEY: - import davey - -if HAS_NACL: - import nacl.secret - from nacl.exceptions import CryptoError - - if TYPE_CHECKING: from discord.member import Member from discord.sinks import Sink @@ -77,8 +72,10 @@ def __init__( after: AfterCallback | None = None, start: bool = False, ) -> None: - if after is not None and not callable(after): - raise TypeError( + if after is not None and not callable( + after + ): # pyright: ignore[reportUnnecessaryComparison] + raise TypeError( # pyright: ignore[reportUnreachable] f"expected a callable for the 'after' parameter, got {after.__class__.__name__!r} instead" ) @@ -86,7 +83,7 @@ def __init__( self.client: VoiceClient = client self.after: AfterCallback | None = after - # self.sink._client = client + self.sink.init(client) self.active: bool = False self.error: Exception | None = None @@ -124,15 +121,17 @@ def stop(self) -> None: _log.debug("Reader is not active") return + self.active = False self.client._connection.remove_socket_listener(self.callback) self.speaking_timer.notify() self._stop() - self.active = False def _stop(self) -> None: try: if self.packet_router.is_alive(): self.packet_router.stop() + if threading.current_thread() is not self.packet_router: + self.packet_router.join(timeout=5) except Exception as exc: self.error = exc _log.exception("An error ocurred while stopping packet router.") @@ -154,16 +153,13 @@ def _stop(self) -> None: "An error ocurred while calling the after callback on audio reader" ) - """for sink in self.sink.root.walk_children(with_self=True): - try: - sink.cleanup() - except Exception as exc: - _log.exception("Error calling cleanup() for %s", sink, exc_info=exc)""" + try: + self.sink.cleanup() + except Exception as exc: + _log.exception("Error calling cleanup() for %s", self.sink, exc_info=exc) def set_sink(self, sink: Sink) -> Sink: old_sink = self.sink - # old_sink._client = None - # sink._client = self.client self.packet_router.set_sink(sink) self.sink = sink return old_sink @@ -271,22 +267,6 @@ def _make_box(self, secret_key: bytes) -> EncryptionBox: else: return nacl.secret.SecretBox(secret_key) - """def decrypt_rtp(self, packet: RTPPacket) -> bytes: - state = self.client._connection - dave = state.dave_session - data = self._decryptor_rtp(packet) - - if dave is not None and dave.ready and packet.ssrc in state.ssrc_user_map: - data = dave.decrypt( - state.ssrc_user_map[packet.ssrc], davey.MediaType.audio, data - ) - - if packet.extended: - offset = packet.update_extended_header(data) - data = data[offset:] - - return data""" - def decrypt_rtp(self, packet: RTPPacket) -> bytes: state = self.client._connection dave = state.dave_session @@ -295,26 +275,55 @@ def decrypt_rtp(self, packet: RTPPacket) -> bytes: if dave is not None and dave.ready: uid = state.ssrc_user_map.get(packet.ssrc) - if uid: + + if not uid: + # SSRC -> user_id mapping not yet populated (race with member_connect). + # Try every user ID known to the DAVE session until one decrypts. + # This is ugly but I didn't manage to get it to work otherwise. If you have a better implementation, + # please open a PR. + for candidate_uid in dave.get_user_ids(): + try: + int_uid = int(candidate_uid) + decrypted_audio = dave.decrypt( + int_uid, + davey.MediaType.audio, + raw_payload, + ) + # Successfully decrypted - cache the mapping for next time + self.client._connection.user_ssrc_map[int_uid] = packet.ssrc + uid = int_uid + raw_payload = decrypted_audio + _log.debug( + "DAVE: inferred ssrc %s -> user_id %s from decryption", + packet.ssrc, + uid, + ) + break + except ValueError: + continue + else: + raw_payload = OPUS_SILENCE + else: try: - decrypted_audio = dave.decrypt( + raw_payload = dave.decrypt( uid, davey.MediaType.audio, raw_payload, ) - - if packet.extended: - offset = packet.update_extended_header(decrypted_audio) - packet.decrypted_data = decrypted_audio[offset:] - else: - packet.decrypted_data = decrypted_audio - except Exception as exc: + except ValueError: + # UnencryptedWhenPassthroughDisabled here is actually misleading, we can't passthrough, + # it gives a corrupted stream. _log.debug( - "Ignoring exception while decoding DAVE packet", exc_info=exc + "DAVE: Decryption failed, falling back to OPUS_SILENCE", + exc_info=True, ) - packet.decrypted_data = OPUS_SILENCE + raw_payload = OPUS_SILENCE - return packet.decrypted_data + packet.decrypted_data = raw_payload + else: # e.g., stage channels + packet.decrypted_data = raw_payload + + return packet.decrypted_data or b"" def decrypt_rtcp(self, packet: bytes) -> bytes: data = self._decryptor_rtcp(packet) @@ -425,9 +434,10 @@ def _decrypt_rtp_aead_xchacha20_poly1305_rtpsize(self, packet: RTPPacket) -> byt raise CryptoError(exc) if packet.extended: - packet.update_extended_header(result) + offset = packet.update_extended_header(result) + return result[offset:] - return result[8:] + return result def _decrypt_rtcp_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes: _log.debug("Decrypting RTCP AEAD XChaCha20 Poly1305 RTPSize") diff --git a/discord/voice/receive/router.py b/discord/voice/receive/router.py index b7f2f28064..f73c2e51b0 100644 --- a/discord/voice/receive/router.py +++ b/discord/voice/receive/router.py @@ -34,6 +34,7 @@ from discord.opus import PacketDecoder +from ...sinks.errors import RecordingException from ..utils.multidataevent import MultiDataEvent if TYPE_CHECKING: @@ -121,7 +122,10 @@ def run(self) -> None: _log.exception("Error in %s loop", self) self.reader.error = exc finally: - self.reader.client.stop_recording() + try: + self.reader.client.stop_recording() + except RecordingException: + pass self.waiter.clear() def _do_run(self) -> None: diff --git a/discord/voice/state.py b/discord/voice/state.py index 7f659899ad..742825b9d9 100644 --- a/discord/voice/state.py +++ b/discord/voice/state.py @@ -33,10 +33,12 @@ from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any +import davey + from discord import utils from discord.backoff import ExponentialBackoff from discord.errors import ConnectionClosed -from discord.voice.utils.dependencies import DAVE_PROTOCOL_VERSION, HAS_DAVEY +from discord.voice.utils.dependencies import dave_protocol_version from .enums import ConnectionFlowState, OpCodes from .gateway import VoiceWebSocket @@ -57,9 +59,6 @@ _log = logging.getLogger(__name__) _recv_log = logging.getLogger("discord.voice.receiver") -if HAS_DAVEY: - import davey - class SocketReader(threading.Thread): def __init__( @@ -261,7 +260,7 @@ def __init__( self.recording_done_callbacks: list[ tuple[Callable[..., Coroutine[Any, Any, Any]], tuple[Any, ...]] ] = [] - self._dispatch_task_set: set[asyncio.Task] = set() + self._dispatch_task_set: set[asyncio.Task[None]] = set() if not self._connection.self_id: raise RuntimeError("client self ID is not set") @@ -283,7 +282,7 @@ def ssrc_user_map(self) -> dict[int, int]: @property def max_dave_proto_version(self) -> int: - return DAVE_PROTOCOL_VERSION + return dave_protocol_version @property def state(self) -> ConnectionFlowState: @@ -312,8 +311,8 @@ def user(self) -> ClientUser: return self.client.user @property - def channel_id(self) -> int | None: - return self.client.channel is not None and self.client.channel.id + def channel_id(self) -> int: + return self.client.channel.id @property def guild_id(self) -> int: @@ -393,7 +392,7 @@ async def voice_server_update(self, data: RawVoiceServerUpdateEvent) -> None: self.server_id = data.guild_id endpoint = data.endpoint - if self.token is None or endpoint is None: + if endpoint is None: _log.warning( "Awaiting endpoint... This requires waiting. " "If timeout occurred considering raising the timeout and reconnecting." @@ -713,8 +712,10 @@ async def _voice_connect( self, *, self_deaf: bool = False, self_mute: bool = False ) -> None: channel = self.client.channel - await channel.guild.change_voice_state( - channel=channel, self_deaf=self_deaf, self_mute=self_mute + await self.guild.change_voice_state( + channel=channel, + self_deaf=self_deaf, + self_mute=self_mute, # pyright: ignore[reportArgumentType] ) async def _voice_disconnect(self) -> None: @@ -725,7 +726,7 @@ async def _voice_disconnect(self) -> None: ) self.state = ConnectionFlowState.disconnected - await self.client.channel.guild.change_voice_state( + await self.guild.change_voice_state( channel=None ) # pyright: ignore[reportAttributeAccessIssue] self._expecting_disconnect = True @@ -901,9 +902,9 @@ async def _potential_reconnect(self) -> bool: await previous_ws.close() async def _move_to(self, channel: abc.Snowflake) -> None: - await self.client.channel.guild.change_voice_state( - channel=channel - ) # pyright: ignore[reportAttributeAccessIssue] + await self.guild.change_voice_state( + channel=channel # pyright: ignore[reportArgumentType] + ) self.state = ConnectionFlowState.set_guild_voice_state def _update_voice_channel(self, channel_id: int | None) -> None: diff --git a/discord/voice/utils/buffer.py b/discord/voice/utils/buffer.py index f12fb0ae6e..fd59eecd08 100644 --- a/discord/voice/utils/buffer.py +++ b/discord/voice/utils/buffer.py @@ -28,57 +28,16 @@ import heapq import logging import threading -from typing import Protocol, TypeVar from ..packets import Packet from .wrapped import add_wrapped, gap_wrapped -__all__ = ( - "Buffer", - "JitterBuffer", -) +__all__ = ("JitterBuffer",) - -T = TypeVar("T") -PacketT = TypeVar("PacketT", bound=Packet) _log = logging.getLogger(__name__) -class Buffer(Protocol[T]): - def __len__(self) -> int: ... - def push(self, item: T) -> None: ... - def pop(self) -> T | None: ... - def peek(self) -> T | None: ... - def flush(self) -> list[T]: ... - def reset(self) -> None: ... - - -class BaseBuff(Buffer[PacketT]): - def __init__(self) -> None: - self._buffer: list[PacketT] = [] - - def __len__(self) -> int: - return len(self._buffer) - - def push(self, item: PacketT) -> None: - self._buffer.append(item) - - def pop(self) -> PacketT | None: - return self._buffer.pop() - - def peek(self) -> PacketT | None: - return self._buffer[-1] if self._buffer else None - - def flush(self) -> list[PacketT]: - buf = self._buffer.copy() - self._buffer.clear() - return buf - - def reset(self) -> None: - self._buffer.clear() - - -class JitterBuffer(BaseBuff[PacketT]): +class JitterBuffer: _threshold: int = 10000 def __init__( @@ -96,9 +55,11 @@ def __init__( self._prefill: int = prefill self._last_tx_seq: int = -1 self._has_item: threading.Event = threading.Event() - # self._lock: threading.Lock = threading.Lock() self._buffer: list[Packet] = [] + def __len__(self) -> int: + return len(self._buffer) + def _push(self, packet: Packet) -> None: heapq.heappush(self._buffer, packet) diff --git a/discord/voice/utils/dependencies.py b/discord/voice/utils/dependencies.py index e3a8c3f83d..b373c84ffd 100644 --- a/discord/voice/utils/dependencies.py +++ b/discord/voice/utils/dependencies.py @@ -24,17 +24,19 @@ try: import davey + + _ = davey.DAVE_PROTOCOL_VERSION except ImportError: - HAS_DAVEY = False - DAVE_PROTOCOL_VERSION = 0 + has_davey = False + dave_protocol_version = 0 else: - HAS_DAVEY = True - DAVE_PROTOCOL_VERSION = davey.DAVE_PROTOCOL_VERSION + has_davey = True + dave_protocol_version = davey.DAVE_PROTOCOL_VERSION try: - import nacl.secret - import nacl.utils + import nacl.secret # pyright: ignore[reportUnusedImport] + import nacl.utils # pyright: ignore[reportUnusedImport] except ImportError: - HAS_NACL = False + has_nacl = False else: - HAS_NACL = True + has_nacl = True diff --git a/discord/voice/utils/multidataevent.py b/discord/voice/utils/multidataevent.py index e0079b5a54..b266c2b304 100644 --- a/discord/voice/utils/multidataevent.py +++ b/discord/voice/utils/multidataevent.py @@ -37,7 +37,7 @@ class MultiDataEvent(Generic[T]): with accompanying data object for convenience. """ - def __init__(self): + def __init__(self) -> None: self._items: list[T] = [] self._ready: threading.Event = threading.Event()