Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 151 additions & 88 deletions async_postgres/pg_connection.nim
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import async_backend, pg_protocol, pg_auth, pg_types

when hasChronos:
import chronos/streams/tlsstream
import bearssl/ssl as bearssl_ssl
import pg_bearssl
elif hasAsyncDispatch:
import std/asyncnet
Expand Down Expand Up @@ -69,6 +70,11 @@ type
sslVerifyCa ## Require SSL + verify CA chain (no hostname verification)
sslVerifyFull ## Require SSL + verify CA chain and hostname

SslNegotiation* = enum
## SSL negotiation method for the connection.
sslnPostgres ## Traditional SSLRequest negotiation (default)
sslnDirect ## Direct SSL: start TLS immediately without SSLRequest (PostgreSQL 17+)

TargetSessionAttrs* = enum
## Target server type for multi-host failover (libpq compatible).
tsaAny ## Connect to any server (default)
Expand All @@ -90,6 +96,7 @@ type
password*: string
database*: string
sslMode*: SslMode
sslNegotiation*: SslNegotiation ## SSL negotiation method (default: sslnPostgres)
sslRootCert*: string ## PEM-encoded CA certificate(s) for sslVerifyCa/sslVerifyFull
applicationName*: string
connectTimeout*: Duration ## TCP connect timeout (default: no timeout)
Expand Down Expand Up @@ -135,6 +142,8 @@ type
tlsStream: TLSAsyncStream
trustAnchorBufs: seq[seq[byte]] ## Backing memory for custom trust anchor pointers
x509Capture: X509CertCaptureContext ## X509 wrapper for cert capture
alpnNames: cstringArray
## Backing storage for ALPN protocol names (must outlive TLS session)
elif hasAsyncDispatch:
socket*: AsyncSocket
serverCertDer: seq[byte] ## DER-encoded server certificate for SCRAM channel binding
Expand Down Expand Up @@ -717,6 +726,9 @@ proc closeTransport(conn: PgConnection) {.async.} =
except CatchableError:
discard
conn.tlsStream = nil
if conn.alpnNames != nil:
deallocCStringArray(conn.alpnNames)
conn.alpnNames = nil
if conn.baseReader != nil:
try:
await conn.baseReader.closeWait()
Expand All @@ -739,6 +751,116 @@ proc closeTransport(conn: PgConnection) {.async.} =
conn.socket.close()
conn.socket = nil

proc setupTls(conn: PgConnection, config: ConnConfig) {.async.} =
## Perform TLS handshake on an already-connected transport.
## Shared by both traditional (SSLRequest) and direct SSL negotiation.
when hasChronos:
conn.baseReader = newAsyncStreamReader(conn.transport)
conn.baseWriter = newAsyncStreamWriter(conn.transport)

let flags =
case config.sslMode
of sslVerifyFull:
{}
of sslVerifyCa:
{TLSFlags.NoVerifyServerName}
else:
{TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName}

let serverName = if config.sslMode == sslVerifyFull: config.host else: ""

if config.sslRootCert.len > 0:
let parsed = parseTrustAnchors(config.sslRootCert)
conn.trustAnchorBufs = parsed.backing
# Must outlive TLS session (see parseTrustAnchors doc)
conn.tlsStream = newTLSClientAsyncStream(
conn.baseReader,
conn.baseWriter,
serverName,
flags = flags,
minVersion = TLSVersion.TLS12,
maxVersion = TLSVersion.TLS12,
trustAnchors = parsed.store,
)
else:
conn.tlsStream = newTLSClientAsyncStream(
conn.baseReader,
conn.baseWriter,
serverName,
flags = flags,
minVersion = TLSVersion.TLS12,
maxVersion = TLSVersion.TLS12,
)
if config.sslNegotiation == sslnDirect:
# Direct SSL requires ALPN "postgresql" (PostgreSQL 17+).
# Must set ALPN then re-reset the client context so the handshake
# state machine picks up the protocol names for the ClientHello.
conn.alpnNames = allocCStringArray(["postgresql"])
let names = conn.alpnNames
{.
emit: [
conn.tlsStream[].ccontext.eng, ".protocol_names = (const char **)", names, ";"
]
.}
conn.tlsStream.ccontext.eng.protocolNamesNum = 1
discard sslClientReset(conn.tlsStream.ccontext, nil, 0)
installX509Capture(
conn.x509Capture, conn.tlsStream.ccontext.eng, addr conn.serverCertDer
)
await conn.tlsStream.handshake()
conn.reader = conn.tlsStream.reader
conn.writer = conn.tlsStream.writer
conn.sslEnabled = true
elif hasAsyncDispatch:
when defined(ssl):
let verifyMode =
case config.sslMode
of sslVerifyCa, sslVerifyFull: SslCVerifyMode.CVerifyPeer
else: SslCVerifyMode.CVerifyNone

