Skip to content
97 changes: 76 additions & 21 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ const (
postgresImage = "docker.io/postgres:16.0-alpine"
)

var postgresDataSource string
// Connection strings for the three databases created in [TestMain]
// These are used in other tests
var (
postgresDataSource string
postgresDataSourceRR1 string
postgresDataSourceRR2 string
)

func withSchemaSQL() testcontainers.CustomizeRequestOption {
return func(req *testcontainers.GenericContainerRequest) error {
Expand All @@ -34,24 +40,74 @@ func withSchemaSQL() testcontainers.CustomizeRequestOption {
}

func TestMain(m *testing.M) {
var cleanups []func()
cleanup := func(fn func()) {
cleanups = append(cleanups, fn)
}
fatal := func(args ...any) {
fmt.Fprintln(os.Stderr, args...)
for _, fn := range cleanups {
fn()
ctx := context.Background()
cleanups, err := createPostgresContainers(ctx)
if err != nil {
fmt.Printf("did not create postgres containers: %v\n", err)
for _, cleanup := range cleanups {
if cleanup != nil {
cleanup()
}
}
os.Exit(1)
}

ctx := context.Background()
// nolint:gocritic // The docs for Run specify that the returned int is to be passed to os.Exit
retCode := m.Run()
for _, cleanup := range cleanups {
if cleanup != nil {
cleanup()
}
}
os.Exit(retCode)
}

// createPostgresContainers creates 3 databases as containers. The first is intended to mimic a master database and
// the last 2 are intended to mimic read replicas.
// A []func() is returned to cleanups.
// Package-level connection strings are set.
func createPostgresContainers(ctx context.Context) ([]func(), error) {
// Database connection strings are the same, except the port
connString := func(port string) string {
return fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable&application_name=test", dbUser, dbPassword, port, dbName)
}

var cleanups []func()

// Create the master database
cM, mpM, err := createPostgresContainer(ctx, "init-db.sh")
cleanups = append(cleanups, cM)
if err != nil {
return cleanups, err
}
postgresDataSource = connString(mpM)

// Create first read replica
cRR1, mpRR1, err := createPostgresContainer(ctx, "init-db-rr.sh")
cleanups = append(cleanups, cRR1)
if err != nil {
return cleanups, err
}
postgresDataSourceRR1 = connString(mpRR1)

// Create second read replica
cRR2, mpRR2, err := createPostgresContainer(ctx, "init-db-rr.sh")
cleanups = append(cleanups, cRR2)
if err != nil {
return cleanups, err
}
postgresDataSourceRR2 = connString(mpRR2)

return cleanups, nil
}

// createPostgresContainer creates a single postgres container on the specified port
func createPostgresContainer(ctx context.Context, initFilename string) (func(), string, error) {
postgresContainer, err := postgres.Run(ctx, postgresImage,
postgres.WithDatabase(dbName),
postgres.WithUsername(dbUser),
postgres.WithPassword(dbPassword),
postgres.WithInitScripts(filepath.Join("testdata", "init-db.sh")),
postgres.WithInitScripts(filepath.Join("testdata", initFilename)),
withSchemaSQL(),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
Expand All @@ -60,28 +116,27 @@ func TestMain(m *testing.M) {
),
)
if err != nil {
fatal("error creating postgres container:", err)
return nil, "", fmt.Errorf("error creating postgres container: %w", err)
}
cleanup(func() {

cleanup := func() {
if err := postgresContainer.Terminate(ctx); err != nil {
fmt.Fprintln(os.Stderr, "error terminating postgres:", err)
}
})
}

postgresState, err := postgresContainer.State(ctx)
if err != nil {
fatal(err)
return cleanup, "", fmt.Errorf("checking container state: %w", err)
}
if !postgresState.Running {
fatal("Postgres status:", postgresState.Status)
return cleanup, "", fmt.Errorf("Postgres status %q is not \"running\"", postgresState.Status)
}

postgresPort, err := postgresContainer.MappedPort(ctx, "5432/tcp")
mp, err := postgresContainer.MappedPort(ctx, "5432/tcp")
if err != nil {
fatal(err)
return cleanup, "", fmt.Errorf("mapped port 5432/tcp does not seem to be available: %w", err)
}

postgresDataSource = fmt.Sprintf("postgres://%s:%s@localhost:%s/%s?sslmode=disable&application_name=test", dbUser, dbPassword, postgresPort.Port(), dbName)

os.Exit(m.Run())
return cleanup, mp.Port(), nil
}
169 changes: 169 additions & 0 deletions read_replica_set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package sequel

import (
"context"
"database/sql"
"errors"
"fmt"
"sync"

"github.com/go-sqlx/sqlx"
)

var ErrNoReadReplicaConnections = errors.New("no read replica connections available")

type readReplica struct {
db *sqlx.DB

next *readReplica
}

// ReadReplicaSet contains a set of DB connections. It is intended to give fair round robin access
// through a circular singly linked list.
//
// Replicas are appended after the current one. The intended use is to build the replica set before
// querying, but all operations are concurrent-safe.
type ReadReplicaSet struct {
m sync.Mutex

current *readReplica
}

// add adds a DB to the collection of read replicas
func (rr *ReadReplicaSet) add(db *sqlx.DB) {
rr.m.Lock()
defer rr.m.Unlock()

r := readReplica{
db: db,
}

// Empty ring, add new DB
if rr.current == nil {
r.next = &r

rr.current = &r

return
}

// Insert new db after current
n := rr.current.next
r.next = n
rr.current.next = &r
}

// next returns the current DB. The current pointer is advanced.
func (rr *ReadReplicaSet) next() (*sqlx.DB, error) {
rr.m.Lock()
defer rr.m.Unlock()

if rr.current == nil {
return nil, ErrNoReadReplicaConnections
}

c := rr.current

rr.current = rr.current.next

return c.db, nil
}

// Close closes all read replica connections
func (rr *ReadReplicaSet) Close() {
if rr == nil {
return
}

rr.m.Lock()
defer rr.m.Unlock()

// If this instance has no replicas, current would be nil
if rr.current == nil {
return
}

first := rr.current
for c := first; ; c = c.next {
c.db.Close()

if c.next == first {
break
}
}
}

// Query executes a query against a read replica. Queries that are not SELECTs may not work.
// The args are for any placeholder parameters in the query.
func (rr *ReadReplicaSet) Query(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
if rr == nil {
return nil, ErrNoReadReplicaConnections
}

db, err := rr.next()
if err != nil {
return nil, fmt.Errorf("did not get read replica connection: %w", err)
}

return db.QueryContext(ctx, query, args...)
}

// QueryRow executes a query that is expected to return at most one row against a read replica.
// QueryRowContext always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
//
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Otherwise, the *Row's Scan scans the first selected row and discards the
// rest.
func (rr *ReadReplicaSet) QueryRow(ctx context.Context, query string, args ...any) (*sql.Row, error) {
if rr == nil {
return nil, ErrNoReadReplicaConnections
}

db, err := rr.next()
if err != nil {
return nil, fmt.Errorf("did not get read replica connection: %w", err)
}

return db.QueryRowContext(ctx, query, args...), nil
}

// Get populates the given model for the result of the given select query against a read replica.
func (rr *ReadReplicaSet) Get(ctx context.Context, dest Model, query string, args ...any) error {
if rr == nil {
return ErrNoReadReplicaConnections
}

db, err := rr.next()
if err != nil {
return fmt.Errorf("did not get read replica connection: %w", err)
}

return db.GetContext(ctx, dest, query, args...)
}

// GetAll populates the given destination with all the results of the given
// select query (from a read replica). The method will fail if the destination is not a pointer to a
// slice.
func (rr *ReadReplicaSet) GetAll(ctx context.Context, dest any, query string, args ...any) error {
if rr == nil {
return ErrNoReadReplicaConnections
}

db, err := rr.next()
if err != nil {
return fmt.Errorf("did not get read replica connection: %w", err)
}

rows, err := db.QueryContext(ctx, query, args...)
if err != nil {
return err
}

defer rows.Close()

if err := rows.Err(); err != nil {
return err
}
return sqlx.StructScan(rows, dest)
}
Loading
Loading