Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 84 additions & 31 deletions async_postgres/pg_replication.nim
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,50 @@ proc currentPgTimestamp*(): int64 =

# pgoutput decoder

# Bounds-checked readers for the pgoutput decoder.
#
# The raw ``decodeInt16``/``decodeInt32``/``decodeInt64`` helpers index the
# buffer directly and rely on Nim's array bounds checks. Those checks are
# compiled out under ``-d:danger`` (and raise the uncatchable ``IndexDefect``
# otherwise), so feeding a truncated or malicious WAL stream through the
# decoder could read past the end of the buffer. These wrappers validate the
# available length first and raise ``ProtocolError`` (a ``CatchableError`` and
# ``PgError`` subtype) on any shortfall, matching how the rest of the wire
# parsing reports protocol violations.

proc ensureAvail(buf: openArray[byte], pos, n: int) {.inline.} =
## Raise ``ProtocolError`` unless ``n`` bytes are readable at ``pos``.
## ``n > buf.len - pos`` is written so it cannot overflow and so a ``pos``
## past the end (negative ``buf.len - pos``) is rejected for any ``n >= 0``.
if pos < 0 or n < 0 or n > buf.len - pos:
raise newException(
ProtocolError,
"pgoutput: truncated message (need " & $n & " byte(s) at offset " & $pos &
", buffer holds " & $buf.len & ")",
)

proc readByteAt(buf: openArray[byte], pos: int): byte {.inline.} =
ensureAvail(buf, pos, 1)
buf[pos]

proc readInt16At(buf: openArray[byte], pos: int): int16 {.inline.} =
ensureAvail(buf, pos, 2)
decodeInt16(buf, pos)

proc readInt32At(buf: openArray[byte], pos: int): int32 {.inline.} =
ensureAvail(buf, pos, 4)
decodeInt32(buf, pos)

proc readInt64At(buf: openArray[byte], pos: int): int64 {.inline.} =
ensureAvail(buf, pos, 8)
decodeInt64(buf, pos)

proc readBytesAt(buf: openArray[byte], pos, n: int): seq[byte] {.inline.} =
## ``n`` is attacker-controlled (a length prefix from the stream); validate it
## against the buffer before the bulk copy in ``readBytes``.
ensureAvail(buf, pos, n)
readBytes(buf, pos, n)

