package testdock import ( "context" "database/sql" "fmt" "sync" "testing" "time" "github.com/cenkalti/backoff/v5" "github.com/n-r-w/ctxlog" ) // Informer interface for database information. type Informer interface { // DSN returns the real database connection string. DSN() string // Host returns the host of the database server. Host() string // Port returns the port of the database server. Port() int // DatabaseName returns the database name for testing. DatabaseName() string } const ( // DefaultRetryTimeout is the default retry timeout. DefaultRetryTimeout = time.Second * 3 // DefaultTotalRetryDuration is the default total retry duration. DefaultTotalRetryDuration = time.Second * 30 ) // PrepareCleanUp - function for prepare to delete temporary test database. // For example, disconnect users. type PrepareCleanUp func(db *sql.DB, databaseName string) error // testDB represents a test database. type testDB struct { t testing.TB logger ctxlog.ILogger // unified way to logging databaseName string // name of the test database url *dbURL // parsed database connection string dsnNoPass string // database connection string without password // options driver string // database driver (pgx, pq, etc) mode RunMode // run mode (docker or external) dsn string // database connection string retryTimeout time.Duration // retry timeout for connecting to the database totalRetryDuration time.Duration // total retry duration migrationsDir string // migrations directory unsetProxyEnv bool // unset HTTP_PROXY, HTTPS_PROXY etc. environment variables MigrateFactory MigrateFactory // unified way to create a migrations prepareCleanUp []PrepareCleanUp // function for prepare to delete temporary test database. connectDatabase string // database name for connecting to the database server connectDatabaseOverride bool dockerPort int // docker port dockerRepository string // docker hub repository dockerImage string // docker hub image tag dockerSocketEndpoint string // docker socket endpoint for connecting to the docker daemon dockerEnv []string // environment variables for the docker container } var ( globalMu sync.Mutex globalMuByDSN = make(map[string]*sync.Mutex) ) // newTDB creates a new test database and applies migrations. func newTDB(ctx context.Context, tb testing.TB, driver, dsn string, opt []Option) *testDB { tb.Helper() var ( db = &testDB{ t: tb, logger: ctxlog.Must(ctxlog.WithTesting(tb)), driver: driver, dsn: dsn, mode: RunModeAuto, retryTimeout: DefaultRetryTimeout, totalRetryDuration: DefaultTotalRetryDuration, } errResult error ) defer func() { if errResult != nil { tb.Fatalf("cannot create test database: %v", errResult) } }() if errResult = db.prepareOptions(driver, opt); errResult != nil { return nil } globalMu.Lock() mu, ok := globalMuByDSN[db.dsn] if !ok { mu = &sync.Mutex{} globalMuByDSN[db.dsn] = mu } globalMu.Unlock() mu.Lock() defer mu.Unlock() if db.mode == RunModeDocker { db.logger.Info(ctx, "using docker test database", "dsn", db.dsnNoPass) if errResult = db.createDockerResources(ctx); errResult != nil { return nil } } else { db.logger.Info(ctx, "using real test database", "dsn", db.dsnNoPass) } if errResult = db.createTestDatabase(ctx); errResult != nil { if err := db.close(ctx); err != nil { db.logger.Info(ctx, "failed to close test database", "dsn", db.dsnNoPass, "error", err) } return nil } if db.migrationsDir != "" { if errResult = db.migrationsUp(ctx); errResult != nil { return nil } } tb.Cleanup(func() { ctx := context.Background() if err := db.close(ctx); err != nil { db.logger.Info(ctx, "failed to close test database", "dsn", db.dsnNoPass, "error", err) } else { db.logger.Info(ctx, "test database closed", "dsn", db.dsnNoPass) } }) return db } // migrationsUp applies migrations to the database. func (d *testDB) migrationsUp(ctx context.Context) error { d.logger.Info(ctx, "migrations up start", "dsn", d.dsnNoPass) defer d.logger.Info(ctx, "migrations up end", "dsn", d.dsnNoPass) dsn := d.url.replaceDatabase(d.databaseName).string(false) migrator, err := d.MigrateFactory(d.t, dsn, d.migrationsDir, d.logger) if err != nil { return fmt.Errorf("new migrator: %w", err) } if err = migrator.Up(context.Background()); err != nil { return fmt.Errorf("up migrations: %w", err) } return nil } // close closes the test database. func (d *testDB) close(ctx context.Context) error { if d.mode != RunModeDocker { if d.driver == mongoDriverName { return nil } // remove the database created before applying the migrations d.logger.Info(ctx, "deleting test database", "dsn", d.dsnNoPass, "database", d.databaseName) dsn := d.url.string(false) db, err := sql.Open(d.driver, dsn) if err != nil { return fmt.Errorf("sql open url (%s): %w", dsn, err) } defer func() { _ = db.Close() }() for _, prepareCleanUp := range d.prepareCleanUp { if err := prepareCleanUp(db, d.databaseName); err != nil { d.logger.Info(ctx, "failed to prepare clean up", "dsn", d.dsnNoPass, "error", err) } } if _, err = db.Exec(fmt.Sprintf("DROP DATABASE %s", d.databaseName)); err != nil { return fmt.Errorf("drop db: %w", err) } d.logger.Info(ctx, "test database deleted", "dsn", d.dsnNoPass, "database", d.databaseName) } return nil } // initDatabase creates a test database or connects to an existing one. func (d *testDB) createTestDatabase(ctx context.Context) error { if d.driver == mongoDriverName { return nil } return d.createSQLDatabase(ctx) } // retryConnect connects to the database with retries. func (d *testDB) retryConnect(ctx context.Context, info string, op func() error) error { var attempt int operation := func() (struct{}, error) { if err := op(); err != nil { d.logger.Info(ctx, "retrying operation", "info", info, "attempt", attempt, "error", err) attempt++ return struct{}{}, err } return struct{}{}, nil } _, err := backoff.Retry( context.Background(), operation, backoff.WithBackOff(backoff.NewConstantBackOff(d.retryTimeout)), backoff.WithMaxElapsedTime(d.totalRetryDuration), ) if err != nil { return fmt.Errorf("retry failed after %d attempts: %w", attempt, err) } return nil } // DSN returns the real database connection string. func (d *testDB) DSN() string { return d.url.replaceDatabase(d.databaseName).string(false) } // Host returns the database host. func (d *testDB) Host() string { return d.url.Host } // Port returns the database port. func (d *testDB) Port() int { return d.url.Port } // DatabaseName returns the database name for testing. func (d *testDB) DatabaseName() string { return d.databaseName }
package testdock import ( "context" "fmt" "os" "strconv" "strings" "sync" "time" "github.com/cenkalti/backoff/v5" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" ) // we ensure the creation of docker resources only once for all tests var ( globalDockerMu sync.Mutex globalDockerResources = make(map[string]*dockerResourceInfo) globalDockerPool *dockertest.Pool ) type dockerResourceInfo struct { resource *dockertest.Resource port int count int mu sync.Mutex } // createDockerResources create a pool and a resource for creating a test database in docker. func (d *testDB) createDockerResources(ctx context.Context) error { //nolint:gocognit // ok globalDockerMu.Lock() info, ok := globalDockerResources[d.dsn] if !ok { info = &dockerResourceInfo{} } var ( err error logDsn = d.dsnNoPass ) if globalDockerPool == nil { globalDockerPool, err = dockertest.NewPool(d.dockerSocketEndpoint) if err != nil { globalDockerMu.Unlock() return fmt.Errorf("dockertest NewPool: %w", err) } if d.unsetProxyEnv { // we clear the proxy environment variables, because they can interfere with the work of docker proxyEnv := []string{ "HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy", } for _, env := range proxyEnv { if os.Getenv(env) != "" { d.logger.Info(ctx, "unset proxy env", "component", "docker", "env", env) _ = os.Unsetenv(env) } } } err = globalDockerPool.Client.Ping() if err != nil { globalDockerMu.Unlock() return fmt.Errorf("dockertest ping: %w", err) } d.logger.Info(ctx, "pool created", "component", "docker") defer func() { globalDockerMu.Lock() defer globalDockerMu.Unlock() if len(globalDockerResources) == 0 { globalDockerPool = nil d.logger.Info(ctx, "pool purged", "component", "docker") } }() } globalDockerMu.Unlock() info.mu.Lock() defer info.mu.Unlock() if info.count == 0 { // docker releases the port after calling globalDockerPool.Purge(globalDockerResource) not instantly, so we try several times const ( maxAttempts = 10 sleepTime = 5 * time.Second ) var ( attempt int dockerPort = fmt.Sprintf("%d/tcp", d.dockerPort) ) for { info.resource, err = globalDockerPool.RunWithOptions(&dockertest.RunOptions{ Repository: d.dockerRepository, Tag: d.dockerImage, Env: d.dockerEnv, PortBindings: map[docker.Port][]docker.PortBinding{ docker.Port(dockerPort): {{ HostIP: d.url.Host, HostPort: strconv.Itoa(d.url.Port), }}, }, }, func(config *docker.HostConfig) { config.AutoRemove = true config.RestartPolicy = docker.RestartPolicy{Name: "no"} }) if err == nil { break } bindErrors := []string{ "address already in use", "port is already allocated", "failed to bind host port", } needNextPort := false for _, bindError := range bindErrors { if strings.Contains(err.Error(), bindError) { needNextPort = true break } } if needNextPort { // increase hostPort by 1 d.logger.Info(ctx, "port is already allocated, trying next port", "dsn", logDsn, "next_port", d.url.Port+1) d.url.Port++ continue } attempt++ if attempt >= maxAttempts { break } d.logger.Info(ctx, "RunWithOptions failed", "component", "docker", "dsn", logDsn, "attempt", attempt, "error", err) time.Sleep(sleepTime) } if err != nil { return fmt.Errorf("dockertest RunWithOptions: %w", err) } info.port = d.url.Port d.logger.Info(ctx, "resources created", "component", "docker", "dsn", logDsn) } else { d.url.Port = info.port // restore port d.logger.Info(ctx, "use existing resources", "component", "docker", "dsn", logDsn) } globalDockerMu.Lock() globalDockerResources[d.dsn] = info globalDockerMu.Unlock() info.count++ d.t.Cleanup(func() { ctx := context.Background() info.mu.Lock() defer info.mu.Unlock() info.count-- if info.count == 0 { globalDockerMu.Lock() defer globalDockerMu.Unlock() delete(globalDockerResources, d.dsn) const ( maxTime = 10 * time.Second retryTimeout = 1 * time.Second ) var attempt int operation := func() (struct{}, error) { if err := globalDockerPool.Purge(info.resource); err != nil { attempt++ // Closure needs access to context, so we'll pass background context since this is a cleanup function d.logger.Info(ctx, "purge attempt failed", "component", "docker", "dsn", logDsn, "attempt", attempt, "error", err) return struct{}{}, err } return struct{}{}, nil } if _, err := backoff.Retry(ctx, operation, backoff.WithBackOff(backoff.NewConstantBackOff(retryTimeout)), backoff.WithMaxElapsedTime(maxTime)); err != nil { d.logger.Info(ctx, "purge failed after retries", "component", "docker", "dsn", logDsn, "attempt", attempt, "error", err) } else { d.logger.Info(ctx, "resources purged successfully", "component", "docker", "dsn", logDsn, "attempts", attempt) } } }) return nil }
package testdock import ( "context" "database/sql" "fmt" "os" "path/filepath" "testing" "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/mongodb" // require for mongodb _ "github.com/golang-migrate/migrate/v4/database/postgres" // require for gomigrate _ "github.com/golang-migrate/migrate/v4/source/file" // require for gomigrate "github.com/n-r-w/ctxlog" "github.com/pressly/goose/v3" ) // MigrateFactory creates a new migrator. type MigrateFactory func(t testing.TB, dsn string, migrationsDir string, logger ctxlog.ILogger) (Migrator, error) // Migrator interface for applying migrations. type Migrator interface { Up(ctx context.Context) error } var ( // GooseMigrateFactoryPGX is a migrator for https://github.com/pressly/goose with pgx driver. GooseMigrateFactoryPGX = GooseMigrateFactory(goose.DialectPostgres, "pgx") // GooseMigrateFactoryPQ is a migrator for https://github.com/pressly/goose with pq driver. GooseMigrateFactoryPQ = GooseMigrateFactory(goose.DialectPostgres, "postgres") // GooseMigrateFactoryMySQL is a migrator for https://github.com/pressly/goose with mysql driver. GooseMigrateFactoryMySQL = GooseMigrateFactory(goose.DialectMySQL, "mysql") ) // GooseMigrateFactory creates a new migrator for https://github.com/pressly/goose. func GooseMigrateFactory(dialect goose.Dialect, driver string) MigrateFactory { return func(t testing.TB, dsn, migrationsDir string, logger ctxlog.ILogger) (Migrator, error) { return newGooseMigrator(t, dialect, driver, dsn, migrationsDir, logger) } } // gooseMigrator is a migrator for goose. type gooseMigrator struct { p *goose.Provider } // newGooseMigrator creates a new migrator for goose. func newGooseMigrator(t testing.TB, dialect goose.Dialect, driver, dsn, migrationsDir string, logger ctxlog.ILogger) (*gooseMigrator, error) { conn, err := sql.Open(driver, dsn) if err != nil { return nil, fmt.Errorf("sql open url (%s): %w", dsn, err) } p, err := goose.NewProvider(dialect, conn, os.DirFS(migrationsDir), goose.WithLogger(NewGooseLogger(t, logger)), goose.WithVerbose(true), ) if err != nil { _ = conn.Close() return nil, fmt.Errorf("new goose provider: %w", err) } return &gooseMigrator{ p: p, }, nil } func (m *gooseMigrator) Up(ctx context.Context) error { defer m.p.Close() _, err := m.p.Up(ctx) return err } // GolangMigrateFactory creates a new migrator for https://github.com/golang-migrate/migrate. func GolangMigrateFactory(_ testing.TB, dsn, migrationsDir string, logger ctxlog.ILogger) (Migrator, error) { return newGolangMigrateMigrator(dsn, migrationsDir, logger) } // golangMigrateMigrator is a migrator for https://github.com/golang-migrate/migrate. type golangMigrateMigrator struct { m *migrate.Migrate } // newGolangMigrateMigrator creates a new migrator for https://github.com/golang-migrate/migrate. func newGolangMigrateMigrator(dsn, migrationsDir string, logger ctxlog.ILogger) (*golangMigrateMigrator, error) { if !filepath.IsAbs(migrationsDir) { var err error migrationsDir, err = filepath.Abs(migrationsDir) if err != nil { return nil, fmt.Errorf("get absolute path: %w", err) } } m, err := migrate.New("file://"+migrationsDir, dsn) if err != nil { return nil, fmt.Errorf("new migrate: %w", err) } m.Log = NewGolangMigrateLogger(logger) return &golangMigrateMigrator{m: m}, nil } func (m *golangMigrateMigrator) Up(_ context.Context) error { return m.m.Up() } // GooseLogger is a logger for goose. type GooseLogger struct { t testing.TB l ctxlog.ILogger } // NewGooseLogger creates a new goose logger. func NewGooseLogger(t testing.TB, l ctxlog.ILogger) *GooseLogger { return &GooseLogger{t: t, l: l} } // Fatalf logs a fatal error. func (l GooseLogger) Fatalf(format string, v ...any) { l.t.Fatalf(format, v...) } // Printf logs a message. func (l GooseLogger) Printf(format string, v ...any) { l.l.Info(context.Background(), fmt.Sprintf(format, v...)) } // GolangMigrateLogger is a logger for golang-migrate. type GolangMigrateLogger struct { l ctxlog.ILogger } // NewGolangMigrateLogger creates a new golang-migrate logger. func NewGolangMigrateLogger(l ctxlog.ILogger) *GolangMigrateLogger { return &GolangMigrateLogger{l: l} } // Printf logs a message. func (g *GolangMigrateLogger) Printf(format string, v ...any) { g.l.Info(context.Background(), fmt.Sprintf(format, v...)) } // Verbose returns true. func (g *GolangMigrateLogger) Verbose() bool { return true }
package testdock import ( "context" "fmt" "testing" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) // mongo driver name for separating sql and mongo const mongoDriverName = "mongodb" // GetMongoDatabase initializes a test MongoDB database, applies migrations, and returns a database connection. func GetMongoDatabase(tb testing.TB, dsn string, opt ...Option) (*mongo.Database, Informer) { tb.Helper() ctx := context.Background() url, err := parseURL(dsn) if err != nil { tb.Fatalf("failed to parse dsn: %v", err) } optPrepared := make([]Option, 0, len(opt)) optPrepared = append(optPrepared, WithDockerRepository("mongo"), WithDockerImage("latest"), ) if url.User != "" { optPrepared = append(optPrepared, WithDockerEnv([]string{ fmt.Sprintf("MONGO_INITDB_ROOT_USERNAME=%s", url.User), fmt.Sprintf("MONGO_INITDB_ROOT_PASSWORD=%s", url.Password), })) } optPrepared = append(optPrepared, opt...) tDB := newTDB(ctx, tb, mongoDriverName, dsn, optPrepared) client, err := tDB.connectMongoDB(ctx) if err != nil { tb.Fatalf("cannot connect to mongo: %v", err) } mongoDatabase := client.Database(tDB.databaseName) tb.Cleanup(func() { if tDB.mode != RunModeDocker { if err := mongoDatabase.Drop(ctx); err != nil { tb.Logf("failed to drop database %s: %v", tDB.databaseName, err) } } _ = client.Disconnect(context.Background()) }) return mongoDatabase, tDB } // connectDB connects to MongoDB with retries func (d *testDB) connectMongoDB(ctx context.Context) (*mongo.Client, error) { var ( client *mongo.Client err error ) url := d.url.replaceDatabase(d.databaseName) err = d.retryConnect(ctx, url.string(true), func() error { client, err = mongo.Connect(ctx, options.Client().ApplyURI(url.string(false))) if err != nil { return fmt.Errorf("mongo connect: %w", err) } if err = client.Ping(ctx, nil); err != nil { return fmt.Errorf("mongo ping: %w", err) } return nil }) if err != nil { return nil, fmt.Errorf("connect mongo url (%s): %w", url.string(false), err) } return client, nil }
package testdock import ( "database/sql" "fmt" "testing" _ "github.com/go-sql-driver/mysql" // mysql driver ) // GetMySQLConn inits a test mysql database, applies migrations. // Use user root for docker test database. func GetMySQLConn(tb testing.TB, dsn string, opt ...Option) (*sql.DB, Informer) { tb.Helper() url, err := parseURL(dsn) if err != nil { tb.Fatalf("failed to parse dsn: %v", err) } optPrepared := make([]Option, 0, len(opt)) optPrepared = append(optPrepared, WithDockerRepository("mysql"), WithDockerImage("9.1.0"), WithDockerEnv([]string{ fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", url.Password), fmt.Sprintf("MYSQL_DATABASE=%s", url.Database), }), ) optPrepared = append(optPrepared, opt...) return GetSQLConn(tb, "mysql", dsn, optPrepared...) }
package testdock import ( "errors" "fmt" "os" "strings" "time" "github.com/google/uuid" "github.com/n-r-w/ctxlog" ) const ( // DefaultMongoDSN - default mongodb connection string. DefaultMongoDSN = "mongodb://testuser:secret@127.0.0.1:27017/testdb?authSource=admin" // DefaultMySQLDSN - default mysql connection string. DefaultMySQLDSN = "root:secret@tcp(127.0.0.1:3306)/test_db" // DefaultPostgresDSN - default postgres connection string. DefaultPostgresDSN = "postgres://postgres:secret@127.0.0.1:5432/postgres?sslmode=disable" ) // RunMode defines the run mode of the test database. type RunMode int const ( // RunModeUnknown - unknown run mode RunModeUnknown RunMode = 0 // RunModeDocker - run the tests in docker RunModeDocker RunMode = 1 // RunModeExternal - run the tests in external database RunModeExternal RunMode = 2 // RunModeAuto - checks the environment variable TESTDOCK_DSN_[DRIVER]. If it is set, // then RunModeExternal, otherwise RunModeDocker. // If TESTDOCK_DSN_[DRIVER] is set and RunModeAuto, WithDSN option is ignored. // For example, for postgres pgx driver: // TESTDOCK_DSN_PGX=postgres://postgres:secret@localhost:5432/postgres&sslmode=disable RunModeAuto RunMode = 3 ) // Option option for creating a test database. type Option func(*testDB) // WithMode sets the mode for the test database. // The default is RunModeAuto. func WithMode(mode RunMode) Option { return func(o *testDB) { o.mode = mode } } // WithDockerRepository sets the name of docker hub repository. // Required for RunModeDocker or RunModeAuto with empty environment variable TESTDOCK_DSN_[DRIVER]. func WithDockerRepository(dockerRepository string) Option { return func(o *testDB) { o.dockerRepository = dockerRepository } } // WithDockerImage sets the name of the docker image. // The default is `latest`. func WithDockerImage(dockerImage string) Option { return func(o *testDB) { o.dockerImage = dockerImage } } // WithDockerSocketEndpoint sets the docker socket endpoint for connecting to the docker daemon. // The default is autodetect. func WithDockerSocketEndpoint(dockerSocketEndpoint string) Option { return func(o *testDB) { o.dockerSocketEndpoint = dockerSocketEndpoint } } // WithDockerPort sets the port for connecting to database in docker. // The default is the port from the DSN. func WithDockerPort(dockerPort int) Option { return func(o *testDB) { o.dockerPort = dockerPort } } // WithRetryTimeout sets the timeout for connecting to the database. // The default is 3 second. Must be less than totalRetryDuration. func WithRetryTimeout(retryTimeout time.Duration) Option { return func(o *testDB) { o.retryTimeout = retryTimeout } } // WithTotalRetryDuration sets the total retry duration. // The default is 30 seconds. Must be greater than retryTimeout. func WithTotalRetryDuration(totalRetryDuration time.Duration) Option { return func(o *testDB) { o.totalRetryDuration = totalRetryDuration } } // WithLogger sets the logger for the test database. // The default is logger from testing.TB. func WithLogger(logger ctxlog.ILogger) Option { return func(o *testDB) { o.logger = logger } } // WithMigrations sets the directory and factory for the migrations. func WithMigrations(migrationsDir string, migrateFactory MigrateFactory) Option { return func(o *testDB) { o.migrationsDir = migrationsDir o.MigrateFactory = migrateFactory } } // WithDockerEnv sets the environment variables for the docker container. // The default is empty. func WithDockerEnv(dockerEnv []string) Option { return func(o *testDB) { o.dockerEnv = dockerEnv } } // WithUnsetProxyEnv unsets the proxy environment variables. // The default is false. func WithUnsetProxyEnv(unsetProxyEnv bool) Option { return func(o *testDB) { o.unsetProxyEnv = unsetProxyEnv } } // WithPrepareCleanUp sets the function for prepare to delete temporary test database. // The default is empty, but `GetPgxPool` and `GetPqConn` use it // to automatically apply cleanup handlers to disconnect all users from the database // before cleaning up. func WithPrepareCleanUp(prepareCleanUp PrepareCleanUp) Option { return func(o *testDB) { o.prepareCleanUp = append(o.prepareCleanUp, prepareCleanUp) } } // WithConnectDatabase sets the name of the database to connect to. // The default will be take from the DSN. func WithConnectDatabase(connectDatabase string) Option { return func(o *testDB) { o.connectDatabase = connectDatabase o.connectDatabaseOverride = true } } func (d *testDB) prepareOptions(driver string, options []Option) error { for _, o := range options { o(d) } if d.totalRetryDuration <= d.retryTimeout { return errors.New("totalRetryDuration must be greater than retryTimeout") } if d.driver == "" { return errors.New("driver is empty") } if d.mode == RunModeAuto { dsnEnv := os.Getenv(fmt.Sprintf("TESTDOCK_DSN_%s", strings.ToUpper(driver))) if dsnEnv != "" { d.dsn = dsnEnv d.mode = RunModeExternal } else { d.mode = RunModeDocker } } if d.dsn == "" { return errors.New("dsn is empty") } p, err := parseURL(d.dsn) if err != nil { return fmt.Errorf("parse dsn: %w", err) } d.url = p d.dsnNoPass = p.string(true) if !d.connectDatabaseOverride && d.connectDatabase == "" { d.connectDatabase = p.Database } if d.mode == RunModeDocker { if d.dockerRepository == "" { return errors.New("dockerRepository is empty") } if d.dockerImage == "" { d.dockerImage = "latest" } if d.dockerPort <= 0 { d.dockerPort = p.Port if d.dockerPort <= 0 { return errors.New("dockerPort must be greater than 0") } } } dbName := fmt.Sprintf("t_%s_%s", time.Now().Format("2006_0102_1504_05"), uuid.New().String()) d.databaseName = strings.ReplaceAll(dbName, "-", "") if (d.MigrateFactory == nil) != (d.migrationsDir == "") { return errors.New("MigrateFactory and migrationsDir must be set together") } return nil }
package testdock import ( "context" "database/sql" "fmt" "testing" "github.com/jackc/pgx/v5/pgxpool" _ "github.com/jackc/pgx/v5/stdlib" // pgx postgres driver _ "github.com/lib/pq" // pq postgres driver ) // GetPgxPool inits a test postgresql (pgx driver) database, applies migrations, // and returns pgx connection pool to the database. func GetPgxPool(tb testing.TB, dsn string, opt ...Option) (*pgxpool.Pool, Informer) { tb.Helper() ctx := context.Background() tDB := newTDB(ctx, tb, "pgx", dsn, getPostgresOptions(tb, dsn, opt...)) db, err := tDB.connectPgxDB(ctx) if err != nil { tb.Fatalf("cannot connect to postgres: %v", err) } tb.Cleanup(func() { db.Close() }) return db, tDB } // GetPqConn inits a test postgresql (pq driver) database, applies migrations, // and returns sql connection to the database. func GetPqConn(ctx context.Context, tb testing.TB, dsn string, opt ...Option) (*sql.DB, Informer) { tb.Helper() tDB := newTDB(ctx, tb, "postgres", dsn, getPostgresOptions(tb, dsn, opt...)) db, err := tDB.connectSQLDB(ctx, true) if err != nil { tb.Fatalf("cannot connect to postgres: %v", err) } tb.Cleanup(func() { _ = db.Close() }) return db, tDB } // connectPgxDB connects to the database with retries using pgx. func (d *testDB) connectPgxDB(ctx context.Context) (*pgxpool.Pool, error) { var db *pgxpool.Pool dbURL := d.url.replaceDatabase(d.databaseName) d.logger.Info(ctx, "connecting to test database", "url", dbURL.string(true)) err := d.retryConnect(ctx, dbURL.string(true), func() (err error) { db, err = pgxpool.New(ctx, dbURL.string(false)) if err != nil { return err } if err = db.Ping(ctx); err != nil { db.Close() return err } return nil }) if err != nil { return nil, fmt.Errorf("connect postgres url (%s): %w", dbURL.string(false), err) } return db, nil } // disconnect users before deleting the database func disconnectUsers(db *sql.DB, databaseName string) error { _, err := db.Exec( `SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE datname = $1 AND pid <> pg_backend_pid()`, databaseName) return err } // getPostgresOptions returns the options for the postgresql database. func getPostgresOptions(tb testing.TB, dsn string, opt ...Option) []Option { tb.Helper() url, err := parseURL(dsn) if err != nil { tb.Fatalf("failed to parse dsn: %v", err) } optPrepared := make([]Option, 0, len(opt)) optPrepared = append(optPrepared, WithDockerRepository("postgres"), WithPrepareCleanUp(disconnectUsers), WithDockerEnv([]string{ fmt.Sprintf("POSTGRES_USER=%s", url.User), fmt.Sprintf("POSTGRES_PASSWORD=%s", url.Password), fmt.Sprintf("POSTGRES_DB=%s", url.Database), "listen_addresses = '*'", "max_connections = 1000", }), ) optPrepared = append(optPrepared, opt...) return optPrepared }
package testdock import ( "context" "database/sql" "fmt" "testing" ) // GetSQLConn inits a test database, applies migrations, and returns sql connection to the database. // driver: https://go.dev/wiki/SQLDrivers. // Do not forget to import corresponding driver package. func GetSQLConn(tb testing.TB, driver, dsn string, opt ...Option) (*sql.DB, Informer) { tb.Helper() ctx := context.Background() tDB := newTDB(ctx, tb, driver, dsn, opt) db, err := tDB.connectSQLDB(ctx, true) if err != nil { tb.Fatalf("cannot connect to database: %v", err) } tb.Cleanup(func() { _ = db.Close() }) return db, tDB } // connectSQLDB connects to the database with retries using database/sql. // testDatabase: if true, will be connected to the temporary test database. func (d *testDB) connectSQLDB(ctx context.Context, testDatabase bool) (*sql.DB, error) { var dbURL *dbURL if testDatabase { dbURL = d.url.replaceDatabase(d.databaseName) } else { dbURL = d.url.replaceDatabase(d.connectDatabase) } d.logger.Info(ctx, "connecting to test database", "url", dbURL.string(true)) var db *sql.DB err := d.retryConnect(ctx, dbURL.string(true), func() (err error) { db, err = sql.Open(d.driver, dbURL.string(false)) if err != nil { return err } if err = db.Ping(); err != nil { _ = db.Close() return err } return nil }) if err != nil { return nil, fmt.Errorf("connect url (%s): %w", dbURL.string(false), err) } return db, nil } func (d *testDB) createSQLDatabase(ctx context.Context) error { d.logger.Info(ctx, "creating new test sql database", "dsn", d.dsnNoPass, "database", d.databaseName) db, err := d.connectSQLDB(ctx, false) if err != nil { return err } defer db.Close() _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", d.databaseName)) if err != nil { return fmt.Errorf("create db: %w", err) } d.logger.Info(ctx, "new test sql database created", "dsn", d.dsnNoPass, "database", d.databaseName) return nil }
package testdock import ( "errors" "fmt" "sort" "strconv" "strings" ) // dbURL represents a parsed database connection string. // Supported connection string format: // [protocol://]user:password@[transport(]host:port[)][/database][?option1=a&option2=b] // // Required fields: user, password, host, port // Optional fields: protocol, transport, database and options. type dbURL struct { Protocol string Transport string User string Password string Host string Port int Database string Options map[string]string // option1=a&option2=b -> {"option1": "a", "option2": "b"} } // parseURL parses a connection string into a URL. func parseURL(connStr string) (*dbURL, error) { if connStr == "" { return nil, errors.New("connection string cannot be empty") } u := &dbURL{ Options: make(map[string]string), } const splitCount = 2 var rest string // Split protocol and the rest parts := strings.SplitN(connStr, "://", splitCount) if len(parts) == splitCount { // Parse protocol u.Protocol = parts[0] if u.Protocol == "" { return nil, errors.New("invalid connection string format: '://' exists, but no protocol") } rest = parts[1] } else { rest = connStr } // Find the last @ to properly handle @ in passwords atIndex := strings.LastIndex(rest, "@") if atIndex >= 0 { credentials := rest[:atIndex] rest = rest[atIndex+1:] // Parse credentials credParts := strings.SplitN(credentials, ":", splitCount) if len(credParts) != splitCount { return nil, errors.New("invalid connection string format: missing password") } u.User = credParts[0] if u.User == "" { return nil, errors.New("user is required") } u.Password = credParts[1] if u.Password == "" { return nil, errors.New("password is required") } } // Split query parameters if they exist hostAndQuery := strings.SplitN(rest, "?", splitCount) rest = hostAndQuery[0] // Parse query parameters if they exist if len(hostAndQuery) > 1 { queryStr := hostAndQuery[1] for _, param := range strings.Split(queryStr, "&") { kv := strings.SplitN(param, "=", splitCount) if len(kv) == splitCount { u.Options[kv[0]] = kv[1] } } } // Parse database if exists hostAndDB := strings.SplitN(rest, "/", splitCount) rest = hostAndDB[0] if len(hostAndDB) > 1 { u.Database = hostAndDB[1] } // Check if transport is specified if strings.Contains(rest, "(") && strings.HasSuffix(rest, ")") { transportParts := strings.SplitN(rest, "(", splitCount) if len(transportParts) != splitCount { return nil, errors.New("invalid connection string format: malformed transport") } u.Transport = transportParts[0] rest = strings.TrimSuffix(transportParts[1], ")") } if rest == "" { return nil, errors.New("host is required") } // Parse host and port hostAndPort := strings.SplitN(rest, ":", splitCount) if len(hostAndPort) != splitCount { return nil, errors.New("invalid connection string format: missing port") } u.Host = hostAndPort[0] if u.Host == "" { return nil, errors.New("host is required") } if hostAndPort[1] == "" { return nil, errors.New("port is required") } p, err := strconv.Atoi(hostAndPort[1]) if err != nil { return nil, fmt.Errorf("parse port: %w", err) } if p <= 0 { return nil, errors.New("port must be positive") } u.Port = p return u, nil } // string returns the connection string representation of the URL. func (u *dbURL) string(hidePassword bool) string { if u == nil { return "" } var b strings.Builder // Write protocol if u.Protocol != "" { b.WriteString(u.Protocol) b.WriteString("://") } if u.User != "" { // Write credentials b.WriteString(u.User) b.WriteString(":") if hidePassword { b.WriteString("*****") } else { b.WriteString(u.Password) } b.WriteString("@") } // Write transport, host and port if u.Transport != "" { b.WriteString(u.Transport) b.WriteString("(") } b.WriteString(u.Host) if u.Port != 0 { b.WriteString(":" + strconv.Itoa(u.Port)) } if u.Transport != "" { b.WriteString(")") } // Write database if exists if u.Database != "" { b.WriteString("/" + u.Database) } // Write options if exist if len(u.Options) > 0 { b.WriteString("?") // Sort keys for deterministic output keys := make([]string, 0, len(u.Options)) for k := range u.Options { keys = append(keys, k) } sort.Strings(keys) for i, k := range keys { if i > 0 { b.WriteString("&") } b.WriteString(k) b.WriteString("=") b.WriteString(u.Options[k]) } } return b.String() } // clone returns a copy of the URL. func (u *dbURL) clone() *dbURL { if u == nil { return nil } clone := &dbURL{ Protocol: u.Protocol, Transport: u.Transport, User: u.User, Password: u.Password, Host: u.Host, Port: u.Port, Database: u.Database, Options: make(map[string]string, len(u.Options)), } // Deep copy the options map for k, v := range u.Options { clone.Options[k] = v } return clone } // replaceDatabase replaces the database name in the URL. func (u *dbURL) replaceDatabase(newDBName string) *dbURL { clone := u.clone() clone.Database = newDBName return clone }