package testdock
import (
"context"
"errors"
"fmt"
"runtime/pprof"
"strings"
"time"
)
// goroutineProfileDebug selects panic-style stack output for the goroutine profile.
const goroutineProfileDebug = 2
// pgxPoolCloseStats contains pgxpool statistics that explain a close timeout.
type pgxPoolCloseStats struct {
AcquiredConns int32
IdleConns int32
TotalConns int32
ConstructingConns int32
MaxConns int32
}
// closeTimeoutDiagnostics contains fields printed when a returned resource close times out.
type closeTimeoutDiagnostics struct {
TestName string
Resource string
RedactedDSN string
DatabaseName string
Timeout time.Duration
PgxStats *pgxPoolCloseStats
GoroutineDump string
}
// closeResourceWithTimeout closes a returned resource with a bounded wait.
func closeResourceWithTimeout(timeout time.Duration, closeResource func() error, diagnostics func() string) error {
done := make(chan struct{})
go func() {
_ = closeResource()
close(done)
}()
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case <-done:
return nil
case <-timer.C:
message := fmt.Sprintf("close timed out after %s", timeout)
if diagnostics != nil {
if details := diagnostics(); details != "" {
message += "\n" + details
}
}
return errors.New(message)
}
}
// disconnectWithTimeout disconnects a returned client with a timeout-aware context.
func disconnectWithTimeout(timeout time.Duration, disconnect func(context.Context) error) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err := disconnect(ctx)
if errors.Is(ctx.Err(), context.DeadlineExceeded) || errors.Is(err, context.DeadlineExceeded) {
return fmt.Errorf("close timed out after %s", timeout)
}
return nil
}
// formatCloseTimeoutDiagnostics formats close timeout diagnostics.
func formatCloseTimeoutDiagnostics(d closeTimeoutDiagnostics) string {
lines := []string{
"test: " + d.TestName,
"resource: " + d.Resource,
"dsn: " + d.RedactedDSN,
"database: " + d.DatabaseName,
"timeout: " + d.Timeout.String(),
}
if d.PgxStats != nil {
lines = append(lines,
fmt.Sprintf("acquired_conns: %d", d.PgxStats.AcquiredConns),
fmt.Sprintf("idle_conns: %d", d.PgxStats.IdleConns),
fmt.Sprintf("total_conns: %d", d.PgxStats.TotalConns),
fmt.Sprintf("constructing_conns: %d", d.PgxStats.ConstructingConns),
fmt.Sprintf("max_conns: %d", d.PgxStats.MaxConns),
)
}
lines = append(lines, "goroutine dump:", d.GoroutineDump)
return strings.Join(lines, "\n")
}
// closeTimeoutDetails builds diagnostics for a returned resource close timeout.
func (d *testDB) closeTimeoutDetails(resource string, stats *pgxPoolCloseStats) string {
redactedDSN := d.redactedTestDSN()
goroutineDump := redactKnownSecrets(captureGoroutineDump(), d.rawTestDSN(), d.urlPassword())
return formatCloseTimeoutDiagnostics(closeTimeoutDiagnostics{
TestName: d.t.Name(),
Resource: resource,
RedactedDSN: redactedDSN,
DatabaseName: d.databaseName,
Timeout: d.closeTimeout,
PgxStats: stats,
GoroutineDump: goroutineDump,
})
}
// redactedTestDSN returns the temporary database DSN without password.
func (d *testDB) redactedTestDSN() string {
if d.url == nil {
return d.dsnNoPass
}
return d.url.replaceDatabase(d.databaseName).string(true)
}
// rawTestDSN returns the temporary database DSN with password for diagnostic redaction only.
func (d *testDB) rawTestDSN() string {
if d.url == nil {
return d.dsn
}
return d.url.replaceDatabase(d.databaseName).string(false)
}
// urlPassword returns the configured password for diagnostic redaction only.
func (d *testDB) urlPassword() string {
if d.url == nil {
return ""
}
return d.url.Password
}
// captureGoroutineDump captures the runtime goroutine profile in panic-style text format.
func captureGoroutineDump() string {
profile := pprof.Lookup("goroutine")
if profile == nil {
return "goroutine profile is unavailable"
}
var b strings.Builder
if err := profile.WriteTo(&b, goroutineProfileDebug); err != nil {
return fmt.Sprintf("goroutine profile write failed: %v", err)
}
return b.String()
}
// redactKnownSecrets removes known sensitive values from diagnostic text.
func redactKnownSecrets(text string, secrets ...string) string {
redacted := text
for _, secret := range secrets {
if secret == "" {
continue
}
redacted = strings.ReplaceAll(redacted, secret, "*****")
}
return redacted
}
package testdock
import (
"context"
"database/sql"
"fmt"
"sync"
"testing"
"time"
"github.com/cenkalti/backoff/v5"
"github.com/n-r-w/ctxlog"
)
// Informer interface for database information.
type Informer interface {
// DSN returns the real database connection string.
DSN() string
// Host returns the host of the database server.
Host() string
// Port returns the port of the database server.
Port() int
// DatabaseName returns the database name for testing.
DatabaseName() string
}
const (
// DefaultRetryTimeout is the default retry timeout.
DefaultRetryTimeout = time.Second * 3
// DefaultTotalRetryDuration is the default total retry duration.
DefaultTotalRetryDuration = time.Second * 30
// defaultCloseTimeout is the default timeout for closing returned resources during cleanup.
defaultCloseTimeout = time.Second * 30
)
// PrepareCleanUp - function for prepare to delete temporary test database.
// For example, disconnect users.
type PrepareCleanUp func(db *sql.DB, databaseName string) error
// testDB represents a test database.
type testDB struct {
t testing.TB
logger ctxlog.ILogger // unified way to logging
databaseName string // name of the test database
url *dbURL // parsed database connection string
dsnNoPass string // database connection string without password
// options
driver string // database driver (pgx, pq, etc)
mode RunMode // run mode (docker or external)
dsn string // database connection string
retryTimeout time.Duration // retry timeout for connecting to the database
totalRetryDuration time.Duration // total retry duration
closeTimeout time.Duration // timeout for closing returned resources during cleanup
migrationsDir string // migrations directory
migrationTargetVersion int64 // numeric migration file prefix where automatic migration must stop
hasMigrationTargetVersion bool // enables migration up to migrationTargetVersion instead of all migrations
unsetProxyEnv bool // unset HTTP_PROXY, HTTPS_PROXY etc. environment variables
migrateFactory MigrateFactory // unified way to create migrations
prepareCleanUp []PrepareCleanUp // function for prepare to delete temporary test database.
connectDatabase string // database name for connecting to the database server
connectDatabaseOverride bool
dockerPort int // docker port
dockerRepository string // docker hub repository
dockerImage string // docker hub image tag
dockerSocketEndpoint string // docker socket endpoint for connecting to the docker daemon
dockerEnv []string // environment variables for the docker container
}
//nolint:gochecknoglobals // used to synchronize access to the same database connection string across tests.
var (
globalMu sync.Mutex
globalMuByDSN = make(map[string]*sync.Mutex)
)
// newTDB creates a new test database and applies migrations.
func newTDB(ctx context.Context, tb testing.TB, driver, dsn string, opt []Option) *testDB {
tb.Helper()
var (
db = &testDB{
t: tb,
logger: ctxlog.Must(ctxlog.WithTesting(tb)),
databaseName: "",
url: nil,
dsnNoPass: "",
driver: driver,
mode: RunModeAuto,
dsn: dsn,
retryTimeout: DefaultRetryTimeout,
totalRetryDuration: DefaultTotalRetryDuration,
closeTimeout: defaultCloseTimeout,
migrationsDir: "",
migrationTargetVersion: 0,
hasMigrationTargetVersion: false,
unsetProxyEnv: false,
migrateFactory: nil,
prepareCleanUp: nil,
connectDatabase: "",
connectDatabaseOverride: false,
dockerPort: 0,
dockerRepository: "",
dockerImage: "",
dockerSocketEndpoint: "",
dockerEnv: nil,
}
errResult error
)
defer func() {
if errResult != nil {
tb.Fatalf("cannot create test database: %v", errResult)
}
}()
if errResult = db.prepareOptions(driver, opt); errResult != nil {
return nil
}
globalMu.Lock()
mu, ok := globalMuByDSN[db.dsn]
if !ok {
mu = &sync.Mutex{}
globalMuByDSN[db.dsn] = mu
}
globalMu.Unlock()
mu.Lock()
defer mu.Unlock()
if db.mode == RunModeDocker {
db.logger.Info(ctx, "using docker test database", "dsn", db.dsnNoPass)
if errResult = db.createDockerResources(ctx); errResult != nil {
return nil
}
} else {
db.logger.Info(ctx, "using real test database", "dsn", db.dsnNoPass)
}
if errResult = db.createTestDatabase(ctx); errResult != nil {
if err := db.close(ctx); err != nil {
db.logger.Info(ctx, "failed to close test database", "dsn", db.dsnNoPass, "error", err)
}
return nil
}
if db.migrationsDir != "" {
if errResult = db.migrationsUp(ctx); errResult != nil {
return nil
}
}
tb.Cleanup(func() {
cleanupCtx := context.Background()
if closeErr := db.close(cleanupCtx); closeErr != nil {
db.logger.Info(cleanupCtx, "failed to close test database", "dsn", db.dsnNoPass, "error", closeErr)
} else {
db.logger.Info(cleanupCtx, "test database closed", "dsn", db.dsnNoPass)
}
})
return db
}
// migrationsUp applies migrations to the database.
func (d *testDB) migrationsUp(ctx context.Context) error {
d.logger.Info(ctx, "migrations up start", "dsn", d.dsnNoPass)
defer d.logger.Info(ctx, "migrations up end", "dsn", d.dsnNoPass)
dsn := d.url.replaceDatabase(d.databaseName).string(false)
migrator, err := d.migrateFactory(d.t, dsn, d.migrationsDir, d.logger)
if err != nil {
return fmt.Errorf("new migrator: %w", err)
}
if d.hasMigrationTargetVersion {
if err = migrateUpToVersion(ctx, migrator, d.migrationTargetVersion); err != nil {
return fmt.Errorf("up migrations to version: %w", err)
}
return nil
}
if err = migrator.Up(ctx); err != nil {
return fmt.Errorf("up migrations: %w", err)
}
return nil
}
// close closes the test database.
func (d *testDB) close(ctx context.Context) error {
if d.mode != RunModeDocker {
if d.driver == mongoDriverName {
return nil
}
// remove the database created before applying the migrations
d.logger.Info(ctx, "deleting test database", "dsn", d.dsnNoPass, "database", d.databaseName)
dsn := d.url.string(false)
db, err := sql.Open(d.driver, dsn)
if err != nil {
return fmt.Errorf("sql open url (%s): %w", dsn, err)
}
defer func() {
_ = db.Close()
}()
for _, prepareCleanUp := range d.prepareCleanUp {
if prepareErr := prepareCleanUp(db, d.databaseName); prepareErr != nil {
d.logger.Info(ctx, "failed to prepare clean up", "dsn", d.dsnNoPass, "error", prepareErr)
}
}
if _, err = db.ExecContext(ctx, fmt.Sprintf("DROP DATABASE %s", d.databaseName)); err != nil {
return fmt.Errorf("drop db: %w", err)
}
d.logger.Info(ctx, "test database deleted", "dsn", d.dsnNoPass, "database", d.databaseName)
}
return nil
}
// initDatabase creates a test database or connects to an existing one.
func (d *testDB) createTestDatabase(ctx context.Context) error {
if d.driver == mongoDriverName {
return nil
}
return d.createSQLDatabase(ctx)
}
// retryConnect connects to the database with retries.
func (d *testDB) retryConnect(ctx context.Context, info string, op func() error) error {
var attempt int
operation := func() (struct{}, error) {
if err := op(); err != nil {
d.logger.Info(ctx, "retrying operation", "info", info, "attempt", attempt, "error", err)
attempt++
return struct{}{}, err
}
return struct{}{}, nil
}
_, err := backoff.Retry(
context.Background(), operation,
backoff.WithBackOff(backoff.NewConstantBackOff(d.retryTimeout)),
backoff.WithMaxElapsedTime(d.totalRetryDuration),
)
if err != nil {
return fmt.Errorf("retry failed after %d attempts: %w", attempt, err)
}
return nil
}
// DSN returns the real database connection string.
func (d *testDB) DSN() string {
return d.url.replaceDatabase(d.databaseName).string(false)
}
// Host returns the database host.
func (d *testDB) Host() string {
return d.url.Host
}
// Port returns the database port.
func (d *testDB) Port() int {
return d.url.Port
}
// DatabaseName returns the database name for testing.
func (d *testDB) DatabaseName() string {
return d.databaseName
}
package testdock
import (
"context"
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/cenkalti/backoff/v5"
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
)
// we ensure the creation of docker resources only once for all tests.
//
//nolint:gochecknoglobals // used to synchronize access to the same database connection string across tests.
var (
globalDockerMu sync.Mutex
globalDockerResources = make(map[string]*dockerResourceInfo)
globalDockerPool *dockertest.Pool
)
type dockerResourceInfo struct {
resource *dockertest.Resource
port int
count int
mu sync.Mutex
}
// createDockerResources create a pool and a resource for creating a test database in docker.
func (d *testDB) createDockerResources(ctx context.Context) error {
globalDockerMu.Lock()
info, ok := globalDockerResources[d.dsn]
if !ok {
info = &dockerResourceInfo{}
}
logDsn := d.dsnNoPass
if globalDockerPool == nil {
if err := d.createDockerPoolLocked(ctx); err != nil {
globalDockerMu.Unlock()
return err
}
defer d.clearDockerPoolWhenUnused(ctx)
}
globalDockerMu.Unlock()
info.mu.Lock()
defer info.mu.Unlock()
if info.count > 0 {
d.url.Port = info.port
d.logger.Info(ctx, "use existing resources", "component", "docker", "dsn", logDsn)
} else if err := d.createDockerResource(ctx, info, logDsn); err != nil {
return err
}
globalDockerMu.Lock()
globalDockerResources[d.dsn] = info
globalDockerMu.Unlock()
info.count++
d.registerDockerResourceCleanup(info, logDsn)
return nil
}
// createDockerPoolLocked creates the global Docker pool while globalDockerMu is held.
func (d *testDB) createDockerPoolLocked(ctx context.Context) error {
var err error
globalDockerPool, err = dockertest.NewPool(d.dockerSocketEndpoint)
if err != nil {
return fmt.Errorf("dockertest NewPool: %w", err)
}
if d.unsetProxyEnv {
d.unsetDockerProxyEnv(ctx)
}
if err = globalDockerPool.Client.Ping(); err != nil {
return fmt.Errorf("dockertest ping: %w", err)
}
d.logger.Info(ctx, "pool created", "component", "docker")
return nil
}
// unsetDockerProxyEnv removes proxy variables that can affect Docker client calls.
func (d *testDB) unsetDockerProxyEnv(ctx context.Context) {
proxyEnv := []string{
"HTTP_PROXY",
"HTTPS_PROXY",
"ALL_PROXY",
"http_proxy",
"https_proxy",
"all_proxy",
}
for _, env := range proxyEnv {
if os.Getenv(env) == "" {
continue
}
d.logger.Info(ctx, "unset proxy env", "component", "docker", "env", env)
_ = os.Unsetenv(env)
}
}
// clearDockerPoolWhenUnused clears the global Docker pool if no resources were registered.
func (d *testDB) clearDockerPoolWhenUnused(ctx context.Context) {
globalDockerMu.Lock()
defer globalDockerMu.Unlock()
if len(globalDockerResources) != 0 {
return
}
globalDockerPool = nil
d.logger.Info(ctx, "pool purged", "component", "docker")
}
// createDockerResource creates a Docker resource and retries while Docker holds the previous port.
func (d *testDB) createDockerResource(ctx context.Context, info *dockerResourceInfo, logDsn string) error {
const (
maxAttempts = 10
sleepTime = 5 * time.Second
)
var (
attempt int
dockerPort = fmt.Sprintf("%d/tcp", d.dockerPort)
err error
)
for {
runOptions := &dockertest.RunOptions{ //nolint:exhaustruct // optional SDK fields use zero values.
Repository: d.dockerRepository,
Tag: d.dockerImage,
Env: d.dockerEnv,
PortBindings: map[docker.Port][]docker.PortBinding{
docker.Port(dockerPort): {{
HostIP: d.url.Host,
HostPort: strconv.Itoa(d.url.Port),
}},
},
}
info.resource, err = globalDockerPool.RunWithOptions(runOptions, func(config *docker.HostConfig) {
config.AutoRemove = true
config.RestartPolicy = docker.RestartPolicy{Name: "no", MaximumRetryCount: 0}
})
if err == nil {
break
}
if isDockerBindError(err) {
d.logger.Info(ctx, "port is already allocated, trying next port", "dsn", logDsn, "next_port", d.url.Port+1)
d.url.Port++
continue
}
attempt++
if attempt >= maxAttempts {
break
}
d.logger.Info(ctx, "RunWithOptions failed", "component", "docker", "dsn", logDsn, "attempt", attempt, "error", err)
time.Sleep(sleepTime)
}
if err != nil {
return fmt.Errorf("dockertest RunWithOptions: %w", err)
}
info.port = d.url.Port
d.logger.Info(ctx, "resources created", "component", "docker", "dsn", logDsn)
return nil
}
// isDockerBindError checks errors reported when a Docker port is already allocated.
func isDockerBindError(err error) bool {
bindErrors := []string{
"address already in use",
"port is already allocated",
"failed to bind host port",
}
for _, bindError := range bindErrors {
if strings.Contains(err.Error(), bindError) {
return true
}
}
return false
}
// registerDockerResourceCleanup removes the shared Docker resource after the last user test.
func (d *testDB) registerDockerResourceCleanup(info *dockerResourceInfo, logDsn string) {
d.t.Cleanup(func() {
cleanupCtx := context.Background()
info.mu.Lock()
defer info.mu.Unlock()
info.count--
if info.count != 0 {
return
}
globalDockerMu.Lock()
defer globalDockerMu.Unlock()
delete(globalDockerResources, d.dsn)
d.purgeDockerResource(cleanupCtx, info, logDsn)
})
}
// purgeDockerResource purges the Docker resource with retries.
func (d *testDB) purgeDockerResource(ctx context.Context, info *dockerResourceInfo, logDsn string) {
const (
maxTime = 10 * time.Second
retryTimeout = 1 * time.Second
)
var attempt int
operation := func() (struct{}, error) {
if purgeErr := globalDockerPool.Purge(info.resource); purgeErr != nil {
attempt++
d.logger.Info(ctx, "purge attempt failed",
"component", "docker", "dsn", logDsn, "attempt", attempt, "error", purgeErr)
return struct{}{}, purgeErr
}
return struct{}{}, nil
}
if _, retryErr := backoff.Retry(ctx, operation,
backoff.WithBackOff(backoff.NewConstantBackOff(retryTimeout)),
backoff.WithMaxElapsedTime(maxTime)); retryErr != nil {
d.logger.Info(ctx, "purge failed after retries",
"component", "docker", "dsn", logDsn, "attempt", attempt, "error", retryErr)
return
}
d.logger.Info(ctx, "resources purged successfully", "component", "docker", "dsn", logDsn, "attempts", attempt)
}
package testdock
import (
"context"
"database/sql"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"testing"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/mongodb" // require for mongodb
_ "github.com/golang-migrate/migrate/v4/database/postgres" // require for gomigrate
_ "github.com/golang-migrate/migrate/v4/source/file" // require for gomigrate
"github.com/n-r-w/ctxlog"
"github.com/pressly/goose/v3"
)
// MigrateFactory creates a new migrator.
type MigrateFactory func(t testing.TB, dsn, migrationsDir string, logger ctxlog.ILogger) (Migrator, error)
// Migrator interface for applying migrations.
type Migrator interface {
Up(ctx context.Context) error
}
// VersionedMigrator is the contract for migration factories used with WithMigrationsToVersion.
// The version is the numeric file prefix before "_", including timestamp prefixes.
type VersionedMigrator interface {
Migrator
UpTo(ctx context.Context, version int64) error
}
// ApplyMigrations applies all pending migrations to an existing test database.
// The helper fails tb on invalid input, migrator creation errors, or migration errors.
func ApplyMigrations(tb testing.TB, dsn, migrationsDir string, migrateFactory MigrateFactory) {
tb.Helper()
ctx := context.Background()
migrator := newMigratorForTest(tb, dsn, migrationsDir, migrateFactory)
if err := migrator.Up(ctx); err != nil {
tb.Fatalf("cannot apply migrations: %v", err)
}
}
// ApplyMigrationsToVersion applies pending migrations up to and including the target version.
// The version is the numeric file prefix before "_", including timestamp prefixes.
// Custom factories must return a migrator that implements VersionedMigrator.
func ApplyMigrationsToVersion(tb testing.TB, dsn, migrationsDir string, migrateFactory MigrateFactory, version int64) {
tb.Helper()
if err := validateMigrationVersion(version); err != nil {
tb.Fatal(err)
}
ctx := context.Background()
migrator := newMigratorForTest(tb, dsn, migrationsDir, migrateFactory)
if err := migrateUpToVersion(ctx, migrator, version); err != nil {
tb.Fatalf("cannot apply migrations to version: %v", err)
}
}
// newMigratorForTest validates helper input and creates a migrator for a test database.
func newMigratorForTest(tb testing.TB, dsn, migrationsDir string, migrateFactory MigrateFactory) Migrator {
tb.Helper()
if dsn == "" {
tb.Fatal("dsn is empty")
}
if migrationsDir == "" {
tb.Fatal("migrationsDir is empty")
}
if migrateFactory == nil {
tb.Fatal("migrateFactory is nil")
}
logger := ctxlog.Must(ctxlog.WithTesting(tb))
migrator, err := migrateFactory(tb, dsn, migrationsDir, logger)
if err != nil {
tb.Fatalf("cannot create migrator: %v", err)
}
return migrator
}
// migrateUpToVersion applies migrations up to the numeric file prefix requested by the test.
func migrateUpToVersion(ctx context.Context, migrator Migrator, version int64) error {
if err := validateMigrationVersion(version); err != nil {
return err
}
versionedMigrator, ok := migrator.(VersionedMigrator)
if !ok {
return errors.New("WithMigrationsToVersion and ApplyMigrationsToVersion require " +
"migrator to implement VersionedMigrator")
}
return versionedMigrator.UpTo(ctx, version)
}
// validateMigrationVersion rejects values that cannot match a migration file prefix.
func validateMigrationVersion(version int64) error {
if version <= 0 {
return errors.New("migration version must be greater than 0")
}
return nil
}
//nolint:gochecknoglobals // predefined migrator factories.
var (
// GooseMigrateFactoryPGX is a migrator for https://github.com/pressly/goose with pgx driver.
GooseMigrateFactoryPGX = GooseMigrateFactory(goose.DialectPostgres, "pgx")
// GooseMigrateFactoryPQ is a migrator for https://github.com/pressly/goose with pq driver.
GooseMigrateFactoryPQ = GooseMigrateFactory(goose.DialectPostgres, "postgres")
// GooseMigrateFactoryMySQL is a migrator for https://github.com/pressly/goose with mysql driver.
GooseMigrateFactoryMySQL = GooseMigrateFactory(goose.DialectMySQL, "mysql")
)
// GooseMigrateFactory creates a new migrator for https://github.com/pressly/goose.
func GooseMigrateFactory(dialect goose.Dialect, driver string) MigrateFactory {
return func(t testing.TB, dsn, migrationsDir string, logger ctxlog.ILogger) (Migrator, error) {
return newGooseMigrator(t, dialect, driver, dsn, migrationsDir, logger)
}
}
// gooseMigrator is a migrator for goose.
type gooseMigrator struct {
p *goose.Provider
}
// newGooseMigrator creates a new migrator for goose.
func newGooseMigrator(
t testing.TB,
dialect goose.Dialect,
driver, dsn, migrationsDir string,
logger ctxlog.ILogger,
) (*gooseMigrator, error) {
conn, err := sql.Open(driver, dsn)
if err != nil {
return nil, fmt.Errorf("sql open url (%s): %w", dsn, err)
}
p, err := goose.NewProvider(dialect, conn, os.DirFS(migrationsDir),
goose.WithLogger(NewGooseLogger(t, logger)),
goose.WithVerbose(true),
)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("new goose provider: %w", err)
}
return &gooseMigrator{
p: p,
}, nil
}
func (m *gooseMigrator) Up(ctx context.Context) error {
defer m.p.Close() //nolint:errcheck // Close only releases resources; keep migration result.
_, err := m.p.Up(ctx)
return err
}
// UpTo applies goose migrations up to and including the target numeric file prefix.
func (m *gooseMigrator) UpTo(ctx context.Context, version int64) error {
defer m.p.Close() //nolint:errcheck // Close only releases resources; keep migration result.
_, err := m.p.UpTo(ctx, version)
return err
}
// GolangMigrateFactory creates a new migrator for https://github.com/golang-migrate/migrate.
func GolangMigrateFactory(_ testing.TB, dsn, migrationsDir string, logger ctxlog.ILogger) (Migrator, error) {
return newGolangMigrateMigrator(dsn, migrationsDir, logger)
}
// golangMigrateMigrator is a migrator for https://github.com/golang-migrate/migrate.
type golangMigrateMigrator struct {
m *migrate.Migrate
}
// newGolangMigrateMigrator creates a new migrator for https://github.com/golang-migrate/migrate.
func newGolangMigrateMigrator(dsn, migrationsDir string, logger ctxlog.ILogger) (*golangMigrateMigrator, error) {
if !filepath.IsAbs(migrationsDir) {
var err error
migrationsDir, err = filepath.Abs(migrationsDir)
if err != nil {
return nil, fmt.Errorf("get absolute path: %w", err)
}
}
m, err := migrate.New("file://"+migrationsDir, dsn)
if err != nil {
return nil, fmt.Errorf("new migrate: %w", err)
}
m.Log = NewGolangMigrateLogger(logger)
return &golangMigrateMigrator{m: m}, nil
}
func (m *golangMigrateMigrator) Up(_ context.Context) error {
return m.m.Up()
}
// UpTo applies golang-migrate migrations up to the target numeric file prefix.
func (m *golangMigrateMigrator) UpTo(_ context.Context, version int64) error {
migrationVersion, err := migrationVersionToUint(version)
if err != nil {
return err
}
return m.m.Migrate(migrationVersion)
}
// migrationVersionToUint validates that the public int64 version fits golang-migrate.
func migrationVersionToUint(version int64) (uint, error) {
if err := validateMigrationVersion(version); err != nil {
return 0, err
}
const maxUint32 = int64(1<<32 - 1)
if strconv.IntSize == 32 && version > maxUint32 {
return 0, fmt.Errorf("migration version %d overflows uint", version)
}
//nolint:gosec // version is positive, and overflow is checked above on 32-bit platforms.
return uint(version), nil
}
// GooseLogger is a logger for goose.
type GooseLogger struct {
t testing.TB
l ctxlog.ILogger
}
// NewGooseLogger creates a new goose logger.
func NewGooseLogger(t testing.TB, l ctxlog.ILogger) *GooseLogger {
return &GooseLogger{t: t, l: l}
}
// Fatalf logs a fatal error.
func (l GooseLogger) Fatalf(format string, v ...any) {
l.t.Fatalf(format, v...)
}
// Printf logs a message.
func (l GooseLogger) Printf(format string, v ...any) {
l.l.Info(context.Background(), fmt.Sprintf(format, v...))
}
// GolangMigrateLogger is a logger for golang-migrate.
type GolangMigrateLogger struct {
l ctxlog.ILogger
}
// NewGolangMigrateLogger creates a new golang-migrate logger.
func NewGolangMigrateLogger(l ctxlog.ILogger) *GolangMigrateLogger {
return &GolangMigrateLogger{l: l}
}
// Printf logs a message.
func (g *GolangMigrateLogger) Printf(format string, v ...any) {
g.l.Info(context.Background(), fmt.Sprintf(format, v...))
}
// Verbose returns true.
func (*GolangMigrateLogger) Verbose() bool {
return true
}
package testdock
import (
"context"
"fmt"
"testing"
mongov1 "go.mongodb.org/mongo-driver/mongo"
optionsv1 "go.mongodb.org/mongo-driver/mongo/options"
)
// GetMongoDatabase initializes a test MongoDB database, applies migrations, and returns a database connection.
//
//nolint:dupl // similar code, but with different drivers and options.
func GetMongoDatabase(tb testing.TB, dsn string, opt ...Option) (*mongov1.Database, Informer) {
tb.Helper()
ctx := context.Background()
url, err := parseURL(dsn)
if err != nil {
tb.Fatalf("failed to parse dsn: %v", err)
}
optPrepared := make([]Option, 0, len(opt))
optPrepared = append(optPrepared,
WithDockerRepository("mongo"),
WithDockerImage("latest"),
)
if url.User != "" {
optPrepared = append(optPrepared,
WithDockerEnv([]string{
fmt.Sprintf("MONGO_INITDB_ROOT_USERNAME=%s", url.User),
fmt.Sprintf("MONGO_INITDB_ROOT_PASSWORD=%s", url.Password),
}))
}
optPrepared = append(optPrepared, opt...)
tDB := newTDB(ctx, tb, mongoDriverName, dsn, optPrepared)
client, err := tDB.connectMongoDB(ctx)
if err != nil {
tb.Fatalf("cannot connect to mongo: %v", err)
}
tb.Cleanup(func() {
if tDB.mode != RunModeDocker {
// protect against closing connection during tests
clientClean, connectErr := tDB.connectMongoDB(ctx)
if connectErr != nil {
tb.Logf("cannot connect to mongo for cleanup: %v", connectErr)
return
}
defer func() {
if disconnectErr := clientClean.Disconnect(ctx); disconnectErr != nil {
tb.Logf("cannot disconnect mongo cleanup client: %v", disconnectErr)
}
}()
dbClean := clientClean.Database(tDB.databaseName)
if dropErr := dbClean.Drop(ctx); dropErr != nil {
tb.Logf("failed to drop database %s: %v", tDB.databaseName, dropErr)
}
}
if closeErr := disconnectWithTimeout(tDB.closeTimeout, client.Disconnect); closeErr != nil {
tb.Errorf("%v\n%s", closeErr, tDB.closeTimeoutDetails("mongo client", nil))
}
})
return client.Database(tDB.databaseName), tDB
}
// connectMongoDB connects to MongoDB with retries.
func (d *testDB) connectMongoDB(ctx context.Context) (*mongov1.Client, error) {
var (
client *mongov1.Client
err error
)
url := d.url.replaceDatabase(d.databaseName)
err = d.retryConnect(ctx, url.string(true), func() error {
client, err = mongov1.Connect(ctx, optionsv1.Client().ApplyURI(url.string(false)))
if err != nil {
return fmt.Errorf("mongo connect: %w", err)
}
if err = client.Ping(ctx, nil); err != nil {
return fmt.Errorf("mongo ping: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("connect mongo url (%s): %w", url.string(false), err)
}
return client, nil
}
package testdock
import (
"context"
"fmt"
"testing"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
)
// mongo driver name for separating sql and mongo.
const mongoDriverName = "mongodb"
// GetMongoDatabaseV2 initializes a test MongoDB database, applies migrations, and returns a database connection.
//
//nolint:dupl // similar code, but with different drivers and options.
func GetMongoDatabaseV2(tb testing.TB, dsn string, opt ...Option) (*mongo.Database, Informer) {
tb.Helper()
ctx := context.Background()
url, err := parseURL(dsn)
if err != nil {
tb.Fatalf("failed to parse dsn: %v", err)
}
optPrepared := make([]Option, 0, len(opt))
optPrepared = append(optPrepared,
WithDockerRepository("mongo"),
WithDockerImage("latest"),
)
if url.User != "" {
optPrepared = append(optPrepared,
WithDockerEnv([]string{
fmt.Sprintf("MONGO_INITDB_ROOT_USERNAME=%s", url.User),
fmt.Sprintf("MONGO_INITDB_ROOT_PASSWORD=%s", url.Password),
}))
}
optPrepared = append(optPrepared, opt...)
tDB := newTDB(ctx, tb, mongoDriverName, dsn, optPrepared)
client, err := tDB.connectMongoDBv2(ctx)
if err != nil {
tb.Fatalf("cannot connect to mongo: %v", err)
}
tb.Cleanup(func() {
if tDB.mode != RunModeDocker {
// protect against closing connection during tests
clientClean, connectErr := tDB.connectMongoDBv2(ctx)
if connectErr != nil {
tb.Logf("cannot connect to mongo for cleanup: %v", connectErr)
return
}
defer func() {
if disconnectErr := clientClean.Disconnect(ctx); disconnectErr != nil {
tb.Logf("cannot disconnect mongo cleanup client: %v", disconnectErr)
}
}()
dbClean := clientClean.Database(tDB.databaseName)
if dropErr := dbClean.Drop(ctx); dropErr != nil {
tb.Logf("failed to drop database %s: %v", tDB.databaseName, dropErr)
}
}
if closeErr := disconnectWithTimeout(tDB.closeTimeout, client.Disconnect); closeErr != nil {
tb.Errorf("%v\n%s", closeErr, tDB.closeTimeoutDetails("mongo client", nil))
}
})
return client.Database(tDB.databaseName), tDB
}
// connectMongoDBv2 connects to MongoDB with retries.
func (d *testDB) connectMongoDBv2(ctx context.Context) (*mongo.Client, error) {
var (
client *mongo.Client
err error
)
url := d.url.replaceDatabase(d.databaseName)
err = d.retryConnect(ctx, url.string(true), func() error {
client, err = mongo.Connect(options.Client().ApplyURI(url.string(false)))
if err != nil {
return fmt.Errorf("mongo connect: %w", err)
}
if err = client.Ping(ctx, nil); err != nil {
return fmt.Errorf("mongo ping: %w", err)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("connect mongo url (%s): %w", url.string(false), err)
}
return client, nil
}
package testdock
import (
"database/sql"
"fmt"
"testing"
_ "github.com/go-sql-driver/mysql" // mysql driver
)
// GetMySQLConn inits a test mysql database, applies migrations.
// Use user root for docker test database.
func GetMySQLConn(tb testing.TB, dsn string, opt ...Option) (*sql.DB, Informer) {
tb.Helper()
url, err := parseURL(dsn)
if err != nil {
tb.Fatalf("failed to parse dsn: %v", err)
}
optPrepared := make([]Option, 0, len(opt))
optPrepared = append(optPrepared,
WithDockerRepository("mysql"),
WithDockerImage("9.1.0"),
WithDockerEnv([]string{
fmt.Sprintf("MYSQL_ROOT_PASSWORD=%s", url.Password),
fmt.Sprintf("MYSQL_DATABASE=%s", url.Database),
}),
)
optPrepared = append(optPrepared, opt...)
return GetSQLConn(tb, "mysql", dsn, optPrepared...)
}
package testdock
import (
"errors"
"fmt"
"os"
"strings"
"time"
"github.com/google/uuid"
"github.com/n-r-w/ctxlog"
)
//nolint:gosec // we use hardcoded credentials for testing purposes, which is not a security issue.
const (
// DefaultMongoDSN - default mongodb connection string.
DefaultMongoDSN = "mongodb://testuser:secret@127.0.0.1:27017/testdb?authSource=admin"
// DefaultMySQLDSN - default mysql connection string.
DefaultMySQLDSN = "root:secret@tcp(127.0.0.1:3306)/test_db"
// DefaultPostgresDSN - default postgres connection string.
DefaultPostgresDSN = "postgres://postgres:secret@127.0.0.1:5432/postgres?sslmode=disable"
)
// RunMode defines the run mode of the test database.
type RunMode int
const (
// RunModeUnknown - unknown run mode.
RunModeUnknown RunMode = 0
// RunModeDocker - run the tests in docker.
RunModeDocker RunMode = 1
// RunModeExternal - run the tests in external database.
RunModeExternal RunMode = 2
// RunModeAuto - checks the environment variable TESTDOCK_DSN_[DRIVER]. If it is set,
// then RunModeExternal, otherwise RunModeDocker.
// If TESTDOCK_DSN_[DRIVER] is set and RunModeAuto, WithDSN option is ignored.
// For example, for postgres pgx driver:
// TESTDOCK_DSN_PGX=postgres://postgres:secret@localhost:5432/postgres&sslmode=disable
RunModeAuto RunMode = 3
)
// Option option for creating a test database.
type Option func(*testDB)
// WithMode sets the mode for the test database.
// The default is RunModeAuto.
func WithMode(mode RunMode) Option {
return func(o *testDB) {
o.mode = mode
}
}
// WithDockerRepository sets the name of docker hub repository.
// Required for RunModeDocker or RunModeAuto with empty environment variable TESTDOCK_DSN_[DRIVER].
func WithDockerRepository(dockerRepository string) Option {
return func(o *testDB) {
o.dockerRepository = dockerRepository
}
}
// WithDockerImage sets the name of the docker image.
// The default is `latest`.
func WithDockerImage(dockerImage string) Option {
return func(o *testDB) {
o.dockerImage = dockerImage
}
}
// WithDockerSocketEndpoint sets the docker socket endpoint for connecting to the docker daemon.
// The default is autodetect.
func WithDockerSocketEndpoint(dockerSocketEndpoint string) Option {
return func(o *testDB) {
o.dockerSocketEndpoint = dockerSocketEndpoint
}
}
// WithDockerPort sets the port for connecting to database in docker.
// The default is the port from the DSN.
func WithDockerPort(dockerPort int) Option {
return func(o *testDB) {
o.dockerPort = dockerPort
}
}
// WithRetryTimeout sets the timeout for connecting to the database.
// The default is 3 second. Must be less than totalRetryDuration.
func WithRetryTimeout(retryTimeout time.Duration) Option {
return func(o *testDB) {
o.retryTimeout = retryTimeout
}
}
// WithTotalRetryDuration sets the total retry duration.
// The default is 30 seconds. Must be greater than retryTimeout.
func WithTotalRetryDuration(totalRetryDuration time.Duration) Option {
return func(o *testDB) {
o.totalRetryDuration = totalRetryDuration
}
}
// WithCloseTimeout sets the timeout for closing returned resources during cleanup.
// The default is 30 seconds. The timeout must be greater than 0.
// The timeout covers pgxpool.Pool.Close, sql.DB.Close, and mongo.Client.Disconnect.
// It does not cover SQL DROP DATABASE, MongoDB Drop, or Docker cleanup.
func WithCloseTimeout(closeTimeout time.Duration) Option {
return func(o *testDB) {
o.closeTimeout = closeTimeout
}
}
// WithLogger sets the logger for the test database.
// The default is logger from testing.TB.
func WithLogger(logger ctxlog.ILogger) Option {
return func(o *testDB) {
o.logger = logger
}
}
// WithMigrations sets the directory and factory for the migrations.
func WithMigrations(migrationsDir string, migrateFactory MigrateFactory) Option {
return func(o *testDB) {
o.migrationsDir = migrationsDir
o.migrateFactory = migrateFactory
o.hasMigrationTargetVersion = false
o.migrationTargetVersion = 0
}
}
// WithMigrationsToVersion applies migrations up to and including the target version.
// The version is the numeric file prefix before "_", including timestamp prefixes.
// Custom factories must return a migrator that implements VersionedMigrator.
func WithMigrationsToVersion(migrationsDir string, migrateFactory MigrateFactory, version int64) Option {
return func(o *testDB) {
o.migrationsDir = migrationsDir
o.migrateFactory = migrateFactory
o.hasMigrationTargetVersion = true
o.migrationTargetVersion = version
}
}
// WithDockerEnv sets the environment variables for the docker container.
// The default is empty.
func WithDockerEnv(dockerEnv []string) Option {
return func(o *testDB) {
o.dockerEnv = dockerEnv
}
}
// WithUnsetProxyEnv unsets the proxy environment variables.
// The default is false.
func WithUnsetProxyEnv(unsetProxyEnv bool) Option {
return func(o *testDB) {
o.unsetProxyEnv = unsetProxyEnv
}
}
// WithPrepareCleanUp sets the function for prepare to delete temporary test database.
// The default is empty, but `GetPgxPool` and `GetPqConn` use it
// to automatically apply cleanup handlers to disconnect all users from the database
// before cleaning up.
func WithPrepareCleanUp(prepareCleanUp PrepareCleanUp) Option {
return func(o *testDB) {
o.prepareCleanUp = append(o.prepareCleanUp, prepareCleanUp)
}
}
// WithConnectDatabase sets the name of the database to connect to.
// The default will be take from the DSN.
func WithConnectDatabase(connectDatabase string) Option {
return func(o *testDB) {
o.connectDatabase = connectDatabase
o.connectDatabaseOverride = true
}
}
func (d *testDB) prepareOptions(driver string, options []Option) error {
for _, o := range options {
o(d)
}
if d.totalRetryDuration <= d.retryTimeout {
return errors.New("totalRetryDuration must be greater than retryTimeout")
}
if d.closeTimeout <= 0 {
return errors.New("closeTimeout must be greater than 0")
}
if d.driver == "" {
return errors.New("driver is empty")
}
if d.mode == RunModeAuto {
dsnEnv := os.Getenv(fmt.Sprintf("TESTDOCK_DSN_%s", strings.ToUpper(driver)))
if dsnEnv != "" {
d.dsn = dsnEnv
d.mode = RunModeExternal
} else {
d.mode = RunModeDocker
}
}
if d.dsn == "" {
return errors.New("dsn is empty")
}
p, err := parseURL(d.dsn)
if err != nil {
return fmt.Errorf("parse dsn: %w", err)
}
d.url = p
d.dsnNoPass = p.string(true)
if !d.connectDatabaseOverride && d.connectDatabase == "" {
d.connectDatabase = p.Database
}
if d.mode == RunModeDocker {
if err = d.prepareDockerOptions(p); err != nil {
return err
}
}
dbName := fmt.Sprintf("t_%s_%s", time.Now().Format("2006_0102_1504_05"), uuid.New().String())
d.databaseName = strings.ReplaceAll(dbName, "-", "")
if (d.migrateFactory == nil) != (d.migrationsDir == "") {
return errors.New("MigrateFactory and migrationsDir must be set together")
}
if d.hasMigrationTargetVersion && d.migrationsDir == "" {
return errors.New("migration target version requires migrationsDir and MigrateFactory")
}
if d.hasMigrationTargetVersion {
if err = validateMigrationVersion(d.migrationTargetVersion); err != nil {
return fmt.Errorf("migration target version: %w", err)
}
}
return nil
}
// prepareDockerOptions validates and fills Docker-specific options.
func (d *testDB) prepareDockerOptions(p *dbURL) error {
if d.dockerRepository == "" {
return errors.New("dockerRepository is empty")
}
if d.dockerImage == "" {
d.dockerImage = "latest"
}
if d.dockerPort > 0 {
return nil
}
d.dockerPort = p.Port
if d.dockerPort <= 0 {
return errors.New("dockerPort must be greater than 0")
}
return nil
}
package testdock
import (
"context"
"database/sql"
"fmt"
"testing"
"github.com/jackc/pgx/v5/pgxpool"
_ "github.com/jackc/pgx/v5/stdlib" // pgx postgres driver
_ "github.com/lib/pq" // pq postgres driver
)
// GetPgxPool inits a test postgresql (pgx driver) database, applies migrations,
// and returns pgx connection pool to the database.
func GetPgxPool(tb testing.TB, dsn string, opt ...Option) (*pgxpool.Pool, Informer) {
tb.Helper()
ctx := context.Background()
tDB := newTDB(ctx, tb, "pgx", dsn, getPostgresOptions(tb, dsn, opt...))
db, err := tDB.connectPgxDB(ctx)
if err != nil {
tb.Fatalf("cannot connect to postgres: %v", err)
}
tb.Cleanup(func() {
if closeErr := closeResourceWithTimeout(tDB.closeTimeout, func() error {
db.Close()
return nil
}, func() string {
return tDB.closeTimeoutDetails("pgxpool", snapshotPgxPoolStats(db))
}); closeErr != nil {
tb.Errorf("%v", closeErr)
}
})
return db, tDB
}
// GetPqConn inits a test postgresql (pq driver) database, applies migrations,
// and returns sql connection to the database.
func GetPqConn(ctx context.Context, tb testing.TB, dsn string, opt ...Option) (*sql.DB, Informer) {
tb.Helper()
tDB := newTDB(ctx, tb, "postgres", dsn, getPostgresOptions(tb, dsn, opt...))
db, err := tDB.connectSQLDB(ctx, true)
if err != nil {
tb.Fatalf("cannot connect to postgres: %v", err)
}
tb.Cleanup(func() {
if closeErr := closeResourceWithTimeout(tDB.closeTimeout, db.Close, func() string {
return tDB.closeTimeoutDetails("postgres sql connection", nil)
}); closeErr != nil {
tb.Errorf("%v", closeErr)
}
})
return db, tDB
}
// snapshotPgxPoolStats captures the pgxpool counters required for close-timeout diagnostics.
func snapshotPgxPoolStats(pool *pgxpool.Pool) *pgxPoolCloseStats {
stats := pool.Stat()
return &pgxPoolCloseStats{
AcquiredConns: stats.AcquiredConns(),
IdleConns: stats.IdleConns(),
TotalConns: stats.TotalConns(),
ConstructingConns: stats.ConstructingConns(),
MaxConns: stats.MaxConns(),
}
}
// connectPgxDB connects to the database with retries using pgx.
func (d *testDB) connectPgxDB(ctx context.Context) (*pgxpool.Pool, error) {
var db *pgxpool.Pool
dbURL := d.url.replaceDatabase(d.databaseName)
d.logger.Info(ctx, "connecting to test database", "url", dbURL.string(true))
err := d.retryConnect(ctx, dbURL.string(true), func() (err error) {
db, err = pgxpool.New(ctx, dbURL.string(false))
if err != nil {
return err
}
if err = db.Ping(ctx); err != nil {
db.Close()
return err
}
return nil
})
if err != nil {
return nil, fmt.Errorf("connect postgres url (%s): %w", dbURL.string(false), err)
}
return db, nil
}
// disconnectUsers disconnects users before deleting the database.
func disconnectUsers(db *sql.DB, databaseName string) error {
_, err := db.ExecContext(context.Background(),
`SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE datname = $1 AND pid <> pg_backend_pid()`,
databaseName)
return err
}
// getPostgresOptions returns the options for the postgresql database.
func getPostgresOptions(tb testing.TB, dsn string, opt ...Option) []Option {
tb.Helper()
url, err := parseURL(dsn)
if err != nil {
tb.Fatalf("failed to parse dsn: %v", err)
}
optPrepared := make([]Option, 0, len(opt))
optPrepared = append(optPrepared,
WithDockerRepository("postgres"),
WithPrepareCleanUp(disconnectUsers),
WithDockerEnv([]string{
fmt.Sprintf("POSTGRES_USER=%s", url.User),
fmt.Sprintf("POSTGRES_PASSWORD=%s", url.Password),
fmt.Sprintf("POSTGRES_DB=%s", url.Database),
"listen_addresses = '*'",
"max_connections = 1000",
}),
)
optPrepared = append(optPrepared, opt...)
return optPrepared
}
package testdock
import (
"context"
"database/sql"
"fmt"
"testing"
)
// GetSQLConn inits a test database, applies migrations, and returns sql connection to the database.
// driver: https://go.dev/wiki/SQLDrivers.
// Do not forget to import corresponding driver package.
func GetSQLConn(tb testing.TB, driver, dsn string, opt ...Option) (*sql.DB, Informer) {
tb.Helper()
ctx := context.Background()
tDB := newTDB(ctx, tb, driver, dsn, opt)
db, err := tDB.connectSQLDB(ctx, true)
if err != nil {
tb.Fatalf("cannot connect to database: %v", err)
}
tb.Cleanup(func() {
if closeErr := closeResourceWithTimeout(tDB.closeTimeout, db.Close, func() string {
return tDB.closeTimeoutDetails("sql connection", nil)
}); closeErr != nil {
tb.Errorf("%v", closeErr)
}
})
return db, tDB
}
// connectSQLDB connects to the database with retries using database/sql.
// testDatabase: if true, will be connected to the temporary test database.
func (d *testDB) connectSQLDB(ctx context.Context, testDatabase bool) (*sql.DB, error) {
var dbURL *dbURL
if testDatabase {
dbURL = d.url.replaceDatabase(d.databaseName)
} else {
dbURL = d.url.replaceDatabase(d.connectDatabase)
}
d.logger.Info(ctx, "connecting to test database", "url", dbURL.string(true))
var db *sql.DB
err := d.retryConnect(ctx, dbURL.string(true), func() (err error) {
db, err = sql.Open(d.driver, dbURL.string(false))
if err != nil {
return err
}
if err = db.PingContext(ctx); err != nil {
_ = db.Close()
return err
}
return nil
})
if err != nil {
return nil, fmt.Errorf("connect url (%s): %w", dbURL.string(false), err)
}
return db, nil
}
func (d *testDB) createSQLDatabase(ctx context.Context) error {
d.logger.Info(ctx, "creating new test sql database", "dsn", d.dsnNoPass, "database", d.databaseName)
db, err := d.connectSQLDB(ctx, false)
if err != nil {
return err
}
defer db.Close() //nolint:errcheck // Close only releases setup connection; keep ExecContext result.
_, err = db.ExecContext(ctx, fmt.Sprintf("CREATE DATABASE %s", d.databaseName))
if err != nil {
return fmt.Errorf("create db: %w", err)
}
d.logger.Info(ctx, "new test sql database created", "dsn", d.dsnNoPass, "database", d.databaseName)
return nil
}
package testdock
import (
"errors"
"fmt"
"maps"
"slices"
"strconv"
"strings"
)
// dbURL represents a parsed database connection string.
// Supported connection string format:
// [protocol://]user:password@[transport(]host:port[)][/database][?option1=a&option2=b]
//
// Required fields: user, password, host, port
// Optional fields: protocol, transport, database and options.
type dbURL struct {
Protocol string
Transport string
User string
Password string
Host string
Port int
Database string
Options map[string]string // option1=a&option2=b -> {"option1": "a", "option2": "b"}
}
// parseURL parses a connection string into a URL.
func parseURL(connStr string) (*dbURL, error) {
if connStr == "" {
return nil, errors.New("connection string cannot be empty")
}
u := &dbURL{
Protocol: "",
Transport: "",
User: "",
Password: "",
Host: "",
Port: 0,
Database: "",
Options: make(map[string]string),
}
const splitCount = 2
var rest string
// Split protocol and the rest
parts := strings.SplitN(connStr, "://", splitCount)
if len(parts) == splitCount {
// Parse protocol
u.Protocol = parts[0]
if u.Protocol == "" {
return nil, errors.New("invalid connection string format: '://' exists, but no protocol")
}
rest = parts[1]
} else {
rest = connStr
}
// Find the last @ to properly handle @ in passwords
atIndex := strings.LastIndex(rest, "@")
if atIndex >= 0 {
credentials := rest[:atIndex]
rest = rest[atIndex+1:]
// Parse credentials
credParts := strings.SplitN(credentials, ":", splitCount)
if len(credParts) != splitCount {
return nil, errors.New("invalid connection string format: missing password")
}
u.User = credParts[0]
if u.User == "" {
return nil, errors.New("user is required")
}
u.Password = credParts[1]
if u.Password == "" {
return nil, errors.New("password is required")
}
}
// Split query parameters if they exist
hostAndQuery := strings.SplitN(rest, "?", splitCount)
rest = hostAndQuery[0]
// Parse query parameters if they exist
if len(hostAndQuery) > 1 {
queryStr := hostAndQuery[1]
for param := range strings.SplitSeq(queryStr, "&") {
kv := strings.SplitN(param, "=", splitCount)
if len(kv) == splitCount {
u.Options[kv[0]] = kv[1]
}
}
}
// Parse database if exists
hostAndDB := strings.SplitN(rest, "/", splitCount)
rest = hostAndDB[0]
if len(hostAndDB) > 1 {
u.Database = hostAndDB[1]
}
// Check if transport is specified
if strings.Contains(rest, "(") && strings.HasSuffix(rest, ")") {
transportParts := strings.SplitN(rest, "(", splitCount)
if len(transportParts) != splitCount {
return nil, errors.New("invalid connection string format: malformed transport")
}
u.Transport = transportParts[0]
rest = strings.TrimSuffix(transportParts[1], ")")
}
if rest == "" {
return nil, errors.New("host is required")
}
// Parse host and port
hostAndPort := strings.SplitN(rest, ":", splitCount)
if len(hostAndPort) != splitCount {
return nil, errors.New("invalid connection string format: missing port")
}
u.Host = hostAndPort[0]
if u.Host == "" {
return nil, errors.New("host is required")
}
if hostAndPort[1] == "" {
return nil, errors.New("port is required")
}
p, err := strconv.Atoi(hostAndPort[1])
if err != nil {
return nil, fmt.Errorf("parse port: %w", err)
}
if p <= 0 {
return nil, errors.New("port must be positive")
}
u.Port = p
return u, nil
}
// string returns the connection string representation of the URL.
func (u *dbURL) string(hidePassword bool) string {
if u == nil {
return ""
}
var b strings.Builder
writeString := func(s string) {
_, _ = b.WriteString(s)
}
// Write protocol
if u.Protocol != "" {
writeString(u.Protocol)
writeString("://")
}
if u.User != "" {
// Write credentials
writeString(u.User)
writeString(":")
if hidePassword {
writeString("*****")
} else {
writeString(u.Password)
}
writeString("@")
}
// Write transport, host and port
if u.Transport != "" {
writeString(u.Transport)
writeString("(")
}
writeString(u.Host)
if u.Port != 0 {
writeString(":")
writeString(strconv.Itoa(u.Port))
}
if u.Transport != "" {
writeString(")")
}
// Write database if exists
if u.Database != "" {
writeString("/")
writeString(u.Database)
}
// Write options if exist
if len(u.Options) > 0 {
writeString("?")
// Sort keys for deterministic output
keys := make([]string, 0, len(u.Options))
for k := range u.Options {
keys = append(keys, k)
}
slices.Sort(keys)
for i, k := range keys {
if i > 0 {
writeString("&")
}
writeString(k)
writeString("=")
writeString(u.Options[k])
}
}
return b.String()
}
// clone returns a copy of the URL.
func (u *dbURL) clone() *dbURL {
if u == nil {
return nil
}
clone := &dbURL{
Protocol: u.Protocol,
Transport: u.Transport,
User: u.User,
Password: u.Password,
Host: u.Host,
Port: u.Port,
Database: u.Database,
Options: make(map[string]string, len(u.Options)),
}
// Deep copy the options map
maps.Copy(clone.Options, u.Options)
return clone
}
// replaceDatabase replaces the database name in the URL.
func (u *dbURL) replaceDatabase(newDBName string) *dbURL {
clone := u.clone()
clone.Database = newDBName
return clone
}