proc decodeCStringAt(buf: openArray[byte], offset: int): (string, int) =
## Decode a null-terminated string at offset. Returns (string, next offset).
if offset >= buf.len:
Expand All @@ -266,21 +310,23 @@ proc decodeCStringAt(buf: openArray[byte], offset: int): (string, int) =
proc decodeTuple(buf: openArray[byte], offset: int): (seq[TupleField], int) =
## Decode a pgoutput TupleData structure.
var pos = offset
let numCols = decodeInt16(buf, pos)
let numCols = readInt16At(buf, pos)
pos += 2
if numCols < 0:
raise newException(ProtocolError, "pgoutput tuple: negative column count")
var fields = newSeq[TupleField](numCols)
for i in 0 ..< numCols:
let kind = char(buf[pos])
let kind = char(readByteAt(buf, pos))
inc pos
case kind
of 'n':
fields[i] = TupleField(kind: tdkNull)
of 'u':
fields[i] = TupleField(kind: tdkUnchanged)
of 't', 'b':
let dataLen = decodeInt32(buf, pos)
let dataLen = readInt32At(buf, pos)
pos += 4
let data = readBytes(buf, pos, int(dataLen))
let data = readBytesAt(buf, pos, int(dataLen))
pos += int(dataLen)
fields[i] = TupleField(kind: if kind == 't': tdkText else: tdkBinary, data: data)
else:
Expand All @@ -295,54 +341,56 @@ proc parsePgOutputMessage*(data: openArray[byte]): PgOutputMessage =
case msgType
of 'B': # Begin
var msg = BeginMessage()
msg.finalLsn = Lsn(cast[uint64](decodeInt64(data, 1)))
msg.commitTime = decodeInt64(data, 9)
msg.xid = decodeInt32(data, 17)
msg.finalLsn = Lsn(cast[uint64](readInt64At(data, 1)))
msg.commitTime = readInt64At(data, 9)
msg.xid = readInt32At(data, 17)
PgOutputMessage(kind: pomkBegin, begin: msg)
of 'C': # Commit
var msg = CommitMessage()
msg.flags = data[1]
msg.commitLsn = Lsn(cast[uint64](decodeInt64(data, 2)))
msg.endLsn = Lsn(cast[uint64](decodeInt64(data, 10)))
msg.commitTime = decodeInt64(data, 18)
msg.flags = readByteAt(data, 1)
msg.commitLsn = Lsn(cast[uint64](readInt64At(data, 2)))
msg.endLsn = Lsn(cast[uint64](readInt64At(data, 10)))
msg.commitTime = readInt64At(data, 18)
PgOutputMessage(kind: pomkCommit, commit: msg)
of 'O': # Origin
var msg = OriginMessage()
msg.originLsn = Lsn(cast[uint64](decodeInt64(data, 1)))
msg.originLsn = Lsn(cast[uint64](readInt64At(data, 1)))
let (name, _) = decodeCStringAt(data, 9)
msg.originName = name
PgOutputMessage(kind: pomkOrigin, origin: msg)
of 'R': # Relation
var msg = RelationInfo()
msg.relationId = decodeInt32(data, 1)
msg.relationId = readInt32At(data, 1)
var pos = 5
let (ns, pos2) = decodeCStringAt(data, pos)
msg.namespace = ns
pos = pos2
let (name, pos3) = decodeCStringAt(data, pos)
msg.name = name
pos = pos3
msg.replicaIdentity = char(data[pos])
msg.replicaIdentity = char(readByteAt(data, pos))
inc pos
let numCols = decodeInt16(data, pos)
let numCols = readInt16At(data, pos)
pos += 2
if numCols < 0:
raise newException(ProtocolError, "pgoutput Relation: negative column count")
msg.columns = newSeq[RelationColumn](numCols)
for i in 0 ..< numCols:
var col = RelationColumn()
col.flags = data[pos]
col.flags = readByteAt(data, pos)
inc pos
let (colName, nextPos) = decodeCStringAt(data, pos)
col.name = colName
pos = nextPos
col.typeOid = decodeInt32(data, pos)
col.typeOid = readInt32At(data, pos)
pos += 4
col.typeMod = decodeInt32(data, pos)
col.typeMod = readInt32At(data, pos)
pos += 4
msg.columns[i] = col
PgOutputMessage(kind: pomkRelation, relation: msg)
of 'Y': # Type
var msg = TypeMessage()
msg.typeId = decodeInt32(data, 1)
msg.typeId = readInt32At(data, 1)
var pos = 5
let (ns, pos2) = decodeCStringAt(data, pos)
msg.namespace = ns
Expand All @@ -352,16 +400,16 @@ proc parsePgOutputMessage*(data: openArray[byte]): PgOutputMessage =
PgOutputMessage(kind: pomkType, typeMsg: msg)
of 'I': # Insert
var msg = InsertMessage()
msg.relationId = decodeInt32(data, 1)
msg.relationId = readInt32At(data, 1)
# byte at offset 5 is 'N' (new tuple marker)
let (fields, _) = decodeTuple(data, 6)
msg.newTuple = fields
PgOutputMessage(kind: pomkInsert, insert: msg)
of 'U': # Update
var msg = UpdateMessage()
msg.relationId = decodeInt32(data, 1)
msg.relationId = readInt32At(data, 1)
var pos = 5
let marker = char(data[pos])
let marker = char(readByteAt(data, pos))
inc pos
if marker == 'K' or marker == 'O':
# Old key or old tuple included
Expand All @@ -377,7 +425,7 @@ proc parsePgOutputMessage*(data: openArray[byte]): PgOutputMessage =
PgOutputMessage(kind: pomkUpdate, update: msg)
of 'D': # Delete
var msg = DeleteMessage()
msg.relationId = decodeInt32(data, 1)
msg.relationId = readInt32At(data, 1)
var pos = 5
# byte at offset 5 is 'K' (key) or 'O' (old tuple)
inc pos
Expand All @@ -386,25 +434,30 @@ proc parsePgOutputMessage*(data: openArray[byte]): PgOutputMessage =
PgOutputMessage(kind: pomkDelete, delete: msg)
of 'T': # Truncate
var msg = TruncateMessage()
let numRels = decodeInt32(data, 1)
msg.options = data[5]
msg.relationIds = newSeq[int32](numRels)
let numRels = readInt32At(data, 1)
msg.options = readByteAt(data, 5)
var pos = 6
# Each relation id is exactly 4 bytes; reject a count that cannot fit in the
# remaining buffer before allocating, so a forged count can neither trigger
# a huge allocation nor over-read in the loop below.
if numRels < 0 or numRels.int > (data.len - pos) div 4:
raise newException(ProtocolError, "pgoutput Truncate: invalid relation count")
msg.relationIds = newSeq[int32](numRels)
for i in 0 ..< numRels:
msg.relationIds[i] = decodeInt32(data, pos)
msg.relationIds[i] = readInt32At(data, pos)
pos += 4
PgOutputMessage(kind: pomkTruncate, truncate: msg)
of 'M': # Message
var msg = LogicalMessage()
msg.flags = data[1]
msg.lsn = Lsn(cast[uint64](decodeInt64(data, 2)))
msg.flags = readByteAt(data, 1)
msg.lsn = Lsn(cast[uint64](readInt64At(data, 2)))
var pos = 10
let (prefix, nextPos) = decodeCStringAt(data, pos)
msg.prefix = prefix
pos = nextPos
let contentLen = decodeInt32(data, pos)
let contentLen = readInt32At(data, pos)
pos += 4
msg.content = readBytes(data, pos, int(contentLen))
msg.content = readBytesAt(data, pos, int(contentLen))
PgOutputMessage(kind: pomkMessage, message: msg)
else:
raise newException(ProtocolError, "Unknown pgoutput message type: " & msgType)
Expand Down
4 changes: 2 additions & 2 deletions tests/all_tests.nim
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ import
test_abandonment_e2e, test_advisory_lock, test_auth, test_cancel_e2e, test_dsn,
test_e2e, test_fill_recvbuf, test_keepalive, test_largeobject, test_network_failure,
test_physical_replication, test_pool, test_protocol, test_protocol_fuzz,
test_replication_keepalive, test_rowdata, test_sql, test_ssl, test_tracing,
test_types, test_pool_cluster
test_replication, test_replication_keepalive, test_rowdata, test_sql, test_ssl,
test_tracing, test_types, test_pool_cluster
{.pop.}
143 changes: 143 additions & 0 deletions tests/test_replication.nim
Original file line number Diff line number Diff line change
Expand Up @@ -369,3 +369,146 @@ suite "pgoutput decoder":
test "unknown message type raises":
expect(ProtocolError):
discard parsePgOutputMessage(@[byte('Z')])

