Skip to content
111 changes: 87 additions & 24 deletions discord/voice/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ async def received_binary_message(self, msg: bytes) -> None:
)
elif op == OpCodes.mls_proposals:
op_type = msg[3]
epoch_before = state.dave_session.epoch
result = state.dave_session.process_proposals(
(
davey.ProposalsOperationType.append
Expand All @@ -282,6 +283,13 @@ async def received_binary_message(self, msg: bytes) -> None:
),
msg[4:],
)
_log.info(
"process_proposals done — epoch %s→%s ready=%s result=%s",
epoch_before,
state.dave_session.epoch,
state.dave_session.ready,
type(result).__name__,
)

if isinstance(result, davey.CommitWelcome):
data = (
Expand All @@ -294,49 +302,104 @@ async def received_binary_message(self, msg: bytes) -> None:
OpCodes.mls_commit_welcome,
data,
)
# Apply our own commit immediately so we use the same epoch key
# material that Discord will forward to other participants.
# This avoids the mismatch when Discord sends us mls_welcome (op30)
# for a different group context.
try:
state.dave_session.process_commit(result.commit)
auth = state.dave_session.get_epoch_authenticator()
_log.info(
"Self-applied CommitWelcome.commit — epoch=%s ready=%s user_ids=%s privacy_code=%s epoch_auth=%s",
state.dave_session.epoch,
state.dave_session.ready,
state.dave_session.get_user_ids(),
state.dave_session.voice_privacy_code,
auth.hex() if auth else None,
)
except Exception as exc:
_log.warning("Self-commit failed (non-fatal): %s", exc)
_log.debug("Processed MLS proposals for current dave session: %r", result)
elif op == OpCodes.mls_commit_transition:
transt_id = struct.unpack_from(">H", msg, 3)[0]
try:
state.dave_session.process_commit(msg[5:])
# If session is already ready (self-commit was applied), skip re-processing.
if state.dave_session.ready:
_log.info(
"mls_commit_transition (transition %s) skipped — session already ready epoch=%s",
transt_id,
state.dave_session.epoch,
)
if transt_id != 0:
state.dave_pending_transition = {
"transition_id": transt_id,
"protocol_version": state.dave_protocol_version,
}
await self.send_dave_transition_ready(transt_id)
else:
try:
state.dave_session.process_commit(msg[5:])
auth = state.dave_session.get_epoch_authenticator()
_log.info(
"MLS commit processed (transition %s) — dave.ready=%s epoch=%s user_ids=%s privacy_code=%s epoch_auth=%s",
transt_id,
state.dave_session.ready,
state.dave_session.epoch,
state.dave_session.get_user_ids(),
state.dave_session.voice_privacy_code,
auth.hex() if auth else None,
)
if transt_id != 0:
state.dave_pending_transition = {
"transition_id": transt_id,
"protocol_version": state.dave_protocol_version,
}
_log.debug(
"Sending DAVE transition ready from MLS commit transition with data: %s",
state.dave_pending_transition,
)
await self.send_dave_transition_ready(transt_id)
_log.debug("Processed MLS commit for transition %s", transt_id)
except Exception as exc:
_log.debug(
"Sending DAVE transition ready from MLS commit transition with data: %s",
state.dave_pending_transition,
"An exception ocurred while processing a MLS commit, this should be safe to ignore: %s",
exc,
)
await self.send_dave_transition_ready(transt_id)
_log.debug("Processed MLS commit for transition %s", transt_id)
except Exception as exc:
_log.debug(
"An exception ocurred while processing a MLS commit, this should be safe to ignore: %s",
exc,
)
await state.recover_dave_from_invalid_commit(transt_id)
await state.recover_dave_from_invalid_commit(transt_id)
elif op == OpCodes.mls_welcome:
transt_id = struct.unpack_from(">H", msg, 3)[0]
try:
state.dave_session.process_welcome(msg[5:])
# If session is already ready (self-commit was applied), skip re-processing.
if state.dave_session.ready:
_log.info(
"mls_welcome (transition %s) skipped — session already ready epoch=%s",
transt_id,
state.dave_session.epoch,
)
if transt_id != 0:
state.dave_pending_transition = {
"transition_id": transt_id,
"protocol_version": state.dave_protocol_version,
}
await self.send_dave_transition_ready(transt_id)
else:
try:
state.dave_session.process_welcome(msg[5:])
if transt_id != 0:
state.dave_pending_transition = {
"transition_id": transt_id,
"protocol_version": state.dave_protocol_version,
}
_log.debug(
"Sending DAVE transition ready from MLS welcome with data: %s",
state.dave_pending_transition,
)
await self.send_dave_transition_ready(transt_id)
_log.debug("Processed MLS welcome for transition %s", transt_id)
except Exception as exc:
_log.debug(
"Sending DAVE transition ready from MLS welcome with data: %s",
state.dave_pending_transition,
"An exception ocurred while processing a MLS welcome, this should be safe to ignore: %s",
exc,
)
await self.send_dave_transition_ready(transt_id)
_log.debug("Processed MLS welcome for transition %s", transt_id)
except Exception as exc:
_log.debug(
"An exception ocurred while processing a MLS welcome, this should be safe to ignore: %s",
exc,
)
await state.recover_dave_from_invalid_commit(transt_id)
await state.recover_dave_from_invalid_commit(transt_id)

