package testdock
import (
"context"
"database/sql"
"fmt"
"sync"
"testing"
"time"
"github.com/cenkalti/backoff/v5"
"github.com/n-r-w/ctxlog"
)
// Informer interface for database information.
type Informer interface {
// DSN returns the real database connection string.
DSN() string
// Host returns the host of the database server.
Host() string
// Port returns the port of the database server.
Port() int
// DatabaseName returns the database name for testing.
DatabaseName() string
}
const (
// DefaultRetryTimeout is the default retry timeout.
DefaultRetryTimeout = time.Second * 3
// DefaultTotalRetryDuration is the default total retry duration.
DefaultTotalRetryDuration = time.Second * 30
)
// PrepareCleanUp - function for prepare to delete temporary test database.
// For example, disconnect users.
type PrepareCleanUp func(db *sql.DB, databaseName string) error
// testDB represents a test database.
type testDB struct {
t testing.TB
logger ctxlog.ILogger // unified way to logging
databaseName string // name of the test database
url *dbURL // parsed database connection string
dsnNoPass string // database connection string without password
// options
driver string // database driver (pgx, pq, etc)
mode RunMode // run mode (docker or external)
dsn string // database connection string
retryTimeout time.Duration // retry timeout for connecting to the database
totalRetryDuration time.Duration // total retry duration
migrationsDir string // migrations directory
unsetProxyEnv bool // unset HTTP_PROXY, HTTPS_PROXY etc. environment variables
MigrateFactory MigrateFactory // unified way to create a migrations
prepareCleanUp []PrepareCleanUp // function for prepare to delete temporary test database.
connectDatabase string // database name for connecting to the database server
connectDatabaseOverride bool
dockerPort int // docker port
dockerRepository string // docker hub repository
dockerImage string // docker hub image tag
dockerSocketEndpoint string // docker socket endpoint for connecting to the docker daemon
dockerEnv []string // environment variables for the docker container
}
var (
globalMu sync.Mutex
globalMuByDSN = make(map[string]*sync.Mutex)
)
// newTDB creates a new test database and applies migrations.
func newTDB(ctx context.Context, tb testing.TB, driver, dsn string, opt []Option) *testDB {
tb.Helper()
var (
db = &testDB{
t: tb,
logger: ctxlog.Must(ctxlog.WithTesting(tb)),
driver: driver,
dsn: dsn,
mode: RunModeAuto,
retryTimeout: DefaultRetryTimeout,
totalRetryDuration: DefaultTotalRetryDuration,
}
errResult error
)
defer func() {
if errResult != nil {
tb.Fatalf("cannot create test database: %v", errResult)
}
}()
if errResult = db.prepareOptions(driver, opt); errResult != nil {
return nil
}
globalMu.Lock()
mu, ok := globalMuByDSN[db.dsn]
if !ok {
mu = &sync.Mutex{}
globalMuByDSN[db.dsn] = mu
}
globalMu.Unlock()
mu.Lock()
defer mu.Unlock()
if db.mode == RunModeDocker {
db.logger.Info(ctx, "using docker test database", "dsn", db.dsnNoPass)
if errResult = db.createDockerResources(ctx); errResult != nil {
return nil
}
} else {
db.logger.Info(ctx, "using real test database", "dsn", db.dsnNoPass)
}
if errResult = db.createTestDatabase(ctx); errResult != nil {
if err := db.close(ctx); err != nil {
db.logger.Info(ctx, "failed to close test database", "dsn", db.dsnNoPass, "error", err)
}
return nil
}
if db.migrationsDir != "" {
if errResult = db.migrationsUp(ctx); errResult != nil {
return nil
}
}
tb.Cleanup(func() {
ctx := context.Background()
if err := db.close(ctx); err != nil {
db.logger.Info(ctx, "failed to close test database", "dsn", db.dsnNoPass, "error", err)
} else {
db.logger.Info(ctx, "test database closed", "dsn", db.dsnNoPass)
}
})
return db
}
// migrationsUp applies migrations to the database.
func (d *testDB) migrationsUp(ctx context.Context) error {
d.logger.Info(ctx, "migrations up start", "dsn", d.dsnNoPass)
defer d.logger.Info(ctx, "migrations up end", "dsn", d.dsnNoPass)
dsn := d.url.replaceDatabase(d.databaseName).string(false)
migrator, err := d.MigrateFactory(d.t, dsn, d.migrationsDir, d.logger)
if err != nil {
return fmt.Errorf("new migrator: %w", err)
}
if err = migrator.Up(context.Background()); err != nil {
return fmt.Errorf("up migrations: %w", err)
}
return nil
}
// close closes the test database.
func (d *testDB) close(ctx context.Context) error {
if d.mode != RunModeDocker {
if d.driver == mongoDriverName {
return nil
}
// remove the database created before applying the migrations
d.logger.Info(ctx, "deleting test database", "dsn", d.dsnNoPass, "database", d.databaseName)
dsn := d.url.string(false)
db, err := sql.Open(d.driver, dsn)
if err != nil {
return fmt.Errorf("sql open url (%s): %w", dsn, err)
}
defer func() {
_ = db.Close()
}()
for _, prepareCleanUp := range d.prepareCleanUp {
if err := prepareCleanUp(db, d.databaseName); err != nil {
d.logger.Info(ctx, "failed to prepare clean up", "dsn", d.dsnNoPass, "error", err)
}
}
if _, err = db.Exec(fmt.Sprintf("DROP DATABASE %s", d.databaseName)); err != nil {
return fmt.Errorf("drop db: %w", err)
}
d.logger.Info(ctx, "test database deleted", "dsn", d.dsnNoPass, "database", d.databaseName)
}
return nil
}
// initDatabase creates a test database or connects to an existing one.
func (d *testDB) createTestDatabase(ctx context.Context) error {
if d.driver == mongoDriverName {
return nil
}
return d.createSQLDatabase(ctx)
}
// retryConnect connects to the database with retries.
func (d *testDB) retryConnect(ctx context.Context, info string, op func() error) error {
var attempt int
operation := func() (struct{}, error) {
if err := op(); err != nil {
d.logger.Info(ctx, "retrying operation", "info", info, "attempt", attempt, "error", err)
attempt++
return struct{}{}, err
}
return struct{}{}, nil
}
_, err := backoff.Retry(
context.Background(), operation,
backoff.WithBackOff(backoff.NewConstantBackOff(d.retryTimeout)),
backoff.WithMaxElapsedTime(d.totalRetryDuration),
)
if err != nil {
return fmt.Errorf("retry failed after %d attempts: %w", attempt, err)
}
return nil
}
// DSN returns the real database connection string.
func (d *testDB) DSN() string {
return d.url.replaceDatabase(d.databaseName).string(false)
}
// Host returns the database host.
func (d *testDB) Host() string {
return d.url.Host
}
// Port returns the database port.
func (d *testDB) Port() int {
return d.url.Port
}
// DatabaseName returns the database name for testing.
func (d *testDB) DatabaseName() string {
return d.databaseName
}
package testdock
import (
"context"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/cenkalti/backoff/v5"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
)
// we ensure the creation of docker resources only once for all tests
var (
globalDockerMu sync.Mutex
globalDockerResources = make(map[string]*dockerResourceInfo)
globalDockerPool *dockertest.Pool
)
type dockerResourceInfo struct {
resource *dockertest.Resource
port int
count int
mu sync.Mutex
}
// createDockerResources create a pool and a resource for creating a test database in docker.
func (d *testDB) createDockerResources(ctx context.Context) error { //nolint:gocognit // ok
globalDockerMu.Lock()
info, ok := globalDockerResources[d.dsn]
if !ok {
info = &dockerResourceInfo{}
}
var (
err error
logDsn = d.dsnNoPass
)
if globalDockerPool == nil {
globalDockerPool, err = dockertest.NewPool(d.dockerSocketEndpoint)
if err != nil {
globalDockerMu.Unlock()
return fmt.Errorf("dockertest NewPool: %w", err)
}
if d.unsetProxyEnv {
// we clear the proxy environment variables, because they can interfere with the work of docker
proxyEnv := []string{
"HTTP_PROXY",
"HTTPS_PROXY",
"ALL_PROXY",
"http_proxy",
"https_proxy",
"all_proxy",
}
for _, env := range proxyEnv {
if os.Getenv(env) != "" {
d.logger.Info(ctx, "unset proxy env", "component", "docker", "env", env)
_ = os.Unsetenv(env)
}
}
}
err = globalDockerPool.Client.Ping()
if err != nil {
globalDockerMu.Unlock()
return fmt.Errorf("dockertest ping: %w", err)
}
d.logger.Info(ctx, "pool created", "component", "docker")
defer func() {
globalDockerMu.Lock()
defer globalDockerMu.Unlock()
if len(globalDockerResources) == 0 {
globalDockerPool = nil
d.logger.Info(ctx, "pool purged", "component", "docker")
}
}()
}
globalDockerMu.Unlock()
info.mu.Lock()
defer info.mu.Unlock()
if info.count == 0 {
// docker releases the port after calling globalDockerPool.Purge(globalDockerResource) not instantly, so we try several times
const (
maxAttempts = 10
sleepTime = 5 * time.Second
)
var (
attempt int
dockerPort = fmt.Sprintf("%d/tcp", d.dockerPort)
)
for {
info.resource, err = globalDockerPool.RunWithOptions(&dockertest.RunOptions{
Repository: d.dockerRepository,
Tag: d.dockerImage,
Env: d.dockerEnv,
PortBindings: map[docker.Port][]docker.PortBinding{
docker.Port(dockerPort): {{
HostIP: d.url.Host,
HostPort: strconv.Itoa(d.url.Port),
}},
},
}, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no"}
})
if err == nil {
break
}
bindErrors := []string{
"address already in use",
"port is already allocated",
"failed to bind host port",
}
needNextPort := false
for _, bindError := range bindErrors {
if strings.Contains(err.Error(), bindError) {
needNextPort = true
break
}
}
if needNextPort {
// increase hostPort by 1
d.logger.Info(ctx, "port is already allocated, trying next port", "dsn", logDsn, "next_port", d.url.Port+1)
d.url.Port++
continue
}
attempt++
if attempt >= maxAttempts {
break
}
d.logger.Info(ctx, "RunWithOptions failed", "component", "docker", "dsn", logDsn, "attempt", attempt, "error", err)
time.Sleep(sleepTime)
}
if err != nil {
return fmt.Errorf("dockertest RunWithOptions: %w", err)
}
info.port = d.url.Port
d.logger.Info(ctx, "resources created", "component", "docker", "dsn", logDsn)
} else {
d.url.Port = info.port // restore port
d.logger.Info(ctx, "use existing resources", "component", "docker", "dsn", logDsn)
}
globalDockerMu.Lock()
globalDockerResources[d.dsn] = info
globalDockerMu.Unlock()
info.count++
d.t.Cleanup(func() {
ctx := context.Background()
info.mu.Lock()
defer info.mu.Unlock()
info.count--
if info.count == 0 {
globalDockerMu.Lock()
defer globalDockerMu.Unlock()
delete(globalDockerResources, d.dsn)
const (
maxTime = 10 * time.Second
retryTimeout = 1 * time.Second
)
var attempt int
operation := func() (struct{}, error) {
if err := globalDockerPool.Purge(info.resource); err != nil {
attempt++
// Closure needs access to context, so we'll pass background context since this is a cleanup function
d.logger.Info(ctx, "purge attempt failed", "component", "docker", "dsn", logDsn, "attempt", attempt, "error", err)
return struct{}{}, err
}
return struct{}{}, nil
}
if _, err := backoff.Retry(ctx, operation,
backoff.WithBackOff(backoff.NewConstantBackOff(retryTimeout)),
backoff.WithMaxElapsedTime(maxTime)); err != nil {
d.logger.Info(ctx, "purge failed after retries", "component", "docker", "dsn", logDsn, "attempt", attempt, "error", err)
} else {
d.logger.Info(ctx, "resources purged successfully", "component", "docker", "dsn", logDsn, "attempts", attempt)
}
}
})
return nil
}
package testdock
import (
"context"
"database/sql"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/mongodb" // require for mongodb
_ "github.com/golang-migrate/migrate/v4/database/postgres" // require for gomigrate
_ "github.com/golang-migrate/migrate/v4/source/file" // require for gomigrate
"github.com/n-r-w/ctxlog"
"github.com/pressly/goose/v3"
)
// MigrateFactory creates a new migrator.
type MigrateFactory func(t testing.TB, dsn string, migrationsDir string, logger ctxlog.ILogger) (Migrator, error)
// Migrator interface for applying migrations.
type Migrator interface {
Up(ctx context.Context) error
}
var (
// GooseMigrateFactoryPGX is a migrator for https://github.com/pressly/goose with pgx driver.
GooseMigrateFactoryPGX = GooseMigrateFactory(goose.DialectPostgres, "pgx")
// GooseMigrateFactoryPQ is a migrator for https://github.com/pressly/goose with pq driver.
GooseMigrateFactoryPQ = GooseMigrateFactory(goose.DialectPostgres, "postgres")
// GooseMigrateFactoryMySQL is a migrator for https://github.com/pressly/goose with mysql driver.
GooseMigrateFactoryMySQL = GooseMigrateFactory(goose.DialectMySQL, "mysql")
)
// GooseMigrateFactory creates a new migrator for https://github.com/pressly/goose.
func GooseMigrateFactory(dialect goose.Dialect, driver string) MigrateFactory {
return func(t testing.TB, dsn, migrationsDir string, logger ctxlog.ILogger) (Migrator, error) {
return newGooseMigrator(t, dialect, driver, dsn, migrationsDir, logger)
}
}
// gooseMigrator is a migrator for goose.
type gooseMigrator struct {
p *goose.Provider
}
// newGooseMigrator creates a new migrator for goose.
func newGooseMigrator(t testing.TB, dialect goose.Dialect, driver, dsn, migrationsDir string, logger ctxlog.ILogger) (*gooseMigrator, error) {
conn, err := sql.Open(driver, dsn)
if err != nil {
return nil, fmt.Errorf("sql open url (%s): %w", dsn, err)
}
p, err := goose.NewProvider(dialect, conn, os.DirFS(migrationsDir),
goose.WithLogger(NewGooseLogger(t, logger)),
goose.WithVerbose(true),
)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("new goose provider: %w", err)
}
return &gooseMigrator{
p: p,
}, nil
}
func (m *gooseMigrator) Up(ctx context.Context) error {
defer m.p.Close()
_, err := m.p.Up(ctx)
return err
}
// GolangMigrateFactory creates a new migrator for https://github.com/golang-migrate/migrate.
func GolangMigrateFactory(_ testing.TB, dsn, migrationsDir string, logger ctxlog.ILogger) (Migrator, error) {
return newGolangMigrateMigrator(dsn, migrationsDir, logger)
}
// golangMigrateMigrator is a migrator for https://github.com/golang-migrate/migrate.
type golangMigrateMigrator struct {
m *migrate.Migrate
}
// newGolangMigrateMigrator creates a new migrator for https://github.com/golang-migrate/migrate.
func newGolangMigrateMigrator(dsn, migrationsDir string, logger ctxlog.ILogger) (*golangMigrateMigrator, error) {
if !filepath.IsAbs(migrationsDir) {
var err error
migrationsDir, err = filepath.Abs(migrationsDir)
if err != nil {
return nil, fmt.Errorf("get absolute path: %w", err)
}
}
m, err := migrate.New("file://"+migrationsDir, dsn)
if err != nil {
return nil, fmt.Errorf("new migrate: %w", err)
}
m.Log = NewGolangMigrateLogger(logger)
return &golangMigrateMigrator{m: m}, nil
}
func (m *golangMigrateMigrator) Up(_ context.Context) error {
return m.m.Up()
}
// GooseLogger is a logger for goose.
type GooseLogger struct {
t testing.TB
l ctxlog.ILogger
}
// NewGooseLogger creates a new goose logger.
func NewGooseLogger(t testing.TB, l ctxlog.ILogger) *GooseLogger {
return &GooseLogger{t: t, l: l}
}
// Fatalf logs a fatal error.
func (l GooseLogger) Fatalf(format string, v ...any) {
l.t.Fatalf(format, v...)
}
// Printf logs a message.
func (l GooseLogger) Printf(format string, v ...any) {
l.l.Info(context.Background(), fmt.Sprintf(format, v...))
}
// GolangMigrateLogger is a logger for golang-migrate.
type GolangMigrateLogger struct {
l ctxlog.ILogger
}
// NewGolangMigrateLogger creates a new golang-migrate logger.
func NewGolangMigrateLogger(l ctxlog.ILogger) *GolangMigrateLogger {
return &GolangMigrateLogger{l: l}
}
// Printf logs a message.
func (g *GolangMigrateLogger) Printf(format string, v ...any) {
g.l.Info(context.Background(), fmt.Sprintf(format, v...))
}
// Verbose returns true.
func (g *GolangMigrateLogger) Verbose() bool {
return true
}
package testdock
import (
"context"
"fmt"
"testing"
mongov1 "go.mongodb.org/mongo-driver/mongo"
optionsv1 "go.mongodb.org/mongo-driver/mongo/options"
)
// GetMongoDatabase initializes a test MongoDB database, applies migrations, and returns a database connection.
func GetMongoDatabase(tb testing.TB, dsn string, opt ...Option) (*mongov1.Database, Informer) {
tb.Helper()
ctx := context.Background()
url, err := parseURL(dsn)
if err != nil {
tb.Fatalf("failed to parse dsn: %v", err)
}
optPrepared := make([]Option, 0, len(opt))
optPrepared = append(optPrepared,
WithDockerRepository("mongo"),
WithDockerImage("latest"),
)
if url.User != "" {
optPrepared = append(optPrepared,
WithDockerEnv([]string{
fmt.Sprintf("MONGO_INITDB_ROOT_USERNAME=%s", url.User),
fmt.Sprintf("MONGO_INITDB_ROOT_PASSWORD=%s", url.Password),
}))
}
optPrepared = append(optPrepared, opt...)
tDB := newTDB(ctx, tb, mongoDriverName, dsn, optPrepared)
client, err := tDB.connectMongoDB(ctx)
if err != nil {
tb.Fatalf("cannot connect to mongo: %v", err)
}
tb.Cleanup(func() {
if tDB.mode != RunModeDocker {
// protect against closing connection during tests
clientClean, err := tDB.connectMongoDB(ctx)
if err != nil {
tb.Logf("cannot connect to mongo for cleanup: %v", err)
return
}
defer clientClean.Disconnect(ctx)
dbClean := clientClean.Database(tDB.databaseName)
if err := dbClean.Drop(ctx); err != nil {
tb.Logf("failed to drop database %s: %v", tDB.databaseName, err)
}
}
_ = client.Disconnect(context.Background())
})
return client.Database(tDB.databaseName), tDB
}
// connectMongoDB connects to MongoDB with retries
func (d *testDB) connectMongoDB(ctx context.Context) (*mongov1.Client, error) {
var (
client *mongov1.Client
err error
)
url := d.url.replaceDatabase(d.databaseName)
err = d.retryConnect(ctx, url.string(true), func() error {
client, err = mongov1.Connect(ctx, optionsv1.Client().ApplyURI(url.string(false)))
if err != nil {
return fmt.Errorf("mongo connect: %w", err)
}
if err = client.Ping(ctx, nil); err != nil {
return fmt.Errorf("mongo ping: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("connect mongo url (%s): %w", url.string(false), err)
}
return client, nil
}
package testdock
import (
"context"
"fmt"
"testing"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
// mongo driver name for separating sql and mongo
const mongoDriverName = "mongodb"
// GetMongoDatabaseV2 initializes a test MongoDB database, applies migrations, and returns a database connection.
func GetMongoDatabaseV2(tb testing.TB, dsn string, opt ...Option) (*mongo.Database, Informer) {
tb.Helper()
ctx := context.Background()
url, err := parseURL(dsn)
if err != nil {
tb.Fatalf("failed to parse dsn: %v", err)
}
optPrepared := make([]Option, 0, len(opt))
optPrepared = append(optPrepared,
WithDockerRepository("mongo"),
WithDockerImage("latest"),
)
if url.User != "" {
optPrepared = append(optPrepared,
WithDockerEnv([]string{
fmt.Sprintf("MONGO_INITDB_ROOT_USERNAME=%s", url.User),
fmt.Sprintf("MONGO_INITDB_ROOT_PASSWORD=%s", url.Password),
}))
}
optPrepared = append(optPrepared, opt...)
tDB := newTDB(ctx, tb, mongoDriverName, dsn, optPrepared)
client, err := tDB.connectMongoDBv2(ctx)
if err != nil {
tb.Fatalf("cannot connect to mongo: %v", err)
}
tb.Cleanup(func() {
if tDB.mode != RunModeDocker {
// protect against closing connection during tests
clientClean, err := tDB.connectMongoDBv2(ctx)
if err != nil {
tb.Logf("cannot connect to mongo for cleanup: %v", err)
return
}
defer clientClean.Disconnect(ctx)
dbClean := clientClean.Database(tDB.databaseName)
if err := dbClean.Drop(ctx); err != nil {
tb.Logf("failed to drop database %s: %v", tDB.databaseName, err)
}
}
_ = client.Disconnect(context.Background())
})
return client.Database(tDB.databaseName), tDB
}
// connectMongoDBv2 connects to MongoDB with retries
func (d *testDB) connectMongoDBv2(ctx context.Context) (*mongo.Client, error) {
var (
client *mongo.Client
err error
)
url := d.url.replaceDatabase(d.databaseName)
err = d.retryConnect(ctx, url.string(true), func() error {
client, err = mongo.Connect(options.Client().ApplyURI(url.string(false)))
if err != nil {
return fmt.Errorf("mongo connect: %w", err)
}
if err = client.Ping(ctx, nil); err != nil {
return fmt.Errorf("mongo ping: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("connect mongo url (%s): %w", url.string(false), err)
}
return client, nil
}
package testdock
import (
"database/sql"
"fmt"
"testing"
_ "github.com/go-sql-driver/mysql" // mysql driver
)
// GetMySQLConn inits a test mysql database, applies migrations.
// Use user root for docker test database.
func GetMySQLConn(tb testing.TB, dsn string, opt ...Option) (*sql.DB, Informer) {
tb.Helper()
url, err := parseURL(dsn)
if err != nil {
tb.Fatalf("failed to parse dsn: %v", err)
}
optPrepared := make([]Option, 0, len(opt))
optPrepared = append(optPrepared,
WithDockerRepository("mysql"),
WithDockerImage("9.1.0"),
WithDockerEnv([]string{
fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", url.Password),
fmt.Sprintf("MYSQL_DATABASE=%s", url.Database),
}),
)
optPrepared = append(optPrepared, opt...)
return GetSQLConn(tb, "mysql", dsn, optPrepared...)
}
package testdock
import (
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/google/uuid"
"github.com/n-r-w/ctxlog"
)
const (
// DefaultMongoDSN - default mongodb connection string.
DefaultMongoDSN = "mongodb://testuser:secret@127.0.0.1:27017/testdb?authSource=admin"
// DefaultMySQLDSN - default mysql connection string.
DefaultMySQLDSN = "root:secret@tcp(127.0.0.1:3306)/test_db"
// DefaultPostgresDSN - default postgres connection string.
DefaultPostgresDSN = "postgres://postgres:secret@127.0.0.1:5432/postgres?sslmode=disable"
)
// RunMode defines the run mode of the test database.
type RunMode int
const (
// RunModeUnknown - unknown run mode
RunModeUnknown RunMode = 0
// RunModeDocker - run the tests in docker
RunModeDocker RunMode = 1
// RunModeExternal - run the tests in external database
RunModeExternal RunMode = 2
// RunModeAuto - checks the environment variable TESTDOCK_DSN_[DRIVER]. If it is set,
// then RunModeExternal, otherwise RunModeDocker.
// If TESTDOCK_DSN_[DRIVER] is set and RunModeAuto, WithDSN option is ignored.
// For example, for postgres pgx driver:
// TESTDOCK_DSN_PGX=postgres://postgres:secret@localhost:5432/postgres&sslmode=disable
RunModeAuto RunMode = 3
)
// Option option for creating a test database.
type Option func(*testDB)
// WithMode sets the mode for the test database.
// The default is RunModeAuto.
func WithMode(mode RunMode) Option {
return func(o *testDB) {
o.mode = mode
}
}
// WithDockerRepository sets the name of docker hub repository.
// Required for RunModeDocker or RunModeAuto with empty environment variable TESTDOCK_DSN_[DRIVER].
func WithDockerRepository(dockerRepository string) Option {
return func(o *testDB) {
o.dockerRepository = dockerRepository
}
}
// WithDockerImage sets the name of the docker image.
// The default is `latest`.
func WithDockerImage(dockerImage string) Option {
return func(o *testDB) {
o.dockerImage = dockerImage
}
}
// WithDockerSocketEndpoint sets the docker socket endpoint for connecting to the docker daemon.
// The default is autodetect.
func WithDockerSocketEndpoint(dockerSocketEndpoint string) Option {
return func(o *testDB) {
o.dockerSocketEndpoint = dockerSocketEndpoint
}
}
// WithDockerPort sets the port for connecting to database in docker.
// The default is the port from the DSN.
func WithDockerPort(dockerPort int) Option {
return func(o *testDB) {
o.dockerPort = dockerPort
}
}
// WithRetryTimeout sets the timeout for connecting to the database.
// The default is 3 second. Must be less than totalRetryDuration.
func WithRetryTimeout(retryTimeout time.Duration) Option {
return func(o *testDB) {
o.retryTimeout = retryTimeout
}
}
// WithTotalRetryDuration sets the total retry duration.
// The default is 30 seconds. Must be greater than retryTimeout.
func WithTotalRetryDuration(totalRetryDuration time.Duration) Option {
return func(o *testDB) {
o.totalRetryDuration = totalRetryDuration
}
}
// WithLogger sets the logger for the test database.
// The default is logger from testing.TB.
func WithLogger(logger ctxlog.ILogger) Option {
return func(o *testDB) {
o.logger = logger
}
}
// WithMigrations sets the directory and factory for the migrations.
func WithMigrations(migrationsDir string, migrateFactory MigrateFactory) Option {
return func(o *testDB) {
o.migrationsDir = migrationsDir
o.MigrateFactory = migrateFactory
}
}
// WithDockerEnv sets the environment variables for the docker container.
// The default is empty.
func WithDockerEnv(dockerEnv []string) Option {
return func(o *testDB) {
o.dockerEnv = dockerEnv
}
}
// WithUnsetProxyEnv unsets the proxy environment variables.
// The default is false.
func WithUnsetProxyEnv(unsetProxyEnv bool) Option {
return func(o *testDB) {
o.unsetProxyEnv = unsetProxyEnv
}
}
// WithPrepareCleanUp sets the function for prepare to delete temporary test database.
// The default is empty, but `GetPgxPool` and `GetPqConn` use it
// to automatically apply cleanup handlers to disconnect all users from the database
// before cleaning up.
func WithPrepareCleanUp(prepareCleanUp PrepareCleanUp) Option {
return func(o *testDB) {
o.prepareCleanUp = append(o.prepareCleanUp, prepareCleanUp)
}
}
// WithConnectDatabase sets the name of the database to connect to.
// The default will be take from the DSN.
func WithConnectDatabase(connectDatabase string) Option {
return func(o *testDB) {
o.connectDatabase = connectDatabase
o.connectDatabaseOverride = true
}
}
func (d *testDB) prepareOptions(driver string, options []Option) error {
for _, o := range options {
o(d)
}
if d.totalRetryDuration <= d.retryTimeout {
return errors.New("totalRetryDuration must be greater than retryTimeout")
}
if d.driver == "" {
return errors.New("driver is empty")
}
if d.mode == RunModeAuto {
dsnEnv := os.Getenv(fmt.Sprintf("TESTDOCK_DSN_%s", strings.ToUpper(driver)))
if dsnEnv != "" {
d.dsn = dsnEnv
d.mode = RunModeExternal
} else {
d.mode = RunModeDocker
}
}
if d.dsn == "" {
return errors.New("dsn is empty")
}
p, err := parseURL(d.dsn)
if err != nil {
return fmt.Errorf("parse dsn: %w", err)
}
d.url = p
d.dsnNoPass = p.string(true)
if !d.connectDatabaseOverride && d.connectDatabase == "" {
d.connectDatabase = p.Database
}
if d.mode == RunModeDocker {
if d.dockerRepository == "" {
return errors.New("dockerRepository is empty")
}
if d.dockerImage == "" {
d.dockerImage = "latest"
}
if d.dockerPort <= 0 {
d.dockerPort = p.Port
if d.dockerPort <= 0 {
return errors.New("dockerPort must be greater than 0")
}
}
}
dbName := fmt.Sprintf("t_%s_%s", time.Now().Format("2006_0102_1504_05"), uuid.New().String())
d.databaseName = strings.ReplaceAll(dbName, "-", "")
if (d.MigrateFactory == nil) != (d.migrationsDir == "") {
return errors.New("MigrateFactory and migrationsDir must be set together")
}
return nil
}
package testdock
import (
"context"
"database/sql"
"fmt"
"testing"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/jackc/pgx/v5/stdlib" // pgx postgres driver
_ "github.com/lib/pq" // pq postgres driver
)
// GetPgxPool inits a test postgresql (pgx driver) database, applies migrations,
// and returns pgx connection pool to the database.
func GetPgxPool(tb testing.TB, dsn string, opt ...Option) (*pgxpool.Pool, Informer) {
tb.Helper()
ctx := context.Background()
tDB := newTDB(ctx, tb, "pgx", dsn, getPostgresOptions(tb, dsn, opt...))
db, err := tDB.connectPgxDB(ctx)
if err != nil {
tb.Fatalf("cannot connect to postgres: %v", err)
}
tb.Cleanup(func() { db.Close() })
return db, tDB
}
// GetPqConn inits a test postgresql (pq driver) database, applies migrations,
// and returns sql connection to the database.
func GetPqConn(ctx context.Context, tb testing.TB, dsn string, opt ...Option) (*sql.DB, Informer) {
tb.Helper()
tDB := newTDB(ctx, tb, "postgres", dsn, getPostgresOptions(tb, dsn, opt...))
db, err := tDB.connectSQLDB(ctx, true)
if err != nil {
tb.Fatalf("cannot connect to postgres: %v", err)
}
tb.Cleanup(func() { _ = db.Close() })
return db, tDB
}
// connectPgxDB connects to the database with retries using pgx.
func (d *testDB) connectPgxDB(ctx context.Context) (*pgxpool.Pool, error) {
var db *pgxpool.Pool
dbURL := d.url.replaceDatabase(d.databaseName)
d.logger.Info(ctx, "connecting to test database", "url", dbURL.string(true))
err := d.retryConnect(ctx, dbURL.string(true), func() (err error) {
db, err = pgxpool.New(ctx, dbURL.string(false))
if err != nil {
return err
}
if err = db.Ping(ctx); err != nil {
db.Close()
return err
}
return nil
})
if err != nil {
return nil, fmt.Errorf("connect postgres url (%s): %w", dbURL.string(false), err)
}
return db, nil
}
// disconnect users before deleting the database
func disconnectUsers(db *sql.DB, databaseName string) error {
_, err := db.Exec(
`SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE datname = $1 AND pid <> pg_backend_pid()`,
databaseName)
return err
}
// getPostgresOptions returns the options for the postgresql database.
func getPostgresOptions(tb testing.TB, dsn string, opt ...Option) []Option {
tb.Helper()
url, err := parseURL(dsn)
if err != nil {
tb.Fatalf("failed to parse dsn: %v", err)
}
optPrepared := make([]Option, 0, len(opt))
optPrepared = append(optPrepared,
WithDockerRepository("postgres"),
WithPrepareCleanUp(disconnectUsers),
WithDockerEnv([]string{
fmt.Sprintf("POSTGRES_USER=%s", url.User),
fmt.Sprintf("POSTGRES_PASSWORD=%s", url.Password),
fmt.Sprintf("POSTGRES_DB=%s", url.Database),
"listen_addresses = '*'",
"max_connections = 1000",
}),
)
optPrepared = append(optPrepared, opt...)
return optPrepared
}
package testdock
import (
"context"
"database/sql"
"fmt"
"testing"
)
// GetSQLConn inits a test database, applies migrations, and returns sql connection to the database.
// driver: https://go.dev/wiki/SQLDrivers.
// Do not forget to import corresponding driver package.
func GetSQLConn(tb testing.TB, driver, dsn string, opt ...Option) (*sql.DB, Informer) {
tb.Helper()
ctx := context.Background()
tDB := newTDB(ctx, tb, driver, dsn, opt)
db, err := tDB.connectSQLDB(ctx, true)
if err != nil {
tb.Fatalf("cannot connect to database: %v", err)
}
tb.Cleanup(func() { _ = db.Close() })
return db, tDB
}
// connectSQLDB connects to the database with retries using database/sql.
// testDatabase: if true, will be connected to the temporary test database.
func (d *testDB) connectSQLDB(ctx context.Context, testDatabase bool) (*sql.DB, error) {
var dbURL *dbURL
if testDatabase {
dbURL = d.url.replaceDatabase(d.databaseName)
} else {
dbURL = d.url.replaceDatabase(d.connectDatabase)
}
d.logger.Info(ctx, "connecting to test database", "url", dbURL.string(true))
var db *sql.DB
err := d.retryConnect(ctx, dbURL.string(true), func() (err error) {
db, err = sql.Open(d.driver, dbURL.string(false))
if err != nil {
return err
}
if err = db.Ping(); err != nil {
_ = db.Close()
return err
}
return nil
})
if err != nil {
return nil, fmt.Errorf("connect url (%s): %w", dbURL.string(false), err)
}
return db, nil
}
func (d *testDB) createSQLDatabase(ctx context.Context) error {
d.logger.Info(ctx, "creating new test sql database", "dsn", d.dsnNoPass, "database", d.databaseName)
db, err := d.connectSQLDB(ctx, false)
if err != nil {
return err
}
defer db.Close()
_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", d.databaseName))
if err != nil {
return fmt.Errorf("create db: %w", err)
}
d.logger.Info(ctx, "new test sql database created", "dsn", d.dsnNoPass, "database", d.databaseName)
return nil
}
package testdock
import (
"errors"
"fmt"
"sort"
"strconv"
"strings"
)
// dbURL represents a parsed database connection string.
// Supported connection string format:
// [protocol://]user:password@[transport(]host:port[)][/database][?option1=a&option2=b]
//
// Required fields: user, password, host, port
// Optional fields: protocol, transport, database and options.
type dbURL struct {
Protocol string
Transport string
User string
Password string
Host string
Port int
Database string
Options map[string]string // option1=a&option2=b -> {"option1": "a", "option2": "b"}
}
// parseURL parses a connection string into a URL.
func parseURL(connStr string) (*dbURL, error) {
if connStr == "" {
return nil, errors.New("connection string cannot be empty")
}
u := &dbURL{
Options: make(map[string]string),
}
const splitCount = 2
var rest string
// Split protocol and the rest
parts := strings.SplitN(connStr, "://", splitCount)
if len(parts) == splitCount {
// Parse protocol
u.Protocol = parts[0]
if u.Protocol == "" {
return nil, errors.New("invalid connection string format: '://' exists, but no protocol")
}
rest = parts[1]
} else {
rest = connStr
}
// Find the last @ to properly handle @ in passwords
atIndex := strings.LastIndex(rest, "@")
if atIndex >= 0 {
credentials := rest[:atIndex]
rest = rest[atIndex+1:]
// Parse credentials
credParts := strings.SplitN(credentials, ":", splitCount)
if len(credParts) != splitCount {
return nil, errors.New("invalid connection string format: missing password")
}
u.User = credParts[0]
if u.User == "" {
return nil, errors.New("user is required")
}
u.Password = credParts[1]
if u.Password == "" {
return nil, errors.New("password is required")
}
}
// Split query parameters if they exist
hostAndQuery := strings.SplitN(rest, "?", splitCount)
rest = hostAndQuery[0]
// Parse query parameters if they exist
if len(hostAndQuery) > 1 {
queryStr := hostAndQuery[1]
for _, param := range strings.Split(queryStr, "&") {
kv := strings.SplitN(param, "=", splitCount)
if len(kv) == splitCount {
u.Options[kv[0]] = kv[1]
}
}
}
// Parse database if exists
hostAndDB := strings.SplitN(rest, "/", splitCount)
rest = hostAndDB[0]
if len(hostAndDB) > 1 {
u.Database = hostAndDB[1]
}
// Check if transport is specified
if strings.Contains(rest, "(") && strings.HasSuffix(rest, ")") {
transportParts := strings.SplitN(rest, "(", splitCount)
if len(transportParts) != splitCount {
return nil, errors.New("invalid connection string format: malformed transport")
}
u.Transport = transportParts[0]
rest = strings.TrimSuffix(transportParts[1], ")")
}
if rest == "" {
return nil, errors.New("host is required")
}
// Parse host and port
hostAndPort := strings.SplitN(rest, ":", splitCount)
if len(hostAndPort) != splitCount {
return nil, errors.New("invalid connection string format: missing port")
}
u.Host = hostAndPort[0]
if u.Host == "" {
return nil, errors.New("host is required")
}
if hostAndPort[1] == "" {
return nil, errors.New("port is required")
}
p, err := strconv.Atoi(hostAndPort[1])
if err != nil {
return nil, fmt.Errorf("parse port: %w", err)
}
if p <= 0 {
return nil, errors.New("port must be positive")
}
u.Port = p
return u, nil
}
// string returns the connection string representation of the URL.
func (u *dbURL) string(hidePassword bool) string {
if u == nil {
return ""
}
var b strings.Builder
// Write protocol
if u.Protocol != "" {
b.WriteString(u.Protocol)
b.WriteString("://")
}
if u.User != "" {
// Write credentials
b.WriteString(u.User)
b.WriteString(":")
if hidePassword {
b.WriteString("*****")
} else {
b.WriteString(u.Password)
}
b.WriteString("@")
}
// Write transport, host and port
if u.Transport != "" {
b.WriteString(u.Transport)
b.WriteString("(")
}
b.WriteString(u.Host)
if u.Port != 0 {
b.WriteString(":" + strconv.Itoa(u.Port))
}
if u.Transport != "" {
b.WriteString(")")
}
// Write database if exists
if u.Database != "" {
b.WriteString("/" + u.Database)
}
// Write options if exist
if len(u.Options) > 0 {
b.WriteString("?")
// Sort keys for deterministic output
keys := make([]string, 0, len(u.Options))
for k := range u.Options {
keys = append(keys, k)
}
sort.Strings(keys)
for i, k := range keys {
if i > 0 {
b.WriteString("&")
}
b.WriteString(k)
b.WriteString("=")
b.WriteString(u.Options[k])
}
}
return b.String()
}
// clone returns a copy of the URL.
func (u *dbURL) clone() *dbURL {
if u == nil {
return nil
}
clone := &dbURL{
Protocol: u.Protocol,
Transport: u.Transport,
User: u.User,
Password: u.Password,
Host: u.Host,
Port: u.Port,
Database: u.Database,
Options: make(map[string]string, len(u.Options)),
}
// Deep copy the options map
for k, v := range u.Options {
clone.Options[k] = v
}
return clone
}
// replaceDatabase replaces the database name in the URL.
func (u *dbURL) replaceDatabase(newDBName string) *dbURL {
clone := u.clone()
clone.Database = newDBName
return clone
}