diff --git a/main_test.go b/main_test.go index 3f391ac..07cdf97 100644 --- a/main_test.go +++ b/main_test.go @@ -20,7 +20,13 @@ const ( postgresImage = "docker.io/postgres:16.0-alpine" ) -var postgresDataSource string +// Connection strings for the three databases created in [TestMain] +// These are used in other tests +var ( + postgresDataSource string + postgresDataSourceRR1 string + postgresDataSourceRR2 string +) func withSchemaSQL() testcontainers.CustomizeRequestOption { return func(req *testcontainers.GenericContainerRequest) error { @@ -34,24 +40,74 @@ func withSchemaSQL() testcontainers.CustomizeRequestOption { } func TestMain(m *testing.M) { - var cleanups []func() - cleanup := func(fn func()) { - cleanups = append(cleanups, fn) - } - fatal := func(args ...any) { - fmt.Fprintln(os.Stderr, args...) - for _, fn := range cleanups { - fn() + ctx := context.Background() + cleanups, err := createPostgresContainers(ctx) + if err != nil { + fmt.Printf("did not create postgres containers: %v\n", err) + for _, cleanup := range cleanups { + if cleanup != nil { + cleanup() + } } os.Exit(1) } - ctx := context.Background() + // nolint:gocritic // The docs for Run specify that the returned int is to be passed to os.Exit + retCode := m.Run() + for _, cleanup := range cleanups { + if cleanup != nil { + cleanup() + } + } + os.Exit(retCode) +} + +// createPostgresContainers creates 3 databases as containers. The first is intended to mimic a master database and +// the last 2 are intended to mimic read replicas. +// A []func() is returned to cleanups. +// Package-level connection strings are set. +func createPostgresContainers(ctx context.Context) ([]func(), error) { + // Database connection strings are the same, except the port + connString := func(port string) string { + return fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable&application_name=test", dbUser, dbPassword, port, dbName) + } + + var cleanups []func() + + // Create the master database + cM, mpM, err := createPostgresContainer(ctx, "init-db.sh") + cleanups = append(cleanups, cM) + if err != nil { + return cleanups, err + } + postgresDataSource = connString(mpM) + + // Create first read replica + cRR1, mpRR1, err := createPostgresContainer(ctx, "init-db-rr.sh") + cleanups = append(cleanups, cRR1) + if err != nil { + return cleanups, err + } + postgresDataSourceRR1 = connString(mpRR1) + + // Create second read replica + cRR2, mpRR2, err := createPostgresContainer(ctx, "init-db-rr.sh") + cleanups = append(cleanups, cRR2) + if err != nil { + return cleanups, err + } + postgresDataSourceRR2 = connString(mpRR2) + + return cleanups, nil +} + +// createPostgresContainer creates a single postgres container on the specified port +func createPostgresContainer(ctx context.Context, initFilename string) (func(), string, error) { postgresContainer, err := postgres.Run(ctx, postgresImage, postgres.WithDatabase(dbName), postgres.WithUsername(dbUser), postgres.WithPassword(dbPassword), - postgres.WithInitScripts(filepath.Join("testdata", "init-db.sh")), + postgres.WithInitScripts(filepath.Join("testdata", initFilename)), withSchemaSQL(), testcontainers.WithWaitStrategy( wait.ForLog("database system is ready to accept connections"). @@ -60,28 +116,27 @@ func TestMain(m *testing.M) { ), ) if err != nil { - fatal("error creating postgres container:", err) + return nil, "", fmt.Errorf("error creating postgres container: %w", err) } - cleanup(func() { + + cleanup := func() { if err := postgresContainer.Terminate(ctx); err != nil { fmt.Fprintln(os.Stderr, "error terminating postgres:", err) } - }) + } postgresState, err := postgresContainer.State(ctx) if err != nil { - fatal(err) + return cleanup, "", fmt.Errorf("checking container state: %w", err) } if !postgresState.Running { - fatal("Postgres status:", postgresState.Status) + return cleanup, "", fmt.Errorf("Postgres status %q is not \"running\"", postgresState.Status) } - postgresPort, err := postgresContainer.MappedPort(ctx, "5432/tcp") + mp, err := postgresContainer.MappedPort(ctx, "5432/tcp") if err != nil { - fatal(err) + return cleanup, "", fmt.Errorf("mapped port 5432/tcp does not seem to be available: %w", err) } - postgresDataSource = fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable&application_name=test", dbUser, dbPassword, postgresPort.Port(), dbName) - - os.Exit(m.Run()) + return cleanup, mp.Port(), nil } diff --git a/read_replica_set.go b/read_replica_set.go new file mode 100644 index 0000000..f966df5 --- /dev/null +++ b/read_replica_set.go @@ -0,0 +1,169 @@ +package sequel + +import ( + "context" + "database/sql" + "errors" + "fmt" + "sync" + + "github.com/go-sqlx/sqlx" +) + +var ErrNoReadReplicaConnections = errors.New("no read replica connections available") + +type readReplica struct { + db *sqlx.DB + + next *readReplica +} + +// ReadReplicaSet contains a set of DB connections. It is intended to give fair round robin access +// through a circular singly linked list. +// +// Replicas are appended after the current one. The intended use is to build the replica set before +// querying, but all operations are concurrent-safe. +type ReadReplicaSet struct { + m sync.Mutex + + current *readReplica +} + +// add adds a DB to the collection of read replicas +func (rr *ReadReplicaSet) add(db *sqlx.DB) { + rr.m.Lock() + defer rr.m.Unlock() + + r := readReplica{ + db: db, + } + + // Empty ring, add new DB + if rr.current == nil { + r.next = &r + + rr.current = &r + + return + } + + // Insert new db after current + n := rr.current.next + r.next = n + rr.current.next = &r +} + +// next returns the current DB. The current pointer is advanced. +func (rr *ReadReplicaSet) next() (*sqlx.DB, error) { + rr.m.Lock() + defer rr.m.Unlock() + + if rr.current == nil { + return nil, ErrNoReadReplicaConnections + } + + c := rr.current + + rr.current = rr.current.next + + return c.db, nil +} + +// Close closes all read replica connections +func (rr *ReadReplicaSet) Close() { + if rr == nil { + return + } + + rr.m.Lock() + defer rr.m.Unlock() + + // If this instance has no replicas, current would be nil + if rr.current == nil { + return + } + + first := rr.current + for c := first; ; c = c.next { + c.db.Close() + + if c.next == first { + break + } + } +} + +// Query executes a query against a read replica. Queries that are not SELECTs may not work. +// The args are for any placeholder parameters in the query. +func (rr *ReadReplicaSet) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + if rr == nil { + return nil, ErrNoReadReplicaConnections + } + + db, err := rr.next() + if err != nil { + return nil, fmt.Errorf("did not get read replica connection: %w", err) + } + + return db.QueryContext(ctx, query, args...) +} + +// QueryRow executes a query that is expected to return at most one row against a read replica. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards the +// rest. +func (rr *ReadReplicaSet) QueryRow(ctx context.Context, query string, args ...any) (*sql.Row, error) { + if rr == nil { + return nil, ErrNoReadReplicaConnections + } + + db, err := rr.next() + if err != nil { + return nil, fmt.Errorf("did not get read replica connection: %w", err) + } + + return db.QueryRowContext(ctx, query, args...), nil +} + +// Get populates the given model for the result of the given select query against a read replica. +func (rr *ReadReplicaSet) Get(ctx context.Context, dest Model, query string, args ...any) error { + if rr == nil { + return ErrNoReadReplicaConnections + } + + db, err := rr.next() + if err != nil { + return fmt.Errorf("did not get read replica connection: %w", err) + } + + return db.GetContext(ctx, dest, query, args...) +} + +// GetAll populates the given destination with all the results of the given +// select query (from a read replica). The method will fail if the destination is not a pointer to a +// slice. +func (rr *ReadReplicaSet) GetAll(ctx context.Context, dest any, query string, args ...any) error { + if rr == nil { + return ErrNoReadReplicaConnections + } + + db, err := rr.next() + if err != nil { + return fmt.Errorf("did not get read replica connection: %w", err) + } + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return err + } + + defer rows.Close() + + if err := rows.Err(); err != nil { + return err + } + return sqlx.StructScan(rows, dest) +} diff --git a/read_replica_set_test.go b/read_replica_set_test.go new file mode 100644 index 0000000..7314362 --- /dev/null +++ b/read_replica_set_test.go @@ -0,0 +1,145 @@ +package sequel + +import ( + "errors" + "slices" + "testing" + + "github.com/go-sqlx/sqlx" +) + +func Test_readReplicas_add(t *testing.T) { + tests := []struct { + name string + rrs []*sqlx.DB + }{ + { + name: "add single", + rrs: []*sqlx.DB{sqlx.NewDb(nil, "fakeDriver")}, + }, + { + name: "add multiple", + rrs: []*sqlx.DB{sqlx.NewDb(nil, "fakeDriver"), sqlx.NewDb(nil, "fakeDriver"), sqlx.NewDb(nil, "fakeDriver")}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create ReadReplicaSet + rr := &ReadReplicaSet{} + + // Add connections to read replicas + for _, c := range tt.rrs { + rr.add(c) + } + + // Confirm they are all added + if !rrContains(t, rr, tt.rrs) { + t.Fatal("ReadReplicaSet does not contain all added conns") + } + }) + } +} + +func Test_readReplicas_next(t *testing.T) { + newRRs := func(rrs ...*sqlx.DB) *ReadReplicaSet { + rr := &ReadReplicaSet{} + for _, c := range rrs { + rr.add(c) + } + + return rr + } + + c1, c2, c3 := sqlx.NewDb(nil, "fakeDriver"), sqlx.NewDb(nil, "fakeDriver"), sqlx.NewDb(nil, "fakeDriver") + + emptyRR := newRRs() + singleRR := newRRs(c1) + threeRR := newRRs(c1, c2, c3) + + tests := []struct { + name string + rr *ReadReplicaSet + want *sqlx.DB + wantErr error + }{ + { + name: "empty ReadReplicaSet", + rr: emptyRR, + want: nil, + wantErr: ErrNoReadReplicaConnections, + }, + { + name: "one read replica", + rr: singleRR, + want: c1, + wantErr: nil, + }, + { + name: "one read replica (2 calls)", + rr: singleRR, + want: c1, + wantErr: nil, + }, + { + name: "one read replica (3 calls)", + rr: singleRR, + want: c1, + wantErr: nil, + }, + { + name: "three read replicas", + rr: threeRR, + want: c1, + wantErr: nil, + }, + { + name: "three read replicas (2 calls)", + rr: threeRR, + want: c3, + wantErr: nil, + }, + { + name: "three read replicas (3 calls)", + rr: threeRR, + want: c2, + wantErr: nil, + }, + { + name: "three read replicas (4 calls)", + rr: threeRR, + want: c1, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.rr.next() + if !errors.Is(err, tt.wantErr) { + t.Fatalf("got err %v, expecting %v", err, tt.wantErr) + } + + if got != tt.want { + t.Fatalf("got rr %p, want %p", got, tt.want) + } + }) + } +} + +// rrContains is a test helper to see if a [ReadReplicaSet] contains all the connections passed in +func rrContains(t *testing.T, rr *ReadReplicaSet, conns []*sqlx.DB) bool { + t.Helper() + + var findCnt int + head := rr.current + for c := head.next; ; c = c.next { + if slices.Contains(conns, c.db) { + findCnt++ + } + + if c == head { + break + } + } + + return findCnt == len(conns) +} diff --git a/sequel.go b/sequel.go index e9c68b7..c5d1aaf 100644 --- a/sequel.go +++ b/sequel.go @@ -24,6 +24,7 @@ const MaxOpenConnections = 100 // operations on a Model. type DB struct { db *sqlx.DB + rrs *ReadReplicaSet clock clock.Clock doRebindModel bool driverName string @@ -34,6 +35,7 @@ type options struct { DriverName string RebindModel bool MaxOpenConnections int + ReadReplicaDSNs []string } func newOptions(driverName string) *options { @@ -88,6 +90,13 @@ func WithMaxOpenConnections(n int) Option { } } +// WithReadReplica adds a read replica DSN to the list of read replicas. +func WithReadReplica(dsn string) Option { + return func(o *options) { + o.ReadReplicaDSNs = append(o.ReadReplicaDSNs, dsn) + } +} + // New creates a new DB. It will fail if it cannot ping it. func New(dataSourceName string, opts ...Option) (*DB, error) { options := newOptions("pgx/v5").apply(opts) @@ -99,8 +108,19 @@ func New(dataSourceName string, opts ...Option) (*DB, error) { } db.SetMaxOpenConns(options.MaxOpenConnections) + rrs, err := createReadReplicaSet(options.DriverName, options.ReadReplicaDSNs, options.MaxOpenConnections) + if err != nil { + // Close any open db connections + _ = db.Close() + if rrs != nil { + rrs.Close() + } + return nil, fmt.Errorf("error creating read replica set: %w", err) + } + return &DB{ db: db, + rrs: rrs, clock: options.Clock, doRebindModel: options.RebindModel, driverName: options.DriverName, @@ -120,14 +140,50 @@ func NewDB(db *sql.DB, driverName string, opts ...Option) (*DB, error) { } dbx.SetMaxOpenConns(options.MaxOpenConnections) + rrs, err := createReadReplicaSet(options.DriverName, options.ReadReplicaDSNs, options.MaxOpenConnections) + if err != nil { + // Close any open db connections + _ = db.Close() + if rrs != nil { + rrs.Close() + } + return nil, fmt.Errorf("error creating read replica set: %w", err) + } + return &DB{ db: dbx, + rrs: rrs, clock: options.Clock, doRebindModel: options.RebindModel, driverName: options.DriverName, }, nil } +// createReadReplicaSet creates a new ReadReplicaSet from the given list of DSN strings. +// It also connects and pings the DB to ensure connectivity. +func createReadReplicaSet(driverName string, dsns []string, maxConns int) (*ReadReplicaSet, error) { + if len(dsns) == 0 { + return nil, nil //nolint:nilnil // if there are no dsns, then this should not create anything + } + + var rrs ReadReplicaSet + for _, dsn := range dsns { + // Connect to DB and ping + rr, err := sqlx.Connect(driverName, dsn) + if err != nil { + return &rrs, fmt.Errorf("error connecting to the read replica database: %w", err) + } + + // Set max open connections + rr.SetMaxOpenConns(maxConns) + + // Add read replica to ReadReplicaSet + rrs.add(rr) + } + + return &rrs, nil +} + type dbKey struct{} // NewContext returns a new context with the given DB. @@ -180,6 +236,7 @@ func RowsAffected(res sql.Result, n int64) error { // Close closes the database and prevents new queries from starting. Close then // waits for all queries that have started processing on the server to finish. func (d *DB) Close() error { + d.rrs.Close() return d.db.Close() } @@ -193,6 +250,11 @@ func (d *DB) DB() *sql.DB { return d.db.DB } +// ReadReplicaSet returns the read replica set. +func (d *DB) ReadReplicaSet() *ReadReplicaSet { + return d.rrs +} + // Rebind transforms a query from `?` to the DB driver's bind type. func (d *DB) Rebind(query string) string { return d.db.Rebind(query) @@ -228,14 +290,14 @@ func (d *DB) Exec(ctx context.Context, query string, args ...any) (sql.Result, e return d.db.ExecContext(ctx, query, args...) } -// Query executes a query that returns rows, typically a SELECT. The query is +// RebindQuery executes a query that returns rows, typically a SELECT. The query is // rebound from `?` to the DB driver's bind type. The args are for any // placeholder parameters in the query. func (d *DB) RebindQuery(ctx context.Context, query string, args ...any) (*sql.Rows, error) { return d.db.QueryContext(ctx, d.db.Rebind(query), args...) } -// QueryRow executes a query that is expected to return at most one row. The +// RebindQueryRow executes a query that is expected to return at most one row. The // query is rebound from `?` to the DB driver's bind type. QueryRowContext // always returns a non-nil value. Errors are deferred until Row's Scan method // is called. @@ -247,7 +309,7 @@ func (d *DB) RebindQueryRow(ctx context.Context, query string, args ...any) *sql return d.db.QueryRowContext(ctx, d.db.Rebind(query), args...) } -// Exec executes a query without returning any rows. The query is rebound from +// RebindExec executes a query without returning any rows. The query is rebound from // `?` to the DB driver's bind type. The args are for any placeholder parameters // in the query. func (d *DB) RebindExec(ctx context.Context, query string, args ...any) (sql.Result, error) { @@ -279,6 +341,9 @@ func (d *DB) GetAll(ctx context.Context, dest any, query string, args ...any) er if err != nil { return err } + + defer rows.Close() + if err := rows.Err(); err != nil { return err } @@ -500,14 +565,14 @@ func (t *Tx) ExecContext(ctx context.Context, query string, args ...any) (sql.Re return t.tx.ExecContext(ctx, query, args...) } -// Query executes a query that returns rows, typically a SELECT. The query is +// RebindQuery executes a query that returns rows, typically a SELECT. The query is // rebound from `?` to the DB driver's bind type. The args are for any // placeholder parameters in the query. func (t *Tx) RebindQuery(query string, args ...any) (*sql.Rows, error) { return t.tx.Query(t.tx.Rebind(query), args...) } -// QueryRow executes a query that is expected to return at most one row. The +// RebindQueryRow executes a query that is expected to return at most one row. The // query is rebound from `?` to the DB driver's bind type. QueryRowContext // always returns a non-nil value. Errors are deferred until Row's Scan method // is called. @@ -519,7 +584,7 @@ func (t *Tx) RebindQueryRow(query string, args ...any) *sql.Row { return t.tx.QueryRow(t.tx.Rebind(query), args...) } -// Exec executes a query without returning any rows. The query is rebound from +// RebindExec executes a query without returning any rows. The query is rebound from // `?` to the DB driver's bind type. The args are for any placeholder parameters // in the query. func (t *Tx) RebindExec(query string, args ...any) (sql.Result, error) { diff --git a/sequel_test.go b/sequel_test.go index f1d912f..1f6ab40 100644 --- a/sequel_test.go +++ b/sequel_test.go @@ -240,12 +240,20 @@ func TestIsUniqueViolation(t *testing.T) { } func TestDBQueries(t *testing.T) { + // Create single DB source db, err := New(postgresDataSource) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) + // Create DB source with two read replicas + dbWithRRs, err := New(postgresDataSource, WithReadReplica(postgresDataSourceRR1), WithReadReplica(postgresDataSourceRR2)) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, dbWithRRs.Close()) + }) + p1 := &personModel{ Name: "Lucky Luke", Email: NullString("lucky@example.com"), @@ -323,6 +331,27 @@ func TestDBQueries(t *testing.T) { assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes }) + t.Run("query RR (no RRs)", func(t *testing.T) { + //nolint:rowserrcheck // rows is expected to be nil, err to be non-nil + rows, err := db.ReadReplicaSet().Query(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.ErrorIs(t, err, ErrNoReadReplicaConnections) + assert.Nil(t, rows) + }) + + t.Run("query RR", func(t *testing.T) { + rows, err := dbWithRRs.ReadReplicaSet().Query(ctx, "SELECT * FROM person_test WHERE email = $1", "read1@replica.com") + if err != nil { + t.Fatalf("expecting no error, got %v", err) + } + for rows.Next() { + var p personModel + assert.NoError(t, rows.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + assert.Equal(t, p.Email.String, "read1@replica.com") + } + assert.NoError(t, rows.Err()) + assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes + }) + t.Run("queryRow", func(t *testing.T) { var p personModel row := db.QueryRow(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) @@ -331,6 +360,15 @@ func TestDBQueries(t *testing.T) { assertEqualPerson(t, p1, &p) }) + t.Run("queryRow RR", func(t *testing.T) { + var p personModel + row, err := dbWithRRs.ReadReplicaSet().QueryRow(ctx, "SELECT * FROM person_test WHERE email = $1", "read2@replica.com") + assert.NoError(t, err) + assert.NoError(t, row.Err()) + assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + assert.Equal(t, p.Email.String, "read2@replica.com") + }) + t.Run("rebindQuery", func(t *testing.T) { rows, err := db.RebindQuery(ctx, "SELECT * FROM person_test WHERE id = ?", p1.GetID()) assert.NoError(t, err) @@ -385,6 +423,12 @@ func TestDBQueries(t *testing.T) { assertEqualPerson(t, &personModel{}, &pp2) }) + t.Run("get RR", func(t *testing.T) { + var p personModel + assert.NoError(t, dbWithRRs.ReadReplicaSet().Get(ctx, &p, "SELECT * FROM person_test WHERE email = $1", "read3@replica.com")) + assert.Equal(t, p.Email.String, "read3@replica.com") + }) + t.Run("getAll", func(t *testing.T) { var ap []*personModel assert.NoError(t, db.GetAll(ctx, &ap, "SELECT * FROM person_test")) @@ -393,6 +437,15 @@ func TestDBQueries(t *testing.T) { assertEqualPersons(t, []*personModel{}, ap) }) + t.Run("getAll RR", func(t *testing.T) { + var ap []*personModel + assert.NoError(t, dbWithRRs.ReadReplicaSet().GetAll(ctx, &ap, "SELECT * FROM person_test ORDER BY email")) + assert.Equal(t, len(ap), 3) + assert.Equal(t, ap[0].Email.String, "read1@replica.com") + assert.Equal(t, ap[1].Email.String, "read2@replica.com") + assert.Equal(t, ap[2].Email.String, "read3@replica.com") + }) + t.Run("select", func(t *testing.T) { var pp1, pp2 personModel assert.NoError(t, db.Select(ctx, &pp1, p2.GetID())) diff --git a/testdata/init-db-rr.sh b/testdata/init-db-rr.sh new file mode 100755 index 0000000..f733d18 --- /dev/null +++ b/testdata/init-db-rr.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set -e + +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname sequel --file /tmp/schema.sql +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname sequel -c "INSERT INTO person_test (id, name, email) VALUES ('ffb39de50af9c0ef6719e1c5698bc9ef', 'read replica user1', 'read1@replica.com'), ('2144dd61a17ab3c24d487916698bc9f5', 'read replica user2', 'read2@replica.com'), ('cc577469cc0aca5f0ed8df51698bc9ff', 'read replica user3', 'read3@replica.com');"