var ctx: SslContext
var tmpPath: string
if config.sslRootCert.len > 0:
let (tmpFile, tp) = createTempFile("pg_ca_", ".pem")
tmpPath = tp
try:
tmpFile.write(config.sslRootCert)
tmpFile.close()
ctx = newContext(verifyMode = verifyMode, caFile = tmpPath)
except CatchableError:
removeFile(tmpPath)
raise
else:
ctx = newContext(verifyMode = verifyMode)

if config.sslNegotiation == sslnDirect:
# Direct SSL requires ALPN "postgresql" (PostgreSQL 17+)
const alpnProto = "\x0apostgresql"
discard
SSL_CTX_set_alpn_protos(ctx.context, alpnProto.cstring, cuint(alpnProto.len))

try:
let hostname = if config.sslMode == sslVerifyFull: config.host else: ""
wrapConnectedSocket(ctx, conn.socket, handshakeAsClient, hostname)
conn.sslEnabled = true
# Extract server certificate DER for SCRAM-SHA-256-PLUS channel binding
let peerCert = SSL_get_peer_certificate(conn.socket.sslHandle)
if peerCert != nil:
try:
let derStr = i2d_X509(peerCert)
if derStr.len > 0:
conn.serverCertDer = newSeq[byte](derStr.len)
for i in 0 ..< derStr.len:
conn.serverCertDer[i] = byte(derStr[i])
finally:
X509_free(peerCert)
finally:
if tmpPath.len > 0:
removeFile(tmpPath)
else:
raise
newException(PgConnectionError, "SSL support requires compiling with -d:ssl")

proc negotiateSSL(conn: PgConnection, config: ConnConfig) {.async.} =
## Send SSLRequest and negotiate TLS if server accepts.
let sslReq = encodeSSLRequest()
Expand All @@ -760,93 +882,7 @@ proc negotiateSSL(conn: PgConnection, config: ConnConfig) {.async.} =

case respChar
of 'S':
when hasChronos:
conn.baseReader = newAsyncStreamReader(conn.transport)
conn.baseWriter = newAsyncStreamWriter(conn.transport)

let flags =
case config.sslMode
of sslVerifyFull:
{}
of sslVerifyCa:
{TLSFlags.NoVerifyServerName}
else:
{TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName}

let serverName = if config.sslMode == sslVerifyFull: config.host else: ""

if config.sslRootCert.len > 0:
let parsed = parseTrustAnchors(config.sslRootCert)
conn.trustAnchorBufs = parsed.backing
# Must outlive TLS session (see parseTrustAnchors doc)
conn.tlsStream = newTLSClientAsyncStream(
conn.baseReader,
conn.baseWriter,
serverName,
flags = flags,
minVersion = TLSVersion.TLS12,
maxVersion = TLSVersion.TLS12,
trustAnchors = parsed.store,
)
else:
conn.tlsStream = newTLSClientAsyncStream(
conn.baseReader,
conn.baseWriter,
serverName,
flags = flags,
minVersion = TLSVersion.TLS12,
maxVersion = TLSVersion.TLS12,
)
installX509Capture(
conn.x509Capture, conn.tlsStream.ccontext.eng, addr conn.serverCertDer
)
await conn.tlsStream.handshake()
conn.reader = conn.tlsStream.reader
conn.writer = conn.tlsStream.writer
conn.sslEnabled = true
elif hasAsyncDispatch:
when defined(ssl):
let verifyMode =
case config.sslMode
of sslVerifyCa, sslVerifyFull: SslCVerifyMode.CVerifyPeer
else: SslCVerifyMode.CVerifyNone

var ctx: SslContext
var tmpPath: string
if config.sslRootCert.len > 0:
let (tmpFile, tp) = createTempFile("pg_ca_", ".pem")
tmpPath = tp
try:
tmpFile.write(config.sslRootCert)
tmpFile.close()
ctx = newContext(verifyMode = verifyMode, caFile = tmpPath)
except:
removeFile(tmpPath)
raise
else:
ctx = newContext(verifyMode = verifyMode)

try:
let hostname = if config.sslMode == sslVerifyFull: config.host else: ""
wrapConnectedSocket(ctx, conn.socket, handshakeAsClient, hostname)
conn.sslEnabled = true
# Extract server certificate DER for SCRAM-SHA-256-PLUS channel binding
let peerCert = SSL_get_peer_certificate(conn.socket.sslHandle)
if peerCert != nil:
try:
let derStr = i2d_X509(peerCert)
if derStr.len > 0:
conn.serverCertDer = newSeq[byte](derStr.len)
for i in 0 ..< derStr.len:
conn.serverCertDer[i] = byte(derStr[i])
finally:
X509_free(peerCert)
finally:
if tmpPath.len > 0:
removeFile(tmpPath)
else:
raise
newException(PgConnectionError, "SSL support requires compiling with -d:ssl")
await conn.setupTls(config)
of 'N':
if config.sslMode in {sslRequire, sslVerifyCa, sslVerifyFull}:
raise newException(PgConnectionError, "Server does not support SSL")
Expand Down Expand Up @@ -957,6 +993,19 @@ proc connectToHost(
): Future[PgConnection] {.async.} =
## Connect to a single PostgreSQL host. Internal helper for multi-host connect.