suite "pgoutput decoder bounds checking":
# Each case feeds a truncated or forged message and asserts a catchable
# ProtocolError instead of an out-of-bounds read (which would be an
# uncatchable IndexDefect, or undefined behaviour under -d:danger).

test "Begin truncated raises ProtocolError":
# 'B' followed by only 4 of the 8 finalLsn bytes
var data = @[byte('B'), 0, 0, 0, 0]
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Commit truncated raises ProtocolError":
# 'C' + flags only, the three LSN/time fields are missing
var data = @[byte('C'), 0]
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Origin truncated raises ProtocolError":
var data = @[byte('O'), 0, 0, 0]
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Type truncated raises ProtocolError":
var data = @[byte('Y'), 0, 0]
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Insert with column count past end raises ProtocolError":
var data: seq[byte]
data.add(byte('I'))
data.addInt32(16384'i32)
data.add(byte('N'))
data.addInt16(100'i16) # claims 100 columns, none follow
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Insert with negative column count raises ProtocolError":
var data: seq[byte]
data.add(byte('I'))
data.addInt32(16384'i32)
data.add(byte('N'))
data.addInt16(-1'i16)
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Insert tuple data length past end raises ProtocolError":
var data: seq[byte]
data.add(byte('I'))
data.addInt32(16384'i32)
data.add(byte('N'))
data.addInt16(1'i16) # 1 column
data.add(byte('t'))
data.addInt32(1000'i32) # claims 1000 bytes
data.add(byte('x')) # but only 1 present
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Insert tuple with negative data length raises ProtocolError":
var data: seq[byte]
data.add(byte('I'))
data.addInt32(16384'i32)
data.add(byte('N'))
data.addInt16(1'i16)
data.add(byte('t'))
data.addInt32(-5'i32)
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Relation truncated mid-column raises ProtocolError":
var data: seq[byte]
data.add(byte('R'))
data.addInt32(16384'i32)
for c in "public":
data.add(byte(c))
data.add(0'u8)
for c in "t":
data.add(byte(c))
data.add(0'u8)
data.add(byte('d')) # replicaIdentity
data.addInt16(1'i16) # 1 column
data.add(1'u8) # flags
for c in "id":
data.add(byte(c))
data.add(0'u8)
data.addInt32(23'i32) # typeOid
# typeMod missing — truncated here
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Truncate with oversized relation count raises ProtocolError":
var data: seq[byte]
data.add(byte('T'))
data.addInt32(1_000_000'i32) # claims 1M relations
data.add(0'u8) # options
data.addInt32(16384'i32) # only one actually present
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Truncate with negative relation count raises ProtocolError":
var data: seq[byte]
data.add(byte('T'))
data.addInt32(-1'i32)
data.add(0'u8)
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Update with marker past end raises ProtocolError":
var data: seq[byte]
data.add(byte('U'))
data.addInt32(16384'i32) # relationId, then nothing (no marker byte)
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Message content length past end raises ProtocolError":
var data: seq[byte]
data.add(byte('M'))
data.add(0'u8) # flags
data.addInt64(0x800'i64) # lsn
for c in "p":
data.add(byte(c))
data.add(0'u8) # prefix terminator
data.addInt32(9999'i32) # claims 9999 bytes
data.add(byte('x')) # but only 1 present
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "Message truncated header raises ProtocolError":
var data = @[byte('M'), 0, 0, 0]
expect(ProtocolError):
discard parsePgOutputMessage(data)

test "valid messages still decode after hardening":
# Regression guard: a well-formed Begin must continue to parse cleanly.
var data: seq[byte]
data.add(byte('B'))
data.addInt64(0x500'i64)
data.addInt64(99999'i64)
data.addInt32(42'i32)
let msg = parsePgOutputMessage(data)
check msg.kind == pomkBegin
check msg.begin.finalLsn == Lsn(0x500'u64)
check msg.begin.xid == 42'i32