From 4624dde29e5a319ba01a6a0f56bca0a5a13cf745 Mon Sep 17 00:00:00 2001 From: jeff Date: Sat, 7 Feb 2026 23:20:24 -0800 Subject: [PATCH 01/15] close rows resource --- sequel.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sequel.go b/sequel.go index e9c68b7..d6945d5 100644 --- a/sequel.go +++ b/sequel.go @@ -122,6 +122,7 @@ func NewDB(db *sql.DB, driverName string, opts ...Option) (*DB, error) { return &DB{ db: dbx, + dbRRs: &readReplicas{}, clock: options.Clock, doRebindModel: options.RebindModel, driverName: options.DriverName, @@ -279,6 +280,9 @@ func (d *DB) GetAll(ctx context.Context, dest any, query string, args ...any) er if err != nil { return err } + + defer rows.Close() + if err := rows.Err(); err != nil { return err } From f3ebb0d8d2aec66b0ea1de9aa4ca863dece44963 Mon Sep 17 00:00:00 2001 From: jeff Date: Sat, 7 Feb 2026 23:21:53 -0800 Subject: [PATCH 02/15] clean up comments --- sequel.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sequel.go b/sequel.go index d6945d5..33266d9 100644 --- a/sequel.go +++ b/sequel.go @@ -229,14 +229,14 @@ func (d *DB) Exec(ctx context.Context, query string, args ...any) (sql.Result, e return d.db.ExecContext(ctx, query, args...) } -// Query executes a query that returns rows, typically a SELECT. The query is +// RebindQuery executes a query that returns rows, typically a SELECT. The query is // rebound from `?` to the DB driver's bind type. The args are for any // placeholder parameters in the query. func (d *DB) RebindQuery(ctx context.Context, query string, args ...any) (*sql.Rows, error) { return d.db.QueryContext(ctx, d.db.Rebind(query), args...) } -// QueryRow executes a query that is expected to return at most one row. The +// RebindQueryRow executes a query that is expected to return at most one row. The // query is rebound from `?` to the DB driver's bind type. QueryRowContext // always returns a non-nil value. Errors are deferred until Row's Scan method // is called. @@ -248,7 +248,7 @@ func (d *DB) RebindQueryRow(ctx context.Context, query string, args ...any) *sql return d.db.QueryRowContext(ctx, d.db.Rebind(query), args...) } -// Exec executes a query without returning any rows. The query is rebound from +// RebindExec executes a query without returning any rows. The query is rebound from // `?` to the DB driver's bind type. The args are for any placeholder parameters // in the query. func (d *DB) RebindExec(ctx context.Context, query string, args ...any) (sql.Result, error) { @@ -504,14 +504,14 @@ func (t *Tx) ExecContext(ctx context.Context, query string, args ...any) (sql.Re return t.tx.ExecContext(ctx, query, args...) } -// Query executes a query that returns rows, typically a SELECT. The query is +// RebindQuery executes a query that returns rows, typically a SELECT. The query is // rebound from `?` to the DB driver's bind type. The args are for any // placeholder parameters in the query. func (t *Tx) RebindQuery(query string, args ...any) (*sql.Rows, error) { return t.tx.Query(t.tx.Rebind(query), args...) } -// QueryRow executes a query that is expected to return at most one row. The +// RebindQueryRow executes a query that is expected to return at most one row. The // query is rebound from `?` to the DB driver's bind type. QueryRowContext // always returns a non-nil value. Errors are deferred until Row's Scan method // is called. @@ -523,7 +523,7 @@ func (t *Tx) RebindQueryRow(query string, args ...any) *sql.Row { return t.tx.QueryRow(t.tx.Rebind(query), args...) } -// Exec executes a query without returning any rows. The query is rebound from +// RebindExec executes a query without returning any rows. The query is rebound from // `?` to the DB driver's bind type. The args are for any placeholder parameters // in the query. func (t *Tx) RebindExec(query string, args ...any) (sql.Result, error) { From 8c7f35ba4589399cc71fac9b1ad22b8fa5b9331a Mon Sep 17 00:00:00 2001 From: jeff Date: Tue, 10 Feb 2026 16:28:19 -0800 Subject: [PATCH 03/15] add support for querying against read replicas --- main_test.go | 91 ++++++++++++++++++++------ read_replicas.go | 78 ++++++++++++++++++++++ read_replicas_test.go | 145 +++++++++++++++++++++++++++++++++++++++++ sequel.go | 69 ++++++++++++++++++++ sequel_test.go | 61 +++++++++++++++++ testdata/init-db-rr.sh | 5 ++ 6 files changed, 428 insertions(+), 21 deletions(-) create mode 100644 read_replicas.go create mode 100644 read_replicas_test.go create mode 100755 testdata/init-db-rr.sh diff --git a/main_test.go b/main_test.go index 3f391ac..bb0106c 100644 --- a/main_test.go +++ b/main_test.go @@ -20,7 +20,13 @@ const ( postgresImage = "docker.io/postgres:16.0-alpine" ) -var postgresDataSource string +// Connection strings for the three databases created in [TestMain] +// These are used in other tests +var ( + postgresDataSource string + postgresDataSourceRR1 string + postgresDataSourceRR2 string +) func withSchemaSQL() testcontainers.CustomizeRequestOption { return func(req *testcontainers.GenericContainerRequest) error { @@ -34,24 +40,68 @@ func withSchemaSQL() testcontainers.CustomizeRequestOption { } func TestMain(m *testing.M) { + ctx := context.Background() + cleanups, err := createPostgresContainers(ctx) + defer func() { + for _, cleanup := range cleanups { + if cleanup != nil { + cleanup() + } + } + }() + if err != nil { + fmt.Printf("did not create postgres containers: %v\n", err) + } + + os.Exit(m.Run()) +} + +// createPostgresContainers creates 3 databases as containers. The first is intended to mimick a master database and +// the last 2 are intended to mimick read replicas. +// A []func() is returned to cleanups. +// Package-level connction strings are set. +func createPostgresContainers(ctx context.Context) ([]func(), error) { + // Database connection strings are the same, except the port + connString := func(port string) string { + return fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable&application_name=test", dbUser, dbPassword, port, dbName) + } + var cleanups []func() - cleanup := func(fn func()) { - cleanups = append(cleanups, fn) + + // Create the master database + cM, mpM, err := createPostgresContainer(ctx, "init-db.sh") + cleanups = append(cleanups, cM) + if err != nil { + return cleanups, err } - fatal := func(args ...any) { - fmt.Fprintln(os.Stderr, args...) - for _, fn := range cleanups { - fn() - } - os.Exit(1) + postgresDataSource = connString(mpM) + + // Create first read replica + cRR1, mpRR1, err := createPostgresContainer(ctx, "init-db-rr.sh") + cleanups = append(cleanups, cRR1) + if err != nil { + return cleanups, err } + postgresDataSourceRR1 = connString(mpRR1) - ctx := context.Background() + // Create second read replica + cRR2, mpRR2, err := createPostgresContainer(ctx, "init-db-rr.sh") + cleanups = append(cleanups, cRR2) + if err != nil { + return cleanups, err + } + postgresDataSourceRR2 = connString(mpRR2) + + return cleanups, nil +} + +// createPostgresContainer creates a single postgres container on the specifed port +func createPostgresContainer(ctx context.Context, initFilename string) (func(), string, error) { postgresContainer, err := postgres.Run(ctx, postgresImage, postgres.WithDatabase(dbName), postgres.WithUsername(dbUser), postgres.WithPassword(dbPassword), - postgres.WithInitScripts(filepath.Join("testdata", "init-db.sh")), + postgres.WithInitScripts(filepath.Join("testdata", initFilename)), withSchemaSQL(), testcontainers.WithWaitStrategy( wait.ForLog("database system is ready to accept connections"). @@ -60,28 +110,27 @@ func TestMain(m *testing.M) { ), ) if err != nil { - fatal("error creating postgres container:", err) + return nil, "", fmt.Errorf("error creating postgres container: %w", err) } - cleanup(func() { + + cleanup := func() { if err := postgresContainer.Terminate(ctx); err != nil { fmt.Fprintln(os.Stderr, "error terminating postgres:", err) } - }) + } postgresState, err := postgresContainer.State(ctx) if err != nil { - fatal(err) + return cleanup, "", fmt.Errorf("checking container state: %w", err) } if !postgresState.Running { - fatal("Postgres status:", postgresState.Status) + return cleanup, "", fmt.Errorf("Postgres status %q is not \"running\"", postgresState.Status) } - postgresPort, err := postgresContainer.MappedPort(ctx, "5432/tcp") + mp, err := postgresContainer.MappedPort(ctx, "5432/tcp") if err != nil { - fatal(err) + return cleanup, "", fmt.Errorf("mapped port 5432/tcp does not seem to be available: %w", err) } - postgresDataSource = fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable&application_name=test", dbUser, dbPassword, postgresPort.Port(), dbName) - - os.Exit(m.Run()) + return cleanup, mp.Port(), nil } diff --git a/read_replicas.go b/read_replicas.go new file mode 100644 index 0000000..c49a942 --- /dev/null +++ b/read_replicas.go @@ -0,0 +1,78 @@ +package sequel + +import ( + "errors" + "sync" + + "github.com/go-sqlx/sqlx" +) + +var ErrorNoReadReplicaConnection = errors.New("no read replica connections") + +type readReplica struct { + db *sqlx.DB + + next *readReplica +} + +// readReplicas contains a set of DB connections. It is intended to give fair round robin access +// through a circular singularly linked list. +// +// Replicas are appended after the current one. The intented use is to build the replica set before +// querying, but all operations are concurrent-safe. +type readReplicas struct { + m sync.Mutex + + current *readReplica +} + +// add adds a DB to the collection of read replicas +func (rr *readReplicas) add(db *sqlx.DB) { + rr.m.Lock() + defer rr.m.Unlock() + + r := readReplica{ + db: db, + } + + // Empty ring, add new DB + if rr.current == nil { + r.next = &r + + rr.current = &r + + return + } + + // Insert new db after current + n := rr.current.next + r.next = n + rr.current.next = &r +} + +// next returns the current DB. The current pointer is advanced. +func (rr *readReplicas) next() (*sqlx.DB, error) { + rr.m.Lock() + defer rr.m.Unlock() + + if rr.current == nil { + return nil, ErrorNoReadReplicaConnection + } + + c := rr.current + + rr.current = rr.current.next + + return c.db, nil +} + +// Close closes all read replica connections +func (rr *readReplicas) Close() { + rr.m.Lock() + defer rr.m.Unlock() + + first := rr.current + for c := first; c != first; c = c.next { + c.db.Close() + } +} diff --git a/read_replicas_test.go b/read_replicas_test.go new file mode 100644 index 0000000..264b21d --- /dev/null +++ b/read_replicas_test.go @@ -0,0 +1,145 @@ +package sequel + +import ( + "errors" + "slices" + "testing" + + "github.com/go-sqlx/sqlx" +) + +func Test_readReplicas_add(t *testing.T) { + tests := []struct { + name string + rrs []*sqlx.DB + }{ + { + name: "add single", + rrs: []*sqlx.DB{sqlx.NewDb(nil, "fakeDriver")}, + }, + { + name: "add multiple", + rrs: []*sqlx.DB{sqlx.NewDb(nil, "fakeDriver"), sqlx.NewDb(nil, "fakeDriver"), sqlx.NewDb(nil, "fakeDriver")}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create readReplicas + rr := &readReplicas{} + + // Add connections to read replicas + for _, c := range tt.rrs { + rr.add(c) + } + + // Confirm they are all added + if !rrContains(t, rr, tt.rrs) { + t.Fatal("readReplicas does not contain all added conns") + } + }) + } +} + +func Test_readReplicas_next(t *testing.T) { + newRRs := func(rrs ...*sqlx.DB) *readReplicas { + rr := &readReplicas{} + for _, c := range rrs { + rr.add(c) + } + + return rr + } + + c1, c2, c3 := sqlx.NewDb(nil, "fakeDriver"), sqlx.NewDb(nil, "fakeDriver"), sqlx.NewDb(nil, "fakeDriver") + + emptyRR := newRRs() + singleRR := newRRs(c1) + threeRR := newRRs(c1, c2, c3) + + tests := []struct { + name string + rr *readReplicas + want *sqlx.DB + wantErr error + }{ + { + name: "empty readReplicas", + rr: emptyRR, + want: nil, + wantErr: ErrorNoReadReplicaConnection, + }, + { + name: "one read replica", + rr: singleRR, + want: c1, + wantErr: nil, + }, + { + name: "one read replica (2 calls)", + rr: singleRR, + want: c1, + wantErr: nil, + }, + { + name: "one read replica (3 calls)", + rr: singleRR, + want: c1, + wantErr: nil, + }, + { + name: "three read replicas", + rr: threeRR, + want: c1, + wantErr: nil, + }, + { + name: "three read replicas (2 calls)", + rr: threeRR, + want: c3, + wantErr: nil, + }, + { + name: "three read replicas (3 calls)", + rr: threeRR, + want: c2, + wantErr: nil, + }, + { + name: "three read replicas (4 calls)", + rr: threeRR, + want: c1, + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.rr.next() + if !errors.Is(err, tt.wantErr) { + t.Fatalf("got err %v, expecting %v", err, tt.wantErr) + } + + if got != tt.want { + t.Fatalf("got rr %p, want %p", got, tt.want) + } + }) + } +} + +// rrContains is atest helper to see if a [readReplicas] contains all the connections passed in +func rrContains(t *testing.T, rr *readReplicas, conns []*sqlx.DB) bool { + t.Helper() + + var findCnt int + head := rr.current + for c := head.next; ; c = c.next { + if slices.Contains(conns, c.db) { + findCnt++ + } + + if c == head { + break + } + } + + return findCnt == len(conns) +} diff --git a/sequel.go b/sequel.go index 33266d9..5771157 100644 --- a/sequel.go +++ b/sequel.go @@ -24,6 +24,7 @@ const MaxOpenConnections = 100 // operations on a Model. type DB struct { db *sqlx.DB + dbRRs *readReplicas clock clock.Clock doRebindModel bool driverName string @@ -101,6 +102,7 @@ func New(dataSourceName string, opts ...Option) (*DB, error) { return &DB{ db: db, + dbRRs: &readReplicas{}, clock: options.Clock, doRebindModel: options.RebindModel, driverName: options.DriverName, @@ -178,9 +180,17 @@ func RowsAffected(res sql.Result, n int64) error { return fmt.Errorf("unexpected number of rows: got %d, want %d", got, n) } +// WithReadReplica adds a read replica connection to the [DB]. +func (d *DB) WithReadReplica(conn *sqlx.DB) { + if conn != nil { + d.dbRRs.add(conn) + } +} + // Close closes the database and prevents new queries from starting. Close then // waits for all queries that have started processing on the server to finish. func (d *DB) Close() error { + d.dbRRs.Close() return d.db.Close() } @@ -212,6 +222,17 @@ func (d *DB) Query(ctx context.Context, query string, args ...any) (*sql.Rows, e return d.db.QueryContext(ctx, query, args...) } +// QueryRR executes a query against a read replica. Queries that are not SELECTs may not work. +// The args are for any placeholder parameters in the query. +func (d *DB) QueryRR(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + db, err := d.dbRRs.next() + if err != nil { + return nil, fmt.Errorf("did not get read replica connection: %w", err) + } + + return db.QueryContext(ctx, query, args...) +} + // QueryRow executes a query that is expected to return at most one row. // QueryRowContext always returns a non-nil value. Errors are deferred until // Row's Scan method is called. @@ -223,6 +244,22 @@ func (d *DB) QueryRow(ctx context.Context, query string, args ...any) *sql.Row { return d.db.QueryRowContext(ctx, query, args...) } +// QueryRowRR executes a query that is expected to return at most one row against a read replica. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards the +// rest. +func (d *DB) QueryRowRR(ctx context.Context, query string, args ...any) (*sql.Row, error) { + db, err := d.dbRRs.next() + if err != nil { + return nil, fmt.Errorf("did not get read replica connection: %w", err) + } + + return db.QueryRowContext(ctx, query, args...), nil +} + // Exec executes a query without returning any rows. The args are for any // placeholder parameters in the query. func (d *DB) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) { @@ -272,6 +309,16 @@ func (d *DB) Get(ctx context.Context, dest Model, query string, args ...any) err return d.db.GetContext(ctx, dest, query, args...) } +// GetRR populates the given model for the result of the given select query against a read replica. +func (d *DB) GetRR(ctx context.Context, dest Model, query string, args ...any) error { + db, err := d.dbRRs.next() + if err != nil { + return fmt.Errorf("did not get read replica connection: %w", err) + } + + return db.GetContext(ctx, dest, query, args...) +} + // GetAll populates the given destination with all the results of the given // select query. The method will fail if the destination is not a pointer to a // slice. @@ -289,6 +336,28 @@ func (d *DB) GetAll(ctx context.Context, dest any, query string, args ...any) er return sqlx.StructScan(rows, dest) } +// GetAllRR populates the given destination with all the results of the given +// select query (from a read replica). The method will fail if the destination is not a pointer to a +// slice. +func (d *DB) GetAllRR(ctx context.Context, dest any, query string, args ...any) error { + db, err := d.dbRRs.next() + if err != nil { + return fmt.Errorf("did not get read replica connection: %w", err) + } + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return err + } + + defer rows.Close() + + if err := rows.Err(); err != nil { + return err + } + return sqlx.StructScan(rows, dest) +} + // Select populates the given model with the result of a select by id query. func (d *DB) Select(ctx context.Context, dest Model, id string) error { return d.db.GetContext(ctx, dest, d.rebindModel(dest.Select()), id) diff --git a/sequel_test.go b/sequel_test.go index f1d912f..affa90a 100644 --- a/sequel_test.go +++ b/sequel_test.go @@ -160,12 +160,14 @@ func TestNewDB(t *testing.T) { }{ {"ok", args{db, "pgx/v5", nil}, &DB{ db: sqlx.NewDb(db, "pgx/v5"), + dbRRs: &readReplicas{}, clock: clock.New(), doRebindModel: false, driverName: "pgx/v5", }, assert.NoError}, {"ok with options", args{db, "pgx/v5", []Option{WithClock(clock.NewMock(testTime)), WithDriver("pgx"), WithRebindModel()}}, &DB{ db: sqlx.NewDb(db, "pgx"), + dbRRs: &readReplicas{}, clock: clock.NewMock(testTime), doRebindModel: true, driverName: "pgx", @@ -240,12 +242,28 @@ func TestIsUniqueViolation(t *testing.T) { } func TestDBQueries(t *testing.T) { + // Create single DB source db, err := New(postgresDataSource) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, db.Close()) }) + // Create DB source with two read replicas + dbWithRRs, err := New(postgresDataSource) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, dbWithRRs.Close()) + }) + + rr1, err := sqlx.Open("pgx/v5", postgresDataSourceRR1) + assert.NoError(t, err) + dbWithRRs.WithReadReplica(rr1) + + rr2, err := sqlx.Open("pgx/v5", postgresDataSourceRR2) + assert.NoError(t, err) + dbWithRRs.WithReadReplica(rr2) + p1 := &personModel{ Name: "Lucky Luke", Email: NullString("lucky@example.com"), @@ -323,6 +341,25 @@ func TestDBQueries(t *testing.T) { assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes }) + t.Run("query RR (no RRs)", func(t *testing.T) { + _, err := db.QueryRR(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.Error(t, err, ErrorNoReadReplicaConnection) + }) + + t.Run("query RR", func(t *testing.T) { + rows, err := dbWithRRs.QueryRR(ctx, "SELECT * FROM person_test WHERE email = $1", "read1@replica.com") + if err != nil { + t.Fatalf("expecting no error, got %v", err) + } + for rows.Next() { + var p personModel + assert.NoError(t, rows.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + assert.Equal(t, p.Email.String, "read1@replica.com") + } + assert.NoError(t, rows.Err()) + assert.NoError(t, rows.Close()) //nolint:sqlclosecheck // no defer for testing purposes + }) + t.Run("queryRow", func(t *testing.T) { var p personModel row := db.QueryRow(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) @@ -331,6 +368,15 @@ func TestDBQueries(t *testing.T) { assertEqualPerson(t, p1, &p) }) + t.Run("queryRow RR", func(t *testing.T) { + var p personModel + row, err := dbWithRRs.QueryRowRR(ctx, "SELECT * FROM person_test WHERE email = $1", "read2@replica.com") + assert.NoError(t, err) + assert.NoError(t, row.Err()) + assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) + assert.Equal(t, p.Email.String, "read2@replica.com") + }) + t.Run("rebindQuery", func(t *testing.T) { rows, err := db.RebindQuery(ctx, "SELECT * FROM person_test WHERE id = ?", p1.GetID()) assert.NoError(t, err) @@ -385,6 +431,12 @@ func TestDBQueries(t *testing.T) { assertEqualPerson(t, &personModel{}, &pp2) }) + t.Run("get RR", func(t *testing.T) { + var p personModel + assert.NoError(t, dbWithRRs.GetRR(ctx, &p, "SELECT * FROM person_test WHERE email = $1", "read3@replica.com")) + assert.Equal(t, p.Email.String, "read3@replica.com") + }) + t.Run("getAll", func(t *testing.T) { var ap []*personModel assert.NoError(t, db.GetAll(ctx, &ap, "SELECT * FROM person_test")) @@ -393,6 +445,15 @@ func TestDBQueries(t *testing.T) { assertEqualPersons(t, []*personModel{}, ap) }) + t.Run("getAll RR", func(t *testing.T) { + var ap []*personModel + assert.NoError(t, dbWithRRs.GetAllRR(ctx, &ap, "SELECT * FROM person_test")) + assert.Equal(t, len(ap), 3) + assert.Equal(t, ap[0].Email.String, "read1@replica.com") + assert.Equal(t, ap[1].Email.String, "read2@replica.com") + assert.Equal(t, ap[2].Email.String, "read3@replica.com") + }) + t.Run("select", func(t *testing.T) { var pp1, pp2 personModel assert.NoError(t, db.Select(ctx, &pp1, p2.GetID())) diff --git a/testdata/init-db-rr.sh b/testdata/init-db-rr.sh new file mode 100755 index 0000000..f733d18 --- /dev/null +++ b/testdata/init-db-rr.sh @@ -0,0 +1,5 @@ +#!/bin/bash +set -e + +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname sequel --file /tmp/schema.sql +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname sequel -c "INSERT INTO person_test (id, name, email) VALUES ('ffb39de50af9c0ef6719e1c5698bc9ef', 'read replica user1', 'read1@replica.com'), ('2144dd61a17ab3c24d487916698bc9f5', 'read replica user2', 'read2@replica.com'), ('cc577469cc0aca5f0ed8df51698bc9ff', 'read replica user3', 'read3@replica.com');" From d6274ddf5d50998582d7a45caec98bdb2590a4ee Mon Sep 17 00:00:00 2001 From: jeff Date: Tue, 10 Feb 2026 16:33:51 -0800 Subject: [PATCH 04/15] lint --- main_test.go | 2 +- read_replicas.go | 6 +++--- read_replicas_test.go | 2 +- sequel_test.go | 5 +++-- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/main_test.go b/main_test.go index bb0106c..36ae169 100644 --- a/main_test.go +++ b/main_test.go @@ -95,7 +95,7 @@ func createPostgresContainers(ctx context.Context) ([]func(), error) { return cleanups, nil } -// createPostgresContainer creates a single postgres container on the specifed port +// createPostgresContainer creates a single postgres container on the specified port func createPostgresContainer(ctx context.Context, initFilename string) (func(), string, error) { postgresContainer, err := postgres.Run(ctx, postgresImage, postgres.WithDatabase(dbName), diff --git a/read_replicas.go b/read_replicas.go index c49a942..ad9988e 100644 --- a/read_replicas.go +++ b/read_replicas.go @@ -7,7 +7,7 @@ import ( "github.com/go-sqlx/sqlx" ) -var ErrorNoReadReplicaConnection = errors.New("no read replica connections") +var ErrNoReadReplicaConnection = errors.New("no read replica connections") type readReplica struct { db *sqlx.DB @@ -18,7 +18,7 @@ type readReplica struct { // readReplicas contains a set of DB connections. It is intended to give fair round robin access // through a circular singularly linked list. // -// Replicas are appended after the current one. The intented use is to build the replica set before +// Replicas are appended after the current one. The intended use is to build the replica set before // querying, but all operations are concurrent-safe. type readReplicas struct { m sync.Mutex @@ -56,7 +56,7 @@ func (rr *readReplicas) next() (*sqlx.DB, error) { defer rr.m.Unlock() if rr.current == nil { - return nil, ErrorNoReadReplicaConnection + return nil, ErrNoReadReplicaConnection } c := rr.current diff --git a/read_replicas_test.go b/read_replicas_test.go index 264b21d..3e3033f 100644 --- a/read_replicas_test.go +++ b/read_replicas_test.go @@ -66,7 +66,7 @@ func Test_readReplicas_next(t *testing.T) { name: "empty readReplicas", rr: emptyRR, want: nil, - wantErr: ErrorNoReadReplicaConnection, + wantErr: ErrNoReadReplicaConnection, }, { name: "one read replica", diff --git a/sequel_test.go b/sequel_test.go index affa90a..ee9076c 100644 --- a/sequel_test.go +++ b/sequel_test.go @@ -342,8 +342,9 @@ func TestDBQueries(t *testing.T) { }) t.Run("query RR (no RRs)", func(t *testing.T) { - _, err := db.QueryRR(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) - assert.Error(t, err, ErrorNoReadReplicaConnection) + rows, err := db.QueryRR(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.Error(t, err, ErrNoReadReplicaConnection) + assert.Nil(t, rows) }) t.Run("query RR", func(t *testing.T) { From 1fb9ea48749e3e2c4f856d0b82f73368412cd042 Mon Sep 17 00:00:00 2001 From: jeff Date: Tue, 10 Feb 2026 19:35:45 -0800 Subject: [PATCH 05/15] lint --- main_test.go | 1 + sequel_test.go | 1 + 2 files changed, 2 insertions(+) diff --git a/main_test.go b/main_test.go index 36ae169..ae89522 100644 --- a/main_test.go +++ b/main_test.go @@ -53,6 +53,7 @@ func TestMain(m *testing.M) { fmt.Printf("did not create postgres containers: %v\n", err) } + // nolint:gocritic // The docs for Run specify that the returned int is to be passed to os.Exit os.Exit(m.Run()) } diff --git a/sequel_test.go b/sequel_test.go index ee9076c..bd3c16c 100644 --- a/sequel_test.go +++ b/sequel_test.go @@ -342,6 +342,7 @@ func TestDBQueries(t *testing.T) { }) t.Run("query RR (no RRs)", func(t *testing.T) { + //nolint:rowserrcheck // rows is expected to be nil, err to be non-nil rows, err := db.QueryRR(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) assert.Error(t, err, ErrNoReadReplicaConnection) assert.Nil(t, rows) From 1d7377cfabf5026d85628735eb1647f3c4acc925 Mon Sep 17 00:00:00 2001 From: smst-jeff Date: Tue, 10 Feb 2026 19:51:31 -0800 Subject: [PATCH 06/15] Update read_replicas_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- read_replicas_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/read_replicas_test.go b/read_replicas_test.go index 3e3033f..a67a373 100644 --- a/read_replicas_test.go +++ b/read_replicas_test.go @@ -125,7 +125,7 @@ func Test_readReplicas_next(t *testing.T) { } } -// rrContains is atest helper to see if a [readReplicas] contains all the connections passed in +// rrContains is a test helper to see if a [readReplicas] contains all the connections passed in func rrContains(t *testing.T, rr *readReplicas, conns []*sqlx.DB) bool { t.Helper() From d47b45c4e7860bd46a03429d1281713225fa9181 Mon Sep 17 00:00:00 2001 From: smst-jeff Date: Tue, 10 Feb 2026 19:51:42 -0800 Subject: [PATCH 07/15] Update main_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- main_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main_test.go b/main_test.go index ae89522..297cd02 100644 --- a/main_test.go +++ b/main_test.go @@ -57,8 +57,8 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -// createPostgresContainers creates 3 databases as containers. The first is intended to mimick a master database and -// the last 2 are intended to mimick read replicas. +// createPostgresContainers creates 3 databases as containers. The first is intended to mimic a master database and +// the last 2 are intended to mimic read replicas. // A []func() is returned to cleanups. // Package-level connction strings are set. func createPostgresContainers(ctx context.Context) ([]func(), error) { From 8e10ce5604450ceb6b10959932b7dc46e2430474 Mon Sep 17 00:00:00 2001 From: smst-jeff Date: Tue, 10 Feb 2026 19:51:52 -0800 Subject: [PATCH 08/15] Update main_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- main_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_test.go b/main_test.go index 297cd02..70ecb64 100644 --- a/main_test.go +++ b/main_test.go @@ -60,7 +60,7 @@ func TestMain(m *testing.M) { // createPostgresContainers creates 3 databases as containers. The first is intended to mimic a master database and // the last 2 are intended to mimic read replicas. // A []func() is returned to cleanups. -// Package-level connction strings are set. +// Package-level connection strings are set. func createPostgresContainers(ctx context.Context) ([]func(), error) { // Database connection strings are the same, except the port connString := func(port string) string { From 10e79d189cf1bf5a95671fe0c5b3a4f9f43b89da Mon Sep 17 00:00:00 2001 From: jeff Date: Tue, 10 Feb 2026 19:54:34 -0800 Subject: [PATCH 09/15] lint --- read_replicas.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/read_replicas.go b/read_replicas.go index ad9988e..5d317f9 100644 --- a/read_replicas.go +++ b/read_replicas.go @@ -71,8 +71,17 @@ func (rr *readReplicas) Close() { rr.m.Lock() defer rr.m.Unlock() + // If this instance has no replicas, current would be nil + if rr.current == nil { + return + } + first := rr.current - for c := first; c != first; c = c.next { + for c := first; ; c = c.next { c.db.Close() + + if c.next == first { + break + } } } From 1d9800c8cc3806a44ff3ae6c791eca227bbb7f76 Mon Sep 17 00:00:00 2001 From: jeff Date: Mon, 23 Feb 2026 10:11:17 -0800 Subject: [PATCH 10/15] manually clean up containers --- main_test.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/main_test.go b/main_test.go index 70ecb64..07cdf97 100644 --- a/main_test.go +++ b/main_test.go @@ -42,19 +42,24 @@ func withSchemaSQL() testcontainers.CustomizeRequestOption { func TestMain(m *testing.M) { ctx := context.Background() cleanups, err := createPostgresContainers(ctx) - defer func() { + if err != nil { + fmt.Printf("did not create postgres containers: %v\n", err) for _, cleanup := range cleanups { if cleanup != nil { cleanup() } } - }() - if err != nil { - fmt.Printf("did not create postgres containers: %v\n", err) + os.Exit(1) } // nolint:gocritic // The docs for Run specify that the returned int is to be passed to os.Exit - os.Exit(m.Run()) + retCode := m.Run() + for _, cleanup := range cleanups { + if cleanup != nil { + cleanup() + } + } + os.Exit(retCode) } // createPostgresContainers creates 3 databases as containers. The first is intended to mimic a master database and From aaeb52f2155ce25f20355add9f9ad184c058af3c Mon Sep 17 00:00:00 2001 From: jeff Date: Sun, 15 Mar 2026 17:07:29 -0700 Subject: [PATCH 11/15] refactor --- read_replica_set.go | 169 ++++++++++++++++++ ...plicas_test.go => read_replica_set_test.go | 20 +-- read_replicas.go | 87 --------- sequel.go | 122 ++++++------- sequel_test.go | 24 +-- 5 files changed, 238 insertions(+), 184 deletions(-) create mode 100644 read_replica_set.go rename read_replicas_test.go => read_replica_set_test.go (82%) delete mode 100644 read_replicas.go diff --git a/read_replica_set.go b/read_replica_set.go new file mode 100644 index 0000000..87c4a25 --- /dev/null +++ b/read_replica_set.go @@ -0,0 +1,169 @@ +package sequel + +import ( + "context" + "database/sql" + "errors" + "fmt" + "sync" + + "github.com/go-sqlx/sqlx" +) + +var ErrNoReadReplicaConnections = errors.New("no read replica connections available") + +type readReplica struct { + db *sqlx.DB + + next *readReplica +} + +// ReadReplicaSet contains a set of DB connections. It is intended to give fair round robin access +// through a circular singularly linked list. +// +// Replicas are appended after the current one. The intended use is to build the replica set before +// querying, but all operations are concurrent-safe. +type ReadReplicaSet struct { + m sync.Mutex + + current *readReplica +} + +// add adds a DB to the collection of read replicas +func (rr *ReadReplicaSet) add(db *sqlx.DB) { + rr.m.Lock() + defer rr.m.Unlock() + + r := readReplica{ + db: db, + } + + // Empty ring, add new DB + if rr.current == nil { + r.next = &r + + rr.current = &r + + return + } + + // Insert new db after current + n := rr.current.next + r.next = n + rr.current.next = &r +} + +// next returns the current DB. The current pointer is advanced. +func (rr *ReadReplicaSet) next() (*sqlx.DB, error) { + rr.m.Lock() + defer rr.m.Unlock() + + if rr.current == nil { + return nil, ErrNoReadReplicaConnections + } + + c := rr.current + + rr.current = rr.current.next + + return c.db, nil +} + +// Close closes all read replica connections +func (rr *ReadReplicaSet) Close() { + if rr == nil { + return + } + + rr.m.Lock() + defer rr.m.Unlock() + + // If this instance has no replicas, current would be nil + if rr.current == nil { + return + } + + first := rr.current + for c := first; ; c = c.next { + c.db.Close() + + if c.next == first { + break + } + } +} + +// Query executes a query against a read replica. Queries that are not SELECTs may not work. +// The args are for any placeholder parameters in the query. +func (rr *ReadReplicaSet) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) { + if rr == nil { + return nil, ErrNoReadReplicaConnections + } + + db, err := rr.next() + if err != nil { + return nil, fmt.Errorf("did not get read replica connection: %w", err) + } + + return db.QueryContext(ctx, query, args...) +} + +// QueryRow executes a query that is expected to return at most one row against a read replica. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards the +// rest. +func (rr *ReadReplicaSet) QueryRow(ctx context.Context, query string, args ...any) (*sql.Row, error) { + if rr == nil { + return nil, ErrNoReadReplicaConnections + } + + db, err := rr.next() + if err != nil { + return nil, fmt.Errorf("did not get read replica connection: %w", err) + } + + return db.QueryRowContext(ctx, query, args...), nil +} + +// Get populates the given model for the result of the given select query against a read replica. +func (rr *ReadReplicaSet) Get(ctx context.Context, dest Model, query string, args ...any) error { + if rr == nil { + return ErrNoReadReplicaConnections + } + + db, err := rr.next() + if err != nil { + return fmt.Errorf("did not get read replica connection: %w", err) + } + + return db.GetContext(ctx, dest, query, args...) +} + +// GetAll populates the given destination with all the results of the given +// select query (from a read replica). The method will fail if the destination is not a pointer to a +// slice. +func (rr *ReadReplicaSet) GetAll(ctx context.Context, dest any, query string, args ...any) error { + if rr == nil { + return ErrNoReadReplicaConnections + } + + db, err := rr.next() + if err != nil { + return fmt.Errorf("did not get read replica connection: %w", err) + } + + rows, err := db.QueryContext(ctx, query, args...) + if err != nil { + return err + } + + defer rows.Close() + + if err := rows.Err(); err != nil { + return err + } + return sqlx.StructScan(rows, dest) +} diff --git a/read_replicas_test.go b/read_replica_set_test.go similarity index 82% rename from read_replicas_test.go rename to read_replica_set_test.go index a67a373..7314362 100644 --- a/read_replicas_test.go +++ b/read_replica_set_test.go @@ -24,8 +24,8 @@ func Test_readReplicas_add(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Create readReplicas - rr := &readReplicas{} + // Create ReadReplicaSet + rr := &ReadReplicaSet{} // Add connections to read replicas for _, c := range tt.rrs { @@ -34,15 +34,15 @@ func Test_readReplicas_add(t *testing.T) { // Confirm they are all added if !rrContains(t, rr, tt.rrs) { - t.Fatal("readReplicas does not contain all added conns") + t.Fatal("ReadReplicaSet does not contain all added conns") } }) } } func Test_readReplicas_next(t *testing.T) { - newRRs := func(rrs ...*sqlx.DB) *readReplicas { - rr := &readReplicas{} + newRRs := func(rrs ...*sqlx.DB) *ReadReplicaSet { + rr := &ReadReplicaSet{} for _, c := range rrs { rr.add(c) } @@ -58,15 +58,15 @@ func Test_readReplicas_next(t *testing.T) { tests := []struct { name string - rr *readReplicas + rr *ReadReplicaSet want *sqlx.DB wantErr error }{ { - name: "empty readReplicas", + name: "empty ReadReplicaSet", rr: emptyRR, want: nil, - wantErr: ErrNoReadReplicaConnection, + wantErr: ErrNoReadReplicaConnections, }, { name: "one read replica", @@ -125,8 +125,8 @@ func Test_readReplicas_next(t *testing.T) { } } -// rrContains is a test helper to see if a [readReplicas] contains all the connections passed in -func rrContains(t *testing.T, rr *readReplicas, conns []*sqlx.DB) bool { +// rrContains is a test helper to see if a [ReadReplicaSet] contains all the connections passed in +func rrContains(t *testing.T, rr *ReadReplicaSet, conns []*sqlx.DB) bool { t.Helper() var findCnt int diff --git a/read_replicas.go b/read_replicas.go deleted file mode 100644 index 5d317f9..0000000 --- a/read_replicas.go +++ /dev/null @@ -1,87 +0,0 @@ -package sequel - -import ( - "errors" - "sync" - - "github.com/go-sqlx/sqlx" -) - -var ErrNoReadReplicaConnection = errors.New("no read replica connections") - -type readReplica struct { - db *sqlx.DB - - next *readReplica -} - -// readReplicas contains a set of DB connections. It is intended to give fair round robin access -// through a circular singularly linked list. -// -// Replicas are appended after the current one. The intended use is to build the replica set before -// querying, but all operations are concurrent-safe. -type readReplicas struct { - m sync.Mutex - - current *readReplica -} - -// add adds a DB to the collection of read replicas -func (rr *readReplicas) add(db *sqlx.DB) { - rr.m.Lock() - defer rr.m.Unlock() - - r := readReplica{ - db: db, - } - - // Empty ring, add new DB - if rr.current == nil { - r.next = &r - - rr.current = &r - - return - } - - // Insert new db after current - n := rr.current.next - r.next = n - rr.current.next = &r -} - -// next returns the current DB. The current pointer is advanced. -func (rr *readReplicas) next() (*sqlx.DB, error) { - rr.m.Lock() - defer rr.m.Unlock() - - if rr.current == nil { - return nil, ErrNoReadReplicaConnection - } - - c := rr.current - - rr.current = rr.current.next - - return c.db, nil -} - -// Close closes all read replica connections -func (rr *readReplicas) Close() { - rr.m.Lock() - defer rr.m.Unlock() - - // If this instance has no replicas, current would be nil - if rr.current == nil { - return - } - - first := rr.current - for c := first; ; c = c.next { - c.db.Close() - - if c.next == first { - break - } - } -} diff --git a/sequel.go b/sequel.go index 5771157..bc9b7de 100644 --- a/sequel.go +++ b/sequel.go @@ -24,7 +24,7 @@ const MaxOpenConnections = 100 // operations on a Model. type DB struct { db *sqlx.DB - dbRRs *readReplicas + rrs *ReadReplicaSet clock clock.Clock doRebindModel bool driverName string @@ -35,6 +35,7 @@ type options struct { DriverName string RebindModel bool MaxOpenConnections int + ReadReplicaDSNs []string } func newOptions(driverName string) *options { @@ -89,6 +90,13 @@ func WithMaxOpenConnections(n int) Option { } } +// WithReadReplica adds a read replica DSN to the list of read replicas. +func WithReadReplica(dsn string) Option { + return func(o *options) { + o.ReadReplicaDSNs = append(o.ReadReplicaDSNs, dsn) + } +} + // New creates a new DB. It will fail if it cannot ping it. func New(dataSourceName string, opts ...Option) (*DB, error) { options := newOptions("pgx/v5").apply(opts) @@ -100,9 +108,14 @@ func New(dataSourceName string, opts ...Option) (*DB, error) { } db.SetMaxOpenConns(options.MaxOpenConnections) + rrs, err := createReadReplicaSet(options.DriverName, options.ReadReplicaDSNs, options.MaxOpenConnections) + if err != nil { + return nil, fmt.Errorf("error creating read replica set: %w", err) + } + return &DB{ db: db, - dbRRs: &readReplicas{}, + rrs: rrs, clock: options.Clock, doRebindModel: options.RebindModel, driverName: options.DriverName, @@ -122,15 +135,45 @@ func NewDB(db *sql.DB, driverName string, opts ...Option) (*DB, error) { } dbx.SetMaxOpenConns(options.MaxOpenConnections) + rrs, err := createReadReplicaSet(options.DriverName, options.ReadReplicaDSNs, options.MaxOpenConnections) + if err != nil { + return nil, fmt.Errorf("error creating read replica set: %w", err) + } + return &DB{ db: dbx, - dbRRs: &readReplicas{}, + rrs: rrs, clock: options.Clock, doRebindModel: options.RebindModel, driverName: options.DriverName, }, nil } +// createReadReplicaSet creates a new ReadReplicaSet from the given list of DSN strings. +// It also connects and pings the DB to ensure connectivity. +func createReadReplicaSet(driverName string, dsns []string, maxConns int) (*ReadReplicaSet, error) { + if len(dsns) == 0 { + return nil, nil + } + + var rss ReadReplicaSet + for _, dsn := range dsns { + // Connect to DB and ping + rr, err := sqlx.Connect(driverName, dsn) + if err != nil { + return nil, fmt.Errorf("error connecting to the read replica database: %w", err) + } + + // Set max open connections + rr.SetMaxOpenConns(maxConns) + + // Add read replica to ReadReplicaSet + rss.add(rr) + } + + return &rss, nil +} + type dbKey struct{} // NewContext returns a new context with the given DB. @@ -180,17 +223,10 @@ func RowsAffected(res sql.Result, n int64) error { return fmt.Errorf("unexpected number of rows: got %d, want %d", got, n) } -// WithReadReplica adds a read replica connection to the [DB]. -func (d *DB) WithReadReplica(conn *sqlx.DB) { - if conn != nil { - d.dbRRs.add(conn) - } -} - // Close closes the database and prevents new queries from starting. Close then // waits for all queries that have started processing on the server to finish. func (d *DB) Close() error { - d.dbRRs.Close() + d.rrs.Close() return d.db.Close() } @@ -204,6 +240,11 @@ func (d *DB) DB() *sql.DB { return d.db.DB } +// ReadReplicaSet returns the read replica set. +func (d *DB) ReadReplicaSet() *ReadReplicaSet { + return d.rrs +} + // Rebind transforms a query from `?` to the DB driver's bind type. func (d *DB) Rebind(query string) string { return d.db.Rebind(query) @@ -222,17 +263,6 @@ func (d *DB) Query(ctx context.Context, query string, args ...any) (*sql.Rows, e return d.db.QueryContext(ctx, query, args...) } -// QueryRR executes a query against a read replica. Queries that are not SELECTs may not work. -// The args are for any placeholder parameters in the query. -func (d *DB) QueryRR(ctx context.Context, query string, args ...any) (*sql.Rows, error) { - db, err := d.dbRRs.next() - if err != nil { - return nil, fmt.Errorf("did not get read replica connection: %w", err) - } - - return db.QueryContext(ctx, query, args...) -} - // QueryRow executes a query that is expected to return at most one row. // QueryRowContext always returns a non-nil value. Errors are deferred until // Row's Scan method is called. @@ -244,22 +274,6 @@ func (d *DB) QueryRow(ctx context.Context, query string, args ...any) *sql.Row { return d.db.QueryRowContext(ctx, query, args...) } -// QueryRowRR executes a query that is expected to return at most one row against a read replica. -// QueryRowContext always returns a non-nil value. Errors are deferred until -// Row's Scan method is called. -// -// If the query selects no rows, the *Row's Scan will return ErrNoRows. -// Otherwise, the *Row's Scan scans the first selected row and discards the -// rest. -func (d *DB) QueryRowRR(ctx context.Context, query string, args ...any) (*sql.Row, error) { - db, err := d.dbRRs.next() - if err != nil { - return nil, fmt.Errorf("did not get read replica connection: %w", err) - } - - return db.QueryRowContext(ctx, query, args...), nil -} - // Exec executes a query without returning any rows. The args are for any // placeholder parameters in the query. func (d *DB) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) { @@ -309,16 +323,6 @@ func (d *DB) Get(ctx context.Context, dest Model, query string, args ...any) err return d.db.GetContext(ctx, dest, query, args...) } -// GetRR populates the given model for the result of the given select query against a read replica. -func (d *DB) GetRR(ctx context.Context, dest Model, query string, args ...any) error { - db, err := d.dbRRs.next() - if err != nil { - return fmt.Errorf("did not get read replica connection: %w", err) - } - - return db.GetContext(ctx, dest, query, args...) -} - // GetAll populates the given destination with all the results of the given // select query. The method will fail if the destination is not a pointer to a // slice. @@ -336,28 +340,6 @@ func (d *DB) GetAll(ctx context.Context, dest any, query string, args ...any) er return sqlx.StructScan(rows, dest) } -// GetAllRR populates the given destination with all the results of the given -// select query (from a read replica). The method will fail if the destination is not a pointer to a -// slice. -func (d *DB) GetAllRR(ctx context.Context, dest any, query string, args ...any) error { - db, err := d.dbRRs.next() - if err != nil { - return fmt.Errorf("did not get read replica connection: %w", err) - } - - rows, err := db.QueryContext(ctx, query, args...) - if err != nil { - return err - } - - defer rows.Close() - - if err := rows.Err(); err != nil { - return err - } - return sqlx.StructScan(rows, dest) -} - // Select populates the given model with the result of a select by id query. func (d *DB) Select(ctx context.Context, dest Model, id string) error { return d.db.GetContext(ctx, dest, d.rebindModel(dest.Select()), id) diff --git a/sequel_test.go b/sequel_test.go index bd3c16c..79fdb09 100644 --- a/sequel_test.go +++ b/sequel_test.go @@ -160,14 +160,12 @@ func TestNewDB(t *testing.T) { }{ {"ok", args{db, "pgx/v5", nil}, &DB{ db: sqlx.NewDb(db, "pgx/v5"), - dbRRs: &readReplicas{}, clock: clock.New(), doRebindModel: false, driverName: "pgx/v5", }, assert.NoError}, {"ok with options", args{db, "pgx/v5", []Option{WithClock(clock.NewMock(testTime)), WithDriver("pgx"), WithRebindModel()}}, &DB{ db: sqlx.NewDb(db, "pgx"), - dbRRs: &readReplicas{}, clock: clock.NewMock(testTime), doRebindModel: true, driverName: "pgx", @@ -250,20 +248,12 @@ func TestDBQueries(t *testing.T) { }) // Create DB source with two read replicas - dbWithRRs, err := New(postgresDataSource) + dbWithRRs, err := New(postgresDataSource, WithReadReplica(postgresDataSourceRR1), WithReadReplica(postgresDataSourceRR2)) require.NoError(t, err) t.Cleanup(func() { assert.NoError(t, dbWithRRs.Close()) }) - rr1, err := sqlx.Open("pgx/v5", postgresDataSourceRR1) - assert.NoError(t, err) - dbWithRRs.WithReadReplica(rr1) - - rr2, err := sqlx.Open("pgx/v5", postgresDataSourceRR2) - assert.NoError(t, err) - dbWithRRs.WithReadReplica(rr2) - p1 := &personModel{ Name: "Lucky Luke", Email: NullString("lucky@example.com"), @@ -343,13 +333,13 @@ func TestDBQueries(t *testing.T) { t.Run("query RR (no RRs)", func(t *testing.T) { //nolint:rowserrcheck // rows is expected to be nil, err to be non-nil - rows, err := db.QueryRR(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) - assert.Error(t, err, ErrNoReadReplicaConnection) + rows, err := db.ReadReplicaSet().Query(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) + assert.Error(t, err, ErrNoReadReplicaConnections) assert.Nil(t, rows) }) t.Run("query RR", func(t *testing.T) { - rows, err := dbWithRRs.QueryRR(ctx, "SELECT * FROM person_test WHERE email = $1", "read1@replica.com") + rows, err := dbWithRRs.ReadReplicaSet().Query(ctx, "SELECT * FROM person_test WHERE email = $1", "read1@replica.com") if err != nil { t.Fatalf("expecting no error, got %v", err) } @@ -372,7 +362,7 @@ func TestDBQueries(t *testing.T) { t.Run("queryRow RR", func(t *testing.T) { var p personModel - row, err := dbWithRRs.QueryRowRR(ctx, "SELECT * FROM person_test WHERE email = $1", "read2@replica.com") + row, err := dbWithRRs.ReadReplicaSet().QueryRow(ctx, "SELECT * FROM person_test WHERE email = $1", "read2@replica.com") assert.NoError(t, err) assert.NoError(t, row.Err()) assert.NoError(t, row.Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt, &p.DeletedAt, &p.Name, &p.Email)) @@ -435,7 +425,7 @@ func TestDBQueries(t *testing.T) { t.Run("get RR", func(t *testing.T) { var p personModel - assert.NoError(t, dbWithRRs.GetRR(ctx, &p, "SELECT * FROM person_test WHERE email = $1", "read3@replica.com")) + assert.NoError(t, dbWithRRs.ReadReplicaSet().Get(ctx, &p, "SELECT * FROM person_test WHERE email = $1", "read3@replica.com")) assert.Equal(t, p.Email.String, "read3@replica.com") }) @@ -449,7 +439,7 @@ func TestDBQueries(t *testing.T) { t.Run("getAll RR", func(t *testing.T) { var ap []*personModel - assert.NoError(t, dbWithRRs.GetAllRR(ctx, &ap, "SELECT * FROM person_test")) + assert.NoError(t, dbWithRRs.ReadReplicaSet().GetAll(ctx, &ap, "SELECT * FROM person_test")) assert.Equal(t, len(ap), 3) assert.Equal(t, ap[0].Email.String, "read1@replica.com") assert.Equal(t, ap[1].Email.String, "read2@replica.com") From 95a0d8adb60af5e65aebf6915086dd899204c589 Mon Sep 17 00:00:00 2001 From: jeff Date: Mon, 16 Mar 2026 10:33:27 -0700 Subject: [PATCH 12/15] lint --- sequel.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sequel.go b/sequel.go index bc9b7de..b757338 100644 --- a/sequel.go +++ b/sequel.go @@ -153,7 +153,7 @@ func NewDB(db *sql.DB, driverName string, opts ...Option) (*DB, error) { // It also connects and pings the DB to ensure connectivity. func createReadReplicaSet(driverName string, dsns []string, maxConns int) (*ReadReplicaSet, error) { if len(dsns) == 0 { - return nil, nil + return nil, nil //nolint:nilnil // if there are no dsns, then this should not create anything } var rss ReadReplicaSet From 49ef26a5fb3ec4d949057ed4730176e904647c77 Mon Sep 17 00:00:00 2001 From: smst-jeff Date: Fri, 20 Mar 2026 14:15:09 -0700 Subject: [PATCH 13/15] Update sequel_test.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- sequel_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sequel_test.go b/sequel_test.go index 79fdb09..cb0f936 100644 --- a/sequel_test.go +++ b/sequel_test.go @@ -334,7 +334,7 @@ func TestDBQueries(t *testing.T) { t.Run("query RR (no RRs)", func(t *testing.T) { //nolint:rowserrcheck // rows is expected to be nil, err to be non-nil rows, err := db.ReadReplicaSet().Query(ctx, "SELECT * FROM person_test WHERE id = $1", p1.GetID()) - assert.Error(t, err, ErrNoReadReplicaConnections) + assert.ErrorIs(t, err, ErrNoReadReplicaConnections) assert.Nil(t, rows) }) From 6945d45dd0f581eaa1b8f554ab397f45e37507f1 Mon Sep 17 00:00:00 2001 From: smst-jeff Date: Fri, 20 Mar 2026 14:19:06 -0700 Subject: [PATCH 14/15] Update read_replica_set.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- read_replica_set.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/read_replica_set.go b/read_replica_set.go index 87c4a25..f966df5 100644 --- a/read_replica_set.go +++ b/read_replica_set.go @@ -19,7 +19,7 @@ type readReplica struct { } // ReadReplicaSet contains a set of DB connections. It is intended to give fair round robin access -// through a circular singularly linked list. +// through a circular singly linked list. // // Replicas are appended after the current one. The intended use is to build the replica set before // querying, but all operations are concurrent-safe. From aef55c5daef62a26e8e624340cb6af847f1b650c Mon Sep 17 00:00:00 2001 From: jeff Date: Fri, 20 Mar 2026 14:21:18 -0700 Subject: [PATCH 15/15] copilot suggestions --- sequel.go | 18 ++++++++++++++---- sequel_test.go | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/sequel.go b/sequel.go index b757338..c5d1aaf 100644 --- a/sequel.go +++ b/sequel.go @@ -110,6 +110,11 @@ func New(dataSourceName string, opts ...Option) (*DB, error) { rrs, err := createReadReplicaSet(options.DriverName, options.ReadReplicaDSNs, options.MaxOpenConnections) if err != nil { + // Close any open db connections + _ = db.Close() + if rrs != nil { + rrs.Close() + } return nil, fmt.Errorf("error creating read replica set: %w", err) } @@ -137,6 +142,11 @@ func NewDB(db *sql.DB, driverName string, opts ...Option) (*DB, error) { rrs, err := createReadReplicaSet(options.DriverName, options.ReadReplicaDSNs, options.MaxOpenConnections) if err != nil { + // Close any open db connections + _ = db.Close() + if rrs != nil { + rrs.Close() + } return nil, fmt.Errorf("error creating read replica set: %w", err) } @@ -156,22 +166,22 @@ func createReadReplicaSet(driverName string, dsns []string, maxConns int) (*Read return nil, nil //nolint:nilnil // if there are no dsns, then this should not create anything } - var rss ReadReplicaSet + var rrs ReadReplicaSet for _, dsn := range dsns { // Connect to DB and ping rr, err := sqlx.Connect(driverName, dsn) if err != nil { - return nil, fmt.Errorf("error connecting to the read replica database: %w", err) + return &rrs, fmt.Errorf("error connecting to the read replica database: %w", err) } // Set max open connections rr.SetMaxOpenConns(maxConns) // Add read replica to ReadReplicaSet - rss.add(rr) + rrs.add(rr) } - return &rss, nil + return &rrs, nil } type dbKey struct{} diff --git a/sequel_test.go b/sequel_test.go index cb0f936..1f6ab40 100644 --- a/sequel_test.go +++ b/sequel_test.go @@ -439,7 +439,7 @@ func TestDBQueries(t *testing.T) { t.Run("getAll RR", func(t *testing.T) { var ap []*personModel - assert.NoError(t, dbWithRRs.ReadReplicaSet().GetAll(ctx, &ap, "SELECT * FROM person_test")) + assert.NoError(t, dbWithRRs.ReadReplicaSet().GetAll(ctx, &ap, "SELECT * FROM person_test ORDER BY email")) assert.Equal(t, len(ap), 3) assert.Equal(t, ap[0].Email.String, "read1@replica.com") assert.Equal(t, ap[1].Email.String, "read2@replica.com")