if config.sslNegotiation == sslnDirect:
if config.sslMode == sslDisable:
raise newException(
PgConnectionError, "sslnegotiation=direct is incompatible with sslmode=disable"
)
if config.sslMode in {sslAllow, sslPrefer}:
raise newException(
PgConnectionError,
"sslnegotiation=direct is incompatible with sslmode=" &
(if config.sslMode == sslAllow: "allow" else: "prefer") &
" (direct SSL does not support fallback)",
)

if config.sslMode == sslAllow:
# sslAllow: try plaintext first, then fall back to SSL.
var plainConfig = config
Expand Down Expand Up @@ -1059,7 +1108,10 @@ proc connectToHost(
try:
# SSL negotiation (before StartupMessage) — skip for Unix sockets
if config.sslMode != sslDisable and not isUnix:
await negotiateSSL(conn, config)
if config.sslNegotiation == sslnDirect:
await conn.setupTls(config)
else:
await negotiateSSL(conn, config)

when hasChronos:
# If SSL was not established, create plain streams
Expand Down Expand Up @@ -1739,6 +1791,15 @@ proc parseSslMode(s: string): SslMode =
else:
raise newException(PgError, "Invalid sslmode: " & s)

proc parseSslNegotiation(s: string): SslNegotiation =
case s
of "postgres":
sslnPostgres
of "direct":
sslnDirect
else:
raise newException(PgError, "Invalid sslnegotiation: " & s)

proc parseTargetSessionAttrs(s: string): TargetSessionAttrs =
case s
of "any":
Expand Down Expand Up @@ -1779,6 +1840,8 @@ proc applyParam(result: var ConnConfig, key, val: string) =
result.password = val
of "sslmode":
result.sslMode = parseSslMode(val)
of "sslnegotiation":
result.sslNegotiation = parseSslNegotiation(val)
of "application_name":
result.applicationName = val
of "connect_timeout":
Expand Down
23 changes: 23 additions & 0 deletions tests/test_dsn.nim
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ suite "parseDsn":
else:
discard

test "query param sslnegotiation":
for mode in ["postgres", "direct"]:
let cfg = parseDsn("postgresql://host/db?sslnegotiation=" & mode)
case mode
of "postgres":
check cfg.sslNegotiation == sslnPostgres
of "direct":
check cfg.sslNegotiation == sslnDirect
else:
discard

test "error: invalid sslnegotiation":
expect PgError:
discard parseDsn("postgresql://host/db?sslnegotiation=bogus")

test "query param application_name":
let cfg = parseDsn("postgresql://host/db?application_name=myapp")
check cfg.applicationName == "myapp"
Expand Down Expand Up @@ -410,6 +425,14 @@ suite "parseDsn keyword=value":
let cfg = parseDsn("host=h dbname=d sslmode=require")
check cfg.sslMode == sslRequire

test "sslnegotiation parameter":
let cfg = parseDsn("host=h dbname=d sslnegotiation=direct")
check cfg.sslNegotiation == sslnDirect

test "error: invalid sslnegotiation (key-value)":
expect PgError:
discard parseDsn("host=h sslnegotiation=bogus")

test "connect_timeout parameter":
let cfg = parseDsn("host=h connect_timeout=30")
check cfg.connectTimeout == seconds(30)
Expand Down
39 changes: 39 additions & 0 deletions tests/test_e2e.nim
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,45 @@ suite "E2E: SSL Connection":

waitFor t()

suite "E2E: Direct SSL Negotiation":
test "sslnDirect + sslRequire connects with SSL":
proc t() {.async.} =
let conn = await connect(
ConnConfig(
host: PgHost,
port: PgPort,
user: PgUser,
password: PgPassword,
database: PgDatabase,
sslMode: sslRequire,
sslNegotiation: sslnDirect,
)
)
doAssert conn.state == csReady
doAssert conn.sslEnabled == true
await conn.close()

waitFor t()

test "sslnDirect query over SSL connection":
proc t() {.async.} =
let conn = await connect(
ConnConfig(
host: PgHost,
port: PgPort,
user: PgUser,
password: PgPassword,
database: PgDatabase,
sslMode: sslRequire,
sslNegotiation: sslnDirect,
)
)
let results = await conn.simpleQuery("SELECT 42 AS answer")
doAssert results[0].rows[0][0].get().toString() == "42"
await conn.close()

waitFor t()

suite "E2E: SSL Verification":
test "sslVerifyCa connects with CA verification":
proc t() {.async.} =
Expand Down
Loading
Loading