diff --git a/async_postgres/pg_replication.nim b/async_postgres/pg_replication.nim index 4e41e4d..a8e928c 100644 --- a/async_postgres/pg_replication.nim +++ b/async_postgres/pg_replication.nim @@ -249,6 +249,50 @@ proc currentPgTimestamp*(): int64 = # pgoutput decoder +# Bounds-checked readers for the pgoutput decoder. +# +# The raw ``decodeInt16``/``decodeInt32``/``decodeInt64`` helpers index the +# buffer directly and rely on Nim's array bounds checks. Those checks are +# compiled out under ``-d:danger`` (and raise the uncatchable ``IndexDefect`` +# otherwise), so feeding a truncated or malicious WAL stream through the +# decoder could read past the end of the buffer. These wrappers validate the +# available length first and raise ``ProtocolError`` (a ``CatchableError`` and +# ``PgError`` subtype) on any shortfall, matching how the rest of the wire +# parsing reports protocol violations. + +proc ensureAvail(buf: openArray[byte], pos, n: int) {.inline.} = + ## Raise ``ProtocolError`` unless ``n`` bytes are readable at ``pos``. + ## ``n > buf.len - pos`` is written so it cannot overflow and so a ``pos`` + ## past the end (negative ``buf.len - pos``) is rejected for any ``n >= 0``. + if pos < 0 or n < 0 or n > buf.len - pos: + raise newException( + ProtocolError, + "pgoutput: truncated message (need " & $n & " byte(s) at offset " & $pos & + ", buffer holds " & $buf.len & ")", + ) + +proc readByteAt(buf: openArray[byte], pos: int): byte {.inline.} = + ensureAvail(buf, pos, 1) + buf[pos] + +proc readInt16At(buf: openArray[byte], pos: int): int16 {.inline.} = + ensureAvail(buf, pos, 2) + decodeInt16(buf, pos) + +proc readInt32At(buf: openArray[byte], pos: int): int32 {.inline.} = + ensureAvail(buf, pos, 4) + decodeInt32(buf, pos) + +proc readInt64At(buf: openArray[byte], pos: int): int64 {.inline.} = + ensureAvail(buf, pos, 8) + decodeInt64(buf, pos) + +proc readBytesAt(buf: openArray[byte], pos, n: int): seq[byte] {.inline.} = + ## ``n`` is attacker-controlled (a length prefix from the stream); validate it + ## against the buffer before the bulk copy in ``readBytes``. + ensureAvail(buf, pos, n) + readBytes(buf, pos, n) + proc decodeCStringAt(buf: openArray[byte], offset: int): (string, int) = ## Decode a null-terminated string at offset. Returns (string, next offset). if offset >= buf.len: @@ -266,11 +310,13 @@ proc decodeCStringAt(buf: openArray[byte], offset: int): (string, int) = proc decodeTuple(buf: openArray[byte], offset: int): (seq[TupleField], int) = ## Decode a pgoutput TupleData structure. var pos = offset - let numCols = decodeInt16(buf, pos) + let numCols = readInt16At(buf, pos) pos += 2 + if numCols < 0: + raise newException(ProtocolError, "pgoutput tuple: negative column count") var fields = newSeq[TupleField](numCols) for i in 0 ..< numCols: - let kind = char(buf[pos]) + let kind = char(readByteAt(buf, pos)) inc pos case kind of 'n': @@ -278,9 +324,9 @@ proc decodeTuple(buf: openArray[byte], offset: int): (seq[TupleField], int) = of 'u': fields[i] = TupleField(kind: tdkUnchanged) of 't', 'b': - let dataLen = decodeInt32(buf, pos) + let dataLen = readInt32At(buf, pos) pos += 4 - let data = readBytes(buf, pos, int(dataLen)) + let data = readBytesAt(buf, pos, int(dataLen)) pos += int(dataLen) fields[i] = TupleField(kind: if kind == 't': tdkText else: tdkBinary, data: data) else: @@ -295,26 +341,26 @@ proc parsePgOutputMessage*(data: openArray[byte]): PgOutputMessage = case msgType of 'B': # Begin var msg = BeginMessage() - msg.finalLsn = Lsn(cast[uint64](decodeInt64(data, 1))) - msg.commitTime = decodeInt64(data, 9) - msg.xid = decodeInt32(data, 17) + msg.finalLsn = Lsn(cast[uint64](readInt64At(data, 1))) + msg.commitTime = readInt64At(data, 9) + msg.xid = readInt32At(data, 17) PgOutputMessage(kind: pomkBegin, begin: msg) of 'C': # Commit var msg = CommitMessage() - msg.flags = data[1] - msg.commitLsn = Lsn(cast[uint64](decodeInt64(data, 2))) - msg.endLsn = Lsn(cast[uint64](decodeInt64(data, 10))) - msg.commitTime = decodeInt64(data, 18) + msg.flags = readByteAt(data, 1) + msg.commitLsn = Lsn(cast[uint64](readInt64At(data, 2))) + msg.endLsn = Lsn(cast[uint64](readInt64At(data, 10))) + msg.commitTime = readInt64At(data, 18) PgOutputMessage(kind: pomkCommit, commit: msg) of 'O': # Origin var msg = OriginMessage() - msg.originLsn = Lsn(cast[uint64](decodeInt64(data, 1))) + msg.originLsn = Lsn(cast[uint64](readInt64At(data, 1))) let (name, _) = decodeCStringAt(data, 9) msg.originName = name PgOutputMessage(kind: pomkOrigin, origin: msg) of 'R': # Relation var msg = RelationInfo() - msg.relationId = decodeInt32(data, 1) + msg.relationId = readInt32At(data, 1) var pos = 5 let (ns, pos2) = decodeCStringAt(data, pos) msg.namespace = ns @@ -322,27 +368,29 @@ proc parsePgOutputMessage*(data: openArray[byte]): PgOutputMessage = let (name, pos3) = decodeCStringAt(data, pos) msg.name = name pos = pos3 - msg.replicaIdentity = char(data[pos]) + msg.replicaIdentity = char(readByteAt(data, pos)) inc pos - let numCols = decodeInt16(data, pos) + let numCols = readInt16At(data, pos) pos += 2 + if numCols < 0: + raise newException(ProtocolError, "pgoutput Relation: negative column count") msg.columns = newSeq[RelationColumn](numCols) for i in 0 ..< numCols: var col = RelationColumn() - col.flags = data[pos] + col.flags = readByteAt(data, pos) inc pos let (colName, nextPos) = decodeCStringAt(data, pos) col.name = colName pos = nextPos - col.typeOid = decodeInt32(data, pos) + col.typeOid = readInt32At(data, pos) pos += 4 - col.typeMod = decodeInt32(data, pos) + col.typeMod = readInt32At(data, pos) pos += 4 msg.columns[i] = col PgOutputMessage(kind: pomkRelation, relation: msg) of 'Y': # Type var msg = TypeMessage() - msg.typeId = decodeInt32(data, 1) + msg.typeId = readInt32At(data, 1) var pos = 5 let (ns, pos2) = decodeCStringAt(data, pos) msg.namespace = ns @@ -352,16 +400,16 @@ proc parsePgOutputMessage*(data: openArray[byte]): PgOutputMessage = PgOutputMessage(kind: pomkType, typeMsg: msg) of 'I': # Insert var msg = InsertMessage() - msg.relationId = decodeInt32(data, 1) + msg.relationId = readInt32At(data, 1) # byte at offset 5 is 'N' (new tuple marker) let (fields, _) = decodeTuple(data, 6) msg.newTuple = fields PgOutputMessage(kind: pomkInsert, insert: msg) of 'U': # Update var msg = UpdateMessage() - msg.relationId = decodeInt32(data, 1) + msg.relationId = readInt32At(data, 1) var pos = 5 - let marker = char(data[pos]) + let marker = char(readByteAt(data, pos)) inc pos if marker == 'K' or marker == 'O': # Old key or old tuple included @@ -377,7 +425,7 @@ proc parsePgOutputMessage*(data: openArray[byte]): PgOutputMessage = PgOutputMessage(kind: pomkUpdate, update: msg) of 'D': # Delete var msg = DeleteMessage() - msg.relationId = decodeInt32(data, 1) + msg.relationId = readInt32At(data, 1) var pos = 5 # byte at offset 5 is 'K' (key) or 'O' (old tuple) inc pos @@ -386,25 +434,30 @@ proc parsePgOutputMessage*(data: openArray[byte]): PgOutputMessage = PgOutputMessage(kind: pomkDelete, delete: msg) of 'T': # Truncate var msg = TruncateMessage() - let numRels = decodeInt32(data, 1) - msg.options = data[5] - msg.relationIds = newSeq[int32](numRels) + let numRels = readInt32At(data, 1) + msg.options = readByteAt(data, 5) var pos = 6 + # Each relation id is exactly 4 bytes; reject a count that cannot fit in the + # remaining buffer before allocating, so a forged count can neither trigger + # a huge allocation nor over-read in the loop below. + if numRels < 0 or numRels.int > (data.len - pos) div 4: + raise newException(ProtocolError, "pgoutput Truncate: invalid relation count") + msg.relationIds = newSeq[int32](numRels) for i in 0 ..< numRels: - msg.relationIds[i] = decodeInt32(data, pos) + msg.relationIds[i] = readInt32At(data, pos) pos += 4 PgOutputMessage(kind: pomkTruncate, truncate: msg) of 'M': # Message var msg = LogicalMessage() - msg.flags = data[1] - msg.lsn = Lsn(cast[uint64](decodeInt64(data, 2))) + msg.flags = readByteAt(data, 1) + msg.lsn = Lsn(cast[uint64](readInt64At(data, 2))) var pos = 10 let (prefix, nextPos) = decodeCStringAt(data, pos) msg.prefix = prefix pos = nextPos - let contentLen = decodeInt32(data, pos) + let contentLen = readInt32At(data, pos) pos += 4 - msg.content = readBytes(data, pos, int(contentLen)) + msg.content = readBytesAt(data, pos, int(contentLen)) PgOutputMessage(kind: pomkMessage, message: msg) else: raise newException(ProtocolError, "Unknown pgoutput message type: " & msgType) diff --git a/tests/all_tests.nim b/tests/all_tests.nim index 02fb4af..8767138 100644 --- a/tests/all_tests.nim +++ b/tests/all_tests.nim @@ -3,6 +3,6 @@ import test_abandonment_e2e, test_advisory_lock, test_auth, test_cancel_e2e, test_dsn, test_e2e, test_fill_recvbuf, test_keepalive, test_largeobject, test_network_failure, test_physical_replication, test_pool, test_protocol, test_protocol_fuzz, - test_replication_keepalive, test_rowdata, test_sql, test_ssl, test_tracing, - test_types, test_pool_cluster + test_replication, test_replication_keepalive, test_rowdata, test_sql, test_ssl, + test_tracing, test_types, test_pool_cluster {.pop.} diff --git a/tests/test_replication.nim b/tests/test_replication.nim index bf873e5..5854f43 100644 --- a/tests/test_replication.nim +++ b/tests/test_replication.nim @@ -369,3 +369,146 @@ suite "pgoutput decoder": test "unknown message type raises": expect(ProtocolError): discard parsePgOutputMessage(@[byte('Z')]) + +suite "pgoutput decoder bounds checking": + # Each case feeds a truncated or forged message and asserts a catchable + # ProtocolError instead of an out-of-bounds read (which would be an + # uncatchable IndexDefect, or undefined behaviour under -d:danger). + + test "Begin truncated raises ProtocolError": + # 'B' followed by only 4 of the 8 finalLsn bytes + var data = @[byte('B'), 0, 0, 0, 0] + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Commit truncated raises ProtocolError": + # 'C' + flags only, the three LSN/time fields are missing + var data = @[byte('C'), 0] + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Origin truncated raises ProtocolError": + var data = @[byte('O'), 0, 0, 0] + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Type truncated raises ProtocolError": + var data = @[byte('Y'), 0, 0] + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Insert with column count past end raises ProtocolError": + var data: seq[byte] + data.add(byte('I')) + data.addInt32(16384'i32) + data.add(byte('N')) + data.addInt16(100'i16) # claims 100 columns, none follow + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Insert with negative column count raises ProtocolError": + var data: seq[byte] + data.add(byte('I')) + data.addInt32(16384'i32) + data.add(byte('N')) + data.addInt16(-1'i16) + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Insert tuple data length past end raises ProtocolError": + var data: seq[byte] + data.add(byte('I')) + data.addInt32(16384'i32) + data.add(byte('N')) + data.addInt16(1'i16) # 1 column + data.add(byte('t')) + data.addInt32(1000'i32) # claims 1000 bytes + data.add(byte('x')) # but only 1 present + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Insert tuple with negative data length raises ProtocolError": + var data: seq[byte] + data.add(byte('I')) + data.addInt32(16384'i32) + data.add(byte('N')) + data.addInt16(1'i16) + data.add(byte('t')) + data.addInt32(-5'i32) + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Relation truncated mid-column raises ProtocolError": + var data: seq[byte] + data.add(byte('R')) + data.addInt32(16384'i32) + for c in "public": + data.add(byte(c)) + data.add(0'u8) + for c in "t": + data.add(byte(c)) + data.add(0'u8) + data.add(byte('d')) # replicaIdentity + data.addInt16(1'i16) # 1 column + data.add(1'u8) # flags + for c in "id": + data.add(byte(c)) + data.add(0'u8) + data.addInt32(23'i32) # typeOid + # typeMod missing — truncated here + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Truncate with oversized relation count raises ProtocolError": + var data: seq[byte] + data.add(byte('T')) + data.addInt32(1_000_000'i32) # claims 1M relations + data.add(0'u8) # options + data.addInt32(16384'i32) # only one actually present + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Truncate with negative relation count raises ProtocolError": + var data: seq[byte] + data.add(byte('T')) + data.addInt32(-1'i32) + data.add(0'u8) + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Update with marker past end raises ProtocolError": + var data: seq[byte] + data.add(byte('U')) + data.addInt32(16384'i32) # relationId, then nothing (no marker byte) + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Message content length past end raises ProtocolError": + var data: seq[byte] + data.add(byte('M')) + data.add(0'u8) # flags + data.addInt64(0x800'i64) # lsn + for c in "p": + data.add(byte(c)) + data.add(0'u8) # prefix terminator + data.addInt32(9999'i32) # claims 9999 bytes + data.add(byte('x')) # but only 1 present + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "Message truncated header raises ProtocolError": + var data = @[byte('M'), 0, 0, 0] + expect(ProtocolError): + discard parsePgOutputMessage(data) + + test "valid messages still decode after hardening": + # Regression guard: a well-formed Begin must continue to parse cleanly. + var data: seq[byte] + data.add(byte('B')) + data.addInt64(0x500'i64) + data.addInt64(99999'i64) + data.addInt32(42'i32) + let msg = parsePgOutputMessage(data) + check msg.kind == pomkBegin + check msg.begin.finalLsn == Lsn(0x500'u64) + check msg.begin.xid == 42'i32