Skip to content

Commit 2f89da5

Browse files
committed
ssl: look for ~/.postgresql/root.crt
If the sslrootcert option hasn't been specified, use ~/.postgresql/root.crt if it exists (or %APPDATA%\postgresql\root.crt on Windows). This is what libpq does. See - https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-CONNECT-SSLROOTCERT - https://www.postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES
1 parent 300ec9b commit 2f89da5

2 files changed

Lines changed: 133 additions & 7 deletions

File tree

ssl.go

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,39 @@ import (
88
"os"
99
"os/user"
1010
"path/filepath"
11+
"runtime"
1112
"strings"
1213
)
1314

15+
var testUser *user.User // for replacing user.Current() in tests
16+
1417
// ssl generates a function to upgrade a net.Conn based on the "sslmode" and
1518
// related settings. The function is nil when no upgrade should take place.
1619
func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
20+
var usr *user.User
21+
// usr.Current() might fail when cross-compiling. We have to ignore the
22+
// error and continue without home directory defaults, since we wouldn't
23+
// know from where to load certificates.
24+
if testUser != nil {
25+
usr = new(user.User)
26+
*usr = *testUser
27+
} else {
28+
usr, _ = user.Current()
29+
}
30+
1731
verifyCaOnly := false
1832
tlsConf := tls.Config{}
33+
34+
if usr != nil && o["sslmode"] != "disable" && o["sslrootcert"] == "" {
35+
// https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNECT-SSLROOTCERT
36+
// https://www.postgresql.org/docs/current/libpq-ssl.html#LIBQ-SSL-CERTIFICATES
37+
if runtime.GOOS == "windows" {
38+
o["sslrootcert"] = filepath.Join(usr.HomeDir, "AppData", "Roaming", "postgresql", "root.crt")
39+
} else {
40+
o["sslrootcert"] = filepath.Join(usr.HomeDir, ".postgresql", "root.crt")
41+
}
42+
}
43+
1944
switch mode := o["sslmode"]; mode {
2045
// "require" is the default.
2146
case "", "require":
@@ -61,7 +86,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
6186
tlsConf.ServerName = o["host"]
6287
}
6388

64-
err := sslClientCertificates(&tlsConf, o)
89+
err := sslClientCertificates(&tlsConf, o, usr)
6590
if err != nil {
6691
return nil, err
6792
}
@@ -93,7 +118,7 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) {
93118
// "sslkey" settings, or if they aren't set, from the .postgresql directory
94119
// in the user's home directory. The configured files must exist and have
95120
// the correct permissions.
96-
func sslClientCertificates(tlsConf *tls.Config, o values) error {
121+
func sslClientCertificates(tlsConf *tls.Config, o values, user *user.User) error {
97122
sslinline := o["sslinline"]
98123
if sslinline == "true" {
99124
cert, err := tls.X509KeyPair([]byte(o["sslcert"]), []byte(o["sslkey"]))
@@ -104,11 +129,6 @@ func sslClientCertificates(tlsConf *tls.Config, o values) error {
104129
return nil
105130
}
106131

107-
// user.Current() might fail when cross-compiling. We have to ignore the
108-
// error and continue without home directory defaults, since we wouldn't
109-
// know from where to load them.
110-
user, _ := user.Current()
111-
112132
// In libpq, the client certificate is only loaded if the setting is not blank.
113133
//
114134
// https://github.com/postgres/postgres/blob/REL9_6_2/src/interfaces/libpq/fe-secure-openssl.c#L1036-L1037

ssl_test.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package pq
44

55
import (
66
"bytes"
7+
"context"
78
_ "crypto/sha256"
89
"crypto/tls"
910
"crypto/x509"
@@ -13,6 +14,7 @@ import (
1314
"io"
1415
"net"
1516
"os"
17+
"os/user"
1618
"path/filepath"
1719
"strings"
1820
"testing"
@@ -377,6 +379,110 @@ func TestSNISupport(t *testing.T) {
377379
}
378380
}
379381

382+
func TestDefaultRootCert(t *testing.T) {
383+
homeDir, err := setupHomeWithRootCRT(t)
384+
if err != nil {
385+
t.Fatalf("setup mock $HOME: %v", err)
386+
}
387+
388+
testUser = &user.User{
389+
// no leading slash to we can be sure that $HOME/.postgresql/root.crt
390+
// does not exist
391+
HomeDir: homeDir,
392+
}
393+
defer func() { testUser = nil }()
394+
395+
o := values{"sslmode": "verify-ca"}
396+
397+
upgrade, err := ssl(o)
398+
if err != nil {
399+
t.Fatal(err)
400+
}
401+
402+
addr, handshakeErr := mockTLSServer(t, "certs/server.crt", "certs/server.key")
403+
404+
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
405+
defer cancel()
406+
407+
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr)
408+
if err != nil {
409+
t.Fatal(err)
410+
}
411+
defer conn.Close()
412+
413+
if _, err := upgrade(conn); err != nil {
414+
t.Fatal(err)
415+
}
416+
417+
select {
418+
case <-ctx.Done():
419+
t.Fatal(ctx.Err())
420+
case err := <-handshakeErr:
421+
if err != nil {
422+
t.Fatal(err)
423+
}
424+
}
425+
}
426+
427+
func setupHomeWithRootCRT(t *testing.T) (string, error) {
428+
t.Helper()
429+
430+
homeDir, err := os.MkdirTemp("", "lib-pg-ssl-test-*")
431+
if err != nil {
432+
return "", err
433+
}
434+
t.Cleanup(func() { os.RemoveAll(homeDir) })
435+
436+
err = os.MkdirAll(filepath.Join(homeDir, ".postgresql"), 0700)
437+
if err != nil {
438+
return "", err
439+
}
440+
441+
b, err := os.ReadFile("certs/root.crt")
442+
if err != nil {
443+
return "", err
444+
}
445+
446+
err = os.WriteFile(filepath.Join(homeDir, ".postgresql", "root.crt"), b, 0600)
447+
if err != nil {
448+
return "", err
449+
}
450+
451+
return homeDir, nil
452+
}
453+
454+
func mockTLSServer(t *testing.T, certFile, keyFile string) (string, chan error) {
455+
t.Helper()
456+
457+
ln, err := net.Listen("tcp", "localhost:0")
458+
if err != nil {
459+
t.Fatal(err)
460+
}
461+
t.Cleanup(func() { ln.Close() })
462+
463+
serverCert, err := tls.LoadX509KeyPair(certFile, keyFile)
464+
if err != nil {
465+
t.Fatal(err)
466+
}
467+
serverConf := &tls.Config{
468+
Certificates: []tls.Certificate{serverCert},
469+
}
470+
471+
handshakeErr := make(chan error, 1)
472+
go func() {
473+
conn, err := ln.Accept()
474+
if err != nil {
475+
t.Logf("mockTLSServer: cannot Accept: %v", err)
476+
return
477+
}
478+
defer conn.Close()
479+
480+
handshakeErr <- tls.Server(conn, serverConf).Handshake()
481+
}()
482+
483+
return ln.Addr().String(), handshakeErr
484+
}
485+
380486
// Make a postgres mock server to test TLS SNI
381487
//
382488
// Accepts postgres StartupMessage and handles TLS clientHello, then closes a connection.

0 commit comments

Comments
 (0)