From 39965abfa7806ae091fcec8e87a385d992f12e76 Mon Sep 17 00:00:00 2001 From: Rob Hague Date: Mon, 27 Apr 2026 22:02:55 +0200 Subject: [PATCH] Remove plaintext receive buffer #1752 added a persistent buffer into which to decrypt packets, rather than allocating a new array for each packet. This was on the back of #1733 which added support in the cipher types for decrypting into a given buffer, but for the case of AES-CTR, not into the same buffer in-place. #1787 adds that missing support, meaning we can now decrypt in-place, and the plaintext buffer becomes unnecessary. --- .../Ciphers/ChaCha20Poly1305Cipher.cs | 25 ++-- src/Renci.SshNet/Session.cs | 135 ++++++++++-------- 2 files changed, 84 insertions(+), 76 deletions(-) diff --git a/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs b/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs index 9ce4f53d0..a87bf047b 100644 --- a/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs +++ b/src/Renci.SshNet/Security/Cryptography/Ciphers/ChaCha20Poly1305Cipher.cs @@ -138,23 +138,28 @@ public override int Encrypt(byte[] input, int offset, int length, byte[] output, /// The decrypted plaintext. public override byte[] Decrypt(byte[] input, int offset, int length) { - byte[] output; + var output = new byte[length]; + + _cipher.Init(forEncryption: false, new ParametersWithIV(_keyParameter, _iv)); + + var keyStream = new byte[64]; + _cipher.ProcessBytes(keyStream, 0, keyStream.Length, keyStream, 0); + _mac.Init(new KeyParameter(keyStream, 0, 32)); if (_aadLength > 0) { // If we are in 'AAD mode', then put these bytes through the AAD cipher. + _mac.BlockUpdate(input, offset, length); + Debug.Assert(_aadCipher != null); _aadCipher.Init(forEncryption: false, new ParametersWithIV(_aadKeyParameter, _iv)); - output = new byte[length]; _aadCipher.ProcessBytes(input, offset, length, output, 0); } else { - output = new byte[length]; - var bytesWritten = Decrypt(input, offset, length, output, 0); Debug.Assert(bytesWritten == length); @@ -169,7 +174,7 @@ public override byte[] Decrypt(byte[] input, int offset, int length) /// /// The input data with below format: /// - /// [----][----Cipher AAD----(offset)][----Cipher Text----(length)][----TAG----] + /// [----(offset)][----Cipher Text----(length)][----TAG----] /// /// /// The zero-based offset in at which to begin decrypting and authenticating. @@ -179,16 +184,8 @@ public override byte[] Decrypt(byte[] input, int offset, int length) /// The number of plaintext bytes written to . public override int Decrypt(byte[] input, int offset, int length, byte[] output, int outputOffset) { - Debug.Assert(offset >= _aadLength, "The offset must be greater than or equals to aad length"); - - _cipher.Init(forEncryption: false, new ParametersWithIV(_keyParameter, _iv)); - - var keyStream = new byte[64]; - _cipher.ProcessBytes(keyStream, 0, keyStream.Length, keyStream, 0); - _mac.Init(new KeyParameter(keyStream, 0, 32)); - var tag = new byte[TagSize]; - _mac.BlockUpdate(input, offset - _aadLength, length + _aadLength); + _mac.BlockUpdate(input, offset, length); _ = _mac.DoFinal(tag, 0); if (!Arrays.FixedTimeEquals(TagSize, tag, 0, input, offset + length)) { diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index a6576d9af..c73c3db26 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -105,6 +105,23 @@ public sealed class Session : ISession /// private readonly SemaphoreSlim _connectLock = new SemaphoreSlim(1, 1); + private readonly byte[] _inboundPacketSequenceBytes = new byte[4]; + + /// + /// Gets or sets the incoming packet number. + /// + private uint InboundPacketSequence + { + get + { + return BinaryPrimitives.ReadUInt32BigEndian(_inboundPacketSequenceBytes); + } + set + { + BinaryPrimitives.WriteUInt32BigEndian(_inboundPacketSequenceBytes, value); + } + } + /// /// Holds metadata about session messages. /// @@ -120,11 +137,6 @@ public sealed class Session : ISession /// private volatile uint _outboundPacketSequence; - /// - /// Specifies incoming packet number. - /// - private uint _inboundPacketSequence; - /// /// WaitHandle to signal that last service request was accepted. /// @@ -200,7 +212,6 @@ public sealed class Session : ISession private Socket _socket; private ArrayBuffer _receiveBuffer = new(4 * 1024); - private byte[] _plaintextReceiveBuffer = new byte[4 * 1024]; /// /// Gets the session semaphore that controls session channels. @@ -1213,9 +1224,6 @@ private bool TrySendMessage(Message message) /// private Message ReceiveMessage(Socket socket) { - // the length of the packet sequence field in bytes - const int inboundPacketSequenceLength = 4; - // The length of the "packet length" field in bytes const int packetLengthFieldLength = 4; @@ -1272,31 +1280,28 @@ private Message ReceiveMessage(Socket socket) } } - var firstBlock = new ArraySegment( - _receiveBuffer.DangerousGetUnderlyingBuffer(), - _receiveBuffer.ActiveStartOffset, - blockSize); - - var plainFirstBlock = firstBlock; - - // For ETM or AES-GCM, firstBlock holds the packet length which is - // not encrypted. Otherwise, we decrypt the first "blockSize" bytes. - // (For chacha20-poly1305, this means passing the encrypted packet - // length as AAD). + // For ETM or AES-GCM, the first "blockSize" bytes hold the packet length + // which is not encrypted. Otherwise, we decrypt them. + // (For chacha20-poly1305, this means passing the encrypted packet length + // to its AAD cipher instance - it is the awkward difference between the + // 3-arg and 5-arg Decrypt, and explains why we don't just decrypt these + // bytes in-place). if (_serverCipher is not null and not Security.Cryptography.Ciphers.AesGcmCipher) { - _serverCipher.SetSequenceNumber(_inboundPacketSequence); + _serverCipher.SetSequenceNumber(InboundPacketSequence); if (_serverMac == null || !_serverEtm) { - plainFirstBlock = new ArraySegment(_serverCipher.Decrypt( - firstBlock.Array, - firstBlock.Offset, - firstBlock.Count)); + var plainFirstBlock = _serverCipher.Decrypt( + _receiveBuffer.DangerousGetUnderlyingBuffer(), + _receiveBuffer.ActiveStartOffset, + blockSize); + + plainFirstBlock.CopyTo(_receiveBuffer.ActiveSpan); } } - var packetLength = BinaryPrimitives.ReadInt32BigEndian(plainFirstBlock); + var packetLength = BinaryPrimitives.ReadInt32BigEndian(_receiveBuffer.ActiveReadOnlySpan); // Test packet minimum and maximum boundaries if (packetLength < Math.Max((byte)8, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4) @@ -1330,26 +1335,13 @@ private Message ReceiveMessage(Socket socket) } } - // Construct buffer for holding the payload and the inbound packet sequence as we need both in order - // to generate the hash. - var plaintextLength = 4 + totalPacketLength - serverMacLength; - - if (_plaintextReceiveBuffer.Length < plaintextLength) - { - Array.Resize(ref _plaintextReceiveBuffer, Math.Max(plaintextLength, 2 * _plaintextReceiveBuffer.Length)); - } - - BinaryPrimitives.WriteUInt32BigEndian(_plaintextReceiveBuffer, _inboundPacketSequence); - - plainFirstBlock.AsSpan().CopyTo(_plaintextReceiveBuffer.AsSpan(4)); - if (_serverMac != null && _serverEtm) { // ETM mac = MAC(key, sequence_number || packet_length || encrypted_packet) // sequence_number _ = _serverMac.TransformBlock( - inputBuffer: _plaintextReceiveBuffer, + inputBuffer: _inboundPacketSequenceBytes, inputOffset: 0, inputCount: 4, outputBuffer: null, @@ -1377,41 +1369,52 @@ private Message ReceiveMessage(Socket socket) { Debug.Assert(numberOfBytesToDecrypt % blockSize == 0); + var decryptBuffer = _receiveBuffer.DangerousGetUnderlyingBuffer(); + var decryptOffset = _receiveBuffer.ActiveStartOffset + blockSize; + var numberOfBytesDecrypted = _serverCipher.Decrypt( - input: _receiveBuffer.DangerousGetUnderlyingBuffer(), - offset: _receiveBuffer.ActiveStartOffset + blockSize, + input: decryptBuffer, + offset: decryptOffset, length: numberOfBytesToDecrypt, - output: _plaintextReceiveBuffer, - outputOffset: 4 + blockSize); + output: decryptBuffer, + outputOffset: decryptOffset); Debug.Assert(numberOfBytesDecrypted == numberOfBytesToDecrypt); } - else - { - _receiveBuffer.ActiveReadOnlySpan - .Slice(blockSize, numberOfBytesToDecrypt) - .CopyTo(_plaintextReceiveBuffer.AsSpan(4 + blockSize)); - } if (_serverMac != null && !_serverEtm) { // non-ETM mac = MAC(key, sequence_number || unencrypted_packet) - var clientHash = _serverMac.ComputeHash(_plaintextReceiveBuffer, 0, plaintextLength); + // sequence_number + _ = _serverMac.TransformBlock( + inputBuffer: _inboundPacketSequenceBytes, + inputOffset: 0, + inputCount: 4, + outputBuffer: null, + outputOffset: 0); + + // unencrypted_packet + _ = _serverMac.TransformBlock( + inputBuffer: _receiveBuffer.DangerousGetUnderlyingBuffer(), + inputOffset: _receiveBuffer.ActiveStartOffset, + inputCount: totalPacketLength - serverMacLength, + outputBuffer: null, + outputOffset: 0); + + _ = _serverMac.TransformFinalBlock(Array.Empty(), 0, 0); - if (!CryptoAbstraction.FixedTimeEquals(clientHash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength))) + if (!CryptoAbstraction.FixedTimeEquals(_serverMac.Hash, _receiveBuffer.ActiveSpan.Slice(totalPacketLength - serverMacLength, serverMacLength))) { throw new SshConnectionException("MAC error", DisconnectReason.MacError); } } - _receiveBuffer.Discard(totalPacketLength); - - var paddingLength = _plaintextReceiveBuffer[inboundPacketSequenceLength + packetLengthFieldLength]; + var paddingLength = _receiveBuffer.ActiveReadOnlySpan[packetLengthFieldLength]; ArraySegment payload = new( - _plaintextReceiveBuffer, - offset: inboundPacketSequenceLength + packetLengthFieldLength + paddingLengthFieldLength, + _receiveBuffer.DangerousGetUnderlyingBuffer(), + offset: _receiveBuffer.ActiveStartOffset + packetLengthFieldLength + paddingLengthFieldLength, count: packetLength - paddingLength - paddingLengthFieldLength); if (_serverDecompression != null) @@ -1419,16 +1422,24 @@ private Message ReceiveMessage(Socket socket) payload = new(_serverDecompression.Decompress(payload.Array, payload.Offset, payload.Count)); } - _inboundPacketSequence++; + var newInboundPacketSequence = ++InboundPacketSequence; // The below code mirrors from https://github.com/openssh/openssh-portable/commit/1edb00c58f8a6875fad6a497aa2bacf37f9e6cd5 // It ensures the integrity of key exchange process. - if (_inboundPacketSequence == uint.MaxValue && _isInitialKex) + if (newInboundPacketSequence == uint.MaxValue && _isInitialKex) { throw new SshConnectionException("Inbound packet sequence number is about to wrap during initial key exchange.", DisconnectReason.KeyExchangeFailed); } - return LoadMessage(payload.Array, payload.Offset, payload.Count); + var message = LoadMessage(payload.Array, payload.Offset, payload.Count); + + // The deserialised message may still reference data in the buffer, so calling Discard + // here might seem misguided. It is OK because Discard does not mutate the buffer + // and it will not be touched again until the next call to ReceiveMessage, which will + // only occur after the message has been fully processed. + _receiveBuffer.Discard(totalPacketLength); + + return message; } private void TrySendDisconnect(DisconnectReason reasonCode, string message) @@ -1545,7 +1556,7 @@ internal void OnKeyExchangeInitReceived(KeyExchangeInitMessage message) _logger.LogDebug("[{SessionId}] Enabling strict key exchange extension.", SessionIdHex); - if (_inboundPacketSequence != 1) + if (InboundPacketSequence != 1) { throw new SshConnectionException("KEXINIT was not the first packet during strict key exchange.", DisconnectReason.KeyExchangeFailed); } @@ -1646,7 +1657,7 @@ internal void OnNewKeysReceived(NewKeysMessage message) if (_isStrictKex) { - _inboundPacketSequence = 0; + InboundPacketSequence = 0; } NewKeysReceived?.Invoke(this, new MessageEventArgs(message));