async def ready(self, data: dict[str, Any]) -> None:
state = self.state
Expand Down
150 changes: 133 additions & 17 deletions discord/voice/receive/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
self.client: VoiceClient = client
self.after: AfterCallback | None = after

# self.sink._client = client
self.sink.init(client)

self.active: bool = False
self.error: Exception | None = None
Expand Down Expand Up @@ -287,32 +287,144 @@ def _make_box(self, secret_key: bytes) -> EncryptionBox:

return data"""

# Per-SSRC counters used to suppress repetitive log lines.
_dave_success: dict[int, int] = {}
_dave_consecutive_failures: dict[int, int] = {}
_dave_seen_generations: dict[int, set] = {}

@staticmethod
def _parse_dave_generation(data: bytes) -> int:
"""Return the key generation encoded in a DAVE supplemental block, or -1.

DAVE frame layout (from end of payload):
[...ciphertext][auth_tag(8B)][nonce(LEB128)][supp_size(1B)][0xFAFA(2B)]
supp_size counts the entire trailing block including itself and the magic.
"""
if len(data) < 12:
return -1
if data[-2:] != b"\xfa\xfa":
return -1
supp_size = data[-3]
if supp_size < 11 or supp_size > len(data):
return -1
block_start = len(data) - supp_size
nonce_pos = block_start + 8 # skip auth_tag (8B)
nonce_end = len(data) - 3 # position of supp_size byte
nonce = 0
shift = 0
for i in range(nonce_pos, nonce_end):
b = data[i]
nonce |= (b & 0x7F) << shift
shift += 7
if not (b & 0x80):
break
return (nonce >> 24) & 0xFF

def decrypt_rtp(self, packet: RTPPacket) -> bytes:
state = self.client._connection
dave = state.dave_session

raw_payload = self._decryptor_rtp(packet)

# For extended RTP packets (which Discord always sends for audio),
# _decryptor_rtp already strips the RTP extension values so that
# davey.decrypt() receives only the DAVE frame. For non-extended
# packets fall back to the full outer-decrypted buffer.
if packet.extended:
dave_input = raw_payload
else:
dave_input = getattr(packet, "_outer_decrypted", raw_payload)

if dave is not None and dave.ready:
uid = state.ssrc_user_map.get(packet.ssrc)
if uid:
try:
decrypted_audio = dave.decrypt(
uid,
davey.MediaType.audio,
raw_payload,
dave_input,
)

if packet.extended:
offset = packet.update_extended_header(decrypted_audio)
packet.decrypted_data = decrypted_audio[offset:]
else:
packet.decrypted_data = decrypted_audio
success_count = self._dave_success.get(packet.ssrc, 0) + 1
self._dave_success[packet.ssrc] = success_count
prev_fails = self._dave_consecutive_failures.get(packet.ssrc, 0)
self._dave_consecutive_failures[packet.ssrc] = 0

if success_count == 1:
_log.debug(
"DAVE decrypt active ssrc=%s uid=%s", packet.ssrc, uid
)
elif prev_fails > 0:
_log.info(
"DAVE decrypt recovered ssrc=%s uid=%s after %d frame(s)",
packet.ssrc,
uid,
prev_fails,
)

# DAVE output is pure Opus — do NOT call update_extended_header;
# it would misinterpret Opus bytes as RTP extension values.
packet.decrypted_data = decrypted_audio

except Exception as exc:
_log.debug(
"Ignoring exception while decoding DAVE packet", exc_info=exc
)
packet.decrypted_data = OPUS_SILENCE
consec = self._dave_consecutive_failures.get(packet.ssrc, 0) + 1
self._dave_consecutive_failures[packet.ssrc] = consec
gen = self._parse_dave_generation(dave_input)
seen = self._dave_seen_generations.setdefault(packet.ssrc, set())

# Log on the first failure in a burst or when a new generation appears.
if consec == 1 or gen not in seen:
_log.warning(
"DAVE decrypt failed ssrc=%s uid=%s frame_gen=%s epoch=%s err=%s",
packet.ssrc,
uid,
gen,
dave.epoch,
type(exc).__name__,
)
seen.add(gen)

if "UnencryptedWhenPassthroughDisabled" in str(exc):
# Discord sends passthrough (unencrypted) frames even while DAVE
# is active. These carry raw Opus wrapped in a small DAVE
# supplemental block with optional RTP padding appended:
#
# [raw_opus][supp_block(supp_size B)][rtp_padding]
#
# supp_block ends with supp_size(1B) + 0xFAFA(2B); supp_size
# counts the whole block including itself and the magic bytes.
# RTP padding (RFC 3550): last byte = N, strip N bytes from end.
opus_data = raw_payload
if packet.padding and opus_data:
pad_n = opus_data[-1]
if 0 < pad_n < len(opus_data):
opus_data = opus_data[:-pad_n]
if len(opus_data) >= 3 and opus_data[-2:] == b"\xfa\xfa":
supp_size = opus_data[-3]
if 3 <= supp_size < len(opus_data):
opus_data = opus_data[:-supp_size]
packet.decrypted_data = (
opus_data if len(opus_data) >= 3 else OPUS_SILENCE
)
else:
packet.decrypted_data = OPUS_SILENCE
else:
packet.decrypted_data = OPUS_SILENCE
else:
packet.decrypted_data = OPUS_SILENCE

if packet.decrypted_data is None:
if dave is None:
# Non-DAVE mode: outer-decrypted bytes ARE the Opus payload.
if packet.extended:
offset = packet.update_extended_header(raw_payload)
packet.decrypted_data = raw_payload[offset:]
else:
packet.decrypted_data = raw_payload
else:
# DAVE session not ready yet or SSRC not yet mapped — use Opus
# silence to avoid feeding ciphertext to the Opus decoder.
packet.decrypted_data = OPUS_SILENCE

return packet.decrypted_data

Expand Down Expand Up @@ -405,10 +517,6 @@ def _decrypt_rtcp_xsalsa20_poly1305_lite(self, data: bytes) -> bytes:
return header + result

def _decrypt_rtp_aead_xchacha20_poly1305_rtpsize(self, packet: RTPPacket) -> bytes:
_log.debug(
"Decrypting RTP AEAD XChaCha20 Poly1305 RTPSize, has decrypted data?: %s",
packet.decrypted_data is not None,
)
packet.adjust_rtpsize()
nonce = packet.nonce + b"\x00" * 20

Expand All @@ -424,10 +532,18 @@ def _decrypt_rtp_aead_xchacha20_poly1305_rtpsize(self, packet: RTPPacket) -> byt
_log.error("Critical error at AEAD: %s", exc)
raise CryptoError(exc)

# update_extended_header returns the actual payload offset into result.
# For Discord DAVE frames the extension has length=2 (8 bytes) → offset=8.
# For passthrough/unencrypted frames the extension has length=1 (4 bytes)
# → offset=4. Hardcoding result[8:] would strip 4 bytes too many for
# passthrough frames and hand invalid bytes to davey / the Opus decoder.
if packet.extended:
packet.update_extended_header(result)
offset = packet.update_extended_header(result)
else:
offset = 0

return result[8:]
packet._outer_decrypted = result
return result[offset:]

def _decrypt_rtcp_aead_xchacha20_poly1305_rtpsize(self, data: bytes) -> bytes:
_log.debug("Decrypting RTCP AEAD XChaCha20 Poly1305 RTPSize")
Expand Down