From 7c2a3d39f07705dde338bea9a055f2c41daaf9dd Mon Sep 17 00:00:00 2001 From: fox0430 Date: Wed, 8 Apr 2026 19:34:23 +0900 Subject: [PATCH] Add Direct SSL Negotiation --- async_postgres/pg_connection.nim | 239 ++++++++++++++++++---------- tests/test_dsn.nim | 23 +++ tests/test_e2e.nim | 39 +++++ tests/test_ssl.nim | 265 +++++++++++++++++++++++++++++++ 4 files changed, 478 insertions(+), 88 deletions(-) diff --git a/async_postgres/pg_connection.nim b/async_postgres/pg_connection.nim index db75828..b933d2a 100644 --- a/async_postgres/pg_connection.nim +++ b/async_postgres/pg_connection.nim @@ -6,6 +6,7 @@ import async_backend, pg_protocol, pg_auth, pg_types when hasChronos: import chronos/streams/tlsstream + import bearssl/ssl as bearssl_ssl import pg_bearssl elif hasAsyncDispatch: import std/asyncnet @@ -69,6 +70,11 @@ type sslVerifyCa ## Require SSL + verify CA chain (no hostname verification) sslVerifyFull ## Require SSL + verify CA chain and hostname + SslNegotiation* = enum + ## SSL negotiation method for the connection. + sslnPostgres ## Traditional SSLRequest negotiation (default) + sslnDirect ## Direct SSL: start TLS immediately without SSLRequest (PostgreSQL 17+) + TargetSessionAttrs* = enum ## Target server type for multi-host failover (libpq compatible). tsaAny ## Connect to any server (default) @@ -90,6 +96,7 @@ type password*: string database*: string sslMode*: SslMode + sslNegotiation*: SslNegotiation ## SSL negotiation method (default: sslnPostgres) sslRootCert*: string ## PEM-encoded CA certificate(s) for sslVerifyCa/sslVerifyFull applicationName*: string connectTimeout*: Duration ## TCP connect timeout (default: no timeout) @@ -135,6 +142,8 @@ type tlsStream: TLSAsyncStream trustAnchorBufs: seq[seq[byte]] ## Backing memory for custom trust anchor pointers x509Capture: X509CertCaptureContext ## X509 wrapper for cert capture + alpnNames: cstringArray + ## Backing storage for ALPN protocol names (must outlive TLS session) elif hasAsyncDispatch: socket*: AsyncSocket serverCertDer: seq[byte] ## DER-encoded server certificate for SCRAM channel binding @@ -717,6 +726,9 @@ proc closeTransport(conn: PgConnection) {.async.} = except CatchableError: discard conn.tlsStream = nil + if conn.alpnNames != nil: + deallocCStringArray(conn.alpnNames) + conn.alpnNames = nil if conn.baseReader != nil: try: await conn.baseReader.closeWait() @@ -739,6 +751,116 @@ proc closeTransport(conn: PgConnection) {.async.} = conn.socket.close() conn.socket = nil +proc setupTls(conn: PgConnection, config: ConnConfig) {.async.} = + ## Perform TLS handshake on an already-connected transport. + ## Shared by both traditional (SSLRequest) and direct SSL negotiation. + when hasChronos: + conn.baseReader = newAsyncStreamReader(conn.transport) + conn.baseWriter = newAsyncStreamWriter(conn.transport) + + let flags = + case config.sslMode + of sslVerifyFull: + {} + of sslVerifyCa: + {TLSFlags.NoVerifyServerName} + else: + {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName} + + let serverName = if config.sslMode == sslVerifyFull: config.host else: "" + + if config.sslRootCert.len > 0: + let parsed = parseTrustAnchors(config.sslRootCert) + conn.trustAnchorBufs = parsed.backing + # Must outlive TLS session (see parseTrustAnchors doc) + conn.tlsStream = newTLSClientAsyncStream( + conn.baseReader, + conn.baseWriter, + serverName, + flags = flags, + minVersion = TLSVersion.TLS12, + maxVersion = TLSVersion.TLS12, + trustAnchors = parsed.store, + ) + else: + conn.tlsStream = newTLSClientAsyncStream( + conn.baseReader, + conn.baseWriter, + serverName, + flags = flags, + minVersion = TLSVersion.TLS12, + maxVersion = TLSVersion.TLS12, + ) + if config.sslNegotiation == sslnDirect: + # Direct SSL requires ALPN "postgresql" (PostgreSQL 17+). + # Must set ALPN then re-reset the client context so the handshake + # state machine picks up the protocol names for the ClientHello. + conn.alpnNames = allocCStringArray(["postgresql"]) + let names = conn.alpnNames + {. + emit: [ + conn.tlsStream[].ccontext.eng, ".protocol_names = (const char **)", names, ";" + ] + .} + conn.tlsStream.ccontext.eng.protocolNamesNum = 1 + discard sslClientReset(conn.tlsStream.ccontext, nil, 0) + installX509Capture( + conn.x509Capture, conn.tlsStream.ccontext.eng, addr conn.serverCertDer + ) + await conn.tlsStream.handshake() + conn.reader = conn.tlsStream.reader + conn.writer = conn.tlsStream.writer + conn.sslEnabled = true + elif hasAsyncDispatch: + when defined(ssl): + let verifyMode = + case config.sslMode + of sslVerifyCa, sslVerifyFull: SslCVerifyMode.CVerifyPeer + else: SslCVerifyMode.CVerifyNone + + var ctx: SslContext + var tmpPath: string + if config.sslRootCert.len > 0: + let (tmpFile, tp) = createTempFile("pg_ca_", ".pem") + tmpPath = tp + try: + tmpFile.write(config.sslRootCert) + tmpFile.close() + ctx = newContext(verifyMode = verifyMode, caFile = tmpPath) + except CatchableError: + removeFile(tmpPath) + raise + else: + ctx = newContext(verifyMode = verifyMode) + + if config.sslNegotiation == sslnDirect: + # Direct SSL requires ALPN "postgresql" (PostgreSQL 17+) + const alpnProto = "\x0apostgresql" + discard + SSL_CTX_set_alpn_protos(ctx.context, alpnProto.cstring, cuint(alpnProto.len)) + + try: + let hostname = if config.sslMode == sslVerifyFull: config.host else: "" + wrapConnectedSocket(ctx, conn.socket, handshakeAsClient, hostname) + conn.sslEnabled = true + # Extract server certificate DER for SCRAM-SHA-256-PLUS channel binding + let peerCert = SSL_get_peer_certificate(conn.socket.sslHandle) + if peerCert != nil: + try: + let derStr = i2d_X509(peerCert) + if derStr.len > 0: + conn.serverCertDer = newSeq[byte](derStr.len) + for i in 0 ..< derStr.len: + conn.serverCertDer[i] = byte(derStr[i]) + finally: + X509_free(peerCert) + finally: + if tmpPath.len > 0: + removeFile(tmpPath) + else: + raise + newException(PgConnectionError, "SSL support requires compiling with -d:ssl") + proc negotiateSSL(conn: PgConnection, config: ConnConfig) {.async.} = ## Send SSLRequest and negotiate TLS if server accepts. let sslReq = encodeSSLRequest() @@ -760,93 +882,7 @@ proc negotiateSSL(conn: PgConnection, config: ConnConfig) {.async.} = case respChar of 'S': - when hasChronos: - conn.baseReader = newAsyncStreamReader(conn.transport) - conn.baseWriter = newAsyncStreamWriter(conn.transport) - - let flags = - case config.sslMode - of sslVerifyFull: - {} - of sslVerifyCa: - {TLSFlags.NoVerifyServerName} - else: - {TLSFlags.NoVerifyHost, TLSFlags.NoVerifyServerName} - - let serverName = if config.sslMode == sslVerifyFull: config.host else: "" - - if config.sslRootCert.len > 0: - let parsed = parseTrustAnchors(config.sslRootCert) - conn.trustAnchorBufs = parsed.backing - # Must outlive TLS session (see parseTrustAnchors doc) - conn.tlsStream = newTLSClientAsyncStream( - conn.baseReader, - conn.baseWriter, - serverName, - flags = flags, - minVersion = TLSVersion.TLS12, - maxVersion = TLSVersion.TLS12, - trustAnchors = parsed.store, - ) - else: - conn.tlsStream = newTLSClientAsyncStream( - conn.baseReader, - conn.baseWriter, - serverName, - flags = flags, - minVersion = TLSVersion.TLS12, - maxVersion = TLSVersion.TLS12, - ) - installX509Capture( - conn.x509Capture, conn.tlsStream.ccontext.eng, addr conn.serverCertDer - ) - await conn.tlsStream.handshake() - conn.reader = conn.tlsStream.reader - conn.writer = conn.tlsStream.writer - conn.sslEnabled = true - elif hasAsyncDispatch: - when defined(ssl): - let verifyMode = - case config.sslMode - of sslVerifyCa, sslVerifyFull: SslCVerifyMode.CVerifyPeer - else: SslCVerifyMode.CVerifyNone - - var ctx: SslContext - var tmpPath: string - if config.sslRootCert.len > 0: - let (tmpFile, tp) = createTempFile("pg_ca_", ".pem") - tmpPath = tp - try: - tmpFile.write(config.sslRootCert) - tmpFile.close() - ctx = newContext(verifyMode = verifyMode, caFile = tmpPath) - except: - removeFile(tmpPath) - raise - else: - ctx = newContext(verifyMode = verifyMode) - - try: - let hostname = if config.sslMode == sslVerifyFull: config.host else: "" - wrapConnectedSocket(ctx, conn.socket, handshakeAsClient, hostname) - conn.sslEnabled = true - # Extract server certificate DER for SCRAM-SHA-256-PLUS channel binding - let peerCert = SSL_get_peer_certificate(conn.socket.sslHandle) - if peerCert != nil: - try: - let derStr = i2d_X509(peerCert) - if derStr.len > 0: - conn.serverCertDer = newSeq[byte](derStr.len) - for i in 0 ..< derStr.len: - conn.serverCertDer[i] = byte(derStr[i]) - finally: - X509_free(peerCert) - finally: - if tmpPath.len > 0: - removeFile(tmpPath) - else: - raise - newException(PgConnectionError, "SSL support requires compiling with -d:ssl") + await conn.setupTls(config) of 'N': if config.sslMode in {sslRequire, sslVerifyCa, sslVerifyFull}: raise newException(PgConnectionError, "Server does not support SSL") @@ -957,6 +993,19 @@ proc connectToHost( ): Future[PgConnection] {.async.} = ## Connect to a single PostgreSQL host. Internal helper for multi-host connect. + if config.sslNegotiation == sslnDirect: + if config.sslMode == sslDisable: + raise newException( + PgConnectionError, "sslnegotiation=direct is incompatible with sslmode=disable" + ) + if config.sslMode in {sslAllow, sslPrefer}: + raise newException( + PgConnectionError, + "sslnegotiation=direct is incompatible with sslmode=" & + (if config.sslMode == sslAllow: "allow" else: "prefer") & + " (direct SSL does not support fallback)", + ) + if config.sslMode == sslAllow: # sslAllow: try plaintext first, then fall back to SSL. var plainConfig = config @@ -1059,7 +1108,10 @@ proc connectToHost( try: # SSL negotiation (before StartupMessage) — skip for Unix sockets if config.sslMode != sslDisable and not isUnix: - await negotiateSSL(conn, config) + if config.sslNegotiation == sslnDirect: + await conn.setupTls(config) + else: + await negotiateSSL(conn, config) when hasChronos: # If SSL was not established, create plain streams @@ -1739,6 +1791,15 @@ proc parseSslMode(s: string): SslMode = else: raise newException(PgError, "Invalid sslmode: " & s) +proc parseSslNegotiation(s: string): SslNegotiation = + case s + of "postgres": + sslnPostgres + of "direct": + sslnDirect + else: + raise newException(PgError, "Invalid sslnegotiation: " & s) + proc parseTargetSessionAttrs(s: string): TargetSessionAttrs = case s of "any": @@ -1779,6 +1840,8 @@ proc applyParam(result: var ConnConfig, key, val: string) = result.password = val of "sslmode": result.sslMode = parseSslMode(val) + of "sslnegotiation": + result.sslNegotiation = parseSslNegotiation(val) of "application_name": result.applicationName = val of "connect_timeout": diff --git a/tests/test_dsn.nim b/tests/test_dsn.nim index 210d47a..2a9a928 100644 --- a/tests/test_dsn.nim +++ b/tests/test_dsn.nim @@ -94,6 +94,21 @@ suite "parseDsn": else: discard + test "query param sslnegotiation": + for mode in ["postgres", "direct"]: + let cfg = parseDsn("postgresql://host/db?sslnegotiation=" & mode) + case mode + of "postgres": + check cfg.sslNegotiation == sslnPostgres + of "direct": + check cfg.sslNegotiation == sslnDirect + else: + discard + + test "error: invalid sslnegotiation": + expect PgError: + discard parseDsn("postgresql://host/db?sslnegotiation=bogus") + test "query param application_name": let cfg = parseDsn("postgresql://host/db?application_name=myapp") check cfg.applicationName == "myapp" @@ -410,6 +425,14 @@ suite "parseDsn keyword=value": let cfg = parseDsn("host=h dbname=d sslmode=require") check cfg.sslMode == sslRequire + test "sslnegotiation parameter": + let cfg = parseDsn("host=h dbname=d sslnegotiation=direct") + check cfg.sslNegotiation == sslnDirect + + test "error: invalid sslnegotiation (key-value)": + expect PgError: + discard parseDsn("host=h sslnegotiation=bogus") + test "connect_timeout parameter": let cfg = parseDsn("host=h connect_timeout=30") check cfg.connectTimeout == seconds(30) diff --git a/tests/test_e2e.nim b/tests/test_e2e.nim index 80f705b..decdd03 100644 --- a/tests/test_e2e.nim +++ b/tests/test_e2e.nim @@ -333,6 +333,45 @@ suite "E2E: SSL Connection": waitFor t() +suite "E2E: Direct SSL Negotiation": + test "sslnDirect + sslRequire connects with SSL": + proc t() {.async.} = + let conn = await connect( + ConnConfig( + host: PgHost, + port: PgPort, + user: PgUser, + password: PgPassword, + database: PgDatabase, + sslMode: sslRequire, + sslNegotiation: sslnDirect, + ) + ) + doAssert conn.state == csReady + doAssert conn.sslEnabled == true + await conn.close() + + waitFor t() + + test "sslnDirect query over SSL connection": + proc t() {.async.} = + let conn = await connect( + ConnConfig( + host: PgHost, + port: PgPort, + user: PgUser, + password: PgPassword, + database: PgDatabase, + sslMode: sslRequire, + sslNegotiation: sslnDirect, + ) + ) + let results = await conn.simpleQuery("SELECT 42 AS answer") + doAssert results[0].rows[0][0].get().toString() == "42" + await conn.close() + + waitFor t() + suite "E2E: SSL Verification": test "sslVerifyCa connects with CA verification": proc t() {.async.} = diff --git a/tests/test_ssl.nim b/tests/test_ssl.nim index 4c072d6..90fcc93 100644 --- a/tests/test_ssl.nim +++ b/tests/test_ssl.nim @@ -545,3 +545,268 @@ suite "SSL negotiation - sslDisable": check firstMsgVersion == 196608'i32 check connState == csReady check connSslEnabled == false + +suite "Direct SSL negotiation (sslnDirect)": + test "sslnDirect sends TLS ClientHello as first bytes (no SSLRequest)": + when defined(ssl) or hasChronos: + var firstByte: byte = 0 + var gotData = false + var raised = false + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + # Read first 5 bytes (TLS record header: type + version + length) + let data = await readN(st, 5) + firstByte = data[0] + gotData = true + # Read remaining TLS ClientHello body to avoid broken pipe + let bodyLen = int(data[3]) shl 8 or int(data[4]) + if bodyLen > 0 and bodyLen < 65536: + try: + discard await readN(st, bodyLen) + except CatchableError: + discard + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + database: "test", + sslMode: sslRequire, + sslNegotiation: sslnDirect, + ) + + try: + let conn = await connect(config) + await conn.close() + except CatchableError: + raised = true + + await serverFut + await closeServer(ms) + + waitFor testBody() + # TLS record content type 0x16 = Handshake (ClientHello) + check gotData + check firstByte == 0x16'u8 + # Connection will fail because mock server doesn't do TLS + check raised + else: + skip() + + test "sslnDirect + sslVerifyCa sends TLS ClientHello as first bytes": + when defined(ssl) or hasChronos: + var firstByte: byte = 0 + var gotData = false + var raised = false + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + let data = await readN(st, 5) + firstByte = data[0] + gotData = true + let bodyLen = int(data[3]) shl 8 or int(data[4]) + if bodyLen > 0 and bodyLen < 65536: + try: + discard await readN(st, bodyLen) + except CatchableError: + discard + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + database: "test", + sslMode: sslVerifyCa, + sslNegotiation: sslnDirect, + ) + + try: + let conn = await connect(config) + await conn.close() + except CatchableError: + raised = true + + await serverFut + await closeServer(ms) + + waitFor testBody() + check gotData + check firstByte == 0x16'u8 + check raised + else: + skip() + + test "sslnDirect + sslVerifyFull sends TLS ClientHello as first bytes": + when defined(ssl) or hasChronos: + var firstByte: byte = 0 + var gotData = false + var raised = false + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + let data = await readN(st, 5) + firstByte = data[0] + gotData = true + let bodyLen = int(data[3]) shl 8 or int(data[4]) + if bodyLen > 0 and bodyLen < 65536: + try: + discard await readN(st, bodyLen) + except CatchableError: + discard + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + database: "test", + sslMode: sslVerifyFull, + sslNegotiation: sslnDirect, + ) + + try: + let conn = await connect(config) + await conn.close() + except CatchableError: + raised = true + + await serverFut + await closeServer(ms) + + waitFor testBody() + check gotData + check firstByte == 0x16'u8 + check raised + else: + skip() + + test "sslnDirect + sslDisable raises error": + var raised = false + + proc testBody() {.async.} = + let config = ConnConfig( + host: "127.0.0.1", + port: 5432, + user: "test", + database: "test", + sslMode: sslDisable, + sslNegotiation: sslnDirect, + ) + + try: + let conn = await connect(config) + await conn.close() + except PgError: + raised = true + + waitFor testBody() + check raised + + test "sslnDirect + sslAllow raises error": + var raised = false + + proc testBody() {.async.} = + let config = ConnConfig( + host: "127.0.0.1", + port: 5432, + user: "test", + database: "test", + sslMode: sslAllow, + sslNegotiation: sslnDirect, + ) + + try: + let conn = await connect(config) + await conn.close() + except PgError: + raised = true + + waitFor testBody() + check raised + + test "sslnDirect + sslPrefer raises error": + var raised = false + + proc testBody() {.async.} = + let config = ConnConfig( + host: "127.0.0.1", + port: 5432, + user: "test", + database: "test", + sslMode: sslPrefer, + sslNegotiation: sslnDirect, + ) + + try: + let conn = await connect(config) + await conn.close() + except PgError: + raised = true + + waitFor testBody() + check raised + + test "sslnPostgres is default and sends SSLRequest as before": + var sslReqMagic: int32 = 0 + + proc testBody() {.async.} = + let ms = startMockServer() + + proc serverHandler() {.async.} = + let st = await ms.accept() + try: + let sslReq = await readN(st, 8) + sslReqMagic = decodeInt32(sslReq, 4) + await sendBytes(st, @[byte('N')]) + await drainStartupMessage(st) + await sendAuthOkAndReady(st) + await drainUntilClose(st) + except CatchableError: + discard + await closeClient(st) + + let serverFut = serverHandler() + + let config = ConnConfig( + host: "127.0.0.1", + port: ms.port, + user: "test", + database: "test", + sslMode: sslPrefer, # sslNegotiation defaults to sslnPostgres + ) + + let conn = await connect(config) + await conn.close() + + await serverFut + await closeServer(ms) + + waitFor testBody() + check sslReqMagic == 80877103'i32